diff --git a/pyerrors/input/dobs.py b/pyerrors/input/dobs.py index aea9b7a9..201bbff8 100644 --- a/pyerrors/input/dobs.py +++ b/pyerrors/input/dobs.py @@ -1,3 +1,4 @@ +from __future__ import annotations from collections import defaultdict import gzip import lxml.etree as et @@ -11,10 +12,13 @@ from ..obs import Obs from ..obs import _merge_idx from ..covobs import Covobs from .. import version as pyerrorsversion +from lxml.etree import _Element +from numpy import ndarray +from typing import Any, Dict, List, Optional, Tuple, Union # Based on https://stackoverflow.com/a/10076823 -def _etree_to_dict(t): +def _etree_to_dict(t: _Element) -> Dict[str, Union[str, Dict[str, str], Dict[str, Union[str, Dict[str, str]]]]]: """ Convert the content of an XML file to a python dict""" d = {t.tag: {} if t.attrib else None} children = list(t) @@ -38,7 +42,7 @@ def _etree_to_dict(t): return d -def _dict_to_xmlstring(d): +def _dict_to_xmlstring(d: Dict[str, Any]) -> str: if isinstance(d, dict): iters = '' for k in d: @@ -66,7 +70,7 @@ def _dict_to_xmlstring(d): return iters -def _dict_to_xmlstring_spaces(d, space=' '): +def _dict_to_xmlstring_spaces(d: Dict[str, Dict[str, Dict[str, Union[str, Dict[str, str], List[Dict[str, str]]]]]], space: str=' ') -> str: s = _dict_to_xmlstring(d) o = '' c = 0 @@ -85,7 +89,7 @@ def _dict_to_xmlstring_spaces(d, space=' '): return o -def create_pobs_string(obsl, name, spec='', origin='', symbol=[], enstag=None): +def create_pobs_string(obsl: List[Obs], name: str, spec: str='', origin: str='', symbol: Optional[List[Union[str, Any]]]=None, enstag: None=None) -> str: """Export a list of Obs or structures containing Obs to an xml string according to the Zeuthen pobs format. @@ -113,6 +117,8 @@ def create_pobs_string(obsl, name, spec='', origin='', symbol=[], enstag=None): XML formatted string of the input data """ + if symbol is None: + symbol = [] od = {} ename = obsl[0].e_names[0] names = list(obsl[0].deltas.keys()) @@ -176,7 +182,7 @@ def create_pobs_string(obsl, name, spec='', origin='', symbol=[], enstag=None): return rs -def write_pobs(obsl, fname, name, spec='', origin='', symbol=[], enstag=None, gz=True): +def write_pobs(obsl: List[Obs], fname: str, name: str, spec: str='', origin: str='', symbol: Optional[List[Union[str, Any]]]=None, enstag: None=None, gz: bool=True): """Export a list of Obs or structures containing Obs to a .xml.gz file according to the Zeuthen pobs format. @@ -206,6 +212,8 @@ def write_pobs(obsl, fname, name, spec='', origin='', symbol=[], enstag=None, gz ------- None """ + if symbol is None: + symbol = [] pobsstring = create_pobs_string(obsl, name, spec, origin, symbol, enstag) if not fname.endswith('.xml') and not fname.endswith('.gz'): @@ -223,30 +231,30 @@ def write_pobs(obsl, fname, name, spec='', origin='', symbol=[], enstag=None, gz fp.close() -def _import_data(string): +def _import_data(string: str) -> List[Union[int, float]]: return json.loads("[" + ",".join(string.replace(' +', ' ').split()) + "]") -def _check(condition): +def _check(condition: bool): if not condition: raise Exception("XML file format not supported") class _NoTagInDataError(Exception): """Raised when tag is not in data""" - def __init__(self, tag): + def __init__(self, tag: str): self.tag = tag super().__init__('Tag %s not in data!' % (self.tag)) -def _find_tag(dat, tag): +def _find_tag(dat: _Element, tag: str) -> int: for i in range(len(dat)): if dat[i].tag == tag: return i raise _NoTagInDataError(tag) -def _import_array(arr): +def _import_array(arr: _Element) -> Union[List[Union[str, List[int], List[ndarray]]], ndarray]: name = arr[_find_tag(arr, 'id')].text.strip() index = _find_tag(arr, 'layout') try: @@ -284,12 +292,12 @@ def _import_array(arr): _check(False) -def _import_rdata(rd): +def _import_rdata(rd: _Element) -> Tuple[List[ndarray], str, List[int]]: name, idx, mask, deltas = _import_array(rd) return deltas, name, idx -def _import_cdata(cd): +def _import_cdata(cd: _Element) -> Tuple[str, ndarray, ndarray]: _check(cd[0].tag == "id") _check(cd[1][0].text.strip() == "cov") cov = _import_array(cd[1]) @@ -297,7 +305,7 @@ def _import_cdata(cd): return cd[0].text.strip(), cov, grad -def read_pobs(fname, full_output=False, gz=True, separator_insertion=None): +def read_pobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: None=None) -> Union[Dict[str, Union[str, Dict[str, str], List[Obs]]], List[Obs]]: """Import a list of Obs from an xml.gz file in the Zeuthen pobs format. Tags are not written or recovered automatically. @@ -309,7 +317,7 @@ def read_pobs(fname, full_output=False, gz=True, separator_insertion=None): full_output : bool If True, a dict containing auxiliary information and the data is returned. If False, only the data is returned as list. - separatior_insertion: str or int + separator_insertion: str or int str: replace all occurences of "separator_insertion" within the replica names by "|%s" % (separator_insertion) when constructing the names of the replica. int: Insert the separator "|" at the position given by separator_insertion. @@ -397,7 +405,7 @@ def read_pobs(fname, full_output=False, gz=True, separator_insertion=None): # this is based on Mattia Bruno's implementation at https://github.com/mbruno46/pyobs/blob/master/pyobs/IO/xml.py -def import_dobs_string(content, full_output=False, separator_insertion=True): +def import_dobs_string(content: bytes, full_output: bool=False, separator_insertion: bool=True) -> Union[Dict[str, Union[str, Dict[str, str], List[Obs]]], List[Obs]]: """Import a list of Obs from a string in the Zeuthen dobs format. Tags are not written or recovered automatically. @@ -409,7 +417,7 @@ def import_dobs_string(content, full_output=False, separator_insertion=True): full_output : bool If True, a dict containing auxiliary information and the data is returned. If False, only the data is returned as list. - separatior_insertion: str, int or bool + separator_insertion: str, int or bool str: replace all occurences of "separator_insertion" within the replica names by "|%s" % (separator_insertion) when constructing the names of the replica. int: Insert the separator "|" at the position given by separator_insertion. @@ -571,7 +579,7 @@ def import_dobs_string(content, full_output=False, separator_insertion=True): return res -def read_dobs(fname, full_output=False, gz=True, separator_insertion=True): +def read_dobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: bool=True) -> Union[Dict[str, Union[str, Dict[str, str], List[Obs]]], List[Obs]]: """Import a list of Obs from an xml.gz file in the Zeuthen dobs format. Tags are not written or recovered automatically. @@ -618,7 +626,7 @@ def read_dobs(fname, full_output=False, gz=True, separator_insertion=True): return import_dobs_string(content, full_output, separator_insertion=separator_insertion) -def _dobsdict_to_xmlstring(d): +def _dobsdict_to_xmlstring(d: Dict[str, Any]) -> str: if isinstance(d, dict): iters = '' for k in d: @@ -658,7 +666,7 @@ def _dobsdict_to_xmlstring(d): return iters -def _dobsdict_to_xmlstring_spaces(d, space=' '): +def _dobsdict_to_xmlstring_spaces(d: Dict[str, Union[Dict[str, Union[Dict[str, str], Dict[str, Union[str, Dict[str, str]]], Dict[str, Union[str, Dict[str, Union[str, List[str]]], List[Dict[str, Union[str, int, List[Dict[str, str]]]]]]]]], Dict[str, Union[Dict[str, str], Dict[str, Union[str, Dict[str, str]]], Dict[str, Union[str, Dict[str, Union[str, List[str]]], List[Dict[str, Union[str, int, List[Dict[str, str]]]]], List[Dict[str, Union[str, List[Dict[str, str]]]]]]]]]]], space: str=' ') -> str: s = _dobsdict_to_xmlstring(d) o = '' c = 0 @@ -677,7 +685,7 @@ def _dobsdict_to_xmlstring_spaces(d, space=' '): return o -def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=None, enstags=None): +def create_dobs_string(obsl: List[Obs], name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[List[Union[str, Any]]]=None, who: None=None, enstags: Optional[Dict[Any, Any]]=None) -> str: """Generate the string for the export of a list of Obs or structures containing Obs to a .xml.gz file according to the Zeuthen dobs format. @@ -708,6 +716,8 @@ def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=N xml_str : str XML string generated from the data """ + if symbol is None: + symbol = [] if enstags is None: enstags = {} od = {} @@ -866,7 +876,7 @@ def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=N return rs -def write_dobs(obsl, fname, name, spec='dobs v1.0', origin='', symbol=[], who=None, enstags=None, gz=True): +def write_dobs(obsl: List[Obs], fname: str, name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[List[Union[str, Any]]]=None, who: None=None, enstags: None=None, gz: bool=True): """Export a list of Obs or structures containing Obs to a .xml.gz file according to the Zeuthen dobs format. @@ -900,6 +910,8 @@ def write_dobs(obsl, fname, name, spec='dobs v1.0', origin='', symbol=[], who=No ------- None """ + if symbol is None: + symbol = [] if enstags is None: enstags = {} diff --git a/pyerrors/input/json.py b/pyerrors/input/json.py index ca3fb0d2..2463959d 100644 --- a/pyerrors/input/json.py +++ b/pyerrors/input/json.py @@ -1,3 +1,4 @@ +from __future__ import annotations import rapidjson as json import gzip import getpass @@ -12,9 +13,11 @@ from ..covobs import Covobs from ..correlators import Corr from ..misc import _assert_equal_properties from .. import version as pyerrorsversion +from numpy import float32, float64, int64, ndarray +from typing import Any, Dict, List, Optional, Tuple, Union -def create_json_string(ol, description='', indent=1): +def create_json_string(ol: Any, description: Union[str, Dict[str, Union[str, Dict[str, Union[Dict[str, Union[str, Dict[str, Union[int, str]]]], Dict[str, Optional[Union[str, List[str], float]]]]]]], Dict[str, Union[str, Dict[Optional[Union[int, bool]], str], float32]], Dict[str, Dict[str, Dict[str, str]]], Dict[str, Union[str, Dict[Optional[Union[int, bool]], str], Dict[int64, float64]]]]='', indent: int=1) -> str: """Generate the string for the export of a list of Obs or structures containing Obs to a .json(.gz) file @@ -216,7 +219,7 @@ def create_json_string(ol, description='', indent=1): return json.dumps(d, indent=indent, ensure_ascii=False, default=_jsonifier, write_mode=json.WM_COMPACT) -def dump_to_json(ol, fname, description='', indent=1, gz=True): +def dump_to_json(ol: Union[Corr, List[Union[Obs, List[Obs], Corr, ndarray]], ndarray, List[Union[Obs, List[Obs], ndarray]], List[Obs]], fname: str, description: Union[str, Dict[str, Union[str, Dict[str, Union[Dict[str, Union[str, Dict[str, Union[int, str]]]], Dict[str, Optional[Union[str, List[str], float]]]]]]], Dict[str, Union[str, Dict[Optional[Union[int, bool]], str], float32]], Dict[str, Dict[str, Dict[str, str]]], Dict[str, Union[str, Dict[Optional[Union[int, bool]], str], Dict[int64, float64]]]]='', indent: int=1, gz: bool=True): """Export a list of Obs or structures containing Obs to a .json(.gz) file. Dict keys that are not JSON-serializable such as floats are converted to strings. @@ -258,7 +261,7 @@ def dump_to_json(ol, fname, description='', indent=1, gz=True): fp.close() -def _parse_json_dict(json_dict, verbose=True, full_output=False): +def _parse_json_dict(json_dict: Dict[str, Any], verbose: bool=True, full_output: bool=False) -> Any: """Reconstruct a list of Obs or structures containing Obs from a dict that was built out of a json string. @@ -470,7 +473,7 @@ def _parse_json_dict(json_dict, verbose=True, full_output=False): return ol -def import_json_string(json_string, verbose=True, full_output=False): +def import_json_string(json_string: str, verbose: bool=True, full_output: bool=False) -> Union[Obs, List[Obs], Corr]: """Reconstruct a list of Obs or structures containing Obs from a json string. The following structures are supported: Obs, list, numpy.ndarray, Corr @@ -500,7 +503,7 @@ def import_json_string(json_string, verbose=True, full_output=False): return _parse_json_dict(json.loads(json_string), verbose, full_output) -def load_json(fname, verbose=True, gz=True, full_output=False): +def load_json(fname: str, verbose: bool=True, gz: bool=True, full_output: bool=False) -> Any: """Import a list of Obs or structures containing Obs from a .json(.gz) file. The following structures are supported: Obs, list, numpy.ndarray, Corr @@ -545,7 +548,7 @@ def load_json(fname, verbose=True, gz=True, full_output=False): return _parse_json_dict(d, verbose, full_output) -def _ol_from_dict(ind, reps='DICTOBS'): +def _ol_from_dict(ind: Union[Dict[Optional[Union[int, bool]], str], Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]], str]], Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]], List[str]]], Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]]]]], reps: str='DICTOBS') -> Union[Tuple[List[Any], Dict[Optional[Union[int, bool]], str]], Tuple[List[Union[Obs, List[Obs], Corr, ndarray]], Dict[str, Union[Dict[str, Union[str, Dict[str, Union[int, str]]]], Dict[str, Optional[Union[str, List[str], float]]]]]]]: """Convert a dictionary of Obs objects to a list and a dictionary that contains placeholders instead of the Obs objects. @@ -625,7 +628,7 @@ def _ol_from_dict(ind, reps='DICTOBS'): return ol, nd -def dump_dict_to_json(od, fname, description='', indent=1, reps='DICTOBS', gz=True): +def dump_dict_to_json(od: Union[Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]], str]], Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]], List[str]]], List[Union[Obs, List[Obs], Corr, ndarray]], Dict[Optional[Union[int, bool]], str], Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]]]]], fname: str, description: Union[str, float32, Dict[int64, float64]]='', indent: int=1, reps: str='DICTOBS', gz: bool=True): """Export a dict of Obs or structures containing Obs to a .json(.gz) file Parameters @@ -665,7 +668,7 @@ def dump_dict_to_json(od, fname, description='', indent=1, reps='DICTOBS', gz=Tr dump_to_json(ol, fname, description=desc_dict, indent=indent, gz=gz) -def _od_from_list_and_dict(ol, ind, reps='DICTOBS'): +def _od_from_list_and_dict(ol: List[Union[Obs, List[Obs], Corr, ndarray]], ind: Dict[str, Dict[str, Optional[Union[str, Dict[str, Union[int, str]], List[str], float]]]], reps: str='DICTOBS') -> Dict[str, Dict[str, Any]]: """Parse a list of Obs or structures containing Obs and an accompanying dict, where the structures have been replaced by placeholders to a dict that contains the structures. @@ -728,7 +731,7 @@ def _od_from_list_and_dict(ol, ind, reps='DICTOBS'): return nd -def load_json_dict(fname, verbose=True, gz=True, full_output=False, reps='DICTOBS'): +def load_json_dict(fname: str, verbose: bool=True, gz: bool=True, full_output: bool=False, reps: str='DICTOBS') -> Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]], str, Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]]]]]]: """Import a dict of Obs or structures containing Obs from a .json(.gz) file. The following structures are supported: Obs, list, numpy.ndarray, Corr diff --git a/pyerrors/input/misc.py b/pyerrors/input/misc.py index c62f502c..0c09b429 100644 --- a/pyerrors/input/misc.py +++ b/pyerrors/input/misc.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os import fnmatch import re @@ -8,9 +9,10 @@ import matplotlib.pyplot as plt from matplotlib import gridspec from ..obs import Obs from ..fits import fit_lin +from typing import Dict, Optional -def fit_t0(t2E_dict, fit_range, plot_fit=False, observable='t0'): +def fit_t0(t2E_dict: Dict[float, Obs], fit_range: int, plot_fit: Optional[bool]=False, observable: str='t0') -> Obs: """Compute the root of (flow-based) data based on a dictionary that contains the necessary information in key-value pairs a la (flow time: observable at flow time). diff --git a/pyerrors/input/openQCD.py b/pyerrors/input/openQCD.py index 278977d2..6c23e3f2 100644 --- a/pyerrors/input/openQCD.py +++ b/pyerrors/input/openQCD.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os import fnmatch import struct @@ -8,9 +9,11 @@ from ..obs import CObs from ..correlators import Corr from .misc import fit_t0 from .utils import sort_names +from io import BufferedReader +from typing import Dict, List, Optional, Tuple, Union -def read_rwms(path, prefix, version='2.0', names=None, **kwargs): +def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[List[str]]=None, **kwargs) -> List[Obs]: """Read rwms format from given folder structure. Returns a list of length nrw Parameters @@ -229,7 +232,7 @@ def read_rwms(path, prefix, version='2.0', names=None, **kwargs): return result -def _extract_flowed_energy_density(path, prefix, dtr_read, xmin, spatial_extent, postfix='ms', **kwargs): +def _extract_flowed_energy_density(path: str, prefix: str, dtr_read: int, xmin: int, spatial_extent: int, postfix: str='ms', **kwargs) -> Dict[float, Obs]: """Extract a dictionary with the flowed Yang-Mills action density from given .ms.dat files. Returns a dictionary with Obs as values and flow times as keys. @@ -422,7 +425,7 @@ def _extract_flowed_energy_density(path, prefix, dtr_read, xmin, spatial_extent, return E_dict -def extract_t0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfix='ms', c=0.3, **kwargs): +def extract_t0(path: str, prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int=5, postfix: str='ms', c: Union[float, int]=0.3, **kwargs) -> Obs: """Extract t0/a^2 from given .ms.dat files. Returns t0 as Obs. It is assumed that all boundary effects have @@ -495,7 +498,7 @@ def extract_t0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfi return fit_t0(t2E_dict, fit_range, plot_fit=kwargs.get('plot_fit')) -def extract_w0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfix='ms', c=0.3, **kwargs): +def extract_w0(path: str, prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int=5, postfix: str='ms', c: Union[float, int]=0.3, **kwargs) -> Obs: """Extract w0/a from given .ms.dat files. Returns w0 as Obs. It is assumed that all boundary effects have @@ -577,7 +580,7 @@ def extract_w0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfi return np.sqrt(fit_t0(tdtt2E_dict, fit_range, plot_fit=kwargs.get('plot_fit'), observable='w0')) -def _parse_array_openQCD2(d, n, size, wa, quadrupel=False): +def _parse_array_openQCD2(d: int, n: Tuple[int, int], size: int, wa: Union[Tuple[float, float, float, float, float, float, float, float], Tuple[float, float]], quadrupel: bool=False) -> List[List[float]]: arr = [] if d == 2: for i in range(n[0]): @@ -596,7 +599,7 @@ def _parse_array_openQCD2(d, n, size, wa, quadrupel=False): return arr -def _find_files(path, prefix, postfix, ext, known_files=[]): +def _find_files(path: str, prefix: str, postfix: str, ext: str, known_files: Union[str, List[str]]=[]) -> List[str]: found = [] files = [] @@ -636,7 +639,7 @@ def _find_files(path, prefix, postfix, ext, known_files=[]): return files -def _read_array_openQCD2(fp): +def _read_array_openQCD2(fp: BufferedReader) -> Dict[str, Union[int, Tuple[int, int], List[List[float]]]]: t = fp.read(4) d = struct.unpack('i', t)[0] t = fp.read(4 * d) @@ -662,7 +665,7 @@ def _read_array_openQCD2(fp): return {'d': d, 'n': n, 'size': size, 'arr': arr} -def read_qtop(path, prefix, c, dtr_cnfg=1, version="openQCD", **kwargs): +def read_qtop(path: str, prefix: str, c: float, dtr_cnfg: int=1, version: str="openQCD", **kwargs) -> Obs: """Read the topologial charge based on openQCD gradient flow measurements. Parameters @@ -715,7 +718,7 @@ def read_qtop(path, prefix, c, dtr_cnfg=1, version="openQCD", **kwargs): return _read_flow_obs(path, prefix, c, dtr_cnfg=dtr_cnfg, version=version, obspos=0, **kwargs) -def read_gf_coupling(path, prefix, c, dtr_cnfg=1, Zeuthen_flow=True, **kwargs): +def read_gf_coupling(path: str, prefix: str, c: float, dtr_cnfg: int=1, Zeuthen_flow: bool=True, **kwargs) -> Obs: """Read the gradient flow coupling based on sfqcd gradient flow measurements. See 1607.06423 for details. Note: The current implementation only works for c=0.3 and T=L. The definition of the coupling in 1607.06423 requires projection to topological charge zero which is not done within this function but has to be performed in a separate step. @@ -787,7 +790,7 @@ def read_gf_coupling(path, prefix, c, dtr_cnfg=1, Zeuthen_flow=True, **kwargs): return t * t * (5 / 3 * plaq - 1 / 12 * C2x1) / normdict[L] -def _read_flow_obs(path, prefix, c, dtr_cnfg=1, version="openQCD", obspos=0, sum_t=True, **kwargs): +def _read_flow_obs(path: str, prefix: str, c: float, dtr_cnfg: int=1, version: str="openQCD", obspos: int=0, sum_t: bool=True, **kwargs) -> Obs: """Read a flow observable based on openQCD gradient flow measurements. Parameters @@ -1059,7 +1062,7 @@ def _read_flow_obs(path, prefix, c, dtr_cnfg=1, version="openQCD", obspos=0, sum return result -def qtop_projection(qtop, target=0): +def qtop_projection(qtop: Obs, target: int=0) -> Obs: """Returns the projection to the topological charge sector defined by target. Parameters @@ -1085,7 +1088,7 @@ def qtop_projection(qtop, target=0): return reto -def read_qtop_sector(path, prefix, c, target=0, **kwargs): +def read_qtop_sector(path: str, prefix: str, c: float, target: int=0, **kwargs) -> Obs: """Constructs reweighting factors to a specified topological sector. Parameters @@ -1143,7 +1146,7 @@ def read_qtop_sector(path, prefix, c, target=0, **kwargs): return qtop_projection(qtop, target=target) -def read_ms5_xsf(path, prefix, qc, corr, sep="r", **kwargs): +def read_ms5_xsf(path: str, prefix: str, qc: str, corr: str, sep: str="r", **kwargs) -> Corr: """ Read data from files in the specified directory with the specified prefix and quark combination extension, and return a `Corr` object containing the data. diff --git a/pyerrors/input/pandas.py b/pyerrors/input/pandas.py index 13482983..f6cdb7bf 100644 --- a/pyerrors/input/pandas.py +++ b/pyerrors/input/pandas.py @@ -1,3 +1,4 @@ +from __future__ import annotations import warnings import gzip import sqlite3 @@ -6,9 +7,11 @@ from ..obs import Obs from ..correlators import Corr from .json import create_json_string, import_json_string import numpy as np +from pandas.core.frame import DataFrame +from pandas.core.series import Series -def to_sql(df, table_name, db, if_exists='fail', gz=True, **kwargs): +def to_sql(df: DataFrame, table_name: str, db: str, if_exists: str='fail', gz: bool=True, **kwargs): """Write DataFrame including Obs or Corr valued columns to sqlite database. Parameters @@ -34,7 +37,7 @@ def to_sql(df, table_name, db, if_exists='fail', gz=True, **kwargs): con.close() -def read_sql(sql, db, auto_gamma=False, **kwargs): +def read_sql(sql: str, db: str, auto_gamma: bool=False, **kwargs) -> DataFrame: """Execute SQL query on sqlite database and obtain DataFrame including Obs or Corr valued columns. Parameters @@ -58,7 +61,7 @@ def read_sql(sql, db, auto_gamma=False, **kwargs): return _deserialize_df(extract_df, auto_gamma=auto_gamma) -def dump_df(df, fname, gz=True): +def dump_df(df: DataFrame, fname: str, gz: bool=True): """Exports a pandas DataFrame containing Obs valued columns to a (gzipped) csv file. Before making use of pandas to_csv functionality Obs objects are serialized via the standardized @@ -97,7 +100,7 @@ def dump_df(df, fname, gz=True): out.to_csv(fname, index=False) -def load_df(fname, auto_gamma=False, gz=True): +def load_df(fname: str, auto_gamma: bool=False, gz: bool=True) -> DataFrame: """Imports a pandas DataFrame from a csv.(gz) file in which Obs objects are serialized as json strings. Parameters @@ -131,7 +134,7 @@ def load_df(fname, auto_gamma=False, gz=True): return _deserialize_df(re_import, auto_gamma=auto_gamma) -def _serialize_df(df, gz=False): +def _serialize_df(df: DataFrame, gz: bool=False) -> DataFrame: """Serializes all Obs or Corr valued columns into json strings according to the pyerrors json specification. Parameters @@ -152,7 +155,7 @@ def _serialize_df(df, gz=False): return out -def _deserialize_df(df, auto_gamma=False): +def _deserialize_df(df: DataFrame, auto_gamma: bool=False) -> DataFrame: """Deserializes all pyerrors json strings into Obs or Corr objects according to the pyerrors json specification. Parameters @@ -188,7 +191,7 @@ def _deserialize_df(df, auto_gamma=False): return df -def _need_to_serialize(col): +def _need_to_serialize(col: Series) -> bool: serialize = False i = 0 while i < len(col) and col[i] is None: diff --git a/pyerrors/input/sfcf.py b/pyerrors/input/sfcf.py index e9f2837e..596a52e4 100644 --- a/pyerrors/input/sfcf.py +++ b/pyerrors/input/sfcf.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os import fnmatch import re @@ -5,12 +6,14 @@ import numpy as np # Thinly-wrapped numpy from ..obs import Obs from .utils import sort_names, check_idl import itertools +from numpy import ndarray +from typing import Any, Dict, List, Tuple, Union sep = "/" -def read_sfcf(path, prefix, name, quarks='.*', corr_type="bi", noffset=0, wf=0, wf2=0, version="1.0c", cfg_separator="n", silent=False, **kwargs): +def read_sfcf(path: str, prefix: str, name: str, quarks: str='.*', corr_type: str="bi", noffset: int=0, wf: int=0, wf2: int=0, version: str="1.0c", cfg_separator: str="n", silent: bool=False, **kwargs) -> List[Obs]: """Read sfcf files from given folder structure. Parameters @@ -75,7 +78,7 @@ def read_sfcf(path, prefix, name, quarks='.*', corr_type="bi", noffset=0, wf=0, return ret[name][quarks][str(noffset)][str(wf)][str(wf2)] -def read_sfcf_multi(path, prefix, name_list, quarks_list=['.*'], corr_type_list=['bi'], noffset_list=[0], wf_list=[0], wf2_list=[0], version="1.0c", cfg_separator="n", silent=False, keyed_out=False, **kwargs): +def read_sfcf_multi(path: str, prefix: str, name_list: List[str], quarks_list: List[str]=['.*'], corr_type_list: List[str]=['bi'], noffset_list: List[int]=[0], wf_list: List[int]=[0], wf2_list: List[int]=[0], version: str="1.0c", cfg_separator: str="n", silent: bool=False, keyed_out: bool=False, **kwargs) -> Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, List[Obs]]]]]]: """Read sfcf files from given folder structure. Parameters @@ -407,22 +410,22 @@ def read_sfcf_multi(path, prefix, name_list, quarks_list=['.*'], corr_type_list= return result_dict -def _lists2key(*lists): +def _lists2key(*lists) -> List[str]: keys = [] for tup in itertools.product(*lists): keys.append(sep.join(tup)) return keys -def _key2specs(key): +def _key2specs(key: str) -> List[str]: return key.split(sep) -def _specs2key(*specs): +def _specs2key(*specs) -> str: return sep.join(specs) -def _read_o_file(cfg_path, name, needed_keys, intern, version, im): +def _read_o_file(cfg_path: str, name: str, needed_keys: List[str], intern: Dict[str, Dict[str, Union[bool, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Union[int, str]]]]]], int]]], version: str, im: int) -> Dict[str, List[float]]: return_vals = {} for key in needed_keys: file = cfg_path + '/' + name @@ -447,7 +450,7 @@ def _read_o_file(cfg_path, name, needed_keys, intern, version, im): return return_vals -def _extract_corr_type(corr_type): +def _extract_corr_type(corr_type: str) -> Tuple[bool, bool]: if corr_type == 'bb': b2b = True single = True @@ -460,7 +463,7 @@ def _extract_corr_type(corr_type): return b2b, single -def _find_files(rep_path, prefix, compact, files=[]): +def _find_files(rep_path: str, prefix: str, compact: bool, files: List[Union[range, str, Any]]=[]) -> List[str]: sub_ls = [] if not files == []: files.sort(key=lambda x: int(re.findall(r'\d+', x)[-1])) @@ -487,7 +490,7 @@ def _find_files(rep_path, prefix, compact, files=[]): return files -def _make_pattern(version, name, noffset, wf, wf2, b2b, quarks): +def _make_pattern(version: str, name: str, noffset: str, wf: str, wf2: Union[str, int], b2b: bool, quarks: str) -> str: if version == "0.0": pattern = "# " + name + " : offset " + str(noffset) + ", wf " + str(wf) if b2b: @@ -501,7 +504,7 @@ def _make_pattern(version, name, noffset, wf, wf2, b2b, quarks): return pattern -def _find_correlator(file_name, version, pattern, b2b, silent=False): +def _find_correlator(file_name: str, version: str, pattern: str, b2b: bool, silent: bool=False) -> Tuple[int, int]: T = 0 with open(file_name, "r") as my_file: @@ -527,7 +530,7 @@ def _find_correlator(file_name, version, pattern, b2b, silent=False): return start_read, T -def _read_compact_file(rep_path, cfg_file, intern, needed_keys, im): +def _read_compact_file(rep_path: str, cfg_file: str, intern: Dict[str, Dict[str, Union[bool, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Union[int, str]]]]]], int]]], needed_keys: List[str], im: int) -> Dict[str, List[float]]: return_vals = {} with open(rep_path + cfg_file) as fp: lines = fp.readlines() @@ -558,7 +561,7 @@ def _read_compact_file(rep_path, cfg_file, intern, needed_keys, im): return return_vals -def _read_compact_rep(path, rep, sub_ls, intern, needed_keys, im): +def _read_compact_rep(path: str, rep: str, sub_ls: List[str], intern: Dict[str, Dict[str, Union[bool, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Union[int, str]]]]]], int]]], needed_keys: List[str], im: int) -> Dict[str, List[ndarray]]: rep_path = path + '/' + rep + '/' no_cfg = len(sub_ls) @@ -580,7 +583,7 @@ def _read_compact_rep(path, rep, sub_ls, intern, needed_keys, im): return return_vals -def _read_chunk(chunk, gauge_line, cfg_sep, start_read, T, corr_line, b2b, pattern, im, single): +def _read_chunk(chunk: List[str], gauge_line: int, cfg_sep: str, start_read: int, T: int, corr_line: int, b2b: bool, pattern: str, im: int, single: bool) -> Tuple[int, List[float]]: try: idl = int(chunk[gauge_line].split(cfg_sep)[-1]) except Exception: @@ -597,7 +600,7 @@ def _read_chunk(chunk, gauge_line, cfg_sep, start_read, T, corr_line, b2b, patte return idl, data -def _read_append_rep(filename, pattern, b2b, cfg_separator, im, single): +def _read_append_rep(filename: str, pattern: str, b2b: bool, cfg_separator: str, im: int, single: bool) -> Tuple[int, List[int], List[List[float]]]: with open(filename, 'r') as fp: content = fp.readlines() data_starts = [] @@ -646,7 +649,7 @@ def _read_append_rep(filename, pattern, b2b, cfg_separator, im, single): return T, rep_idl, data -def _get_rep_names(ls, ens_name=None): +def _get_rep_names(ls: List[str], ens_name: None=None) -> List[str]: new_names = [] for entry in ls: try: @@ -661,7 +664,7 @@ def _get_rep_names(ls, ens_name=None): return new_names -def _get_appended_rep_names(ls, prefix, name, ens_name=None): +def _get_appended_rep_names(ls: List[str], prefix: str, name: str, ens_name: None=None) -> List[str]: new_names = [] for exc in ls: if not fnmatch.fnmatch(exc, prefix + '*.' + name): diff --git a/pyerrors/input/utils.py b/pyerrors/input/utils.py index eaf41f06..5ac00ba8 100644 --- a/pyerrors/input/utils.py +++ b/pyerrors/input/utils.py @@ -1,11 +1,13 @@ """Utilities for the input""" +from __future__ import annotations import re import fnmatch import os +from typing import List -def sort_names(ll): +def sort_names(ll: List[str]) -> List[str]: """Sorts a list of names of replika with searches for `r` and `id` in the replikum string. If this search fails, a fallback method is used, where the strings are simply compared and the first diffeing numeral is used for differentiation.