[Fix] Start fixing remaining type hints

This commit is contained in:
Fabian Joswig 2025-01-03 19:01:20 +01:00
parent 23d4f4c320
commit 1916de15ec
3 changed files with 27 additions and 28 deletions

View file

@ -76,7 +76,7 @@ class Corr:
T = data_input[0, 0].T T = data_input[0, 0].T
N = data_input.shape[0] N = data_input.shape[0]
input_as_list = [] input_as_list: list[Union[None, ndarray]] = []
for t in range(T): for t in range(T):
if any([(item.content[t] is None) for item in data_input.flatten()]): if any([(item.content[t] is None) for item in data_input.flatten()]):
if not all([(item.content[t] is None) for item in data_input.flatten()]): if not all([(item.content[t] is None) for item in data_input.flatten()]):
@ -100,7 +100,7 @@ class Corr:
if all([isinstance(item, (Obs, CObs)) or item is None for item in data_input]): if all([isinstance(item, (Obs, CObs)) or item is None for item in data_input]):
_assert_equal_properties([o for o in data_input if o is not None]) _assert_equal_properties([o for o in data_input if o is not None])
self.content = [np.asarray([item]) if item is not None else None for item in data_input] self.content: list[Union[None, ndarray]] = [np.asarray([item]) if item is not None else None for item in data_input]
self.N = 1 self.N = 1
elif all([isinstance(item, np.ndarray) or item is None for item in data_input]) and any([isinstance(item, np.ndarray) for item in data_input]): elif all([isinstance(item, np.ndarray) or item is None for item in data_input]) and any([isinstance(item, np.ndarray) for item in data_input]):
self.content = data_input self.content = data_input
@ -124,12 +124,13 @@ class Corr:
def __getitem__(self, idx: Union[slice, int]) -> Union[CObs, Obs, ndarray, List[ndarray]]: def __getitem__(self, idx: Union[slice, int]) -> Union[CObs, Obs, ndarray, List[ndarray]]:
"""Return the content of timeslice idx""" """Return the content of timeslice idx"""
if self.content[idx] is None: idx_content = self.content[idx]
if idx_content is None:
return None return None
elif len(self.content[idx]) == 1: elif len(idx_content) == 1:
return self.content[idx][0] return idx_content[0]
else: else:
return self.content[idx] return idx_content
@property @property
def reweighted(self): def reweighted(self):
@ -154,7 +155,7 @@ class Corr:
gm = gamma_method gm = gamma_method
def projected(self, vector_l: Optional[Union[ndarray, List[Optional[ndarray]]]]=None, vector_r: None=None, normalize: bool=False) -> "Corr": def projected(self, vector_l: Optional[Union[ndarray, List[Optional[ndarray]]]]=None, vector_r: Optional[Union[ndarray, List[Optional[ndarray]]]]=None, normalize: bool=False) -> "Corr":
"""We need to project the Correlator with a Vector to get a single value at each timeslice. """We need to project the Correlator with a Vector to get a single value at each timeslice.
The method can use one or two vectors. The method can use one or two vectors.
@ -166,7 +167,7 @@ class Corr:
if vector_l is None: if vector_l is None:
vector_l, vector_r = np.asarray([1.] + (self.N - 1) * [0.]), np.asarray([1.] + (self.N - 1) * [0.]) vector_l, vector_r = np.asarray([1.] + (self.N - 1) * [0.]), np.asarray([1.] + (self.N - 1) * [0.])
elif (vector_r is None): elif vector_r is None:
vector_r = vector_l vector_r = vector_l
if isinstance(vector_l, list) and not isinstance(vector_r, list): if isinstance(vector_l, list) and not isinstance(vector_r, list):
if len(vector_l) != self.T: if len(vector_l) != self.T:
@ -177,7 +178,7 @@ class Corr:
raise ValueError("Length of vector list must be equal to T") raise ValueError("Length of vector list must be equal to T")
vector_l = [vector_l] * self.T vector_l = [vector_l] * self.T
if not isinstance(vector_l, list): if isinstance(vector_l, ndarray) and isinstance(vector_r, ndarray):
if not vector_l.shape == vector_r.shape == (self.N,): if not vector_l.shape == vector_r.shape == (self.N,):
raise ValueError("Vectors are of wrong shape!") raise ValueError("Vectors are of wrong shape!")
if normalize: if normalize:
@ -239,7 +240,7 @@ class Corr:
newcontent.append(None) newcontent.append(None)
else: else:
newcontent.append(0.5 * (self.content[t] + self.content[self.T - t])) newcontent.append(0.5 * (self.content[t] + self.content[self.T - t]))
if (all([x is None for x in newcontent])): if all([x is None for x in newcontent]):
raise ValueError("Corr could not be symmetrized: No redundant values") raise ValueError("Corr could not be symmetrized: No redundant values")
return Corr(newcontent, prange=self.prange) return Corr(newcontent, prange=self.prange)
@ -284,7 +285,7 @@ class Corr:
"""Calculates the per-timeslice trace of a correlator matrix.""" """Calculates the per-timeslice trace of a correlator matrix."""
if self.N == 1: if self.N == 1:
raise ValueError("Only works for correlator matrices.") raise ValueError("Only works for correlator matrices.")
newcontent = [] newcontent: list[Union[None, float]] = []
for t in range(self.T): for t in range(self.T):
if _check_for_none(self, self.content[t]): if _check_for_none(self, self.content[t]):
newcontent.append(None) newcontent.append(None)
@ -486,7 +487,7 @@ class Corr:
offset : int offset : int
Offset the equal spacing Offset the equal spacing
""" """
new_content = [] new_content: list[Union[None, list, ndarray]] = []
for t in range(self.T): for t in range(self.T):
if (offset + t) % spacing != 0: if (offset + t) % spacing != 0:
new_content.append(None) new_content.append(None)
@ -506,7 +507,7 @@ class Corr:
""" """
if self.N != 1: if self.N != 1:
raise ValueError("Only one-dimensional correlators can be safely correlated.") raise ValueError("Only one-dimensional correlators can be safely correlated.")
new_content = [] new_content: list[Union[None, ndarray]] = []
for x0, t_slice in enumerate(self.content): for x0, t_slice in enumerate(self.content):
if _check_for_none(self, t_slice): if _check_for_none(self, t_slice):
new_content.append(None) new_content.append(None)
@ -538,7 +539,7 @@ class Corr:
""" """
if self.N != 1: if self.N != 1:
raise Exception("Reweighting only implemented for one-dimensional correlators.") raise Exception("Reweighting only implemented for one-dimensional correlators.")
new_content = [] new_content: list[Union[None, ndarray]] = []
for t_slice in self.content: for t_slice in self.content:
if _check_for_none(self, t_slice): if _check_for_none(self, t_slice):
new_content.append(None) new_content.append(None)
@ -660,8 +661,8 @@ class Corr:
""" """
if self.N != 1: if self.N != 1:
raise ValueError("second_deriv only implemented for one-dimensional correlators.") raise ValueError("second_deriv only implemented for one-dimensional correlators.")
newcontent: list[Union[None, ndarray, Obs]] = []
if variant == "symmetric": if variant == "symmetric":
newcontent = []
for t in range(1, self.T - 1): for t in range(1, self.T - 1):
if (self.content[t - 1] is None) or (self.content[t + 1] is None): if (self.content[t - 1] is None) or (self.content[t + 1] is None):
newcontent.append(None) newcontent.append(None)
@ -671,7 +672,6 @@ class Corr:
raise ValueError("Derivative is undefined at all timeslices") raise ValueError("Derivative is undefined at all timeslices")
return Corr(newcontent, padding=[1, 1]) return Corr(newcontent, padding=[1, 1])
elif variant == "big_symmetric": elif variant == "big_symmetric":
newcontent = []
for t in range(2, self.T - 2): for t in range(2, self.T - 2):
if (self.content[t - 2] is None) or (self.content[t + 2] is None): if (self.content[t - 2] is None) or (self.content[t + 2] is None):
newcontent.append(None) newcontent.append(None)
@ -681,7 +681,6 @@ class Corr:
raise ValueError("Derivative is undefined at all timeslices") raise ValueError("Derivative is undefined at all timeslices")
return Corr(newcontent, padding=[2, 2]) return Corr(newcontent, padding=[2, 2])
elif variant == "improved": elif variant == "improved":
newcontent = []
for t in range(2, self.T - 2): for t in range(2, self.T - 2):
if (self.content[t - 2] is None) or (self.content[t - 1] is None) or (self.content[t] is None) or (self.content[t + 1] is None) or (self.content[t + 2] is None): if (self.content[t - 2] is None) or (self.content[t - 1] is None) or (self.content[t] is None) or (self.content[t + 1] is None) or (self.content[t + 2] is None):
newcontent.append(None) newcontent.append(None)
@ -691,7 +690,6 @@ class Corr:
raise ValueError("Derivative is undefined at all timeslices") raise ValueError("Derivative is undefined at all timeslices")
return Corr(newcontent, padding=[2, 2]) return Corr(newcontent, padding=[2, 2])
elif variant == 'log': elif variant == 'log':
newcontent = []
for t in range(self.T): for t in range(self.T):
if (self.content[t] is None) or (self.content[t] <= 0): if (self.content[t] is None) or (self.content[t] <= 0):
newcontent.append(None) newcontent.append(None)
@ -859,7 +857,7 @@ class Corr:
else: else:
raise ValueError("Unsupported plateau method: " + method) raise ValueError("Unsupported plateau method: " + method)
def set_prange(self, prange: List[Union[int, float]]): def set_prange(self, prange: List[int]):
"""Sets the attribute prange of the Corr object.""" """Sets the attribute prange of the Corr object."""
if not len(prange) == 2: if not len(prange) == 2:
raise ValueError("prange must be a list or array with two values") raise ValueError("prange must be a list or array with two values")
@ -1098,7 +1096,7 @@ class Corr:
if isinstance(y, Corr): if isinstance(y, Corr):
if ((self.N != y.N) or (self.T != y.T)): if ((self.N != y.N) or (self.T != y.T)):
raise ValueError("Addition of Corrs with different shape") raise ValueError("Addition of Corrs with different shape")
newcontent = [] newcontent: list[Union[None, ndarray, Obs]] = []
for t in range(self.T): for t in range(self.T):
if _check_for_none(self, self.content[t]) or _check_for_none(y, y.content[t]): if _check_for_none(self, self.content[t]) or _check_for_none(y, y.content[t]):
newcontent.append(None) newcontent.append(None)
@ -1107,7 +1105,7 @@ class Corr:
return Corr(newcontent) return Corr(newcontent)
elif isinstance(y, (Obs, int, float, CObs, complex)): elif isinstance(y, (Obs, int, float, CObs, complex)):
newcontent = [] newcontent: list[Union[None, ndarray, Obs]] = []
for t in range(self.T): for t in range(self.T):
if _check_for_none(self, self.content[t]): if _check_for_none(self, self.content[t]):
newcontent.append(None) newcontent.append(None)
@ -1126,7 +1124,7 @@ class Corr:
if isinstance(y, Corr): if isinstance(y, Corr):
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T): if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):
raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T") raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T")
newcontent = [] newcontent: list[Union[None, ndarray, Obs]] = []
for t in range(self.T): for t in range(self.T):
if _check_for_none(self, self.content[t]) or _check_for_none(y, y.content[t]): if _check_for_none(self, self.content[t]) or _check_for_none(y, y.content[t]):
newcontent.append(None) newcontent.append(None)
@ -1156,7 +1154,7 @@ class Corr:
raise ValueError("Can only multiply correlators by square matrices.") raise ValueError("Can only multiply correlators by square matrices.")
if not self.N == y.shape[0]: if not self.N == y.shape[0]:
raise ValueError("matmul: mismatch of matrix dimensions") raise ValueError("matmul: mismatch of matrix dimensions")
newcontent = [] newcontent: list[Union[None, ndarray, Obs]] = []
for t in range(self.T): for t in range(self.T):
if _check_for_none(self, self.content[t]): if _check_for_none(self, self.content[t]):
newcontent.append(None) newcontent.append(None)
@ -1183,7 +1181,7 @@ class Corr:
raise ValueError("Can only multiply correlators by square matrices.") raise ValueError("Can only multiply correlators by square matrices.")
if not self.N == y.shape[0]: if not self.N == y.shape[0]:
raise ValueError("matmul: mismatch of matrix dimensions") raise ValueError("matmul: mismatch of matrix dimensions")
newcontent = [] newcontent: list[Union[None, ndarray, Obs]] = []
for t in range(self.T): for t in range(self.T):
if _check_for_none(self, self.content[t]): if _check_for_none(self, self.content[t]):
newcontent.append(None) newcontent.append(None)
@ -1197,7 +1195,7 @@ class Corr:
if isinstance(y, Corr): if isinstance(y, Corr):
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T): if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):
raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T") raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T")
newcontent = [] newcontent: list[Union[None, ndarray, Obs]] = []
for t in range(self.T): for t in range(self.T):
if _check_for_none(self, self.content[t]) or _check_for_none(y, y.content[t]): if _check_for_none(self, self.content[t]) or _check_for_none(y, y.content[t]):
newcontent.append(None) newcontent.append(None)
@ -1232,7 +1230,7 @@ class Corr:
elif isinstance(y, (int, float)): elif isinstance(y, (int, float)):
if y == 0: if y == 0:
raise ValueError('Division by zero will return undefined correlator') raise ValueError('Division by zero will return undefined correlator')
newcontent = [] newcontent: list[Union[None, ndarray, Obs]] = []
for t in range(self.T): for t in range(self.T):
if _check_for_none(self, self.content[t]): if _check_for_none(self, self.content[t]):
newcontent.append(None) newcontent.append(None)

View file

@ -48,6 +48,7 @@ def matmul(*operands) -> ndarray:
Nr = derived_observable(multi_dot_r, extended_operands, array_mode=True) Nr = derived_observable(multi_dot_r, extended_operands, array_mode=True)
Ni = derived_observable(multi_dot_i, extended_operands, array_mode=True) Ni = derived_observable(multi_dot_i, extended_operands, array_mode=True)
assert isinstance(Nr, ndarray) and isinstance(Ni, ndarray)
res = np.empty_like(Nr) res = np.empty_like(Nr)
for (n, m), entry in np.ndenumerate(Nr): for (n, m), entry in np.ndenumerate(Nr):
res[n, m] = CObs(Nr[n, m], Ni[n, m]) res[n, m] = CObs(Nr[n, m], Ni[n, m])

View file

@ -3,10 +3,10 @@ import numpy as np
import scipy.linalg import scipy.linalg
from .obs import Obs from .obs import Obs
from .linalg import svd, eig from .linalg import svd, eig
from typing import List from typing import Optional
def matrix_pencil_method(corrs: List[Obs], k: int=1, p: None=None, **kwargs) -> List[Obs]: def matrix_pencil_method(corrs: list[Obs], k: int=1, p: Optional[int]=None, **kwargs) -> list[Obs]:
"""Matrix pencil method to extract k energy levels from data """Matrix pencil method to extract k energy levels from data
Implementation of the matrix pencil method based on Implementation of the matrix pencil method based on