Merge branch 'develop' into fix/cache

This commit is contained in:
Justus Kuhlmann 2026-02-18 10:12:22 +01:00
commit 15bf399a89
Signed by: jkuhl
GPG key ID: 00ED992DD79B85A6
27 changed files with 4016 additions and 265 deletions

30
.github/workflows/mypy.yaml vendored Normal file
View file

@ -0,0 +1,30 @@
name: Mypy
on:
push:
pull_request:
workflow_dispatch:
jobs:
mypy:
runs-on: ubuntu-latest
env:
UV_CACHE_DIR: /tmp/.uv-cache
steps:
- name: Install git-annex
run: |
sudo apt-get update
sudo apt-get install -y git-annex
- name: Check out the repository
uses: https://github.com/RouxAntoine/checkout@v4.1.8
with:
show-progress: true
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python-version }}
enable-cache: true
- name: Install corrlib
run: uv sync --locked --all-extras --dev --python "3.12"
- name: Run tests
run: uv run mypy corrlib

39
.github/workflows/pytest.yaml vendored Normal file
View file

@ -0,0 +1,39 @@
name: Pytest
on:
push:
pull_request:
workflow_dispatch:
schedule:
- cron: '0 4 1 * *'
jobs:
pytest:
strategy:
matrix:
python-version:
- "3.12"
- "3.13"
- "3.14"
runs-on: ubuntu-latest
env:
UV_CACHE_DIR: /tmp/.uv-cache
steps:
- name: Install git-annex
run: |
sudo apt-get update
sudo apt-get install -y git-annex
- name: Check out the repository
uses: https://github.com/RouxAntoine/checkout@v4.1.8
with:
show-progress: true
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python-version }}
enable-cache: true
- name: Install corrlib
run: uv sync --locked --all-extras --dev --python ${{ matrix.python-version }}
- name: Run tests
run: uv run pytest --cov=corrlib tests

30
.github/workflows/ruff.yaml vendored Normal file
View file

@ -0,0 +1,30 @@
name: Ruff
on:
push:
pull_request:
workflow_dispatch:
jobs:
ruff:
runs-on: ubuntu-latest
env:
UV_CACHE_DIR: /tmp/.uv-cache
steps:
- name: Install git-annex
run: |
sudo apt-get update
sudo apt-get install -y git-annex
- name: Check out the repository
uses: https://github.com/RouxAntoine/checkout@v4.1.8
with:
show-progress: true
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- name: Install corrlib
run: uv sync --locked --all-extras --dev --python "3.12"
- name: Run tests
run: uv run ruff check corrlib

4
.gitignore vendored
View file

@ -3,3 +3,7 @@ __pycache__
*.egg-info
test.ipynb
test_ds
.vscode
.venv
.pytest_cache
.coverage

5
.gitmodules vendored
View file

@ -1,5 +0,0 @@
[submodule "projects/tmp"]
path = projects/tmp
url = git@kuhl-mann.de:lattice/charm_SF_data.git
datalad-id = 5f402163-77f2-470e-b6f1-64d7bf9f87d4
datalad-url = git@kuhl-mann.de:lattice/charm_SF_data.git

View file

@ -15,10 +15,10 @@ For now, we are interested in collecting primary IObservables only, as these are
__app_name__ = "corrlib"
from .main import *
from .import input as input
from .initialization import *
from .meas_io import *
from .cache_io import *
from .find import *
from .version import __version__
from .initialization import create as create
from .meas_io import load_record as load_record
from .meas_io import load_records as load_records
from .find import find_project as find_project
from .find import find_record as find_record
from .find import list_projects as list_projects

View file

@ -1,8 +1,9 @@
from corrlib import cli, __app_name__
def main():
def main() -> None:
cli.app(prog_name=__app_name__)
return
if __name__ == "__main__":

View file

@ -1,6 +1,6 @@
from typing import Optional
import typer
from corrlib import __app_name__, __version__
from corrlib import __app_name__
from .initialization import create
from .toml import import_tomls, update_project, reimport_project
from .find import find_record, list_projects
@ -8,6 +8,7 @@ from .tools import str2list
from .main import update_aliases
from .cache_io import drop_cache_files as cio_drop_cache_files
import os
from importlib.metadata import version
app = typer.Typer()
@ -15,7 +16,7 @@ app = typer.Typer()
def _version_callback(value: bool) -> None:
if value:
typer.echo(f"{__app_name__} v{__version__}")
print(__app_name__, version(__app_name__))
raise typer.Exit()
@ -133,6 +134,9 @@ def reimporter(
),
ident: str = typer.Argument()
) -> None:
"""
Reimport the toml file identfied by the ident string.
"""
uuid = ident.split("::")[0]
if len(ident.split("::")) > 1:
toml_file = os.path.join(path, "toml_imports", ident.split("::")[1])
@ -152,11 +156,16 @@ def init(
"--dataset",
"-d",
),
tracker: str = typer.Option(
str('datalad'),
"--tracker",
"-t",
),
) -> None:
"""
Initialize a new backlog-database.
"""
create(path)
create(path, tracker)
return

View file

@ -1,16 +1,30 @@
import sqlite3
import datalad.api as dl
import os
import json
import pandas as pd
import numpy as np
from .input.implementations import codes
from .tools import k2m, get_file
# this will implement the search functionality
from .tools import k2m, get_db_file
from .tracker import get
from typing import Any, Optional
def _project_lookup_by_alias(db, alias):
# this will lookup the project name based on the alias
def _project_lookup_by_alias(db: str, alias: str) -> str:
"""
Lookup a projects UUID by its (human-readable) alias.
Parameters
----------
db: str
The database to look up the project.
alias: str
The alias to look up.
Returns
-------
uuid: str
The UUID of the project with the given alias.
"""
conn = sqlite3.connect(db)
c = conn.cursor()
c.execute(f"SELECT * FROM 'projects' WHERE alias = '{alias}'")
@ -20,10 +34,25 @@ def _project_lookup_by_alias(db, alias):
print("Error: multiple projects found with alias " + alias)
elif len(results) == 0:
raise Exception("Error: no project found with alias " + alias)
return results[0][0]
return str(results[0][0])
def _project_lookup_by_id(db, uuid):
def _project_lookup_by_id(db: str, uuid: str) -> list[tuple[str, str]]:
"""
Return the project information available in the database by UUID.
Parameters
----------
db: str
The database to look up the project.
uuid: str
The uuid of the project in question.
Returns
-------
results: list
The row of the project in the database.
"""
conn = sqlite3.connect(db)
c = conn.cursor()
c.execute(f"SELECT * FROM 'projects' WHERE id = '{uuid}'")
@ -32,7 +61,40 @@ def _project_lookup_by_id(db, uuid):
return results
def _db_lookup(db, ensemble, correlator_name,code, project=None, parameters=None, created_before=None, created_after=None, updated_before=None, updated_after=None, revision=None):
def _db_lookup(db: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None,
created_before: Optional[str]=None, created_after: Optional[Any]=None, updated_before: Optional[Any]=None, updated_after: Optional[Any]=None) -> pd.DataFrame:
"""
Look up a correlator record in the database by the data given to the method.
Parameters
----------
db: str
The database to look up the record.
ensemble: str
The ensemble the record is associated with.
correlator_name: str
The name of the correlator in question.
code: str
The name of the code which was used to calculate the correlator.
project: str, optional
The UUID of the project the correlator was calculated in.
parameters: str, optional
A dictionary holding the exact parameters for the measurement that are held in the database.
created_before: str, optional
Timestamp string before which the meaurement has been created.
created_after: str, optional
Timestamp string after which the meaurement has been created.
updated_before: str, optional
Timestamp string before which the meaurement has been updated.
updated_after: str, optional
Timestamp string after which the meaurement has been updated.
Returns
-------
results: pd.DataFrame
A pandas DataFrame holding the information received form the DB query.
"""
project_str = project
search_expr = f"SELECT * FROM 'backlogs' WHERE name = '{correlator_name}' AND ensemble = '{ensemble}'"
@ -56,11 +118,38 @@ def _db_lookup(db, ensemble, correlator_name,code, project=None, parameters=Non
return results
def sfcf_filter(results, **kwargs):
def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
"""
Filter method for the Database entries holding SFCF calculations.
Parameters
----------
results: pd.DataFrame
The unfiltered pandas DataFrame holding the entries from the database.
offset: list[float], optional
quark_kappas: list[float]
quarks_masses: list[float]
qk1: float, optional
Mass parameter $\kappa_1$ of the first quark.
qk2: float, optional
Mass parameter $\kappa_2$ of the first quark.
qm1: float, optional
Bare quak mass $m_1$ of the first quark.
qm2: float, optional
Bare quak mass $m_1$ of the first quark.
quarks_thetas: list[list[float]], optional
wf1: optional
wf2: optional
Results
-------
results: pd.DataFrame
The filtered DataFrame, only holding the records that fit to the parameters given.
"""
drops = []
for ind in range(len(results)):
result = results.iloc[ind]
if result['code'] == 'sfcf':
param = json.loads(result['parameters'])
if 'offset' in kwargs:
if kwargs.get('offset') != param['offset']:
@ -139,27 +228,62 @@ def sfcf_filter(results, **kwargs):
return results.drop(drops)
def find_record(path, ensemble, correlator_name, code, project=None, parameters=None, created_before=None, created_after=None, updated_before=None, updated_after=None, revision=None, **kwargs):
db = path + '/backlogger.db'
def find_record(path: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None,
created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame:
db_file = get_db_file(path)
db = os.path.join(path, db_file)
if code not in codes:
raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes))
get_file(path, "backlogger.db")
results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after, revision=revision)
get(path, db_file)
results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after)
if code == "sfcf":
results = sfcf_filter(results, **kwargs)
elif code == "openQCD":
pass
else:
raise Exception
print("Found " + str(len(results)) + " result" + ("s" if len(results)>1 else ""))
return results.reset_index()
def find_project(path, name):
get_file(path, "backlogger.db")
return _project_lookup_by_alias(os.path.join(path, "backlogger.db"), name)
def find_project(path: str, name: str) -> str:
"""
Find a project by it's human readable name.
Parameters
----------
path: str
The path of the library.
name: str
The name of the project to look for in the library.
Returns
-------
uuid: str
The uuid of the project in question.
"""
db_file = get_db_file(path)
get(path, db_file)
return _project_lookup_by_alias(os.path.join(path, db_file), name)
def list_projects(path):
db = path + '/backlogger.db'
get_file(path, "backlogger.db")
conn = sqlite3.connect(db)
def list_projects(path: str) -> list[tuple[str, str]]:
"""
List all projects known to the library.
Parameters
----------
path: str
The path of the library.
Returns
-------
results: list[Any]
The projects known to the library.
"""
db_file = get_db_file(path)
get(path, db_file)
conn = sqlite3.connect(os.path.join(path, db_file))
c = conn.cursor()
c.execute("SELECT id,aliases FROM projects")
results = c.fetchall()

View file

@ -1,11 +1,11 @@
import os
import datalad.api as dl
from .tracker import save
import git
GITMODULES_FILE = '.gitmodules'
def move_submodule(repo_path, old_path, new_path):
def move_submodule(repo_path: str, old_path: str, new_path: str) -> None:
"""
Move a submodule to a new location.
@ -40,4 +40,6 @@ def move_submodule(repo_path, old_path, new_path):
repo = git.Repo(repo_path)
repo.git.add('.gitmodules')
# save new state of the dataset
dl.save(repo_path, message=f"Move module from {old_path} to {new_path}", dataset=repo_path)
save(repo_path, message=f"Move module from {old_path} to {new_path}", files=['.gitmodules', repo_path])
return

View file

@ -1,12 +1,17 @@
from configparser import ConfigParser
import sqlite3
import datalad.api as dl
import os
from .tracker import save, init
def _create_db(db):
def _create_db(db: str) -> None:
"""
Create the database file and the table.
Parameters
----------
db: str
Path of the database file.
"""
conn = sqlite3.connect(db)
c = conn.cursor()
@ -33,21 +38,84 @@ def _create_db(db):
updated_at TEXT)''')
conn.commit()
conn.close()
return
def create(path):
def _create_config(path: str, tracker: str, cached: bool) -> ConfigParser:
"""
Create the config file construction for backlogger.
Parameters
----------
path: str
The path of the libaray to create.
tracker: str
Type of the tracker to use for the library (only DataLad is supported at the moment).
cached: bool
Whether or not the library will create a cache folder for multiple reads when downloaded.
Returns
-------
config: ConfigParser
Cpnfig parser with the default configuration printed.
"""
config = ConfigParser()
config['core'] = {
'version': '1.0',
'tracker': tracker,
'cached': str(cached),
}
config['paths'] = {
'db': 'backlogger.db',
'projects_path': 'projects',
'archive_path': 'archive',
'toml_imports_path': 'toml_imports',
'import_scripts_path': 'import_scripts',
}
return config
def _write_config(path: str, config: ConfigParser) -> None:
"""
Write the config file to disk.
Parameters
----------
path: str
The path of the libaray to create.
config: ConfigParser
The configuration to be used as a ConfigParser, e.g. generated by _create_config.
"""
with open(os.path.join(path, '.corrlib'), 'w') as configfile:
config.write(configfile)
return
def create(path: str, tracker: str = 'datalad', cached: bool = True) -> None:
"""
Create folder of backlogs.
Parameters
----------
path: str
The path at which the library will be created.
tracker: str, optional
The tracker to use for the library. The delauft is DataLad, which is also the only one that is supported at the moment.
cached: bool, optional
Whether or not hte librarby will be cached. By default, it does cache already read entries.
"""
dl.create(path)
_create_db(path + '/backlogger.db')
os.chmod(path + '/backlogger.db', 0o666) # why does this not work?
os.makedirs(path + '/projects')
os.makedirs(path + '/archive')
os.makedirs(path + '/toml_imports')
os.makedirs(path + '/import_scripts/template.py')
with open(path + "/.gitignore", "w") as fp:
config = _create_config(path, tracker, cached)
init(path, tracker)
_write_config(path, config)
_create_db(os.path.join(path, config['paths']['db']))
os.chmod(os.path.join(path, config['paths']['db']), 0o666)
os.makedirs(os.path.join(path, config['paths']['projects_path']))
os.makedirs(os.path.join(path, config['paths']['archive_path']))
os.makedirs(os.path.join(path, config['paths']['toml_imports_path']))
os.makedirs(os.path.join(path, config['paths']['import_scripts_path'], 'template.py'))
with open(os.path.join(path, ".gitignore"), "w") as fp:
fp.write(".cache")
fp.close()
dl.save(path, dataset=path, message="Initialize backlogger directory.")
save(path, message="Initialized correlator library")
return

View file

@ -2,6 +2,6 @@
Import functions for different codes.
"""
from . import sfcf
from . import openQCD
from . import implementations
from . import sfcf as sfcf
from . import openQCD as openQCD
from . import implementations as implementations

View file

@ -1,2 +1,2 @@
# List of supported input implementations
codes = ['sfcf', 'openQCD']

View file

@ -2,10 +2,28 @@ import pyerrors.input.openQCD as input
import datalad.api as dl
import os
import fnmatch
from typing import Any
from typing import Any, Optional
def read_ms1_param(path: str, project: str, file_in_project: str) -> dict[str, Any]:
"""
Read the parameters for ms1 measurements from a parameter file in the project.
Parameters
----------
path: str
The path to the backlogger folder.
project: str
The project from which to read the parameter file.
file_in_project: str
The path to the parameter file within the project.
Returns
-------
param: dict[str, Any]
The parameters read from the file.
"""
file = os.path.join(path, "projects", project, file_in_project)
ds = os.path.join(path, "projects", project)
dl.get(file, dataset=ds)
@ -52,6 +70,24 @@ def read_ms1_param(path: str, project: str, file_in_project: str) -> dict[str, A
def read_ms3_param(path: str, project: str, file_in_project: str) -> dict[str, Any]:
"""
Read the parameters for ms3 measurements from a parameter file in the project.
Parameters
----------
path: str
The path to the backlogger folder.
project: str
The project from which to read the parameter file.
file_in_project: str
The path to the parameter file within the project.
Returns
-------
param: dict[str, Any]
The parameters read from the file.
"""
file = os.path.join(path, "projects", project, file_in_project)
ds = os.path.join(path, "projects", project)
dl.get(file, dataset=ds)
@ -67,7 +103,37 @@ def read_ms3_param(path: str, project: str, file_in_project: str) -> dict[str, A
return param
def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: list[str]=None, files: list[str]=None) -> dict[str, Any]:
def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]:
"""
Read reweighting factor measurements from the project.
Parameters
----------
path: str
The path to the backlogger folder.
project: str
The project from which to read the measurements.
dir_in_project: str
The directory within the project where the measurements are located.
param: dict[str, Any]
The parameters for the measurements.
prefix: str
The prefix of the measurement files.
postfix: str
The postfix of the measurement files.
version: str
The version of the openQCD used.
names: list[str]
Specific names for the replica of the ensemble the measurement file belongs to.
files: list[str]
Specific files to read.
Returns
-------
rw_dict: dict[str, dict[str, Any]]
The reweighting factor measurements read from the files.
"""
dataset = os.path.join(path, "projects", project)
directory = os.path.join(dataset, dir_in_project)
if files is None:
@ -94,7 +160,43 @@ def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any
return rw_dict
def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str=None, names: list[str]=None, files: list[str]=None) -> dict[str, Any]:
def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str="", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]:
"""
Extract t0 measurements from the project.
Parameters
----------
path: str
The path to the backlogger folder.
project: str
The project from which to read the measurements.
dir_in_project: str
The directory within the project where the measurements are located.
param: dict[str, Any]
The parameters for the measurements.
prefix: str
The prefix of the measurement files.
dtr_read: int
The dtr_read parameter for the extraction.
xmin: int
The xmin parameter for the extraction.
spatial_extent: int
The spatial_extent parameter for the extraction.
fit_range: int
The fit_range parameter for the extraction.
postfix: str
The postfix of the measurement files.
names: list[str]
Specific names for the replica of the ensemble the measurement file belongs to.
files: list[str]
Specific files to read.
Returns
-------
t0_dict: dict
Dictionary of t0 values in the pycorrlib style, with the parameters at hand.
"""
dataset = os.path.join(path, "projects", project)
directory = os.path.join(dataset, dir_in_project)
if files is None:
@ -132,7 +234,43 @@ def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, An
return t0_dict
def extract_t1(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = None, names: list[str]=None, files: list[str]=None) -> dict[str, Any]:
def extract_t1(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = "", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]:
"""
Extract t1 measurements from the project.
Parameters
----------
path: str
The path to the backlogger folder.
project: str
The project from which to read the measurements.
dir_in_project: str
The directory within the project where the measurements are located.
param: dict[str, Any]
The parameters for the measurements.
prefix: str
The prefix of the measurement files.
dtr_read: int
The dtr_read parameter for the extraction.
xmin: int
The xmin parameter for the extraction.
spatial_extent: int
The spatial_extent parameter for the extraction.
fit_range: int
The fit_range parameter for the extraction.
postfix: str
The postfix of the measurement files.
names: list[str]
Specific names for the replica of the ensemble the measurement file belongs to.
files: list[str]
Specific files to read.
Returns
-------
t1_dict: dict
Dictionary of t1 values in the pycorrlib style, with the parameters at hand.
"""
directory = os.path.join(path, "projects", project, dir_in_project)
if files is None:
files = []
@ -161,7 +299,7 @@ def extract_t1(path: str, project: str, dir_in_project: str, param: dict[str, An
for k in ["integrator", "eps", "ntot", "dnms"]:
par_list.append(str(param[k]))
pars = "/".join(par_list)
t0_dict: dict[str, Any] = {}
t0_dict[param["type"]] = {}
t0_dict[param["type"]][pars] = t0
return t0_dict
t1_dict: dict[str, Any] = {}
t1_dict[param["type"]] = {}
t1_dict[param["type"]][pars] = t0
return t1_dict

View file

@ -5,7 +5,7 @@ import os
from typing import Any
bi_corrs: list = ["f_P", "fP", "f_p",
bi_corrs: list[str] = ["f_P", "fP", "f_p",
"g_P", "gP", "g_p",
"fA0", "f_A", "f_a",
"gA0", "g_A", "g_a",
@ -43,7 +43,7 @@ bi_corrs: list = ["f_P", "fP", "f_p",
"l3A2", "l3_A2", "g_av23",
]
bb_corrs: list = [
bb_corrs: list[str] = [
'F1',
'F_1',
'f_1',
@ -64,7 +64,7 @@ bb_corrs: list = [
'F_sPdP_d',
]
bib_corrs: list = [
bib_corrs: list[str] = [
'F_V0',
'K_V0',
]
@ -184,7 +184,7 @@ def read_param(path: str, project: str, file_in_project: str) -> dict[str, Any]:
return params
def _map_params(params: dict, spec_list: list) -> dict[str, Any]:
def _map_params(params: dict[str, Any], spec_list: list[str]) -> dict[str, Any]:
"""
Map the extracted parameters to the extracted data.
@ -228,7 +228,25 @@ def _map_params(params: dict, spec_list: list) -> dict[str, Any]:
return new_specs
def get_specs(key, parameters, sep='/') -> str:
def get_specs(key: str, parameters: dict[str, Any], sep: str = '/') -> str:
"""
Get sepcification from the parameter file for a specific key in the read measurements
Parameters
----------
key: str
The key for whioch the parameters are to be looked up.
parameters: dict[str, Any]
The dictionary with the parameters from the parameter file.
sep: str
Separator string for the key. (default="/")
Return
------
s: str
json string holding the parameters.
"""
key_parts = key.split(sep)
if corr_types[key_parts[0]] == 'bi':
param = _map_params(parameters, key_parts[1:-1])
@ -238,7 +256,7 @@ def get_specs(key, parameters, sep='/') -> str:
return s
def read_data(path, project, dir_in_project, prefix, param, version='1.0c', cfg_seperator='n', sep='/', **kwargs) -> dict:
def read_data(path: str, project: str, dir_in_project: str, prefix: str, param: dict[str, Any], version: str = '1.0c', cfg_seperator: str = 'n', sep: str = '/', **kwargs: Any) -> dict[str, Any]:
"""
Extract the data from the sfcf file.

View file

@ -5,11 +5,12 @@ import os
from .git_tools import move_submodule
import shutil
from .find import _project_lookup_by_id
from .tools import list2str, str2list, get_file
from typing import Union
from .tools import list2str, str2list, get_db_file
from .tracker import get, save, unlock, clone, drop
from typing import Union, Optional
def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Union[str, None]=None, aliases: Union[str, None]=None, code: Union[str, None]=None):
def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Union[list[str], None]=None, aliases: Union[list[str], None]=None, code: Union[str, None]=None) -> None:
"""
Create a new project entry in the database.
@ -24,30 +25,48 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni
code: str (optional)
The code that was used to create the measurements.
"""
db = path + "/backlogger.db"
get_file(path, "backlogger.db")
db_file = get_db_file(path)
db = os.path.join(path, db_file)
get(path, db_file)
conn = sqlite3.connect(db)
c = conn.cursor()
known_projects = c.execute("SELECT * FROM projects WHERE id=?", (uuid,))
if known_projects.fetchone():
raise ValueError("Project already imported, use update_project() instead.")
dl.unlock(db, dataset=path)
alias_str = None
unlock(path, db_file)
alias_str = ""
if aliases is not None:
alias_str = list2str(aliases)
tag_str = None
tag_str = ""
if tags is not None:
tag_str = list2str(tags)
c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", (uuid, alias_str, tag_str, owner, code))
conn.commit()
conn.close()
dl.save(db, message="Added entry for project " + uuid + " to database", dataset=path)
save(path, message="Added entry for project " + uuid + " to database", files=[db_file])
return
def update_project_data(path, uuid, prop, value = None):
get_file(path, "backlogger.db")
conn = sqlite3.connect(os.path.join(path, "backlogger.db"))
def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] = None) -> None:
"""
Update/Edit a project entry in the database.
Thin wrapper around sql3 call.
Parameters
----------
path: str
The path to the backlogger folder.
uuid: str
The uuid of the project.
prop: str
Property of the entry to edit
value: str or None
Value to se `prop` to.
"""
db_file = get_db_file(path)
get(path, db_file)
conn = sqlite3.connect(os.path.join(path, db_file))
c = conn.cursor()
c.execute(f"UPDATE projects SET '{prop}' = '{value}' WHERE id == '{uuid}'")
conn.commit()
@ -55,9 +74,10 @@ def update_project_data(path, uuid, prop, value = None):
return
def update_aliases(path: str, uuid: str, aliases: list[str]):
db = os.path.join(path, "backlogger.db")
get_file(path, "backlogger.db")
def update_aliases(path: str, uuid: str, aliases: list[str]) -> None:
db_file = get_db_file(path)
db = os.path.join(path, db_file)
get(path, db_file)
known_data = _project_lookup_by_id(db, uuid)[0]
known_aliases = known_data[1]
@ -76,14 +96,16 @@ def update_aliases(path: str, uuid: str, aliases: list[str]):
if not len(new_alias_list) == len(known_alias_list):
alias_str = list2str(new_alias_list)
dl.unlock(db, dataset=path)
unlock(path, db_file)
update_project_data(path, uuid, "aliases", alias_str)
dl.save(db, dataset=path)
save(path, message="Updated aliases for project " + uuid, files=[db_file])
return
def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Union[str, None]=None, aliases: Union[str, None]=None, code: Union[str, None]=None, isDataset: bool=True):
def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Optional[list[str]]=None, aliases: Optional[list[str]]=None, code: Optional[str]=None, isDataset: bool=True) -> str:
"""
Import a datalad dataset into the backlogger.
Parameters
----------
@ -91,43 +113,35 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Unio
The url of the project to import. This can be any url that datalad can handle.
path: str
The path to the backlogger folder.
aliases: list[str]
Custom name of the project, alias of the project.
code: str
owner: str, optional
Person responsible for the maintainance of the project to be impoerted.
tags: list[str], optional
Custom tags of the imported project.
aliases: list[str], optional
Custom names of the project, alias of the project.
code: str, optional
Code that was used to create the measurements.
Import a datalad dataset into the backlogger.
Parameters
----------
path: str
The path to the backlogger directory.
url: str
The url of the project to import. This can be any url that datalad can handle.
Also supported are non-datalad datasets, which will be converted to datalad datasets,
in order to receive a uuid and have a consistent interface.
Returns
-------
uuid: str
The unique identifier of the imported project.
"""
tmp_path = path + '/projects/tmp'
if not isDataset:
dl.create(tmp_path, dataset=path)
shutil.copytree(url + "/*", path + '/projects/tmp/')
dl.save(tmp_path, dataset=path)
else:
dl.install(path=tmp_path, source=url, dataset=path)
tmp_path = os.path.join(path, 'projects/tmp')
clone(path, source=url, target=tmp_path)
tmp_ds = dl.Dataset(tmp_path)
conf = dlc.ConfigManager(tmp_ds)
uuid = conf.get("datalad.dataset.id")
uuid = str(conf.get("datalad.dataset.id"))
if not uuid:
raise ValueError("The dataset does not have a uuid!")
if not os.path.exists(path + "/projects/" + uuid):
db = path + "/backlogger.db"
get_file(path, "backlogger.db")
dl.unlock(db, dataset=path)
db_file = get_db_file(path)
get(path, db_file)
unlock(path, db_file)
create_project(path, uuid, owner, tags, aliases, code)
move_submodule(path, 'projects/tmp', 'projects/' + uuid)
os.mkdir(path + '/import_scripts/' + uuid)
dl.save([db, path + '/projects/' + uuid], message="Import project from " + url, dataset=path)
save(path, message="Import project from " + url, files=['projects/' + uuid, db_file])
else:
dl.drop(tmp_path, reckless='kill')
shutil.rmtree(tmp_path)
@ -142,9 +156,19 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Unio
return uuid
def drop_project_data(path: str, uuid: str, path_in_project: str = ""):
def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None:
"""
Drop (parts of) a prject to free up diskspace
"""
dl.drop(path + "/projects/" + uuid + "/" + path_in_project)
Drop (parts of) a project to free up diskspace
Parameters
----------
path: str
Path of the library.
uuid: str
The UUID ofthe project rom which data is to be dropped.
path_pn_project: str, optional
If set, only the given path within the project is dropped.
"""
drop(path + "/projects/" + uuid + "/" + path_in_project)
return

View file

@ -1,17 +1,19 @@
from pyerrors.input import json as pj
import os
import datalad.api as dl
import sqlite3
from .input import sfcf,openQCD
import json
from typing import Union, Optional
from typing import Union, Optional,Any
from pyerrors import Obs, Corr, load_object, dump_object
from hashlib import sha256, sha1
from .tools import cached, get_file, record2name_key, name_key2record, make_version_hash
from .tools import record2name_key, name_key2record, make_version_hash
from .cache_io import is_in_cache, cache_path, cache_dir, get_version_hash
from .tools import get_db_file, cache_enabled
from .tracker import get, save, unlock
import shutil
def write_measurement(path, ensemble, measurement, uuid, code, parameter_file: Optional[str]=None):
def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: str) -> None:
"""
Write a measurement to the backlog.
If the file for the measurement already exists, update the measurement.
@ -26,10 +28,15 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file: O
Measurements to be captured in the backlogging system.
uuid: str
The uuid of the project.
code: str
Name of the code that was used for the project.
parameter_file: str
The parameter file used for the measurement.
"""
db = os.path.join(path, 'backlogger.db')
get_file(path, "backlogger.db")
dl.unlock(db, dataset=path)
db_file = get_db_file(path)
db = os.path.join(path, db_file)
get(path, db_file)
unlock(path, db_file)
conn = sqlite3.connect(db)
c = conn.cursor()
files = []
@ -42,7 +49,7 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file: O
os.makedirs(os.path.join(path, '.', 'archive', ensemble, corr))
else:
if os.path.exists(file):
dl.unlock(file, dataset=path)
unlock(path, file_in_archive)
known_meas = pj.load_json_dict(file)
if code == "sfcf":
parameters = sfcf.read_param(path, uuid, parameter_file)
@ -92,13 +99,13 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file: O
(corr, ensemble, code, meas_path, uuid, pars[subkey], parameter_file))
c.execute("UPDATE backlogs SET current_version = ?, updated_at = datetime('now') WHERE path = ?", (data_hash, meas_path))
pj.dump_dict_to_json(known_meas, file)
conn.commit()
files.append(db)
files.append(os.path.join(path, db_file))
conn.close()
dl.save(files, message="Add measurements to database", dataset=path)
save(path, message="Add measurements to database", files=files)
return
def load_record(path: str, meas_path: str):
def load_record(path: str, meas_path: str) -> Union[Corr, Obs]:
"""
Load a list of records by their paths.
@ -117,7 +124,7 @@ def load_record(path: str, meas_path: str):
return load_records(path, [meas_path])[0]
def load_records(path: str, record_paths: list[str], preloaded = {}) -> list[Union[Corr, Obs]]:
def load_records(path: str, record_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]:
"""
Load a list of records by their paths.
@ -127,10 +134,13 @@ def load_records(path: str, record_paths: list[str], preloaded = {}) -> list[Uni
Path of the correlator library.
meas_paths: list[str]
A list of the paths to the correlator in the backlog system.
perloaded: dict[str, Any]
The data that is already prelaoded. Of interest if data has alread been loaded in the same script.
Returns
-------
List
retruned_data: list
The loaded records.
"""
needed_data: dict[str, list[str]] = {}
for rpath in record_paths:
@ -138,7 +148,7 @@ def load_records(path: str, record_paths: list[str], preloaded = {}) -> list[Uni
if file not in needed_data.keys():
needed_data[file] = []
needed_data[file].append(key)
returned_data: list = []
returned_data: list[Any] = []
for file in needed_data.keys():
for key in list(needed_data[file]):
record = name_key2record(file, key)
@ -149,7 +159,7 @@ def load_records(path: str, record_paths: list[str], preloaded = {}) -> list[Uni
if file not in preloaded:
preloaded[file] = preload(path, file)
returned_data.append(preloaded[file][key])
if cached:
if cache_enabled(path):
if not is_in_cache(path, record):
file, key = record2name_key(record)
if not os.path.exists(cache_dir(path, file)):
@ -159,19 +169,46 @@ def load_records(path: str, record_paths: list[str], preloaded = {}) -> list[Uni
return returned_data
def preload(path: str, file: str):
get_file(path, file)
filedict = pj.load_json_dict(os.path.join(path, file))
def preload(path: str, file: str) -> dict[str, Any]:
"""
Read the contents of a file into a json dictionary with the pyerrors.json.load_json_dict method.
Parameters
----------
path: str
The path of the library.
file: str
The file within the library to be laoded.
Returns
-------
filedict: dict[str, Any]
The data read from the file.
"""
get(path, file)
filedict: dict[str, Any] = pj.load_json_dict(os.path.join(path, file))
print("> read file")
return filedict
def drop_record(path: str, meas_path: str):
file_in_archive, sub_key = record2name_key(meas_path)
def drop_record(path: str, meas_path: str) -> None:
"""
Drop a record by it's path.
Parameters
----------
path: str
The path of the library.
meas_path: str
The measurement path as noted in the database.
"""
file_in_archive = meas_path.split("::")[0]
file = os.path.join(path, file_in_archive)
db = os.path.join(path, 'backlogger.db')
get_file(path, 'backlogger.db')
dl.unlock(db, dataset=path)
db_file = get_db_file(path)
db = os.path.join(path, db_file)
get(path, db_file)
sub_key = meas_path.split("::")[1]
unlock(path, db_file)
conn = sqlite3.connect(db)
c = conn.cursor()
if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path, )).fetchone() is not None:
@ -183,10 +220,24 @@ def drop_record(path: str, meas_path: str):
known_meas = pj.load_json_dict(file)
if sub_key in known_meas:
del known_meas[sub_key]
dl.unlock(file, dataset=path)
unlock(path, file_in_archive)
pj.dump_dict_to_json(known_meas, file)
dl.save([db, file], message="Drop measurements to database", dataset=path)
save(path, message="Drop measurements to database", files=[db, file])
return
else:
raise ValueError("This measurement does not exist as a file!")
def drop_cache(path: str) -> None:
"""
Drop the cache directory of the library.
Parameters
----------
path: str
The path of the library.
"""
cache_dir = os.path.join(path, ".cache")
for f in os.listdir(cache_dir):
shutil.rmtree(os.path.join(cache_dir, f))
return

View file

@ -10,22 +10,48 @@ the import of projects via TOML.
import tomllib as toml
import shutil
import datalad.api as dl
from .tracker import save
from .input import sfcf, openQCD
from .main import import_project, update_aliases
from .meas_io import write_measurement
import datalad.api as dl
import os
from .input.implementations import codes as known_codes
from typing import Any
def replace_string(string: str, name: str, val: str):
def replace_string(string: str, name: str, val: str) -> str:
"""
Replace a placeholder {name} with a value in a string.
Parameters
----------
string: str
String in which the placeholders are to be replaced.
name: str
The name of the placeholder.
val: str
The value the placeholder is to be replaced with.
"""
if '{' + name + '}' in string:
n = string.replace('{' + name + '}', val)
return n
else:
return string
def replace_in_meas(measurements: dict, vars: dict[str, str]):
# replace global variables
def replace_in_meas(measurements: dict[str, dict[str, Any]], vars: dict[str, str]) -> dict[str, dict[str, Any]]:
"""
Replace placeholders in the defiitions for a measurement.
Parameters
----------
measurements: dict[str, dict[str, Any]]
The measurements read from the toml file.
vars: dict[str, str]
Simple key:value dictionary with the keys to be replaced by the values.
"""
for name, value in vars.items():
for m in measurements.keys():
for key in measurements[m].keys():
@ -36,7 +62,18 @@ def replace_in_meas(measurements: dict, vars: dict[str, str]):
measurements[m][key][i] = replace_string(measurements[m][key][i], name, value)
return measurements
def fill_cons(measurements, constants):
def fill_cons(measurements: dict[str, dict[str, Any]], constants: dict[str, str]) -> dict[str, dict[str, Any]]:
"""
Fill in defined constants into the measurements of the toml-file.
Parameters
----------
measurements: dict[str, dict[str, Any]]
The measurements read from the toml file.
constants: dict[str, str]
Simple key:value dictionary with the keys to be replaced by the values.
"""
for m in measurements.keys():
for name, val in constants.items():
if name not in measurements[m].keys():
@ -44,7 +81,15 @@ def fill_cons(measurements, constants):
return measurements
def check_project_data(d: dict) -> None:
def check_project_data(d: dict[str, dict[str, str]]) -> None:
"""
Check the data given in the toml import file for the project we want to import.
Parameters
----------
d: dict
The dictionary holding the data necessary to import the project.
"""
if 'project' not in d.keys() or 'measurements' not in d.keys() or len(list(d.keys())) > 4:
raise ValueError('There should only be maximally be four keys on the top level, "project" and "measurements" are mandatory, "contants" is optional!')
project_data = d['project']
@ -57,7 +102,17 @@ def check_project_data(d: dict) -> None:
return
def check_measurement_data(measurements: dict, code: str) -> None:
def check_measurement_data(measurements: dict[str, dict[str, str]], code: str) -> None:
"""
Check syntax of the measurements we want to import.
Parameters
----------
measurements: dict[str, dict[str, str]]
The dictionary holding the necessary data to import the project.
code: str
The code used for the project.
"""
var_names: list[str] = []
if code == "sfcf":
var_names = ["path", "ensemble", "param_file", "version", "prefix", "cfg_seperator", "names"]
@ -72,8 +127,21 @@ def check_measurement_data(measurements: dict, code: str) -> None:
def import_tomls(path: str, files: list[str], copy_files: bool=True) -> None:
"""
Import multiple toml files.
Parameters
----------
path: str
Path to the backlog directory.
files: list[str]
Path to the description files.
copy_files: bool, optional
Whether the toml-files will be copied into the library. Default is True.
"""
for file in files:
import_toml(path, file, copy_files)
return
def import_toml(path: str, file: str, copy_file: bool=True) -> None:
@ -86,19 +154,21 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None:
Path to the backlog directory.
file: str
Path to the description file.
copy_file: bool, optional
Whether the toml-files will be copied into the library. Default is True.
"""
print("Import project as decribed in " + file)
with open(file, 'rb') as fp:
toml_dict = toml.load(fp)
check_project_data(toml_dict)
project: dict = toml_dict['project']
project: dict[str, Any] = toml_dict['project']
if project['code'] not in known_codes:
raise ValueError('Code' + project['code'] + 'has no import implementation!')
measurements: dict = toml_dict['measurements']
measurements: dict[str, dict[str, Any]] = toml_dict['measurements']
measurements = fill_cons(measurements, toml_dict['constants'] if 'constants' in toml_dict else {})
measurements = replace_in_meas(measurements, toml_dict['replace'] if 'replace' in toml_dict else {})
check_measurement_data(measurements, project['code'])
aliases = project.get('aliases', None)
aliases = project.get('aliases', [])
uuid = project.get('uuid', None)
if uuid is not None:
if not os.path.exists(path + "/projects/" + uuid):
@ -133,29 +203,29 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None:
for rwp in ["integrator", "eps", "ntot", "dnms"]:
param[rwp] = "Unknown"
param['type'] = 't0'
measurement = openQCD.extract_t0(path, uuid, md['path'], param, md["prefix"], md["dtr_read"], md["xmin"], md["spatial_extent"],
fit_range=md.get('fit_range', 5), postfix=md.get('postfix', None), names=md.get('names', None), files=md.get('files', None))
measurement = openQCD.extract_t0(path, uuid, md['path'], param, str(md["prefix"]), int(md["dtr_read"]), int(md["xmin"]), int(md["spatial_extent"]),
fit_range=int(md.get('fit_range', 5)), postfix=str(md.get('postfix', '')), names=md.get('names', []), files=md.get('files', []))
elif md['measurement'] == 't1':
if 'param_file' in md:
param = openQCD.read_ms3_param(path, uuid, md['param_file'])
param['type'] = 't1'
measurement = openQCD.extract_t1(path, uuid, md['path'], param, md["prefix"], md["dtr_read"], md["xmin"], md["spatial_extent"],
fit_range=md.get('fit_range', 5), postfix=md.get('postfix', None), names=md.get('names', None), files=md.get('files', None))
measurement = openQCD.extract_t1(path, uuid, md['path'], param, str(md["prefix"]), int(md["dtr_read"]), int(md["xmin"]), int(md["spatial_extent"]),
fit_range=int(md.get('fit_range', 5)), postfix=str(md.get('postfix', '')), names=md.get('names', []), files=md.get('files', []))
write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else None))
write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else ''))
if not os.path.exists(os.path.join(path, "toml_imports", uuid)):
os.makedirs(os.path.join(path, "toml_imports", uuid))
if copy_file:
import_file = os.path.join(path, "toml_imports", uuid, file.split("/")[-1])
shutil.copy(file, import_file)
dl.save(import_file, message="Import using " + import_file, dataset=path)
save(path, files=[import_file], message="Import using " + import_file)
print("File copied to " + import_file)
print("Imported project.")
return
def reimport_project(path, uuid):
def reimport_project(path: str, uuid: str) -> None:
"""
Reimport an existing project using the files that are already available for this project.
@ -173,6 +243,17 @@ def reimport_project(path, uuid):
return
def update_project(path, uuid):
def update_project(path: str, uuid: str) -> None:
"""
Update all entries associated with a given project.
Parameters
----------
path: str
The path of the library.
uuid: str
The unique identifier of the project to be updated.
"""
dl.update(how='merge', follow='sibling', dataset=os.path.join(path, "projects", uuid))
# reimport_project(path, uuid)
return

View file

@ -1,45 +1,201 @@
import os
import datalad.api as dl
import hashlib
from configparser import ConfigParser
from typing import Any
def str2list(string):
CONFIG_FILENAME = ".corrlib"
cached: bool = True
def str2list(string: str) -> list[str]:
"""
Convert a comma-separated string to a list.
Parameters
----------
string: str
The sting holding a comma-sparated list.
Returns
-------
s: list[str]
The list of strings that was held bythe comma separated string.
"""
return string.split(",")
def list2str(mylist):
def list2str(mylist: list[str]) -> str:
"""
Convert a list to a comma-separated string.
Parameters
----------
mylist: list[str]
A list of strings to be concatinated.
Returns
-------
s: list[str]
The sting holding a comma-sparated list.
"""
s = ",".join(mylist)
return s
cached = True
def m2k(m):
def m2k(m: float) -> float:
"""
Convert to bare quark mas $m$ to inverse mass parameter $kappa$.
Parameters
----------
m: float
Bare quark mass.
Returns
-------
k: float
The corresponing $kappa$.
"""
return 1/(2*m+8)
def k2m(k):
def k2m(k: float) -> float:
"""
Convert from the inverse bare quark parameter $kappa$ to the bare quark mass $m$.
Parameters
----------
k: float
Inverse bare quark mass parameter $kappa$.
Returns
-------
m: float
The corresponing bare quark mass.
"""
return (1/(2*k))-4
def get_file(path, file):
if file == "backlogger.db":
def get_file(path: str, file: str) -> None:
if file == get_db_file(path):
print("Downloading database...")
else:
print("Downloading data...")
dl.get(os.path.join(path, file), dataset=path)
print("> downloaded file")
return
def record2name_key(record_path: str):
def record2name_key(record_path: str) -> tuple[str, str]:
"""
Convert a record to a pair of name and key.
Parameters
----------
record: str
Returns
-------
name: str
key: str
"""
file = record_path.split("::")[0]
key = record_path.split("::")[1]
return file, key
def name_key2record(name: str, key: str):
def name_key2record(name: str, key: str) -> str:
"""
Convert a name and a key to a record name.
Parameters
----------
name: str
key: str
Returns
-------
record: str
"""
return name + "::" + key
def make_version_hash(path, record):
def make_version_hash(path: str, record: str) -> str:
file, key = record2name_key(record)
with open(os.path.join(path, file), 'rb') as fp:
file_hash = hashlib.file_digest(fp, 'sha1').hexdigest()
return file_hash
def set_config(path: str, section: str, option: str, value: Any) -> None:
"""
Set configuration parameters for the library.
Parameters
----------
path: str
The path of the library.
section: str
The section within the configuration file.
option: str
The option to be set to value.
value: Any
The value we set the option to.
"""
config_path = os.path.join(path, '.corrlib')
config = ConfigParser()
if os.path.exists(config_path):
config.read(config_path)
if not config.has_section(section):
config.add_section(section)
config.set(section, option, value)
with open(config_path, 'w') as configfile:
config.write(configfile)
return
def get_db_file(path: str) -> str:
"""
Get the database file associated with the library at the given path.
Parameters
----------
path: str
The path of the library.
Returns
-------
db_file: str
The file holding the database.
"""
config_path = os.path.join(path, CONFIG_FILENAME)
config = ConfigParser()
if os.path.exists(config_path):
config.read(config_path)
db_file = config.get('paths', 'db', fallback='backlogger.db')
return db_file
def cache_enabled(path: str) -> bool:
"""
Check, whether the library is cached.
Fallback is true.
Parameters
----------
path: str
The path of the library.
Returns
-------
cached_bool: bool
Whether the given library is cached.
"""
config_path = os.path.join(path, CONFIG_FILENAME)
config = ConfigParser()
if os.path.exists(config_path):
config.read(config_path)
cached_str = config.get('core', 'cached', fallback='True')
cached_bool = cached_str == ('True')
return cached_bool

169
corrlib/tracker.py Normal file
View file

@ -0,0 +1,169 @@
import os
from configparser import ConfigParser
import datalad.api as dl
from typing import Optional
import shutil
from .tools import get_db_file
def get_tracker(path: str) -> str:
"""
Get the tracker used in the dataset located at path.
Parameters
----------
path: str
The path to the backlogger folder.
Returns
-------
tracker: str
The tracker used in the dataset.
"""
config_path = os.path.join(path, '.corrlib')
config = ConfigParser()
if os.path.exists(config_path):
config.read(config_path)
else:
raise FileNotFoundError(f"No config file found in {path}.")
tracker = config.get('core', 'tracker', fallback='datalad')
return tracker
def get(path: str, file: str) -> None:
"""
Wrapper function to get a file from the dataset located at path with the specified tracker.
Parameters
----------
path: str
The path to the backlogger folder.
file: str
The file to get.
"""
tracker = get_tracker(path)
if tracker == 'datalad':
if file == get_db_file(path):
print("Downloading database...")
else:
print("Downloading data...")
dl.get(os.path.join(path, file), dataset=path)
print("> downloaded file")
elif tracker == 'None':
pass
else:
raise ValueError(f"Tracker {tracker} is not supported.")
return
def save(path: str, message: str, files: Optional[list[str]]=None) -> None:
"""
Wrapper function to save a file to the dataset located at path with the specified tracker.
Parameters
----------
path: str
The path to the backlogger folder.
message: str
The commit message.
files: list[str], optional
The files to save. If None, all changes are saved.
"""
tracker = get_tracker(path)
if tracker == 'datalad':
if files is not None:
files = [os.path.join(path, f) for f in files]
dl.save(files, message=message, dataset=path)
elif tracker == 'None':
Warning("Tracker 'None' does not implement save.")
pass
else:
raise ValueError(f"Tracker {tracker} is not supported.")
def init(path: str, tracker: str='datalad') -> None:
"""
Initialize a dataset at the specified path with the specified tracker.
Parameters
----------
path: str
The path to initialize the dataset.
tracker: str
The tracker to use. Currently only 'datalad' and 'None' are supported.
"""
if tracker == 'datalad':
dl.create(path)
elif tracker == 'None':
os.makedirs(path, exist_ok=True)
else:
raise ValueError(f"Tracker {tracker} is not supported.")
return
def unlock(path: str, file: str) -> None:
"""
Wrapper function to unlock a file in the dataset located at path with the specified tracker.
Parameters
----------
path : str
The path to the backlogger folder.
file : str
The file to unlock.
"""
tracker = get_tracker(path)
if tracker == 'datalad':
dl.unlock(file, dataset=path)
elif tracker == 'None':
Warning("Tracker 'None' does not implement unlock.")
pass
else:
raise ValueError(f"Tracker {tracker} is not supported.")
return
def clone(path: str, source: str, target: str) -> None:
"""
Wrapper function to clone a dataset from source to target with the specified tracker.
Parameters
----------
path: str
The path to the backlogger folder.
source: str
The source dataset to clone.
target: str
The target path to clone the dataset to.
"""
tracker = get_tracker(path)
if tracker == 'datalad':
dl.clone(target=target, source=source, dataset=path)
elif tracker == 'None':
os.makedirs(path, exist_ok=True)
# Implement a simple clone by copying files
shutil.copytree(source, target, dirs_exist_ok=False)
else:
raise ValueError(f"Tracker {tracker} is not supported.")
return
def drop(path: str, reckless: Optional[str]=None) -> None:
"""
Wrapper function to drop data from a dataset located at path with the specified tracker.
Parameters
----------
path: str
The path to the backlogger folder.
reckless: Optional[str]
The datalad's reckless option for dropping data.
"""
tracker = get_tracker(path)
if tracker == 'datalad':
dl.drop(path, reckless=reckless)
elif tracker == 'None':
Warning("Tracker 'None' does not implement drop.")
pass
else:
raise ValueError(f"Tracker {tracker} is not supported.")
return

View file

@ -1 +1,34 @@
__version__ = "0.2.3"
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '0.2.4.dev14+g602324f84.d20251202'
__version_tuple__ = version_tuple = (0, 2, 4, 'dev14', 'g602324f84.d20251202')
__commit_id__ = commit_id = 'g602324f84'

View file

@ -1,6 +1,52 @@
[build-system]
requires = ["setuptools >= 63.0.0", "wheel"]
requires = ["setuptools >= 63.0.0", "wheel", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[project]
requires-python = ">=3.10"
name = "corrlib"
dynamic = ["version"]
dependencies = [
"gitpython>=3.1.45",
'pyerrors>=2.11.1',
"datalad>=1.1.0",
'typer>=0.12.5',
]
description = "Python correlation library"
authors = [
{ name = 'Justus Kuhlmann', email = 'j_kuhl19@uni-muenster.de'}
]
[project.scripts]
pcl = "corrlib.cli:app"
[tool.setuptools.packages.find]
include = ["corrlib", "corrlib.*"]
[tool.setuptools_scm]
write_to = "corrlib/version.py"
[tool.ruff.lint]
ignore = ["F403"]
ignore = ["E501"]
extend-select = [
"YTT",
"E",
"W",
"F",
]
[tool.mypy]
strict = true
implicit_reexport = false
follow_untyped_imports = false
ignore_missing_imports = true
[dependency-groups]
dev = [
"mypy>=1.19.0",
"pandas-stubs>=2.3.3.251201",
"pytest>=9.0.1",
"pytest-cov>=7.0.0",
"pytest-pretty>=1.3.0",
"ruff>=0.14.7",
]

91
tests/cli_test.py Normal file
View file

@ -0,0 +1,91 @@
from typer.testing import CliRunner
from corrlib.cli import app
import os
import sqlite3 as sql
runner = CliRunner()
def test_version():
result = runner.invoke(app, ["--version"])
assert result.exit_code == 0
assert "corrlib" in result.output
def test_init_folders(tmp_path):
dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0
assert os.path.exists(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_db(tmp_path):
dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0
assert os.path.exists(str(dataset_path / "backlogger.db"))
conn = sql.connect(str(dataset_path / "backlogger.db"))
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
expected_tables = [
'projects',
'backlogs',
]
table_names = [table[0] for table in tables]
for expected_table in expected_tables:
assert expected_table in table_names
cursor.execute("SELECT * FROM projects;")
projects = cursor.fetchall()
assert len(projects) == 0
cursor.execute("SELECT * FROM backlogs;")
backlogs = cursor.fetchall()
assert len(backlogs) == 0
cursor.execute("PRAGMA table_info('projects');")
project_columns = cursor.fetchall()
expected_project_columns = [
"id",
"aliases",
"customTags",
"owner",
"code",
"created_at",
"updated_at"
]
project_column_names = [col[1] for col in project_columns]
for expected_col in expected_project_columns:
assert expected_col in project_column_names
cursor.execute("PRAGMA table_info('backlogs');")
backlog_columns = cursor.fetchall()
expected_backlog_columns = [
"id",
"name",
"ensemble",
"code",
"path",
"project",
"customTags",
"parameters",
"parameter_file",
"created_at",
"updated_at"
]
backlog_column_names = [col[1] for col in backlog_columns]
for expected_col in expected_backlog_columns:
assert expected_col in backlog_column_names
def test_list(tmp_path):
dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0
result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "ensembles"])
assert result.exit_code == 0
result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "projects"])
assert result.exit_code == 0

View file

@ -14,4 +14,4 @@ def test_toml_check_measurement_data():
"names": ['list', 'of', 'names']
}
}
t.check_measurement_data(measurements)
t.check_measurement_data(measurements, "sfcf")

View file

@ -0,0 +1,93 @@
import corrlib.initialization as init
import os
import sqlite3 as sql
def test_init_folders(tmp_path):
dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path))
assert os.path.exists(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_folders_no_tracker(tmp_path):
dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path), tracker="None")
assert os.path.exists(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_config(tmp_path):
dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path), tracker="None")
config_path = dataset_path / ".corrlib"
assert os.path.exists(str(config_path))
from configparser import ConfigParser
config = ConfigParser()
config.read(str(config_path))
assert config.get("core", "tracker") == "None"
assert config.get("core", "version") == "1.0"
assert config.get("core", "cached") == "True"
assert config.get("paths", "db") == "backlogger.db"
assert config.get("paths", "projects_path") == "projects"
assert config.get("paths", "archive_path") == "archive"
assert config.get("paths", "toml_imports_path") == "toml_imports"
assert config.get("paths", "import_scripts_path") == "import_scripts"
def test_init_db(tmp_path):
dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db"))
conn = sql.connect(str(dataset_path / "backlogger.db"))
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
expected_tables = [
'projects',
'backlogs',
]
table_names = [table[0] for table in tables]
for expected_table in expected_tables:
assert expected_table in table_names
cursor.execute("SELECT * FROM projects;")
projects = cursor.fetchall()
assert len(projects) == 0
cursor.execute("SELECT * FROM backlogs;")
backlogs = cursor.fetchall()
assert len(backlogs) == 0
cursor.execute("PRAGMA table_info('projects');")
project_columns = cursor.fetchall()
expected_project_columns = [
"id",
"aliases",
"customTags",
"owner",
"code",
"created_at",
"updated_at"
]
project_column_names = [col[1] for col in project_columns]
for expected_col in expected_project_columns:
assert expected_col in project_column_names
cursor.execute("PRAGMA table_info('backlogs');")
backlog_columns = cursor.fetchall()
expected_backlog_columns = [
"id",
"name",
"ensemble",
"code",
"path",
"project",
"customTags",
"parameters",
"parameter_file",
"created_at",
"updated_at"
]
backlog_column_names = [col[1] for col in backlog_columns]
for expected_col in expected_backlog_columns:
assert expected_col in backlog_column_names

31
tests/tools_test.py Normal file
View file

@ -0,0 +1,31 @@
from corrlib import tools as tl
def test_m2k():
for m in [0.1, 0.5, 1.0]:
expected_k = 1 / (2 * m + 8)
assert tl.m2k(m) == expected_k
def test_k2m():
for m in [0.1, 0.5, 1.0]:
assert tl.k2m(m) == (1/(2*m))-4
def test_k2m_m2k():
for m in [0.1, 0.5, 1.0]:
k = tl.m2k(m)
m_converted = tl.k2m(k)
assert abs(m - m_converted) < 1e-9
def test_str2list():
assert tl.str2list("a,b,c") == ["a", "b", "c"]
assert tl.str2list("1,2,3") == ["1", "2", "3"]
def test_list2str():
assert tl.list2str(["a", "b", "c"]) == "a,b,c"
assert tl.list2str(["1", "2", "3"]) == "1,2,3"

2518
uv.lock generated Normal file

File diff suppressed because it is too large Load diff