being a bit more concrete with literals

This commit is contained in:
Justus Kuhlmann 2025-11-04 08:05:11 +00:00
commit dcf6a1f8ad

View file

@ -10,8 +10,9 @@ from .misc import dump_object, _assert_equal_properties
from .fits import least_squares, Fit_result
from .roots import find_root
from . import linalg
from .input.json import dump_to_json
from numpy import ndarray, ufunc
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, Literal
class Corr:
@ -45,7 +46,7 @@ class Corr:
__slots__ = ["content", "N", "T", "tag", "prange"]
def __init__(self, data_input: list[Obs, CObs], padding: list[int]=[0, 0], prange: Optional[list[int]]=None):
def __init__(self, data_input: Union[list[Obs, CObs], list[ndarray[ndarray[Obs, CObs]]], ndarray[ndarray[Corr]]], padding: list[int]=[0, 0], prange: Optional[list[int]]=None):
""" Initialize a Corr object.
Parameters
@ -303,7 +304,7 @@ class Corr:
transposed = [None if _check_for_none(self, G) else G.T for G in self.content]
return 0.5 * (Corr(transposed) + self)
def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[str]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[list[list[Optional[ndarray]]], ndarray, list[Optional[ndarray]]]:
def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[Literal["Eigenvalue", "Eigenvector"]]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[list[list[Optional[ndarray]]], ndarray, list[Optional[ndarray]]]:
r'''Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors.
The eigenvectors are sorted according to the descending eigenvalues, the zeroth eigenvector(s) correspond to the
@ -409,7 +410,7 @@ class Corr:
else:
return reordered_vecs
def Eigenvalue(self, t0: int, ts: None=None, state: int=0, sort: str="Eigenvalue", **kwargs) -> "Corr":
def Eigenvalue(self, t0: int, ts: Optional[int]=None, state: int=0, sort: Optional[Literal["Eigenvalue", "Eigenvector"]]="Eigenvalue", **kwargs) -> "Corr":
"""Determines the eigenvalue of the GEVP by solving and projecting the correlator
Parameters
@ -495,7 +496,7 @@ class Corr:
new_content.append(self.content[t])
return Corr(new_content)
def correlate(self, partner: Union[Corr, float, Obs]) -> "Corr":
def correlate(self, partner: Union[Corr, Obs]) -> "Corr":
"""Correlate the correlator with another correlator or Obs
Parameters
@ -577,14 +578,14 @@ class Corr:
return (self + T_partner) / 2
def deriv(self, variant: Optional[str]="symmetric") -> "Corr":
def deriv(self, variant: Literal["symmetric", "forward", "backward", "improved", "log"]="symmetric") -> "Corr":
"""Return the first derivative of the correlator with respect to x0.
Parameters
----------
variant : str
decides which definition of the finite differences derivative is used.
Available choice: symmetric, forward, backward, improved, log, default: symmetric
Available choices: symmetric, forward, backward, improved, log, default: symmetric
"""
if self.N != 1:
raise ValueError("deriv only implemented for one-dimensional correlators.")
@ -638,7 +639,7 @@ class Corr:
else:
raise ValueError("Unknown variant.")
def second_deriv(self, variant: Optional[str]="symmetric") -> "Corr":
def second_deriv(self, variant: Literal["symmetric", "big_symmetric", "improved", "log"]="symmetric") -> "Corr":
r"""Return the second derivative of the correlator with respect to x0.
Parameters
@ -698,7 +699,7 @@ class Corr:
else:
raise ValueError("Unknown variant.")
def m_eff(self, variant: str='log', guess: float=1.0) -> "Corr":
def m_eff(self, variant: Literal["log", "cosh", "periodic", "sinh", "arccosh", "logsym"]='log', guess: float=1.0) -> "Corr":
"""Returns the effective mass of the correlator as correlator object
Parameters
@ -813,7 +814,7 @@ class Corr:
result = least_squares(xs, ys, function, silent=silent, **kwargs)
return result
def plateau(self, plateau_range: Optional[list[int]]=None, method: str="fit", auto_gamma: bool=False) -> Obs:
def plateau(self, plateau_range: Optional[list[int]]=None, method: Literal['fit', 'avg']="fit", auto_gamma: bool=False) -> Obs:
""" Extract a plateau value from a Corr object
Parameters
@ -862,7 +863,7 @@ class Corr:
self.prange = prange
return
def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Optional[Obs, float, int]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Optional[int, float]=None, references: Optional[list[float]]=None, title: Optional[str]=None):
def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Union[Obs, float, int, None]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Union[int, float, None]=None, references: Optional[list[float]]=None, title: Optional[str]=None):
"""Plots the correlator using the tag of the correlator as label if available.
Parameters
@ -1029,11 +1030,8 @@ class Corr:
specifies a custom path for the file (default '.')
"""
if datatype == "json.gz":
from .input.json import dump_to_json
if 'path' in kwargs:
file_name = kwargs.get('path') + '/' + filename
else:
file_name = filename
path = kwargs.get("path", ".")
file_name = path + '/' + filename
dump_to_json(self, file_name)
elif datatype == "pickle":
dump_object(self, filename, **kwargs)
@ -1078,7 +1076,7 @@ class Corr:
__array_priority__ = 10000
def __eq__(self, y: Any) -> ndarray:
def __eq__(self, y: Any) -> ndarray[bool, None]:
if isinstance(y, Corr):
comp = np.asarray(y.content, dtype=object)
else: