reweighted is now bool, performance of reweighted check improved

This commit is contained in:
Fabian Joswig 2021-10-29 11:56:58 +01:00
parent 607d77de25
commit 100228d373

View file

@ -109,7 +109,7 @@ class Obs:
self._dvalue = 0.0
self.ddvalue = 0.0
self.reweighted = 0
self.reweighted = False
self.tag = None
@ -934,7 +934,7 @@ def derived_observable(func, data, **kwargs):
new_names = sorted(set([y for x in [o.names for o in raveled_data] for y in x]))
is_merged = len(list(filter(lambda o: o.is_merged is True, raveled_data))) > 1
reweighted = np.max([o.reweighted for o in raveled_data])
reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 1
new_idl_d = {}
for name in new_names:
idl = []
@ -1097,7 +1097,7 @@ def reweight(weight, obs, **kwargs):
new_weight = Obs([w_deltas[name] + weight.r_values[name] for name in sorted(weight.names)], sorted(weight.names), idl=[obs[i].idl[name] for name in sorted(weight.names)])
result.append(derived_observable(lambda x, **kwargs: x[0] / x[1], [tmp_obs, new_weight], **kwargs))
result[-1].reweighted = 1
result[-1].reweighted = True
result[-1].is_merged = obs[i].is_merged
return result
@ -1117,9 +1117,9 @@ def correlate(obs_a, obs_b):
if obs_a.shape[name] != obs_b.shape[name]:
raise Exception('Shapes of ensemble', name, 'do not fit')
if obs_a.reweighted == 1:
if obs_a.reweighted is True:
warnings.warn("The first observable is already reweighted.", RuntimeWarning)
if obs_b.reweighted == 1:
if obs_b.reweighted is True:
warnings.warn("The second observable is already reweighted.", RuntimeWarning)
new_samples = []
@ -1128,7 +1128,7 @@ def correlate(obs_a, obs_b):
o = Obs(new_samples, sorted(obs_a.names))
o.is_merged = obs_a.is_merged or obs_b.is_merged
o.reweighted = np.max([obs_a.reweighted, obs_b.reweighted])
o.reweighted = obs_a.reweighted or obs_b.reweighted
return o