[Fix] Simplify type annotations in input modules

This commit is contained in:
Fabian Joswig 2025-01-03 23:15:40 +01:00
parent b8700ef962
commit 6d5a9b9d83
3 changed files with 25 additions and 22 deletions

View file

@ -710,7 +710,7 @@ def total_least_squares(x: list[Obs], y: list[Obs], func: Callable, silent: bool
return output
def fit_lin(x: list[Union[Obs, int, float]], y: list[Obs], **kwargs) -> list[Obs]:
def fit_lin(x: Sequence[Union[Obs, int, float]], y: Sequence[Obs], **kwargs) -> list[Obs]:
"""Performs a linear fit to y = n + m * x and returns two Obs n, m.
Parameters

View file

@ -18,9 +18,9 @@ from typing import Any, Optional, Union
# Based on https://stackoverflow.com/a/10076823
def _etree_to_dict(t: _Element) -> dict[str, Union[str, dict[str, str], dict[str, Union[str, dict[str, str]]]]]:
def _etree_to_dict(t: _Element) -> dict:
""" Convert the content of an XML file to a python dict"""
d = {t.tag: {} if t.attrib else None}
d: dict = {t.tag: {} if t.attrib else None}
children = list(t)
if children:
dd = defaultdict(list)
@ -70,7 +70,7 @@ def _dict_to_xmlstring(d: dict[str, Any]) -> str:
return iters
def _dict_to_xmlstring_spaces(d: dict[str, dict[str, dict[str, Union[str, dict[str, str], list[dict[str, str]]]]]], space: str=' ') -> str:
def _dict_to_xmlstring_spaces(d: dict, space: str=' ') -> str:
s = _dict_to_xmlstring(d)
o = ''
c = 0
@ -89,7 +89,7 @@ def _dict_to_xmlstring_spaces(d: dict[str, dict[str, dict[str, Union[str, dict[s
return o
def create_pobs_string(obsl: list[Obs], name: str, spec: str='', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, enstag: None=None) -> str:
def create_pobs_string(obsl: list[Obs], name: str, spec: str='', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, enstag: Optional[str]=None) -> str:
"""Export a list of Obs or structures containing Obs to an xml string
according to the Zeuthen pobs format.
@ -119,7 +119,7 @@ def create_pobs_string(obsl: list[Obs], name: str, spec: str='', origin: str='',
if symbol is None:
symbol = []
od = {}
od: dict[str, Any] = {}
ename = obsl[0].e_names[0]
names = list(obsl[0].deltas.keys())
nr = len(names)
@ -182,7 +182,7 @@ def create_pobs_string(obsl: list[Obs], name: str, spec: str='', origin: str='',
return rs
def write_pobs(obsl: list[Obs], fname: str, name: str, spec: str='', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, enstag: None=None, gz: bool=True):
def write_pobs(obsl: list[Obs], fname: str, name: str, spec: str='', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, enstag: Optional[str]=None, gz: bool=True):
"""Export a list of Obs or structures containing Obs to a .xml.gz file
according to the Zeuthen pobs format.
@ -223,12 +223,13 @@ def write_pobs(obsl: list[Obs], fname: str, name: str, spec: str='', origin: str
if not fname.endswith('.gz'):
fname += '.gz'
fp = gzip.open(fname, 'wb')
fp.write(pobsstring.encode('utf-8'))
gp = gzip.open(fname, 'wb')
gp.write(pobsstring.encode('utf-8'))
gp.close()
else:
fp = open(fname, 'w', encoding='utf-8')
fp.write(pobsstring)
fp.close()
fp.close()
def _import_data(string: str) -> list[Union[int, float]]:
@ -305,7 +306,7 @@ def _import_cdata(cd: _Element) -> tuple[str, ndarray, ndarray]:
return cd[0].text.strip(), cov, grad
def read_pobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: None=None) -> Union[dict[str, Union[str, dict[str, str], list[Obs]]], list[Obs]]:
def read_pobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: None=None) -> Union[dict, list[Obs]]:
"""Import a list of Obs from an xml.gz file in the Zeuthen pobs format.
Tags are not written or recovered automatically.
@ -358,7 +359,7 @@ def read_pobs(fname: str, full_output: bool=False, gz: bool=True, separator_inse
deltas = []
names = []
idl = []
idl: list[list[int]] = []
for i in range(5, len(pobs)):
delta, name, idx = _import_rdata(pobs[i])
deltas.append(delta)
@ -405,7 +406,7 @@ def read_pobs(fname: str, full_output: bool=False, gz: bool=True, separator_inse
# this is based on Mattia Bruno's implementation at https://github.com/mbruno46/pyobs/blob/master/pyobs/IO/xml.py
def import_dobs_string(content: bytes, full_output: bool=False, separator_insertion: bool=True) -> Union[dict[str, Union[str, dict[str, str], list[Obs]]], list[Obs]]:
def import_dobs_string(content: bytes, full_output: bool=False, separator_insertion: bool=True) -> Union[dict, list[Obs]]:
"""Import a list of Obs from a string in the Zeuthen dobs format.
Tags are not written or recovered automatically.
@ -579,7 +580,7 @@ def import_dobs_string(content: bytes, full_output: bool=False, separator_insert
return res
def read_dobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: bool=True) -> Union[dict[str, Union[str, dict[str, str], list[Obs]]], list[Obs]]:
def read_dobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: bool=True) -> Union[dict, list[Obs]]:
"""Import a list of Obs from an xml.gz file in the Zeuthen dobs format.
Tags are not written or recovered automatically.
@ -666,7 +667,7 @@ def _dobsdict_to_xmlstring(d: dict[str, Any]) -> str:
return iters
def _dobsdict_to_xmlstring_spaces(d: dict[str, Union[dict[str, Union[dict[str, str], dict[str, Union[str, dict[str, str]]], dict[str, Union[str, dict[str, Union[str, list[str]]], list[dict[str, Union[str, int, list[dict[str, str]]]]]]]]], dict[str, Union[dict[str, str], dict[str, Union[str, dict[str, str]]], dict[str, Union[str, dict[str, Union[str, list[str]]], list[dict[str, Union[str, int, list[dict[str, str]]]]], list[dict[str, Union[str, list[dict[str, str]]]]]]]]]]], space: str=' ') -> str:
def _dobsdict_to_xmlstring_spaces(d: dict, space: str=' ') -> str:
s = _dobsdict_to_xmlstring(d)
o = ''
c = 0
@ -685,7 +686,7 @@ def _dobsdict_to_xmlstring_spaces(d: dict[str, Union[dict[str, Union[dict[str, s
return o
def create_dobs_string(obsl: list[Obs], name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, who: None=None, enstags: Optional[dict[Any, Any]]=None) -> str:
def create_dobs_string(obsl: list[Obs], name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, who: Optional[str]=None, enstags: Optional[dict]=None) -> str:
"""Generate the string for the export of a list of Obs or structures containing Obs
to a .xml.gz file according to the Zeuthen dobs format.
@ -876,7 +877,7 @@ def create_dobs_string(obsl: list[Obs], name: str, spec: str='dobs v1.0', origin
return rs
def write_dobs(obsl: list[Obs], fname: str, name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, who: None=None, enstags: None=None, gz: bool=True):
def write_dobs(obsl: list[Obs], fname: str, name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, who: Optional[str]=None, enstags: Optional[dict]=None, gz: bool=True):
"""Export a list of Obs or structures containing Obs to a .xml.gz file
according to the Zeuthen dobs format.

View file

@ -78,7 +78,7 @@ def read_sfcf(path: str, prefix: str, name: str, quarks: str='.*', corr_type: st
return ret[name][quarks][str(noffset)][str(wf)][str(wf2)]
def read_sfcf_multi(path: str, prefix: str, name_list: list[str], quarks_list: list[str]=['.*'], corr_type_list: list[str]=['bi'], noffset_list: list[int]=[0], wf_list: list[int]=[0], wf2_list: list[int]=[0], version: str="1.0c", cfg_separator: str="n", silent: bool=False, keyed_out: bool=False, **kwargs) -> dict[str, dict[str, dict[str, dict[str, dict[str, list[Obs]]]]]]:
def read_sfcf_multi(path: str, prefix: str, name_list: list[str], quarks_list: list[str]=['.*'], corr_type_list: list[str]=['bi'], noffset_list: list[int]=[0], wf_list: list[int]=[0], wf2_list: list[int]=[0], version: str="1.0c", cfg_separator: str="n", silent: bool=False, keyed_out: bool=False, **kwargs) -> dict:
"""Read sfcf files from given folder structure.
Parameters
@ -425,7 +425,7 @@ def _specs2key(*specs) -> str:
return sep.join(specs)
def _read_o_file(cfg_path: str, name: str, needed_keys: list[str], intern: dict[str, dict[str, Union[bool, dict[str, dict[str, dict[str, dict[str, dict[str, Union[int, str]]]]]], int]]], version: str, im: int) -> dict[str, list[float]]:
def _read_o_file(cfg_path: str, name: str, needed_keys: list[str], intern: dict[str, dict], version: str, im: int) -> dict[str, list[float]]:
return_vals = {}
for key in needed_keys:
file = cfg_path + '/' + name
@ -463,7 +463,9 @@ def _extract_corr_type(corr_type: str) -> tuple[bool, bool]:
return b2b, single
def _find_files(rep_path: str, prefix: str, compact: bool, files: list[Union[range, str, Any]]=[]) -> list[str]:
def _find_files(rep_path: str, prefix: str, compact: bool, files: Optional[list]=None) -> list[str]:
if files is None:
files = []
sub_ls = []
if not files == []:
files.sort(key=lambda x: int(re.findall(r'\d+', x)[-1]))
@ -530,7 +532,7 @@ def _find_correlator(file_name: str, version: str, pattern: str, b2b: bool, sile
return start_read, T
def _read_compact_file(rep_path: str, cfg_file: str, intern: dict[str, dict[str, Union[bool, dict[str, dict[str, dict[str, dict[str, dict[str, Union[int, str]]]]]], int]]], needed_keys: list[str], im: int) -> dict[str, list[float]]:
def _read_compact_file(rep_path: str, cfg_file: str, intern: dict[str, dict], needed_keys: list[str], im: int) -> dict[str, list[float]]:
return_vals = {}
with open(rep_path + cfg_file) as fp:
lines = fp.readlines()
@ -561,7 +563,7 @@ def _read_compact_file(rep_path: str, cfg_file: str, intern: dict[str, dict[str,
return return_vals
def _read_compact_rep(path: str, rep: str, sub_ls: list[str], intern: dict[str, dict[str, Union[bool, dict[str, dict[str, dict[str, dict[str, dict[str, Union[int, str]]]]]], int]]], needed_keys: list[str], im: int) -> dict[str, list[ndarray]]:
def _read_compact_rep(path: str, rep: str, sub_ls: list[str], intern: dict[str, dict], needed_keys: list[str], im: int) -> dict[str, list[ndarray]]:
rep_path = path + '/' + rep + '/'
no_cfg = len(sub_ls)