[Feat] Add type hints to pyerrors modules

This commit is contained in:
Fabian Joswig 2024-12-25 11:09:58 +01:00
commit 3db8eb2989
11 changed files with 236 additions and 207 deletions

View file

@ -1,3 +1,4 @@
from __future__ import annotations
import gc
from collections.abc import Sequence
import warnings
@ -15,6 +16,8 @@ from autograd import elementwise_grad as egrad
from numdifftools import Jacobian as num_jacobian
from numdifftools import Hessian as num_hessian
from .obs import Obs, derived_observable, covariance, cov_Obs, invert_corr_cov_cholesky
from numpy import ndarray
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
class Fit_result(Sequence):
@ -36,10 +39,10 @@ class Fit_result(Sequence):
def __init__(self):
self.fit_parameters = None
def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Obs:
return self.fit_parameters[idx]
def __len__(self):
def __len__(self) -> int:
return len(self.fit_parameters)
def gamma_method(self, **kwargs):
@ -48,7 +51,7 @@ class Fit_result(Sequence):
gm = gamma_method
def __str__(self):
def __str__(self) -> str:
my_str = 'Goodness of fit:\n'
if hasattr(self, 'chisquare_by_dof'):
my_str += '\u03C7\u00b2/d.o.f. = ' + f'{self.chisquare_by_dof:2.6f}' + '\n'
@ -65,12 +68,12 @@ class Fit_result(Sequence):
my_str += str(i_par) + '\t' + ' ' * int(par >= 0) + str(par).rjust(int(par < 0.0)) + '\n'
return my_str
def __repr__(self):
def __repr__(self) -> str:
m = max(map(len, list(self.__dict__.keys()))) + 1
return '\n'.join([key.rjust(m) + ': ' + repr(value) for key, value in sorted(self.__dict__.items())])
def least_squares(x, y, func, priors=None, silent=False, **kwargs):
def least_squares(x: Any, y: Union[Dict[str, ndarray], List[Obs], ndarray, Dict[str, List[Obs]]], func: Union[Callable, Dict[str, Callable]], priors: Optional[Union[Dict[int, str], List[str], List[Obs], Dict[int, Obs]]]=None, silent: bool=False, **kwargs) -> Fit_result:
r'''Performs a non-linear fit to y = func(x).
```
@ -503,7 +506,7 @@ def least_squares(x, y, func, priors=None, silent=False, **kwargs):
return output
def total_least_squares(x, y, func, silent=False, **kwargs):
def total_least_squares(x: List[Obs], y: List[Obs], func: Callable, silent: bool=False, **kwargs) -> Fit_result:
r'''Performs a non-linear fit to y = func(x) and returns a list of Obs corresponding to the fit parameters.
Parameters
@ -707,7 +710,7 @@ def total_least_squares(x, y, func, silent=False, **kwargs):
return output
def fit_lin(x, y, **kwargs):
def fit_lin(x: List[Union[Obs, int, float]], y: List[Obs], **kwargs) -> List[Obs]:
"""Performs a linear fit to y = n + m * x and returns two Obs n, m.
Parameters
@ -738,7 +741,7 @@ def fit_lin(x, y, **kwargs):
raise TypeError('Unsupported types for x')
def qqplot(x, o_y, func, p, title=""):
def qqplot(x: ndarray, o_y: List[Obs], func: Callable, p: List[Obs], title: str=""):
"""Generates a quantile-quantile plot of the fit result which can be used to
check if the residuals of the fit are gaussian distributed.
@ -768,7 +771,7 @@ def qqplot(x, o_y, func, p, title=""):
plt.draw()
def residual_plot(x, y, func, fit_res, title=""):
def residual_plot(x: ndarray, y: List[Obs], func: Callable, fit_res: List[Obs], title: str=""):
"""Generates a plot which compares the fit to the data and displays the corresponding residuals
For uncorrelated data the residuals are expected to be distributed ~N(0,1).
@ -805,7 +808,7 @@ def residual_plot(x, y, func, fit_res, title=""):
plt.draw()
def error_band(x, func, beta):
def error_band(x: List[int], func: Callable, beta: List[Obs]) -> ndarray:
"""Calculate the error band for an array of sample values x, for given fit function func with optimized parameters beta.
Returns
@ -829,7 +832,7 @@ def error_band(x, func, beta):
return err
def ks_test(objects=None):
def ks_test(objects: Optional[List[Fit_result]]=None):
"""Performs a KolmogorovSmirnov test for the p-values of all fit object.
Parameters
@ -873,7 +876,7 @@ def ks_test(objects=None):
print(scipy.stats.kstest(p_values, 'uniform'))
def _extract_val_and_dval(string):
def _extract_val_and_dval(string: str) -> Tuple[float, float]:
split_string = string.split('(')
if '.' in split_string[0] and '.' not in split_string[1][:-1]:
factor = 10 ** -len(split_string[0].partition('.')[2])
@ -882,7 +885,7 @@ def _extract_val_and_dval(string):
return float(split_string[0]), float(split_string[1][:-1]) * factor
def _construct_prior_obs(i_prior, i_n):
def _construct_prior_obs(i_prior: Union[Obs, str], i_n: int) -> Obs:
if isinstance(i_prior, Obs):
return i_prior
elif isinstance(i_prior, str):