[Fix] Simplify type hints

This commit is contained in:
Fabian Joswig 2025-01-03 22:43:19 +01:00
commit 1c6053ef61
12 changed files with 84 additions and 85 deletions

View file

@ -17,7 +17,7 @@ from numdifftools import Jacobian as num_jacobian
from numdifftools import Hessian as num_hessian
from .obs import Obs, derived_observable, covariance, cov_Obs, invert_corr_cov_cholesky
from numpy import ndarray
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
class Fit_result(Sequence):
@ -73,7 +73,7 @@ class Fit_result(Sequence):
return '\n'.join([key.rjust(m) + ': ' + repr(value) for key, value in sorted(self.__dict__.items())])
def least_squares(x: Any, y: Union[Dict[str, ndarray], List[Obs], ndarray, Dict[str, List[Obs]]], func: Union[Callable, Dict[str, Callable]], priors: Optional[Union[Dict[int, str], List[str], List[Obs], Dict[int, Obs]]]=None, silent: bool=False, **kwargs) -> Fit_result:
def least_squares(x: Any, y: Union[dict[str, ndarray], list[Obs], ndarray, dict[str, list[Obs]]], func: Union[Callable, dict[str, Callable]], priors: Optional[Union[dict[int, str], list[str], list[Obs], dict[int, Obs]]]=None, silent: bool=False, **kwargs) -> Fit_result:
r'''Performs a non-linear fit to y = func(x).
```
@ -506,7 +506,7 @@ def least_squares(x: Any, y: Union[Dict[str, ndarray], List[Obs], ndarray, Dict[
return output
def total_least_squares(x: List[Obs], y: List[Obs], func: Callable, silent: bool=False, **kwargs) -> Fit_result:
def total_least_squares(x: list[Obs], y: list[Obs], func: Callable, silent: bool=False, **kwargs) -> Fit_result:
r'''Performs a non-linear fit to y = func(x) and returns a list of Obs corresponding to the fit parameters.
Parameters
@ -710,7 +710,7 @@ def total_least_squares(x: List[Obs], y: List[Obs], func: Callable, silent: bool
return output
def fit_lin(x: List[Union[Obs, int, float]], y: List[Obs], **kwargs) -> List[Obs]:
def fit_lin(x: list[Union[Obs, int, float]], y: list[Obs], **kwargs) -> list[Obs]:
"""Performs a linear fit to y = n + m * x and returns two Obs n, m.
Parameters
@ -741,7 +741,7 @@ def fit_lin(x: List[Union[Obs, int, float]], y: List[Obs], **kwargs) -> List[Obs
raise TypeError('Unsupported types for x')
def qqplot(x: ndarray, o_y: List[Obs], func: Callable, p: List[Obs], title: str=""):
def qqplot(x: ndarray, o_y: list[Obs], func: Callable, p: list[Obs], title: str=""):
"""Generates a quantile-quantile plot of the fit result which can be used to
check if the residuals of the fit are gaussian distributed.
@ -771,7 +771,7 @@ def qqplot(x: ndarray, o_y: List[Obs], func: Callable, p: List[Obs], title: str=
plt.draw()
def residual_plot(x: ndarray, y: List[Obs], func: Callable, fit_res: List[Obs], title: str=""):
def residual_plot(x: ndarray, y: list[Obs], func: Callable, fit_res: list[Obs], title: str=""):
"""Generates a plot which compares the fit to the data and displays the corresponding residuals
For uncorrelated data the residuals are expected to be distributed ~N(0,1).
@ -808,7 +808,7 @@ def residual_plot(x: ndarray, y: List[Obs], func: Callable, fit_res: List[Obs],
plt.draw()
def error_band(x: List[int], func: Callable, beta: List[Obs]) -> ndarray:
def error_band(x: list[int], func: Callable, beta: list[Obs]) -> ndarray:
"""Calculate the error band for an array of sample values x, for given fit function func with optimized parameters beta.
Returns
@ -832,7 +832,7 @@ def error_band(x: List[int], func: Callable, beta: List[Obs]) -> ndarray:
return err
def ks_test(objects: Optional[List[Fit_result]]=None):
def ks_test(objects: Optional[list[Fit_result]]=None):
"""Performs a KolmogorovSmirnov test for the p-values of all fit object.
Parameters
@ -876,7 +876,7 @@ def ks_test(objects: Optional[List[Fit_result]]=None):
print(scipy.stats.kstest(p_values, 'uniform'))
def _extract_val_and_dval(string: str) -> Tuple[float, float]:
def _extract_val_and_dval(string: str) -> tuple[float, float]:
split_string = string.split('(')
if '.' in split_string[0] and '.' not in split_string[1][:-1]:
factor = 10 ** -len(split_string[0].partition('.')[2])