From 4c8d75888917ef7c256fdb34ffe1b80345986942 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Wed, 1 Dec 2021 12:19:32 +0000 Subject: [PATCH] test: test for linalg.einsum added --- pyerrors/linalg.py | 2 +- tests/linalg_test.py | 53 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/pyerrors/linalg.py b/pyerrors/linalg.py index bbb59367..bb44870d 100644 --- a/pyerrors/linalg.py +++ b/pyerrors/linalg.py @@ -298,7 +298,7 @@ def einsum(subscripts, *operands): if jack_einsum.dtype == complex: result = _imp_from_jack_c(jack_einsum, name, idl) elif jack_einsum.dtype == float: - result =_imp_from_jack(jack_einsum, name, idl) + result = _imp_from_jack(jack_einsum, name, idl) else: raise Exception("Result has unexpected datatype") diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 46ee6c89..d34da29f 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -93,6 +93,59 @@ def test_jack_matmul(): assert trace4.real.dvalue < 0.001 assert trace4.imag.dvalue < 0.001 + +def test_einsum(): + + def _perform_real_check(arr): + [o.gamma_method() for o in arr] + assert np.all([o.is_zero_within_error(0.001) for o in arr]) + assert np.all([o.dvalue < 0.001 for o in arr]) + + def _perform_complex_check(arr): + [o.gamma_method() for o in arr] + assert np.all([o.real.is_zero_within_error(0.001) for o in arr]) + assert np.all([o.real.dvalue < 0.001 for o in arr]) + assert np.all([o.imag.is_zero_within_error(0.001) for o in arr]) + assert np.all([o.imag.dvalue < 0.001 for o in arr]) + + + tt = [get_real_matrix(4), get_real_matrix(3)] + q = np.tensordot(tt[0], tt[1], 0) + c1 = tt[1] @ q + c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q) + check1 = c1 - c2 + _perform_real_check(check1.ravel()) + check2 = np.trace(tt[0]) - pe.linalg.einsum('ii', tt[0]) + _perform_real_check([check2]) + check3 = np.trace(tt[1]) - pe.linalg.einsum('ii', tt[1]) + _perform_real_check([check3]) + + tt = [get_real_matrix(4), np.random.random((3, 3))] + q = np.tensordot(tt[0], tt[1], 0) + c1 = tt[1] @ q + c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q) + check1 = c1 - c2 + _perform_real_check(check1.ravel()) + + tt = [get_complex_matrix(4), get_complex_matrix(3)] + q = np.tensordot(tt[0], tt[1], 0) + c1 = tt[1] @ q + c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q) + check1 = c1 - c2 + _perform_complex_check(check1.ravel()) + check2 = np.trace(tt[0]) - pe.linalg.einsum('ii', tt[0]) + _perform_complex_check([check2]) + check3 = np.trace(tt[1]) - pe.linalg.einsum('ii', tt[1]) + _perform_complex_check([check3]) + + tt = [get_complex_matrix(4), np.random.random((3, 3))] + q = np.tensordot(tt[0], tt[1], 0) + c1 = tt[1] @ q + c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q) + check1 = c1 - c2 + _perform_complex_check(check1.ravel()) + + def test_multi_dot(): for dim in [4, 6]: my_list = []