feat: einsum function added to linalg module

This commit is contained in:
Fabian Joswig 2021-12-01 09:35:40 +00:00
parent 6bc8102f87
commit fe1aeb5354

View file

@ -239,6 +239,44 @@ def jack_matmul(*operands):
return _imp_from_jack(r, name, idl)
def einsum(subscripts, *operands):
"""Wrapper for numpy.einsum
Parameters
----------
subscripts : str
Subscripts for summation (see numpy documentation for details)
operands : numpy.ndarray
Arbitrary number of 2d-numpy arrays which can be real or complex
Obs valued.
"""
if any(isinstance(o.flat[0], CObs) for o in operands):
name = operands[0].flat[0].real.names[0]
idl = operands[0].flat[0].real.idl[name]
else:
name = operands[0].flat[0].names[0]
idl = operands[0].flat[0].idl[name]
conv_operands = []
for op in operands:
if isinstance(op.flat[0], CObs):
conv_operands.append(_exp_to_jack_c(op))
elif isinstance(op.flat[0], Obs):
conv_operands.append(_exp_to_jack(op))
else:
conv_operands.append(op)
result = np.einsum(subscripts, *conv_operands)
if result.dtype == complex:
return _imp_from_jack_c(result, name, idl)
elif result.dtype == float:
return _imp_from_jack(result, name, idl)
else:
raise Exception("Result has unexpected datatype")
def inv(x):
"""Inverse of Obs or CObs valued matrices."""
return _mat_mat_op(anp.linalg.inv, x)