From d4b86f5f732ca7b101373a2416c89c43479c7056 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Mon, 25 Oct 2021 13:58:16 +0100 Subject: [PATCH] linalg.matmul now works with any number of operands --- pyerrors/linalg.py | 46 +++++++++++++++++++++++++++++++++++++------- tests/test_linalg.py | 22 +++++++++++++++++++++ 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/pyerrors/linalg.py b/pyerrors/linalg.py index e05ecdda..99fa0bb3 100644 --- a/pyerrors/linalg.py +++ b/pyerrors/linalg.py @@ -124,18 +124,50 @@ def derived_array(func, data, **kwargs): return final_result -def matmul(x1, x2): - if isinstance(x1[0, 0], CObs) or isinstance(x2[0, 0], CObs): - Lr, Li = np.vectorize(lambda x: (np.real(x), np.imag(x)))(x1) - Rr, Ri = np.vectorize(lambda x: (np.real(x), np.imag(x)))(x2) - Nr = derived_array(lambda x: x[0] @ x[2] - x[1] @ x[3], [Lr, Li, Rr, Ri]) - Ni = derived_array(lambda x: x[0] @ x[3] + x[1] @ x[2], [Lr, Li, Rr, Ri]) +def matmul(*operands): + if any(isinstance(o[0, 0], CObs) for o in operands): + extended_operands = [] + for op in operands: + tmp = np.vectorize(lambda x: (np.real(x), np.imag(x)))(op) + extended_operands.append(tmp[0]) + extended_operands.append(tmp[1]) + + def multi_dot(operands, part): + stack_r = operands[0] + stack_i = operands[1] + for op_r, op_i in zip(operands[2::2], operands[3::2]): + tmp_r = stack_r @ op_r - stack_i @ op_i + tmp_i = stack_r @ op_i + stack_i @ op_r + + stack_r = tmp_r + stack_i = tmp_i + + if part == 'Real': + return stack_r + else: + return stack_i + + def multi_dot_r(operands): + return multi_dot(operands, 'Real') + + def multi_dot_i(operands): + return multi_dot(operands, 'Imag') + + Nr = derived_array(multi_dot_r, extended_operands) + Ni = derived_array(multi_dot_i, extended_operands) + res = np.empty_like(Nr) for (n, m), entry in np.ndenumerate(Nr): res[n, m] = CObs(Nr[n, m], Ni[n, m]) + return res else: - return derived_array(lambda x: x[0] @ x[1], [x1, x2]) + def multi_dot(operands): + stack = operands[0] + for op in operands[1:]: + stack = stack @ op + return stack + return derived_array(multi_dot, operands) def inv(x): diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 49a1aaa0..3350e9c9 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -29,6 +29,28 @@ def test_matmul(): assert e.is_zero(), t +def test_multi_dot(): + for dim in [4, 8]: + my_list = [] + length = 1000 + np.random.randint(200) + for i in range(dim ** 2): + my_list.append(pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2'])) + my_array = np.array(my_list).reshape((dim, dim)) + tt = pe.linalg.matmul(my_array, my_array, my_array, my_array) - my_array @ my_array @ my_array @ my_array + for t, e in np.ndenumerate(tt): + assert e.is_zero(), t + + my_list = [] + length = 1000 + np.random.randint(200) + for i in range(dim ** 2): + my_list.append(pe.CObs(pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2']), + pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2']))) + my_array = np.array(my_list).reshape((dim, dim)) + tt = pe.linalg.matmul(my_array, my_array, my_array, my_array) - my_array @ my_array @ my_array @ my_array + for t, e in np.ndenumerate(tt): + assert e.is_zero(), t + + def test_matrix_inverse(): content = [] for t in range(9):