[Fix] more work on typehints

This commit is contained in:
Simon Kuberski 2025-02-17 15:24:18 +01:00
parent bbf0b689a1
commit 4814675ff6
7 changed files with 56 additions and 43 deletions

View file

@ -45,7 +45,7 @@ class Corr:
__slots__ = ["content", "N", "T", "tag", "prange"] __slots__ = ["content", "N", "T", "tag", "prange"]
def __init__(self, data_input: Any, padding: list[int]=[0, 0], prange: Optional[list[int]]=None): def __init__(self, data_input: list[Obs, CObs], padding: list[int]=[0, 0], prange: Optional[list[int]]=None):
""" Initialize a Corr object. """ Initialize a Corr object.
Parameters Parameters
@ -285,7 +285,7 @@ class Corr:
"""Calculates the per-timeslice trace of a correlator matrix.""" """Calculates the per-timeslice trace of a correlator matrix."""
if self.N == 1: if self.N == 1:
raise ValueError("Only works for correlator matrices.") raise ValueError("Only works for correlator matrices.")
newcontent: list[Union[None, float]] = [] newcontent: list[Union[None, Obs, CObs]] = []
for t in range(self.T): for t in range(self.T):
if _check_for_none(self, self.content[t]): if _check_for_none(self, self.content[t]):
newcontent.append(None) newcontent.append(None)
@ -715,8 +715,8 @@ class Corr:
""" """
if self.N != 1: if self.N != 1:
raise Exception('Correlator must be projected before getting m_eff') raise Exception('Correlator must be projected before getting m_eff')
newcontent: list[Union[None, Obs]] = []
if variant == 'log': if variant == 'log':
newcontent = []
for t in range(self.T - 1): for t in range(self.T - 1):
if ((self.content[t] is None) or (self.content[t + 1] is None)) or (self.content[t + 1][0].value == 0): if ((self.content[t] is None) or (self.content[t + 1] is None)) or (self.content[t + 1][0].value == 0):
newcontent.append(None) newcontent.append(None)
@ -730,7 +730,6 @@ class Corr:
return np.log(Corr(newcontent, padding=[0, 1])) return np.log(Corr(newcontent, padding=[0, 1]))
elif variant == 'logsym': elif variant == 'logsym':
newcontent = []
for t in range(1, self.T - 1): for t in range(1, self.T - 1):
if ((self.content[t - 1] is None) or (self.content[t + 1] is None)) or (self.content[t + 1][0].value == 0): if ((self.content[t - 1] is None) or (self.content[t + 1] is None)) or (self.content[t + 1][0].value == 0):
newcontent.append(None) newcontent.append(None)
@ -752,7 +751,6 @@ class Corr:
def root_function(x, d): def root_function(x, d):
return func(x * (t - self.T / 2)) / func(x * (t + 1 - self.T / 2)) - d return func(x * (t - self.T / 2)) / func(x * (t + 1 - self.T / 2)) - d
newcontent = []
for t in range(self.T - 1): for t in range(self.T - 1):
if (self.content[t] is None) or (self.content[t + 1] is None) or (self.content[t + 1][0].value == 0): if (self.content[t] is None) or (self.content[t + 1] is None) or (self.content[t + 1][0].value == 0):
newcontent.append(None) newcontent.append(None)
@ -769,7 +767,6 @@ class Corr:
return Corr(newcontent, padding=[0, 1]) return Corr(newcontent, padding=[0, 1])
elif variant == 'arccosh': elif variant == 'arccosh':
newcontent = []
for t in range(1, self.T - 1): for t in range(1, self.T - 1):
if (self.content[t] is None) or (self.content[t + 1] is None) or (self.content[t - 1] is None) or (self.content[t][0].value == 0): if (self.content[t] is None) or (self.content[t + 1] is None) or (self.content[t - 1] is None) or (self.content[t][0].value == 0):
newcontent.append(None) newcontent.append(None)
@ -782,7 +779,7 @@ class Corr:
else: else:
raise ValueError('Unknown variant.') raise ValueError('Unknown variant.')
def fit(self, function: Callable, fitrange: Optional[Union[str, list[int]]]=None, silent: bool=False, **kwargs) -> Fit_result: def fit(self, function: Callable, fitrange: Optional[list[int]]=None, silent: bool=False, **kwargs) -> Fit_result:
r'''Fits function to the data r'''Fits function to the data
Parameters Parameters
@ -865,7 +862,7 @@ class Corr:
self.prange = prange self.prange = prange
return return
def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: None=None, logscale: bool=False, plateau: None=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: None=None, save: None=None, auto_gamma: bool=False, hide_sigma: None=None, references: None=None, title: None=None): def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Optional[Obs, float, int]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Optional[int, float]=None, references: Optional[list[float]]=None, title: Optional[str]=None):
"""Plots the correlator using the tag of the correlator as label if available. """Plots the correlator using the tag of the correlator as label if available.
Parameters Parameters
@ -1081,14 +1078,14 @@ class Corr:
__array_priority__ = 10000 __array_priority__ = 10000
def __eq__(self, y: Union[Corr, Obs, int]) -> ndarray: def __eq__(self, y: Any) -> ndarray:
if isinstance(y, Corr): if isinstance(y, Corr):
comp = np.asarray(y.content, dtype=object) comp = np.asarray(y.content, dtype=object)
else: else:
comp = np.asarray(y) comp = np.asarray(y)
return np.asarray(self.content, dtype=object) == comp return np.asarray(self.content, dtype=object) == comp
def __add__(self, y: Any) -> "Corr": def __add__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
if isinstance(y, Corr): if isinstance(y, Corr):
if ((self.N != y.N) or (self.T != y.T)): if ((self.N != y.N) or (self.T != y.T)):
raise ValueError("Addition of Corrs with different shape") raise ValueError("Addition of Corrs with different shape")
@ -1116,7 +1113,7 @@ class Corr:
else: else:
raise TypeError("Corr + wrong type") raise TypeError("Corr + wrong type")
def __mul__(self, y: Any) -> "Corr": def __mul__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
if isinstance(y, Corr): if isinstance(y, Corr):
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T): if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):
raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T") raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T")
@ -1187,7 +1184,7 @@ class Corr:
else: else:
return NotImplemented return NotImplemented
def __truediv__(self, y: Union[Corr, float, ndarray, int]) -> "Corr": def __truediv__(self, y: Union[Corr, Obs, CObs, int, float, ndarray]) -> "Corr":
if isinstance(y, Corr): if isinstance(y, Corr):
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T): if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):
raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T") raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T")
@ -1245,10 +1242,10 @@ class Corr:
newcontent = [None if _check_for_none(self, item) else -1. * item for item in self.content] newcontent = [None if _check_for_none(self, item) else -1. * item for item in self.content]
return Corr(newcontent, prange=self.prange) return Corr(newcontent, prange=self.prange)
def __sub__(self, y: Union[Corr, float, ndarray, int]) -> "Corr": def __sub__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
return self + (-y) return self + (-y)
def __pow__(self, y: Union[float, int]) -> "Corr": def __pow__(self, y: Union[Obs, CObs, float, int]) -> "Corr":
if isinstance(y, (Obs, int, float, CObs)): if isinstance(y, (Obs, int, float, CObs)):
newcontent = [None if _check_for_none(self, item) else item**y for item in self.content] newcontent = [None if _check_for_none(self, item) else item**y for item in self.content]
return Corr(newcontent, prange=self.prange) return Corr(newcontent, prange=self.prange)
@ -1321,16 +1318,16 @@ class Corr:
return self._apply_func_to_corr(np.arctanh) return self._apply_func_to_corr(np.arctanh)
# Right hand side operations (require tweak in main module to work) # Right hand side operations (require tweak in main module to work)
def __radd__(self, y): def __radd__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
return self + y return self + y
def __rsub__(self, y: int) -> "Corr": def __rsub__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
return -self + y return -self + y
def __rmul__(self, y: Union[float, int]) -> "Corr": def __rmul__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
return self * y return self * y
def __rtruediv__(self, y: int) -> "Corr": def __rtruediv__(self, y: Union[Corr, Obs, CObs, int, float, ndarray]) -> "Corr":
return (self / y) ** (-1) return (self / y) ** (-1)
@property @property
@ -1353,7 +1350,7 @@ class Corr:
return self._apply_func_to_corr(return_imag) return self._apply_func_to_corr(return_imag)
def prune(self, Ntrunc: int, tproj: int=3, t0proj: int=2, basematrix: None=None) -> "Corr": def prune(self, Ntrunc: int, tproj: int=3, t0proj: int=2, basematrix: Optional[ndarray]=None) -> "Corr":
r''' Project large correlation matrix to lowest states r''' Project large correlation matrix to lowest states
This method can be used to reduce the size of an (N x N) correlation matrix This method can be used to reduce the size of an (N x N) correlation matrix
@ -1448,7 +1445,7 @@ def _check_for_none(corr: Corr, entry: Optional[ndarray]) -> bool:
return len(list(filter(None, np.asarray(entry).flatten()))) < corr.N ** 2 return len(list(filter(None, np.asarray(entry).flatten()))) < corr.N ** 2
def _GEVP_solver(Gt: Optional[ndarray], G0: ndarray, method: str='eigh', chol_inv: Optional[ndarray]=None) -> ndarray: def _GEVP_solver(Gt: ndarray, G0: ndarray, method: str='eigh', chol_inv: Optional[ndarray]=None) -> ndarray:
r"""Helper function for solving the GEVP and sorting the eigenvectors. r"""Helper function for solving the GEVP and sorting the eigenvectors.
Solves $G(t)v_i=\lambda_i G(t_0)v_i$ and returns the eigenvectors v_i Solves $G(t)v_i=\lambda_i G(t_0)v_i$ and returns the eigenvectors v_i

View file

@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from numpy import ndarray from numpy import ndarray
from typing import Any, Optional, Union from typing import Optional, Union
class Covobs: class Covobs:
def __init__(self, mean: Optional[Union[float, int]], cov: Any, name: str, pos: Optional[int]=None, grad: Optional[Union[ndarray, list[float]]]=None): def __init__(self, mean: Union[float, int], cov: Union[list, ndarray], name: str, pos: Optional[int]=None, grad: Optional[Union[ndarray, list[float]]]=None):
""" Initialize Covobs object. """ Initialize Covobs object.
Parameters Parameters
@ -47,7 +47,7 @@ class Covobs:
""" """
return np.dot(np.transpose(self.grad), np.dot(self.cov, self.grad)).item() return np.dot(np.transpose(self.grad), np.dot(self.cov, self.grad)).item()
def _set_cov(self, cov: Any): def _set_cov(self, cov: Union[list, ndarray]):
""" Set the covariance matrix of the covobs """ Set the covariance matrix of the covobs
Parameters Parameters

View file

@ -827,9 +827,20 @@ def residual_plot(x: ndarray, y: list[Obs], func: Callable, fit_res: list[Obs],
plt.draw() plt.draw()
def error_band(x: list[int], func: Callable, beta: list[Obs]) -> ndarray: def error_band(x: list[int], func: Callable, beta: Union[Fit_result, list[Obs]]) -> ndarray:
"""Calculate the error band for an array of sample values x, for given fit function func with optimized parameters beta. """Calculate the error band for an array of sample values x, for given fit function func with optimized parameters beta.
Parameters
----------
x : list[int]
A list of sample points where the error band is evaluated.
func : Callable
The function representing the fit model.
beta : Union[Fit_result, list[Obs]]
Optimized fit parameters.
Returns Returns
------- -------
err : np.array(Obs) err : np.array(Obs)

View file

@ -29,7 +29,7 @@ def print_config():
print(f"{key: <10}\t {value}") print(f"{key: <10}\t {value}")
def errorbar(x, y, axes=plt, **kwargs): def errorbar(x: Union[ndarray[int, float, Obs], list[int, float, Obs]], y: Union[ndarray[int, float, Obs], list[int, float, Obs]], axes=plt, **kwargs):
"""pyerrors wrapper for the errorbars method of matplotlib """pyerrors wrapper for the errorbars method of matplotlib
Parameters Parameters

View file

@ -12,7 +12,7 @@ from scipy.stats import skew, skewtest, kurtosis, kurtosistest
import numdifftools as nd import numdifftools as nd
from itertools import groupby from itertools import groupby
from .covobs import Covobs from .covobs import Covobs
from numpy import bool, float64, int64, ndarray from numpy import float64, int64, ndarray
from typing import Any, Callable, Optional, Union, Sequence, TYPE_CHECKING from typing import Any, Callable, Optional, Union, Sequence, TYPE_CHECKING
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
@ -501,7 +501,7 @@ class Obs:
""" """
return np.isclose(0.0, self.value, 1e-14, atol) and all(np.allclose(0.0, delta, 1e-14, atol) for delta in self.deltas.values()) and all(np.allclose(0.0, delta.errsq(), 1e-14, atol) for delta in self.covobs.values()) return np.isclose(0.0, self.value, 1e-14, atol) and all(np.allclose(0.0, delta, 1e-14, atol) for delta in self.deltas.values()) and all(np.allclose(0.0, delta.errsq(), 1e-14, atol) for delta in self.covobs.values())
def plot_tauint(self, save: None=None): def plot_tauint(self, save: Optional[str]=None):
"""Plot integrated autocorrelation time for each ensemble. """Plot integrated autocorrelation time for each ensemble.
Parameters Parameters
@ -541,7 +541,7 @@ class Obs:
if save: if save:
fig.savefig(save + "_" + str(e)) fig.savefig(save + "_" + str(e))
def plot_rho(self, save: None=None): def plot_rho(self, save: Optional[str]=None):
"""Plot normalized autocorrelation function time for each ensemble. """Plot normalized autocorrelation function time for each ensemble.
Parameters Parameters
@ -626,7 +626,7 @@ class Obs:
plt.title(e_name + f'\nskew: {skew(y_test):.3f} (p={skewtest(y_test).pvalue:.3f}), kurtosis: {kurtosis(y_test):.3f} (p={kurtosistest(y_test).pvalue:.3f})') plt.title(e_name + f'\nskew: {skew(y_test):.3f} (p={skewtest(y_test).pvalue:.3f}), kurtosis: {kurtosis(y_test):.3f} (p={kurtosistest(y_test).pvalue:.3f})')
plt.draw() plt.draw()
def plot_piechart(self, save: None=None) -> dict[str, float64]: def plot_piechart(self, save: Optional[str]=None) -> dict[str, float64]:
"""Plot piechart which shows the fractional contribution of each """Plot piechart which shows the fractional contribution of each
ensemble to the error and returns a dictionary containing the fractions. ensemble to the error and returns a dictionary containing the fractions.
@ -708,7 +708,7 @@ class Obs:
tmp_jacks[1:] = (n * mean - full_data) / (n - 1) tmp_jacks[1:] = (n * mean - full_data) / (n - 1)
return tmp_jacks return tmp_jacks
def export_bootstrap(self, samples: int=500, random_numbers: Optional[ndarray]=None, save_rng: None=None) -> ndarray: def export_bootstrap(self, samples: int=500, random_numbers: Optional[ndarray]=None, save_rng: Optional[str]=None) -> ndarray:
"""Export bootstrap samples from the Obs """Export bootstrap samples from the Obs
Parameters Parameters
@ -784,19 +784,19 @@ class Obs:
return int(m.hexdigest(), 16) & 0xFFFFFFFF return int(m.hexdigest(), 16) & 0xFFFFFFFF
# Overload comparisons # Overload comparisons
def __lt__(self, other: Union[Obs, float, float64]) -> Union[bool, bool]: def __lt__(self, other: Union[Obs, float, float64, int]) -> bool:
return self.value < other return self.value < other
def __le__(self, other: Union[Obs, float, int]) -> bool: def __le__(self, other: Union[Obs, float, float64, int]) -> bool:
return self.value <= other return self.value <= other
def __gt__(self, other: Union[Obs, float]) -> Union[bool, bool]: def __gt__(self, other: Union[Obs, float, float64, int]) -> bool:
return self.value > other return self.value > other
def __ge__(self, other: Union[Obs, float, int]) -> Union[bool, bool]: def __ge__(self, other: Union[Obs, float, float64, int]) -> bool:
return self.value >= other return self.value >= other
def __eq__(self, other: Optional[Union[Obs, float64, int, float]]) -> Union[bool, bool]: def __eq__(self, other: Optional[Union[Obs, float, float64, int]]) -> bool:
if other is None: if other is None:
return False return False
return (self - other).is_zero() return (self - other).is_zero()
@ -815,10 +815,10 @@ class Obs:
else: else:
return derived_observable(lambda x, **kwargs: x[0] + y, [self], man_grad=[1]) return derived_observable(lambda x, **kwargs: x[0] + y, [self], man_grad=[1])
def __radd__(self, y: Union[float, int]) -> Union[Obs, NotImplementedType, CObs, ndarray]: def __radd__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
return self + y return self + y
def __mul__(self, y: Any) -> Union[Obs, ndarray, CObs, NotImplementedType]: def __mul__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
if isinstance(y, Obs): if isinstance(y, Obs):
return derived_observable(lambda x, **kwargs: x[0] * x[1], [self, y], man_grad=[y.value, self.value]) return derived_observable(lambda x, **kwargs: x[0] * x[1], [self, y], man_grad=[y.value, self.value])
else: else:
@ -831,10 +831,10 @@ class Obs:
else: else:
return derived_observable(lambda x, **kwargs: x[0] * y, [self], man_grad=[y]) return derived_observable(lambda x, **kwargs: x[0] * y, [self], man_grad=[y])
def __rmul__(self, y: Union[float, int]) -> Union[Obs, NotImplementedType, CObs, ndarray]: def __rmul__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
return self * y return self * y
def __sub__(self, y: Any) -> Union[Obs, NotImplementedType, ndarray]: def __sub__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
if isinstance(y, Obs): if isinstance(y, Obs):
return derived_observable(lambda x, **kwargs: x[0] - x[1], [self, y], man_grad=[1, -1]) return derived_observable(lambda x, **kwargs: x[0] - x[1], [self, y], man_grad=[1, -1])
else: else:
@ -845,7 +845,7 @@ class Obs:
else: else:
return derived_observable(lambda x, **kwargs: x[0] - y, [self], man_grad=[1]) return derived_observable(lambda x, **kwargs: x[0] - y, [self], man_grad=[1])
def __rsub__(self, y: Union[float, int]) -> Union[Obs, NotImplementedType, CObs, ndarray]: def __rsub__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
return -1 * (self - y) return -1 * (self - y)
def __pos__(self) -> Obs: def __pos__(self) -> Obs:
@ -959,6 +959,8 @@ class CObs:
if isinstance(self.imag, Obs): if isinstance(self.imag, Obs):
self.imag.gamma_method(**kwargs) self.imag.gamma_method(**kwargs)
gm = gamma_method
def is_zero(self) -> bool: def is_zero(self) -> bool:
"""Checks whether both real and imaginary part are zero within machine precision.""" """Checks whether both real and imaginary part are zero within machine precision."""
return self.real == 0.0 and self.imag == 0.0 return self.real == 0.0 and self.imag == 0.0
@ -1057,7 +1059,7 @@ class CObs:
return f"({self.real:{format_type}}{self.imag:+{significance}}j)" return f"({self.real:{format_type}}{self.imag:+{significance}}j)"
def gamma_method(x: Union[Corr, Obs, ndarray, list[Obs]], **kwargs) -> ndarray: def gamma_method(x: Union[Corr, Obs, CObs, ndarray, list[Obs, CObs]], **kwargs) -> ndarray:
"""Vectorized version of the gamma_method applicable to lists or arrays of Obs. """Vectorized version of the gamma_method applicable to lists or arrays of Obs.
See docstring of pe.Obs.gamma_method for details. See docstring of pe.Obs.gamma_method for details.
@ -1192,7 +1194,7 @@ def _expand_deltas_for_merge(deltas: ndarray, idx: Union[range, list[int]], shap
return np.array([ret[new_idx[i] - new_idx[0]] for i in range(len(new_idx))]) * len(new_idx) / len(idx) * scalefactor return np.array([ret[new_idx[i] - new_idx[0]] for i in range(len(new_idx))]) * len(new_idx) / len(idx) * scalefactor
def derived_observable(func: Callable, data: Any, array_mode: bool=False, **kwargs) -> Union[Obs, ndarray]: def derived_observable(func: Callable, data: Union[list[Obs], ndarray], array_mode: bool=False, **kwargs) -> Union[Obs, ndarray]:
"""Construct a derived Obs according to func(data, **kwargs) using automatic differentiation. """Construct a derived Obs according to func(data, **kwargs) using automatic differentiation.
Parameters Parameters

View file

@ -181,6 +181,8 @@ def test_fit_correlator():
with pytest.raises(ValueError): with pytest.raises(ValueError):
my_corr.fit(f, [0, 2, 3]) my_corr.fit(f, [0, 2, 3])
fit_res = my_corr.fit(f, fitrange=[0, 1])
def test_plateau(): def test_plateau():
my_corr = pe.correlators.Corr([pe.pseudo_Obs(1.01324, 0.05, 't'), pe.pseudo_Obs(1.042345, 0.008, 't')]) my_corr = pe.correlators.Corr([pe.pseudo_Obs(1.01324, 0.05, 't'), pe.pseudo_Obs(1.042345, 0.008, 't')])
@ -226,7 +228,7 @@ def test_utility():
corr.print() corr.print()
corr.print([2, 4]) corr.print([2, 4])
corr.show() corr.show()
corr.show(comp=corr) corr.show(comp=corr, x_range=[2, 5.], y_range=[2, 3.], hide_sigma=0.5, references=[.1, .2, .6], title='TEST')
corr.dump('test_dump', datatype="pickle", path='.') corr.dump('test_dump', datatype="pickle", path='.')
corr.dump('test_dump', datatype="pickle") corr.dump('test_dump', datatype="pickle")

View file

@ -407,6 +407,7 @@ def test_cobs():
obs2 = pe.pseudo_Obs(-0.2, 0.03, 't') obs2 = pe.pseudo_Obs(-0.2, 0.03, 't')
my_cobs = pe.CObs(obs1, obs2) my_cobs = pe.CObs(obs1, obs2)
my_cobs.gm()
assert +my_cobs == my_cobs assert +my_cobs == my_cobs
assert -my_cobs == 0 - my_cobs assert -my_cobs == 0 - my_cobs
my_cobs == my_cobs my_cobs == my_cobs