diff --git a/corrlib/meas_io.py b/corrlib/meas_io.py index 594e668..6a50f7e 100644 --- a/corrlib/meas_io.py +++ b/corrlib/meas_io.py @@ -8,7 +8,7 @@ from typing import Union from pyerrors import Obs, Corr, dump_object, load_object from hashlib import sha256 from .tools import cached -from .tracker import get_file +from .tracker import get import shutil @@ -29,7 +29,7 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=No The uuid of the project. """ db = os.path.join(path, 'backlogger.db') - get_file(path, "backlogger.db") + get(path, "backlogger.db") dl.unlock(db, dataset=path) conn = sqlite3.connect(db) c = conn.cursor() @@ -169,7 +169,7 @@ def cache_path(path, file, key): def preload(path: str, file: str): - get_file(path, file) + get(path, file) filedict = pj.load_json_dict(os.path.join(path, file)) print("> read file") return filedict @@ -179,7 +179,7 @@ def drop_record(path: str, meas_path: str): file_in_archive = meas_path.split("::")[0] file = os.path.join(path, file_in_archive) db = os.path.join(path, 'backlogger.db') - get_file(path, 'backlogger.db') + get(path, 'backlogger.db') sub_key = meas_path.split("::")[1] dl.unlock(db, dataset=path) conn = sqlite3.connect(db) diff --git a/corrlib/tools.py b/corrlib/tools.py index 6bc8b27..1aaa4d3 100644 --- a/corrlib/tools.py +++ b/corrlib/tools.py @@ -1,5 +1,5 @@ import os -import datalad.api as dl +from configparser import ConfigParser def str2list(string): @@ -17,4 +17,16 @@ def m2k(m): def k2m(k): return (1/(2*k))-4 - \ No newline at end of file + + +def set_config(path, section, option, value): + config_path = os.path.join(path, '.corrlib') + config = ConfigParser() + if os.path.exists(config_path): + config.read(config_path) + if not config.has_section(section): + config.add_section(section) + config.set(section, option, value) + with open(config_path, 'w') as configfile: + config.write(configfile) + return diff --git a/corrlib/tracker.py b/corrlib/tracker.py index 7230386..fcf3994 100644 --- a/corrlib/tracker.py +++ b/corrlib/tracker.py @@ -12,7 +12,7 @@ def get_tracker(path): return tracker -def get_file(path, file): +def get(path, file): tracker = get_tracker(path) if tracker == 'datalad': dl.get_file(path, file) @@ -20,3 +20,10 @@ def get_file(path, file): raise ValueError(f"Tracker {tracker} is not supported.") return + +def save(path, message, files): + tracker = get_tracker(path) + if tracker == 'datalad': + dl.save(files, message=message, dataset=path) + else: + raise ValueError(f"Tracker {tracker} is not supported.")