[Fix] Fix type annotations for first part of obs.py

This commit is contained in:
Fabian Joswig 2025-01-03 17:09:05 +01:00
parent 2d34b355ed
commit a9e082c333

View file

@ -63,13 +63,13 @@ class Obs:
'idl', 'tag', '_covobs', '__dict__']
S_global = 2.0
S_dict = {}
S_dict: dict[str, float] = {}
tau_exp_global = 0.0
tau_exp_dict = {}
tau_exp_dict: dict[str, float] = {}
N_sigma_global = 1.0
N_sigma_dict = {}
N_sigma_dict: dict[str, int] = {}
def __init__(self, samples: Union[List[List[int]], List[ndarray], ndarray, List[List[float64]], List[List[float]]], names: List[Union[int, Any, str]], idl: Optional[Any]=None, **kwargs):
def __init__(self, samples: Union[List[List[int]], List[ndarray], ndarray, List[List[float64]], List[List[float]]], names: List[str], idl: Optional[list[Union[list[int], range]]]=None, **kwargs):
""" Initialize Obs object.
Parameters
@ -82,7 +82,8 @@ class Obs:
list of ranges or lists on which the samples are defined
"""
if kwargs.get("means") is None and len(samples):
means: Optional[list[float]] = kwargs.get("means")
if means is None and len(samples):
if len(samples) != len(names):
raise ValueError('Length of samples and names incompatible.')
if idl is not None:
@ -100,17 +101,17 @@ class Obs:
if min(len(x) for x in samples) <= 4:
raise ValueError('Samples have to have at least 5 entries.')
self.names = sorted(names)
self.names: list[str] = sorted(names)
self.shape = {}
self.r_values = {}
self.deltas = {}
self._covobs = {}
self._covobs: dict[str, Covobs] = {}
self._value = 0
self.N = 0
self.idl = {}
self._value: float = 0.0
self.N: int = 0
self.idl: dict[str, Union[list[int], range]] = {}
if idl is not None:
for name, idx in sorted(zip(names, idl)):
for name, idx in sorted(zip(names, idl, strict=True)):
if isinstance(idx, range):
self.idl[name] = idx
elif isinstance(idx, (list, np.ndarray)):
@ -124,19 +125,19 @@ class Obs:
else:
self.idl[name] = list(idx)
else:
raise TypeError('incompatible type for idl[%s].' % (name))
raise TypeError('incompatible type for idl[%s].' % name)
else:
for name, sample in sorted(zip(names, samples)):
for name, sample in sorted(zip(names, samples, strict=True)):
self.idl[name] = range(1, len(sample) + 1)
if kwargs.get("means") is not None:
for name, sample, mean in sorted(zip(names, samples, kwargs.get("means"))):
if means is not None:
for name, sample, mean in sorted(zip(names, samples, means, strict=True)):
self.shape[name] = len(self.idl[name])
self.N += self.shape[name]
self.r_values[name] = mean
self.deltas[name] = sample
else:
for name, sample in sorted(zip(names, samples)):
for name, sample in sorted(zip(names, samples, strict=True)):
self.shape[name] = len(self.idl[name])
self.N += self.shape[name]
if len(sample) != self.shape[name]:
@ -642,7 +643,7 @@ class Obs:
if save:
fig1.savefig(save)
return dict(zip(labels, sizes))
return dict(zip(labels, sizes, strict=True))
def dump(self, filename: str, datatype: str="json.gz", description: str="", **kwargs):
"""Dump the Obs to a file 'name' of chosen format.