mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-16 15:20:24 +01:00
feat: new_cov_names and new_sample_names added to derived_array
This commit is contained in:
parent
15dd10f19a
commit
5789c0cef6
1 changed files with 8 additions and 5 deletions
|
@ -1048,6 +1048,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
||||||
raveled_data = data.ravel()
|
raveled_data = data.ravel()
|
||||||
|
|
||||||
# Workaround for matrix operations containing non Obs data
|
# Workaround for matrix operations containing non Obs data
|
||||||
|
# TODO: Find more elegant solution here.
|
||||||
for i_data in raveled_data:
|
for i_data in raveled_data:
|
||||||
if isinstance(i_data, Obs):
|
if isinstance(i_data, Obs):
|
||||||
first_name = i_data.names[0]
|
first_name = i_data.names[0]
|
||||||
|
@ -1070,11 +1071,13 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
||||||
|
|
||||||
n_obs = len(raveled_data)
|
n_obs = len(raveled_data)
|
||||||
new_names = sorted(set([y for x in [o.names for o in raveled_data] for y in x]))
|
new_names = sorted(set([y for x in [o.names for o in raveled_data] for y in x]))
|
||||||
|
new_cov_names = sorted(set([y for x in [o.cov_names for o in raveled_data] for y in x]))
|
||||||
|
new_sample_names = sorted(set(new_names) - set(new_cov_names))
|
||||||
|
|
||||||
is_merged = {name: (len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0) for name in new_names}
|
is_merged = {name: (len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0) for name in new_sample_names}
|
||||||
reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 0
|
reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 0
|
||||||
new_idl_d = {}
|
new_idl_d = {}
|
||||||
for name in new_names:
|
for name in new_sample_names:
|
||||||
idl = []
|
idl = []
|
||||||
for i_data in raveled_data:
|
for i_data in raveled_data:
|
||||||
tmp = i_data.idl.get(name)
|
tmp = i_data.idl.get(name)
|
||||||
|
@ -1096,7 +1099,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
||||||
multi = 1
|
multi = 1
|
||||||
|
|
||||||
new_r_values = {}
|
new_r_values = {}
|
||||||
for name in new_names:
|
for name in new_sample_names:
|
||||||
tmp_values = np.zeros(n_obs)
|
tmp_values = np.zeros(n_obs)
|
||||||
for i, item in enumerate(raveled_data):
|
for i, item in enumerate(raveled_data):
|
||||||
tmp = item.r_values.get(name)
|
tmp = item.r_values.get(name)
|
||||||
|
@ -1140,7 +1143,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
||||||
|
|
||||||
if array_mode is True:
|
if array_mode is True:
|
||||||
d_extracted = {}
|
d_extracted = {}
|
||||||
for name in new_names:
|
for name in new_sample_names:
|
||||||
d_extracted[name] = []
|
d_extracted[name] = []
|
||||||
for i_dat, dat in enumerate(data):
|
for i_dat, dat in enumerate(data):
|
||||||
ens_length = len(new_idl_d[name])
|
ens_length = len(new_idl_d[name])
|
||||||
|
@ -1150,7 +1153,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
|
||||||
new_deltas = {}
|
new_deltas = {}
|
||||||
new_grad = {}
|
new_grad = {}
|
||||||
if array_mode is True:
|
if array_mode is True:
|
||||||
for name in new_names:
|
for name in new_sample_names:
|
||||||
ens_length = d_extracted[name][0].shape[-1]
|
ens_length = d_extracted[name][0].shape[-1]
|
||||||
new_deltas[name] = np.zeros(ens_length)
|
new_deltas[name] = np.zeros(ens_length)
|
||||||
for i_dat, dat in enumerate(d_extracted[name]):
|
for i_dat, dat in enumerate(d_extracted[name]):
|
||||||
|
|
Loading…
Add table
Reference in a new issue