mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-15 14:50:25 +01:00
refactor: two loops over new_sample_names merged.
This commit is contained in:
parent
3f0040a815
commit
140268c1c9
1 changed files with 8 additions and 10 deletions
|
@ -1075,16 +1075,6 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
|||
|
||||
is_merged = {name: (len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0) for name in new_sample_names}
|
||||
reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 0
|
||||
new_idl_d = {}
|
||||
for name in new_sample_names:
|
||||
idl = []
|
||||
for i_data in raveled_data:
|
||||
tmp_idl = i_data.idl.get(name)
|
||||
if tmp_idl is not None:
|
||||
idl.append(tmp_idl)
|
||||
new_idl_d[name] = _merge_idx(idl)
|
||||
if not is_merged[name]:
|
||||
is_merged[name] = (1 != len(set([len(idx) for idx in [*idl, new_idl_d[name]]])))
|
||||
|
||||
if data.ndim == 1:
|
||||
values = np.array([o.value for o in data])
|
||||
|
@ -1098,13 +1088,21 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
|||
multi = 1
|
||||
|
||||
new_r_values = {}
|
||||
new_idl_d = {}
|
||||
for name in new_sample_names:
|
||||
idl = []
|
||||
tmp_values = np.zeros(n_obs)
|
||||
for i, item in enumerate(raveled_data):
|
||||
tmp_values[i] = item.r_values.get(name, item.value)
|
||||
tmp_idl = item.idl.get(name)
|
||||
if tmp_idl is not None:
|
||||
idl.append(tmp_idl)
|
||||
if multi > 0:
|
||||
tmp_values = np.array(tmp_values).reshape(data.shape)
|
||||
new_r_values[name] = func(tmp_values, **kwargs)
|
||||
new_idl_d[name] = _merge_idx(idl)
|
||||
if not is_merged[name]:
|
||||
is_merged[name] = (1 != len(set([len(idx) for idx in [*idl, new_idl_d[name]]])))
|
||||
|
||||
if 'man_grad' in kwargs:
|
||||
deriv = np.asarray(kwargs.get('man_grad'))
|
||||
|
|
Loading…
Add table
Reference in a new issue