[Fix] more work on typehints

This commit is contained in:
Simon Kuberski 2025-02-17 15:24:18 +01:00
parent bbf0b689a1
commit 4814675ff6
7 changed files with 56 additions and 43 deletions

View file

@ -45,7 +45,7 @@ class Corr:
__slots__ = ["content", "N", "T", "tag", "prange"]
def __init__(self, data_input: Any, padding: list[int]=[0, 0], prange: Optional[list[int]]=None):
def __init__(self, data_input: list[Obs, CObs], padding: list[int]=[0, 0], prange: Optional[list[int]]=None):
""" Initialize a Corr object.
Parameters
@ -285,7 +285,7 @@ class Corr:
"""Calculates the per-timeslice trace of a correlator matrix."""
if self.N == 1:
raise ValueError("Only works for correlator matrices.")
newcontent: list[Union[None, float]] = []
newcontent: list[Union[None, Obs, CObs]] = []
for t in range(self.T):
if _check_for_none(self, self.content[t]):
newcontent.append(None)
@ -715,8 +715,8 @@ class Corr:
"""
if self.N != 1:
raise Exception('Correlator must be projected before getting m_eff')
newcontent: list[Union[None, Obs]] = []
if variant == 'log':
newcontent = []
for t in range(self.T - 1):
if ((self.content[t] is None) or (self.content[t + 1] is None)) or (self.content[t + 1][0].value == 0):
newcontent.append(None)
@ -730,7 +730,6 @@ class Corr:
return np.log(Corr(newcontent, padding=[0, 1]))
elif variant == 'logsym':
newcontent = []
for t in range(1, self.T - 1):
if ((self.content[t - 1] is None) or (self.content[t + 1] is None)) or (self.content[t + 1][0].value == 0):
newcontent.append(None)
@ -752,7 +751,6 @@ class Corr:
def root_function(x, d):
return func(x * (t - self.T / 2)) / func(x * (t + 1 - self.T / 2)) - d
newcontent = []
for t in range(self.T - 1):
if (self.content[t] is None) or (self.content[t + 1] is None) or (self.content[t + 1][0].value == 0):
newcontent.append(None)
@ -769,7 +767,6 @@ class Corr:
return Corr(newcontent, padding=[0, 1])
elif variant == 'arccosh':
newcontent = []
for t in range(1, self.T - 1):
if (self.content[t] is None) or (self.content[t + 1] is None) or (self.content[t - 1] is None) or (self.content[t][0].value == 0):
newcontent.append(None)
@ -782,7 +779,7 @@ class Corr:
else:
raise ValueError('Unknown variant.')
def fit(self, function: Callable, fitrange: Optional[Union[str, list[int]]]=None, silent: bool=False, **kwargs) -> Fit_result:
def fit(self, function: Callable, fitrange: Optional[list[int]]=None, silent: bool=False, **kwargs) -> Fit_result:
r'''Fits function to the data
Parameters
@ -865,7 +862,7 @@ class Corr:
self.prange = prange
return
def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: None=None, logscale: bool=False, plateau: None=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: None=None, save: None=None, auto_gamma: bool=False, hide_sigma: None=None, references: None=None, title: None=None):
def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Optional[Obs, float, int]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Optional[int, float]=None, references: Optional[list[float]]=None, title: Optional[str]=None):
"""Plots the correlator using the tag of the correlator as label if available.
Parameters
@ -1081,14 +1078,14 @@ class Corr:
__array_priority__ = 10000
def __eq__(self, y: Union[Corr, Obs, int]) -> ndarray:
def __eq__(self, y: Any) -> ndarray:
if isinstance(y, Corr):
comp = np.asarray(y.content, dtype=object)
else:
comp = np.asarray(y)
return np.asarray(self.content, dtype=object) == comp
def __add__(self, y: Any) -> "Corr":
def __add__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
if isinstance(y, Corr):
if ((self.N != y.N) or (self.T != y.T)):
raise ValueError("Addition of Corrs with different shape")
@ -1116,7 +1113,7 @@ class Corr:
else:
raise TypeError("Corr + wrong type")
def __mul__(self, y: Any) -> "Corr":
def __mul__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
if isinstance(y, Corr):
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):
raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T")
@ -1187,7 +1184,7 @@ class Corr:
else:
return NotImplemented
def __truediv__(self, y: Union[Corr, float, ndarray, int]) -> "Corr":
def __truediv__(self, y: Union[Corr, Obs, CObs, int, float, ndarray]) -> "Corr":
if isinstance(y, Corr):
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):
raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T")
@ -1245,10 +1242,10 @@ class Corr:
newcontent = [None if _check_for_none(self, item) else -1. * item for item in self.content]
return Corr(newcontent, prange=self.prange)
def __sub__(self, y: Union[Corr, float, ndarray, int]) -> "Corr":
def __sub__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
return self + (-y)
def __pow__(self, y: Union[float, int]) -> "Corr":
def __pow__(self, y: Union[Obs, CObs, float, int]) -> "Corr":
if isinstance(y, (Obs, int, float, CObs)):
newcontent = [None if _check_for_none(self, item) else item**y for item in self.content]
return Corr(newcontent, prange=self.prange)
@ -1321,16 +1318,16 @@ class Corr:
return self._apply_func_to_corr(np.arctanh)
# Right hand side operations (require tweak in main module to work)
def __radd__(self, y):
def __radd__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
return self + y
def __rsub__(self, y: int) -> "Corr":
def __rsub__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
return -self + y
def __rmul__(self, y: Union[float, int]) -> "Corr":
def __rmul__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
return self * y
def __rtruediv__(self, y: int) -> "Corr":
def __rtruediv__(self, y: Union[Corr, Obs, CObs, int, float, ndarray]) -> "Corr":
return (self / y) ** (-1)
@property
@ -1353,7 +1350,7 @@ class Corr:
return self._apply_func_to_corr(return_imag)
def prune(self, Ntrunc: int, tproj: int=3, t0proj: int=2, basematrix: None=None) -> "Corr":
def prune(self, Ntrunc: int, tproj: int=3, t0proj: int=2, basematrix: Optional[ndarray]=None) -> "Corr":
r''' Project large correlation matrix to lowest states
This method can be used to reduce the size of an (N x N) correlation matrix
@ -1448,7 +1445,7 @@ def _check_for_none(corr: Corr, entry: Optional[ndarray]) -> bool:
return len(list(filter(None, np.asarray(entry).flatten()))) < corr.N ** 2
def _GEVP_solver(Gt: Optional[ndarray], G0: ndarray, method: str='eigh', chol_inv: Optional[ndarray]=None) -> ndarray:
def _GEVP_solver(Gt: ndarray, G0: ndarray, method: str='eigh', chol_inv: Optional[ndarray]=None) -> ndarray:
r"""Helper function for solving the GEVP and sorting the eigenvectors.
Solves $G(t)v_i=\lambda_i G(t_0)v_i$ and returns the eigenvectors v_i

View file

@ -1,12 +1,12 @@
from __future__ import annotations
import numpy as np
from numpy import ndarray
from typing import Any, Optional, Union
from typing import Optional, Union
class Covobs:
def __init__(self, mean: Optional[Union[float, int]], cov: Any, name: str, pos: Optional[int]=None, grad: Optional[Union[ndarray, list[float]]]=None):
def __init__(self, mean: Union[float, int], cov: Union[list, ndarray], name: str, pos: Optional[int]=None, grad: Optional[Union[ndarray, list[float]]]=None):
""" Initialize Covobs object.
Parameters
@ -47,7 +47,7 @@ class Covobs:
"""
return np.dot(np.transpose(self.grad), np.dot(self.cov, self.grad)).item()
def _set_cov(self, cov: Any):
def _set_cov(self, cov: Union[list, ndarray]):
""" Set the covariance matrix of the covobs
Parameters

View file

@ -827,9 +827,20 @@ def residual_plot(x: ndarray, y: list[Obs], func: Callable, fit_res: list[Obs],
plt.draw()
def error_band(x: list[int], func: Callable, beta: list[Obs]) -> ndarray:
def error_band(x: list[int], func: Callable, beta: Union[Fit_result, list[Obs]]) -> ndarray:
"""Calculate the error band for an array of sample values x, for given fit function func with optimized parameters beta.
Parameters
----------
x : list[int]
A list of sample points where the error band is evaluated.
func : Callable
The function representing the fit model.
beta : Union[Fit_result, list[Obs]]
Optimized fit parameters.
Returns
-------
err : np.array(Obs)

View file

@ -29,7 +29,7 @@ def print_config():
print(f"{key: <10}\t {value}")
def errorbar(x, y, axes=plt, **kwargs):
def errorbar(x: Union[ndarray[int, float, Obs], list[int, float, Obs]], y: Union[ndarray[int, float, Obs], list[int, float, Obs]], axes=plt, **kwargs):
"""pyerrors wrapper for the errorbars method of matplotlib
Parameters

View file

@ -12,7 +12,7 @@ from scipy.stats import skew, skewtest, kurtosis, kurtosistest
import numdifftools as nd
from itertools import groupby
from .covobs import Covobs
from numpy import bool, float64, int64, ndarray
from numpy import float64, int64, ndarray
from typing import Any, Callable, Optional, Union, Sequence, TYPE_CHECKING
if sys.version_info >= (3, 10):
@ -501,7 +501,7 @@ class Obs:
"""
return np.isclose(0.0, self.value, 1e-14, atol) and all(np.allclose(0.0, delta, 1e-14, atol) for delta in self.deltas.values()) and all(np.allclose(0.0, delta.errsq(), 1e-14, atol) for delta in self.covobs.values())
def plot_tauint(self, save: None=None):
def plot_tauint(self, save: Optional[str]=None):
"""Plot integrated autocorrelation time for each ensemble.
Parameters
@ -541,7 +541,7 @@ class Obs:
if save:
fig.savefig(save + "_" + str(e))
def plot_rho(self, save: None=None):
def plot_rho(self, save: Optional[str]=None):
"""Plot normalized autocorrelation function time for each ensemble.
Parameters
@ -626,7 +626,7 @@ class Obs:
plt.title(e_name + f'\nskew: {skew(y_test):.3f} (p={skewtest(y_test).pvalue:.3f}), kurtosis: {kurtosis(y_test):.3f} (p={kurtosistest(y_test).pvalue:.3f})')
plt.draw()
def plot_piechart(self, save: None=None) -> dict[str, float64]:
def plot_piechart(self, save: Optional[str]=None) -> dict[str, float64]:
"""Plot piechart which shows the fractional contribution of each
ensemble to the error and returns a dictionary containing the fractions.
@ -708,7 +708,7 @@ class Obs:
tmp_jacks[1:] = (n * mean - full_data) / (n - 1)
return tmp_jacks
def export_bootstrap(self, samples: int=500, random_numbers: Optional[ndarray]=None, save_rng: None=None) -> ndarray:
def export_bootstrap(self, samples: int=500, random_numbers: Optional[ndarray]=None, save_rng: Optional[str]=None) -> ndarray:
"""Export bootstrap samples from the Obs
Parameters
@ -784,19 +784,19 @@ class Obs:
return int(m.hexdigest(), 16) & 0xFFFFFFFF
# Overload comparisons
def __lt__(self, other: Union[Obs, float, float64]) -> Union[bool, bool]:
def __lt__(self, other: Union[Obs, float, float64, int]) -> bool:
return self.value < other
def __le__(self, other: Union[Obs, float, int]) -> bool:
def __le__(self, other: Union[Obs, float, float64, int]) -> bool:
return self.value <= other
def __gt__(self, other: Union[Obs, float]) -> Union[bool, bool]:
def __gt__(self, other: Union[Obs, float, float64, int]) -> bool:
return self.value > other
def __ge__(self, other: Union[Obs, float, int]) -> Union[bool, bool]:
def __ge__(self, other: Union[Obs, float, float64, int]) -> bool:
return self.value >= other
def __eq__(self, other: Optional[Union[Obs, float64, int, float]]) -> Union[bool, bool]:
def __eq__(self, other: Optional[Union[Obs, float, float64, int]]) -> bool:
if other is None:
return False
return (self - other).is_zero()
@ -815,10 +815,10 @@ class Obs:
else:
return derived_observable(lambda x, **kwargs: x[0] + y, [self], man_grad=[1])
def __radd__(self, y: Union[float, int]) -> Union[Obs, NotImplementedType, CObs, ndarray]:
def __radd__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
return self + y
def __mul__(self, y: Any) -> Union[Obs, ndarray, CObs, NotImplementedType]:
def __mul__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
if isinstance(y, Obs):
return derived_observable(lambda x, **kwargs: x[0] * x[1], [self, y], man_grad=[y.value, self.value])
else:
@ -831,10 +831,10 @@ class Obs:
else:
return derived_observable(lambda x, **kwargs: x[0] * y, [self], man_grad=[y])
def __rmul__(self, y: Union[float, int]) -> Union[Obs, NotImplementedType, CObs, ndarray]:
def __rmul__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
return self * y
def __sub__(self, y: Any) -> Union[Obs, NotImplementedType, ndarray]:
def __sub__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
if isinstance(y, Obs):
return derived_observable(lambda x, **kwargs: x[0] - x[1], [self, y], man_grad=[1, -1])
else:
@ -845,7 +845,7 @@ class Obs:
else:
return derived_observable(lambda x, **kwargs: x[0] - y, [self], man_grad=[1])
def __rsub__(self, y: Union[float, int]) -> Union[Obs, NotImplementedType, CObs, ndarray]:
def __rsub__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
return -1 * (self - y)
def __pos__(self) -> Obs:
@ -959,6 +959,8 @@ class CObs:
if isinstance(self.imag, Obs):
self.imag.gamma_method(**kwargs)
gm = gamma_method
def is_zero(self) -> bool:
"""Checks whether both real and imaginary part are zero within machine precision."""
return self.real == 0.0 and self.imag == 0.0
@ -1057,7 +1059,7 @@ class CObs:
return f"({self.real:{format_type}}{self.imag:+{significance}}j)"
def gamma_method(x: Union[Corr, Obs, ndarray, list[Obs]], **kwargs) -> ndarray:
def gamma_method(x: Union[Corr, Obs, CObs, ndarray, list[Obs, CObs]], **kwargs) -> ndarray:
"""Vectorized version of the gamma_method applicable to lists or arrays of Obs.
See docstring of pe.Obs.gamma_method for details.
@ -1192,7 +1194,7 @@ def _expand_deltas_for_merge(deltas: ndarray, idx: Union[range, list[int]], shap
return np.array([ret[new_idx[i] - new_idx[0]] for i in range(len(new_idx))]) * len(new_idx) / len(idx) * scalefactor
def derived_observable(func: Callable, data: Any, array_mode: bool=False, **kwargs) -> Union[Obs, ndarray]:
def derived_observable(func: Callable, data: Union[list[Obs], ndarray], array_mode: bool=False, **kwargs) -> Union[Obs, ndarray]:
"""Construct a derived Obs according to func(data, **kwargs) using automatic differentiation.
Parameters

View file

@ -181,6 +181,8 @@ def test_fit_correlator():
with pytest.raises(ValueError):
my_corr.fit(f, [0, 2, 3])
fit_res = my_corr.fit(f, fitrange=[0, 1])
def test_plateau():
my_corr = pe.correlators.Corr([pe.pseudo_Obs(1.01324, 0.05, 't'), pe.pseudo_Obs(1.042345, 0.008, 't')])
@ -226,7 +228,7 @@ def test_utility():
corr.print()
corr.print([2, 4])
corr.show()
corr.show(comp=corr)
corr.show(comp=corr, x_range=[2, 5.], y_range=[2, 3.], hide_sigma=0.5, references=[.1, .2, .6], title='TEST')
corr.dump('test_dump', datatype="pickle", path='.')
corr.dump('test_dump', datatype="pickle")

View file

@ -407,6 +407,7 @@ def test_cobs():
obs2 = pe.pseudo_Obs(-0.2, 0.03, 't')
my_cobs = pe.CObs(obs1, obs2)
my_cobs.gm()
assert +my_cobs == my_cobs
assert -my_cobs == 0 - my_cobs
my_cobs == my_cobs