mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-15 06:40:24 +01:00
test: test for linalg.einsum added
This commit is contained in:
parent
019196bad3
commit
4c8d758889
2 changed files with 54 additions and 1 deletions
|
@ -298,7 +298,7 @@ def einsum(subscripts, *operands):
|
|||
if jack_einsum.dtype == complex:
|
||||
result = _imp_from_jack_c(jack_einsum, name, idl)
|
||||
elif jack_einsum.dtype == float:
|
||||
result =_imp_from_jack(jack_einsum, name, idl)
|
||||
result = _imp_from_jack(jack_einsum, name, idl)
|
||||
else:
|
||||
raise Exception("Result has unexpected datatype")
|
||||
|
||||
|
|
|
@ -93,6 +93,59 @@ def test_jack_matmul():
|
|||
assert trace4.real.dvalue < 0.001
|
||||
assert trace4.imag.dvalue < 0.001
|
||||
|
||||
|
||||
def test_einsum():
|
||||
|
||||
def _perform_real_check(arr):
|
||||
[o.gamma_method() for o in arr]
|
||||
assert np.all([o.is_zero_within_error(0.001) for o in arr])
|
||||
assert np.all([o.dvalue < 0.001 for o in arr])
|
||||
|
||||
def _perform_complex_check(arr):
|
||||
[o.gamma_method() for o in arr]
|
||||
assert np.all([o.real.is_zero_within_error(0.001) for o in arr])
|
||||
assert np.all([o.real.dvalue < 0.001 for o in arr])
|
||||
assert np.all([o.imag.is_zero_within_error(0.001) for o in arr])
|
||||
assert np.all([o.imag.dvalue < 0.001 for o in arr])
|
||||
|
||||
|
||||
tt = [get_real_matrix(4), get_real_matrix(3)]
|
||||
q = np.tensordot(tt[0], tt[1], 0)
|
||||
c1 = tt[1] @ q
|
||||
c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q)
|
||||
check1 = c1 - c2
|
||||
_perform_real_check(check1.ravel())
|
||||
check2 = np.trace(tt[0]) - pe.linalg.einsum('ii', tt[0])
|
||||
_perform_real_check([check2])
|
||||
check3 = np.trace(tt[1]) - pe.linalg.einsum('ii', tt[1])
|
||||
_perform_real_check([check3])
|
||||
|
||||
tt = [get_real_matrix(4), np.random.random((3, 3))]
|
||||
q = np.tensordot(tt[0], tt[1], 0)
|
||||
c1 = tt[1] @ q
|
||||
c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q)
|
||||
check1 = c1 - c2
|
||||
_perform_real_check(check1.ravel())
|
||||
|
||||
tt = [get_complex_matrix(4), get_complex_matrix(3)]
|
||||
q = np.tensordot(tt[0], tt[1], 0)
|
||||
c1 = tt[1] @ q
|
||||
c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q)
|
||||
check1 = c1 - c2
|
||||
_perform_complex_check(check1.ravel())
|
||||
check2 = np.trace(tt[0]) - pe.linalg.einsum('ii', tt[0])
|
||||
_perform_complex_check([check2])
|
||||
check3 = np.trace(tt[1]) - pe.linalg.einsum('ii', tt[1])
|
||||
_perform_complex_check([check3])
|
||||
|
||||
tt = [get_complex_matrix(4), np.random.random((3, 3))]
|
||||
q = np.tensordot(tt[0], tt[1], 0)
|
||||
c1 = tt[1] @ q
|
||||
c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q)
|
||||
check1 = c1 - c2
|
||||
_perform_complex_check(check1.ravel())
|
||||
|
||||
|
||||
def test_multi_dot():
|
||||
for dim in [4, 6]:
|
||||
my_list = []
|
||||
|
|
Loading…
Add table
Reference in a new issue