From f716835bcea81396d70b576ce7d8ecbc62dedbb8 Mon Sep 17 00:00:00 2001 From: Justus Kuhlmann Date: Sun, 30 Mar 2025 18:03:13 +0000 Subject: [PATCH] add checks, bugfix --- corrlib/toml.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/corrlib/toml.py b/corrlib/toml.py index 68fbaf6..04e0937 100644 --- a/corrlib/toml.py +++ b/corrlib/toml.py @@ -17,7 +17,7 @@ import datalad.api as dl import os -def check_project_data(d): +def check_project_data(d: dict) -> None: if 'project' not in d.keys() or 'measurements' not in d.keys() or len(list(d.keys())) > 2: raise ValueError('There should only be two key on the top level, "project" and "measurements"!') project_data = d['project'] @@ -30,7 +30,17 @@ def check_project_data(d): return -def import_toml(path, file, copy_file=True): +def check_measurement_data(measurements: dict) -> None: + var_names: list[str] = ["path", "ensemble", "param_file", "version", "prefix", "cfg_seperator", "names"] + for mname, md in measurements.items(): + for var_name in var_names: + if var_name not in md.keys(): + raise ImportError("Measurment '" + mname + "' does not possess nessecary variable '" + var_name + "'. \ + Please add this to the measurements definition.") + return + + +def import_toml(path: str, file: str, copy_file: bool=True) -> None: """ Import a project decribed by a .toml file. @@ -45,9 +55,10 @@ def import_toml(path, file, copy_file=True): with open(file, 'rb') as fp: toml_dict = toml.load(fp) check_project_data(toml_dict) - project = toml_dict['project'] - measurements = toml_dict['measurements'] - uuid = import_project(path, project['url']) + project: dict = toml_dict['project'] + measurements: dict = toml_dict['measurements'] + check_measurement_data(measurements) + uuid = import_project(path, project['url'], aliases = project.get('aliases', None)) for mname, md in measurements.items(): print("Import measurement: " + mname) ensemble = md['ensemble'] @@ -64,7 +75,7 @@ def import_toml(path, file, copy_file=True): if not os.path.exists(os.path.join(path, "toml_imports", uuid)): os.makedirs(os.path.join(path, "toml_imports", uuid)) if copy_file: - import_file = os.path.join(path, "toml_imports", uuid, file) + import_file = os.path.join(path, "toml_imports", uuid, file.split("/")[-1]) shutil.copy(file, import_file) dl.save(import_file, message="Import using " + import_file, dataset=path) print("File copied to " + import_file)