diff --git a/corrlib/__init__.py b/corrlib/__init__.py index afc6776..4e1b364 100644 --- a/corrlib/__init__.py +++ b/corrlib/__init__.py @@ -22,4 +22,3 @@ from .meas_io import load_records as load_records from .find import find_project as find_project from .find import find_record as find_record from .find import list_projects as list_projects -from .config import * diff --git a/corrlib/initialization.py b/corrlib/initialization.py index 130cea8..14bcaf0 100644 --- a/corrlib/initialization.py +++ b/corrlib/initialization.py @@ -63,7 +63,7 @@ def create(path: str) -> None: Create folder of backlogs. """ - dl.create(path) + create(path) _create_db(os.path.join(path, 'backlogger.db')) os.chmod(os.path.join(path, 'backlogger.db'), 0o666) # why does this not work? _create_config(path) diff --git a/corrlib/toml.py b/corrlib/toml.py index 11065fe..7b02e33 100644 --- a/corrlib/toml.py +++ b/corrlib/toml.py @@ -10,10 +10,12 @@ the import of projects via TOML. import tomllib as toml import shutil + +import datalad.api as dl +from .tracker import save from .input import sfcf, openQCD from .main import import_project, update_aliases from .meas_io import write_measurement -import datalad.api as dl import os from .input.implementations import codes as known_codes from typing import Any @@ -150,7 +152,7 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None: if copy_file: import_file = os.path.join(path, "toml_imports", uuid, file.split("/")[-1]) shutil.copy(file, import_file) - dl.save(import_file, message="Import using " + import_file, dataset=path) + save(path, files=[import_file], message="Import using " + import_file) print("File copied to " + import_file) print("Imported project.") return diff --git a/corrlib/tracker.py b/corrlib/tracker.py index fcf3994..7f63e9b 100644 --- a/corrlib/tracker.py +++ b/corrlib/tracker.py @@ -3,7 +3,7 @@ from configparser import ConfigParser from .trackers import datalad as dl -def get_tracker(path): +def get_tracker(path: str) -> str: config_path = os.path.join(path, '.corrlib') config = ConfigParser() if os.path.exists(config_path): @@ -12,7 +12,7 @@ def get_tracker(path): return tracker -def get(path, file): +def get(path: str, file: str) -> None: tracker = get_tracker(path) if tracker == 'datalad': dl.get_file(path, file) @@ -21,9 +21,18 @@ def get(path, file): return -def save(path, message, files): +def save(path: str, message: str, files: list[str]) -> None: tracker = get_tracker(path) if tracker == 'datalad': dl.save(files, message=message, dataset=path) else: raise ValueError(f"Tracker {tracker} is not supported.") + + +def create(path: str) -> None: + tracker = get_tracker(path) + if tracker == 'datalad': + dl.create(path) + else: + raise ValueError(f"Tracker {tracker} is not supported.") + return diff --git a/corrlib/trackers/datalad.py b/corrlib/trackers/datalad.py index e9a6e9f..4eccf0f 100644 --- a/corrlib/trackers/datalad.py +++ b/corrlib/trackers/datalad.py @@ -18,3 +18,8 @@ def save(path, message, files= None): files = [os.path.join(path, f) for f in files] dl.save(files, message=message, dataset=path) return + + +def create(path): + dl.create(path) + return