feat: corr, corr matmul and correlator matrix trace added.

This commit is contained in:
Fabian Joswig 2023-07-13 12:18:10 +01:00
parent 4071c8279c
commit 323c3430f5
No known key found for this signature in database

View file

@ -220,7 +220,7 @@ class Corr:
def anti_symmetric(self):
"""Anti-symmetrize the correlator around x0=0."""
if self.N != 1:
raise Exception('anti_symmetric cannot be safely applied to multi-dimensional correlators.')
raise TypeError('anti_symmetric cannot be safely applied to multi-dimensional correlators.')
if self.T % 2 != 0:
raise Exception("Can not symmetrize odd T")
@ -242,7 +242,7 @@ class Corr:
def is_matrix_symmetric(self):
"""Checks whether a correlator matrices is symmetric on every timeslice."""
if self.N == 1:
raise Exception("Only works for correlator matrices.")
raise TypeError("Only works for correlator matrices.")
for t in range(self.T):
if self[t] is None:
continue
@ -254,6 +254,17 @@ class Corr:
return False
return True
def trace(self):
if self.N == 1:
raise TypeError("Only works for correlator matrices.")
newcontent = []
for t in range(self.T):
if _check_for_none(self, self.content[t]):
newcontent.append(None)
else:
newcontent.append(np.trace(self.content[t]))
return Corr(newcontent)
def matrix_symmetric(self):
"""Symmetrizes the correlator matrices on every timeslice."""
if self.N == 1:
@ -1080,8 +1091,10 @@ class Corr:
def __matmul__(self, y):
if isinstance(y, np.ndarray):
if y.ndim != 2 or y.shape[0] != y.shape[1]:
raise ValueError("Can only multiply correlators by square matrices.")
if not self.N == y.shape[0]:
raise TypeError("Shape mismatch")
raise ValueError("matmul: mismatch of matrix dimensions")
newcontent = []
for t in range(self.T):
if _check_for_none(self, self.content[t]):
@ -1089,8 +1102,19 @@ class Corr:
else:
newcontent.append(self.content[t] @ y)
return Corr(newcontent)
elif isinstance(y, Corr):
if not self.N == y.N:
raise ValueError("matmul: mismatch of matrix dimensions")
newcontent = []
for t in range(self.T):
if _check_for_none(self, self.content[t]) or _check_for_none(y, y.content[t]):
newcontent.append(None)
else:
newcontent.append(self.content[t] @ y.content[t])
return Corr(newcontent)
else:
raise NotImplementedError("Matmul not implemented for this type.")
raise TypeError("Matmul not implemented for this type.")
def __truediv__(self, y):
if isinstance(y, Corr):