diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 41a391a3..f2bd3e60 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -51,6 +51,29 @@ def test_multi_dot(): assert e.is_zero(), t +def test_matmul_irregular_histories(): + dim = 2 + length = 500 + + standard_array = [] + for i in range(dim ** 2): + standard_array.append(pe.Obs([np.random.normal(1.1, 0.2, length)], ['ens1'])) + standard_matrix = np.array(standard_array).reshape((dim, dim)) + + for idl in [range(1, 501, 2), range(250, 273), [2, 8, 19, 20, 78]]: + irregular_array = [] + for i in range(dim ** 2): + irregular_array.append(pe.Obs([np.random.normal(1.1, 0.2, len(idl))], ['ens1'], idl=[idl])) + irregular_matrix = np.array(irregular_array).reshape((dim, dim)) + + t1 = standard_matrix @ irregular_matrix + t2 = pe.linalg.matmul(standard_matrix, irregular_matrix) + + assert np.all([o.is_zero() for o in (t1 - t2).ravel()]) + assert np.all([o.is_merged for o in t1.ravel()]) + assert np.all([o.is_merged for o in t2.ravel()]) + + def test_matrix_inverse(): content = [] for t in range(9):