From 100228d373621161eb56557a2b8ccf98edc2b46a Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Fri, 29 Oct 2021 11:56:58 +0100 Subject: [PATCH] reweighted is now bool, performance of reweighted check improved --- pyerrors/pyerrors.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyerrors/pyerrors.py b/pyerrors/pyerrors.py index 2506f253..bf5af0f3 100644 --- a/pyerrors/pyerrors.py +++ b/pyerrors/pyerrors.py @@ -109,7 +109,7 @@ class Obs: self._dvalue = 0.0 self.ddvalue = 0.0 - self.reweighted = 0 + self.reweighted = False self.tag = None @@ -934,7 +934,7 @@ def derived_observable(func, data, **kwargs): new_names = sorted(set([y for x in [o.names for o in raveled_data] for y in x])) is_merged = len(list(filter(lambda o: o.is_merged is True, raveled_data))) > 1 - reweighted = np.max([o.reweighted for o in raveled_data]) + reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 1 new_idl_d = {} for name in new_names: idl = [] @@ -1097,7 +1097,7 @@ def reweight(weight, obs, **kwargs): new_weight = Obs([w_deltas[name] + weight.r_values[name] for name in sorted(weight.names)], sorted(weight.names), idl=[obs[i].idl[name] for name in sorted(weight.names)]) result.append(derived_observable(lambda x, **kwargs: x[0] / x[1], [tmp_obs, new_weight], **kwargs)) - result[-1].reweighted = 1 + result[-1].reweighted = True result[-1].is_merged = obs[i].is_merged return result @@ -1117,9 +1117,9 @@ def correlate(obs_a, obs_b): if obs_a.shape[name] != obs_b.shape[name]: raise Exception('Shapes of ensemble', name, 'do not fit') - if obs_a.reweighted == 1: + if obs_a.reweighted is True: warnings.warn("The first observable is already reweighted.", RuntimeWarning) - if obs_b.reweighted == 1: + if obs_b.reweighted is True: warnings.warn("The second observable is already reweighted.", RuntimeWarning) new_samples = [] @@ -1128,7 +1128,7 @@ def correlate(obs_a, obs_b): o = Obs(new_samples, sorted(obs_a.names)) o.is_merged = obs_a.is_merged or obs_b.is_merged - o.reweighted = np.max([obs_a.reweighted, obs_b.reweighted]) + o.reweighted = obs_a.reweighted or obs_b.reweighted return o