diff --git a/pyerrors/pyerrors.py b/pyerrors/pyerrors.py index ba913ad8..c9b3c538 100644 --- a/pyerrors/pyerrors.py +++ b/pyerrors/pyerrors.py @@ -943,8 +943,8 @@ def derived_observable(func, data, **kwargs): n_obs = len(raveled_data) new_names = sorted(set([y for x in [o.names for o in raveled_data] for y in x])) - is_merged = len(list(filter(lambda o: o.is_merged is True, raveled_data))) > 1 - reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 1 + is_merged = len(list(filter(lambda o: o.is_merged is True, raveled_data))) > 0 + reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 0 new_idl_d = {} for name in new_names: idl = [] diff --git a/tests/pyerrors_test.py b/tests/pyerrors_test.py index 1cefe09d..8b0f0704 100644 --- a/tests/pyerrors_test.py +++ b/tests/pyerrors_test.py @@ -276,6 +276,15 @@ def test_cobs(): assert (other / my_cobs * my_cobs - other).is_zero() +def test_reweighting(): + my_obs = pe.Obs([np.random.rand(1000)], ['t']) + assert not my_obs.reweighted + r_obs = pe.reweight(my_obs, [my_obs]) + assert r_obs[0].reweighted + r_obs2 = r_obs[0] * my_obs + assert r_obs2.reweighted + + def test_irregular_error_propagation(): obs_list = [pe.Obs([np.random.rand(100)], ['t']), pe.Obs([np.random.rand(50)], ['t'], idl=[range(1, 100, 2)]),