diff --git a/corrlib/main.py b/corrlib/main.py index d2cbc6a..f6e3e0e 100644 --- a/corrlib/main.py +++ b/corrlib/main.py @@ -6,7 +6,7 @@ from .git_tools import move_submodule import shutil from .find import _project_lookup_by_id from .tools import list2str, str2list, get_db_file -from .tracker import get, save, unlock, init, clone, drop +from .tracker import get, save from typing import Union, Optional @@ -34,7 +34,7 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni if known_projects.fetchone(): raise ValueError("Project already imported, use update_project() instead.") - unlock(path, db_file) + dl.unlock(db, dataset=path) alias_str = "" if aliases is not None: alias_str = list2str(aliases) @@ -80,7 +80,7 @@ def update_aliases(path: str, uuid: str, aliases: list[str]) -> None: if not len(new_alias_list) == len(known_alias_list): alias_str = list2str(new_alias_list) - unlock(path, db_file) + dl.unlock(db, dataset=path) update_project_data(path, uuid, "aliases", alias_str) save(path, message="Updated aliases for project " + uuid, files=[db_file]) return @@ -113,7 +113,12 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti """ tmp_path = os.path.join(path, 'projects/tmp') - clone(path, source=url, target=tmp_path) + if not isDataset: + dl.create(tmp_path, dataset=path) + shutil.copytree(url + "/*", path + '/projects/tmp/') + save(path, message="Created temporary project dataset", files=['projects/tmp']) + else: + dl.install(path=tmp_path, source=url, dataset=path) tmp_ds = dl.Dataset(tmp_path) conf = dlc.ConfigManager(tmp_ds) uuid = str(conf.get("datalad.dataset.id")) @@ -121,8 +126,9 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti raise ValueError("The dataset does not have a uuid!") if not os.path.exists(path + "/projects/" + uuid): db_file = get_db_file(path) + db = os.path.join(path, db_file) get(path, db_file) - unlock(path, db_file) + dl.unlock(db, dataset=path) create_project(path, uuid, owner, tags, aliases, code) move_submodule(path, 'projects/tmp', 'projects/' + uuid) os.mkdir(path + '/import_scripts/' + uuid) @@ -145,6 +151,6 @@ def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None: """ Drop (parts of) a project to free up diskspace """ - drop(path + "/projects/" + uuid + "/" + path_in_project) + dl.drop(path + "/projects/" + uuid + "/" + path_in_project) return diff --git a/corrlib/meas_io.py b/corrlib/meas_io.py index aec7891..d17750e 100644 --- a/corrlib/meas_io.py +++ b/corrlib/meas_io.py @@ -8,7 +8,7 @@ from typing import Union from pyerrors import Obs, Corr, dump_object, load_object from hashlib import sha256 from .tools import get_db_file, cache_enabled -from .tracker import get, save, unlock +from .tracker import get, save import shutil from typing import Any @@ -32,7 +32,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, db_file = get_db_file(path) db = os.path.join(path, db_file) get(path, db_file) - unlock(path, db_file) + dl.unlock(db, dataset=path) conn = sqlite3.connect(db) c = conn.cursor() files = [] @@ -45,7 +45,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, os.makedirs(os.path.join(path, '.', 'archive', ensemble, corr)) else: if os.path.exists(file): - unlock(path, file_in_archive) + dl.unlock(file, dataset=path) known_meas = pj.load_json_dict(file) if code == "sfcf": parameters = sfcf.read_param(path, uuid, parameter_file) @@ -184,7 +184,7 @@ def drop_record(path: str, meas_path: str) -> None: db = os.path.join(path, db_file) get(path, db_file) sub_key = meas_path.split("::")[1] - unlock(path, db_file) + dl.unlock(db, dataset=path) conn = sqlite3.connect(db) c = conn.cursor() if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path, )).fetchone() is not None: @@ -196,7 +196,7 @@ def drop_record(path: str, meas_path: str) -> None: known_meas = pj.load_json_dict(file) if sub_key in known_meas: del known_meas[sub_key] - unlock(path, file_in_archive) + dl.unlock(file, dataset=path) pj.dump_dict_to_json(known_meas, file) save(path, message="Drop measurements to database", files=[db, file]) return diff --git a/corrlib/tracker.py b/corrlib/tracker.py index 565263a..d41d2d8 100644 --- a/corrlib/tracker.py +++ b/corrlib/tracker.py @@ -1,9 +1,7 @@ import os from configparser import ConfigParser -import datalad.api as dl +from .trackers import datalad as dl from typing import Optional -import shutil -from .tools import get_db_file def get_tracker(path: str) -> str: @@ -20,12 +18,7 @@ def get_tracker(path: str) -> str: def get(path: str, file: str) -> None: tracker = get_tracker(path) if tracker == 'datalad': - if file == get_db_file(path): - print("Downloading database...") - else: - print("Downloading data...") - dl.get(os.path.join(path, file), dataset=path) - print("> downloaded file") + dl.get(path, file) elif tracker == 'None': pass else: @@ -36,9 +29,7 @@ def get(path: str, file: str) -> None: def save(path: str, message: str, files: Optional[list[str]]=None) -> None: tracker = get_tracker(path) if tracker == 'datalad': - if files is not None: - files = [os.path.join(path, f) for f in files] - dl.save(files, message=message, dataset=path) + dl.save(path, message, files) elif tracker == 'None': pass else: @@ -53,38 +44,3 @@ def init(path: str, tracker: str='datalad') -> None: else: raise ValueError(f"Tracker {tracker} is not supported.") return - - -def unlock(path: str, file: str) -> None: - tracker = get_tracker(path) - if tracker == 'datalad': - dl.unlock(file, dataset=path) - elif tracker == 'None': - pass - else: - raise ValueError(f"Tracker {tracker} is not supported.") - return - - -def clone(path: str, source: str, target: str) -> None: - tracker = get_tracker(path) - if tracker == 'datalad': - dl.clone(target=target, source=source, dataset=path) - elif tracker == 'None': - os.makedirs(path, exist_ok=True) - # Implement a simple clone by copying files - shutil.copytree(source, target, dirs_exist_ok=False) - else: - raise ValueError(f"Tracker {tracker} is not supported.") - return - - -def drop(path: str, reckless: Optional[str]=None) -> None: - tracker = get_tracker(path) - if tracker == 'datalad': - dl.drop(path, reckless=reckless) - elif tracker == 'None': - shutil.rmtree(path) - else: - raise ValueError(f"Tracker {tracker} is not supported.") - return \ No newline at end of file diff --git a/corrlib/trackers/datalad.py b/corrlib/trackers/datalad.py new file mode 100644 index 0000000..c4e3e70 --- /dev/null +++ b/corrlib/trackers/datalad.py @@ -0,0 +1,25 @@ +import datalad.api as dl +import os +from typing import Optional + + +def get(path: str, file: str) -> None: + if file == "backlogger.db": + print("Downloading database...") + else: + print("Downloading data...") + dl.get(os.path.join(path, file), dataset=path) + print("> downloaded file") + return + + +def save(path: str, message: str, files: Optional[list[str]]=None) -> None: + if files is not None: + files = [os.path.join(path, f) for f in files] + dl.save(files, message=message, dataset=path) + return + + +def create(path: str) -> None: + dl.create(path) + return