Merge pull request 'refactor/data_backend' (#12) from refactor/data_backend into develop
All checks were successful
Mypy / mypy (push) Successful in 48s
Pytest / pytest (3.12) (push) Successful in 52s
Pytest / pytest (3.14) (push) Successful in 47s
Ruff / ruff (push) Successful in 34s
Pytest / pytest (3.13) (push) Successful in 48s

Reviewed-on: https://www.kuhl-mann.de/git/git/jkuhl/corrlib/pulls/12
This commit is contained in:
Justus Kuhlmann 2025-12-04 15:47:45 +01:00
commit 155e6d952e
11 changed files with 358 additions and 90 deletions

View file

@ -26,9 +26,9 @@ def update(
str('./corrlib'), str('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
uuid: str = typer.Argument(), uuid: str = typer.Argument(),
) -> None: ) -> None:
""" """
Update a project by it's UUID. Update a project by it's UUID.
""" """
@ -43,7 +43,7 @@ def list(
"-d", "-d",
), ),
entities: str = typer.Argument('ensembles'), entities: str = typer.Argument('ensembles'),
) -> None: ) -> None:
""" """
List entities. (ensembles, projects) List entities. (ensembles, projects)
""" """
@ -72,10 +72,10 @@ def alias_add(
str('./corrlib'), str('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
uuid: str = typer.Argument(), uuid: str = typer.Argument(),
alias: str = typer.Argument(), alias: str = typer.Argument(),
) -> None: ) -> None:
""" """
Add an alias to a project UUID. Add an alias to a project UUID.
""" """
@ -90,11 +90,11 @@ def find(
str('./corrlib'), str('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
ensemble: str = typer.Argument(), ensemble: str = typer.Argument(),
corr: str = typer.Argument(), corr: str = typer.Argument(),
code: str = typer.Argument(), code: str = typer.Argument(),
) -> None: ) -> None:
""" """
Find a record in the backlog at hand. Through specifying it's ensemble and the measured correlator. Find a record in the backlog at hand. Through specifying it's ensemble and the measured correlator.
""" """
@ -108,15 +108,15 @@ def importer(
str('./corrlib'), str('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
files: str = typer.Argument( files: str = typer.Argument(
), ),
copy_file: bool = typer.Option( copy_file: bool = typer.Option(
bool(True), bool(True),
"--save", "--save",
"-s", "-s",
), ),
) -> None: ) -> None:
""" """
Import a project from a .toml-file via CLI. Import a project from a .toml-file via CLI.
""" """
@ -152,12 +152,17 @@ def init(
str('./corrlib'), str('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
) -> None: tracker: str = typer.Option(
str('datalad'),
"--tracker",
"-t",
),
) -> None:
""" """
Initialize a new backlog-database. Initialize a new backlog-database.
""" """
create(path) create(path, tracker)
return return
@ -167,8 +172,8 @@ def drop_cache(
str('./corrlib'), str('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
) -> None: ) -> None:
""" """
Drop the currect cache directory of the dataset. Drop the currect cache directory of the dataset.
""" """
@ -185,6 +190,6 @@ def main(
help="Show the application's version and exit.", help="Show the application's version and exit.",
callback=_version_callback, callback=_version_callback,
is_eager=True, is_eager=True,
) )
) -> None: ) -> None:
return return

View file

@ -4,8 +4,10 @@ import json
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from .input.implementations import codes from .input.implementations import codes
from .tools import k2m, get_file from .tools import k2m, get_db_file
from .tracker import get
from typing import Any, Optional from typing import Any, Optional
# this will implement the search functionality # this will implement the search functionality
@ -142,10 +144,11 @@ def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
def find_record(path: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, def find_record(path: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None,
created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame: created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame:
db = path + '/backlogger.db' db_file = get_db_file(path)
db = os.path.join(path, db_file)
if code not in codes: if code not in codes:
raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes)) raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes))
get_file(path, "backlogger.db") get(path, db_file)
results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after) results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after)
if code == "sfcf": if code == "sfcf":
results = sfcf_filter(results, **kwargs) results = sfcf_filter(results, **kwargs)
@ -154,14 +157,15 @@ def find_record(path: str, ensemble: str, correlator_name: str, code: str, proje
def find_project(path: str, name: str) -> str: def find_project(path: str, name: str) -> str:
get_file(path, "backlogger.db") db_file = get_db_file(path)
return _project_lookup_by_alias(os.path.join(path, "backlogger.db"), name) get(path, db_file)
return _project_lookup_by_alias(os.path.join(path, db_file), name)
def list_projects(path: str) -> list[tuple[str, str]]: def list_projects(path: str) -> list[tuple[str, str]]:
db = path + '/backlogger.db' db_file = get_db_file(path)
get_file(path, "backlogger.db") get(path, db_file)
conn = sqlite3.connect(db) conn = sqlite3.connect(os.path.join(path, db_file))
c = conn.cursor() c = conn.cursor()
c.execute("SELECT id,aliases FROM projects") c.execute("SELECT id,aliases FROM projects")
results = c.fetchall() results = c.fetchall()

View file

@ -1,5 +1,5 @@
import os import os
import datalad.api as dl from .tracker import save
import git import git
GITMODULES_FILE = '.gitmodules' GITMODULES_FILE = '.gitmodules'
@ -40,5 +40,6 @@ def move_submodule(repo_path: str, old_path: str, new_path: str) -> None:
repo = git.Repo(repo_path) repo = git.Repo(repo_path)
repo.git.add('.gitmodules') repo.git.add('.gitmodules')
# save new state of the dataset # save new state of the dataset
dl.save(repo_path, message=f"Move module from {old_path} to {new_path}", dataset=repo_path) save(repo_path, message=f"Move module from {old_path} to {new_path}", files=['.gitmodules', repo_path])
return return

View file

@ -1,6 +1,7 @@
from configparser import ConfigParser
import sqlite3 import sqlite3
import datalad.api as dl
import os import os
from .tracker import save, init
def _create_db(db: str) -> None: def _create_db(db: str) -> None:
@ -35,20 +36,52 @@ def _create_db(db: str) -> None:
return return
def create(path: str) -> None: def _create_config(path: str, tracker: str, cached: bool) -> ConfigParser:
"""
Create the config file for backlogger.
"""
config = ConfigParser()
config['core'] = {
'version': '1.0',
'tracker': tracker,
'cached': str(cached),
}
config['paths'] = {
'db': 'backlogger.db',
'projects_path': 'projects',
'archive_path': 'archive',
'toml_imports_path': 'toml_imports',
'import_scripts_path': 'import_scripts',
}
return config
def _write_config(path: str, config: ConfigParser) -> None:
"""
Write the config file to disk.
"""
with open(os.path.join(path, '.corrlib'), 'w') as configfile:
config.write(configfile)
return
def create(path: str, tracker: str = 'datalad', cached: bool = True) -> None:
""" """
Create folder of backlogs. Create folder of backlogs.
""" """
dl.create(path) config = _create_config(path, tracker, cached)
_create_db(path + '/backlogger.db') init(path, tracker)
os.chmod(path + '/backlogger.db', 0o666) # why does this not work? _write_config(path, config)
os.makedirs(path + '/projects') _create_db(os.path.join(path, config['paths']['db']))
os.makedirs(path + '/archive') os.chmod(os.path.join(path, config['paths']['db']), 0o666)
os.makedirs(path + '/toml_imports') os.makedirs(os.path.join(path, config['paths']['projects_path']))
os.makedirs(path + '/import_scripts/template.py') os.makedirs(os.path.join(path, config['paths']['archive_path']))
with open(path + "/.gitignore", "w") as fp: os.makedirs(os.path.join(path, config['paths']['toml_imports_path']))
os.makedirs(os.path.join(path, config['paths']['import_scripts_path'], 'template.py'))
with open(os.path.join(path, ".gitignore"), "w") as fp:
fp.write(".cache") fp.write(".cache")
fp.close() fp.close()
dl.save(path, dataset=path, message="Initialize backlogger directory.") save(path, message="Initialized correlator library")
return return

View file

@ -5,7 +5,8 @@ import os
from .git_tools import move_submodule from .git_tools import move_submodule
import shutil import shutil
from .find import _project_lookup_by_id from .find import _project_lookup_by_id
from .tools import list2str, str2list, get_file from .tools import list2str, str2list, get_db_file
from .tracker import get, save, unlock, clone, drop
from typing import Union, Optional from typing import Union, Optional
@ -24,15 +25,16 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni
code: str (optional) code: str (optional)
The code that was used to create the measurements. The code that was used to create the measurements.
""" """
db = path + "/backlogger.db" db_file = get_db_file(path)
get_file(path, "backlogger.db") db = os.path.join(path, db_file)
get(path, db_file)
conn = sqlite3.connect(db) conn = sqlite3.connect(db)
c = conn.cursor() c = conn.cursor()
known_projects = c.execute("SELECT * FROM projects WHERE id=?", (uuid,)) known_projects = c.execute("SELECT * FROM projects WHERE id=?", (uuid,))
if known_projects.fetchone(): if known_projects.fetchone():
raise ValueError("Project already imported, use update_project() instead.") raise ValueError("Project already imported, use update_project() instead.")
dl.unlock(db, dataset=path) unlock(path, db_file)
alias_str = "" alias_str = ""
if aliases is not None: if aliases is not None:
alias_str = list2str(aliases) alias_str = list2str(aliases)
@ -42,12 +44,13 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni
c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", (uuid, alias_str, tag_str, owner, code)) c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", (uuid, alias_str, tag_str, owner, code))
conn.commit() conn.commit()
conn.close() conn.close()
dl.save(db, message="Added entry for project " + uuid + " to database", dataset=path) save(path, message="Added entry for project " + uuid + " to database", files=[db_file])
def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] = None) -> None: def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] = None) -> None:
get_file(path, "backlogger.db") db_file = get_db_file(path)
conn = sqlite3.connect(os.path.join(path, "backlogger.db")) get(path, db_file)
conn = sqlite3.connect(os.path.join(path, db_file))
c = conn.cursor() c = conn.cursor()
c.execute(f"UPDATE projects SET '{prop}' = '{value}' WHERE id == '{uuid}'") c.execute(f"UPDATE projects SET '{prop}' = '{value}' WHERE id == '{uuid}'")
conn.commit() conn.commit()
@ -56,8 +59,9 @@ def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None]
def update_aliases(path: str, uuid: str, aliases: list[str]) -> None: def update_aliases(path: str, uuid: str, aliases: list[str]) -> None:
db = os.path.join(path, "backlogger.db") db_file = get_db_file(path)
get_file(path, "backlogger.db") db = os.path.join(path, db_file)
get(path, db_file)
known_data = _project_lookup_by_id(db, uuid)[0] known_data = _project_lookup_by_id(db, uuid)[0]
known_aliases = known_data[1] known_aliases = known_data[1]
@ -76,9 +80,9 @@ def update_aliases(path: str, uuid: str, aliases: list[str]) -> None:
if not len(new_alias_list) == len(known_alias_list): if not len(new_alias_list) == len(known_alias_list):
alias_str = list2str(new_alias_list) alias_str = list2str(new_alias_list)
dl.unlock(db, dataset=path) unlock(path, db_file)
update_project_data(path, uuid, "aliases", alias_str) update_project_data(path, uuid, "aliases", alias_str)
dl.save(db, dataset=path) save(path, message="Updated aliases for project " + uuid, files=[db_file])
return return
@ -108,26 +112,21 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti
in order to receive a uuid and have a consistent interface. in order to receive a uuid and have a consistent interface.
""" """
tmp_path = path + '/projects/tmp' tmp_path = os.path.join(path, 'projects/tmp')
if not isDataset: clone(path, source=url, target=tmp_path)
dl.create(tmp_path, dataset=path)
shutil.copytree(url + "/*", path + '/projects/tmp/')
dl.save(tmp_path, dataset=path)
else:
dl.install(path=tmp_path, source=url, dataset=path)
tmp_ds = dl.Dataset(tmp_path) tmp_ds = dl.Dataset(tmp_path)
conf = dlc.ConfigManager(tmp_ds) conf = dlc.ConfigManager(tmp_ds)
uuid = str(conf.get("datalad.dataset.id")) uuid = str(conf.get("datalad.dataset.id"))
if not uuid: if not uuid:
raise ValueError("The dataset does not have a uuid!") raise ValueError("The dataset does not have a uuid!")
if not os.path.exists(path + "/projects/" + uuid): if not os.path.exists(path + "/projects/" + uuid):
db = path + "/backlogger.db" db_file = get_db_file(path)
get_file(path, "backlogger.db") get(path, db_file)
dl.unlock(db, dataset=path) unlock(path, db_file)
create_project(path, uuid, owner, tags, aliases, code) create_project(path, uuid, owner, tags, aliases, code)
move_submodule(path, 'projects/tmp', 'projects/' + uuid) move_submodule(path, 'projects/tmp', 'projects/' + uuid)
os.mkdir(path + '/import_scripts/' + uuid) os.mkdir(path + '/import_scripts/' + uuid)
dl.save([db, path + '/projects/' + uuid], message="Import project from " + url, dataset=path) save(path, message="Import project from " + url, files=['projects/' + uuid, db_file])
else: else:
dl.drop(tmp_path, reckless='kill') dl.drop(tmp_path, reckless='kill')
shutil.rmtree(tmp_path) shutil.rmtree(tmp_path)
@ -144,8 +143,8 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti
def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None: def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None:
""" """
Drop (parts of) a prject to free up diskspace Drop (parts of) a project to free up diskspace
""" """
dl.drop(path + "/projects/" + uuid + "/" + path_in_project) drop(path + "/projects/" + uuid + "/" + path_in_project)
return return

View file

@ -1,13 +1,13 @@
from pyerrors.input import json as pj from pyerrors.input import json as pj
import os import os
import datalad.api as dl
import sqlite3 import sqlite3
from .input import sfcf,openQCD from .input import sfcf,openQCD
import json import json
from typing import Union from typing import Union
from pyerrors import Obs, Corr, dump_object, load_object from pyerrors import Obs, Corr, dump_object, load_object
from hashlib import sha256 from hashlib import sha256
from .tools import cached, get_file from .tools import get_db_file, cache_enabled
from .tracker import get, save, unlock
import shutil import shutil
from typing import Any from typing import Any
@ -28,9 +28,10 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str,
uuid: str uuid: str
The uuid of the project. The uuid of the project.
""" """
db = os.path.join(path, 'backlogger.db') db_file = get_db_file(path)
get_file(path, "backlogger.db") db = os.path.join(path, db_file)
dl.unlock(db, dataset=path) get(path, db_file)
unlock(path, db_file)
conn = sqlite3.connect(db) conn = sqlite3.connect(db)
c = conn.cursor() c = conn.cursor()
files = [] files = []
@ -43,7 +44,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str,
os.makedirs(os.path.join(path, '.', 'archive', ensemble, corr)) os.makedirs(os.path.join(path, '.', 'archive', ensemble, corr))
else: else:
if os.path.exists(file): if os.path.exists(file):
dl.unlock(file, dataset=path) unlock(path, file_in_archive)
known_meas = pj.load_json_dict(file) known_meas = pj.load_json_dict(file)
if code == "sfcf": if code == "sfcf":
parameters = sfcf.read_param(path, uuid, parameter_file) parameters = sfcf.read_param(path, uuid, parameter_file)
@ -93,9 +94,9 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str,
(corr, ensemble, code, meas_path, uuid, pars[subkey], parameter_file)) (corr, ensemble, code, meas_path, uuid, pars[subkey], parameter_file))
conn.commit() conn.commit()
pj.dump_dict_to_json(known_meas, file) pj.dump_dict_to_json(known_meas, file)
files.append(path + '/backlogger.db') files.append(os.path.join(path, db_file))
conn.close() conn.close()
dl.save(files, message="Add measurements to database", dataset=path) save(path, message="Add measurements to database", files=files)
def load_record(path: str, meas_path: str) -> Union[Corr, Obs]: def load_record(path: str, meas_path: str) -> Union[Corr, Obs]:
@ -148,7 +149,7 @@ def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = {
if file not in preloaded: if file not in preloaded:
preloaded[file] = preload(path, file) preloaded[file] = preload(path, file)
returned_data.append(preloaded[file][key]) returned_data.append(preloaded[file][key])
if cached: if cache_enabled(path):
if not os.path.exists(cache_dir(path, file)): if not os.path.exists(cache_dir(path, file)):
os.makedirs(cache_dir(path, file)) os.makedirs(cache_dir(path, file))
dump_object(preloaded[file][key], cache_path(path, file, key)) dump_object(preloaded[file][key], cache_path(path, file, key))
@ -169,7 +170,7 @@ def cache_path(path: str, file: str, key: str) -> str:
def preload(path: str, file: str) -> dict[str, Any]: def preload(path: str, file: str) -> dict[str, Any]:
get_file(path, file) get(path, file)
filedict: dict[str, Any] = pj.load_json_dict(os.path.join(path, file)) filedict: dict[str, Any] = pj.load_json_dict(os.path.join(path, file))
print("> read file") print("> read file")
return filedict return filedict
@ -178,10 +179,11 @@ def preload(path: str, file: str) -> dict[str, Any]:
def drop_record(path: str, meas_path: str) -> None: def drop_record(path: str, meas_path: str) -> None:
file_in_archive = meas_path.split("::")[0] file_in_archive = meas_path.split("::")[0]
file = os.path.join(path, file_in_archive) file = os.path.join(path, file_in_archive)
db = os.path.join(path, 'backlogger.db') db_file = get_db_file(path)
get_file(path, 'backlogger.db') db = os.path.join(path, db_file)
get(path, db_file)
sub_key = meas_path.split("::")[1] sub_key = meas_path.split("::")[1]
dl.unlock(db, dataset=path) unlock(path, db_file)
conn = sqlite3.connect(db) conn = sqlite3.connect(db)
c = conn.cursor() c = conn.cursor()
if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path, )).fetchone() is not None: if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path, )).fetchone() is not None:
@ -193,9 +195,9 @@ def drop_record(path: str, meas_path: str) -> None:
known_meas = pj.load_json_dict(file) known_meas = pj.load_json_dict(file)
if sub_key in known_meas: if sub_key in known_meas:
del known_meas[sub_key] del known_meas[sub_key]
dl.unlock(file, dataset=path) unlock(path, file_in_archive)
pj.dump_dict_to_json(known_meas, file) pj.dump_dict_to_json(known_meas, file)
dl.save([db, file], message="Drop measurements to database", dataset=path) save(path, message="Drop measurements to database", files=[db, file])
return return
else: else:
raise ValueError("This measurement does not exist as a file!") raise ValueError("This measurement does not exist as a file!")

View file

@ -10,14 +10,17 @@ the import of projects via TOML.
import tomllib as toml import tomllib as toml
import shutil import shutil
import datalad.api as dl
from .tracker import save
from .input import sfcf, openQCD from .input import sfcf, openQCD
from .main import import_project, update_aliases from .main import import_project, update_aliases
from .meas_io import write_measurement from .meas_io import write_measurement
import datalad.api as dl
import os import os
from .input.implementations import codes as known_codes from .input.implementations import codes as known_codes
from typing import Any from typing import Any
def replace_string(string: str, name: str, val: str) -> str: def replace_string(string: str, name: str, val: str) -> str:
if '{' + name + '}' in string: if '{' + name + '}' in string:
n = string.replace('{' + name + '}', val) n = string.replace('{' + name + '}', val)
@ -25,6 +28,7 @@ def replace_string(string: str, name: str, val: str) -> str:
else: else:
return string return string
def replace_in_meas(measurements: dict[str, dict[str, Any]], vars: dict[str, str]) -> dict[str, dict[str, Any]]: def replace_in_meas(measurements: dict[str, dict[str, Any]], vars: dict[str, str]) -> dict[str, dict[str, Any]]:
# replace global variables # replace global variables
for name, value in vars.items(): for name, value in vars.items():
@ -37,6 +41,7 @@ def replace_in_meas(measurements: dict[str, dict[str, Any]], vars: dict[str, str
measurements[m][key][i] = replace_string(measurements[m][key][i], name, value) measurements[m][key][i] = replace_string(measurements[m][key][i], name, value)
return measurements return measurements
def fill_cons(measurements: dict[str, dict[str, Any]], constants: dict[str, str]) -> dict[str, dict[str, Any]]: def fill_cons(measurements: dict[str, dict[str, Any]], constants: dict[str, str]) -> dict[str, dict[str, Any]]:
for m in measurements.keys(): for m in measurements.keys():
for name, val in constants.items(): for name, val in constants.items():
@ -150,7 +155,7 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None:
if copy_file: if copy_file:
import_file = os.path.join(path, "toml_imports", uuid, file.split("/")[-1]) import_file = os.path.join(path, "toml_imports", uuid, file.split("/")[-1])
shutil.copy(file, import_file) shutil.copy(file, import_file)
dl.save(import_file, message="Import using " + import_file, dataset=path) save(path, files=[import_file], message="Import using " + import_file)
print("File copied to " + import_file) print("File copied to " + import_file)
print("Imported project.") print("Imported project.")
return return

View file

@ -1,5 +1,8 @@
import os import os
import datalad.api as dl from configparser import ConfigParser
from typing import Any
CONFIG_FILENAME = ".corrlib"
def str2list(string: str) -> list[str]: def str2list(string: str) -> list[str]:
@ -19,11 +22,33 @@ def k2m(k: float) -> float:
return (1/(2*k))-4 return (1/(2*k))-4
def get_file(path: str, file: str) -> None: def set_config(path: str, section: str, option: str, value: Any) -> None:
if file == "backlogger.db": config_path = os.path.join(path, '.corrlib')
print("Downloading database...") config = ConfigParser()
else: if os.path.exists(config_path):
print("Downloading data...") config.read(config_path)
dl.get(os.path.join(path, file), dataset=path) if not config.has_section(section):
print("> downloaded file") config.add_section(section)
config.set(section, option, value)
with open(config_path, 'w') as configfile:
config.write(configfile)
return return
def get_db_file(path: str) -> str:
config_path = os.path.join(path, CONFIG_FILENAME)
config = ConfigParser()
if os.path.exists(config_path):
config.read(config_path)
db_file = config.get('paths', 'db', fallback='backlogger.db')
return db_file
def cache_enabled(path: str) -> bool:
config_path = os.path.join(path, CONFIG_FILENAME)
config = ConfigParser()
if os.path.exists(config_path):
config.read(config_path)
cached_str = config.get('core', 'cached', fallback='True')
cached_bool = cached_str == ('True')
return cached_bool

169
corrlib/tracker.py Normal file
View file

@ -0,0 +1,169 @@
import os
from configparser import ConfigParser
import datalad.api as dl
from typing import Optional
import shutil
from .tools import get_db_file
def get_tracker(path: str) -> str:
"""
Get the tracker used in the dataset located at path.
Parameters
----------
path: str
The path to the backlogger folder.
Returns
-------
tracker: str
The tracker used in the dataset.
"""
config_path = os.path.join(path, '.corrlib')
config = ConfigParser()
if os.path.exists(config_path):
config.read(config_path)
else:
raise FileNotFoundError(f"No config file found in {path}.")
tracker = config.get('core', 'tracker', fallback='datalad')
return tracker
def get(path: str, file: str) -> None:
"""
Wrapper function to get a file from the dataset located at path with the specified tracker.
Parameters
----------
path: str
The path to the backlogger folder.
file: str
The file to get.
"""
tracker = get_tracker(path)
if tracker == 'datalad':
if file == get_db_file(path):
print("Downloading database...")
else:
print("Downloading data...")
dl.get(os.path.join(path, file), dataset=path)
print("> downloaded file")
elif tracker == 'None':
pass
else:
raise ValueError(f"Tracker {tracker} is not supported.")
return
def save(path: str, message: str, files: Optional[list[str]]=None) -> None:
"""
Wrapper function to save a file to the dataset located at path with the specified tracker.
Parameters
----------
path: str
The path to the backlogger folder.
message: str
The commit message.
files: list[str], optional
The files to save. If None, all changes are saved.
"""
tracker = get_tracker(path)
if tracker == 'datalad':
if files is not None:
files = [os.path.join(path, f) for f in files]
dl.save(files, message=message, dataset=path)
elif tracker == 'None':
Warning("Tracker 'None' does not implement save.")
pass
else:
raise ValueError(f"Tracker {tracker} is not supported.")
def init(path: str, tracker: str='datalad') -> None:
"""
Initialize a dataset at the specified path with the specified tracker.
Parameters
----------
path: str
The path to initialize the dataset.
tracker: str
The tracker to use. Currently only 'datalad' and 'None' are supported.
"""
if tracker == 'datalad':
dl.create(path)
elif tracker == 'None':
os.makedirs(path, exist_ok=True)
else:
raise ValueError(f"Tracker {tracker} is not supported.")
return
def unlock(path: str, file: str) -> None:
"""
Wrapper function to unlock a file in the dataset located at path with the specified tracker.
Parameters
----------
path : str
The path to the backlogger folder.
file : str
The file to unlock.
"""
tracker = get_tracker(path)
if tracker == 'datalad':
dl.unlock(file, dataset=path)
elif tracker == 'None':
Warning("Tracker 'None' does not implement unlock.")
pass
else:
raise ValueError(f"Tracker {tracker} is not supported.")
return
def clone(path: str, source: str, target: str) -> None:
"""
Wrapper function to clone a dataset from source to target with the specified tracker.
Parameters
----------
path: str
The path to the backlogger folder.
source: str
The source dataset to clone.
target: str
The target path to clone the dataset to.
"""
tracker = get_tracker(path)
if tracker == 'datalad':
dl.clone(target=target, source=source, dataset=path)
elif tracker == 'None':
os.makedirs(path, exist_ok=True)
# Implement a simple clone by copying files
shutil.copytree(source, target, dirs_exist_ok=False)
else:
raise ValueError(f"Tracker {tracker} is not supported.")
return
def drop(path: str, reckless: Optional[str]=None) -> None:
"""
Wrapper function to drop data from a dataset located at path with the specified tracker.
Parameters
----------
path: str
The path to the backlogger folder.
reckless: Optional[str]
The datalad's reckless option for dropping data.
"""
tracker = get_tracker(path)
if tracker == 'datalad':
dl.drop(path, reckless=reckless)
elif tracker == 'None':
Warning("Tracker 'None' does not implement drop.")
pass
else:
raise ValueError(f"Tracker {tracker} is not supported.")
return

View file

@ -9,6 +9,31 @@ def test_init_folders(tmp_path):
assert os.path.exists(str(dataset_path / "backlogger.db")) assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_folders_no_tracker(tmp_path):
dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path), tracker="None")
assert os.path.exists(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_config(tmp_path):
dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path), tracker="None")
config_path = dataset_path / ".corrlib"
assert os.path.exists(str(config_path))
from configparser import ConfigParser
config = ConfigParser()
config.read(str(config_path))
assert config.get("core", "tracker") == "None"
assert config.get("core", "version") == "1.0"
assert config.get("core", "cached") == "True"
assert config.get("paths", "db") == "backlogger.db"
assert config.get("paths", "projects_path") == "projects"
assert config.get("paths", "archive_path") == "archive"
assert config.get("paths", "toml_imports_path") == "toml_imports"
assert config.get("paths", "import_scripts_path") == "import_scripts"
def test_init_db(tmp_path): def test_init_db(tmp_path):
dataset_path = tmp_path / "test_dataset" dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path)) init.create(str(dataset_path))
@ -24,7 +49,7 @@ def test_init_db(tmp_path):
table_names = [table[0] for table in tables] table_names = [table[0] for table in tables]
for expected_table in expected_tables: for expected_table in expected_tables:
assert expected_table in table_names assert expected_table in table_names
cursor.execute("SELECT * FROM projects;") cursor.execute("SELECT * FROM projects;")
projects = cursor.fetchall() projects = cursor.fetchall()
assert len(projects) == 0 assert len(projects) == 0
@ -47,7 +72,7 @@ def test_init_db(tmp_path):
project_column_names = [col[1] for col in project_columns] project_column_names = [col[1] for col in project_columns]
for expected_col in expected_project_columns: for expected_col in expected_project_columns:
assert expected_col in project_column_names assert expected_col in project_column_names
cursor.execute("PRAGMA table_info('backlogs');") cursor.execute("PRAGMA table_info('backlogs');")
backlog_columns = cursor.fetchall() backlog_columns = cursor.fetchall()
expected_backlog_columns = [ expected_backlog_columns = [