simplify internal method for read_rwms

This commit is contained in:
Justus Kuhlmann 2025-11-03 13:45:56 +00:00
commit 33be7c2ecb

View file

@ -13,15 +13,7 @@ from io import BufferedReader
from typing import Optional, Union, TypedDict from typing import Optional, Union, TypedDict
class rwms_kwargs(TypedDict): def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[str]]=None, **kwargs) -> list[Obs]:
files: list[str]
postfix: str
r_start: list[int]
r_stop: list[int]
r_step: int
def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[str]]=None, **kwargs: rwms_kwargs) -> list[Obs]:
"""Read rwms format from given folder structure. Returns a list of length nrw """Read rwms format from given folder structure. Returns a list of length nrw
Parameters Parameters
@ -69,11 +61,11 @@ def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[s
replica = len(ls) replica = len(ls)
r_start: list[Union[int, None]] = kwargs.get('r_start', [None] * replica) r_start: list[int] = kwargs.get('r_start', [0] * replica)
if len(r_start) != replica: if len(r_start) != replica:
raise Exception('r_start does not match number of replicas') raise Exception('r_start does not match number of replicas')
r_stop: list[Union[int, None]] = kwargs.get('r_stop', [None] * replica) r_stop: list[int] = kwargs.get('r_stop', [-1] * replica)
if len(r_stop) != replica: if len(r_stop) != replica:
raise Exception('r_stop does not match number of replicas') raise Exception('r_stop does not match number of replicas')
@ -153,10 +145,8 @@ def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[s
configlist[-1].append(config_no) configlist[-1].append(config_no)
for i in range(nrw): for i in range(nrw):
if (version == '2.0'): if (version == '2.0'):
tmpd: dict = _read_array_openQCD2(fp) tmp_n, tmp_rw = _read_array_openQCD2(fp)
tmpd = _read_array_openQCD2(fp) tmp_n, tmp_rw = _read_array_openQCD2(fp)
tmp_rw: list[float] = tmpd['arr']
tmp_n: list[int] = tmpd['n']
tmp_nfct = 1.0 tmp_nfct = 1.0
for j in range(tmp_n[0]): for j in range(tmp_n[0]):
tmp_nfct *= np.mean(np.exp(-np.asarray(tmp_rw[j]))) tmp_nfct *= np.mean(np.exp(-np.asarray(tmp_rw[j])))
@ -189,7 +179,7 @@ def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[s
offset = configlist[-1][0] - 1 offset = configlist[-1][0] - 1
configlist[-1] = [item - offset for item in configlist[-1]] configlist[-1] = [item - offset for item in configlist[-1]]
if r_start[rep] is None: if r_start[rep] == 0:
r_start_index.append(0) r_start_index.append(0)
else: else:
try: try:
@ -198,7 +188,7 @@ def read_rwms(path: str, prefix: str, version: str='2.0', names: Optional[list[s
raise Exception('Config %d not in file with range [%d, %d]' % ( raise Exception('Config %d not in file with range [%d, %d]' % (
r_start[rep], configlist[-1][0], configlist[-1][-1])) from None r_start[rep], configlist[-1][0], configlist[-1][-1])) from None
if r_stop[rep] is None: if r_stop[rep] == -1:
r_stop_index.append(len(configlist[-1]) - 1) r_stop_index.append(len(configlist[-1]) - 1)
else: else:
try: try:
@ -618,7 +608,7 @@ def _find_files(path: str, prefix: str, postfix: str, ext: str, known_files: lis
return files return files
def _read_array_openQCD2(fp: BufferedReader) -> dict[str, Union[int, tuple[int, int], list[list[float]]]]: def _read_array_openQCD2(fp: BufferedReader) -> tuple[tuple[int, int], list[list[float]]]:
t = fp.read(4) t = fp.read(4)
d = struct.unpack('i', t)[0] d = struct.unpack('i', t)[0]
t = fp.read(4 * d) t = fp.read(4 * d)
@ -641,7 +631,7 @@ def _read_array_openQCD2(fp: BufferedReader) -> dict[str, Union[int, tuple[int,
tmp = struct.unpack('%d%s' % (m, types), t) tmp = struct.unpack('%d%s' % (m, types), t)
arr = _parse_array_openQCD2(d, n, size, tmp, quadrupel=True) arr = _parse_array_openQCD2(d, n, size, tmp, quadrupel=True)
return {'d': d, 'n': n, 'size': size, 'arr': arr} return n, arr
def read_qtop(path: str, prefix: str, c: float, dtr_cnfg: int=1, version: str="openQCD", **kwargs) -> Obs: def read_qtop(path: str, prefix: str, c: float, dtr_cnfg: int=1, version: str="openQCD", **kwargs) -> Obs: