Merge pull request 'tests/tools' (#22) from tests/tools into develop
All checks were successful
Mypy / mypy (push) Successful in 1m14s
Pytest / pytest (3.12) (push) Successful in 1m17s
Pytest / pytest (3.13) (push) Successful in 1m11s
Ruff / ruff (push) Successful in 1m4s
Pytest / pytest (3.14) (push) Successful in 1m18s

Reviewed-on: https://www.kuhl-mann.de/git/git/jkuhl/corrlib/pulls/22
This commit is contained in:
Justus Kuhlmann 2026-03-23 16:26:07 +01:00
commit 99ec6afdfc
16 changed files with 216 additions and 142 deletions

View file

@ -11,6 +11,7 @@ from .meas_io import load_record as mio_load_record
import os import os
from pyerrors import Corr from pyerrors import Corr
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path
app = typer.Typer() app = typer.Typer()
@ -24,8 +25,8 @@ def _version_callback(value: bool) -> None:
@app.command() @app.command()
def update( def update(
path: str = typer.Option( path: Path = typer.Option(
str('./corrlib'), Path('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
@ -40,8 +41,8 @@ def update(
@app.command() @app.command()
def lister( def lister(
path: str = typer.Option( path: Path = typer.Option(
str('./corrlib'), Path('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
@ -52,8 +53,8 @@ def lister(
""" """
if entities in ['ensembles', 'Ensembles','ENSEMBLES']: if entities in ['ensembles', 'Ensembles','ENSEMBLES']:
print("Ensembles:") print("Ensembles:")
for item in os.listdir(path + "/archive"): for item in os.listdir(path / "archive"):
if os.path.isdir(os.path.join(path + "/archive", item)): if os.path.isdir(path / "archive" / item):
print(item) print(item)
elif entities == 'projects': elif entities == 'projects':
results = list_projects(path) results = list_projects(path)
@ -71,8 +72,8 @@ def lister(
@app.command() @app.command()
def alias_add( def alias_add(
path: str = typer.Option( path: Path = typer.Option(
str('./corrlib'), Path('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
@ -89,8 +90,8 @@ def alias_add(
@app.command() @app.command()
def find( def find(
path: str = typer.Option( path: Path = typer.Option(
str('./corrlib'), Path('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
@ -116,8 +117,8 @@ def find(
@app.command() @app.command()
def stat( def stat(
path: str = typer.Option( path: Path = typer.Option(
str('./corrlib'), Path('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
@ -136,8 +137,8 @@ def stat(
@app.command() @app.command()
def importer( def importer(
path: str = typer.Option( path: Path = typer.Option(
str('./corrlib'), Path('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
@ -159,8 +160,8 @@ def importer(
@app.command() @app.command()
def reimporter( def reimporter(
path: str = typer.Option( path: Path = typer.Option(
str('./corrlib'), Path('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
@ -183,8 +184,8 @@ def reimporter(
@app.command() @app.command()
def init( def init(
path: str = typer.Option( path: Path = typer.Option(
str('./corrlib'), Path('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),
@ -203,8 +204,8 @@ def init(
@app.command() @app.command()
def drop_cache( def drop_cache(
path: str = typer.Option( path: Path = typer.Option(
str('./corrlib'), Path('./corrlib'),
"--dataset", "--dataset",
"-d", "-d",
), ),

View file

@ -7,9 +7,10 @@ from .input.implementations import codes
from .tools import k2m, get_db_file from .tools import k2m, get_db_file
from .tracker import get from .tracker import get
from typing import Any, Optional from typing import Any, Optional
from pathlib import Path
def _project_lookup_by_alias(db: str, alias: str) -> str: def _project_lookup_by_alias(db: Path, alias: str) -> str:
""" """
Lookup a projects UUID by its (human-readable) alias. Lookup a projects UUID by its (human-readable) alias.
@ -37,7 +38,7 @@ def _project_lookup_by_alias(db: str, alias: str) -> str:
return str(results[0][0]) return str(results[0][0])
def _project_lookup_by_id(db: str, uuid: str) -> list[tuple[str, str]]: def _project_lookup_by_id(db: Path, uuid: str) -> list[tuple[str, str]]:
""" """
Return the project information available in the database by UUID. Return the project information available in the database by UUID.
@ -61,7 +62,7 @@ def _project_lookup_by_id(db: str, uuid: str) -> list[tuple[str, str]]:
return results return results
def _db_lookup(db: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, def _db_lookup(db: Path, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None,
created_before: Optional[str]=None, created_after: Optional[Any]=None, updated_before: Optional[Any]=None, updated_after: Optional[Any]=None) -> pd.DataFrame: created_before: Optional[str]=None, created_after: Optional[Any]=None, updated_before: Optional[Any]=None, updated_after: Optional[Any]=None) -> pd.DataFrame:
""" """
Look up a correlator record in the database by the data given to the method. Look up a correlator record in the database by the data given to the method.
@ -228,10 +229,10 @@ def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
return results.drop(drops) return results.drop(drops)
def find_record(path: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, def find_record(path: Path, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None,
created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame: created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame:
db_file = get_db_file(path) db_file = get_db_file(path)
db = os.path.join(path, db_file) db = path / db_file
if code not in codes: if code not in codes:
raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes)) raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes))
get(path, db_file) get(path, db_file)
@ -246,7 +247,7 @@ def find_record(path: str, ensemble: str, correlator_name: str, code: str, proje
return results.reset_index() return results.reset_index()
def find_project(path: str, name: str) -> str: def find_project(path: Path, name: str) -> str:
""" """
Find a project by it's human readable name. Find a project by it's human readable name.
@ -264,10 +265,10 @@ def find_project(path: str, name: str) -> str:
""" """
db_file = get_db_file(path) db_file = get_db_file(path)
get(path, db_file) get(path, db_file)
return _project_lookup_by_alias(os.path.join(path, db_file), name) return _project_lookup_by_alias(path / db_file, name)
def list_projects(path: str) -> list[tuple[str, str]]: def list_projects(path: Path) -> list[tuple[str, str]]:
""" """
List all projects known to the library. List all projects known to the library.

View file

@ -1,27 +1,28 @@
import os import os
from .tracker import save from .tracker import save
import git import git
from pathlib import Path
GITMODULES_FILE = '.gitmodules' GITMODULES_FILE = '.gitmodules'
def move_submodule(repo_path: str, old_path: str, new_path: str) -> None: def move_submodule(repo_path: Path, old_path: Path, new_path: Path) -> None:
""" """
Move a submodule to a new location. Move a submodule to a new location.
Parameters Parameters
---------- ----------
repo_path: str repo_path: Path
Path to the repository. Path to the repository.
old_path: str old_path: Path
The old path of the module. The old path of the module.
new_path: str new_path: Path
The new path of the module. The new path of the module.
""" """
os.rename(os.path.join(repo_path, old_path), os.path.join(repo_path, new_path)) os.rename(repo_path / old_path, repo_path / new_path)
gitmodules_file_path = os.path.join(repo_path, GITMODULES_FILE) gitmodules_file_path = repo_path / GITMODULES_FILE
# update paths in .gitmodules # update paths in .gitmodules
with open(gitmodules_file_path, 'r') as file: with open(gitmodules_file_path, 'r') as file:
@ -29,8 +30,8 @@ def move_submodule(repo_path: str, old_path: str, new_path: str) -> None:
updated_lines = [] updated_lines = []
for line in lines: for line in lines:
if old_path in line: if str(old_path) in line:
line = line.replace(old_path, new_path) line = line.replace(str(old_path), str(new_path))
updated_lines.append(line) updated_lines.append(line)
with open(gitmodules_file_path, 'w') as file: with open(gitmodules_file_path, 'w') as file:
@ -40,6 +41,6 @@ def move_submodule(repo_path: str, old_path: str, new_path: str) -> None:
repo = git.Repo(repo_path) repo = git.Repo(repo_path)
repo.git.add('.gitmodules') repo.git.add('.gitmodules')
# save new state of the dataset # save new state of the dataset
save(repo_path, message=f"Move module from {old_path} to {new_path}", files=['.gitmodules', repo_path]) save(repo_path, message=f"Move module from {old_path} to {new_path}", files=[Path('.gitmodules'), repo_path])
return return

View file

@ -2,9 +2,10 @@ from configparser import ConfigParser
import sqlite3 import sqlite3
import os import os
from .tracker import save, init from .tracker import save, init
from pathlib import Path
def _create_db(db: str) -> None: def _create_db(db: Path) -> None:
""" """
Create the database file and the table. Create the database file and the table.
@ -40,7 +41,7 @@ def _create_db(db: str) -> None:
return return
def _create_config(path: str, tracker: str, cached: bool) -> ConfigParser: def _create_config(path: Path, tracker: str, cached: bool) -> ConfigParser:
""" """
Create the config file construction for backlogger. Create the config file construction for backlogger.
@ -75,7 +76,7 @@ def _create_config(path: str, tracker: str, cached: bool) -> ConfigParser:
return config return config
def _write_config(path: str, config: ConfigParser) -> None: def _write_config(path: Path, config: ConfigParser) -> None:
""" """
Write the config file to disk. Write the config file to disk.
@ -91,7 +92,7 @@ def _write_config(path: str, config: ConfigParser) -> None:
return return
def create(path: str, tracker: str = 'datalad', cached: bool = True) -> None: def create(path: Path, tracker: str = 'datalad', cached: bool = True) -> None:
""" """
Create folder of backlogs. Create folder of backlogs.
@ -107,13 +108,13 @@ def create(path: str, tracker: str = 'datalad', cached: bool = True) -> None:
config = _create_config(path, tracker, cached) config = _create_config(path, tracker, cached)
init(path, tracker) init(path, tracker)
_write_config(path, config) _write_config(path, config)
_create_db(os.path.join(path, config['paths']['db'])) _create_db(path / config['paths']['db'])
os.chmod(os.path.join(path, config['paths']['db']), 0o666) os.chmod(path / config['paths']['db'], 0o666)
os.makedirs(os.path.join(path, config['paths']['projects_path'])) os.makedirs(path / config['paths']['projects_path'])
os.makedirs(os.path.join(path, config['paths']['archive_path'])) os.makedirs(path / config['paths']['archive_path'])
os.makedirs(os.path.join(path, config['paths']['toml_imports_path'])) os.makedirs(path / config['paths']['toml_imports_path'])
os.makedirs(os.path.join(path, config['paths']['import_scripts_path'], 'template.py')) os.makedirs(path / config['paths']['import_scripts_path'] / 'template.py')
with open(os.path.join(path, ".gitignore"), "w") as fp: with open(path / ".gitignore", "w") as fp:
fp.write(".cache") fp.write(".cache")
fp.close() fp.close()
save(path, message="Initialized correlator library") save(path, message="Initialized correlator library")

View file

@ -3,9 +3,10 @@ import datalad.api as dl
import os import os
import fnmatch import fnmatch
from typing import Any, Optional from typing import Any, Optional
from pathlib import Path
def read_ms1_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: def read_ms1_param(path: Path, project: str, file_in_project: str) -> dict[str, Any]:
""" """
Read the parameters for ms1 measurements from a parameter file in the project. Read the parameters for ms1 measurements from a parameter file in the project.
@ -69,7 +70,7 @@ def read_ms1_param(path: str, project: str, file_in_project: str) -> dict[str, A
return param return param
def read_ms3_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: def read_ms3_param(path: Path, project: str, file_in_project: str) -> dict[str, Any]:
""" """
Read the parameters for ms3 measurements from a parameter file in the project. Read the parameters for ms3 measurements from a parameter file in the project.
@ -103,7 +104,7 @@ def read_ms3_param(path: str, project: str, file_in_project: str) -> dict[str, A
return param return param
def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: def read_rwms(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]:
""" """
Read reweighting factor measurements from the project. Read reweighting factor measurements from the project.
@ -160,7 +161,7 @@ def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any
return rw_dict return rw_dict
def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str="", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: def extract_t0(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str="", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]:
""" """
Extract t0 measurements from the project. Extract t0 measurements from the project.
@ -234,7 +235,7 @@ def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, An
return t0_dict return t0_dict
def extract_t1(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = "", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: def extract_t1(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = "", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]:
""" """
Extract t1 measurements from the project. Extract t1 measurements from the project.

View file

@ -4,6 +4,7 @@ import json
import os import os
from typing import Any from typing import Any
from fnmatch import fnmatch from fnmatch import fnmatch
from pathlib import Path
bi_corrs: list[str] = ["f_P", "fP", "f_p", bi_corrs: list[str] = ["f_P", "fP", "f_p",
@ -80,7 +81,7 @@ for c in bib_corrs:
corr_types[c] = 'bib' corr_types[c] = 'bib'
def read_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: def read_param(path: Path, project: str, file_in_project: str) -> dict[str, Any]:
""" """
Read the parameters from the sfcf file. Read the parameters from the sfcf file.
@ -96,7 +97,7 @@ def read_param(path: str, project: str, file_in_project: str) -> dict[str, Any]:
""" """
file = path + "/projects/" + project + '/' + file_in_project file = path / "projects" / project / file_in_project
dl.get(file, dataset=path) dl.get(file, dataset=path)
with open(file, 'r') as f: with open(file, 'r') as f:
lines = f.readlines() lines = f.readlines()
@ -257,7 +258,7 @@ def get_specs(key: str, parameters: dict[str, Any], sep: str = '/') -> str:
return s return s
def read_data(path: str, project: str, dir_in_project: str, prefix: str, param: dict[str, Any], version: str = '1.0c', cfg_seperator: str = 'n', sep: str = '/', **kwargs: Any) -> dict[str, Any]: def read_data(path: Path, project: str, dir_in_project: str, prefix: str, param: dict[str, Any], version: str = '1.0c', cfg_seperator: str = 'n', sep: str = '/', **kwargs: Any) -> dict[str, Any]:
""" """
Extract the data from the sfcf file. Extract the data from the sfcf file.

View file

@ -8,9 +8,10 @@ from .find import _project_lookup_by_id
from .tools import list2str, str2list, get_db_file from .tools import list2str, str2list, get_db_file
from .tracker import get, save, unlock, clone, drop from .tracker import get, save, unlock, clone, drop
from typing import Union, Optional from typing import Union, Optional
from pathlib import Path
def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Union[list[str], None]=None, aliases: Union[list[str], None]=None, code: Union[str, None]=None) -> None: def create_project(path: Path, uuid: str, owner: Union[str, None]=None, tags: Union[list[str], None]=None, aliases: Union[list[str], None]=None, code: Union[str, None]=None) -> None:
""" """
Create a new project entry in the database. Create a new project entry in the database.
@ -48,7 +49,7 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni
return return
def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] = None) -> None: def update_project_data(path: Path, uuid: str, prop: str, value: Union[str, None] = None) -> None:
""" """
Update/Edit a project entry in the database. Update/Edit a project entry in the database.
Thin wrapper around sql3 call. Thin wrapper around sql3 call.
@ -74,9 +75,9 @@ def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None]
return return
def update_aliases(path: str, uuid: str, aliases: list[str]) -> None: def update_aliases(path: Path, uuid: str, aliases: list[str]) -> None:
db_file = get_db_file(path) db_file = get_db_file(path)
db = os.path.join(path, db_file) db = path / db_file
get(path, db_file) get(path, db_file)
known_data = _project_lookup_by_id(db, uuid)[0] known_data = _project_lookup_by_id(db, uuid)[0]
known_aliases = known_data[1] known_aliases = known_data[1]
@ -102,7 +103,7 @@ def update_aliases(path: str, uuid: str, aliases: list[str]) -> None:
return return
def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Optional[list[str]]=None, aliases: Optional[list[str]]=None, code: Optional[str]=None, isDataset: bool=True) -> str: def import_project(path: Path, url: str, owner: Union[str, None]=None, tags: Optional[list[str]]=None, aliases: Optional[list[str]]=None, code: Optional[str]=None, isDataset: bool=True) -> str:
""" """
Import a datalad dataset into the backlogger. Import a datalad dataset into the backlogger.
@ -134,14 +135,14 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti
uuid = str(conf.get("datalad.dataset.id")) uuid = str(conf.get("datalad.dataset.id"))
if not uuid: if not uuid:
raise ValueError("The dataset does not have a uuid!") raise ValueError("The dataset does not have a uuid!")
if not os.path.exists(path + "/projects/" + uuid): if not os.path.exists(path / "projects" / uuid):
db_file = get_db_file(path) db_file = get_db_file(path)
get(path, db_file) get(path, db_file)
unlock(path, db_file) unlock(path, db_file)
create_project(path, uuid, owner, tags, aliases, code) create_project(path, uuid, owner, tags, aliases, code)
move_submodule(path, 'projects/tmp', 'projects/' + uuid) move_submodule(path, Path('projects/tmp'), Path('projects') / uuid)
os.mkdir(path + '/import_scripts/' + uuid) os.mkdir(path / 'import_scripts' / uuid)
save(path, message="Import project from " + url, files=['projects/' + uuid, db_file]) save(path, message="Import project from " + url, files=[Path(f'projects/{uuid}'), db_file])
else: else:
dl.drop(tmp_path, reckless='kill') dl.drop(tmp_path, reckless='kill')
shutil.rmtree(tmp_path) shutil.rmtree(tmp_path)
@ -156,7 +157,7 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti
return uuid return uuid
def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None: def drop_project_data(path: Path, uuid: str, path_in_project: str = "") -> None:
""" """
Drop (parts of) a project to free up diskspace Drop (parts of) a project to free up diskspace
@ -169,6 +170,5 @@ def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None:
path_pn_project: str, optional path_pn_project: str, optional
If set, only the given path within the project is dropped. If set, only the given path within the project is dropped.
""" """
drop(path + "/projects/" + uuid + "/" + path_in_project) drop(path / "projects" / uuid / path_in_project)
return return

View file

@ -10,9 +10,13 @@ from .tools import get_db_file, cache_enabled
from .tracker import get, save, unlock from .tracker import get, save, unlock
import shutil import shutil
from typing import Any from typing import Any
from pathlib import Path
def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: Union[str, None]) -> None: CACHE_DIR = ".cache"
def write_measurement(path: Path, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: Union[str, None]) -> None:
""" """
Write a measurement to the backlog. Write a measurement to the backlog.
If the file for the measurement already exists, update the measurement. If the file for the measurement already exists, update the measurement.
@ -33,7 +37,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str,
The parameter file used for the measurement. The parameter file used for the measurement.
""" """
db_file = get_db_file(path) db_file = get_db_file(path)
db = os.path.join(path, db_file) db = path / db_file
files_to_save = [] files_to_save = []
@ -44,11 +48,11 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str,
conn = sqlite3.connect(db) conn = sqlite3.connect(db)
c = conn.cursor() c = conn.cursor()
for corr in measurement.keys(): for corr in measurement.keys():
file_in_archive = os.path.join('.', 'archive', ensemble, corr, uuid + '.json.gz') file_in_archive = Path('.') / 'archive' / ensemble / corr / str(uuid + '.json.gz')
file = os.path.join(path, file_in_archive) file = path / file_in_archive
known_meas = {} known_meas = {}
if not os.path.exists(os.path.join(path, '.', 'archive', ensemble, corr)): if not os.path.exists(path / 'archive' / ensemble / corr):
os.makedirs(os.path.join(path, '.', 'archive', ensemble, corr)) os.makedirs(path / 'archive' / ensemble / corr)
files_to_save.append(file_in_archive) files_to_save.append(file_in_archive)
else: else:
if os.path.exists(file): if os.path.exists(file):
@ -99,7 +103,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str,
pars[subkey] = json.dumps(parameters) pars[subkey] = json.dumps(parameters)
for subkey in subkeys: for subkey in subkeys:
parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest() parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest()
meas_path = file_in_archive + "::" + parHash meas_path = str(file_in_archive) + "::" + parHash
known_meas[parHash] = measurement[corr][subkey] known_meas[parHash] = measurement[corr][subkey]
@ -115,7 +119,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str,
return return
def load_record(path: str, meas_path: str) -> Union[Corr, Obs]: def load_record(path: Path, meas_path: str) -> Union[Corr, Obs]:
""" """
Load a list of records by their paths. Load a list of records by their paths.
@ -134,7 +138,7 @@ def load_record(path: str, meas_path: str) -> Union[Corr, Obs]:
return load_records(path, [meas_path])[0] return load_records(path, [meas_path])[0]
def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]: def load_records(path: Path, meas_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]:
""" """
Load a list of records by their paths. Load a list of records by their paths.
@ -162,11 +166,11 @@ def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = {
returned_data: list[Any] = [] returned_data: list[Any] = []
for file in needed_data.keys(): for file in needed_data.keys():
for key in list(needed_data[file]): for key in list(needed_data[file]):
if os.path.exists(cache_path(path, file, key) + ".p"): if os.path.exists(str(cache_path(path, file, key)) + ".p"):
returned_data.append(load_object(cache_path(path, file, key) + ".p")) returned_data.append(load_object(str(cache_path(path, file, key)) + ".p"))
else: else:
if file not in preloaded: if file not in preloaded:
preloaded[file] = preload(path, file) preloaded[file] = preload(path, Path(file))
returned_data.append(preloaded[file][key]) returned_data.append(preloaded[file][key])
if cache_enabled(path): if cache_enabled(path):
if not os.path.exists(cache_dir(path, file)): if not os.path.exists(cache_dir(path, file)):
@ -175,7 +179,7 @@ def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = {
return returned_data return returned_data
def cache_dir(path: str, file: str) -> str: def cache_dir(path: Path, file: str) -> Path:
""" """
Returns the directory corresponding to the cache for the given file. Returns the directory corresponding to the cache for the given file.
@ -190,14 +194,14 @@ def cache_dir(path: str, file: str) -> str:
cache_path: str cache_path: str
The path holding the cached data for the given file. The path holding the cached data for the given file.
""" """
cache_path_list = [path] cache_path_list = file.split("/")[1:]
cache_path_list.append(".cache") cache_path = path / CACHE_DIR
cache_path_list.extend(file.split("/")[1:]) for directory in cache_path_list:
cache_path = "/".join(cache_path_list) cache_path /= directory
return cache_path return cache_path
def cache_path(path: str, file: str, key: str) -> str: def cache_path(path: Path, file: str, key: str) -> Path:
""" """
Parameters Parameters
---------- ----------
@ -213,11 +217,11 @@ def cache_path(path: str, file: str, key: str) -> str:
cache_path: str cache_path: str
The path at which the measurement of the given file and key is cached. The path at which the measurement of the given file and key is cached.
""" """
cache_path = os.path.join(cache_dir(path, file), key) cache_path = cache_dir(path, file) / key
return cache_path return cache_path
def preload(path: str, file: str) -> dict[str, Any]: def preload(path: Path, file: Path) -> dict[str, Any]:
""" """
Read the contents of a file into a json dictionary with the pyerrors.json.load_json_dict method. Read the contents of a file into a json dictionary with the pyerrors.json.load_json_dict method.
@ -234,12 +238,12 @@ def preload(path: str, file: str) -> dict[str, Any]:
The data read from the file. The data read from the file.
""" """
get(path, file) get(path, file)
filedict: dict[str, Any] = pj.load_json_dict(os.path.join(path, file)) filedict: dict[str, Any] = pj.load_json_dict(path / file)
print("> read file") print("> read file")
return filedict return filedict
def drop_record(path: str, meas_path: str) -> None: def drop_record(path: Path, meas_path: str) -> None:
""" """
Drop a record by it's path. Drop a record by it's path.
@ -251,9 +255,9 @@ def drop_record(path: str, meas_path: str) -> None:
The measurement path as noted in the database. The measurement path as noted in the database.
""" """
file_in_archive = meas_path.split("::")[0] file_in_archive = meas_path.split("::")[0]
file = os.path.join(path, file_in_archive) file = path / file_in_archive
db_file = get_db_file(path) db_file = get_db_file(path)
db = os.path.join(path, db_file) db = path / db_file
get(path, db_file) get(path, db_file)
sub_key = meas_path.split("::")[1] sub_key = meas_path.split("::")[1]
unlock(path, db_file) unlock(path, db_file)
@ -268,7 +272,7 @@ def drop_record(path: str, meas_path: str) -> None:
known_meas = pj.load_json_dict(file) known_meas = pj.load_json_dict(file)
if sub_key in known_meas: if sub_key in known_meas:
del known_meas[sub_key] del known_meas[sub_key]
unlock(path, file_in_archive) unlock(path, Path(file_in_archive))
pj.dump_dict_to_json(known_meas, file) pj.dump_dict_to_json(known_meas, file)
save(path, message="Drop measurements to database", files=[db, file]) save(path, message="Drop measurements to database", files=[db, file])
return return
@ -276,7 +280,7 @@ def drop_record(path: str, meas_path: str) -> None:
raise ValueError("This measurement does not exist as a file!") raise ValueError("This measurement does not exist as a file!")
def drop_cache(path: str) -> None: def drop_cache(path: Path) -> None:
""" """
Drop the cache directory of the library. Drop the cache directory of the library.
@ -285,7 +289,7 @@ def drop_cache(path: str) -> None:
path: str path: str
The path of the library. The path of the library.
""" """
cache_dir = os.path.join(path, ".cache") cache_dir = path / ".cache"
for f in os.listdir(cache_dir): for f in os.listdir(cache_dir):
shutil.rmtree(os.path.join(cache_dir, f)) shutil.rmtree(cache_dir / f)
return return

View file

@ -19,6 +19,7 @@ from .meas_io import write_measurement
import os import os
from .input.implementations import codes as known_codes from .input.implementations import codes as known_codes
from typing import Any from typing import Any
from pathlib import Path
def replace_string(string: str, name: str, val: str) -> str: def replace_string(string: str, name: str, val: str) -> str:
@ -126,7 +127,7 @@ def check_measurement_data(measurements: dict[str, dict[str, str]], code: str) -
return return
def import_tomls(path: str, files: list[str], copy_files: bool=True) -> None: def import_tomls(path: Path, files: list[str], copy_files: bool=True) -> None:
""" """
Import multiple toml files. Import multiple toml files.
@ -144,7 +145,7 @@ def import_tomls(path: str, files: list[str], copy_files: bool=True) -> None:
return return
def import_toml(path: str, file: str, copy_file: bool=True) -> None: def import_toml(path: Path, file: str, copy_file: bool=True) -> None:
""" """
Import a project decribed by a .toml file. Import a project decribed by a .toml file.
@ -171,7 +172,7 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None:
aliases = project.get('aliases', []) aliases = project.get('aliases', [])
uuid = project.get('uuid', None) uuid = project.get('uuid', None)
if uuid is not None: if uuid is not None:
if not os.path.exists(path + "/projects/" + uuid): if not os.path.exists(path / "projects" / uuid):
uuid = import_project(path, project['url'], aliases=aliases) uuid = import_project(path, project['url'], aliases=aliases)
else: else:
update_aliases(path, uuid, aliases) update_aliases(path, uuid, aliases)
@ -213,18 +214,18 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None:
write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else None)) write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else None))
print(mname + " imported.") print(mname + " imported.")
if not os.path.exists(os.path.join(path, "toml_imports", uuid)): if not os.path.exists(path / "toml_imports" / uuid):
os.makedirs(os.path.join(path, "toml_imports", uuid)) os.makedirs(path / "toml_imports" / uuid)
if copy_file: if copy_file:
import_file = os.path.join(path, "toml_imports", uuid, file.split("/")[-1]) import_file = path / "toml_imports" / uuid / file.split("/")[-1]
shutil.copy(file, import_file) shutil.copy(file, import_file)
save(path, files=[import_file], message="Import using " + import_file) save(path, files=[import_file], message=f"Import using {import_file}")
print("File copied to " + import_file) print(f"File copied to {import_file}")
print("Imported project.") print("Imported project.")
return return
def reimport_project(path: str, uuid: str) -> None: def reimport_project(path: Path, uuid: str) -> None:
""" """
Reimport an existing project using the files that are already available for this project. Reimport an existing project using the files that are already available for this project.
@ -235,14 +236,14 @@ def reimport_project(path: str, uuid: str) -> None:
uuid: str uuid: str
uuid of the project that is to be reimported. uuid of the project that is to be reimported.
""" """
config_path = "/".join([path, "import_scripts", uuid]) config_path = path / "import_scripts" / uuid
for p, filenames, dirnames in os.walk(config_path): for p, filenames, dirnames in os.walk(config_path):
for fname in filenames: for fname in filenames:
import_toml(path, os.path.join(config_path, fname), copy_file=False) import_toml(path, os.path.join(config_path, fname), copy_file=False)
return return
def update_project(path: str, uuid: str) -> None: def update_project(path: Path, uuid: str) -> None:
""" """
Update all entries associated with a given project. Update all entries associated with a given project.

View file

@ -1,6 +1,7 @@
import os import os
from configparser import ConfigParser from configparser import ConfigParser
from typing import Any from typing import Any
from pathlib import Path
CONFIG_FILENAME = ".corrlib" CONFIG_FILENAME = ".corrlib"
cached: bool = True cached: bool = True
@ -73,7 +74,7 @@ def k2m(k: float) -> float:
return (1/(2*k))-4 return (1/(2*k))-4
def set_config(path: str, section: str, option: str, value: Any) -> None: def set_config(path: Path, section: str, option: str, value: Any) -> None:
""" """
Set configuration parameters for the library. Set configuration parameters for the library.
@ -88,7 +89,7 @@ def set_config(path: str, section: str, option: str, value: Any) -> None:
value: Any value: Any
The value we set the option to. The value we set the option to.
""" """
config_path = os.path.join(path, '.corrlib') config_path = os.path.join(path, CONFIG_FILENAME)
config = ConfigParser() config = ConfigParser()
if os.path.exists(config_path): if os.path.exists(config_path):
config.read(config_path) config.read(config_path)
@ -100,7 +101,7 @@ def set_config(path: str, section: str, option: str, value: Any) -> None:
return return
def get_db_file(path: str) -> str: def get_db_file(path: Path) -> Path:
""" """
Get the database file associated with the library at the given path. Get the database file associated with the library at the given path.
@ -118,11 +119,13 @@ def get_db_file(path: str) -> str:
config = ConfigParser() config = ConfigParser()
if os.path.exists(config_path): if os.path.exists(config_path):
config.read(config_path) config.read(config_path)
db_file = config.get('paths', 'db', fallback='backlogger.db') else:
raise FileNotFoundError("Configuration file not found.")
db_file = Path(config.get('paths', 'db', fallback='backlogger.db'))
return db_file return db_file
def cache_enabled(path: str) -> bool: def cache_enabled(path: Path) -> bool:
""" """
Check, whether the library is cached. Check, whether the library is cached.
Fallback is true. Fallback is true.
@ -141,6 +144,10 @@ def cache_enabled(path: str) -> bool:
config = ConfigParser() config = ConfigParser()
if os.path.exists(config_path): if os.path.exists(config_path):
config.read(config_path) config.read(config_path)
else:
raise FileNotFoundError("Configuration file not found.")
cached_str = config.get('core', 'cached', fallback='True') cached_str = config.get('core', 'cached', fallback='True')
if cached_str not in ['True', 'False']:
raise ValueError(f"String {cached_str} is not a valid option, only True and False are allowed!")
cached_bool = cached_str == ('True') cached_bool = cached_str == ('True')
return cached_bool return cached_bool

View file

@ -4,9 +4,10 @@ import datalad.api as dl
from typing import Optional from typing import Optional
import shutil import shutil
from .tools import get_db_file from .tools import get_db_file
from pathlib import Path
def get_tracker(path: str) -> str: def get_tracker(path: Path) -> str:
""" """
Get the tracker used in the dataset located at path. Get the tracker used in the dataset located at path.
@ -30,7 +31,7 @@ def get_tracker(path: str) -> str:
return tracker return tracker
def get(path: str, file: str) -> None: def get(path: Path, file: Path) -> None:
""" """
Wrapper function to get a file from the dataset located at path with the specified tracker. Wrapper function to get a file from the dataset located at path with the specified tracker.
@ -56,7 +57,7 @@ def get(path: str, file: str) -> None:
return return
def save(path: str, message: str, files: Optional[list[str]]=None) -> None: def save(path: Path, message: str, files: Optional[list[Path]]=None) -> None:
""" """
Wrapper function to save a file to the dataset located at path with the specified tracker. Wrapper function to save a file to the dataset located at path with the specified tracker.
@ -72,7 +73,7 @@ def save(path: str, message: str, files: Optional[list[str]]=None) -> None:
tracker = get_tracker(path) tracker = get_tracker(path)
if tracker == 'datalad': if tracker == 'datalad':
if files is not None: if files is not None:
files = [os.path.join(path, f) for f in files] files = [path / f for f in files]
dl.save(files, message=message, dataset=path) dl.save(files, message=message, dataset=path)
elif tracker == 'None': elif tracker == 'None':
Warning("Tracker 'None' does not implement save.") Warning("Tracker 'None' does not implement save.")
@ -81,7 +82,7 @@ def save(path: str, message: str, files: Optional[list[str]]=None) -> None:
raise ValueError(f"Tracker {tracker} is not supported.") raise ValueError(f"Tracker {tracker} is not supported.")
def init(path: str, tracker: str='datalad') -> None: def init(path: Path, tracker: str='datalad') -> None:
""" """
Initialize a dataset at the specified path with the specified tracker. Initialize a dataset at the specified path with the specified tracker.
@ -101,7 +102,7 @@ def init(path: str, tracker: str='datalad') -> None:
return return
def unlock(path: str, file: str) -> None: def unlock(path: Path, file: Path) -> None:
""" """
Wrapper function to unlock a file in the dataset located at path with the specified tracker. Wrapper function to unlock a file in the dataset located at path with the specified tracker.
@ -123,7 +124,7 @@ def unlock(path: str, file: str) -> None:
return return
def clone(path: str, source: str, target: str) -> None: def clone(path: Path, source: str, target: str) -> None:
""" """
Wrapper function to clone a dataset from source to target with the specified tracker. Wrapper function to clone a dataset from source to target with the specified tracker.
Parameters Parameters
@ -147,7 +148,7 @@ def clone(path: str, source: str, target: str) -> None:
return return
def drop(path: str, reckless: Optional[str]=None) -> None: def drop(path: Path, reckless: Optional[str]=None) -> None:
""" """
Wrapper function to drop data from a dataset located at path with the specified tracker. Wrapper function to drop data from a dataset located at path with the specified tracker.

View file

@ -2,18 +2,19 @@ from typer.testing import CliRunner
from corrlib.cli import app from corrlib.cli import app
import os import os
import sqlite3 as sql import sqlite3 as sql
from pathlib import Path
runner = CliRunner() runner = CliRunner()
def test_version(): def test_version() -> None:
result = runner.invoke(app, ["--version"]) result = runner.invoke(app, ["--version"])
assert result.exit_code == 0 assert result.exit_code == 0
assert "corrlib" in result.output assert "corrlib" in result.output
def test_init_folders(tmp_path): def test_init_folders(tmp_path: Path) -> None:
dataset_path = tmp_path / "test_dataset" dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0 assert result.exit_code == 0
@ -21,7 +22,7 @@ def test_init_folders(tmp_path):
assert os.path.exists(str(dataset_path / "backlogger.db")) assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_db(tmp_path): def test_init_db(tmp_path: Path) -> None:
dataset_path = tmp_path / "test_dataset" dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0 assert result.exit_code == 0
@ -81,7 +82,7 @@ def test_init_db(tmp_path):
assert expected_col in backlog_column_names assert expected_col in backlog_column_names
def test_list(tmp_path): def test_list(tmp_path: Path) -> None:
dataset_path = tmp_path / "test_dataset" dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0 assert result.exit_code == 0

View file

@ -1,7 +1,7 @@
import corrlib.toml as t import corrlib.toml as t
def test_toml_check_measurement_data(): def test_toml_check_measurement_data() -> None:
measurements = { measurements = {
"a": "a":
{ {

View file

@ -1,7 +1,7 @@
import corrlib.input.sfcf as input import corrlib.input.sfcf as input
import json import json
def test_get_specs(): def test_get_specs() -> None:
parameters = { parameters = {
'crr': [ 'crr': [
'f_P', 'f_A' 'f_P', 'f_A'
@ -26,4 +26,4 @@ def test_get_specs():
key = "f_P/q1 q2/1/0/0" key = "f_P/q1 q2/1/0/0"
specs = json.loads(input.get_specs(key, parameters)) specs = json.loads(input.get_specs(key, parameters))
assert specs['quarks'] == ['a', 'b'] assert specs['quarks'] == ['a', 'b']
assert specs['wf1'][0] == [1, [0, 0]] assert specs['wf1'][0] == [1, [0, 0]]

View file

@ -1,24 +1,25 @@
import corrlib.initialization as init import corrlib.initialization as init
import os import os
import sqlite3 as sql import sqlite3 as sql
from pathlib import Path
def test_init_folders(tmp_path): def test_init_folders(tmp_path: Path) -> None:
dataset_path = tmp_path / "test_dataset" dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path)) init.create(dataset_path)
assert os.path.exists(str(dataset_path)) assert os.path.exists(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db")) assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_folders_no_tracker(tmp_path): def test_init_folders_no_tracker(tmp_path: Path) -> None:
dataset_path = tmp_path / "test_dataset" dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path), tracker="None") init.create(dataset_path, tracker="None")
assert os.path.exists(str(dataset_path)) assert os.path.exists(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db")) assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_config(tmp_path): def test_init_config(tmp_path: Path) -> None:
dataset_path = tmp_path / "test_dataset" dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path), tracker="None") init.create(dataset_path, tracker="None")
config_path = dataset_path / ".corrlib" config_path = dataset_path / ".corrlib"
assert os.path.exists(str(config_path)) assert os.path.exists(str(config_path))
from configparser import ConfigParser from configparser import ConfigParser
@ -34,9 +35,9 @@ def test_init_config(tmp_path):
assert config.get("paths", "import_scripts_path") == "import_scripts" assert config.get("paths", "import_scripts_path") == "import_scripts"
def test_init_db(tmp_path): def test_init_db(tmp_path: Path) -> None:
dataset_path = tmp_path / "test_dataset" dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path)) init.create(dataset_path)
assert os.path.exists(str(dataset_path / "backlogger.db")) assert os.path.exists(str(dataset_path / "backlogger.db"))
conn = sql.connect(str(dataset_path / "backlogger.db")) conn = sql.connect(str(dataset_path / "backlogger.db"))
cursor = conn.cursor() cursor = conn.cursor()

View file

@ -1,31 +1,84 @@
from corrlib import tools as tl from corrlib import tools as tl
from configparser import ConfigParser
from pathlib import Path
import pytest
def test_m2k(): def test_m2k() -> None:
for m in [0.1, 0.5, 1.0]: for m in [0.1, 0.5, 1.0]:
expected_k = 1 / (2 * m + 8) expected_k = 1 / (2 * m + 8)
assert tl.m2k(m) == expected_k assert tl.m2k(m) == expected_k
def test_k2m(): def test_k2m() -> None:
for m in [0.1, 0.5, 1.0]: for m in [0.1, 0.5, 1.0]:
assert tl.k2m(m) == (1/(2*m))-4 assert tl.k2m(m) == (1/(2*m))-4
def test_k2m_m2k(): def test_k2m_m2k() -> None:
for m in [0.1, 0.5, 1.0]: for m in [0.1, 0.5, 1.0]:
k = tl.m2k(m) k = tl.m2k(m)
m_converted = tl.k2m(k) m_converted = tl.k2m(k)
assert abs(m - m_converted) < 1e-9 assert abs(m - m_converted) < 1e-9
def test_str2list(): def test_str2list() -> None:
assert tl.str2list("a,b,c") == ["a", "b", "c"] assert tl.str2list("a,b,c") == ["a", "b", "c"]
assert tl.str2list("1,2,3") == ["1", "2", "3"] assert tl.str2list("1,2,3") == ["1", "2", "3"]
def test_list2str(): def test_list2str() -> None:
assert tl.list2str(["a", "b", "c"]) == "a,b,c" assert tl.list2str(["a", "b", "c"]) == "a,b,c"
assert tl.list2str(["1", "2", "3"]) == "1,2,3" assert tl.list2str(["1", "2", "3"]) == "1,2,3"
def test_set_config(tmp_path: Path) -> None:
section = "core"
option = "test_option"
value = "test_value"
# config is not yet available
tl.set_config(tmp_path, section, option, value)
config_path = tmp_path / '.corrlib'
config = ConfigParser()
config.read(config_path)
assert config.get('core', 'test_option', fallback="not the value") == "test_value"
# now, a config file is already present
section = "core"
option = "test_option2"
value = "test_value2"
tl.set_config(tmp_path, section, option, value)
config_path = tmp_path / '.corrlib'
config = ConfigParser()
config.read(config_path)
assert config.get('core', 'test_option2', fallback="not the value") == "test_value2"
# update option 2
section = "core"
option = "test_option2"
value = "test_value3"
tl.set_config(tmp_path, section, option, value)
config_path = tmp_path / '.corrlib'
config = ConfigParser()
config.read(config_path)
assert config.get('core', 'test_option2', fallback="not the value") == "test_value3"
def test_get_db_file(tmp_path: Path) -> None:
section = "paths"
option = "db"
value = "test_value"
# config is not yet available
tl.set_config(tmp_path, section, option, value)
assert tl.get_db_file(tmp_path) == Path("test_value")
def test_cache_enabled(tmp_path: Path) -> None:
section = "core"
option = "cached"
# config is not yet available
tl.set_config(tmp_path, section, option, "True")
assert tl.cache_enabled(tmp_path)
tl.set_config(tmp_path, section, option, "False")
assert not tl.cache_enabled(tmp_path)
tl.set_config(tmp_path, section, option, "lalala")
with pytest.raises(ValueError):
tl.cache_enabled(tmp_path)