From 3f0040a81545b7ca0847d0718d749fe49db1253d Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Wed, 8 Dec 2021 15:09:40 +0000 Subject: [PATCH] refactor: generation of new r_values in derived_observable simplified. --- pyerrors/obs.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pyerrors/obs.py b/pyerrors/obs.py index 7b13d8dd..681a398b 100644 --- a/pyerrors/obs.py +++ b/pyerrors/obs.py @@ -1079,9 +1079,9 @@ def derived_observable(func, data, array_mode=False, **kwargs): for name in new_sample_names: idl = [] for i_data in raveled_data: - tmp = i_data.idl.get(name) - if tmp is not None: - idl.append(tmp) + tmp_idl = i_data.idl.get(name) + if tmp_idl is not None: + idl.append(tmp_idl) new_idl_d[name] = _merge_idx(idl) if not is_merged[name]: is_merged[name] = (1 != len(set([len(idx) for idx in [*idl, new_idl_d[name]]]))) @@ -1101,10 +1101,7 @@ def derived_observable(func, data, array_mode=False, **kwargs): for name in new_sample_names: tmp_values = np.zeros(n_obs) for i, item in enumerate(raveled_data): - tmp = item.r_values.get(name) - if tmp is None: - tmp = item.value - tmp_values[i] = tmp + tmp_values[i] = item.r_values.get(name, item.value) if multi > 0: tmp_values = np.array(tmp_values).reshape(data.shape) new_r_values[name] = func(tmp_values, **kwargs)