add checks, bugfix

This commit is contained in:
Justus Kuhlmann 2025-03-30 18:03:13 +00:00
parent a2b96becc1
commit f716835bce

View file

@ -17,7 +17,7 @@ import datalad.api as dl
import os
def check_project_data(d):
def check_project_data(d: dict) -> None:
if 'project' not in d.keys() or 'measurements' not in d.keys() or len(list(d.keys())) > 2:
raise ValueError('There should only be two key on the top level, "project" and "measurements"!')
project_data = d['project']
@ -30,7 +30,17 @@ def check_project_data(d):
return
def import_toml(path, file, copy_file=True):
def check_measurement_data(measurements: dict) -> None:
var_names: list[str] = ["path", "ensemble", "param_file", "version", "prefix", "cfg_seperator", "names"]
for mname, md in measurements.items():
for var_name in var_names:
if var_name not in md.keys():
raise ImportError("Measurment '" + mname + "' does not possess nessecary variable '" + var_name + "'. \
Please add this to the measurements definition.")
return
def import_toml(path: str, file: str, copy_file: bool=True) -> None:
"""
Import a project decribed by a .toml file.
@ -45,9 +55,10 @@ def import_toml(path, file, copy_file=True):
with open(file, 'rb') as fp:
toml_dict = toml.load(fp)
check_project_data(toml_dict)
project = toml_dict['project']
measurements = toml_dict['measurements']
uuid = import_project(path, project['url'])
project: dict = toml_dict['project']
measurements: dict = toml_dict['measurements']
check_measurement_data(measurements)
uuid = import_project(path, project['url'], aliases = project.get('aliases', None))
for mname, md in measurements.items():
print("Import measurement: " + mname)
ensemble = md['ensemble']
@ -64,7 +75,7 @@ def import_toml(path, file, copy_file=True):
if not os.path.exists(os.path.join(path, "toml_imports", uuid)):
os.makedirs(os.path.join(path, "toml_imports", uuid))
if copy_file:
import_file = os.path.join(path, "toml_imports", uuid, file)
import_file = os.path.join(path, "toml_imports", uuid, file.split("/")[-1])
shutil.copy(file, import_file)
dl.save(import_file, message="Import using " + import_file, dataset=path)
print("File copied to " + import_file)