feat: linalg.einsum now works with real, complex and float matrices

This commit is contained in:
Fabian Joswig 2021-12-01 12:17:38 +00:00
parent fe1aeb5354
commit 3f6703ad6a

View file

@ -174,35 +174,6 @@ def matmul(*operands):
return derived_array(multi_dot, operands)
def _exp_to_jack(matrix):
base_matrix = np.empty_like(matrix)
for index, entry in np.ndenumerate(matrix):
base_matrix[index] = entry.export_jackknife()
return base_matrix
def _imp_from_jack(matrix, name, idl):
base_matrix = np.empty_like(matrix)
for index, entry in np.ndenumerate(matrix):
base_matrix[index] = import_jackknife(entry, name, [idl])
return base_matrix
def _exp_to_jack_c(matrix):
base_matrix = np.empty_like(matrix)
for index, entry in np.ndenumerate(matrix):
base_matrix[index] = entry.real.export_jackknife() + 1j * entry.imag.export_jackknife()
return base_matrix
def _imp_from_jack_c(matrix, name, idl):
base_matrix = np.empty_like(matrix)
for index, entry in np.ndenumerate(matrix):
base_matrix[index] = CObs(import_jackknife(entry.real, name, [idl]),
import_jackknife(entry.imag, name, [idl]))
return base_matrix
def jack_matmul(*operands):
"""Matrix multiply both operands making use of the jackknife approximation.
@ -215,6 +186,31 @@ def jack_matmul(*operands):
For large matrices this is considerably faster compared to matmul.
"""
def _exp_to_jack(matrix):
base_matrix = np.empty_like(matrix)
for index, entry in np.ndenumerate(matrix):
base_matrix[index] = entry.export_jackknife()
return base_matrix
def _imp_from_jack(matrix, name, idl):
base_matrix = np.empty_like(matrix)
for index, entry in np.ndenumerate(matrix):
base_matrix[index] = import_jackknife(entry, name, [idl])
return base_matrix
def _exp_to_jack_c(matrix):
base_matrix = np.empty_like(matrix)
for index, entry in np.ndenumerate(matrix):
base_matrix[index] = entry.real.export_jackknife() + 1j * entry.imag.export_jackknife()
return base_matrix
def _imp_from_jack_c(matrix, name, idl):
base_matrix = np.empty_like(matrix)
for index, entry in np.ndenumerate(matrix):
base_matrix[index] = CObs(import_jackknife(entry.real, name, [idl]),
import_jackknife(entry.imag, name, [idl]))
return base_matrix
if any(isinstance(o.flat[0], CObs) for o in operands):
name = operands[0].flat[0].real.names[0]
idl = operands[0].flat[0].real.idl[name]
@ -251,12 +247,40 @@ def einsum(subscripts, *operands):
Obs valued.
"""
if any(isinstance(o.flat[0], CObs) for o in operands):
name = operands[0].flat[0].real.names[0]
idl = operands[0].flat[0].real.idl[name]
else:
name = operands[0].flat[0].names[0]
idl = operands[0].flat[0].idl[name]
def _exp_to_jack(matrix):
base_matrix = []
for index, entry in np.ndenumerate(matrix):
base_matrix.append(entry.export_jackknife())
return np.asarray(base_matrix).reshape(matrix.shape + base_matrix[0].shape)
def _exp_to_jack_c(matrix):
base_matrix = []
for index, entry in np.ndenumerate(matrix):
base_matrix.append(entry.real.export_jackknife() + 1j * entry.imag.export_jackknife())
return np.asarray(base_matrix).reshape(matrix.shape + base_matrix[0].shape)
def _imp_from_jack(matrix, name, idl):
base_matrix = np.empty(shape=matrix.shape[:-1], dtype=object)
for index in np.ndindex(matrix.shape[:-1]):
base_matrix[index] = import_jackknife(matrix[index], name, [idl])
return base_matrix
def _imp_from_jack_c(matrix, name, idl):
base_matrix = np.empty(shape=matrix.shape[:-1], dtype=object)
for index in np.ndindex(matrix.shape[:-1]):
base_matrix[index] = CObs(import_jackknife(matrix[index].real, name, [idl]),
import_jackknife(matrix[index].imag, name, [idl]))
return base_matrix
for op in operands:
if isinstance(op.flat[0], CObs):
name = op.flat[0].real.names[0]
idl = op.flat[0].real.idl[name]
break
elif isinstance(op.flat[0], Obs):
name = op.flat[0].names[0]
idl = op.flat[0].idl[name]
break
conv_operands = []
for op in operands:
@ -267,15 +291,22 @@ def einsum(subscripts, *operands):
else:
conv_operands.append(op)
result = np.einsum(subscripts, *conv_operands)
tmp_subscripts = ','.join([o + '...' for o in subscripts.split(',')])
extended_subscripts = '->'.join([o + '...' for o in tmp_subscripts.split('->')[:-1]] + [tmp_subscripts.split('->')[-1]])
jack_einsum = np.einsum(extended_subscripts, *conv_operands)
if result.dtype == complex:
return _imp_from_jack_c(result, name, idl)
elif result.dtype == float:
return _imp_from_jack(result, name, idl)
if jack_einsum.dtype == complex:
result = _imp_from_jack_c(jack_einsum, name, idl)
elif jack_einsum.dtype == float:
result =_imp_from_jack(jack_einsum, name, idl)
else:
raise Exception("Result has unexpected datatype")
if result.shape == ():
return result.flat[0]
else:
return result
def inv(x):
"""Inverse of Obs or CObs valued matrices."""