diff --git a/pyerrors/linalg.py b/pyerrors/linalg.py index 075b9605..fb121e6a 100644 --- a/pyerrors/linalg.py +++ b/pyerrors/linalg.py @@ -205,7 +205,10 @@ def jack_matmul(*operands): r = _exp_to_jack(operands[0]) for op in operands[1:]: - r = r @ _exp_to_jack(op) + if isinstance(op[0, 0], CObs): + r = r @ _exp_to_jack(op) + else: + r = r @ op return _imp_from_jack(r) else: name = operands[0][0, 0].names[0] @@ -225,7 +228,10 @@ def jack_matmul(*operands): r = _exp_to_jack(operands[0]) for op in operands[1:]: - r = r @ _exp_to_jack(op) + if isinstance(op[0, 0], Obs): + r = r @ _exp_to_jack(op) + else: + r = r @ op return _imp_from_jack(r) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5af84b52..46ee6c89 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -55,23 +55,46 @@ def test_jack_matmul(): check1 = pe.linalg.jack_matmul(tt, tt) - pe.linalg.matmul(tt, tt) [o.gamma_method() for o in check1.ravel()] assert np.all([o.is_zero_within_error(0.1) for o in check1.ravel()]) + assert np.all([o.dvalue < 0.001 for o in check1.ravel()]) trace1 = np.trace(check1) trace1.gamma_method() assert trace1.dvalue < 0.001 - tt2 = get_complex_matrix(8) - check2 = pe.linalg.jack_matmul(tt2, tt2) - pe.linalg.matmul(tt2, tt2) + tr = np.random.rand(8, 8) + check2 = pe.linalg.jack_matmul(tt, tr) - pe.linalg.matmul(tt, tr) [o.gamma_method() for o in check2.ravel()] - assert np.all([o.real.is_zero_within_error(0.1) for o in check2.ravel()]) - assert np.all([o.imag.is_zero_within_error(0.1) for o in check2.ravel()]) + assert np.all([o.is_zero_within_error(0.1) for o in check2.ravel()]) + assert np.all([o.dvalue < 0.001 for o in check2.ravel()]) trace2 = np.trace(check2) trace2.gamma_method() - assert trace2.real.dvalue < 0.001 - assert trace2.imag.dvalue < 0.001 + assert trace2.dvalue < 0.001 + tt2 = get_complex_matrix(8) + check3 = pe.linalg.jack_matmul(tt2, tt2) - pe.linalg.matmul(tt2, tt2) + [o.gamma_method() for o in check3.ravel()] + assert np.all([o.real.is_zero_within_error(0.1) for o in check3.ravel()]) + assert np.all([o.imag.is_zero_within_error(0.1) for o in check3.ravel()]) + assert np.all([o.real.dvalue < 0.001 for o in check3.ravel()]) + assert np.all([o.imag.dvalue < 0.001 for o in check3.ravel()]) + trace3 = np.trace(check3) + trace3.gamma_method() + assert trace3.real.dvalue < 0.001 + assert trace3.imag.dvalue < 0.001 + + tr2 = np.random.rand(8, 8) + 1j * np.random.rand(8, 8) + check4 = pe.linalg.jack_matmul(tt2, tr2) - pe.linalg.matmul(tt2, tr2) + [o.gamma_method() for o in check4.ravel()] + assert np.all([o.real.is_zero_within_error(0.1) for o in check4.ravel()]) + assert np.all([o.imag.is_zero_within_error(0.1) for o in check4.ravel()]) + assert np.all([o.real.dvalue < 0.001 for o in check4.ravel()]) + assert np.all([o.imag.dvalue < 0.001 for o in check4.ravel()]) + trace4 = np.trace(check4) + trace4.gamma_method() + assert trace4.real.dvalue < 0.001 + assert trace4.imag.dvalue < 0.001 def test_multi_dot(): - for dim in [4, 8]: + for dim in [4, 6]: my_list = [] length = 1000 + np.random.randint(200) for i in range(dim ** 2): @@ -167,8 +190,8 @@ def test_complex_matrix_inverse(): base_matrix = np.empty((dimension, dimension), dtype=object) matrix = np.empty((dimension, dimension), dtype=complex) for (n, m), entry in np.ndenumerate(base_matrix): - exponent_real = np.random.normal(3, 5) - exponent_imag = np.random.normal(3, 5) + exponent_real = np.random.normal(2, 3) + exponent_imag = np.random.normal(2, 3) base_matrix[n, m] = pe.CObs(pe.pseudo_Obs(2 + 10 ** exponent_real, 10 ** (exponent_real - 1), 't'), pe.pseudo_Obs(2 + 10 ** exponent_imag, 10 ** (exponent_imag - 1), 't'))