From 6d5a9b9d837fa32bb1ee171ce29b5f3bbc0ff96b Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Fri, 3 Jan 2025 23:15:40 +0100 Subject: [PATCH] [Fix] Simplify type annotations in input modules --- pyerrors/fits.py | 2 +- pyerrors/input/dobs.py | 33 +++++++++++++++++---------------- pyerrors/input/sfcf.py | 12 +++++++----- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/pyerrors/fits.py b/pyerrors/fits.py index 2a2950ae..fe96713a 100644 --- a/pyerrors/fits.py +++ b/pyerrors/fits.py @@ -710,7 +710,7 @@ def total_least_squares(x: list[Obs], y: list[Obs], func: Callable, silent: bool return output -def fit_lin(x: list[Union[Obs, int, float]], y: list[Obs], **kwargs) -> list[Obs]: +def fit_lin(x: Sequence[Union[Obs, int, float]], y: Sequence[Obs], **kwargs) -> list[Obs]: """Performs a linear fit to y = n + m * x and returns two Obs n, m. Parameters diff --git a/pyerrors/input/dobs.py b/pyerrors/input/dobs.py index d4ac638c..af60eb9c 100644 --- a/pyerrors/input/dobs.py +++ b/pyerrors/input/dobs.py @@ -18,9 +18,9 @@ from typing import Any, Optional, Union # Based on https://stackoverflow.com/a/10076823 -def _etree_to_dict(t: _Element) -> dict[str, Union[str, dict[str, str], dict[str, Union[str, dict[str, str]]]]]: +def _etree_to_dict(t: _Element) -> dict: """ Convert the content of an XML file to a python dict""" - d = {t.tag: {} if t.attrib else None} + d: dict = {t.tag: {} if t.attrib else None} children = list(t) if children: dd = defaultdict(list) @@ -70,7 +70,7 @@ def _dict_to_xmlstring(d: dict[str, Any]) -> str: return iters -def _dict_to_xmlstring_spaces(d: dict[str, dict[str, dict[str, Union[str, dict[str, str], list[dict[str, str]]]]]], space: str=' ') -> str: +def _dict_to_xmlstring_spaces(d: dict, space: str=' ') -> str: s = _dict_to_xmlstring(d) o = '' c = 0 @@ -89,7 +89,7 @@ def _dict_to_xmlstring_spaces(d: dict[str, dict[str, dict[str, Union[str, dict[s return o -def create_pobs_string(obsl: list[Obs], name: str, spec: str='', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, enstag: None=None) -> str: +def create_pobs_string(obsl: list[Obs], name: str, spec: str='', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, enstag: Optional[str]=None) -> str: """Export a list of Obs or structures containing Obs to an xml string according to the Zeuthen pobs format. @@ -119,7 +119,7 @@ def create_pobs_string(obsl: list[Obs], name: str, spec: str='', origin: str='', if symbol is None: symbol = [] - od = {} + od: dict[str, Any] = {} ename = obsl[0].e_names[0] names = list(obsl[0].deltas.keys()) nr = len(names) @@ -182,7 +182,7 @@ def create_pobs_string(obsl: list[Obs], name: str, spec: str='', origin: str='', return rs -def write_pobs(obsl: list[Obs], fname: str, name: str, spec: str='', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, enstag: None=None, gz: bool=True): +def write_pobs(obsl: list[Obs], fname: str, name: str, spec: str='', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, enstag: Optional[str]=None, gz: bool=True): """Export a list of Obs or structures containing Obs to a .xml.gz file according to the Zeuthen pobs format. @@ -223,12 +223,13 @@ def write_pobs(obsl: list[Obs], fname: str, name: str, spec: str='', origin: str if not fname.endswith('.gz'): fname += '.gz' - fp = gzip.open(fname, 'wb') - fp.write(pobsstring.encode('utf-8')) + gp = gzip.open(fname, 'wb') + gp.write(pobsstring.encode('utf-8')) + gp.close() else: fp = open(fname, 'w', encoding='utf-8') fp.write(pobsstring) - fp.close() + fp.close() def _import_data(string: str) -> list[Union[int, float]]: @@ -305,7 +306,7 @@ def _import_cdata(cd: _Element) -> tuple[str, ndarray, ndarray]: return cd[0].text.strip(), cov, grad -def read_pobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: None=None) -> Union[dict[str, Union[str, dict[str, str], list[Obs]]], list[Obs]]: +def read_pobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: None=None) -> Union[dict, list[Obs]]: """Import a list of Obs from an xml.gz file in the Zeuthen pobs format. Tags are not written or recovered automatically. @@ -358,7 +359,7 @@ def read_pobs(fname: str, full_output: bool=False, gz: bool=True, separator_inse deltas = [] names = [] - idl = [] + idl: list[list[int]] = [] for i in range(5, len(pobs)): delta, name, idx = _import_rdata(pobs[i]) deltas.append(delta) @@ -405,7 +406,7 @@ def read_pobs(fname: str, full_output: bool=False, gz: bool=True, separator_inse # this is based on Mattia Bruno's implementation at https://github.com/mbruno46/pyobs/blob/master/pyobs/IO/xml.py -def import_dobs_string(content: bytes, full_output: bool=False, separator_insertion: bool=True) -> Union[dict[str, Union[str, dict[str, str], list[Obs]]], list[Obs]]: +def import_dobs_string(content: bytes, full_output: bool=False, separator_insertion: bool=True) -> Union[dict, list[Obs]]: """Import a list of Obs from a string in the Zeuthen dobs format. Tags are not written or recovered automatically. @@ -579,7 +580,7 @@ def import_dobs_string(content: bytes, full_output: bool=False, separator_insert return res -def read_dobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: bool=True) -> Union[dict[str, Union[str, dict[str, str], list[Obs]]], list[Obs]]: +def read_dobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: bool=True) -> Union[dict, list[Obs]]: """Import a list of Obs from an xml.gz file in the Zeuthen dobs format. Tags are not written or recovered automatically. @@ -666,7 +667,7 @@ def _dobsdict_to_xmlstring(d: dict[str, Any]) -> str: return iters -def _dobsdict_to_xmlstring_spaces(d: dict[str, Union[dict[str, Union[dict[str, str], dict[str, Union[str, dict[str, str]]], dict[str, Union[str, dict[str, Union[str, list[str]]], list[dict[str, Union[str, int, list[dict[str, str]]]]]]]]], dict[str, Union[dict[str, str], dict[str, Union[str, dict[str, str]]], dict[str, Union[str, dict[str, Union[str, list[str]]], list[dict[str, Union[str, int, list[dict[str, str]]]]], list[dict[str, Union[str, list[dict[str, str]]]]]]]]]]], space: str=' ') -> str: +def _dobsdict_to_xmlstring_spaces(d: dict, space: str=' ') -> str: s = _dobsdict_to_xmlstring(d) o = '' c = 0 @@ -685,7 +686,7 @@ def _dobsdict_to_xmlstring_spaces(d: dict[str, Union[dict[str, Union[dict[str, s return o -def create_dobs_string(obsl: list[Obs], name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, who: None=None, enstags: Optional[dict[Any, Any]]=None) -> str: +def create_dobs_string(obsl: list[Obs], name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, who: Optional[str]=None, enstags: Optional[dict]=None) -> str: """Generate the string for the export of a list of Obs or structures containing Obs to a .xml.gz file according to the Zeuthen dobs format. @@ -876,7 +877,7 @@ def create_dobs_string(obsl: list[Obs], name: str, spec: str='dobs v1.0', origin return rs -def write_dobs(obsl: list[Obs], fname: str, name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, who: None=None, enstags: None=None, gz: bool=True): +def write_dobs(obsl: list[Obs], fname: str, name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[list[Union[str, Any]]]=None, who: Optional[str]=None, enstags: Optional[dict]=None, gz: bool=True): """Export a list of Obs or structures containing Obs to a .xml.gz file according to the Zeuthen dobs format. diff --git a/pyerrors/input/sfcf.py b/pyerrors/input/sfcf.py index 46cfa5d4..c47f81d2 100644 --- a/pyerrors/input/sfcf.py +++ b/pyerrors/input/sfcf.py @@ -78,7 +78,7 @@ def read_sfcf(path: str, prefix: str, name: str, quarks: str='.*', corr_type: st return ret[name][quarks][str(noffset)][str(wf)][str(wf2)] -def read_sfcf_multi(path: str, prefix: str, name_list: list[str], quarks_list: list[str]=['.*'], corr_type_list: list[str]=['bi'], noffset_list: list[int]=[0], wf_list: list[int]=[0], wf2_list: list[int]=[0], version: str="1.0c", cfg_separator: str="n", silent: bool=False, keyed_out: bool=False, **kwargs) -> dict[str, dict[str, dict[str, dict[str, dict[str, list[Obs]]]]]]: +def read_sfcf_multi(path: str, prefix: str, name_list: list[str], quarks_list: list[str]=['.*'], corr_type_list: list[str]=['bi'], noffset_list: list[int]=[0], wf_list: list[int]=[0], wf2_list: list[int]=[0], version: str="1.0c", cfg_separator: str="n", silent: bool=False, keyed_out: bool=False, **kwargs) -> dict: """Read sfcf files from given folder structure. Parameters @@ -425,7 +425,7 @@ def _specs2key(*specs) -> str: return sep.join(specs) -def _read_o_file(cfg_path: str, name: str, needed_keys: list[str], intern: dict[str, dict[str, Union[bool, dict[str, dict[str, dict[str, dict[str, dict[str, Union[int, str]]]]]], int]]], version: str, im: int) -> dict[str, list[float]]: +def _read_o_file(cfg_path: str, name: str, needed_keys: list[str], intern: dict[str, dict], version: str, im: int) -> dict[str, list[float]]: return_vals = {} for key in needed_keys: file = cfg_path + '/' + name @@ -463,7 +463,9 @@ def _extract_corr_type(corr_type: str) -> tuple[bool, bool]: return b2b, single -def _find_files(rep_path: str, prefix: str, compact: bool, files: list[Union[range, str, Any]]=[]) -> list[str]: +def _find_files(rep_path: str, prefix: str, compact: bool, files: Optional[list]=None) -> list[str]: + if files is None: + files = [] sub_ls = [] if not files == []: files.sort(key=lambda x: int(re.findall(r'\d+', x)[-1])) @@ -530,7 +532,7 @@ def _find_correlator(file_name: str, version: str, pattern: str, b2b: bool, sile return start_read, T -def _read_compact_file(rep_path: str, cfg_file: str, intern: dict[str, dict[str, Union[bool, dict[str, dict[str, dict[str, dict[str, dict[str, Union[int, str]]]]]], int]]], needed_keys: list[str], im: int) -> dict[str, list[float]]: +def _read_compact_file(rep_path: str, cfg_file: str, intern: dict[str, dict], needed_keys: list[str], im: int) -> dict[str, list[float]]: return_vals = {} with open(rep_path + cfg_file) as fp: lines = fp.readlines() @@ -561,7 +563,7 @@ def _read_compact_file(rep_path: str, cfg_file: str, intern: dict[str, dict[str, return return_vals -def _read_compact_rep(path: str, rep: str, sub_ls: list[str], intern: dict[str, dict[str, Union[bool, dict[str, dict[str, dict[str, dict[str, dict[str, Union[int, str]]]]]], int]]], needed_keys: list[str], im: int) -> dict[str, list[ndarray]]: +def _read_compact_rep(path: str, rep: str, sub_ls: list[str], intern: dict[str, dict], needed_keys: list[str], im: int) -> dict[str, list[ndarray]]: rep_path = path + '/' + rep + '/' no_cfg = len(sub_ls)