refactor: two loops over new_sample_names merged.

This commit is contained in:
Fabian Joswig 2021-12-08 15:17:32 +00:00
parent 3f0040a815
commit 140268c1c9

View file

@ -1075,16 +1075,6 @@ def derived_observable(func, data, array_mode=False, **kwargs):
is_merged = {name: (len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0) for name in new_sample_names}
reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 0
new_idl_d = {}
for name in new_sample_names:
idl = []
for i_data in raveled_data:
tmp_idl = i_data.idl.get(name)
if tmp_idl is not None:
idl.append(tmp_idl)
new_idl_d[name] = _merge_idx(idl)
if not is_merged[name]:
is_merged[name] = (1 != len(set([len(idx) for idx in [*idl, new_idl_d[name]]])))
if data.ndim == 1:
values = np.array([o.value for o in data])
@ -1098,13 +1088,21 @@ def derived_observable(func, data, array_mode=False, **kwargs):
multi = 1
new_r_values = {}
new_idl_d = {}
for name in new_sample_names:
idl = []
tmp_values = np.zeros(n_obs)
for i, item in enumerate(raveled_data):
tmp_values[i] = item.r_values.get(name, item.value)
tmp_idl = item.idl.get(name)
if tmp_idl is not None:
idl.append(tmp_idl)
if multi > 0:
tmp_values = np.array(tmp_values).reshape(data.shape)
new_r_values[name] = func(tmp_values, **kwargs)
new_idl_d[name] = _merge_idx(idl)
if not is_merged[name]:
is_merged[name] = (1 != len(set([len(idx) for idx in [*idl, new_idl_d[name]]])))
if 'man_grad' in kwargs:
deriv = np.asarray(kwargs.get('man_grad'))