From 48a468c87291823ba72814570df248f99a0ea5d9 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Thu, 13 Jul 2023 16:51:43 +0100 Subject: [PATCH] feat: rmatmul added and __array_priority__ set. --- pyerrors/correlators.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index 6459d466..ea767ec5 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -1033,6 +1033,8 @@ class Corr: # This is because Obs*Corr checks Obs.__mul__ first and does not catch an exception. # One could try and tell Obs to check if the y in __mul__ is a Corr and + __array_priority__ = 10000 + def __add__(self, y): if isinstance(y, Corr): if ((self.N != y.N) or (self.T != y.T)): @@ -1114,7 +1116,23 @@ class Corr: return Corr(newcontent) else: - raise TypeError("Matmul not implemented for this type.") + return NotImplemented + + def __rmatmul__(self, y): + if isinstance(y, np.ndarray): + if y.ndim != 2 or y.shape[0] != y.shape[1]: + raise ValueError("Can only multiply correlators by square matrices.") + if not self.N == y.shape[0]: + raise ValueError("matmul: mismatch of matrix dimensions") + newcontent = [] + for t in range(self.T): + if _check_for_none(self, self.content[t]): + newcontent.append(None) + else: + newcontent.append(y @ self.content[t]) + return Corr(newcontent) + else: + return NotImplemented def __truediv__(self, y): if isinstance(y, Corr):