fix: array mode now also works with covobs with N>1

This commit is contained in:
Fabian Joswig 2021-12-07 08:09:38 +00:00
parent df6b151c13
commit e8bcf8de6f
2 changed files with 10 additions and 8 deletions

View file

@ -1148,23 +1148,25 @@ def derived_observable(func, data, array_mode=False, **kwargs):
if array_mode is True:
class _Zero_grad():
def __init__(self):
self.grad = 0
new_covobs_lengths = dict(set([y for x in [[(n, o.covobs[n].N) for n in o.cov_names] for o in raveled_data] for y in x]))
zero_grad = _Zero_grad()
class _Zero_grad():
def __init__(self, N):
# self.grad = np.zeros(N)
self.grad = np.zeros((N, 1))
d_extracted = {}
g_extracted = {}
for name in new_sample_names:
d_extracted[name] = []
ens_length = len(new_idl_d[name])
for i_dat, dat in enumerate(data):
ens_length = len(new_idl_d[name])
d_extracted[name].append(np.array([_expand_deltas_for_merge(o.deltas.get(name, np.zeros(ens_length)), o.idl.get(name, new_idl_d[name]), o.shape.get(name, ens_length), new_idl_d[name]) for o in dat.reshape(np.prod(dat.shape))]).reshape(dat.shape + (ens_length, )))
for name in new_cov_names:
g_extracted[name] = []
zero_grad = _Zero_grad(new_covobs_lengths[name])
for i_dat, dat in enumerate(data):
g_extracted[name].append(np.array([o.covobs.get(name, zero_grad).grad for o in dat.reshape(np.prod(dat.shape))]).reshape(dat.shape + (1, )))
g_extracted[name].append(np.array([o.covobs.get(name, zero_grad).grad for o in dat.reshape(np.prod(dat.shape))]).reshape(dat.shape + (new_covobs_lengths[name], 1)))
for i_val, new_val in np.ndenumerate(new_values):
new_deltas = {}