diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index 93d20850..6459d466 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -220,7 +220,7 @@ class Corr: def anti_symmetric(self): """Anti-symmetrize the correlator around x0=0.""" if self.N != 1: - raise Exception('anti_symmetric cannot be safely applied to multi-dimensional correlators.') + raise TypeError('anti_symmetric cannot be safely applied to multi-dimensional correlators.') if self.T % 2 != 0: raise Exception("Can not symmetrize odd T") @@ -242,7 +242,7 @@ class Corr: def is_matrix_symmetric(self): """Checks whether a correlator matrices is symmetric on every timeslice.""" if self.N == 1: - raise Exception("Only works for correlator matrices.") + raise TypeError("Only works for correlator matrices.") for t in range(self.T): if self[t] is None: continue @@ -254,6 +254,17 @@ class Corr: return False return True + def trace(self): + if self.N == 1: + raise TypeError("Only works for correlator matrices.") + newcontent = [] + for t in range(self.T): + if _check_for_none(self, self.content[t]): + newcontent.append(None) + else: + newcontent.append(np.trace(self.content[t])) + return Corr(newcontent) + def matrix_symmetric(self): """Symmetrizes the correlator matrices on every timeslice.""" if self.N == 1: @@ -1080,8 +1091,10 @@ class Corr: def __matmul__(self, y): if isinstance(y, np.ndarray): + if y.ndim != 2 or y.shape[0] != y.shape[1]: + raise ValueError("Can only multiply correlators by square matrices.") if not self.N == y.shape[0]: - raise TypeError("Shape mismatch") + raise ValueError("matmul: mismatch of matrix dimensions") newcontent = [] for t in range(self.T): if _check_for_none(self, self.content[t]): @@ -1089,8 +1102,19 @@ class Corr: else: newcontent.append(self.content[t] @ y) return Corr(newcontent) + elif isinstance(y, Corr): + if not self.N == y.N: + raise ValueError("matmul: mismatch of matrix dimensions") + newcontent = [] + for t in range(self.T): + if _check_for_none(self, self.content[t]) or _check_for_none(y, y.content[t]): + newcontent.append(None) + else: + newcontent.append(self.content[t] @ y.content[t]) + return Corr(newcontent) + else: - raise NotImplementedError("Matmul not implemented for this type.") + raise TypeError("Matmul not implemented for this type.") def __truediv__(self, y): if isinstance(y, Corr):