From 96cdec46e236583ab8e73563d22d53380aa8f720 Mon Sep 17 00:00:00 2001 From: Justus Kuhlmann Date: Sat, 29 Mar 2025 14:04:52 +0000 Subject: [PATCH] annotate read_ms5_xsf --- pyerrors/input/openQCD.py | 171 ++++++++++++++++++-------------------- 1 file changed, 79 insertions(+), 92 deletions(-) diff --git a/pyerrors/input/openQCD.py b/pyerrors/input/openQCD.py index 339c8bfd..98ae202e 100644 --- a/pyerrors/input/openQCD.py +++ b/pyerrors/input/openQCD.py @@ -10,10 +10,19 @@ from ..correlators import Corr from .misc import fit_t0 from .utils import sort_names from io import BufferedReader -from typing import Optional, Union +from typing import Optional, Union, TypedDict, Unpack -def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[str]]=None, **kwargs) -> list[Obs]: +class rwms_kwargs(TypedDict): + files: list[str] + postfix: str + r_start: list[Union[int]] + r_stop: list[Union[int]] + r_step: int + + + +def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[str]]=None, **kwargs: Unpack[rwms_kwargs]) -> list[Obs]: """Read rwms format from given folder structure. Returns a list of length nrw Parameters @@ -27,7 +36,7 @@ def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[s version : str version of openQCD, default 2.0 names : list - list of names that is assigned to the data according according + list of names that is assigned to the data according to the order in the file list. Use careful, if you do not provide file names! r_start : list list which contains the first config to be read for each replicum @@ -53,39 +62,24 @@ def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[s if version not in known_oqcd_versions: raise Exception('Unknown openQCD version defined!') print("Working with openQCD version " + version) - if 'postfix' in kwargs: - postfix = kwargs.get('postfix') - else: - postfix = '' + postfix: str = kwargs.get('postfix', '') - if 'files' in kwargs: - known_files = kwargs.get('files') - else: - known_files = [] + known_files: list[str] = kwargs.get('files', []) + ls = _find_files(path, prefix, postfix, 'dat', known_files=known_files) replica = len(ls) - if 'r_start' in kwargs: - r_start = kwargs.get('r_start') - if len(r_start) != replica: - raise Exception('r_start does not match number of replicas') - r_start = [o if o else None for o in r_start] - else: - r_start = [None] * replica + r_start: list[Union[int, None]] = kwargs.get('r_start', [None] * replica) + if len(r_start) != replica: + raise Exception('r_start does not match number of replicas') - if 'r_stop' in kwargs: - r_stop = kwargs.get('r_stop') - if len(r_stop) != replica: - raise Exception('r_stop does not match number of replicas') - else: - r_stop = [None] * replica + r_stop: list[Union[int, None]] = kwargs.get('r_stop', [None] * replica) + if len(r_stop) != replica: + raise Exception('r_stop does not match number of replicas') - if 'r_step' in kwargs: - r_step = kwargs.get('r_step') - else: - r_step = 1 + r_step: int = kwargs.get('r_step', 1) print('Read reweighting factors from', prefix[:-1], ',', replica, 'replica', end='') @@ -110,14 +104,14 @@ def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[s print_err = 1 print() - deltas = [] + deltas: list[list[float]] = [] - configlist = [] + configlist: list[list[int]] = [] r_start_index = [] r_stop_index = [] for rep in range(replica): - tmp_array = [] + tmp_array: list[list] = [] with open(path + '/' + ls[rep], 'rb') as fp: t = fp.read(4) # number of reweighting factors @@ -144,7 +138,7 @@ def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[s for i in range(nrw): nfct.append(1) - nsrc = [] + nsrc: list[int] = [] for i in range(nrw): t = fp.read(4) nsrc.append(struct.unpack('i', t)[0]) @@ -161,11 +155,12 @@ def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[s configlist[-1].append(config_no) for i in range(nrw): if (version == '2.0'): + tmpd: dict = _read_array_openQCD2(fp) tmpd = _read_array_openQCD2(fp) - tmpd = _read_array_openQCD2(fp) - tmp_rw = tmpd['arr'] + tmp_rw: list[float] = tmpd['arr'] + tmp_n: list[int] = tmpd['n'] tmp_nfct = 1.0 - for j in range(tmpd['n'][0]): + for j in range(tmp_n[0]): tmp_nfct *= np.mean(np.exp(-np.asarray(tmp_rw[j]))) if print_err: print(config_no, i, j, @@ -179,7 +174,7 @@ def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[s for j in range(nfct[i]): t = fp.read(8 * nsrc[i]) t = fp.read(8 * nsrc[i]) - tmp_rw = struct.unpack('d' * nsrc[i], t) + tmp_rw: list[float] = struct.unpack('d' * nsrc[i], t) tmp_nfct *= np.mean(np.exp(-np.asarray(tmp_rw))) if print_err: print(config_no, i, j, @@ -232,7 +227,7 @@ def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[s return result -def _extract_flowed_energy_density(path: str, prefix: str, dtr_read: int, xmin: int, spatial_extent: int, postfix: str='ms', **kwargs) -> dict[float, Obs]: +def _extract_flowed_energy_density(path: str, prefix: str, dtr_read: int, xmin: int, spatial_extent: int, postfix: str='ms', **kwargs: Unpack[rwms_kwargs]) -> dict[float, Obs]: """Extract a dictionary with the flowed Yang-Mills action density from given .ms.dat files. Returns a dictionary with Obs as values and flow times as keys. @@ -319,18 +314,18 @@ def _extract_flowed_energy_density(path: str, prefix: str, dtr_read: int, xmin: print('Extract flowed Yang-Mills action density from', prefix, ',', replica, 'replica') - if 'names' in kwargs: - rep_names = kwargs.get('names') - else: + + rep_names: list[str] = kwargs.get('names', []) + if len(rep_names) == 0: rep_names = [] for entry in ls: truncated_entry = entry.split('.')[0] idx = truncated_entry.index('r') rep_names.append(truncated_entry[:idx] + '|' + truncated_entry[idx:]) - Ysum = [] + Ysum: list = [] - configlist = [] + configlist: list[list[int]] = [] r_start_index = [] r_stop_index = [] @@ -413,7 +408,7 @@ def _extract_flowed_energy_density(path: str, prefix: str, dtr_read: int, xmin: idl = [range(configlist[rep][r_start_index[rep]], configlist[rep][r_stop_index[rep]] + 1, r_step) for rep in range(replica)] E_dict = {} for n in range(nn + 1): - samples = [] + samples: list[list[float]] = [] for nrep, rep in enumerate(Ysum): samples.append([]) for cnfg in rep: @@ -599,7 +594,7 @@ def _parse_array_openQCD2(d: int, n: tuple[int, int], size: int, wa: Union[tuple return arr -def _find_files(path: str, prefix: str, postfix: str, ext: str, known_files: Union[str, list[str]]=[]) -> list[str]: +def _find_files(path: str, prefix: str, postfix: str, ext: str, known_files: list[str]=[]) -> list[str]: found = [] files = [] @@ -1146,7 +1141,7 @@ def read_qtop_sector(path: str, prefix: str, c: float, target: int=0, **kwargs) return qtop_projection(qtop, target=target) -def read_ms5_xsf(path: str, prefix: str, qc: str, corr: str, sep: str="r", **kwargs) -> Corr: +def read_ms5_xsf(path: str, prefix: str, qc: str, corr: str, sep: str="r", **kwargs) -> Union[Corr, CObs]: """ Read data from files in the specified directory with the specified prefix and quark combination extension, and return a `Corr` object containing the data. @@ -1188,9 +1183,7 @@ def read_ms5_xsf(path: str, prefix: str, qc: str, corr: str, sep: str="r", **kwa If there is an error unpacking binary data. """ - # found = [] files = [] - names = [] # test if the input is correct if qc not in ['dd', 'ud', 'du', 'uu']: @@ -1199,15 +1192,13 @@ def read_ms5_xsf(path: str, prefix: str, qc: str, corr: str, sep: str="r", **kwa if corr not in ["gS", "gP", "gA", "gV", "gVt", "lA", "lV", "lVt", "lT", "lTt", "g1", "l1"]: raise Exception("Unknown correlator!") - if "files" in kwargs: - known_files = kwargs.get("files") - else: - known_files = [] + known_files: list[str] = kwargs.get("files", []) + expected_idl = kwargs.get('idl', []) + files = _find_files(path, prefix, "ms5_xsf_" + qc, "dat", known_files=known_files) - if "names" in kwargs: - names = kwargs.get("names") - else: + names: list[str] = kwargs.get("names", []) + if len(names) == 0: for f in files: if not sep == "": se = f.split(".")[0] @@ -1216,31 +1207,30 @@ def read_ms5_xsf(path: str, prefix: str, qc: str, corr: str, sep: str="r", **kwa names.append(se.split(sep)[0] + "|r" + se.split(sep)[1]) else: names.append(prefix) - if 'idl' in kwargs: - expected_idl = kwargs.get('idl') + names = sorted(names) files = sorted(files) - cnfgs = [] - realsamples = [] - imagsamples = [] + cnfgs: list[list[int]] = [] + realsamples: list[list[list[float]]] = [] + imagsamples: list[list[list[float]]] = [] repnum = 0 for file in files: with open(path + "/" + file, "rb") as fp: - t = fp.read(8) - kappa = struct.unpack('d', t)[0] - t = fp.read(8) - csw = struct.unpack('d', t)[0] - t = fp.read(8) - dF = struct.unpack('d', t)[0] - t = fp.read(8) - zF = struct.unpack('d', t)[0] + tmp_bytes = fp.read(8) + kappa: float = struct.unpack('d', tmp_bytes)[0] + tmp_bytes = fp.read(8) + csw: float = struct.unpack('d', tmp_bytes)[0] + tmp_bytes = fp.read(8) + dF: float = struct.unpack('d', tmp_bytes)[0] + tmp_bytes = fp.read(8) + zF: float = struct.unpack('d', tmp_bytes)[0] - t = fp.read(4) - tmax = struct.unpack('i', t)[0] - t = fp.read(4) - bnd = struct.unpack('i', t)[0] + tmp_bytes = fp.read(4) + tmax: int = struct.unpack('i', tmp_bytes)[0] + tmp_bytes = fp.read(4) + bnd: int = struct.unpack('i', tmp_bytes)[0] placesBI = ["gS", "gP", "gA", "gV", @@ -1252,22 +1242,22 @@ def read_ms5_xsf(path: str, prefix: str, qc: str, corr: str, sep: str="r", **kwa # the chunks have the following structure: # confignumber, 10x timedependent complex correlators as doubles, 2x timeindependent complex correlators as doubles - chunksize = 4 + (8 * 2 * tmax * 10) + (8 * 2 * 2) packstr = '=i' + ('d' * 2 * tmax * 10) + ('d' * 2 * 2) + chunksize = struct.calcsize(packstr) cnfgs.append([]) realsamples.append([]) imagsamples.append([]) - for t in range(tmax): + for time in range(tmax): realsamples[repnum].append([]) imagsamples[repnum].append([]) if 'idl' in kwargs: left_idl = set(expected_idl[repnum]) while True: - cnfgt = fp.read(chunksize) - if not cnfgt: + cnfg_bytes = fp.read(chunksize) + if not cnfg_bytes: break - asascii = struct.unpack(packstr, cnfgt) - cnfg = asascii[0] + asascii = struct.unpack(packstr, cnfg_bytes) + cnfg: int = asascii[0] idl_wanted = True if 'idl' in kwargs: idl_wanted = (cnfg in expected_idl[repnum]) @@ -1280,24 +1270,21 @@ def read_ms5_xsf(path: str, prefix: str, qc: str, corr: str, sep: str="r", **kwa else: tmpcorr = asascii[1 + 2 * tmax * len(placesBI) + 2 * placesBB.index(corr):1 + 2 * tmax * len(placesBI) + 2 * placesBB.index(corr) + 2] - corrres = [[], []] + corrres: list[list[float]] = [[], []] for i in range(len(tmpcorr)): corrres[i % 2].append(tmpcorr[i]) - for t in range(int(len(tmpcorr) / 2)): - realsamples[repnum][t].append(corrres[0][t]) - for t in range(int(len(tmpcorr) / 2)): - imagsamples[repnum][t].append(corrres[1][t]) - if 'idl' in kwargs: - left_idl = list(left_idl) - if expected_idl[repnum] == left_idl: - raise ValueError("None of the idls searched for were found in replikum of file " + file) - elif len(left_idl) > 0: - warnings.warn('Could not find idls ' + str(left_idl) + ' in replikum of file ' + file, UserWarning) + for time in range(int(len(tmpcorr) / 2)): + realsamples[repnum][time].append(corrres[0][time]) + for time in range(int(len(tmpcorr) / 2)): + imagsamples[repnum][time].append(corrres[1][time]) + if len(expected_idl) > 0: + left_idl_list = list(left_idl) + if expected_idl[repnum] == left_idl_list: + raise ValueError("None of the idls searched for were found in replicum of file " + file) + elif len(left_idl_list) > 0: + warnings.warn('Could not find idls ' + str(left_idl) + ' in replicum of file ' + file, UserWarning) repnum += 1 - s = "Read correlator " + corr + " from " + str(repnum) + " replika with idls" + str(realsamples[0][t]) - for rep in range(1, repnum): - s += ", " + str(realsamples[rep][t]) - print(s) + print("Read correlator " + corr + " from " + str(repnum) + " replica with idls") print("Asserted run parameters:\n T:", tmax, "kappa:", kappa, "csw:", csw, "dF:", dF, "zF:", zF, "bnd:", bnd) # we have the data now... but we need to re format the whole thing and put it into Corr objects.