[Feat] Added type hints to input modules

This commit is contained in:
Fabian Joswig 2024-12-25 11:14:34 +01:00
parent 3db8eb2989
commit 9fe375a747
7 changed files with 96 additions and 68 deletions

View file

@ -1,3 +1,4 @@
from __future__ import annotations
from collections import defaultdict
import gzip
import lxml.etree as et
@ -11,10 +12,13 @@ from ..obs import Obs
from ..obs import _merge_idx
from ..covobs import Covobs
from .. import version as pyerrorsversion
from lxml.etree import _Element
from numpy import ndarray
from typing import Any, Dict, List, Optional, Tuple, Union
# Based on https://stackoverflow.com/a/10076823
def _etree_to_dict(t):
def _etree_to_dict(t: _Element) -> Dict[str, Union[str, Dict[str, str], Dict[str, Union[str, Dict[str, str]]]]]:
""" Convert the content of an XML file to a python dict"""
d = {t.tag: {} if t.attrib else None}
children = list(t)
@ -38,7 +42,7 @@ def _etree_to_dict(t):
return d
def _dict_to_xmlstring(d):
def _dict_to_xmlstring(d: Dict[str, Any]) -> str:
if isinstance(d, dict):
iters = ''
for k in d:
@ -66,7 +70,7 @@ def _dict_to_xmlstring(d):
return iters
def _dict_to_xmlstring_spaces(d, space=' '):
def _dict_to_xmlstring_spaces(d: Dict[str, Dict[str, Dict[str, Union[str, Dict[str, str], List[Dict[str, str]]]]]], space: str=' ') -> str:
s = _dict_to_xmlstring(d)
o = ''
c = 0
@ -85,7 +89,7 @@ def _dict_to_xmlstring_spaces(d, space=' '):
return o
def create_pobs_string(obsl, name, spec='', origin='', symbol=[], enstag=None):
def create_pobs_string(obsl: List[Obs], name: str, spec: str='', origin: str='', symbol: Optional[List[Union[str, Any]]]=None, enstag: None=None) -> str:
"""Export a list of Obs or structures containing Obs to an xml string
according to the Zeuthen pobs format.
@ -113,6 +117,8 @@ def create_pobs_string(obsl, name, spec='', origin='', symbol=[], enstag=None):
XML formatted string of the input data
"""
if symbol is None:
symbol = []
od = {}
ename = obsl[0].e_names[0]
names = list(obsl[0].deltas.keys())
@ -176,7 +182,7 @@ def create_pobs_string(obsl, name, spec='', origin='', symbol=[], enstag=None):
return rs
def write_pobs(obsl, fname, name, spec='', origin='', symbol=[], enstag=None, gz=True):
def write_pobs(obsl: List[Obs], fname: str, name: str, spec: str='', origin: str='', symbol: Optional[List[Union[str, Any]]]=None, enstag: None=None, gz: bool=True):
"""Export a list of Obs or structures containing Obs to a .xml.gz file
according to the Zeuthen pobs format.
@ -206,6 +212,8 @@ def write_pobs(obsl, fname, name, spec='', origin='', symbol=[], enstag=None, gz
-------
None
"""
if symbol is None:
symbol = []
pobsstring = create_pobs_string(obsl, name, spec, origin, symbol, enstag)
if not fname.endswith('.xml') and not fname.endswith('.gz'):
@ -223,30 +231,30 @@ def write_pobs(obsl, fname, name, spec='', origin='', symbol=[], enstag=None, gz
fp.close()
def _import_data(string):
def _import_data(string: str) -> List[Union[int, float]]:
return json.loads("[" + ",".join(string.replace(' +', ' ').split()) + "]")
def _check(condition):
def _check(condition: bool):
if not condition:
raise Exception("XML file format not supported")
class _NoTagInDataError(Exception):
"""Raised when tag is not in data"""
def __init__(self, tag):
def __init__(self, tag: str):
self.tag = tag
super().__init__('Tag %s not in data!' % (self.tag))
def _find_tag(dat, tag):
def _find_tag(dat: _Element, tag: str) -> int:
for i in range(len(dat)):
if dat[i].tag == tag:
return i
raise _NoTagInDataError(tag)
def _import_array(arr):
def _import_array(arr: _Element) -> Union[List[Union[str, List[int], List[ndarray]]], ndarray]:
name = arr[_find_tag(arr, 'id')].text.strip()
index = _find_tag(arr, 'layout')
try:
@ -284,12 +292,12 @@ def _import_array(arr):
_check(False)
def _import_rdata(rd):
def _import_rdata(rd: _Element) -> Tuple[List[ndarray], str, List[int]]:
name, idx, mask, deltas = _import_array(rd)
return deltas, name, idx
def _import_cdata(cd):
def _import_cdata(cd: _Element) -> Tuple[str, ndarray, ndarray]:
_check(cd[0].tag == "id")
_check(cd[1][0].text.strip() == "cov")
cov = _import_array(cd[1])
@ -297,7 +305,7 @@ def _import_cdata(cd):
return cd[0].text.strip(), cov, grad
def read_pobs(fname, full_output=False, gz=True, separator_insertion=None):
def read_pobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: None=None) -> Union[Dict[str, Union[str, Dict[str, str], List[Obs]]], List[Obs]]:
"""Import a list of Obs from an xml.gz file in the Zeuthen pobs format.
Tags are not written or recovered automatically.
@ -309,7 +317,7 @@ def read_pobs(fname, full_output=False, gz=True, separator_insertion=None):
full_output : bool
If True, a dict containing auxiliary information and the data is returned.
If False, only the data is returned as list.
separatior_insertion: str or int
separator_insertion: str or int
str: replace all occurences of "separator_insertion" within the replica names
by "|%s" % (separator_insertion) when constructing the names of the replica.
int: Insert the separator "|" at the position given by separator_insertion.
@ -397,7 +405,7 @@ def read_pobs(fname, full_output=False, gz=True, separator_insertion=None):
# this is based on Mattia Bruno's implementation at https://github.com/mbruno46/pyobs/blob/master/pyobs/IO/xml.py
def import_dobs_string(content, full_output=False, separator_insertion=True):
def import_dobs_string(content: bytes, full_output: bool=False, separator_insertion: bool=True) -> Union[Dict[str, Union[str, Dict[str, str], List[Obs]]], List[Obs]]:
"""Import a list of Obs from a string in the Zeuthen dobs format.
Tags are not written or recovered automatically.
@ -409,7 +417,7 @@ def import_dobs_string(content, full_output=False, separator_insertion=True):
full_output : bool
If True, a dict containing auxiliary information and the data is returned.
If False, only the data is returned as list.
separatior_insertion: str, int or bool
separator_insertion: str, int or bool
str: replace all occurences of "separator_insertion" within the replica names
by "|%s" % (separator_insertion) when constructing the names of the replica.
int: Insert the separator "|" at the position given by separator_insertion.
@ -571,7 +579,7 @@ def import_dobs_string(content, full_output=False, separator_insertion=True):
return res
def read_dobs(fname, full_output=False, gz=True, separator_insertion=True):
def read_dobs(fname: str, full_output: bool=False, gz: bool=True, separator_insertion: bool=True) -> Union[Dict[str, Union[str, Dict[str, str], List[Obs]]], List[Obs]]:
"""Import a list of Obs from an xml.gz file in the Zeuthen dobs format.
Tags are not written or recovered automatically.
@ -618,7 +626,7 @@ def read_dobs(fname, full_output=False, gz=True, separator_insertion=True):
return import_dobs_string(content, full_output, separator_insertion=separator_insertion)
def _dobsdict_to_xmlstring(d):
def _dobsdict_to_xmlstring(d: Dict[str, Any]) -> str:
if isinstance(d, dict):
iters = ''
for k in d:
@ -658,7 +666,7 @@ def _dobsdict_to_xmlstring(d):
return iters
def _dobsdict_to_xmlstring_spaces(d, space=' '):
def _dobsdict_to_xmlstring_spaces(d: Dict[str, Union[Dict[str, Union[Dict[str, str], Dict[str, Union[str, Dict[str, str]]], Dict[str, Union[str, Dict[str, Union[str, List[str]]], List[Dict[str, Union[str, int, List[Dict[str, str]]]]]]]]], Dict[str, Union[Dict[str, str], Dict[str, Union[str, Dict[str, str]]], Dict[str, Union[str, Dict[str, Union[str, List[str]]], List[Dict[str, Union[str, int, List[Dict[str, str]]]]], List[Dict[str, Union[str, List[Dict[str, str]]]]]]]]]]], space: str=' ') -> str:
s = _dobsdict_to_xmlstring(d)
o = ''
c = 0
@ -677,7 +685,7 @@ def _dobsdict_to_xmlstring_spaces(d, space=' '):
return o
def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=None, enstags=None):
def create_dobs_string(obsl: List[Obs], name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[List[Union[str, Any]]]=None, who: None=None, enstags: Optional[Dict[Any, Any]]=None) -> str:
"""Generate the string for the export of a list of Obs or structures containing Obs
to a .xml.gz file according to the Zeuthen dobs format.
@ -708,6 +716,8 @@ def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=N
xml_str : str
XML string generated from the data
"""
if symbol is None:
symbol = []
if enstags is None:
enstags = {}
od = {}
@ -866,7 +876,7 @@ def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=N
return rs
def write_dobs(obsl, fname, name, spec='dobs v1.0', origin='', symbol=[], who=None, enstags=None, gz=True):
def write_dobs(obsl: List[Obs], fname: str, name: str, spec: str='dobs v1.0', origin: str='', symbol: Optional[List[Union[str, Any]]]=None, who: None=None, enstags: None=None, gz: bool=True):
"""Export a list of Obs or structures containing Obs to a .xml.gz file
according to the Zeuthen dobs format.
@ -900,6 +910,8 @@ def write_dobs(obsl, fname, name, spec='dobs v1.0', origin='', symbol=[], who=No
-------
None
"""
if symbol is None:
symbol = []
if enstags is None:
enstags = {}

View file

@ -1,3 +1,4 @@
from __future__ import annotations
import rapidjson as json
import gzip
import getpass
@ -12,9 +13,11 @@ from ..covobs import Covobs
from ..correlators import Corr
from ..misc import _assert_equal_properties
from .. import version as pyerrorsversion
from numpy import float32, float64, int64, ndarray
from typing import Any, Dict, List, Optional, Tuple, Union
def create_json_string(ol, description='', indent=1):
def create_json_string(ol: Any, description: Union[str, Dict[str, Union[str, Dict[str, Union[Dict[str, Union[str, Dict[str, Union[int, str]]]], Dict[str, Optional[Union[str, List[str], float]]]]]]], Dict[str, Union[str, Dict[Optional[Union[int, bool]], str], float32]], Dict[str, Dict[str, Dict[str, str]]], Dict[str, Union[str, Dict[Optional[Union[int, bool]], str], Dict[int64, float64]]]]='', indent: int=1) -> str:
"""Generate the string for the export of a list of Obs or structures containing Obs
to a .json(.gz) file
@ -216,7 +219,7 @@ def create_json_string(ol, description='', indent=1):
return json.dumps(d, indent=indent, ensure_ascii=False, default=_jsonifier, write_mode=json.WM_COMPACT)
def dump_to_json(ol, fname, description='', indent=1, gz=True):
def dump_to_json(ol: Union[Corr, List[Union[Obs, List[Obs], Corr, ndarray]], ndarray, List[Union[Obs, List[Obs], ndarray]], List[Obs]], fname: str, description: Union[str, Dict[str, Union[str, Dict[str, Union[Dict[str, Union[str, Dict[str, Union[int, str]]]], Dict[str, Optional[Union[str, List[str], float]]]]]]], Dict[str, Union[str, Dict[Optional[Union[int, bool]], str], float32]], Dict[str, Dict[str, Dict[str, str]]], Dict[str, Union[str, Dict[Optional[Union[int, bool]], str], Dict[int64, float64]]]]='', indent: int=1, gz: bool=True):
"""Export a list of Obs or structures containing Obs to a .json(.gz) file.
Dict keys that are not JSON-serializable such as floats are converted to strings.
@ -258,7 +261,7 @@ def dump_to_json(ol, fname, description='', indent=1, gz=True):
fp.close()
def _parse_json_dict(json_dict, verbose=True, full_output=False):
def _parse_json_dict(json_dict: Dict[str, Any], verbose: bool=True, full_output: bool=False) -> Any:
"""Reconstruct a list of Obs or structures containing Obs from a dict that
was built out of a json string.
@ -470,7 +473,7 @@ def _parse_json_dict(json_dict, verbose=True, full_output=False):
return ol
def import_json_string(json_string, verbose=True, full_output=False):
def import_json_string(json_string: str, verbose: bool=True, full_output: bool=False) -> Union[Obs, List[Obs], Corr]:
"""Reconstruct a list of Obs or structures containing Obs from a json string.
The following structures are supported: Obs, list, numpy.ndarray, Corr
@ -500,7 +503,7 @@ def import_json_string(json_string, verbose=True, full_output=False):
return _parse_json_dict(json.loads(json_string), verbose, full_output)
def load_json(fname, verbose=True, gz=True, full_output=False):
def load_json(fname: str, verbose: bool=True, gz: bool=True, full_output: bool=False) -> Any:
"""Import a list of Obs or structures containing Obs from a .json(.gz) file.
The following structures are supported: Obs, list, numpy.ndarray, Corr
@ -545,7 +548,7 @@ def load_json(fname, verbose=True, gz=True, full_output=False):
return _parse_json_dict(d, verbose, full_output)
def _ol_from_dict(ind, reps='DICTOBS'):
def _ol_from_dict(ind: Union[Dict[Optional[Union[int, bool]], str], Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]], str]], Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]], List[str]]], Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]]]]], reps: str='DICTOBS') -> Union[Tuple[List[Any], Dict[Optional[Union[int, bool]], str]], Tuple[List[Union[Obs, List[Obs], Corr, ndarray]], Dict[str, Union[Dict[str, Union[str, Dict[str, Union[int, str]]]], Dict[str, Optional[Union[str, List[str], float]]]]]]]:
"""Convert a dictionary of Obs objects to a list and a dictionary that contains
placeholders instead of the Obs objects.
@ -625,7 +628,7 @@ def _ol_from_dict(ind, reps='DICTOBS'):
return ol, nd
def dump_dict_to_json(od, fname, description='', indent=1, reps='DICTOBS', gz=True):
def dump_dict_to_json(od: Union[Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]], str]], Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]], List[str]]], List[Union[Obs, List[Obs], Corr, ndarray]], Dict[Optional[Union[int, bool]], str], Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]]]]], fname: str, description: Union[str, float32, Dict[int64, float64]]='', indent: int=1, reps: str='DICTOBS', gz: bool=True):
"""Export a dict of Obs or structures containing Obs to a .json(.gz) file
Parameters
@ -665,7 +668,7 @@ def dump_dict_to_json(od, fname, description='', indent=1, reps='DICTOBS', gz=Tr
dump_to_json(ol, fname, description=desc_dict, indent=indent, gz=gz)
def _od_from_list_and_dict(ol, ind, reps='DICTOBS'):
def _od_from_list_and_dict(ol: List[Union[Obs, List[Obs], Corr, ndarray]], ind: Dict[str, Dict[str, Optional[Union[str, Dict[str, Union[int, str]], List[str], float]]]], reps: str='DICTOBS') -> Dict[str, Dict[str, Any]]:
"""Parse a list of Obs or structures containing Obs and an accompanying
dict, where the structures have been replaced by placeholders to a
dict that contains the structures.
@ -728,7 +731,7 @@ def _od_from_list_and_dict(ol, ind, reps='DICTOBS'):
return nd
def load_json_dict(fname, verbose=True, gz=True, full_output=False, reps='DICTOBS'):
def load_json_dict(fname: str, verbose: bool=True, gz: bool=True, full_output: bool=False, reps: str='DICTOBS') -> Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]], str, Dict[str, Union[Dict[str, Union[Obs, List[Obs], Dict[str, Union[int, Obs, Corr, ndarray]]]], Dict[str, Optional[Union[str, List[str], Obs, float]]]]]]]:
"""Import a dict of Obs or structures containing Obs from a .json(.gz) file.
The following structures are supported: Obs, list, numpy.ndarray, Corr

View file

@ -1,3 +1,4 @@
from __future__ import annotations
import os
import fnmatch
import re
@ -8,9 +9,10 @@ import matplotlib.pyplot as plt
from matplotlib import gridspec
from ..obs import Obs
from ..fits import fit_lin
from typing import Dict, Optional
def fit_t0(t2E_dict, fit_range, plot_fit=False, observable='t0'):
def fit_t0(t2E_dict: Dict[float, Obs], fit_range: int, plot_fit: Optional[bool]=False, observable: str='t0') -> Obs:
"""Compute the root of (flow-based) data based on a dictionary that contains
the necessary information in key-value pairs a la (flow time: observable at flow time).

View file

@ -1,3 +1,4 @@
from __future__ import annotations
import os
import fnmatch
import struct
@ -8,9 +9,11 @@ from ..obs import CObs
from ..correlators import Corr
from .misc import fit_t0
from .utils import sort_names
from io import BufferedReader
from typing import Dict, List, Optional, Tuple, Union
def read_rwms(path, prefix, version='2.0', names=None, **kwargs):
def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[List[str]]=None, **kwargs) -> List[Obs]:
"""Read rwms format from given folder structure. Returns a list of length nrw
Parameters
@ -229,7 +232,7 @@ def read_rwms(path, prefix, version='2.0', names=None, **kwargs):
return result
def _extract_flowed_energy_density(path, prefix, dtr_read, xmin, spatial_extent, postfix='ms', **kwargs):
def _extract_flowed_energy_density(path: str, prefix: str, dtr_read: int, xmin: int, spatial_extent: int, postfix: str='ms', **kwargs) -> Dict[float, Obs]:
"""Extract a dictionary with the flowed Yang-Mills action density from given .ms.dat files.
Returns a dictionary with Obs as values and flow times as keys.
@ -422,7 +425,7 @@ def _extract_flowed_energy_density(path, prefix, dtr_read, xmin, spatial_extent,
return E_dict
def extract_t0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfix='ms', c=0.3, **kwargs):
def extract_t0(path: str, prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int=5, postfix: str='ms', c: Union[float, int]=0.3, **kwargs) -> Obs:
"""Extract t0/a^2 from given .ms.dat files. Returns t0 as Obs.
It is assumed that all boundary effects have
@ -495,7 +498,7 @@ def extract_t0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfi
return fit_t0(t2E_dict, fit_range, plot_fit=kwargs.get('plot_fit'))
def extract_w0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfix='ms', c=0.3, **kwargs):
def extract_w0(path: str, prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int=5, postfix: str='ms', c: Union[float, int]=0.3, **kwargs) -> Obs:
"""Extract w0/a from given .ms.dat files. Returns w0 as Obs.
It is assumed that all boundary effects have
@ -577,7 +580,7 @@ def extract_w0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfi
return np.sqrt(fit_t0(tdtt2E_dict, fit_range, plot_fit=kwargs.get('plot_fit'), observable='w0'))
def _parse_array_openQCD2(d, n, size, wa, quadrupel=False):
def _parse_array_openQCD2(d: int, n: Tuple[int, int], size: int, wa: Union[Tuple[float, float, float, float, float, float, float, float], Tuple[float, float]], quadrupel: bool=False) -> List[List[float]]:
arr = []
if d == 2:
for i in range(n[0]):
@ -596,7 +599,7 @@ def _parse_array_openQCD2(d, n, size, wa, quadrupel=False):
return arr
def _find_files(path, prefix, postfix, ext, known_files=[]):
def _find_files(path: str, prefix: str, postfix: str, ext: str, known_files: Union[str, List[str]]=[]) -> List[str]:
found = []
files = []
@ -636,7 +639,7 @@ def _find_files(path, prefix, postfix, ext, known_files=[]):
return files
def _read_array_openQCD2(fp):
def _read_array_openQCD2(fp: BufferedReader) -> Dict[str, Union[int, Tuple[int, int], List[List[float]]]]:
t = fp.read(4)
d = struct.unpack('i', t)[0]
t = fp.read(4 * d)
@ -662,7 +665,7 @@ def _read_array_openQCD2(fp):
return {'d': d, 'n': n, 'size': size, 'arr': arr}
def read_qtop(path, prefix, c, dtr_cnfg=1, version="openQCD", **kwargs):
def read_qtop(path: str, prefix: str, c: float, dtr_cnfg: int=1, version: str="openQCD", **kwargs) -> Obs:
"""Read the topologial charge based on openQCD gradient flow measurements.
Parameters
@ -715,7 +718,7 @@ def read_qtop(path, prefix, c, dtr_cnfg=1, version="openQCD", **kwargs):
return _read_flow_obs(path, prefix, c, dtr_cnfg=dtr_cnfg, version=version, obspos=0, **kwargs)
def read_gf_coupling(path, prefix, c, dtr_cnfg=1, Zeuthen_flow=True, **kwargs):
def read_gf_coupling(path: str, prefix: str, c: float, dtr_cnfg: int=1, Zeuthen_flow: bool=True, **kwargs) -> Obs:
"""Read the gradient flow coupling based on sfqcd gradient flow measurements. See 1607.06423 for details.
Note: The current implementation only works for c=0.3 and T=L. The definition of the coupling in 1607.06423 requires projection to topological charge zero which is not done within this function but has to be performed in a separate step.
@ -787,7 +790,7 @@ def read_gf_coupling(path, prefix, c, dtr_cnfg=1, Zeuthen_flow=True, **kwargs):
return t * t * (5 / 3 * plaq - 1 / 12 * C2x1) / normdict[L]
def _read_flow_obs(path, prefix, c, dtr_cnfg=1, version="openQCD", obspos=0, sum_t=True, **kwargs):
def _read_flow_obs(path: str, prefix: str, c: float, dtr_cnfg: int=1, version: str="openQCD", obspos: int=0, sum_t: bool=True, **kwargs) -> Obs:
"""Read a flow observable based on openQCD gradient flow measurements.
Parameters
@ -1059,7 +1062,7 @@ def _read_flow_obs(path, prefix, c, dtr_cnfg=1, version="openQCD", obspos=0, sum
return result
def qtop_projection(qtop, target=0):
def qtop_projection(qtop: Obs, target: int=0) -> Obs:
"""Returns the projection to the topological charge sector defined by target.
Parameters
@ -1085,7 +1088,7 @@ def qtop_projection(qtop, target=0):
return reto
def read_qtop_sector(path, prefix, c, target=0, **kwargs):
def read_qtop_sector(path: str, prefix: str, c: float, target: int=0, **kwargs) -> Obs:
"""Constructs reweighting factors to a specified topological sector.
Parameters
@ -1143,7 +1146,7 @@ def read_qtop_sector(path, prefix, c, target=0, **kwargs):
return qtop_projection(qtop, target=target)
def read_ms5_xsf(path, prefix, qc, corr, sep="r", **kwargs):
def read_ms5_xsf(path: str, prefix: str, qc: str, corr: str, sep: str="r", **kwargs) -> Corr:
"""
Read data from files in the specified directory with the specified prefix and quark combination extension, and return a `Corr` object containing the data.

View file

@ -1,3 +1,4 @@
from __future__ import annotations
import warnings
import gzip
import sqlite3
@ -6,9 +7,11 @@ from ..obs import Obs
from ..correlators import Corr
from .json import create_json_string, import_json_string
import numpy as np
from pandas.core.frame import DataFrame
from pandas.core.series import Series
def to_sql(df, table_name, db, if_exists='fail', gz=True, **kwargs):
def to_sql(df: DataFrame, table_name: str, db: str, if_exists: str='fail', gz: bool=True, **kwargs):
"""Write DataFrame including Obs or Corr valued columns to sqlite database.
Parameters
@ -34,7 +37,7 @@ def to_sql(df, table_name, db, if_exists='fail', gz=True, **kwargs):
con.close()
def read_sql(sql, db, auto_gamma=False, **kwargs):
def read_sql(sql: str, db: str, auto_gamma: bool=False, **kwargs) -> DataFrame:
"""Execute SQL query on sqlite database and obtain DataFrame including Obs or Corr valued columns.
Parameters
@ -58,7 +61,7 @@ def read_sql(sql, db, auto_gamma=False, **kwargs):
return _deserialize_df(extract_df, auto_gamma=auto_gamma)
def dump_df(df, fname, gz=True):
def dump_df(df: DataFrame, fname: str, gz: bool=True):
"""Exports a pandas DataFrame containing Obs valued columns to a (gzipped) csv file.
Before making use of pandas to_csv functionality Obs objects are serialized via the standardized
@ -97,7 +100,7 @@ def dump_df(df, fname, gz=True):
out.to_csv(fname, index=False)
def load_df(fname, auto_gamma=False, gz=True):
def load_df(fname: str, auto_gamma: bool=False, gz: bool=True) -> DataFrame:
"""Imports a pandas DataFrame from a csv.(gz) file in which Obs objects are serialized as json strings.
Parameters
@ -131,7 +134,7 @@ def load_df(fname, auto_gamma=False, gz=True):
return _deserialize_df(re_import, auto_gamma=auto_gamma)
def _serialize_df(df, gz=False):
def _serialize_df(df: DataFrame, gz: bool=False) -> DataFrame:
"""Serializes all Obs or Corr valued columns into json strings according to the pyerrors json specification.
Parameters
@ -152,7 +155,7 @@ def _serialize_df(df, gz=False):
return out
def _deserialize_df(df, auto_gamma=False):
def _deserialize_df(df: DataFrame, auto_gamma: bool=False) -> DataFrame:
"""Deserializes all pyerrors json strings into Obs or Corr objects according to the pyerrors json specification.
Parameters
@ -188,7 +191,7 @@ def _deserialize_df(df, auto_gamma=False):
return df
def _need_to_serialize(col):
def _need_to_serialize(col: Series) -> bool:
serialize = False
i = 0
while i < len(col) and col[i] is None:

View file

@ -1,3 +1,4 @@
from __future__ import annotations
import os
import fnmatch
import re
@ -5,12 +6,14 @@ import numpy as np # Thinly-wrapped numpy
from ..obs import Obs
from .utils import sort_names, check_idl
import itertools
from numpy import ndarray
from typing import Any, Dict, List, Tuple, Union
sep = "/"
def read_sfcf(path, prefix, name, quarks='.*', corr_type="bi", noffset=0, wf=0, wf2=0, version="1.0c", cfg_separator="n", silent=False, **kwargs):
def read_sfcf(path: str, prefix: str, name: str, quarks: str='.*', corr_type: str="bi", noffset: int=0, wf: int=0, wf2: int=0, version: str="1.0c", cfg_separator: str="n", silent: bool=False, **kwargs) -> List[Obs]:
"""Read sfcf files from given folder structure.
Parameters
@ -75,7 +78,7 @@ def read_sfcf(path, prefix, name, quarks='.*', corr_type="bi", noffset=0, wf=0,
return ret[name][quarks][str(noffset)][str(wf)][str(wf2)]
def read_sfcf_multi(path, prefix, name_list, quarks_list=['.*'], corr_type_list=['bi'], noffset_list=[0], wf_list=[0], wf2_list=[0], version="1.0c", cfg_separator="n", silent=False, keyed_out=False, **kwargs):
def read_sfcf_multi(path: str, prefix: str, name_list: List[str], quarks_list: List[str]=['.*'], corr_type_list: List[str]=['bi'], noffset_list: List[int]=[0], wf_list: List[int]=[0], wf2_list: List[int]=[0], version: str="1.0c", cfg_separator: str="n", silent: bool=False, keyed_out: bool=False, **kwargs) -> Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, List[Obs]]]]]]:
"""Read sfcf files from given folder structure.
Parameters
@ -407,22 +410,22 @@ def read_sfcf_multi(path, prefix, name_list, quarks_list=['.*'], corr_type_list=
return result_dict
def _lists2key(*lists):
def _lists2key(*lists) -> List[str]:
keys = []
for tup in itertools.product(*lists):
keys.append(sep.join(tup))
return keys
def _key2specs(key):
def _key2specs(key: str) -> List[str]:
return key.split(sep)
def _specs2key(*specs):
def _specs2key(*specs) -> str:
return sep.join(specs)
def _read_o_file(cfg_path, name, needed_keys, intern, version, im):
def _read_o_file(cfg_path: str, name: str, needed_keys: List[str], intern: Dict[str, Dict[str, Union[bool, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Union[int, str]]]]]], int]]], version: str, im: int) -> Dict[str, List[float]]:
return_vals = {}
for key in needed_keys:
file = cfg_path + '/' + name
@ -447,7 +450,7 @@ def _read_o_file(cfg_path, name, needed_keys, intern, version, im):
return return_vals
def _extract_corr_type(corr_type):
def _extract_corr_type(corr_type: str) -> Tuple[bool, bool]:
if corr_type == 'bb':
b2b = True
single = True
@ -460,7 +463,7 @@ def _extract_corr_type(corr_type):
return b2b, single
def _find_files(rep_path, prefix, compact, files=[]):
def _find_files(rep_path: str, prefix: str, compact: bool, files: List[Union[range, str, Any]]=[]) -> List[str]:
sub_ls = []
if not files == []:
files.sort(key=lambda x: int(re.findall(r'\d+', x)[-1]))
@ -487,7 +490,7 @@ def _find_files(rep_path, prefix, compact, files=[]):
return files
def _make_pattern(version, name, noffset, wf, wf2, b2b, quarks):
def _make_pattern(version: str, name: str, noffset: str, wf: str, wf2: Union[str, int], b2b: bool, quarks: str) -> str:
if version == "0.0":
pattern = "# " + name + " : offset " + str(noffset) + ", wf " + str(wf)
if b2b:
@ -501,7 +504,7 @@ def _make_pattern(version, name, noffset, wf, wf2, b2b, quarks):
return pattern
def _find_correlator(file_name, version, pattern, b2b, silent=False):
def _find_correlator(file_name: str, version: str, pattern: str, b2b: bool, silent: bool=False) -> Tuple[int, int]:
T = 0
with open(file_name, "r") as my_file:
@ -527,7 +530,7 @@ def _find_correlator(file_name, version, pattern, b2b, silent=False):
return start_read, T
def _read_compact_file(rep_path, cfg_file, intern, needed_keys, im):
def _read_compact_file(rep_path: str, cfg_file: str, intern: Dict[str, Dict[str, Union[bool, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Union[int, str]]]]]], int]]], needed_keys: List[str], im: int) -> Dict[str, List[float]]:
return_vals = {}
with open(rep_path + cfg_file) as fp:
lines = fp.readlines()
@ -558,7 +561,7 @@ def _read_compact_file(rep_path, cfg_file, intern, needed_keys, im):
return return_vals
def _read_compact_rep(path, rep, sub_ls, intern, needed_keys, im):
def _read_compact_rep(path: str, rep: str, sub_ls: List[str], intern: Dict[str, Dict[str, Union[bool, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Union[int, str]]]]]], int]]], needed_keys: List[str], im: int) -> Dict[str, List[ndarray]]:
rep_path = path + '/' + rep + '/'
no_cfg = len(sub_ls)
@ -580,7 +583,7 @@ def _read_compact_rep(path, rep, sub_ls, intern, needed_keys, im):
return return_vals
def _read_chunk(chunk, gauge_line, cfg_sep, start_read, T, corr_line, b2b, pattern, im, single):
def _read_chunk(chunk: List[str], gauge_line: int, cfg_sep: str, start_read: int, T: int, corr_line: int, b2b: bool, pattern: str, im: int, single: bool) -> Tuple[int, List[float]]:
try:
idl = int(chunk[gauge_line].split(cfg_sep)[-1])
except Exception:
@ -597,7 +600,7 @@ def _read_chunk(chunk, gauge_line, cfg_sep, start_read, T, corr_line, b2b, patte
return idl, data
def _read_append_rep(filename, pattern, b2b, cfg_separator, im, single):
def _read_append_rep(filename: str, pattern: str, b2b: bool, cfg_separator: str, im: int, single: bool) -> Tuple[int, List[int], List[List[float]]]:
with open(filename, 'r') as fp:
content = fp.readlines()
data_starts = []
@ -646,7 +649,7 @@ def _read_append_rep(filename, pattern, b2b, cfg_separator, im, single):
return T, rep_idl, data
def _get_rep_names(ls, ens_name=None):
def _get_rep_names(ls: List[str], ens_name: None=None) -> List[str]:
new_names = []
for entry in ls:
try:
@ -661,7 +664,7 @@ def _get_rep_names(ls, ens_name=None):
return new_names
def _get_appended_rep_names(ls, prefix, name, ens_name=None):
def _get_appended_rep_names(ls: List[str], prefix: str, name: str, ens_name: None=None) -> List[str]:
new_names = []
for exc in ls:
if not fnmatch.fnmatch(exc, prefix + '*.' + name):

View file

@ -1,11 +1,13 @@
"""Utilities for the input"""
from __future__ import annotations
import re
import fnmatch
import os
from typing import List
def sort_names(ll):
def sort_names(ll: List[str]) -> List[str]:
"""Sorts a list of names of replika with searches for `r` and `id` in the replikum string.
If this search fails, a fallback method is used,
where the strings are simply compared and the first diffeing numeral is used for differentiation.