mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-07-01 01:09:27 +02:00
feat: first working version of array_mode in dervived_observable
This commit is contained in:
parent
ed47d50286
commit
147bc6b24b
2 changed files with 22 additions and 120 deletions
|
@ -1015,7 +1015,7 @@ def _filter_zeroes(deltas, idx, eps=Obs.filter_eps):
|
|||
return deltas, idx
|
||||
|
||||
|
||||
def derived_observable(func, data, **kwargs):
|
||||
def derived_observable(func, data, array_mode=False, **kwargs):
|
||||
"""Construct a derived Obs according to func(data, **kwargs) using automatic differentiation.
|
||||
|
||||
Parameters
|
||||
|
@ -1138,14 +1138,28 @@ def derived_observable(func, data, **kwargs):
|
|||
|
||||
final_result = np.zeros(new_values.shape, dtype=object)
|
||||
|
||||
if array_mode is True:
|
||||
d_extracted = {}
|
||||
for name in new_names:
|
||||
d_extracted[name] = []
|
||||
for i_dat, dat in enumerate(data):
|
||||
ens_length = len(new_idl_d[name])
|
||||
d_extracted[name].append(np.array([_expand_deltas_for_merge(o.deltas[name], o.idl[name], o.shape[name], new_idl_d[name]) for o in dat.reshape(np.prod(dat.shape))]).reshape(dat.shape + (ens_length, )))
|
||||
|
||||
for i_val, new_val in np.ndenumerate(new_values):
|
||||
new_deltas = {}
|
||||
new_grad = {}
|
||||
if array_mode is True:
|
||||
for name in new_names:
|
||||
ens_length = d_extracted[name][0].shape[-1]
|
||||
new_deltas[name] = np.zeros(ens_length)
|
||||
for i_dat, dat in enumerate(d_extracted[name]):
|
||||
new_deltas[name] += np.tensordot(deriv[i_val + (i_dat, )], dat)
|
||||
for j_obs, obs in np.ndenumerate(data):
|
||||
for name in obs.names:
|
||||
if name in obs.cov_names:
|
||||
new_grad[name] = new_grad.get(name, 0) + deriv[i_val + j_obs] * obs.covobs[name].grad
|
||||
else:
|
||||
elif array_mode is False:
|
||||
new_deltas[name] = new_deltas.get(name, 0) + deriv[i_val + j_obs] * _expand_deltas_for_merge(obs.deltas[name], obs.idl[name], obs.shape[name], new_idl_d[name])
|
||||
|
||||
new_covobs = {name: Covobs(0, allcov[name], name, grad=new_grad[name]) for name in new_grad}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue