[Feat] Added type hints to input modules

This commit is contained in:
Fabian Joswig 2024-12-25 11:14:34 +01:00
parent 3db8eb2989
commit 9fe375a747
7 changed files with 96 additions and 68 deletions

View file

@ -1,3 +1,4 @@
from __future__ import annotations
import os
import fnmatch
import struct
@ -8,9 +9,11 @@ from ..obs import CObs
from ..correlators import Corr
from .misc import fit_t0
from .utils import sort_names
from io import BufferedReader
from typing import Dict, List, Optional, Tuple, Union
def read_rwms(path, prefix, version='2.0', names=None, **kwargs):
def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[List[str]]=None, **kwargs) -> List[Obs]:
"""Read rwms format from given folder structure. Returns a list of length nrw
Parameters
@ -229,7 +232,7 @@ def read_rwms(path, prefix, version='2.0', names=None, **kwargs):
return result
def _extract_flowed_energy_density(path, prefix, dtr_read, xmin, spatial_extent, postfix='ms', **kwargs):
def _extract_flowed_energy_density(path: str, prefix: str, dtr_read: int, xmin: int, spatial_extent: int, postfix: str='ms', **kwargs) -> Dict[float, Obs]:
"""Extract a dictionary with the flowed Yang-Mills action density from given .ms.dat files.
Returns a dictionary with Obs as values and flow times as keys.
@ -422,7 +425,7 @@ def _extract_flowed_energy_density(path, prefix, dtr_read, xmin, spatial_extent,
return E_dict
def extract_t0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfix='ms', c=0.3, **kwargs):
def extract_t0(path: str, prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int=5, postfix: str='ms', c: Union[float, int]=0.3, **kwargs) -> Obs:
"""Extract t0/a^2 from given .ms.dat files. Returns t0 as Obs.
It is assumed that all boundary effects have
@ -495,7 +498,7 @@ def extract_t0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfi
return fit_t0(t2E_dict, fit_range, plot_fit=kwargs.get('plot_fit'))
def extract_w0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfix='ms', c=0.3, **kwargs):
def extract_w0(path: str, prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int=5, postfix: str='ms', c: Union[float, int]=0.3, **kwargs) -> Obs:
"""Extract w0/a from given .ms.dat files. Returns w0 as Obs.
It is assumed that all boundary effects have
@ -577,7 +580,7 @@ def extract_w0(path, prefix, dtr_read, xmin, spatial_extent, fit_range=5, postfi
return np.sqrt(fit_t0(tdtt2E_dict, fit_range, plot_fit=kwargs.get('plot_fit'), observable='w0'))
def _parse_array_openQCD2(d, n, size, wa, quadrupel=False):
def _parse_array_openQCD2(d: int, n: Tuple[int, int], size: int, wa: Union[Tuple[float, float, float, float, float, float, float, float], Tuple[float, float]], quadrupel: bool=False) -> List[List[float]]:
arr = []
if d == 2:
for i in range(n[0]):
@ -596,7 +599,7 @@ def _parse_array_openQCD2(d, n, size, wa, quadrupel=False):
return arr
def _find_files(path, prefix, postfix, ext, known_files=[]):
def _find_files(path: str, prefix: str, postfix: str, ext: str, known_files: Union[str, List[str]]=[]) -> List[str]:
found = []
files = []
@ -636,7 +639,7 @@ def _find_files(path, prefix, postfix, ext, known_files=[]):
return files
def _read_array_openQCD2(fp):
def _read_array_openQCD2(fp: BufferedReader) -> Dict[str, Union[int, Tuple[int, int], List[List[float]]]]:
t = fp.read(4)
d = struct.unpack('i', t)[0]
t = fp.read(4 * d)
@ -662,7 +665,7 @@ def _read_array_openQCD2(fp):
return {'d': d, 'n': n, 'size': size, 'arr': arr}
def read_qtop(path, prefix, c, dtr_cnfg=1, version="openQCD", **kwargs):
def read_qtop(path: str, prefix: str, c: float, dtr_cnfg: int=1, version: str="openQCD", **kwargs) -> Obs:
"""Read the topologial charge based on openQCD gradient flow measurements.
Parameters
@ -715,7 +718,7 @@ def read_qtop(path, prefix, c, dtr_cnfg=1, version="openQCD", **kwargs):
return _read_flow_obs(path, prefix, c, dtr_cnfg=dtr_cnfg, version=version, obspos=0, **kwargs)
def read_gf_coupling(path, prefix, c, dtr_cnfg=1, Zeuthen_flow=True, **kwargs):
def read_gf_coupling(path: str, prefix: str, c: float, dtr_cnfg: int=1, Zeuthen_flow: bool=True, **kwargs) -> Obs:
"""Read the gradient flow coupling based on sfqcd gradient flow measurements. See 1607.06423 for details.
Note: The current implementation only works for c=0.3 and T=L. The definition of the coupling in 1607.06423 requires projection to topological charge zero which is not done within this function but has to be performed in a separate step.
@ -787,7 +790,7 @@ def read_gf_coupling(path, prefix, c, dtr_cnfg=1, Zeuthen_flow=True, **kwargs):
return t * t * (5 / 3 * plaq - 1 / 12 * C2x1) / normdict[L]
def _read_flow_obs(path, prefix, c, dtr_cnfg=1, version="openQCD", obspos=0, sum_t=True, **kwargs):
def _read_flow_obs(path: str, prefix: str, c: float, dtr_cnfg: int=1, version: str="openQCD", obspos: int=0, sum_t: bool=True, **kwargs) -> Obs:
"""Read a flow observable based on openQCD gradient flow measurements.
Parameters
@ -1059,7 +1062,7 @@ def _read_flow_obs(path, prefix, c, dtr_cnfg=1, version="openQCD", obspos=0, sum
return result
def qtop_projection(qtop, target=0):
def qtop_projection(qtop: Obs, target: int=0) -> Obs:
"""Returns the projection to the topological charge sector defined by target.
Parameters
@ -1085,7 +1088,7 @@ def qtop_projection(qtop, target=0):
return reto
def read_qtop_sector(path, prefix, c, target=0, **kwargs):
def read_qtop_sector(path: str, prefix: str, c: float, target: int=0, **kwargs) -> Obs:
"""Constructs reweighting factors to a specified topological sector.
Parameters
@ -1143,7 +1146,7 @@ def read_qtop_sector(path, prefix, c, target=0, **kwargs):
return qtop_projection(qtop, target=target)
def read_ms5_xsf(path, prefix, qc, corr, sep="r", **kwargs):
def read_ms5_xsf(path: str, prefix: str, qc: str, corr: str, sep: str="r", **kwargs) -> Corr:
"""
Read data from files in the specified directory with the specified prefix and quark combination extension, and return a `Corr` object containing the data.