mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-05-14 19:43:41 +02:00
feat: alternative matrix multiplication routine jack_matmul implemented
This commit is contained in:
parent
972c8bd366
commit
865830af4c
2 changed files with 99 additions and 3 deletions
|
@ -1,7 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from autograd import jacobian
|
from autograd import jacobian
|
||||||
import autograd.numpy as anp # Thinly-wrapped numpy
|
import autograd.numpy as anp # Thinly-wrapped numpy
|
||||||
from .obs import derived_observable, CObs, Obs, _merge_idx, _expand_deltas_for_merge, _filter_zeroes
|
from .obs import derived_observable, CObs, Obs, _merge_idx, _expand_deltas_for_merge, _filter_zeroes, import_jackknife
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from autograd.extend import defvjp
|
from autograd.extend import defvjp
|
||||||
|
@ -121,8 +121,13 @@ def derived_array(func, data, **kwargs):
|
||||||
def matmul(*operands):
|
def matmul(*operands):
|
||||||
"""Matrix multiply all operands.
|
"""Matrix multiply all operands.
|
||||||
|
|
||||||
Supports real and complex valued matrices and is faster compared to
|
Parameters
|
||||||
standard multiplication via the @ operator.
|
----------
|
||||||
|
operands : numpy.ndarray
|
||||||
|
Arbitrary number of 2d-numpy arrays which can be real or complex
|
||||||
|
Obs valued.
|
||||||
|
|
||||||
|
This implementation is faster compared to standard multiplication via the @ operator.
|
||||||
"""
|
"""
|
||||||
if any(isinstance(o[0, 0], CObs) for o in operands):
|
if any(isinstance(o[0, 0], CObs) for o in operands):
|
||||||
extended_operands = []
|
extended_operands = []
|
||||||
|
@ -169,6 +174,56 @@ def matmul(*operands):
|
||||||
return derived_array(multi_dot, operands)
|
return derived_array(multi_dot, operands)
|
||||||
|
|
||||||
|
|
||||||
|
def jack_matmul(a, b):
|
||||||
|
"""Matrix multiply both operands making use of the jackknife approximation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
a : numpy.ndarray
|
||||||
|
First matrix, can be real or complex Obs valued
|
||||||
|
b : numpy.ndarray
|
||||||
|
Second matrix, can be real or complex Obs valued
|
||||||
|
|
||||||
|
For large matrices this is considerably faster compared to matmul.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if any(isinstance(o[0, 0], CObs) for o in [a, b]):
|
||||||
|
def _exp_to_jack(matrix):
|
||||||
|
base_matrix = np.empty_like(matrix)
|
||||||
|
for (n, m), entry in np.ndenumerate(matrix):
|
||||||
|
base_matrix[n, m] = entry.real.export_jackknife() + 1j * entry.imag.export_jackknife()
|
||||||
|
return base_matrix
|
||||||
|
|
||||||
|
def _imp_from_jack(matrix, name):
|
||||||
|
base_matrix = np.empty_like(matrix)
|
||||||
|
for (n, m), entry in np.ndenumerate(matrix):
|
||||||
|
base_matrix[n, m] = CObs(import_jackknife(entry.real, name),
|
||||||
|
import_jackknife(entry.imag, name))
|
||||||
|
return base_matrix
|
||||||
|
|
||||||
|
j_a = _exp_to_jack(a)
|
||||||
|
j_b = _exp_to_jack(b)
|
||||||
|
r = j_a @ j_b
|
||||||
|
return _imp_from_jack(r, a.ravel()[0].real.names[0])
|
||||||
|
else:
|
||||||
|
def _exp_to_jack(matrix):
|
||||||
|
base_matrix = np.empty_like(matrix)
|
||||||
|
for (n, m), entry in np.ndenumerate(matrix):
|
||||||
|
base_matrix[n, m] = entry.export_jackknife()
|
||||||
|
return base_matrix
|
||||||
|
|
||||||
|
def _imp_from_jack(matrix, name):
|
||||||
|
base_matrix = np.empty_like(matrix)
|
||||||
|
for (n, m), entry in np.ndenumerate(matrix):
|
||||||
|
base_matrix[n, m] = import_jackknife(entry, name)
|
||||||
|
return base_matrix
|
||||||
|
|
||||||
|
j_a = _exp_to_jack(a)
|
||||||
|
j_b = _exp_to_jack(b)
|
||||||
|
r = j_a @ j_b
|
||||||
|
return _imp_from_jack(r, a.ravel()[0].names[0])
|
||||||
|
|
||||||
|
|
||||||
def inv(x):
|
def inv(x):
|
||||||
"""Inverse of Obs or CObs valued matrices."""
|
"""Inverse of Obs or CObs valued matrices."""
|
||||||
return _mat_mat_op(anp.linalg.inv, x)
|
return _mat_mat_op(anp.linalg.inv, x)
|
||||||
|
|
|
@ -7,6 +7,27 @@ import pytest
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_real_matrix(dimension):
|
||||||
|
base_matrix = np.empty((dimension, dimension), dtype=object)
|
||||||
|
for (n, m), entry in np.ndenumerate(base_matrix):
|
||||||
|
exponent_real = np.random.normal(0, 1)
|
||||||
|
exponent_imag = np.random.normal(0, 1)
|
||||||
|
base_matrix[n, m] = pe.Obs([np.random.normal(1.0, 0.1, 100)], ['t'])
|
||||||
|
|
||||||
|
|
||||||
|
return base_matrix
|
||||||
|
|
||||||
|
def get_complex_matrix(dimension):
|
||||||
|
base_matrix = np.empty((dimension, dimension), dtype=object)
|
||||||
|
for (n, m), entry in np.ndenumerate(base_matrix):
|
||||||
|
exponent_real = np.random.normal(0, 1)
|
||||||
|
exponent_imag = np.random.normal(0, 1)
|
||||||
|
base_matrix[n, m] = pe.CObs(pe.Obs([np.random.normal(1.0, 0.1, 100)], ['t']),
|
||||||
|
pe.Obs([np.random.normal(1.0, 0.1, 100)], ['t']))
|
||||||
|
|
||||||
|
return base_matrix
|
||||||
|
|
||||||
|
|
||||||
def test_matmul():
|
def test_matmul():
|
||||||
for dim in [4, 8]:
|
for dim in [4, 8]:
|
||||||
my_list = []
|
my_list = []
|
||||||
|
@ -29,6 +50,26 @@ def test_matmul():
|
||||||
assert e.is_zero(), t
|
assert e.is_zero(), t
|
||||||
|
|
||||||
|
|
||||||
|
def test_jack_matmul():
|
||||||
|
tt = get_real_matrix(8)
|
||||||
|
check1 = pe.linalg.jack_matmul(tt, tt) - pe.linalg.matmul(tt, tt)
|
||||||
|
[o.gamma_method() for o in check1.ravel()]
|
||||||
|
assert np.all([o.is_zero_within_error(0.1) for o in check1.ravel()])
|
||||||
|
trace1 = np.trace(check1)
|
||||||
|
trace1.gamma_method()
|
||||||
|
assert trace1.dvalue < 0.001
|
||||||
|
|
||||||
|
tt2 = get_complex_matrix(8)
|
||||||
|
check2 = pe.linalg.jack_matmul(tt2, tt2) - pe.linalg.matmul(tt2, tt2)
|
||||||
|
[o.gamma_method() for o in check2.ravel()]
|
||||||
|
assert np.all([o.real.is_zero_within_error(0.1) for o in check2.ravel()])
|
||||||
|
assert np.all([o.imag.is_zero_within_error(0.1) for o in check2.ravel()])
|
||||||
|
trace2 = np.trace(check2)
|
||||||
|
trace2.gamma_method()
|
||||||
|
assert trace2.real.dvalue < 0.001
|
||||||
|
assert trace2.imag.dvalue < 0.001
|
||||||
|
|
||||||
|
|
||||||
def test_multi_dot():
|
def test_multi_dot():
|
||||||
for dim in [4, 8]:
|
for dim in [4, 8]:
|
||||||
my_list = []
|
my_list = []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue