From 865830af4c929a8d196b4d64dcea20e61d28815f Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Wed, 17 Nov 2021 16:57:08 +0000 Subject: [PATCH] feat: alternative matrix multiplication routine jack_matmul implemented --- pyerrors/linalg.py | 61 +++++++++++++++++++++++++++++++++++++++++--- tests/linalg_test.py | 41 +++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/pyerrors/linalg.py b/pyerrors/linalg.py index d993f291..0dc9b9cd 100644 --- a/pyerrors/linalg.py +++ b/pyerrors/linalg.py @@ -1,7 +1,7 @@ import numpy as np from autograd import jacobian import autograd.numpy as anp # Thinly-wrapped numpy -from .obs import derived_observable, CObs, Obs, _merge_idx, _expand_deltas_for_merge, _filter_zeroes +from .obs import derived_observable, CObs, Obs, _merge_idx, _expand_deltas_for_merge, _filter_zeroes, import_jackknife from functools import partial from autograd.extend import defvjp @@ -121,8 +121,13 @@ def derived_array(func, data, **kwargs): def matmul(*operands): """Matrix multiply all operands. - Supports real and complex valued matrices and is faster compared to - standard multiplication via the @ operator. + Parameters + ---------- + operands : numpy.ndarray + Arbitrary number of 2d-numpy arrays which can be real or complex + Obs valued. + + This implementation is faster compared to standard multiplication via the @ operator. """ if any(isinstance(o[0, 0], CObs) for o in operands): extended_operands = [] @@ -169,6 +174,56 @@ def matmul(*operands): return derived_array(multi_dot, operands) +def jack_matmul(a, b): + """Matrix multiply both operands making use of the jackknife approximation. + + Parameters + ---------- + a : numpy.ndarray + First matrix, can be real or complex Obs valued + b : numpy.ndarray + Second matrix, can be real or complex Obs valued + + For large matrices this is considerably faster compared to matmul. + """ + + if any(isinstance(o[0, 0], CObs) for o in [a, b]): + def _exp_to_jack(matrix): + base_matrix = np.empty_like(matrix) + for (n, m), entry in np.ndenumerate(matrix): + base_matrix[n, m] = entry.real.export_jackknife() + 1j * entry.imag.export_jackknife() + return base_matrix + + def _imp_from_jack(matrix, name): + base_matrix = np.empty_like(matrix) + for (n, m), entry in np.ndenumerate(matrix): + base_matrix[n, m] = CObs(import_jackknife(entry.real, name), + import_jackknife(entry.imag, name)) + return base_matrix + + j_a = _exp_to_jack(a) + j_b = _exp_to_jack(b) + r = j_a @ j_b + return _imp_from_jack(r, a.ravel()[0].real.names[0]) + else: + def _exp_to_jack(matrix): + base_matrix = np.empty_like(matrix) + for (n, m), entry in np.ndenumerate(matrix): + base_matrix[n, m] = entry.export_jackknife() + return base_matrix + + def _imp_from_jack(matrix, name): + base_matrix = np.empty_like(matrix) + for (n, m), entry in np.ndenumerate(matrix): + base_matrix[n, m] = import_jackknife(entry, name) + return base_matrix + + j_a = _exp_to_jack(a) + j_b = _exp_to_jack(b) + r = j_a @ j_b + return _imp_from_jack(r, a.ravel()[0].names[0]) + + def inv(x): """Inverse of Obs or CObs valued matrices.""" return _mat_mat_op(anp.linalg.inv, x) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 2b6f200c..4cfa9107 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -7,6 +7,27 @@ import pytest np.random.seed(0) +def get_real_matrix(dimension): + base_matrix = np.empty((dimension, dimension), dtype=object) + for (n, m), entry in np.ndenumerate(base_matrix): + exponent_real = np.random.normal(0, 1) + exponent_imag = np.random.normal(0, 1) + base_matrix[n, m] = pe.Obs([np.random.normal(1.0, 0.1, 100)], ['t']) + + + return base_matrix + +def get_complex_matrix(dimension): + base_matrix = np.empty((dimension, dimension), dtype=object) + for (n, m), entry in np.ndenumerate(base_matrix): + exponent_real = np.random.normal(0, 1) + exponent_imag = np.random.normal(0, 1) + base_matrix[n, m] = pe.CObs(pe.Obs([np.random.normal(1.0, 0.1, 100)], ['t']), + pe.Obs([np.random.normal(1.0, 0.1, 100)], ['t'])) + + return base_matrix + + def test_matmul(): for dim in [4, 8]: my_list = [] @@ -29,6 +50,26 @@ def test_matmul(): assert e.is_zero(), t +def test_jack_matmul(): + tt = get_real_matrix(8) + check1 = pe.linalg.jack_matmul(tt, tt) - pe.linalg.matmul(tt, tt) + [o.gamma_method() for o in check1.ravel()] + assert np.all([o.is_zero_within_error(0.1) for o in check1.ravel()]) + trace1 = np.trace(check1) + trace1.gamma_method() + assert trace1.dvalue < 0.001 + + tt2 = get_complex_matrix(8) + check2 = pe.linalg.jack_matmul(tt2, tt2) - pe.linalg.matmul(tt2, tt2) + [o.gamma_method() for o in check2.ravel()] + assert np.all([o.real.is_zero_within_error(0.1) for o in check2.ravel()]) + assert np.all([o.imag.is_zero_within_error(0.1) for o in check2.ravel()]) + trace2 = np.trace(check2) + trace2.gamma_method() + assert trace2.real.dvalue < 0.001 + assert trace2.imag.dvalue < 0.001 + + def test_multi_dot(): for dim in [4, 8]: my_list = []