refactor: generation of new r_values in derived_observable simplified.

This commit is contained in:
Fabian Joswig 2021-12-08 15:09:40 +00:00
parent 5ced94e086
commit 3f0040a815

View file

@ -1079,9 +1079,9 @@ def derived_observable(func, data, array_mode=False, **kwargs):
for name in new_sample_names:
idl = []
for i_data in raveled_data:
tmp = i_data.idl.get(name)
if tmp is not None:
idl.append(tmp)
tmp_idl = i_data.idl.get(name)
if tmp_idl is not None:
idl.append(tmp_idl)
new_idl_d[name] = _merge_idx(idl)
if not is_merged[name]:
is_merged[name] = (1 != len(set([len(idx) for idx in [*idl, new_idl_d[name]]])))
@ -1101,10 +1101,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
for name in new_sample_names:
tmp_values = np.zeros(n_obs)
for i, item in enumerate(raveled_data):
tmp = item.r_values.get(name)
if tmp is None:
tmp = item.value
tmp_values[i] = tmp
tmp_values[i] = item.r_values.get(name, item.value)
if multi > 0:
tmp_values = np.array(tmp_values).reshape(data.shape)
new_r_values[name] = func(tmp_values, **kwargs)