mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-12-15 11:51:50 +01:00
[Feat] Add type hints to pyerrors modules
This commit is contained in:
parent
997d360db3
commit
3db8eb2989
11 changed files with 236 additions and 207 deletions
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import annotations
|
||||
import gc
|
||||
from collections.abc import Sequence
|
||||
import warnings
|
||||
|
|
@ -15,6 +16,8 @@ from autograd import elementwise_grad as egrad
|
|||
from numdifftools import Jacobian as num_jacobian
|
||||
from numdifftools import Hessian as num_hessian
|
||||
from .obs import Obs, derived_observable, covariance, cov_Obs, invert_corr_cov_cholesky
|
||||
from numpy import ndarray
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
class Fit_result(Sequence):
|
||||
|
|
@ -36,10 +39,10 @@ class Fit_result(Sequence):
|
|||
def __init__(self):
|
||||
self.fit_parameters = None
|
||||
|
||||
def __getitem__(self, idx):
|
||||
def __getitem__(self, idx: int) -> Obs:
|
||||
return self.fit_parameters[idx]
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.fit_parameters)
|
||||
|
||||
def gamma_method(self, **kwargs):
|
||||
|
|
@ -48,7 +51,7 @@ class Fit_result(Sequence):
|
|||
|
||||
gm = gamma_method
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
my_str = 'Goodness of fit:\n'
|
||||
if hasattr(self, 'chisquare_by_dof'):
|
||||
my_str += '\u03C7\u00b2/d.o.f. = ' + f'{self.chisquare_by_dof:2.6f}' + '\n'
|
||||
|
|
@ -65,12 +68,12 @@ class Fit_result(Sequence):
|
|||
my_str += str(i_par) + '\t' + ' ' * int(par >= 0) + str(par).rjust(int(par < 0.0)) + '\n'
|
||||
return my_str
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
m = max(map(len, list(self.__dict__.keys()))) + 1
|
||||
return '\n'.join([key.rjust(m) + ': ' + repr(value) for key, value in sorted(self.__dict__.items())])
|
||||
|
||||
|
||||
def least_squares(x, y, func, priors=None, silent=False, **kwargs):
|
||||
def least_squares(x: Any, y: Union[Dict[str, ndarray], List[Obs], ndarray, Dict[str, List[Obs]]], func: Union[Callable, Dict[str, Callable]], priors: Optional[Union[Dict[int, str], List[str], List[Obs], Dict[int, Obs]]]=None, silent: bool=False, **kwargs) -> Fit_result:
|
||||
r'''Performs a non-linear fit to y = func(x).
|
||||
```
|
||||
|
||||
|
|
@ -503,7 +506,7 @@ def least_squares(x, y, func, priors=None, silent=False, **kwargs):
|
|||
return output
|
||||
|
||||
|
||||
def total_least_squares(x, y, func, silent=False, **kwargs):
|
||||
def total_least_squares(x: List[Obs], y: List[Obs], func: Callable, silent: bool=False, **kwargs) -> Fit_result:
|
||||
r'''Performs a non-linear fit to y = func(x) and returns a list of Obs corresponding to the fit parameters.
|
||||
|
||||
Parameters
|
||||
|
|
@ -707,7 +710,7 @@ def total_least_squares(x, y, func, silent=False, **kwargs):
|
|||
return output
|
||||
|
||||
|
||||
def fit_lin(x, y, **kwargs):
|
||||
def fit_lin(x: List[Union[Obs, int, float]], y: List[Obs], **kwargs) -> List[Obs]:
|
||||
"""Performs a linear fit to y = n + m * x and returns two Obs n, m.
|
||||
|
||||
Parameters
|
||||
|
|
@ -738,7 +741,7 @@ def fit_lin(x, y, **kwargs):
|
|||
raise TypeError('Unsupported types for x')
|
||||
|
||||
|
||||
def qqplot(x, o_y, func, p, title=""):
|
||||
def qqplot(x: ndarray, o_y: List[Obs], func: Callable, p: List[Obs], title: str=""):
|
||||
"""Generates a quantile-quantile plot of the fit result which can be used to
|
||||
check if the residuals of the fit are gaussian distributed.
|
||||
|
||||
|
|
@ -768,7 +771,7 @@ def qqplot(x, o_y, func, p, title=""):
|
|||
plt.draw()
|
||||
|
||||
|
||||
def residual_plot(x, y, func, fit_res, title=""):
|
||||
def residual_plot(x: ndarray, y: List[Obs], func: Callable, fit_res: List[Obs], title: str=""):
|
||||
"""Generates a plot which compares the fit to the data and displays the corresponding residuals
|
||||
|
||||
For uncorrelated data the residuals are expected to be distributed ~N(0,1).
|
||||
|
|
@ -805,7 +808,7 @@ def residual_plot(x, y, func, fit_res, title=""):
|
|||
plt.draw()
|
||||
|
||||
|
||||
def error_band(x, func, beta):
|
||||
def error_band(x: List[int], func: Callable, beta: List[Obs]) -> ndarray:
|
||||
"""Calculate the error band for an array of sample values x, for given fit function func with optimized parameters beta.
|
||||
|
||||
Returns
|
||||
|
|
@ -829,7 +832,7 @@ def error_band(x, func, beta):
|
|||
return err
|
||||
|
||||
|
||||
def ks_test(objects=None):
|
||||
def ks_test(objects: Optional[List[Fit_result]]=None):
|
||||
"""Performs a Kolmogorov–Smirnov test for the p-values of all fit object.
|
||||
|
||||
Parameters
|
||||
|
|
@ -873,7 +876,7 @@ def ks_test(objects=None):
|
|||
print(scipy.stats.kstest(p_values, 'uniform'))
|
||||
|
||||
|
||||
def _extract_val_and_dval(string):
|
||||
def _extract_val_and_dval(string: str) -> Tuple[float, float]:
|
||||
split_string = string.split('(')
|
||||
if '.' in split_string[0] and '.' not in split_string[1][:-1]:
|
||||
factor = 10 ** -len(split_string[0].partition('.')[2])
|
||||
|
|
@ -882,7 +885,7 @@ def _extract_val_and_dval(string):
|
|||
return float(split_string[0]), float(split_string[1][:-1]) * factor
|
||||
|
||||
|
||||
def _construct_prior_obs(i_prior, i_n):
|
||||
def _construct_prior_obs(i_prior: Union[Obs, str], i_n: int) -> Obs:
|
||||
if isinstance(i_prior, Obs):
|
||||
return i_prior
|
||||
elif isinstance(i_prior, str):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue