mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-05-14 19:43:41 +02:00
feat: alternative matrix multiplication routine jack_matmul implemented
This commit is contained in:
parent
972c8bd366
commit
865830af4c
2 changed files with 99 additions and 3 deletions
|
@ -7,6 +7,27 @@ import pytest
|
|||
np.random.seed(0)
|
||||
|
||||
|
||||
def get_real_matrix(dimension):
|
||||
base_matrix = np.empty((dimension, dimension), dtype=object)
|
||||
for (n, m), entry in np.ndenumerate(base_matrix):
|
||||
exponent_real = np.random.normal(0, 1)
|
||||
exponent_imag = np.random.normal(0, 1)
|
||||
base_matrix[n, m] = pe.Obs([np.random.normal(1.0, 0.1, 100)], ['t'])
|
||||
|
||||
|
||||
return base_matrix
|
||||
|
||||
def get_complex_matrix(dimension):
|
||||
base_matrix = np.empty((dimension, dimension), dtype=object)
|
||||
for (n, m), entry in np.ndenumerate(base_matrix):
|
||||
exponent_real = np.random.normal(0, 1)
|
||||
exponent_imag = np.random.normal(0, 1)
|
||||
base_matrix[n, m] = pe.CObs(pe.Obs([np.random.normal(1.0, 0.1, 100)], ['t']),
|
||||
pe.Obs([np.random.normal(1.0, 0.1, 100)], ['t']))
|
||||
|
||||
return base_matrix
|
||||
|
||||
|
||||
def test_matmul():
|
||||
for dim in [4, 8]:
|
||||
my_list = []
|
||||
|
@ -29,6 +50,26 @@ def test_matmul():
|
|||
assert e.is_zero(), t
|
||||
|
||||
|
||||
def test_jack_matmul():
|
||||
tt = get_real_matrix(8)
|
||||
check1 = pe.linalg.jack_matmul(tt, tt) - pe.linalg.matmul(tt, tt)
|
||||
[o.gamma_method() for o in check1.ravel()]
|
||||
assert np.all([o.is_zero_within_error(0.1) for o in check1.ravel()])
|
||||
trace1 = np.trace(check1)
|
||||
trace1.gamma_method()
|
||||
assert trace1.dvalue < 0.001
|
||||
|
||||
tt2 = get_complex_matrix(8)
|
||||
check2 = pe.linalg.jack_matmul(tt2, tt2) - pe.linalg.matmul(tt2, tt2)
|
||||
[o.gamma_method() for o in check2.ravel()]
|
||||
assert np.all([o.real.is_zero_within_error(0.1) for o in check2.ravel()])
|
||||
assert np.all([o.imag.is_zero_within_error(0.1) for o in check2.ravel()])
|
||||
trace2 = np.trace(check2)
|
||||
trace2.gamma_method()
|
||||
assert trace2.real.dvalue < 0.001
|
||||
assert trace2.imag.dvalue < 0.001
|
||||
|
||||
|
||||
def test_multi_dot():
|
||||
for dim in [4, 8]:
|
||||
my_list = []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue