diff --git a/pyerrors/obs.py b/pyerrors/obs.py index ea10e578..f7029a63 100644 --- a/pyerrors/obs.py +++ b/pyerrors/obs.py @@ -1146,14 +1146,13 @@ def derived_observable(func, data, array_mode=False, **kwargs): final_result = np.zeros(new_values.shape, dtype=object) - # TODO: array mode does not work when matrices are defined on differenet ensembles if array_mode is True: - class Zero_grad(): + class _Zero_grad(): def __init__(self): self.grad = 0 - zero_grad = Zero_grad() + zero_grad = _Zero_grad() d_extracted = {} g_extracted = {} @@ -1161,7 +1160,7 @@ def derived_observable(func, data, array_mode=False, **kwargs): d_extracted[name] = [] for i_dat, dat in enumerate(data): ens_length = len(new_idl_d[name]) - d_extracted[name].append(np.array([_expand_deltas_for_merge(o.deltas[name], o.idl[name], o.shape[name], new_idl_d[name]) for o in dat.reshape(np.prod(dat.shape))]).reshape(dat.shape + (ens_length, ))) + d_extracted[name].append(np.array([_expand_deltas_for_merge(o.deltas.get(name, np.zeros(ens_length)), o.idl.get(name, new_idl_d[name]), o.shape.get(name, ens_length), new_idl_d[name]) for o in dat.reshape(np.prod(dat.shape))]).reshape(dat.shape + (ens_length, ))) for name in new_cov_names: g_extracted[name] = [] for i_dat, dat in enumerate(data): diff --git a/tests/linalg_test.py b/tests/linalg_test.py index bdbce655..73bcf5bb 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -29,10 +29,10 @@ def get_complex_matrix(dimension): def test_matmul(): - for dim in [4, 8]: + for dim in [4, 6]: for const in [1, pe.cov_Obs(1.0, 0.002, 'cov')]: my_list = [] - length = 1000 + np.random.randint(200) + length = 100 + np.random.randint(200) for i in range(dim ** 2): my_list.append(pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2'])) my_array = const * np.array(my_list).reshape((dim, dim)) @@ -41,7 +41,7 @@ def test_matmul(): assert e.is_zero(), t my_list = [] - length = 1000 + np.random.randint(200) + length = 100 + np.random.randint(200) for i in range(dim ** 2): my_list.append(pe.CObs(pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2']), pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2']))) @@ -189,7 +189,7 @@ def test_matmul_irregular_histories(): standard_array = [] for i in range(dim ** 2): standard_array.append(pe.Obs([np.random.normal(1.1, 0.2, length)], ['ens1'])) - standard_matrix = np.array(standard_array).reshape((dim, dim)) * pe.cov_Obs(1.0, 0.002, 'cov') # * pe.pseudo_Obs(0.1, 0.002, 'qr') + standard_matrix = np.array(standard_array).reshape((dim, dim)) * pe.cov_Obs(1.0, 0.002, 'cov') * pe.pseudo_Obs(0.1, 0.002, 'qr') for idl in [range(1, 501, 2), range(250, 273), [2, 8, 19, 20, 78]]: irregular_array = []