being a bit more concrete with literals

This commit is contained in:
Justus Kuhlmann 2025-11-04 08:05:11 +00:00
commit dcf6a1f8ad

View file

@ -10,8 +10,9 @@ from .misc import dump_object, _assert_equal_properties
from .fits import least_squares, Fit_result from .fits import least_squares, Fit_result
from .roots import find_root from .roots import find_root
from . import linalg from . import linalg
from .input.json import dump_to_json
from numpy import ndarray, ufunc from numpy import ndarray, ufunc
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union, Literal
class Corr: class Corr:
@ -45,7 +46,7 @@ class Corr:
__slots__ = ["content", "N", "T", "tag", "prange"] __slots__ = ["content", "N", "T", "tag", "prange"]
def __init__(self, data_input: list[Obs, CObs], padding: list[int]=[0, 0], prange: Optional[list[int]]=None): def __init__(self, data_input: Union[list[Obs, CObs], list[ndarray[ndarray[Obs, CObs]]], ndarray[ndarray[Corr]]], padding: list[int]=[0, 0], prange: Optional[list[int]]=None):
""" Initialize a Corr object. """ Initialize a Corr object.
Parameters Parameters
@ -303,7 +304,7 @@ class Corr:
transposed = [None if _check_for_none(self, G) else G.T for G in self.content] transposed = [None if _check_for_none(self, G) else G.T for G in self.content]
return 0.5 * (Corr(transposed) + self) return 0.5 * (Corr(transposed) + self)
def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[str]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[list[list[Optional[ndarray]]], ndarray, list[Optional[ndarray]]]: def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[Literal["Eigenvalue", "Eigenvector"]]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[list[list[Optional[ndarray]]], ndarray, list[Optional[ndarray]]]:
r'''Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors. r'''Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors.
The eigenvectors are sorted according to the descending eigenvalues, the zeroth eigenvector(s) correspond to the The eigenvectors are sorted according to the descending eigenvalues, the zeroth eigenvector(s) correspond to the
@ -409,7 +410,7 @@ class Corr:
else: else:
return reordered_vecs return reordered_vecs
def Eigenvalue(self, t0: int, ts: None=None, state: int=0, sort: str="Eigenvalue", **kwargs) -> "Corr": def Eigenvalue(self, t0: int, ts: Optional[int]=None, state: int=0, sort: Optional[Literal["Eigenvalue", "Eigenvector"]]="Eigenvalue", **kwargs) -> "Corr":
"""Determines the eigenvalue of the GEVP by solving and projecting the correlator """Determines the eigenvalue of the GEVP by solving and projecting the correlator
Parameters Parameters
@ -495,7 +496,7 @@ class Corr:
new_content.append(self.content[t]) new_content.append(self.content[t])
return Corr(new_content) return Corr(new_content)
def correlate(self, partner: Union[Corr, float, Obs]) -> "Corr": def correlate(self, partner: Union[Corr, Obs]) -> "Corr":
"""Correlate the correlator with another correlator or Obs """Correlate the correlator with another correlator or Obs
Parameters Parameters
@ -577,14 +578,14 @@ class Corr:
return (self + T_partner) / 2 return (self + T_partner) / 2
def deriv(self, variant: Optional[str]="symmetric") -> "Corr": def deriv(self, variant: Literal["symmetric", "forward", "backward", "improved", "log"]="symmetric") -> "Corr":
"""Return the first derivative of the correlator with respect to x0. """Return the first derivative of the correlator with respect to x0.
Parameters Parameters
---------- ----------
variant : str variant : str
decides which definition of the finite differences derivative is used. decides which definition of the finite differences derivative is used.
Available choice: symmetric, forward, backward, improved, log, default: symmetric Available choices: symmetric, forward, backward, improved, log, default: symmetric
""" """
if self.N != 1: if self.N != 1:
raise ValueError("deriv only implemented for one-dimensional correlators.") raise ValueError("deriv only implemented for one-dimensional correlators.")
@ -638,7 +639,7 @@ class Corr:
else: else:
raise ValueError("Unknown variant.") raise ValueError("Unknown variant.")
def second_deriv(self, variant: Optional[str]="symmetric") -> "Corr": def second_deriv(self, variant: Literal["symmetric", "big_symmetric", "improved", "log"]="symmetric") -> "Corr":
r"""Return the second derivative of the correlator with respect to x0. r"""Return the second derivative of the correlator with respect to x0.
Parameters Parameters
@ -698,7 +699,7 @@ class Corr:
else: else:
raise ValueError("Unknown variant.") raise ValueError("Unknown variant.")
def m_eff(self, variant: str='log', guess: float=1.0) -> "Corr": def m_eff(self, variant: Literal["log", "cosh", "periodic", "sinh", "arccosh", "logsym"]='log', guess: float=1.0) -> "Corr":
"""Returns the effective mass of the correlator as correlator object """Returns the effective mass of the correlator as correlator object
Parameters Parameters
@ -813,7 +814,7 @@ class Corr:
result = least_squares(xs, ys, function, silent=silent, **kwargs) result = least_squares(xs, ys, function, silent=silent, **kwargs)
return result return result
def plateau(self, plateau_range: Optional[list[int]]=None, method: str="fit", auto_gamma: bool=False) -> Obs: def plateau(self, plateau_range: Optional[list[int]]=None, method: Literal['fit', 'avg']="fit", auto_gamma: bool=False) -> Obs:
""" Extract a plateau value from a Corr object """ Extract a plateau value from a Corr object
Parameters Parameters
@ -862,7 +863,7 @@ class Corr:
self.prange = prange self.prange = prange
return return
def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Optional[Obs, float, int]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Optional[int, float]=None, references: Optional[list[float]]=None, title: Optional[str]=None): def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Union[Obs, float, int, None]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Union[int, float, None]=None, references: Optional[list[float]]=None, title: Optional[str]=None):
"""Plots the correlator using the tag of the correlator as label if available. """Plots the correlator using the tag of the correlator as label if available.
Parameters Parameters
@ -1029,11 +1030,8 @@ class Corr:
specifies a custom path for the file (default '.') specifies a custom path for the file (default '.')
""" """
if datatype == "json.gz": if datatype == "json.gz":
from .input.json import dump_to_json path = kwargs.get("path", ".")
if 'path' in kwargs: file_name = path + '/' + filename
file_name = kwargs.get('path') + '/' + filename
else:
file_name = filename
dump_to_json(self, file_name) dump_to_json(self, file_name)
elif datatype == "pickle": elif datatype == "pickle":
dump_object(self, filename, **kwargs) dump_object(self, filename, **kwargs)
@ -1078,7 +1076,7 @@ class Corr:
__array_priority__ = 10000 __array_priority__ = 10000
def __eq__(self, y: Any) -> ndarray: def __eq__(self, y: Any) -> ndarray[bool, None]:
if isinstance(y, Corr): if isinstance(y, Corr):
comp = np.asarray(y.content, dtype=object) comp = np.asarray(y.content, dtype=object)
else: else: