mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-15 14:50:25 +01:00
feat: new_cov_names and new_sample_names added to derived_array
This commit is contained in:
parent
15dd10f19a
commit
5789c0cef6
1 changed files with 8 additions and 5 deletions
|
@ -1048,6 +1048,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
|||
raveled_data = data.ravel()
|
||||
|
||||
# Workaround for matrix operations containing non Obs data
|
||||
# TODO: Find more elegant solution here.
|
||||
for i_data in raveled_data:
|
||||
if isinstance(i_data, Obs):
|
||||
first_name = i_data.names[0]
|
||||
|
@ -1070,11 +1071,13 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
|||
|
||||
n_obs = len(raveled_data)
|
||||
new_names = sorted(set([y for x in [o.names for o in raveled_data] for y in x]))
|
||||
new_cov_names = sorted(set([y for x in [o.cov_names for o in raveled_data] for y in x]))
|
||||
new_sample_names = sorted(set(new_names) - set(new_cov_names))
|
||||
|
||||
is_merged = {name: (len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0) for name in new_names}
|
||||
is_merged = {name: (len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0) for name in new_sample_names}
|
||||
reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 0
|
||||
new_idl_d = {}
|
||||
for name in new_names:
|
||||
for name in new_sample_names:
|
||||
idl = []
|
||||
for i_data in raveled_data:
|
||||
tmp = i_data.idl.get(name)
|
||||
|
@ -1096,7 +1099,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
|||
multi = 1
|
||||
|
||||
new_r_values = {}
|
||||
for name in new_names:
|
||||
for name in new_sample_names:
|
||||
tmp_values = np.zeros(n_obs)
|
||||
for i, item in enumerate(raveled_data):
|
||||
tmp = item.r_values.get(name)
|
||||
|
@ -1140,7 +1143,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
|||
|
||||
if array_mode is True:
|
||||
d_extracted = {}
|
||||
for name in new_names:
|
||||
for name in new_sample_names:
|
||||
d_extracted[name] = []
|
||||
for i_dat, dat in enumerate(data):
|
||||
ens_length = len(new_idl_d[name])
|
||||
|
@ -1150,7 +1153,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
|||
new_deltas = {}
|
||||
new_grad = {}
|
||||
if array_mode is True:
|
||||
for name in new_names:
|
||||
for name in new_sample_names:
|
||||
ens_length = d_extracted[name][0].shape[-1]
|
||||
new_deltas[name] = np.zeros(ens_length)
|
||||
for i_dat, dat in enumerate(d_extracted[name]):
|
||||
|
|
Loading…
Add table
Reference in a new issue