From 0626b34337686ca25f904ffeb47b8d0a767d952d Mon Sep 17 00:00:00 2001 From: Justus Kuhlmann Date: Thu, 4 Dec 2025 14:31:53 +0100 Subject: [PATCH] implement dynamic db name from config --- corrlib/find.py | 19 +++++++++++-------- corrlib/main.py | 28 ++++++++++++++++------------ corrlib/meas_io.py | 16 +++++++++------- corrlib/toml.py | 3 +++ corrlib/tools.py | 21 +++++++++++++++++++++ 5 files changed, 60 insertions(+), 27 deletions(-) diff --git a/corrlib/find.py b/corrlib/find.py index ac38044..901c09c 100644 --- a/corrlib/find.py +++ b/corrlib/find.py @@ -4,9 +4,10 @@ import json import pandas as pd import numpy as np from .input.implementations import codes -from .tools import k2m +from .tools import k2m, get_db_file from .tracker import get from typing import Any, Optional + # this will implement the search functionality @@ -143,10 +144,11 @@ def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: def find_record(path: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame: - db = path + '/backlogger.db' + db_file = get_db_file(path) + db = os.path.join(path, db_file) if code not in codes: raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes)) - get(path, "backlogger.db") + get(path, db_file) results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after) if code == "sfcf": results = sfcf_filter(results, **kwargs) @@ -155,14 +157,15 @@ def find_record(path: str, ensemble: str, correlator_name: str, code: str, proje def find_project(path: str, name: str) -> str: - get(path, "backlogger.db") - return _project_lookup_by_alias(os.path.join(path, "backlogger.db"), name) + db_file = get_db_file(path) + get(path, db_file) + return _project_lookup_by_alias(os.path.join(path, db_file), name) def list_projects(path: str) -> list[tuple[str, str]]: - db = path + '/backlogger.db' - get(path, "backlogger.db") - conn = sqlite3.connect(db) + db_file = get_db_file(path) + get(path, db_file) + conn = sqlite3.connect(os.path.join(path, db_file)) c = conn.cursor() c.execute("SELECT id,aliases FROM projects") results = c.fetchall() diff --git a/corrlib/main.py b/corrlib/main.py index dfed6ea..c4f2832 100644 --- a/corrlib/main.py +++ b/corrlib/main.py @@ -5,7 +5,7 @@ import os from .git_tools import move_submodule import shutil from .find import _project_lookup_by_id -from .tools import list2str, str2list +from .tools import list2str, str2list, get_db_file from .tracker import get, save from typing import Union, Optional @@ -25,8 +25,9 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni code: str (optional) The code that was used to create the measurements. """ - db = path + "/backlogger.db" - get(path, "backlogger.db") + db_file = get_db_file(path) + db = os.path.join(path, db_file) + get(path, db_file) conn = sqlite3.connect(db) c = conn.cursor() known_projects = c.execute("SELECT * FROM projects WHERE id=?", (uuid,)) @@ -43,12 +44,13 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", (uuid, alias_str, tag_str, owner, code)) conn.commit() conn.close() - save(path, message="Added entry for project " + uuid + " to database", files=["backlogger.db"]) + save(path, message="Added entry for project " + uuid + " to database", files=[db_file]) def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] = None) -> None: - get(path, "backlogger.db") - conn = sqlite3.connect(os.path.join(path, "backlogger.db")) + db_file = get_db_file(path) + get(path, db_file) + conn = sqlite3.connect(os.path.join(path, db_file)) c = conn.cursor() c.execute(f"UPDATE projects SET '{prop}' = '{value}' WHERE id == '{uuid}'") conn.commit() @@ -57,8 +59,8 @@ def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] def update_aliases(path: str, uuid: str, aliases: list[str]) -> None: - db = os.path.join(path, "backlogger.db") - get(path, "backlogger.db") + db_file = get_db_file(path) + get(path, db_file) known_data = _project_lookup_by_id(db, uuid)[0] known_aliases = known_data[1] @@ -77,9 +79,10 @@ def update_aliases(path: str, uuid: str, aliases: list[str]) -> None: if not len(new_alias_list) == len(known_alias_list): alias_str = list2str(new_alias_list) + db = os.path.join(path, db_file) dl.unlock(db, dataset=path) update_project_data(path, uuid, "aliases", alias_str) - save(path, message="Updated aliases for project " + uuid, files=["backlogger.db"]) + save(path, message="Updated aliases for project " + uuid, files=[db_file]) return @@ -122,13 +125,14 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti if not uuid: raise ValueError("The dataset does not have a uuid!") if not os.path.exists(path + "/projects/" + uuid): - db = path + "/backlogger.db" - get(path, "backlogger.db") + db_file = get_db_file(path) + db = os.path.join(path, db_file) + get(path, db_file) dl.unlock(db, dataset=path) create_project(path, uuid, owner, tags, aliases, code) move_submodule(path, 'projects/tmp', 'projects/' + uuid) os.mkdir(path + '/import_scripts/' + uuid) - save(path, message="Import project from " + url, files=['projects/' + uuid, 'backlogger.db']) + save(path, message="Import project from " + url, files=['projects/' + uuid, db_file]) else: dl.drop(tmp_path, reckless='kill') shutil.rmtree(tmp_path) diff --git a/corrlib/meas_io.py b/corrlib/meas_io.py index 645746d..d17750e 100644 --- a/corrlib/meas_io.py +++ b/corrlib/meas_io.py @@ -7,7 +7,7 @@ import json from typing import Union from pyerrors import Obs, Corr, dump_object, load_object from hashlib import sha256 -from .tools import cached +from .tools import get_db_file, cache_enabled from .tracker import get, save import shutil from typing import Any @@ -29,8 +29,9 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, uuid: str The uuid of the project. """ - db = os.path.join(path, 'backlogger.db') - get(path, "backlogger.db") + db_file = get_db_file(path) + db = os.path.join(path, db_file) + get(path, db_file) dl.unlock(db, dataset=path) conn = sqlite3.connect(db) c = conn.cursor() @@ -94,7 +95,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, (corr, ensemble, code, meas_path, uuid, pars[subkey], parameter_file)) conn.commit() pj.dump_dict_to_json(known_meas, file) - files.append(path + '/backlogger.db') + files.append(os.path.join(path, db_file)) conn.close() save(path, message="Add measurements to database", files=files) @@ -149,7 +150,7 @@ def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = { if file not in preloaded: preloaded[file] = preload(path, file) returned_data.append(preloaded[file][key]) - if cached: + if cache_enabled(path): if not os.path.exists(cache_dir(path, file)): os.makedirs(cache_dir(path, file)) dump_object(preloaded[file][key], cache_path(path, file, key)) @@ -179,8 +180,9 @@ def preload(path: str, file: str) -> dict[str, Any]: def drop_record(path: str, meas_path: str) -> None: file_in_archive = meas_path.split("::")[0] file = os.path.join(path, file_in_archive) - db = os.path.join(path, 'backlogger.db') - get(path, 'backlogger.db') + db_file = get_db_file(path) + db = os.path.join(path, db_file) + get(path, db_file) sub_key = meas_path.split("::")[1] dl.unlock(db, dataset=path) conn = sqlite3.connect(db) diff --git a/corrlib/toml.py b/corrlib/toml.py index 7b02e33..c1c4d5b 100644 --- a/corrlib/toml.py +++ b/corrlib/toml.py @@ -20,6 +20,7 @@ import os from .input.implementations import codes as known_codes from typing import Any + def replace_string(string: str, name: str, val: str) -> str: if '{' + name + '}' in string: n = string.replace('{' + name + '}', val) @@ -27,6 +28,7 @@ def replace_string(string: str, name: str, val: str) -> str: else: return string + def replace_in_meas(measurements: dict[str, dict[str, Any]], vars: dict[str, str]) -> dict[str, dict[str, Any]]: # replace global variables for name, value in vars.items(): @@ -39,6 +41,7 @@ def replace_in_meas(measurements: dict[str, dict[str, Any]], vars: dict[str, str measurements[m][key][i] = replace_string(measurements[m][key][i], name, value) return measurements + def fill_cons(measurements: dict[str, dict[str, Any]], constants: dict[str, str]) -> dict[str, dict[str, Any]]: for m in measurements.keys(): for name, val in constants.items(): diff --git a/corrlib/tools.py b/corrlib/tools.py index 77bfd2e..9c39d7c 100644 --- a/corrlib/tools.py +++ b/corrlib/tools.py @@ -2,6 +2,8 @@ import os from configparser import ConfigParser from typing import Any +CONFIG_FILENAME = ".corrlib" + def str2list(string: str) -> list[str]: return string.split(",") @@ -31,3 +33,22 @@ def set_config(path: str, section: str, option: str, value: Any) -> None: with open(config_path, 'w') as configfile: config.write(configfile) return + + +def get_db_file(path: str) -> str: + config_path = os.path.join(path, CONFIG_FILENAME) + config = ConfigParser() + if os.path.exists(config_path): + config.read(config_path) + db_file = config.get('paths', 'db', fallback='backlogger.db') + return db_file + + +def cache_enabled(path: str) -> bool: + config_path = os.path.join(path, CONFIG_FILENAME) + config = ConfigParser() + if os.path.exists(config_path): + config.read(config_path) + cached_str = config.get('core', 'cached', fallback='True') + cached_bool = cached_str == ('True') + return cached_bool