[Fix] Fix type hints in misc.py and remove strict zips for python 3.9

compatability
This commit is contained in:
Fabian Joswig 2025-01-03 18:39:34 +01:00
parent 9389ad67c9
commit 23d4f4c320
3 changed files with 18 additions and 11 deletions

View file

@ -75,8 +75,11 @@ def dump_object(obj: Corr, name: str, **kwargs):
------- -------
None None
""" """
if 'path' in kwargs: path = kwargs.get('path')
file_name = kwargs.get('path') + '/' + name + '.p' if path is not None:
if not isinstance(path, str):
raise Exception("Path has to be a string.")
file_name = path + '/' + name + '.p'
else: else:
file_name = name + '.p' file_name = name + '.p'
with open(file_name, 'wb') as fb: with open(file_name, 'wb') as fb:
@ -100,7 +103,7 @@ def load_object(path: str) -> Union[Obs, Corr]:
return pickle.load(file) return pickle.load(file)
def pseudo_Obs(value: Union[float, int64, float64, int], dvalue: Union[float, float64, int], name: str, samples: int=1000) -> Obs: def pseudo_Obs(value: Union[float, int], dvalue: Union[float, int], name: str, samples: int=1000) -> Obs:
"""Generate an Obs object with given value, dvalue and name for test purposes """Generate an Obs object with given value, dvalue and name for test purposes
Parameters Parameters
@ -123,11 +126,11 @@ def pseudo_Obs(value: Union[float, int64, float64, int], dvalue: Union[float, fl
return Obs([np.zeros(samples) + value], [name]) return Obs([np.zeros(samples) + value], [name])
else: else:
for _ in range(100): for _ in range(100):
deltas = [np.random.normal(0.0, dvalue * np.sqrt(samples), samples)] deltas = np.array([np.random.normal(0.0, dvalue * np.sqrt(samples), samples)])
deltas -= np.mean(deltas) deltas -= np.mean(deltas)
deltas *= dvalue / np.sqrt((np.var(deltas) / samples)) / np.sqrt(1 + 3 / samples) deltas *= dvalue / np.sqrt((np.var(deltas) / samples)) / np.sqrt(1 + 3 / samples)
deltas += value deltas += value
res = Obs(deltas, [name]) res = Obs(list(deltas), [name])
res.gamma_method(S=2, tau_exp=0) res.gamma_method(S=2, tau_exp=0)
if abs(res.dvalue - dvalue) < 1e-10 * dvalue: if abs(res.dvalue - dvalue) < 1e-10 * dvalue:
break break
@ -179,7 +182,7 @@ def gen_correlated_data(means: Union[ndarray, List[float]], cov: ndarray, name:
return [Obs([dat], [name]) for dat in corr_data.T] return [Obs([dat], [name]) for dat in corr_data.T]
def _assert_equal_properties(ol: Union[List[Obs], List[CObs], ndarray], otype: Type[Obs]=Obs): def _assert_equal_properties(ol: Union[List[Obs], List[CObs], ndarray]):
otype = type(ol[0]) otype = type(ol[0])
for o in ol[1:]: for o in ol[1:]:
if not isinstance(o, otype): if not isinstance(o, otype):

View file

@ -114,7 +114,7 @@ class Obs:
self.N: int = 0 self.N: int = 0
self.idl: dict[str, Union[list[int], range]] = {} self.idl: dict[str, Union[list[int], range]] = {}
if idl is not None: if idl is not None:
for name, idx in sorted(zip(names, idl, strict=True)): for name, idx in sorted(zip(names, idl)):
if isinstance(idx, range): if isinstance(idx, range):
self.idl[name] = idx self.idl[name] = idx
elif isinstance(idx, (list, np.ndarray)): elif isinstance(idx, (list, np.ndarray)):
@ -130,17 +130,17 @@ class Obs:
else: else:
raise TypeError('incompatible type for idl[%s].' % name) raise TypeError('incompatible type for idl[%s].' % name)
else: else:
for name, sample in sorted(zip(names, samples, strict=True)): for name, sample in sorted(zip(names, samples)):
self.idl[name] = range(1, len(sample) + 1) self.idl[name] = range(1, len(sample) + 1)
if means is not None: if means is not None:
for name, sample, mean in sorted(zip(names, samples, means, strict=True)): for name, sample, mean in sorted(zip(names, samples, means)):
self.shape[name] = len(self.idl[name]) self.shape[name] = len(self.idl[name])
self.N += self.shape[name] self.N += self.shape[name]
self.r_values[name] = mean self.r_values[name] = mean
self.deltas[name] = sample self.deltas[name] = sample
else: else:
for name, sample in sorted(zip(names, samples, strict=True)): for name, sample in sorted(zip(names, samples)):
self.shape[name] = len(self.idl[name]) self.shape[name] = len(self.idl[name])
self.N += self.shape[name] self.N += self.shape[name]
if len(sample) != self.shape[name]: if len(sample) != self.shape[name]:
@ -648,7 +648,7 @@ class Obs:
if save: if save:
fig1.savefig(save) fig1.savefig(save)
return dict(zip(labels, sizes, strict=True)) return dict(zip(labels, sizes))
def dump(self, filename: str, datatype: str="json.gz", description: str="", **kwargs): def dump(self, filename: str, datatype: str="json.gz", description: str="", **kwargs):
"""Dump the Obs to a file 'name' of chosen format. """Dump the Obs to a file 'name' of chosen format.

View file

@ -4,3 +4,7 @@ build-backend = "setuptools.build_meta"
[tool.ruff.lint] [tool.ruff.lint]
ignore = ["F403"] ignore = ["F403"]
[tool.mypy]
warn_unused_configs = true
ignore_missing_imports = true