From d5a48b91f07b7ec31750dee81e39af656d7bb614 Mon Sep 17 00:00:00 2001 From: Justus Kuhlmann Date: Thu, 4 Dec 2025 11:07:33 +0100 Subject: [PATCH] implement save method --- corrlib/find.py | 8 ++++---- corrlib/trackers/datalad.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/corrlib/find.py b/corrlib/find.py index 6aa795c..57a4d1a 100644 --- a/corrlib/find.py +++ b/corrlib/find.py @@ -6,7 +6,7 @@ import pandas as pd import numpy as np from .input.implementations import codes from .tools import k2m -from .tracker import get_file +from .tracker import get # this will implement the search functionality @@ -144,7 +144,7 @@ def find_record(path, ensemble, correlator_name, code, project=None, parameters= db = path + '/backlogger.db' if code not in codes: raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes)) - get_file(path, "backlogger.db") + get(path, "backlogger.db") results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after, revision=revision) if code == "sfcf": results = sfcf_filter(results, **kwargs) @@ -153,13 +153,13 @@ def find_record(path, ensemble, correlator_name, code, project=None, parameters= def find_project(path, name): - get_file(path, "backlogger.db") + get(path, "backlogger.db") return _project_lookup_by_alias(os.path.join(path, "backlogger.db"), name) def list_projects(path): db = path + '/backlogger.db' - get_file(path, "backlogger.db") + get(path, "backlogger.db") conn = sqlite3.connect(db) c = conn.cursor() c.execute("SELECT id,aliases FROM projects") diff --git a/corrlib/trackers/datalad.py b/corrlib/trackers/datalad.py index 5d3deaa..e9a6e9f 100644 --- a/corrlib/trackers/datalad.py +++ b/corrlib/trackers/datalad.py @@ -2,10 +2,19 @@ import datalad.api as dl import os -def get_file(path, file): +def get(path, file): if file == "backlogger.db": print("Downloading database...") else: print("Downloading data...") dl.get(os.path.join(path, file), dataset=path) - print("> downloaded file") \ No newline at end of file + print("> downloaded file") + + +def save(path, message, files= None): + if files is None: + files = path + else: + files = [os.path.join(path, f) for f in files] + dl.save(files, message=message, dataset=path) + return