fix: array mode now works with elements defined on different ensembles

This commit is contained in:
Fabian Joswig 2021-12-07 07:36:24 +00:00
parent b0610544a8
commit df6b151c13
2 changed files with 7 additions and 8 deletions

View file

@ -1146,14 +1146,13 @@ def derived_observable(func, data, array_mode=False, **kwargs):
final_result = np.zeros(new_values.shape, dtype=object)
# TODO: array mode does not work when matrices are defined on differenet ensembles
if array_mode is True:
class Zero_grad():
class _Zero_grad():
def __init__(self):
self.grad = 0
zero_grad = Zero_grad()
zero_grad = _Zero_grad()
d_extracted = {}
g_extracted = {}
@ -1161,7 +1160,7 @@ def derived_observable(func, data, array_mode=False, **kwargs):
d_extracted[name] = []
for i_dat, dat in enumerate(data):
ens_length = len(new_idl_d[name])
d_extracted[name].append(np.array([_expand_deltas_for_merge(o.deltas[name], o.idl[name], o.shape[name], new_idl_d[name]) for o in dat.reshape(np.prod(dat.shape))]).reshape(dat.shape + (ens_length, )))
d_extracted[name].append(np.array([_expand_deltas_for_merge(o.deltas.get(name, np.zeros(ens_length)), o.idl.get(name, new_idl_d[name]), o.shape.get(name, ens_length), new_idl_d[name]) for o in dat.reshape(np.prod(dat.shape))]).reshape(dat.shape + (ens_length, )))
for name in new_cov_names:
g_extracted[name] = []
for i_dat, dat in enumerate(data):