[Fix] Fix type hints in misc.py and remove strict zips for python 3.9

compatability
This commit is contained in:
Fabian Joswig 2025-01-03 18:39:34 +01:00
parent 9389ad67c9
commit 23d4f4c320
3 changed files with 18 additions and 11 deletions

View file

@ -75,8 +75,11 @@ def dump_object(obj: Corr, name: str, **kwargs):
-------
None
"""
if 'path' in kwargs:
file_name = kwargs.get('path') + '/' + name + '.p'
path = kwargs.get('path')
if path is not None:
if not isinstance(path, str):
raise Exception("Path has to be a string.")
file_name = path + '/' + name + '.p'
else:
file_name = name + '.p'
with open(file_name, 'wb') as fb:
@ -100,7 +103,7 @@ def load_object(path: str) -> Union[Obs, Corr]:
return pickle.load(file)
def pseudo_Obs(value: Union[float, int64, float64, int], dvalue: Union[float, float64, int], name: str, samples: int=1000) -> Obs:
def pseudo_Obs(value: Union[float, int], dvalue: Union[float, int], name: str, samples: int=1000) -> Obs:
"""Generate an Obs object with given value, dvalue and name for test purposes
Parameters
@ -123,11 +126,11 @@ def pseudo_Obs(value: Union[float, int64, float64, int], dvalue: Union[float, fl
return Obs([np.zeros(samples) + value], [name])
else:
for _ in range(100):
deltas = [np.random.normal(0.0, dvalue * np.sqrt(samples), samples)]
deltas = np.array([np.random.normal(0.0, dvalue * np.sqrt(samples), samples)])
deltas -= np.mean(deltas)
deltas *= dvalue / np.sqrt((np.var(deltas) / samples)) / np.sqrt(1 + 3 / samples)
deltas += value
res = Obs(deltas, [name])
res = Obs(list(deltas), [name])
res.gamma_method(S=2, tau_exp=0)
if abs(res.dvalue - dvalue) < 1e-10 * dvalue:
break
@ -179,7 +182,7 @@ def gen_correlated_data(means: Union[ndarray, List[float]], cov: ndarray, name:
return [Obs([dat], [name]) for dat in corr_data.T]
def _assert_equal_properties(ol: Union[List[Obs], List[CObs], ndarray], otype: Type[Obs]=Obs):
def _assert_equal_properties(ol: Union[List[Obs], List[CObs], ndarray]):
otype = type(ol[0])
for o in ol[1:]:
if not isinstance(o, otype):

View file

@ -114,7 +114,7 @@ class Obs:
self.N: int = 0
self.idl: dict[str, Union[list[int], range]] = {}
if idl is not None:
for name, idx in sorted(zip(names, idl, strict=True)):
for name, idx in sorted(zip(names, idl)):
if isinstance(idx, range):
self.idl[name] = idx
elif isinstance(idx, (list, np.ndarray)):
@ -130,17 +130,17 @@ class Obs:
else:
raise TypeError('incompatible type for idl[%s].' % name)
else:
for name, sample in sorted(zip(names, samples, strict=True)):
for name, sample in sorted(zip(names, samples)):
self.idl[name] = range(1, len(sample) + 1)
if means is not None:
for name, sample, mean in sorted(zip(names, samples, means, strict=True)):
for name, sample, mean in sorted(zip(names, samples, means)):
self.shape[name] = len(self.idl[name])
self.N += self.shape[name]
self.r_values[name] = mean
self.deltas[name] = sample
else:
for name, sample in sorted(zip(names, samples, strict=True)):
for name, sample in sorted(zip(names, samples)):
self.shape[name] = len(self.idl[name])
self.N += self.shape[name]
if len(sample) != self.shape[name]:
@ -648,7 +648,7 @@ class Obs:
if save:
fig1.savefig(save)
return dict(zip(labels, sizes, strict=True))
return dict(zip(labels, sizes))
def dump(self, filename: str, datatype: str="json.gz", description: str="", **kwargs):
"""Dump the Obs to a file 'name' of chosen format.

View file

@ -4,3 +4,7 @@ build-backend = "setuptools.build_meta"
[tool.ruff.lint]
ignore = ["F403"]
[tool.mypy]
warn_unused_configs = true
ignore_missing_imports = true