mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-05-14 11:33:42 +02:00
refactor: correlators._solve_GEVP simplified and optimized, test added.
This commit is contained in:
parent
d60212739b
commit
ba054fa11c
2 changed files with 18 additions and 5 deletions
|
@ -1219,8 +1219,6 @@ def _sort_vectors(vec_set, ts):
|
||||||
return sorted_vec_set
|
return sorted_vec_set
|
||||||
|
|
||||||
|
|
||||||
def _GEVP_solver(Gt, G0): # Just so normalization an sorting does not need to be repeated. Here we could later put in some checks
|
def _GEVP_solver(Gt, G0):
|
||||||
sp_val, sp_vecs = scipy.linalg.eigh(Gt, G0)
|
"""Helper function for solving the GEVP and sorting the eigenvectors."""
|
||||||
sp_vecs = [sp_vecs[:, np.argsort(sp_val)[-i]] for i in range(1, sp_vecs.shape[0] + 1)]
|
return scipy.linalg.eigh(Gt, G0)[1].T[::-1]
|
||||||
sp_vecs = [v / np.sqrt((v.T @ G0 @ v)) for v in sp_vecs]
|
|
||||||
return sp_vecs
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import scipy
|
||||||
import pyerrors as pe
|
import pyerrors as pe
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -290,6 +291,20 @@ def test_matrix_symmetric():
|
||||||
assert np.all([np.all(o == o.T) for o in sym_corr_mat])
|
assert np.all([np.all(o == o.T) for o in sym_corr_mat])
|
||||||
|
|
||||||
|
|
||||||
|
def test_GEVP_solver():
|
||||||
|
|
||||||
|
mat1 = np.random.rand(15, 15)
|
||||||
|
mat2 = np.random.rand(15, 15)
|
||||||
|
mat1 = mat1 @ mat1.T
|
||||||
|
mat2 = mat2 @ mat2.T
|
||||||
|
|
||||||
|
sp_val, sp_vecs = scipy.linalg.eigh(mat1, mat2)
|
||||||
|
sp_vecs = [sp_vecs[:, np.argsort(sp_val)[-i]] for i in range(1, sp_vecs.shape[0] + 1)]
|
||||||
|
sp_vecs = [v / np.sqrt((v.T @ mat2 @ v)) for v in sp_vecs]
|
||||||
|
|
||||||
|
assert np.allclose(sp_vecs, pe.correlators._GEVP_solver(mat1, mat2), atol=1e-14)
|
||||||
|
|
||||||
|
|
||||||
def test_hankel():
|
def test_hankel():
|
||||||
corr_content = []
|
corr_content = []
|
||||||
for t in range(8):
|
for t in range(8):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue