feat: first working version of array_mode in dervived_observable

This commit is contained in:
Fabian Joswig 2021-12-02 12:50:08 +00:00
parent ed47d50286
commit 147bc6b24b
2 changed files with 22 additions and 120 deletions

View file

@ -1015,7 +1015,7 @@ def _filter_zeroes(deltas, idx, eps=Obs.filter_eps):
return deltas, idx
def derived_observable(func, data, **kwargs):
def derived_observable(func, data, array_mode=False, **kwargs):
"""Construct a derived Obs according to func(data, **kwargs) using automatic differentiation.
Parameters
@ -1138,14 +1138,28 @@ def derived_observable(func, data, **kwargs):
final_result = np.zeros(new_values.shape, dtype=object)
if array_mode is True:
d_extracted = {}
for name in new_names:
d_extracted[name] = []
for i_dat, dat in enumerate(data):
ens_length = len(new_idl_d[name])
d_extracted[name].append(np.array([_expand_deltas_for_merge(o.deltas[name], o.idl[name], o.shape[name], new_idl_d[name]) for o in dat.reshape(np.prod(dat.shape))]).reshape(dat.shape + (ens_length, )))
for i_val, new_val in np.ndenumerate(new_values):
new_deltas = {}
new_grad = {}
if array_mode is True:
for name in new_names:
ens_length = d_extracted[name][0].shape[-1]
new_deltas[name] = np.zeros(ens_length)
for i_dat, dat in enumerate(d_extracted[name]):
new_deltas[name] += np.tensordot(deriv[i_val + (i_dat, )], dat)
for j_obs, obs in np.ndenumerate(data):
for name in obs.names:
if name in obs.cov_names:
new_grad[name] = new_grad.get(name, 0) + deriv[i_val + j_obs] * obs.covobs[name].grad
else:
elif array_mode is False:
new_deltas[name] = new_deltas.get(name, 0) + deriv[i_val + j_obs] * _expand_deltas_for_merge(obs.deltas[name], obs.idl[name], obs.shape[name], new_idl_d[name])
new_covobs = {name: Covobs(0, allcov[name], name, grad=new_grad[name]) for name in new_grad}