Matmul overloaded for correlator class. (#199)

* feat: matmul method added to correlator class.

* feat: corr, corr matmul and correlator matrix trace added.

* tests: tests for matmul and trace added.

* tests: slightly reduced tolerance and good guess bad guess test.

* feat: rmatmul added and __array_priority__ set.

* tests: additional tests for rmatmul added.

* tests: one more tests for rmatmul added.

* docs: docstring added to Corr.trace.

* tests: associative property test added for complex Corr matmul.

* fix: Corr.roll method now also works for correlator matrices by
explicitly specifying the axis.

Co-authored-by: Matteo Di Carlo <matteo.dicarlo93@gmail.com>

* feat: exception type for correlator trace of 1dim correlator changed.

* tests: trace N=1 exception tested.

---------

Co-authored-by: Matteo Di Carlo <matteo.dicarlo93@gmail.com>
This commit is contained in:
Fabian Joswig 2023-07-17 11:48:57 +01:00 committed by GitHub
parent 7d1858f6c4
commit f1150f09c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 185 additions and 6 deletions

View file

@ -220,7 +220,7 @@ class Corr:
def anti_symmetric(self):
"""Anti-symmetrize the correlator around x0=0."""
if self.N != 1:
raise Exception('anti_symmetric cannot be safely applied to multi-dimensional correlators.')
raise TypeError('anti_symmetric cannot be safely applied to multi-dimensional correlators.')
if self.T % 2 != 0:
raise Exception("Can not symmetrize odd T")
@ -242,7 +242,7 @@ class Corr:
def is_matrix_symmetric(self):
"""Checks whether a correlator matrices is symmetric on every timeslice."""
if self.N == 1:
raise Exception("Only works for correlator matrices.")
raise TypeError("Only works for correlator matrices.")
for t in range(self.T):
if self[t] is None:
continue
@ -254,6 +254,18 @@ class Corr:
return False
return True
def trace(self):
"""Calculates the per-timeslice trace of a correlator matrix."""
if self.N == 1:
raise ValueError("Only works for correlator matrices.")
newcontent = []
for t in range(self.T):
if _check_for_none(self, self.content[t]):
newcontent.append(None)
else:
newcontent.append(np.trace(self.content[t]))
return Corr(newcontent)
def matrix_symmetric(self):
"""Symmetrizes the correlator matrices on every timeslice."""
if self.N == 1:
@ -405,7 +417,7 @@ class Corr:
dt : int
number of timeslices
"""
return Corr(list(np.roll(np.array(self.content, dtype=object), dt)))
return Corr(list(np.roll(np.array(self.content, dtype=object), dt, axis=0)))
def reverse(self):
"""Reverse the time ordering of the Corr"""
@ -1020,6 +1032,8 @@ class Corr:
# This is because Obs*Corr checks Obs.__mul__ first and does not catch an exception.
# One could try and tell Obs to check if the y in __mul__ is a Corr and
__array_priority__ = 10000
def __add__(self, y):
if isinstance(y, Corr):
if ((self.N != y.N) or (self.T != y.T)):
@ -1076,6 +1090,49 @@ class Corr:
else:
raise TypeError("Corr * wrong type")
def __matmul__(self, y):
if isinstance(y, np.ndarray):
if y.ndim != 2 or y.shape[0] != y.shape[1]:
raise ValueError("Can only multiply correlators by square matrices.")
if not self.N == y.shape[0]:
raise ValueError("matmul: mismatch of matrix dimensions")
newcontent = []
for t in range(self.T):
if _check_for_none(self, self.content[t]):
newcontent.append(None)
else:
newcontent.append(self.content[t] @ y)
return Corr(newcontent)
elif isinstance(y, Corr):
if not self.N == y.N:
raise ValueError("matmul: mismatch of matrix dimensions")
newcontent = []
for t in range(self.T):
if _check_for_none(self, self.content[t]) or _check_for_none(y, y.content[t]):
newcontent.append(None)
else:
newcontent.append(self.content[t] @ y.content[t])
return Corr(newcontent)
else:
return NotImplemented
def __rmatmul__(self, y):
if isinstance(y, np.ndarray):
if y.ndim != 2 or y.shape[0] != y.shape[1]:
raise ValueError("Can only multiply correlators by square matrices.")
if not self.N == y.shape[0]:
raise ValueError("matmul: mismatch of matrix dimensions")
newcontent = []
for t in range(self.T):
if _check_for_none(self, self.content[t]):
newcontent.append(None)
else:
newcontent.append(y @ self.content[t])
return Corr(newcontent)
else:
return NotImplemented
def __truediv__(self, y):
if isinstance(y, Corr):
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):