From 17f77532253fe1286777ee3a8d69bad7a7b5d42d Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Thu, 11 Nov 2021 11:15:25 +0000 Subject: [PATCH 1/2] test for derived_array and irregular histories added --- tests/linalg_test.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 41a391a3..f2bd3e60 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -51,6 +51,29 @@ def test_multi_dot(): assert e.is_zero(), t +def test_matmul_irregular_histories(): + dim = 2 + length = 500 + + standard_array = [] + for i in range(dim ** 2): + standard_array.append(pe.Obs([np.random.normal(1.1, 0.2, length)], ['ens1'])) + standard_matrix = np.array(standard_array).reshape((dim, dim)) + + for idl in [range(1, 501, 2), range(250, 273), [2, 8, 19, 20, 78]]: + irregular_array = [] + for i in range(dim ** 2): + irregular_array.append(pe.Obs([np.random.normal(1.1, 0.2, len(idl))], ['ens1'], idl=[idl])) + irregular_matrix = np.array(irregular_array).reshape((dim, dim)) + + t1 = standard_matrix @ irregular_matrix + t2 = pe.linalg.matmul(standard_matrix, irregular_matrix) + + assert np.all([o.is_zero() for o in (t1 - t2).ravel()]) + assert np.all([o.is_merged for o in t1.ravel()]) + assert np.all([o.is_merged for o in t2.ravel()]) + + def test_matrix_inverse(): content = [] for t in range(9): From 0644ecf9aaeba0581605d6ecce038c54b135ef79 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Fri, 12 Nov 2021 09:39:18 +0000 Subject: [PATCH 2/2] bugs in derived_array and gamma_method fixed which caused odd behaviour in connection with irregular monte carlo chains --- pyerrors/linalg.py | 2 +- pyerrors/obs.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyerrors/linalg.py b/pyerrors/linalg.py index 18a6a98c..c1a042fb 100644 --- a/pyerrors/linalg.py +++ b/pyerrors/linalg.py @@ -86,7 +86,7 @@ def derived_array(func, data, **kwargs): for name in new_names: d_extracted[name] = [] for i_dat, dat in enumerate(data): - ens_length = new_idl_d[name][-1] - new_idl_d[name][0] + 1 + ens_length = len(new_idl_d[name]) d_extracted[name].append(np.array([_expand_deltas_for_merge(o.deltas[name], o.idl[name], o.shape[name], new_idl_d[name]) for o in dat.reshape(np.prod(dat.shape))]).reshape(dat.shape + (ens_length, ))) for i_val, new_val in np.ndenumerate(new_values): diff --git a/pyerrors/obs.py b/pyerrors/obs.py index 6a2b414a..085c9a95 100644 --- a/pyerrors/obs.py +++ b/pyerrors/obs.py @@ -248,7 +248,7 @@ class Obs: r_length = [] for r_name in e_content[e_name]: - if self.idl[r_name] is range: + if isinstance(self.idl[r_name], range): r_length.append(len(self.idl[r_name])) else: r_length.append((self.idl[r_name][-1] - self.idl[r_name][0] + 1)) @@ -339,7 +339,7 @@ class Obs: idx -- List or range of configs on which the deltas are defined. shape -- Number of configs in idx. """ - if type(idx) is range: + if isinstance(idx, range): return deltas else: ret = np.zeros(idx[-1] - idx[0] + 1)