correct mypy issues

This commit is contained in:
Justus Kuhlmann 2025-12-02 12:35:09 +01:00
commit 4546688d97
Signed by: jkuhl
GPG key ID: 00ED992DD79B85A6
8 changed files with 64 additions and 50 deletions

View file

@ -9,9 +9,10 @@ from pyerrors import Obs, Corr, dump_object, load_object
from hashlib import sha256
from .tools import cached, get_file
import shutil
from typing import Any
def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=None):
def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: str) -> None:
"""
Write a measurement to the backlog.
If the file for the measurement already exists, update the measurement.
@ -97,7 +98,7 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=No
dl.save(files, message="Add measurements to database", dataset=path)
def load_record(path: str, meas_path: str):
def load_record(path: str, meas_path: str) -> Union[Corr, Obs]:
"""
Load a list of records by their paths.
@ -116,7 +117,7 @@ def load_record(path: str, meas_path: str):
return load_records(path, [meas_path])[0]
def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union[Corr, Obs]]:
def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]:
"""
Load a list of records by their paths.
@ -138,7 +139,7 @@ def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union
needed_data[file] = []
key = mpath.split("::")[1]
needed_data[file].append(key)
returned_data: list = []
returned_data: list[Any] = []
for file in needed_data.keys():
for key in list(needed_data[file]):
if os.path.exists(cache_path(path, file, key) + ".p"):
@ -154,7 +155,7 @@ def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union
return returned_data
def cache_dir(path, file):
def cache_dir(path: str, file: str) -> str:
cache_path_list = [path]
cache_path_list.append(".cache")
cache_path_list.extend(file.split("/")[1:])
@ -162,19 +163,19 @@ def cache_dir(path, file):
return cache_path
def cache_path(path, file, key):
def cache_path(path: str, file: str, key: str) -> str:
cache_path = os.path.join(cache_dir(path, file), key)
return cache_path
def preload(path: str, file: str):
def preload(path: str, file: str) -> dict[str, Any]:
get_file(path, file)
filedict = pj.load_json_dict(os.path.join(path, file))
filedict: dict[str, Any] = pj.load_json_dict(os.path.join(path, file))
print("> read file")
return filedict
def drop_record(path: str, meas_path: str):
def drop_record(path: str, meas_path: str) -> None:
file_in_archive = meas_path.split("::")[0]
file = os.path.join(path, file_in_archive)
db = os.path.join(path, 'backlogger.db')
@ -199,7 +200,9 @@ def drop_record(path: str, meas_path: str):
else:
raise ValueError("This measurement does not exist as a file!")
def drop_cache(path: str):
def drop_cache(path: str) -> None:
cache_dir = os.path.join(path, ".cache")
for f in os.listdir(cache_dir):
shutil.rmtree(os.path.join(cache_dir, f))
return