correct mypy issues

This commit is contained in:
Justus Kuhlmann 2025-12-02 12:35:09 +01:00
commit 4546688d97
Signed by: jkuhl
GPG key ID: 00ED992DD79B85A6
8 changed files with 64 additions and 50 deletions

View file

@ -1,8 +1,9 @@
from corrlib import cli, __app_name__ from corrlib import cli, __app_name__
def main(): def main() -> None:
cli.app(prog_name=__app_name__) cli.app(prog_name=__app_name__)
return
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -5,10 +5,11 @@ import pandas as pd
import numpy as np import numpy as np
from .input.implementations import codes from .input.implementations import codes
from .tools import k2m, get_file from .tools import k2m, get_file
from typing import Any, Union, Optional
# this will implement the search functionality # this will implement the search functionality
def _project_lookup_by_alias(db, alias): def _project_lookup_by_alias(db: str, alias: str) -> str:
# this will lookup the project name based on the alias # this will lookup the project name based on the alias
conn = sqlite3.connect(db) conn = sqlite3.connect(db)
c = conn.cursor() c = conn.cursor()
@ -19,10 +20,10 @@ def _project_lookup_by_alias(db, alias):
print("Error: multiple projects found with alias " + alias) print("Error: multiple projects found with alias " + alias)
elif len(results) == 0: elif len(results) == 0:
raise Exception("Error: no project found with alias " + alias) raise Exception("Error: no project found with alias " + alias)
return results[0][0] return str(results[0][0])
def _project_lookup_by_id(db, uuid): def _project_lookup_by_id(db: str, uuid: str) -> list[tuple[str, str]]:
conn = sqlite3.connect(db) conn = sqlite3.connect(db)
c = conn.cursor() c = conn.cursor()
c.execute(f"SELECT * FROM 'projects' WHERE id = '{uuid}'") c.execute(f"SELECT * FROM 'projects' WHERE id = '{uuid}'")
@ -31,7 +32,8 @@ def _project_lookup_by_id(db, uuid):
return results return results
def _db_lookup(db, ensemble, correlator_name,code, project=None, parameters=None, created_before=None, created_after=None, updated_before=None, updated_after=None, revision=None): def _db_lookup(db: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None,
created_before: Optional[str]=None, created_after: Optional[Any]=None, updated_before: Optional[Any]=None, updated_after: Optional[Any]=None) -> pd.DataFrame:
project_str = project project_str = project
search_expr = f"SELECT * FROM 'backlogs' WHERE name = '{correlator_name}' AND ensemble = '{ensemble}'" search_expr = f"SELECT * FROM 'backlogs' WHERE name = '{correlator_name}' AND ensemble = '{ensemble}'"
@ -55,7 +57,7 @@ def _db_lookup(db, ensemble, correlator_name,code, project=None, parameters=Non
return results return results
def sfcf_filter(results, **kwargs): def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
drops = [] drops = []
for ind in range(len(results)): for ind in range(len(results)):
result = results.iloc[ind] result = results.iloc[ind]
@ -138,24 +140,25 @@ def sfcf_filter(results, **kwargs):
return results.drop(drops) return results.drop(drops)
def find_record(path, ensemble, correlator_name, code, project=None, parameters=None, created_before=None, created_after=None, updated_before=None, updated_after=None, revision=None, **kwargs): def find_record(path: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None,
created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame:
db = path + '/backlogger.db' db = path + '/backlogger.db'
if code not in codes: if code not in codes:
raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes)) raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes))
get_file(path, "backlogger.db") get_file(path, "backlogger.db")
results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after, revision=revision) results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after)
if code == "sfcf": if code == "sfcf":
results = sfcf_filter(results, **kwargs) results = sfcf_filter(results, **kwargs)
print("Found " + str(len(results)) + " result" + ("s" if len(results)>1 else "")) print("Found " + str(len(results)) + " result" + ("s" if len(results)>1 else ""))
return results.reset_index() return results.reset_index()
def find_project(path, name): def find_project(path: str, name: str) -> str:
get_file(path, "backlogger.db") get_file(path, "backlogger.db")
return _project_lookup_by_alias(os.path.join(path, "backlogger.db"), name) return _project_lookup_by_alias(os.path.join(path, "backlogger.db"), name)
def list_projects(path): def list_projects(path: str) -> list[tuple[str, str]]:
db = path + '/backlogger.db' db = path + '/backlogger.db'
get_file(path, "backlogger.db") get_file(path, "backlogger.db")
conn = sqlite3.connect(db) conn = sqlite3.connect(db)

View file

@ -5,7 +5,7 @@ import git
GITMODULES_FILE = '.gitmodules' GITMODULES_FILE = '.gitmodules'
def move_submodule(repo_path, old_path, new_path): def move_submodule(repo_path: str, old_path: str, new_path: str) -> None:
""" """
Move a submodule to a new location. Move a submodule to a new location.
@ -41,3 +41,4 @@ def move_submodule(repo_path, old_path, new_path):
repo.git.add('.gitmodules') repo.git.add('.gitmodules')
# save new state of the dataset # save new state of the dataset
dl.save(repo_path, message=f"Move module from {old_path} to {new_path}", dataset=repo_path) dl.save(repo_path, message=f"Move module from {old_path} to {new_path}", dataset=repo_path)
return

View file

@ -3,7 +3,7 @@ import datalad.api as dl
import os import os
def _create_db(db): def _create_db(db: str) -> None:
""" """
Create the database file and the table. Create the database file and the table.
@ -32,9 +32,10 @@ def _create_db(db):
updated_at TEXT)''') updated_at TEXT)''')
conn.commit() conn.commit()
conn.close() conn.close()
return
def create(path): def create(path: str) -> None:
""" """
Create folder of backlogs. Create folder of backlogs.
@ -50,3 +51,4 @@ def create(path):
fp.write(".cache") fp.write(".cache")
fp.close() fp.close()
dl.save(path, dataset=path, message="Initialize backlogger directory.") dl.save(path, dataset=path, message="Initialize backlogger directory.")
return

View file

@ -6,10 +6,10 @@ from .git_tools import move_submodule
import shutil import shutil
from .find import _project_lookup_by_id from .find import _project_lookup_by_id
from .tools import list2str, str2list, get_file from .tools import list2str, str2list, get_file
from typing import Union from typing import Union, Optional
def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Union[str, None]=None, aliases: Union[str, None]=None, code: Union[str, None]=None): def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Union[list[str], None]=None, aliases: Union[list[str], None]=None, code: Union[str, None]=None) -> None:
""" """
Create a new project entry in the database. Create a new project entry in the database.
@ -33,10 +33,10 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni
raise ValueError("Project already imported, use update_project() instead.") raise ValueError("Project already imported, use update_project() instead.")
dl.unlock(db, dataset=path) dl.unlock(db, dataset=path)
alias_str = None alias_str = ""
if aliases is not None: if aliases is not None:
alias_str = list2str(aliases) alias_str = list2str(aliases)
tag_str = None tag_str = ""
if tags is not None: if tags is not None:
tag_str = list2str(tags) tag_str = list2str(tags)
c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", (uuid, alias_str, tag_str, owner, code)) c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", (uuid, alias_str, tag_str, owner, code))
@ -45,7 +45,7 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni
dl.save(db, message="Added entry for project " + uuid + " to database", dataset=path) dl.save(db, message="Added entry for project " + uuid + " to database", dataset=path)
def update_project_data(path, uuid, prop, value = None): def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] = None) -> None:
get_file(path, "backlogger.db") get_file(path, "backlogger.db")
conn = sqlite3.connect(os.path.join(path, "backlogger.db")) conn = sqlite3.connect(os.path.join(path, "backlogger.db"))
c = conn.cursor() c = conn.cursor()
@ -55,7 +55,7 @@ def update_project_data(path, uuid, prop, value = None):
return return
def update_aliases(path: str, uuid: str, aliases: list[str]): def update_aliases(path: str, uuid: str, aliases: list[str]) -> None:
db = os.path.join(path, "backlogger.db") db = os.path.join(path, "backlogger.db")
get_file(path, "backlogger.db") get_file(path, "backlogger.db")
known_data = _project_lookup_by_id(db, uuid)[0] known_data = _project_lookup_by_id(db, uuid)[0]
@ -82,7 +82,7 @@ def update_aliases(path: str, uuid: str, aliases: list[str]):
return return
def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Union[str, None]=None, aliases: Union[str, None]=None, code: Union[str, None]=None, isDataset: bool=True): def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Optional[list[str]]=None, aliases: Optional[list[str]]=None, code: Optional[str]=None, isDataset: bool=True) -> str:
""" """
Parameters Parameters
---------- ----------
@ -117,7 +117,7 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Unio
dl.install(path=tmp_path, source=url, dataset=path) dl.install(path=tmp_path, source=url, dataset=path)
tmp_ds = dl.Dataset(tmp_path) tmp_ds = dl.Dataset(tmp_path)
conf = dlc.ConfigManager(tmp_ds) conf = dlc.ConfigManager(tmp_ds)
uuid = conf.get("datalad.dataset.id") uuid = str(conf.get("datalad.dataset.id"))
if not uuid: if not uuid:
raise ValueError("The dataset does not have a uuid!") raise ValueError("The dataset does not have a uuid!")
if not os.path.exists(path + "/projects/" + uuid): if not os.path.exists(path + "/projects/" + uuid):
@ -142,9 +142,10 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Unio
return uuid return uuid
def drop_project_data(path: str, uuid: str, path_in_project: str = ""): def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None:
""" """
Drop (parts of) a prject to free up diskspace Drop (parts of) a prject to free up diskspace
""" """
dl.drop(path + "/projects/" + uuid + "/" + path_in_project) dl.drop(path + "/projects/" + uuid + "/" + path_in_project)
return

View file

@ -9,9 +9,10 @@ from pyerrors import Obs, Corr, dump_object, load_object
from hashlib import sha256 from hashlib import sha256
from .tools import cached, get_file from .tools import cached, get_file
import shutil import shutil
from typing import Any
def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=None): def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: str) -> None:
""" """
Write a measurement to the backlog. Write a measurement to the backlog.
If the file for the measurement already exists, update the measurement. If the file for the measurement already exists, update the measurement.
@ -97,7 +98,7 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=No
dl.save(files, message="Add measurements to database", dataset=path) dl.save(files, message="Add measurements to database", dataset=path)
def load_record(path: str, meas_path: str): def load_record(path: str, meas_path: str) -> Union[Corr, Obs]:
""" """
Load a list of records by their paths. Load a list of records by their paths.
@ -116,7 +117,7 @@ def load_record(path: str, meas_path: str):
return load_records(path, [meas_path])[0] return load_records(path, [meas_path])[0]
def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union[Corr, Obs]]: def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]:
""" """
Load a list of records by their paths. Load a list of records by their paths.
@ -138,7 +139,7 @@ def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union
needed_data[file] = [] needed_data[file] = []
key = mpath.split("::")[1] key = mpath.split("::")[1]
needed_data[file].append(key) needed_data[file].append(key)
returned_data: list = [] returned_data: list[Any] = []
for file in needed_data.keys(): for file in needed_data.keys():
for key in list(needed_data[file]): for key in list(needed_data[file]):
if os.path.exists(cache_path(path, file, key) + ".p"): if os.path.exists(cache_path(path, file, key) + ".p"):
@ -154,7 +155,7 @@ def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union
return returned_data return returned_data
def cache_dir(path, file): def cache_dir(path: str, file: str) -> str:
cache_path_list = [path] cache_path_list = [path]
cache_path_list.append(".cache") cache_path_list.append(".cache")
cache_path_list.extend(file.split("/")[1:]) cache_path_list.extend(file.split("/")[1:])
@ -162,19 +163,19 @@ def cache_dir(path, file):
return cache_path return cache_path
def cache_path(path, file, key): def cache_path(path: str, file: str, key: str) -> str:
cache_path = os.path.join(cache_dir(path, file), key) cache_path = os.path.join(cache_dir(path, file), key)
return cache_path return cache_path
def preload(path: str, file: str): def preload(path: str, file: str) -> dict[str, Any]:
get_file(path, file) get_file(path, file)
filedict = pj.load_json_dict(os.path.join(path, file)) filedict: dict[str, Any] = pj.load_json_dict(os.path.join(path, file))
print("> read file") print("> read file")
return filedict return filedict
def drop_record(path: str, meas_path: str): def drop_record(path: str, meas_path: str) -> None:
file_in_archive = meas_path.split("::")[0] file_in_archive = meas_path.split("::")[0]
file = os.path.join(path, file_in_archive) file = os.path.join(path, file_in_archive)
db = os.path.join(path, 'backlogger.db') db = os.path.join(path, 'backlogger.db')
@ -199,7 +200,9 @@ def drop_record(path: str, meas_path: str):
else: else:
raise ValueError("This measurement does not exist as a file!") raise ValueError("This measurement does not exist as a file!")
def drop_cache(path: str):
def drop_cache(path: str) -> None:
cache_dir = os.path.join(path, ".cache") cache_dir = os.path.join(path, ".cache")
for f in os.listdir(cache_dir): for f in os.listdir(cache_dir):
shutil.rmtree(os.path.join(cache_dir, f)) shutil.rmtree(os.path.join(cache_dir, f))
return

View file

@ -16,15 +16,16 @@ from .meas_io import write_measurement
import datalad.api as dl import datalad.api as dl
import os import os
from .input.implementations import codes as known_codes from .input.implementations import codes as known_codes
from typing import Any
def replace_string(string: str, name: str, val: str): def replace_string(string: str, name: str, val: str) -> str:
if '{' + name + '}' in string: if '{' + name + '}' in string:
n = string.replace('{' + name + '}', val) n = string.replace('{' + name + '}', val)
return n return n
else: else:
return string return string
def replace_in_meas(measurements: dict, vars: dict[str, str]): def replace_in_meas(measurements: dict[str, dict[str, Any]], vars: dict[str, str]) -> dict[str, dict[str, Any]]:
# replace global variables # replace global variables
for name, value in vars.items(): for name, value in vars.items():
for m in measurements.keys(): for m in measurements.keys():
@ -36,7 +37,7 @@ def replace_in_meas(measurements: dict, vars: dict[str, str]):
measurements[m][key][i] = replace_string(measurements[m][key][i], name, value) measurements[m][key][i] = replace_string(measurements[m][key][i], name, value)
return measurements return measurements
def fill_cons(measurements, constants): def fill_cons(measurements: dict[str, dict[str, Any]], constants: dict[str, str]) -> dict[str, dict[str, Any]]:
for m in measurements.keys(): for m in measurements.keys():
for name, val in constants.items(): for name, val in constants.items():
if name not in measurements[m].keys(): if name not in measurements[m].keys():
@ -44,7 +45,7 @@ def fill_cons(measurements, constants):
return measurements return measurements
def check_project_data(d: dict) -> None: def check_project_data(d: dict[str, dict[str, str]]) -> None:
if 'project' not in d.keys() or 'measurements' not in d.keys() or len(list(d.keys())) > 4: if 'project' not in d.keys() or 'measurements' not in d.keys() or len(list(d.keys())) > 4:
raise ValueError('There should only be maximally be four keys on the top level, "project" and "measurements" are mandatory, "contants" is optional!') raise ValueError('There should only be maximally be four keys on the top level, "project" and "measurements" are mandatory, "contants" is optional!')
project_data = d['project'] project_data = d['project']
@ -57,7 +58,7 @@ def check_project_data(d: dict) -> None:
return return
def check_measurement_data(measurements: dict, code: str) -> None: def check_measurement_data(measurements: dict[str, dict[str, str]], code: str) -> None:
var_names: list[str] = [] var_names: list[str] = []
if code == "sfcf": if code == "sfcf":
var_names = ["path", "ensemble", "param_file", "version", "prefix", "cfg_seperator", "names"] var_names = ["path", "ensemble", "param_file", "version", "prefix", "cfg_seperator", "names"]
@ -91,14 +92,14 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None:
with open(file, 'rb') as fp: with open(file, 'rb') as fp:
toml_dict = toml.load(fp) toml_dict = toml.load(fp)
check_project_data(toml_dict) check_project_data(toml_dict)
project: dict = toml_dict['project'] project: dict[str, Any] = toml_dict['project']
if project['code'] not in known_codes: if project['code'] not in known_codes:
raise ValueError('Code' + project['code'] + 'has no import implementation!') raise ValueError('Code' + project['code'] + 'has no import implementation!')
measurements: dict = toml_dict['measurements'] measurements: dict[str, dict[str, Any]] = toml_dict['measurements']
measurements = fill_cons(measurements, toml_dict['constants'] if 'constants' in toml_dict else {}) measurements = fill_cons(measurements, toml_dict['constants'] if 'constants' in toml_dict else {})
measurements = replace_in_meas(measurements, toml_dict['replace'] if 'replace' in toml_dict else {}) measurements = replace_in_meas(measurements, toml_dict['replace'] if 'replace' in toml_dict else {})
check_measurement_data(measurements, project['code']) check_measurement_data(measurements, project['code'])
aliases = project.get('aliases', None) aliases = project.get('aliases', [])
uuid = project.get('uuid', None) uuid = project.get('uuid', None)
if uuid is not None: if uuid is not None:
if not os.path.exists(path + "/projects/" + uuid): if not os.path.exists(path + "/projects/" + uuid):
@ -133,16 +134,16 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None:
for rwp in ["integrator", "eps", "ntot", "dnms"]: for rwp in ["integrator", "eps", "ntot", "dnms"]:
param[rwp] = "Unknown" param[rwp] = "Unknown"
param['type'] = 't0' param['type'] = 't0'
measurement = openQCD.extract_t0(path, uuid, md['path'], param, md["prefix"], md["dtr_read"], md["xmin"], md["spatial_extent"], measurement = openQCD.extract_t0(path, uuid, md['path'], param, str(md["prefix"]), int(md["dtr_read"]), int(md["xmin"]), int(md["spatial_extent"]),
fit_range=md.get('fit_range', 5), postfix=md.get('postfix', None), names=md.get('names', None), files=md.get('files', None)) fit_range=int(md.get('fit_range', 5)), postfix=str(md.get('postfix', '')), names=md.get('names', []), files=md.get('files', []))
elif md['measurement'] == 't1': elif md['measurement'] == 't1':
if 'param_file' in md: if 'param_file' in md:
param = openQCD.read_ms3_param(path, uuid, md['param_file']) param = openQCD.read_ms3_param(path, uuid, md['param_file'])
param['type'] = 't1' param['type'] = 't1'
measurement = openQCD.extract_t1(path, uuid, md['path'], param, md["prefix"], md["dtr_read"], md["xmin"], md["spatial_extent"], measurement = openQCD.extract_t1(path, uuid, md['path'], param, str(md["prefix"]), int(md["dtr_read"]), int(md["xmin"]), int(md["spatial_extent"]),
fit_range=md.get('fit_range', 5), postfix=md.get('postfix', None), names=md.get('names', None), files=md.get('files', None)) fit_range=int(md.get('fit_range', 5)), postfix=str(md.get('postfix', '')), names=md.get('names', []), files=md.get('files', []))
write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else None)) write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else ''))
if not os.path.exists(os.path.join(path, "toml_imports", uuid)): if not os.path.exists(os.path.join(path, "toml_imports", uuid)):
os.makedirs(os.path.join(path, "toml_imports", uuid)) os.makedirs(os.path.join(path, "toml_imports", uuid))
@ -155,7 +156,7 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None:
return return
def reimport_project(path, uuid): def reimport_project(path: str, uuid: str) -> None:
""" """
Reimport an existing project using the files that are already available for this project. Reimport an existing project using the files that are already available for this project.
@ -173,6 +174,7 @@ def reimport_project(path, uuid):
return return
def update_project(path, uuid): def update_project(path: str, uuid: str) -> None:
dl.update(how='merge', follow='sibling', dataset=os.path.join(path, "projects", uuid)) dl.update(how='merge', follow='sibling', dataset=os.path.join(path, "projects", uuid))
# reimport_project(path, uuid) # reimport_project(path, uuid)
return

View file

@ -2,10 +2,10 @@ import os
import datalad.api as dl import datalad.api as dl
def str2list(string: str): def str2list(string: str) -> list[str]:
return string.split(",") return string.split(",")
def list2str(mylist): def list2str(mylist: list[str]) -> str:
s = ",".join(mylist) s = ",".join(mylist)
return s return s
@ -19,10 +19,11 @@ def k2m(k: float) -> float:
return (1/(2*k))-4 return (1/(2*k))-4
def get_file(path: str, file: str): def get_file(path: str, file: str) -> None:
if file == "backlogger.db": if file == "backlogger.db":
print("Downloading database...") print("Downloading database...")
else: else:
print("Downloading data...") print("Downloading data...")
dl.get(os.path.join(path, file), dataset=path) dl.get(os.path.join(path, file), dataset=path)
print("> downloaded file") print("> downloaded file")
return