feat: rmatmul added and __array_priority__ set.

This commit is contained in:
Fabian Joswig 2023-07-13 16:51:43 +01:00
parent 199e06c5f9
commit 48a468c872
No known key found for this signature in database

View file

@ -1033,6 +1033,8 @@ class Corr:
# This is because Obs*Corr checks Obs.__mul__ first and does not catch an exception.
# One could try and tell Obs to check if the y in __mul__ is a Corr and
__array_priority__ = 10000
def __add__(self, y):
if isinstance(y, Corr):
if ((self.N != y.N) or (self.T != y.T)):
@ -1114,7 +1116,23 @@ class Corr:
return Corr(newcontent)
else:
raise TypeError("Matmul not implemented for this type.")
return NotImplemented
def __rmatmul__(self, y):
if isinstance(y, np.ndarray):
if y.ndim != 2 or y.shape[0] != y.shape[1]:
raise ValueError("Can only multiply correlators by square matrices.")
if not self.N == y.shape[0]:
raise ValueError("matmul: mismatch of matrix dimensions")
newcontent = []
for t in range(self.T):
if _check_for_none(self, self.content[t]):
newcontent.append(None)
else:
newcontent.append(y @ self.content[t])
return Corr(newcontent)
else:
return NotImplemented
def __truediv__(self, y):
if isinstance(y, Corr):