[Fix] Simplify type hints

This commit is contained in:
Fabian Joswig 2025-01-03 22:43:19 +01:00
parent d45b43e6de
commit 1c6053ef61
12 changed files with 84 additions and 85 deletions

View file

@ -10,8 +10,8 @@ from .misc import dump_object, _assert_equal_properties
from .fits import least_squares, Fit_result
from .roots import find_root
from . import linalg
from numpy import float64, int64, ndarray, ufunc
from typing import Any, Callable, List, Optional, Tuple, Union
from numpy import ndarray, ufunc
from typing import Any, Callable, Optional, Union
class Corr:
@ -45,7 +45,7 @@ class Corr:
__slots__ = ["content", "N", "T", "tag", "prange"]
def __init__(self, data_input: Any, padding: List[int]=[0, 0], prange: Optional[List[int]]=None):
def __init__(self, data_input: Any, padding: list[int]=[0, 0], prange: Optional[list[int]]=None):
""" Initialize a Corr object.
Parameters
@ -122,7 +122,7 @@ class Corr:
self.T = len(self.content)
self.prange = prange
def __getitem__(self, idx: Union[slice, int]) -> Union[CObs, Obs, ndarray, List[ndarray]]:
def __getitem__(self, idx: Union[slice, int]) -> Union[CObs, Obs, ndarray, list[ndarray]]:
"""Return the content of timeslice idx"""
idx_content = self.content[idx]
if idx_content is None:
@ -155,7 +155,7 @@ class Corr:
gm = gamma_method
def projected(self, vector_l: Optional[Union[ndarray, List[Optional[ndarray]]]]=None, vector_r: Optional[Union[ndarray, List[Optional[ndarray]]]]=None, normalize: bool=False) -> "Corr":
def projected(self, vector_l: Optional[Union[ndarray, list[Optional[ndarray]]]]=None, vector_r: Optional[Union[ndarray, list[Optional[ndarray]]]]=None, normalize: bool=False) -> "Corr":
"""We need to project the Correlator with a Vector to get a single value at each timeslice.
The method can use one or two vectors.
@ -209,7 +209,7 @@ class Corr:
newcontent = [None if (item is None) else item[i, j] for item in self.content]
return Corr(newcontent)
def plottable(self) -> Union[Tuple[List[int], List[float64], List[float64]], Tuple[List[int], List[float], List[float64]]]:
def plottable(self) -> tuple[list[int], list[float]]:
"""Outputs the correlator in a plotable format.
Outputs three lists containing the timeslice index, the value on each
@ -303,7 +303,7 @@ class Corr:
transposed = [None if _check_for_none(self, G) else G.T for G in self.content]
return 0.5 * (Corr(transposed) + self)
def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[str]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[List[List[Optional[ndarray]]], ndarray, List[Optional[ndarray]]]:
def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[str]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[list[list[Optional[ndarray]]], ndarray, list[Optional[ndarray]]]:
r'''Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors.
The eigenvectors are sorted according to the descending eigenvalues, the zeroth eigenvector(s) correspond to the
@ -786,7 +786,7 @@ class Corr:
else:
raise ValueError('Unknown variant.')
def fit(self, function: Callable, fitrange: Optional[Union[str, List[int]]]=None, silent: bool=False, **kwargs) -> Fit_result:
def fit(self, function: Callable, fitrange: Optional[Union[str, list[int]]]=None, silent: bool=False, **kwargs) -> Fit_result:
r'''Fits function to the data
Parameters
@ -820,7 +820,7 @@ class Corr:
result = least_squares(xs, ys, function, silent=silent, **kwargs)
return result
def plateau(self, plateau_range: Optional[List[int]]=None, method: str="fit", auto_gamma: bool=False) -> Obs:
def plateau(self, plateau_range: Optional[list[int]]=None, method: str="fit", auto_gamma: bool=False) -> Obs:
""" Extract a plateau value from a Corr object
Parameters
@ -857,7 +857,7 @@ class Corr:
else:
raise ValueError("Unsupported plateau method: " + method)
def set_prange(self, prange: List[int]):
def set_prange(self, prange: list[int]):
"""Sets the attribute prange of the Corr object."""
if not len(prange) == 2:
raise ValueError("prange must be a list or array with two values")
@ -869,7 +869,7 @@ class Corr:
self.prange = prange
return
def show(self, x_range: Optional[List[int64]]=None, comp: Optional[Corr]=None, y_range: None=None, logscale: bool=False, plateau: None=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: None=None, save: None=None, auto_gamma: bool=False, hide_sigma: None=None, references: None=None, title: None=None):
def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: None=None, logscale: bool=False, plateau: None=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: None=None, save: None=None, auto_gamma: bool=False, hide_sigma: None=None, references: None=None, title: None=None):
"""Plots the correlator using the tag of the correlator as label if available.
Parameters
@ -1047,10 +1047,10 @@ class Corr:
else:
raise ValueError("Unknown datatype " + str(datatype))
def print(self, print_range: Optional[List[int]]=None):
def print(self, print_range: Optional[list[int]]=None):
print(self.__repr__(print_range))
def __repr__(self, print_range: Optional[List[int]]=None) -> str:
def __repr__(self, print_range: Optional[list[int]]=None) -> str:
if print_range is None:
print_range = [0, None]
@ -1415,7 +1415,7 @@ class Corr:
return Corr(newcontent)
def _sort_vectors(vec_set_in: List[Optional[ndarray]], ts: int) -> List[Optional[Union[ndarray, List[ndarray]]]]:
def _sort_vectors(vec_set_in: list[Optional[ndarray]], ts: int) -> list[Optional[Union[ndarray, list[ndarray]]]]:
"""Helper function used to find a set of Eigenvectors consistent over all timeslices"""
if isinstance(vec_set_in[ts][0][0], Obs):