Merge branch 'develop' of github.com:fjosw/pyerrors into develop

This commit is contained in:
Fabian Joswig 2022-08-01 16:45:28 +01:00
commit a1d1c412ed
3 changed files with 50 additions and 7 deletions

View file

@ -237,13 +237,30 @@ class Corr:
raise Exception("Corr could not be symmetrized: No redundant values") raise Exception("Corr could not be symmetrized: No redundant values")
return Corr(newcontent, prange=self.prange) return Corr(newcontent, prange=self.prange)
def is_matrix_symmetric(self):
"""Checks whether a correlator matrices is symmetric on every timeslice."""
if self.N == 1:
raise Exception("Only works for correlator matrices.")
for t in range(self.T):
if self[t] is None:
continue
for i in range(self.N):
for j in range(i + 1, self.N):
if self[t][i, j] is self[t][j, i]:
continue
if hash(self[t][i, j]) != hash(self[t][j, i]):
return False
return True
def matrix_symmetric(self): def matrix_symmetric(self):
"""Symmetrizes the correlator matrices on every timeslice.""" """Symmetrizes the correlator matrices on every timeslice."""
if self.N > 1:
transposed = [None if _check_for_none(self, G) else G.T for G in self.content]
return 0.5 * (Corr(transposed) + self)
if self.N == 1: if self.N == 1:
raise Exception("Trying to symmetrize a correlator matrix, that already has N=1.") raise Exception("Trying to symmetrize a correlator matrix, that already has N=1.")
if self.is_matrix_symmetric():
return 1.0 * self
else:
transposed = [None if _check_for_none(self, G) else G.T for G in self.content]
return 0.5 * (Corr(transposed) + self)
def GEVP(self, t0, ts=None, sort="Eigenvalue", **kwargs): def GEVP(self, t0, ts=None, sort="Eigenvalue", **kwargs):
r'''Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors. r'''Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors.
@ -284,7 +301,11 @@ class Corr:
warnings.warn("Argument 'sorted_list' is deprecated, use 'sort' instead.", DeprecationWarning) warnings.warn("Argument 'sorted_list' is deprecated, use 'sort' instead.", DeprecationWarning)
sort = kwargs.get("sorted_list") sort = kwargs.get("sorted_list")
if self.is_matrix_symmetric():
symmetric_corr = self
else:
symmetric_corr = self.matrix_symmetric() symmetric_corr = self.matrix_symmetric()
if sort is None: if sort is None:
if (ts is None): if (ts is None):
raise Exception("ts is required if sort=None.") raise Exception("ts is required if sort=None.")

View file

@ -163,7 +163,7 @@ def read_DistillationContraction_hd5(path, ens_id, diagrams=["direct"], idl=None
Nt = h5file["DistillationContraction/Metadata"].attrs.get("Nt")[0] Nt = h5file["DistillationContraction/Metadata"].attrs.get("Nt")[0]
identifier = [] identifier = []
for in_file in range(4): for in_file in range(len(h5file["DistillationContraction/Metadata/DmfInputFiles"].attrs.keys()) - 1):
encoded_info = h5file["DistillationContraction/Metadata/DmfInputFiles"].attrs.get("DmfInputFiles_" + str(in_file)) encoded_info = h5file["DistillationContraction/Metadata/DmfInputFiles"].attrs.get("DmfInputFiles_" + str(in_file))
full_info = encoded_info[0].decode().split("/")[-1].replace(".h5", "").split("_") full_info = encoded_info[0].decode().split("/")[-1].replace(".h5", "").split("_")
my_tuple = (full_info[0], full_info[1][1:], full_info[2], full_info[3]) my_tuple = (full_info[0], full_info[1][1:], full_info[2], full_info[3])
@ -174,8 +174,8 @@ def read_DistillationContraction_hd5(path, ens_id, diagrams=["direct"], idl=None
for diagram in diagrams: for diagram in diagrams:
real_data = np.zeros(Nt) real_data = np.zeros(Nt)
for x0 in range(Nt): for x0 in range(Nt):
raw_data = h5file["DistillationContraction/Correlators/" + diagram + "/" + str(x0)] raw_data = h5file["DistillationContraction/Correlators/" + diagram + "/" + str(x0)][:]["re"].astype(np.double)
real_data += np.roll(raw_data[:]["re"].astype(np.double), -x0) real_data += np.roll(raw_data, -x0)
real_data /= Nt real_data /= Nt
corr_data[diagram].append(real_data) corr_data[diagram].append(real_data)

View file

@ -367,6 +367,28 @@ def test_matrix_symmetric():
corr3.matrix_symmetric() corr3.matrix_symmetric()
def test_is_matrix_symmetric():
corr_data = []
for t in range(4):
mat = np.zeros((4, 4), dtype=object)
for i in range(4):
for j in range(i, 4):
obs = pe.pseudo_Obs(0.1, 0.047, "rgetrasrewe53455b153v13v5/*/*sdfgb")
mat[i, j] = obs
if i != j:
mat[j, i] = obs
corr_data.append(mat)
corr = pe.Corr(corr_data, padding=[0, 2])
assert corr.is_matrix_symmetric()
corr[0][0, 1] = 1.0 * corr[0][0, 1]
assert corr.is_matrix_symmetric()
corr[3][2, 1] = (1 + 1e-14) * corr[3][2, 1]
assert corr.is_matrix_symmetric()
corr[0][0, 1] = 1.1 * corr[0][0, 1]
assert not corr.is_matrix_symmetric()
def test_GEVP_solver(): def test_GEVP_solver():
mat1 = np.random.rand(15, 15) mat1 = np.random.rand(15, 15)