mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-14 22:30:25 +01:00
[Fix] more work on typehints
This commit is contained in:
parent
bbf0b689a1
commit
4814675ff6
7 changed files with 56 additions and 43 deletions
|
@ -45,7 +45,7 @@ class Corr:
|
|||
|
||||
__slots__ = ["content", "N", "T", "tag", "prange"]
|
||||
|
||||
def __init__(self, data_input: Any, padding: list[int]=[0, 0], prange: Optional[list[int]]=None):
|
||||
def __init__(self, data_input: list[Obs, CObs], padding: list[int]=[0, 0], prange: Optional[list[int]]=None):
|
||||
""" Initialize a Corr object.
|
||||
|
||||
Parameters
|
||||
|
@ -285,7 +285,7 @@ class Corr:
|
|||
"""Calculates the per-timeslice trace of a correlator matrix."""
|
||||
if self.N == 1:
|
||||
raise ValueError("Only works for correlator matrices.")
|
||||
newcontent: list[Union[None, float]] = []
|
||||
newcontent: list[Union[None, Obs, CObs]] = []
|
||||
for t in range(self.T):
|
||||
if _check_for_none(self, self.content[t]):
|
||||
newcontent.append(None)
|
||||
|
@ -715,8 +715,8 @@ class Corr:
|
|||
"""
|
||||
if self.N != 1:
|
||||
raise Exception('Correlator must be projected before getting m_eff')
|
||||
newcontent: list[Union[None, Obs]] = []
|
||||
if variant == 'log':
|
||||
newcontent = []
|
||||
for t in range(self.T - 1):
|
||||
if ((self.content[t] is None) or (self.content[t + 1] is None)) or (self.content[t + 1][0].value == 0):
|
||||
newcontent.append(None)
|
||||
|
@ -730,7 +730,6 @@ class Corr:
|
|||
return np.log(Corr(newcontent, padding=[0, 1]))
|
||||
|
||||
elif variant == 'logsym':
|
||||
newcontent = []
|
||||
for t in range(1, self.T - 1):
|
||||
if ((self.content[t - 1] is None) or (self.content[t + 1] is None)) or (self.content[t + 1][0].value == 0):
|
||||
newcontent.append(None)
|
||||
|
@ -752,7 +751,6 @@ class Corr:
|
|||
def root_function(x, d):
|
||||
return func(x * (t - self.T / 2)) / func(x * (t + 1 - self.T / 2)) - d
|
||||
|
||||
newcontent = []
|
||||
for t in range(self.T - 1):
|
||||
if (self.content[t] is None) or (self.content[t + 1] is None) or (self.content[t + 1][0].value == 0):
|
||||
newcontent.append(None)
|
||||
|
@ -769,7 +767,6 @@ class Corr:
|
|||
return Corr(newcontent, padding=[0, 1])
|
||||
|
||||
elif variant == 'arccosh':
|
||||
newcontent = []
|
||||
for t in range(1, self.T - 1):
|
||||
if (self.content[t] is None) or (self.content[t + 1] is None) or (self.content[t - 1] is None) or (self.content[t][0].value == 0):
|
||||
newcontent.append(None)
|
||||
|
@ -782,7 +779,7 @@ class Corr:
|
|||
else:
|
||||
raise ValueError('Unknown variant.')
|
||||
|
||||
def fit(self, function: Callable, fitrange: Optional[Union[str, list[int]]]=None, silent: bool=False, **kwargs) -> Fit_result:
|
||||
def fit(self, function: Callable, fitrange: Optional[list[int]]=None, silent: bool=False, **kwargs) -> Fit_result:
|
||||
r'''Fits function to the data
|
||||
|
||||
Parameters
|
||||
|
@ -865,7 +862,7 @@ class Corr:
|
|||
self.prange = prange
|
||||
return
|
||||
|
||||
def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: None=None, logscale: bool=False, plateau: None=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: None=None, save: None=None, auto_gamma: bool=False, hide_sigma: None=None, references: None=None, title: None=None):
|
||||
def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Optional[Obs, float, int]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Optional[int, float]=None, references: Optional[list[float]]=None, title: Optional[str]=None):
|
||||
"""Plots the correlator using the tag of the correlator as label if available.
|
||||
|
||||
Parameters
|
||||
|
@ -1081,14 +1078,14 @@ class Corr:
|
|||
|
||||
__array_priority__ = 10000
|
||||
|
||||
def __eq__(self, y: Union[Corr, Obs, int]) -> ndarray:
|
||||
def __eq__(self, y: Any) -> ndarray:
|
||||
if isinstance(y, Corr):
|
||||
comp = np.asarray(y.content, dtype=object)
|
||||
else:
|
||||
comp = np.asarray(y)
|
||||
return np.asarray(self.content, dtype=object) == comp
|
||||
|
||||
def __add__(self, y: Any) -> "Corr":
|
||||
def __add__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
|
||||
if isinstance(y, Corr):
|
||||
if ((self.N != y.N) or (self.T != y.T)):
|
||||
raise ValueError("Addition of Corrs with different shape")
|
||||
|
@ -1116,7 +1113,7 @@ class Corr:
|
|||
else:
|
||||
raise TypeError("Corr + wrong type")
|
||||
|
||||
def __mul__(self, y: Any) -> "Corr":
|
||||
def __mul__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
|
||||
if isinstance(y, Corr):
|
||||
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):
|
||||
raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T")
|
||||
|
@ -1187,7 +1184,7 @@ class Corr:
|
|||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __truediv__(self, y: Union[Corr, float, ndarray, int]) -> "Corr":
|
||||
def __truediv__(self, y: Union[Corr, Obs, CObs, int, float, ndarray]) -> "Corr":
|
||||
if isinstance(y, Corr):
|
||||
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):
|
||||
raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T")
|
||||
|
@ -1245,10 +1242,10 @@ class Corr:
|
|||
newcontent = [None if _check_for_none(self, item) else -1. * item for item in self.content]
|
||||
return Corr(newcontent, prange=self.prange)
|
||||
|
||||
def __sub__(self, y: Union[Corr, float, ndarray, int]) -> "Corr":
|
||||
def __sub__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
|
||||
return self + (-y)
|
||||
|
||||
def __pow__(self, y: Union[float, int]) -> "Corr":
|
||||
def __pow__(self, y: Union[Obs, CObs, float, int]) -> "Corr":
|
||||
if isinstance(y, (Obs, int, float, CObs)):
|
||||
newcontent = [None if _check_for_none(self, item) else item**y for item in self.content]
|
||||
return Corr(newcontent, prange=self.prange)
|
||||
|
@ -1321,16 +1318,16 @@ class Corr:
|
|||
return self._apply_func_to_corr(np.arctanh)
|
||||
|
||||
# Right hand side operations (require tweak in main module to work)
|
||||
def __radd__(self, y):
|
||||
def __radd__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
|
||||
return self + y
|
||||
|
||||
def __rsub__(self, y: int) -> "Corr":
|
||||
def __rsub__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
|
||||
return -self + y
|
||||
|
||||
def __rmul__(self, y: Union[float, int]) -> "Corr":
|
||||
def __rmul__(self, y: Union[Corr, Obs, CObs, int, float, complex, ndarray]) -> "Corr":
|
||||
return self * y
|
||||
|
||||
def __rtruediv__(self, y: int) -> "Corr":
|
||||
def __rtruediv__(self, y: Union[Corr, Obs, CObs, int, float, ndarray]) -> "Corr":
|
||||
return (self / y) ** (-1)
|
||||
|
||||
@property
|
||||
|
@ -1353,7 +1350,7 @@ class Corr:
|
|||
|
||||
return self._apply_func_to_corr(return_imag)
|
||||
|
||||
def prune(self, Ntrunc: int, tproj: int=3, t0proj: int=2, basematrix: None=None) -> "Corr":
|
||||
def prune(self, Ntrunc: int, tproj: int=3, t0proj: int=2, basematrix: Optional[ndarray]=None) -> "Corr":
|
||||
r''' Project large correlation matrix to lowest states
|
||||
|
||||
This method can be used to reduce the size of an (N x N) correlation matrix
|
||||
|
@ -1448,7 +1445,7 @@ def _check_for_none(corr: Corr, entry: Optional[ndarray]) -> bool:
|
|||
return len(list(filter(None, np.asarray(entry).flatten()))) < corr.N ** 2
|
||||
|
||||
|
||||
def _GEVP_solver(Gt: Optional[ndarray], G0: ndarray, method: str='eigh', chol_inv: Optional[ndarray]=None) -> ndarray:
|
||||
def _GEVP_solver(Gt: ndarray, G0: ndarray, method: str='eigh', chol_inv: Optional[ndarray]=None) -> ndarray:
|
||||
r"""Helper function for solving the GEVP and sorting the eigenvectors.
|
||||
|
||||
Solves $G(t)v_i=\lambda_i G(t_0)v_i$ and returns the eigenvectors v_i
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from numpy import ndarray
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class Covobs:
|
||||
|
||||
def __init__(self, mean: Optional[Union[float, int]], cov: Any, name: str, pos: Optional[int]=None, grad: Optional[Union[ndarray, list[float]]]=None):
|
||||
def __init__(self, mean: Union[float, int], cov: Union[list, ndarray], name: str, pos: Optional[int]=None, grad: Optional[Union[ndarray, list[float]]]=None):
|
||||
""" Initialize Covobs object.
|
||||
|
||||
Parameters
|
||||
|
@ -47,7 +47,7 @@ class Covobs:
|
|||
"""
|
||||
return np.dot(np.transpose(self.grad), np.dot(self.cov, self.grad)).item()
|
||||
|
||||
def _set_cov(self, cov: Any):
|
||||
def _set_cov(self, cov: Union[list, ndarray]):
|
||||
""" Set the covariance matrix of the covobs
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -827,9 +827,20 @@ def residual_plot(x: ndarray, y: list[Obs], func: Callable, fit_res: list[Obs],
|
|||
plt.draw()
|
||||
|
||||
|
||||
def error_band(x: list[int], func: Callable, beta: list[Obs]) -> ndarray:
|
||||
def error_band(x: list[int], func: Callable, beta: Union[Fit_result, list[Obs]]) -> ndarray:
|
||||
"""Calculate the error band for an array of sample values x, for given fit function func with optimized parameters beta.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : list[int]
|
||||
A list of sample points where the error band is evaluated.
|
||||
|
||||
func : Callable
|
||||
The function representing the fit model.
|
||||
|
||||
beta : Union[Fit_result, list[Obs]]
|
||||
Optimized fit parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
err : np.array(Obs)
|
||||
|
|
|
@ -29,7 +29,7 @@ def print_config():
|
|||
print(f"{key: <10}\t {value}")
|
||||
|
||||
|
||||
def errorbar(x, y, axes=plt, **kwargs):
|
||||
def errorbar(x: Union[ndarray[int, float, Obs], list[int, float, Obs]], y: Union[ndarray[int, float, Obs], list[int, float, Obs]], axes=plt, **kwargs):
|
||||
"""pyerrors wrapper for the errorbars method of matplotlib
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -12,7 +12,7 @@ from scipy.stats import skew, skewtest, kurtosis, kurtosistest
|
|||
import numdifftools as nd
|
||||
from itertools import groupby
|
||||
from .covobs import Covobs
|
||||
from numpy import bool, float64, int64, ndarray
|
||||
from numpy import float64, int64, ndarray
|
||||
from typing import Any, Callable, Optional, Union, Sequence, TYPE_CHECKING
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
|
@ -501,7 +501,7 @@ class Obs:
|
|||
"""
|
||||
return np.isclose(0.0, self.value, 1e-14, atol) and all(np.allclose(0.0, delta, 1e-14, atol) for delta in self.deltas.values()) and all(np.allclose(0.0, delta.errsq(), 1e-14, atol) for delta in self.covobs.values())
|
||||
|
||||
def plot_tauint(self, save: None=None):
|
||||
def plot_tauint(self, save: Optional[str]=None):
|
||||
"""Plot integrated autocorrelation time for each ensemble.
|
||||
|
||||
Parameters
|
||||
|
@ -541,7 +541,7 @@ class Obs:
|
|||
if save:
|
||||
fig.savefig(save + "_" + str(e))
|
||||
|
||||
def plot_rho(self, save: None=None):
|
||||
def plot_rho(self, save: Optional[str]=None):
|
||||
"""Plot normalized autocorrelation function time for each ensemble.
|
||||
|
||||
Parameters
|
||||
|
@ -626,7 +626,7 @@ class Obs:
|
|||
plt.title(e_name + f'\nskew: {skew(y_test):.3f} (p={skewtest(y_test).pvalue:.3f}), kurtosis: {kurtosis(y_test):.3f} (p={kurtosistest(y_test).pvalue:.3f})')
|
||||
plt.draw()
|
||||
|
||||
def plot_piechart(self, save: None=None) -> dict[str, float64]:
|
||||
def plot_piechart(self, save: Optional[str]=None) -> dict[str, float64]:
|
||||
"""Plot piechart which shows the fractional contribution of each
|
||||
ensemble to the error and returns a dictionary containing the fractions.
|
||||
|
||||
|
@ -708,7 +708,7 @@ class Obs:
|
|||
tmp_jacks[1:] = (n * mean - full_data) / (n - 1)
|
||||
return tmp_jacks
|
||||
|
||||
def export_bootstrap(self, samples: int=500, random_numbers: Optional[ndarray]=None, save_rng: None=None) -> ndarray:
|
||||
def export_bootstrap(self, samples: int=500, random_numbers: Optional[ndarray]=None, save_rng: Optional[str]=None) -> ndarray:
|
||||
"""Export bootstrap samples from the Obs
|
||||
|
||||
Parameters
|
||||
|
@ -784,19 +784,19 @@ class Obs:
|
|||
return int(m.hexdigest(), 16) & 0xFFFFFFFF
|
||||
|
||||
# Overload comparisons
|
||||
def __lt__(self, other: Union[Obs, float, float64]) -> Union[bool, bool]:
|
||||
def __lt__(self, other: Union[Obs, float, float64, int]) -> bool:
|
||||
return self.value < other
|
||||
|
||||
def __le__(self, other: Union[Obs, float, int]) -> bool:
|
||||
def __le__(self, other: Union[Obs, float, float64, int]) -> bool:
|
||||
return self.value <= other
|
||||
|
||||
def __gt__(self, other: Union[Obs, float]) -> Union[bool, bool]:
|
||||
def __gt__(self, other: Union[Obs, float, float64, int]) -> bool:
|
||||
return self.value > other
|
||||
|
||||
def __ge__(self, other: Union[Obs, float, int]) -> Union[bool, bool]:
|
||||
def __ge__(self, other: Union[Obs, float, float64, int]) -> bool:
|
||||
return self.value >= other
|
||||
|
||||
def __eq__(self, other: Optional[Union[Obs, float64, int, float]]) -> Union[bool, bool]:
|
||||
def __eq__(self, other: Optional[Union[Obs, float, float64, int]]) -> bool:
|
||||
if other is None:
|
||||
return False
|
||||
return (self - other).is_zero()
|
||||
|
@ -815,10 +815,10 @@ class Obs:
|
|||
else:
|
||||
return derived_observable(lambda x, **kwargs: x[0] + y, [self], man_grad=[1])
|
||||
|
||||
def __radd__(self, y: Union[float, int]) -> Union[Obs, NotImplementedType, CObs, ndarray]:
|
||||
def __radd__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
|
||||
return self + y
|
||||
|
||||
def __mul__(self, y: Any) -> Union[Obs, ndarray, CObs, NotImplementedType]:
|
||||
def __mul__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
|
||||
if isinstance(y, Obs):
|
||||
return derived_observable(lambda x, **kwargs: x[0] * x[1], [self, y], man_grad=[y.value, self.value])
|
||||
else:
|
||||
|
@ -831,10 +831,10 @@ class Obs:
|
|||
else:
|
||||
return derived_observable(lambda x, **kwargs: x[0] * y, [self], man_grad=[y])
|
||||
|
||||
def __rmul__(self, y: Union[float, int]) -> Union[Obs, NotImplementedType, CObs, ndarray]:
|
||||
def __rmul__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
|
||||
return self * y
|
||||
|
||||
def __sub__(self, y: Any) -> Union[Obs, NotImplementedType, ndarray]:
|
||||
def __sub__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
|
||||
if isinstance(y, Obs):
|
||||
return derived_observable(lambda x, **kwargs: x[0] - x[1], [self, y], man_grad=[1, -1])
|
||||
else:
|
||||
|
@ -845,7 +845,7 @@ class Obs:
|
|||
else:
|
||||
return derived_observable(lambda x, **kwargs: x[0] - y, [self], man_grad=[1])
|
||||
|
||||
def __rsub__(self, y: Union[float, int]) -> Union[Obs, NotImplementedType, CObs, ndarray]:
|
||||
def __rsub__(self, y: Any) -> Union[Obs, NotImplementedType, CObs, ndarray]:
|
||||
return -1 * (self - y)
|
||||
|
||||
def __pos__(self) -> Obs:
|
||||
|
@ -959,6 +959,8 @@ class CObs:
|
|||
if isinstance(self.imag, Obs):
|
||||
self.imag.gamma_method(**kwargs)
|
||||
|
||||
gm = gamma_method
|
||||
|
||||
def is_zero(self) -> bool:
|
||||
"""Checks whether both real and imaginary part are zero within machine precision."""
|
||||
return self.real == 0.0 and self.imag == 0.0
|
||||
|
@ -1057,7 +1059,7 @@ class CObs:
|
|||
return f"({self.real:{format_type}}{self.imag:+{significance}}j)"
|
||||
|
||||
|
||||
def gamma_method(x: Union[Corr, Obs, ndarray, list[Obs]], **kwargs) -> ndarray:
|
||||
def gamma_method(x: Union[Corr, Obs, CObs, ndarray, list[Obs, CObs]], **kwargs) -> ndarray:
|
||||
"""Vectorized version of the gamma_method applicable to lists or arrays of Obs.
|
||||
|
||||
See docstring of pe.Obs.gamma_method for details.
|
||||
|
@ -1192,7 +1194,7 @@ def _expand_deltas_for_merge(deltas: ndarray, idx: Union[range, list[int]], shap
|
|||
return np.array([ret[new_idx[i] - new_idx[0]] for i in range(len(new_idx))]) * len(new_idx) / len(idx) * scalefactor
|
||||
|
||||
|
||||
def derived_observable(func: Callable, data: Any, array_mode: bool=False, **kwargs) -> Union[Obs, ndarray]:
|
||||
def derived_observable(func: Callable, data: Union[list[Obs], ndarray], array_mode: bool=False, **kwargs) -> Union[Obs, ndarray]:
|
||||
"""Construct a derived Obs according to func(data, **kwargs) using automatic differentiation.
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -181,6 +181,8 @@ def test_fit_correlator():
|
|||
with pytest.raises(ValueError):
|
||||
my_corr.fit(f, [0, 2, 3])
|
||||
|
||||
fit_res = my_corr.fit(f, fitrange=[0, 1])
|
||||
|
||||
|
||||
def test_plateau():
|
||||
my_corr = pe.correlators.Corr([pe.pseudo_Obs(1.01324, 0.05, 't'), pe.pseudo_Obs(1.042345, 0.008, 't')])
|
||||
|
@ -226,7 +228,7 @@ def test_utility():
|
|||
corr.print()
|
||||
corr.print([2, 4])
|
||||
corr.show()
|
||||
corr.show(comp=corr)
|
||||
corr.show(comp=corr, x_range=[2, 5.], y_range=[2, 3.], hide_sigma=0.5, references=[.1, .2, .6], title='TEST')
|
||||
|
||||
corr.dump('test_dump', datatype="pickle", path='.')
|
||||
corr.dump('test_dump', datatype="pickle")
|
||||
|
|
|
@ -407,6 +407,7 @@ def test_cobs():
|
|||
obs2 = pe.pseudo_Obs(-0.2, 0.03, 't')
|
||||
|
||||
my_cobs = pe.CObs(obs1, obs2)
|
||||
my_cobs.gm()
|
||||
assert +my_cobs == my_cobs
|
||||
assert -my_cobs == 0 - my_cobs
|
||||
my_cobs == my_cobs
|
||||
|
|
Loading…
Add table
Reference in a new issue