1
0
Fork 0
mirror of https://github.com/fjosw/pyerrors.git synced 2025-03-16 15:20:24 +01:00

feat: new_cov_names and new_sample_names added to derived_array

This commit is contained in:
Fabian Joswig 2021-12-02 16:54:51 +00:00
parent 15dd10f19a
commit 5789c0cef6

View file

@ -1048,6 +1048,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
raveled_data = data.ravel() raveled_data = data.ravel()
# Workaround for matrix operations containing non Obs data # Workaround for matrix operations containing non Obs data
# TODO: Find more elegant solution here.
for i_data in raveled_data: for i_data in raveled_data:
if isinstance(i_data, Obs): if isinstance(i_data, Obs):
first_name = i_data.names[0] first_name = i_data.names[0]
@ -1070,11 +1071,13 @@ def derived_observable(func, data, array_mode=False, **kwargs):
n_obs = len(raveled_data) n_obs = len(raveled_data)
new_names = sorted(set([y for x in [o.names for o in raveled_data] for y in x])) new_names = sorted(set([y for x in [o.names for o in raveled_data] for y in x]))
new_cov_names = sorted(set([y for x in [o.cov_names for o in raveled_data] for y in x]))
new_sample_names = sorted(set(new_names) - set(new_cov_names))
is_merged = {name: (len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0) for name in new_names} is_merged = {name: (len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0) for name in new_sample_names}
reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 0 reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 0
new_idl_d = {} new_idl_d = {}
for name in new_names: for name in new_sample_names:
idl = [] idl = []
for i_data in raveled_data: for i_data in raveled_data:
tmp = i_data.idl.get(name) tmp = i_data.idl.get(name)
@ -1096,7 +1099,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
multi = 1 multi = 1
new_r_values = {} new_r_values = {}
for name in new_names: for name in new_sample_names:
tmp_values = np.zeros(n_obs) tmp_values = np.zeros(n_obs)
for i, item in enumerate(raveled_data): for i, item in enumerate(raveled_data):
tmp = item.r_values.get(name) tmp = item.r_values.get(name)
@ -1140,7 +1143,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
if array_mode is True: if array_mode is True:
d_extracted = {} d_extracted = {}
for name in new_names: for name in new_sample_names:
d_extracted[name] = [] d_extracted[name] = []
for i_dat, dat in enumerate(data): for i_dat, dat in enumerate(data):
ens_length = len(new_idl_d[name]) ens_length = len(new_idl_d[name])
@ -1150,7 +1153,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
new_deltas = {} new_deltas = {}
new_grad = {} new_grad = {}
if array_mode is True: if array_mode is True:
for name in new_names: for name in new_sample_names:
ens_length = d_extracted[name][0].shape[-1] ens_length = d_extracted[name][0].shape[-1]
new_deltas[name] = np.zeros(ens_length) new_deltas[name] = np.zeros(ens_length)
for i_dat, dat in enumerate(d_extracted[name]): for i_dat, dat in enumerate(d_extracted[name]):