feat: jack_matmul no works with an arbitrary number of operands

This commit is contained in:
Fabian Joswig 2021-11-18 10:46:30 +00:00
parent a673a8f656
commit c31034565a

View file

@ -174,20 +174,19 @@ def matmul(*operands):
return derived_array(multi_dot, operands) return derived_array(multi_dot, operands)
def jack_matmul(a, b): def jack_matmul(*operands):
"""Matrix multiply both operands making use of the jackknife approximation. """Matrix multiply both operands making use of the jackknife approximation.
Parameters Parameters
---------- ----------
a : numpy.ndarray operands : numpy.ndarray
First matrix, can be real or complex Obs valued Arbitrary number of 2d-numpy arrays which can be real or complex
b : numpy.ndarray Obs valued.
Second matrix, can be real or complex Obs valued
For large matrices this is considerably faster compared to matmul. For large matrices this is considerably faster compared to matmul.
""" """
if any(isinstance(o[0, 0], CObs) for o in [a, b]): if any(isinstance(o[0, 0], CObs) for o in operands):
def _exp_to_jack(matrix): def _exp_to_jack(matrix):
base_matrix = np.empty_like(matrix) base_matrix = np.empty_like(matrix)
for (n, m), entry in np.ndenumerate(matrix): for (n, m), entry in np.ndenumerate(matrix):
@ -201,10 +200,10 @@ def jack_matmul(a, b):
import_jackknife(entry.imag, name)) import_jackknife(entry.imag, name))
return base_matrix return base_matrix
j_a = _exp_to_jack(a) r = _exp_to_jack(operands[0])
j_b = _exp_to_jack(b) for op in operands[1:]:
r = j_a @ j_b r = r @ _exp_to_jack(op)
return _imp_from_jack(r, a.ravel()[0].real.names[0]) return _imp_from_jack(r, op.ravel()[0].real.names[0])
else: else:
def _exp_to_jack(matrix): def _exp_to_jack(matrix):
base_matrix = np.empty_like(matrix) base_matrix = np.empty_like(matrix)
@ -218,10 +217,10 @@ def jack_matmul(a, b):
base_matrix[n, m] = import_jackknife(entry, name) base_matrix[n, m] = import_jackknife(entry, name)
return base_matrix return base_matrix
j_a = _exp_to_jack(a) r = _exp_to_jack(operands[0])
j_b = _exp_to_jack(b) for op in operands[1:]:
r = j_a @ j_b r = r @ _exp_to_jack(op)
return _imp_from_jack(r, a.ravel()[0].names[0]) return _imp_from_jack(r, op.ravel()[0].names[0])
def inv(x): def inv(x):