mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-05-15 03:53:41 +02:00
[Feat] Add type hints to pyerrors modules
This commit is contained in:
parent
997d360db3
commit
3db8eb2989
11 changed files with 236 additions and 207 deletions
|
@ -1,3 +1,4 @@
|
|||
from __future__ import annotations
|
||||
import warnings
|
||||
from itertools import permutations
|
||||
import numpy as np
|
||||
|
@ -9,6 +10,8 @@ from .misc import dump_object, _assert_equal_properties
|
|||
from .fits import least_squares
|
||||
from .roots import find_root
|
||||
from . import linalg
|
||||
from numpy import float64, int64, ndarray, ufunc
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
class Corr:
|
||||
|
@ -42,7 +45,7 @@ class Corr:
|
|||
|
||||
__slots__ = ["content", "N", "T", "tag", "prange"]
|
||||
|
||||
def __init__(self, data_input, padding=[0, 0], prange=None):
|
||||
def __init__(self, data_input: Any, padding: List[int]=[0, 0], prange: Optional[List[int]]=None):
|
||||
""" Initialize a Corr object.
|
||||
|
||||
Parameters
|
||||
|
@ -119,7 +122,7 @@ class Corr:
|
|||
self.T = len(self.content)
|
||||
self.prange = prange
|
||||
|
||||
def __getitem__(self, idx):
|
||||
def __getitem__(self, idx: Union[slice, int]) -> Union[CObs, Obs, ndarray, List[ndarray]]:
|
||||
"""Return the content of timeslice idx"""
|
||||
if self.content[idx] is None:
|
||||
return None
|
||||
|
@ -151,7 +154,7 @@ class Corr:
|
|||
|
||||
gm = gamma_method
|
||||
|
||||
def projected(self, vector_l=None, vector_r=None, normalize=False):
|
||||
def projected(self, vector_l: Optional[Union[ndarray, List[Optional[ndarray]]]]=None, vector_r: None=None, normalize: bool=False) -> "Corr":
|
||||
"""We need to project the Correlator with a Vector to get a single value at each timeslice.
|
||||
|
||||
The method can use one or two vectors.
|
||||
|
@ -190,7 +193,7 @@ class Corr:
|
|||
newcontent = [None if (_check_for_none(self, self.content[t]) or vector_l[t] is None or vector_r[t] is None) else np.asarray([vector_l[t].T @ self.content[t] @ vector_r[t]]) for t in range(self.T)]
|
||||
return Corr(newcontent)
|
||||
|
||||
def item(self, i, j):
|
||||
def item(self, i: int, j: int) -> "Corr":
|
||||
"""Picks the element [i,j] from every matrix and returns a correlator containing one Obs per timeslice.
|
||||
|
||||
Parameters
|
||||
|
@ -205,7 +208,7 @@ class Corr:
|
|||
newcontent = [None if (item is None) else item[i, j] for item in self.content]
|
||||
return Corr(newcontent)
|
||||
|
||||
def plottable(self):
|
||||
def plottable(self) -> Union[Tuple[List[int], List[float64], List[float64]], Tuple[List[int], List[float], List[float64]]]:
|
||||
"""Outputs the correlator in a plotable format.
|
||||
|
||||
Outputs three lists containing the timeslice index, the value on each
|
||||
|
@ -219,7 +222,7 @@ class Corr:
|
|||
|
||||
return x_list, y_list, y_err_list
|
||||
|
||||
def symmetric(self):
|
||||
def symmetric(self) -> "Corr":
|
||||
""" Symmetrize the correlator around x0=0."""
|
||||
if self.N != 1:
|
||||
raise ValueError('symmetric cannot be safely applied to multi-dimensional correlators.')
|
||||
|
@ -240,7 +243,7 @@ class Corr:
|
|||
raise ValueError("Corr could not be symmetrized: No redundant values")
|
||||
return Corr(newcontent, prange=self.prange)
|
||||
|
||||
def anti_symmetric(self):
|
||||
def anti_symmetric(self) -> "Corr":
|
||||
"""Anti-symmetrize the correlator around x0=0."""
|
||||
if self.N != 1:
|
||||
raise TypeError('anti_symmetric cannot be safely applied to multi-dimensional correlators.')
|
||||
|
@ -277,7 +280,7 @@ class Corr:
|
|||
return False
|
||||
return True
|
||||
|
||||
def trace(self):
|
||||
def trace(self) -> "Corr":
|
||||
"""Calculates the per-timeslice trace of a correlator matrix."""
|
||||
if self.N == 1:
|
||||
raise ValueError("Only works for correlator matrices.")
|
||||
|
@ -289,7 +292,7 @@ class Corr:
|
|||
newcontent.append(np.trace(self.content[t]))
|
||||
return Corr(newcontent)
|
||||
|
||||
def matrix_symmetric(self):
|
||||
def matrix_symmetric(self) -> "Corr":
|
||||
"""Symmetrizes the correlator matrices on every timeslice."""
|
||||
if self.N == 1:
|
||||
raise ValueError("Trying to symmetrize a correlator matrix, that already has N=1.")
|
||||
|
@ -299,7 +302,7 @@ class Corr:
|
|||
transposed = [None if _check_for_none(self, G) else G.T for G in self.content]
|
||||
return 0.5 * (Corr(transposed) + self)
|
||||
|
||||
def GEVP(self, t0, ts=None, sort="Eigenvalue", vector_obs=False, **kwargs):
|
||||
def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[str]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[List[List[Optional[ndarray]]], ndarray, List[Optional[ndarray]]]:
|
||||
r'''Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors.
|
||||
|
||||
The eigenvectors are sorted according to the descending eigenvalues, the zeroth eigenvector(s) correspond to the
|
||||
|
@ -405,7 +408,7 @@ class Corr:
|
|||
else:
|
||||
return reordered_vecs
|
||||
|
||||
def Eigenvalue(self, t0, ts=None, state=0, sort="Eigenvalue", **kwargs):
|
||||
def Eigenvalue(self, t0: int, ts: None=None, state: int=0, sort: str="Eigenvalue", **kwargs) -> "Corr":
|
||||
"""Determines the eigenvalue of the GEVP by solving and projecting the correlator
|
||||
|
||||
Parameters
|
||||
|
@ -418,7 +421,7 @@ class Corr:
|
|||
vec = self.GEVP(t0, ts=ts, sort=sort, **kwargs)[state]
|
||||
return self.projected(vec)
|
||||
|
||||
def Hankel(self, N, periodic=False):
|
||||
def Hankel(self, N: int, periodic: bool=False) -> "Corr":
|
||||
"""Constructs an NxN Hankel matrix
|
||||
|
||||
C(t) c(t+1) ... c(t+n-1)
|
||||
|
@ -459,7 +462,7 @@ class Corr:
|
|||
|
||||
return Corr(new_content)
|
||||
|
||||
def roll(self, dt):
|
||||
def roll(self, dt: int) -> "Corr":
|
||||
"""Periodically shift the correlator by dt timeslices
|
||||
|
||||
Parameters
|
||||
|
@ -469,11 +472,11 @@ class Corr:
|
|||
"""
|
||||
return Corr(list(np.roll(np.array(self.content, dtype=object), dt, axis=0)))
|
||||
|
||||
def reverse(self):
|
||||
def reverse(self) -> "Corr":
|
||||
"""Reverse the time ordering of the Corr"""
|
||||
return Corr(self.content[:: -1])
|
||||
|
||||
def thin(self, spacing=2, offset=0):
|
||||
def thin(self, spacing: int=2, offset: int=0) -> "Corr":
|
||||
"""Thin out a correlator to suppress correlations
|
||||
|
||||
Parameters
|
||||
|
@ -491,7 +494,7 @@ class Corr:
|
|||
new_content.append(self.content[t])
|
||||
return Corr(new_content)
|
||||
|
||||
def correlate(self, partner):
|
||||
def correlate(self, partner: Union[Corr, float, Obs]) -> "Corr":
|
||||
"""Correlate the correlator with another correlator or Obs
|
||||
|
||||
Parameters
|
||||
|
@ -520,7 +523,7 @@ class Corr:
|
|||
|
||||
return Corr(new_content)
|
||||
|
||||
def reweight(self, weight, **kwargs):
|
||||
def reweight(self, weight: Obs, **kwargs) -> "Corr":
|
||||
"""Reweight the correlator.
|
||||
|
||||
Parameters
|
||||
|
@ -543,7 +546,7 @@ class Corr:
|
|||
new_content.append(np.array(reweight(weight, t_slice, **kwargs)))
|
||||
return Corr(new_content)
|
||||
|
||||
def T_symmetry(self, partner, parity=+1):
|
||||
def T_symmetry(self, partner: "Corr", parity: int=+1) -> "Corr":
|
||||
"""Return the time symmetry average of the correlator and its partner
|
||||
|
||||
Parameters
|
||||
|
@ -573,7 +576,7 @@ class Corr:
|
|||
|
||||
return (self + T_partner) / 2
|
||||
|
||||
def deriv(self, variant="symmetric"):
|
||||
def deriv(self, variant: Optional[str]="symmetric") -> "Corr":
|
||||
"""Return the first derivative of the correlator with respect to x0.
|
||||
|
||||
Parameters
|
||||
|
@ -638,7 +641,7 @@ class Corr:
|
|||
else:
|
||||
raise ValueError("Unknown variant.")
|
||||
|
||||
def second_deriv(self, variant="symmetric"):
|
||||
def second_deriv(self, variant: Optional[str]="symmetric") -> "Corr":
|
||||
r"""Return the second derivative of the correlator with respect to x0.
|
||||
|
||||
Parameters
|
||||
|
@ -701,7 +704,7 @@ class Corr:
|
|||
else:
|
||||
raise ValueError("Unknown variant.")
|
||||
|
||||
def m_eff(self, variant='log', guess=1.0):
|
||||
def m_eff(self, variant: str='log', guess: float=1.0) -> "Corr":
|
||||
"""Returns the effective mass of the correlator as correlator object
|
||||
|
||||
Parameters
|
||||
|
@ -785,7 +788,7 @@ class Corr:
|
|||
else:
|
||||
raise ValueError('Unknown variant.')
|
||||
|
||||
def fit(self, function, fitrange=None, silent=False, **kwargs):
|
||||
def fit(self, function: Callable, fitrange: Optional[Union[str, List[int]]]=None, silent: bool=False, **kwargs) -> Fit_result:
|
||||
r'''Fits function to the data
|
||||
|
||||
Parameters
|
||||
|
@ -819,7 +822,7 @@ class Corr:
|
|||
result = least_squares(xs, ys, function, silent=silent, **kwargs)
|
||||
return result
|
||||
|
||||
def plateau(self, plateau_range=None, method="fit", auto_gamma=False):
|
||||
def plateau(self, plateau_range: Optional[List[int]]=None, method: str="fit", auto_gamma: bool=False) -> Obs:
|
||||
""" Extract a plateau value from a Corr object
|
||||
|
||||
Parameters
|
||||
|
@ -856,7 +859,7 @@ class Corr:
|
|||
else:
|
||||
raise ValueError("Unsupported plateau method: " + method)
|
||||
|
||||
def set_prange(self, prange):
|
||||
def set_prange(self, prange: List[Union[int, float]]):
|
||||
"""Sets the attribute prange of the Corr object."""
|
||||
if not len(prange) == 2:
|
||||
raise ValueError("prange must be a list or array with two values")
|
||||
|
@ -868,7 +871,7 @@ class Corr:
|
|||
self.prange = prange
|
||||
return
|
||||
|
||||
def show(self, x_range=None, comp=None, y_range=None, logscale=False, plateau=None, fit_res=None, fit_key=None, ylabel=None, save=None, auto_gamma=False, hide_sigma=None, references=None, title=None):
|
||||
def show(self, x_range: Optional[List[int64]]=None, comp: Optional[Corr]=None, y_range: None=None, logscale: bool=False, plateau: None=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: None=None, save: None=None, auto_gamma: bool=False, hide_sigma: None=None, references: None=None, title: None=None):
|
||||
"""Plots the correlator using the tag of the correlator as label if available.
|
||||
|
||||
Parameters
|
||||
|
@ -993,7 +996,7 @@ class Corr:
|
|||
else:
|
||||
raise TypeError("'save' has to be a string.")
|
||||
|
||||
def spaghetti_plot(self, logscale=True):
|
||||
def spaghetti_plot(self, logscale: bool=True):
|
||||
"""Produces a spaghetti plot of the correlator suited to monitor exceptional configurations.
|
||||
|
||||
Parameters
|
||||
|
@ -1022,7 +1025,7 @@ class Corr:
|
|||
plt.title(name)
|
||||
plt.draw()
|
||||
|
||||
def dump(self, filename, datatype="json.gz", **kwargs):
|
||||
def dump(self, filename: str, datatype: str="json.gz", **kwargs):
|
||||
"""Dumps the Corr into a file of chosen type
|
||||
Parameters
|
||||
----------
|
||||
|
@ -1046,10 +1049,10 @@ class Corr:
|
|||
else:
|
||||
raise ValueError("Unknown datatype " + str(datatype))
|
||||
|
||||
def print(self, print_range=None):
|
||||
def print(self, print_range: Optional[List[int]]=None):
|
||||
print(self.__repr__(print_range))
|
||||
|
||||
def __repr__(self, print_range=None):
|
||||
def __repr__(self, print_range: Optional[List[int]]=None) -> str:
|
||||
if print_range is None:
|
||||
print_range = [0, None]
|
||||
|
||||
|
@ -1074,7 +1077,7 @@ class Corr:
|
|||
content_string += '\n'
|
||||
return content_string
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
# We define the basic operations, that can be performed with correlators.
|
||||
|
@ -1084,14 +1087,14 @@ class Corr:
|
|||
|
||||
__array_priority__ = 10000
|
||||
|
||||
def __eq__(self, y):
|
||||
def __eq__(self, y: Union[Corr, Obs, int]) -> ndarray:
|
||||
if isinstance(y, Corr):
|
||||
comp = np.asarray(y.content, dtype=object)
|
||||
else:
|
||||
comp = np.asarray(y)
|
||||
return np.asarray(self.content, dtype=object) == comp
|
||||
|
||||
def __add__(self, y):
|
||||
def __add__(self, y: Any) -> "Corr":
|
||||
if isinstance(y, Corr):
|
||||
if ((self.N != y.N) or (self.T != y.T)):
|
||||
raise ValueError("Addition of Corrs with different shape")
|
||||
|
@ -1119,7 +1122,7 @@ class Corr:
|
|||
else:
|
||||
raise TypeError("Corr + wrong type")
|
||||
|
||||
def __mul__(self, y):
|
||||
def __mul__(self, y: Any) -> "Corr":
|
||||
if isinstance(y, Corr):
|
||||
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):
|
||||
raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T")
|
||||
|
@ -1147,7 +1150,7 @@ class Corr:
|
|||
else:
|
||||
raise TypeError("Corr * wrong type")
|
||||
|
||||
def __matmul__(self, y):
|
||||
def __matmul__(self, y: Union[Corr, ndarray]) -> "Corr":
|
||||
if isinstance(y, np.ndarray):
|
||||
if y.ndim != 2 or y.shape[0] != y.shape[1]:
|
||||
raise ValueError("Can only multiply correlators by square matrices.")
|
||||
|
@ -1174,7 +1177,7 @@ class Corr:
|
|||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __rmatmul__(self, y):
|
||||
def __rmatmul__(self, y: ndarray) -> "Corr":
|
||||
if isinstance(y, np.ndarray):
|
||||
if y.ndim != 2 or y.shape[0] != y.shape[1]:
|
||||
raise ValueError("Can only multiply correlators by square matrices.")
|
||||
|
@ -1190,7 +1193,7 @@ class Corr:
|
|||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __truediv__(self, y):
|
||||
def __truediv__(self, y: Union[Corr, float, ndarray, int]) -> "Corr":
|
||||
if isinstance(y, Corr):
|
||||
if not ((self.N == 1 or y.N == 1 or self.N == y.N) and self.T == y.T):
|
||||
raise ValueError("Multiplication of Corr object requires N=N or N=1 and T=T")
|
||||
|
@ -1244,37 +1247,37 @@ class Corr:
|
|||
else:
|
||||
raise TypeError('Corr / wrong type')
|
||||
|
||||
def __neg__(self):
|
||||
def __neg__(self) -> "Corr":
|
||||
newcontent = [None if _check_for_none(self, item) else -1. * item for item in self.content]
|
||||
return Corr(newcontent, prange=self.prange)
|
||||
|
||||
def __sub__(self, y):
|
||||
def __sub__(self, y: Union[Corr, float, ndarray, int]) -> "Corr":
|
||||
return self + (-y)
|
||||
|
||||
def __pow__(self, y):
|
||||
def __pow__(self, y: Union[float, int]) -> "Corr":
|
||||
if isinstance(y, (Obs, int, float, CObs)):
|
||||
newcontent = [None if _check_for_none(self, item) else item**y for item in self.content]
|
||||
return Corr(newcontent, prange=self.prange)
|
||||
else:
|
||||
raise TypeError('Type of exponent not supported')
|
||||
|
||||
def __abs__(self):
|
||||
def __abs__(self) -> "Corr":
|
||||
newcontent = [None if _check_for_none(self, item) else np.abs(item) for item in self.content]
|
||||
return Corr(newcontent, prange=self.prange)
|
||||
|
||||
# The numpy functions:
|
||||
def sqrt(self):
|
||||
def sqrt(self) -> "Corr":
|
||||
return self ** 0.5
|
||||
|
||||
def log(self):
|
||||
def log(self) -> "Corr":
|
||||
newcontent = [None if _check_for_none(self, item) else np.log(item) for item in self.content]
|
||||
return Corr(newcontent, prange=self.prange)
|
||||
|
||||
def exp(self):
|
||||
def exp(self) -> "Corr":
|
||||
newcontent = [None if _check_for_none(self, item) else np.exp(item) for item in self.content]
|
||||
return Corr(newcontent, prange=self.prange)
|
||||
|
||||
def _apply_func_to_corr(self, func):
|
||||
def _apply_func_to_corr(self, func: Union[Callable, ufunc]) -> "Corr":
|
||||
newcontent = [None if _check_for_none(self, item) else func(item) for item in self.content]
|
||||
for t in range(self.T):
|
||||
if _check_for_none(self, newcontent[t]):
|
||||
|
@ -1287,57 +1290,57 @@ class Corr:
|
|||
raise ValueError('Operation returns undefined correlator')
|
||||
return Corr(newcontent)
|
||||
|
||||
def sin(self):
|
||||
def sin(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.sin)
|
||||
|
||||
def cos(self):
|
||||
def cos(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.cos)
|
||||
|
||||
def tan(self):
|
||||
def tan(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.tan)
|
||||
|
||||
def sinh(self):
|
||||
def sinh(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.sinh)
|
||||
|
||||
def cosh(self):
|
||||
def cosh(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.cosh)
|
||||
|
||||
def tanh(self):
|
||||
def tanh(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.tanh)
|
||||
|
||||
def arcsin(self):
|
||||
def arcsin(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.arcsin)
|
||||
|
||||
def arccos(self):
|
||||
def arccos(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.arccos)
|
||||
|
||||
def arctan(self):
|
||||
def arctan(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.arctan)
|
||||
|
||||
def arcsinh(self):
|
||||
def arcsinh(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.arcsinh)
|
||||
|
||||
def arccosh(self):
|
||||
def arccosh(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.arccosh)
|
||||
|
||||
def arctanh(self):
|
||||
def arctanh(self) -> "Corr":
|
||||
return self._apply_func_to_corr(np.arctanh)
|
||||
|
||||
# Right hand side operations (require tweak in main module to work)
|
||||
def __radd__(self, y):
|
||||
return self + y
|
||||
|
||||
def __rsub__(self, y):
|
||||
def __rsub__(self, y: int) -> "Corr":
|
||||
return -self + y
|
||||
|
||||
def __rmul__(self, y):
|
||||
def __rmul__(self, y: Union[float, int]) -> "Corr":
|
||||
return self * y
|
||||
|
||||
def __rtruediv__(self, y):
|
||||
def __rtruediv__(self, y: int) -> "Corr":
|
||||
return (self / y) ** (-1)
|
||||
|
||||
@property
|
||||
def real(self):
|
||||
def real(self) -> "Corr":
|
||||
def return_real(obs_OR_cobs):
|
||||
if isinstance(obs_OR_cobs.flatten()[0], CObs):
|
||||
return np.vectorize(lambda x: x.real)(obs_OR_cobs)
|
||||
|
@ -1347,7 +1350,7 @@ class Corr:
|
|||
return self._apply_func_to_corr(return_real)
|
||||
|
||||
@property
|
||||
def imag(self):
|
||||
def imag(self) -> "Corr":
|
||||
def return_imag(obs_OR_cobs):
|
||||
if isinstance(obs_OR_cobs.flatten()[0], CObs):
|
||||
return np.vectorize(lambda x: x.imag)(obs_OR_cobs)
|
||||
|
@ -1356,7 +1359,7 @@ class Corr:
|
|||
|
||||
return self._apply_func_to_corr(return_imag)
|
||||
|
||||
def prune(self, Ntrunc, tproj=3, t0proj=2, basematrix=None):
|
||||
def prune(self, Ntrunc: int, tproj: int=3, t0proj: int=2, basematrix: None=None) -> "Corr":
|
||||
r''' Project large correlation matrix to lowest states
|
||||
|
||||
This method can be used to reduce the size of an (N x N) correlation matrix
|
||||
|
@ -1414,7 +1417,7 @@ class Corr:
|
|||
return Corr(newcontent)
|
||||
|
||||
|
||||
def _sort_vectors(vec_set_in, ts):
|
||||
def _sort_vectors(vec_set_in: List[Optional[ndarray]], ts: int) -> List[Optional[Union[ndarray, List[ndarray]]]]:
|
||||
"""Helper function used to find a set of Eigenvectors consistent over all timeslices"""
|
||||
|
||||
if isinstance(vec_set_in[ts][0][0], Obs):
|
||||
|
@ -1446,12 +1449,12 @@ def _sort_vectors(vec_set_in, ts):
|
|||
return sorted_vec_set
|
||||
|
||||
|
||||
def _check_for_none(corr, entry):
|
||||
def _check_for_none(corr: Corr, entry: Optional[ndarray]) -> bool:
|
||||
"""Checks if entry for correlator corr is None"""
|
||||
return len(list(filter(None, np.asarray(entry).flatten()))) < corr.N ** 2
|
||||
|
||||
|
||||
def _GEVP_solver(Gt, G0, method='eigh', chol_inv=None):
|
||||
def _GEVP_solver(Gt: Optional[ndarray], G0: ndarray, method: str='eigh', chol_inv: Optional[ndarray]=None) -> ndarray:
|
||||
r"""Helper function for solving the GEVP and sorting the eigenvectors.
|
||||
|
||||
Solves $G(t)v_i=\lambda_i G(t_0)v_i$ and returns the eigenvectors v_i
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue