diff --git a/corrlib/cli.py b/corrlib/cli.py index 4e1b65e..b28692a 100644 --- a/corrlib/cli.py +++ b/corrlib/cli.py @@ -11,6 +11,7 @@ from .meas_io import load_record as mio_load_record import os from pyerrors import Corr from importlib.metadata import version +from pathlib import Path app = typer.Typer() @@ -24,8 +25,8 @@ def _version_callback(value: bool) -> None: @app.command() def update( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -40,8 +41,8 @@ def update( @app.command() def lister( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -52,8 +53,8 @@ def lister( """ if entities in ['ensembles', 'Ensembles','ENSEMBLES']: print("Ensembles:") - for item in os.listdir(path + "/archive"): - if os.path.isdir(os.path.join(path + "/archive", item)): + for item in os.listdir(path / "archive"): + if os.path.isdir(path / "archive" / item): print(item) elif entities == 'projects': results = list_projects(path) @@ -71,8 +72,8 @@ def lister( @app.command() def alias_add( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -89,8 +90,8 @@ def alias_add( @app.command() def find( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -116,8 +117,8 @@ def find( @app.command() def stat( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -136,8 +137,8 @@ def stat( @app.command() def importer( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -159,8 +160,8 @@ def importer( @app.command() def reimporter( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -183,8 +184,8 @@ def reimporter( @app.command() def init( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -203,8 +204,8 @@ def init( @app.command() def drop_cache( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), diff --git a/corrlib/find.py b/corrlib/find.py index 022a3f5..faef5db 100644 --- a/corrlib/find.py +++ b/corrlib/find.py @@ -7,9 +7,10 @@ from .input.implementations import codes from .tools import k2m, get_db_file from .tracker import get from typing import Any, Optional +from pathlib import Path -def _project_lookup_by_alias(db: str, alias: str) -> str: +def _project_lookup_by_alias(db: Path, alias: str) -> str: """ Lookup a projects UUID by its (human-readable) alias. @@ -37,7 +38,7 @@ def _project_lookup_by_alias(db: str, alias: str) -> str: return str(results[0][0]) -def _project_lookup_by_id(db: str, uuid: str) -> list[tuple[str, str]]: +def _project_lookup_by_id(db: Path, uuid: str) -> list[tuple[str, str]]: """ Return the project information available in the database by UUID. @@ -61,7 +62,7 @@ def _project_lookup_by_id(db: str, uuid: str) -> list[tuple[str, str]]: return results -def _db_lookup(db: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, +def _db_lookup(db: Path, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, created_before: Optional[str]=None, created_after: Optional[Any]=None, updated_before: Optional[Any]=None, updated_after: Optional[Any]=None) -> pd.DataFrame: """ Look up a correlator record in the database by the data given to the method. @@ -228,10 +229,10 @@ def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: return results.drop(drops) -def find_record(path: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, +def find_record(path: Path, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame: db_file = get_db_file(path) - db = os.path.join(path, db_file) + db = path / db_file if code not in codes: raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes)) get(path, db_file) @@ -246,7 +247,7 @@ def find_record(path: str, ensemble: str, correlator_name: str, code: str, proje return results.reset_index() -def find_project(path: str, name: str) -> str: +def find_project(path: Path, name: str) -> str: """ Find a project by it's human readable name. @@ -264,10 +265,10 @@ def find_project(path: str, name: str) -> str: """ db_file = get_db_file(path) get(path, db_file) - return _project_lookup_by_alias(os.path.join(path, db_file), name) + return _project_lookup_by_alias(path / db_file, name) -def list_projects(path: str) -> list[tuple[str, str]]: +def list_projects(path: Path) -> list[tuple[str, str]]: """ List all projects known to the library. diff --git a/corrlib/git_tools.py b/corrlib/git_tools.py index c6e7522..d77f109 100644 --- a/corrlib/git_tools.py +++ b/corrlib/git_tools.py @@ -1,27 +1,28 @@ import os from .tracker import save import git +from pathlib import Path GITMODULES_FILE = '.gitmodules' -def move_submodule(repo_path: str, old_path: str, new_path: str) -> None: +def move_submodule(repo_path: Path, old_path: Path, new_path: Path) -> None: """ Move a submodule to a new location. Parameters ---------- - repo_path: str + repo_path: Path Path to the repository. - old_path: str + old_path: Path The old path of the module. - new_path: str + new_path: Path The new path of the module. """ - os.rename(os.path.join(repo_path, old_path), os.path.join(repo_path, new_path)) + os.rename(repo_path / old_path, repo_path / new_path) - gitmodules_file_path = os.path.join(repo_path, GITMODULES_FILE) + gitmodules_file_path = repo_path / GITMODULES_FILE # update paths in .gitmodules with open(gitmodules_file_path, 'r') as file: @@ -29,8 +30,8 @@ def move_submodule(repo_path: str, old_path: str, new_path: str) -> None: updated_lines = [] for line in lines: - if old_path in line: - line = line.replace(old_path, new_path) + if str(old_path) in line: + line = line.replace(str(old_path), str(new_path)) updated_lines.append(line) with open(gitmodules_file_path, 'w') as file: @@ -40,6 +41,6 @@ def move_submodule(repo_path: str, old_path: str, new_path: str) -> None: repo = git.Repo(repo_path) repo.git.add('.gitmodules') # save new state of the dataset - save(repo_path, message=f"Move module from {old_path} to {new_path}", files=['.gitmodules', repo_path]) + save(repo_path, message=f"Move module from {old_path} to {new_path}", files=[Path('.gitmodules'), repo_path]) return diff --git a/corrlib/initialization.py b/corrlib/initialization.py index bb71db6..c06a201 100644 --- a/corrlib/initialization.py +++ b/corrlib/initialization.py @@ -2,9 +2,10 @@ from configparser import ConfigParser import sqlite3 import os from .tracker import save, init +from pathlib import Path -def _create_db(db: str) -> None: +def _create_db(db: Path) -> None: """ Create the database file and the table. @@ -40,7 +41,7 @@ def _create_db(db: str) -> None: return -def _create_config(path: str, tracker: str, cached: bool) -> ConfigParser: +def _create_config(path: Path, tracker: str, cached: bool) -> ConfigParser: """ Create the config file construction for backlogger. @@ -75,7 +76,7 @@ def _create_config(path: str, tracker: str, cached: bool) -> ConfigParser: return config -def _write_config(path: str, config: ConfigParser) -> None: +def _write_config(path: Path, config: ConfigParser) -> None: """ Write the config file to disk. @@ -91,7 +92,7 @@ def _write_config(path: str, config: ConfigParser) -> None: return -def create(path: str, tracker: str = 'datalad', cached: bool = True) -> None: +def create(path: Path, tracker: str = 'datalad', cached: bool = True) -> None: """ Create folder of backlogs. @@ -107,13 +108,13 @@ def create(path: str, tracker: str = 'datalad', cached: bool = True) -> None: config = _create_config(path, tracker, cached) init(path, tracker) _write_config(path, config) - _create_db(os.path.join(path, config['paths']['db'])) - os.chmod(os.path.join(path, config['paths']['db']), 0o666) - os.makedirs(os.path.join(path, config['paths']['projects_path'])) - os.makedirs(os.path.join(path, config['paths']['archive_path'])) - os.makedirs(os.path.join(path, config['paths']['toml_imports_path'])) - os.makedirs(os.path.join(path, config['paths']['import_scripts_path'], 'template.py')) - with open(os.path.join(path, ".gitignore"), "w") as fp: + _create_db(path / config['paths']['db']) + os.chmod(path / config['paths']['db'], 0o666) + os.makedirs(path / config['paths']['projects_path']) + os.makedirs(path / config['paths']['archive_path']) + os.makedirs(path / config['paths']['toml_imports_path']) + os.makedirs(path / config['paths']['import_scripts_path'] / 'template.py') + with open(path / ".gitignore", "w") as fp: fp.write(".cache") fp.close() save(path, message="Initialized correlator library") diff --git a/corrlib/input/openQCD.py b/corrlib/input/openQCD.py index 71ebec6..a3bce6f 100644 --- a/corrlib/input/openQCD.py +++ b/corrlib/input/openQCD.py @@ -3,9 +3,10 @@ import datalad.api as dl import os import fnmatch from typing import Any, Optional +from pathlib import Path -def read_ms1_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: +def read_ms1_param(path: Path, project: str, file_in_project: str) -> dict[str, Any]: """ Read the parameters for ms1 measurements from a parameter file in the project. @@ -69,7 +70,7 @@ def read_ms1_param(path: str, project: str, file_in_project: str) -> dict[str, A return param -def read_ms3_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: +def read_ms3_param(path: Path, project: str, file_in_project: str) -> dict[str, Any]: """ Read the parameters for ms3 measurements from a parameter file in the project. @@ -103,7 +104,7 @@ def read_ms3_param(path: str, project: str, file_in_project: str) -> dict[str, A return param -def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: +def read_rwms(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: """ Read reweighting factor measurements from the project. @@ -160,7 +161,7 @@ def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any return rw_dict -def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str="", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: +def extract_t0(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str="", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: """ Extract t0 measurements from the project. @@ -234,7 +235,7 @@ def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, An return t0_dict -def extract_t1(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = "", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: +def extract_t1(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = "", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: """ Extract t1 measurements from the project. diff --git a/corrlib/input/sfcf.py b/corrlib/input/sfcf.py index 8b6e1a3..acd8261 100644 --- a/corrlib/input/sfcf.py +++ b/corrlib/input/sfcf.py @@ -4,6 +4,7 @@ import json import os from typing import Any from fnmatch import fnmatch +from pathlib import Path bi_corrs: list[str] = ["f_P", "fP", "f_p", @@ -80,7 +81,7 @@ for c in bib_corrs: corr_types[c] = 'bib' -def read_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: +def read_param(path: Path, project: str, file_in_project: str) -> dict[str, Any]: """ Read the parameters from the sfcf file. @@ -96,7 +97,7 @@ def read_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: """ - file = path + "/projects/" + project + '/' + file_in_project + file = path / "projects" / project / file_in_project dl.get(file, dataset=path) with open(file, 'r') as f: lines = f.readlines() @@ -257,7 +258,7 @@ def get_specs(key: str, parameters: dict[str, Any], sep: str = '/') -> str: return s -def read_data(path: str, project: str, dir_in_project: str, prefix: str, param: dict[str, Any], version: str = '1.0c', cfg_seperator: str = 'n', sep: str = '/', **kwargs: Any) -> dict[str, Any]: +def read_data(path: Path, project: str, dir_in_project: str, prefix: str, param: dict[str, Any], version: str = '1.0c', cfg_seperator: str = 'n', sep: str = '/', **kwargs: Any) -> dict[str, Any]: """ Extract the data from the sfcf file. diff --git a/corrlib/main.py b/corrlib/main.py index 88b99b3..831b69d 100644 --- a/corrlib/main.py +++ b/corrlib/main.py @@ -8,9 +8,10 @@ from .find import _project_lookup_by_id from .tools import list2str, str2list, get_db_file from .tracker import get, save, unlock, clone, drop from typing import Union, Optional +from pathlib import Path -def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Union[list[str], None]=None, aliases: Union[list[str], None]=None, code: Union[str, None]=None) -> None: +def create_project(path: Path, uuid: str, owner: Union[str, None]=None, tags: Union[list[str], None]=None, aliases: Union[list[str], None]=None, code: Union[str, None]=None) -> None: """ Create a new project entry in the database. @@ -48,7 +49,7 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni return -def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] = None) -> None: +def update_project_data(path: Path, uuid: str, prop: str, value: Union[str, None] = None) -> None: """ Update/Edit a project entry in the database. Thin wrapper around sql3 call. @@ -74,9 +75,9 @@ def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] return -def update_aliases(path: str, uuid: str, aliases: list[str]) -> None: +def update_aliases(path: Path, uuid: str, aliases: list[str]) -> None: db_file = get_db_file(path) - db = os.path.join(path, db_file) + db = path / db_file get(path, db_file) known_data = _project_lookup_by_id(db, uuid)[0] known_aliases = known_data[1] @@ -102,7 +103,7 @@ def update_aliases(path: str, uuid: str, aliases: list[str]) -> None: return -def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Optional[list[str]]=None, aliases: Optional[list[str]]=None, code: Optional[str]=None, isDataset: bool=True) -> str: +def import_project(path: Path, url: str, owner: Union[str, None]=None, tags: Optional[list[str]]=None, aliases: Optional[list[str]]=None, code: Optional[str]=None, isDataset: bool=True) -> str: """ Import a datalad dataset into the backlogger. @@ -134,14 +135,14 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti uuid = str(conf.get("datalad.dataset.id")) if not uuid: raise ValueError("The dataset does not have a uuid!") - if not os.path.exists(path + "/projects/" + uuid): + if not os.path.exists(path / "projects" / uuid): db_file = get_db_file(path) get(path, db_file) unlock(path, db_file) create_project(path, uuid, owner, tags, aliases, code) - move_submodule(path, 'projects/tmp', 'projects/' + uuid) - os.mkdir(path + '/import_scripts/' + uuid) - save(path, message="Import project from " + url, files=['projects/' + uuid, db_file]) + move_submodule(path, Path('projects/tmp'), Path('projects') / uuid) + os.mkdir(path / 'import_scripts' / uuid) + save(path, message="Import project from " + url, files=[Path(f'projects/{uuid}'), db_file]) else: dl.drop(tmp_path, reckless='kill') shutil.rmtree(tmp_path) @@ -156,7 +157,7 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti return uuid -def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None: +def drop_project_data(path: Path, uuid: str, path_in_project: str = "") -> None: """ Drop (parts of) a project to free up diskspace @@ -169,6 +170,5 @@ def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None: path_pn_project: str, optional If set, only the given path within the project is dropped. """ - drop(path + "/projects/" + uuid + "/" + path_in_project) + drop(path / "projects" / uuid / path_in_project) return - diff --git a/corrlib/meas_io.py b/corrlib/meas_io.py index 8e5855d..be80b6f 100644 --- a/corrlib/meas_io.py +++ b/corrlib/meas_io.py @@ -10,9 +10,13 @@ from .tools import get_db_file, cache_enabled from .tracker import get, save, unlock import shutil from typing import Any +from pathlib import Path -def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: Union[str, None]) -> None: +CACHE_DIR = ".cache" + + +def write_measurement(path: Path, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: Union[str, None]) -> None: """ Write a measurement to the backlog. If the file for the measurement already exists, update the measurement. @@ -33,7 +37,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, The parameter file used for the measurement. """ db_file = get_db_file(path) - db = os.path.join(path, db_file) + db = path / db_file files_to_save = [] @@ -44,11 +48,11 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, conn = sqlite3.connect(db) c = conn.cursor() for corr in measurement.keys(): - file_in_archive = os.path.join('.', 'archive', ensemble, corr, uuid + '.json.gz') - file = os.path.join(path, file_in_archive) + file_in_archive = Path('.') / 'archive' / ensemble / corr / str(uuid + '.json.gz') + file = path / file_in_archive known_meas = {} - if not os.path.exists(os.path.join(path, '.', 'archive', ensemble, corr)): - os.makedirs(os.path.join(path, '.', 'archive', ensemble, corr)) + if not os.path.exists(path / 'archive' / ensemble / corr): + os.makedirs(path / 'archive' / ensemble / corr) files_to_save.append(file_in_archive) else: if os.path.exists(file): @@ -99,7 +103,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, pars[subkey] = json.dumps(parameters) for subkey in subkeys: parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest() - meas_path = file_in_archive + "::" + parHash + meas_path = str(file_in_archive) + "::" + parHash known_meas[parHash] = measurement[corr][subkey] @@ -115,7 +119,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, return -def load_record(path: str, meas_path: str) -> Union[Corr, Obs]: +def load_record(path: Path, meas_path: str) -> Union[Corr, Obs]: """ Load a list of records by their paths. @@ -134,7 +138,7 @@ def load_record(path: str, meas_path: str) -> Union[Corr, Obs]: return load_records(path, [meas_path])[0] -def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]: +def load_records(path: Path, meas_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]: """ Load a list of records by their paths. @@ -162,11 +166,11 @@ def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = { returned_data: list[Any] = [] for file in needed_data.keys(): for key in list(needed_data[file]): - if os.path.exists(cache_path(path, file, key) + ".p"): - returned_data.append(load_object(cache_path(path, file, key) + ".p")) + if os.path.exists(str(cache_path(path, file, key)) + ".p"): + returned_data.append(load_object(str(cache_path(path, file, key)) + ".p")) else: if file not in preloaded: - preloaded[file] = preload(path, file) + preloaded[file] = preload(path, Path(file)) returned_data.append(preloaded[file][key]) if cache_enabled(path): if not os.path.exists(cache_dir(path, file)): @@ -175,7 +179,7 @@ def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = { return returned_data -def cache_dir(path: str, file: str) -> str: +def cache_dir(path: Path, file: str) -> Path: """ Returns the directory corresponding to the cache for the given file. @@ -190,14 +194,14 @@ def cache_dir(path: str, file: str) -> str: cache_path: str The path holding the cached data for the given file. """ - cache_path_list = [path] - cache_path_list.append(".cache") - cache_path_list.extend(file.split("/")[1:]) - cache_path = "/".join(cache_path_list) + cache_path_list = file.split("/")[1:] + cache_path = path / CACHE_DIR + for directory in cache_path_list: + cache_path /= directory return cache_path -def cache_path(path: str, file: str, key: str) -> str: +def cache_path(path: Path, file: str, key: str) -> Path: """ Parameters ---------- @@ -213,11 +217,11 @@ def cache_path(path: str, file: str, key: str) -> str: cache_path: str The path at which the measurement of the given file and key is cached. """ - cache_path = os.path.join(cache_dir(path, file), key) + cache_path = cache_dir(path, file) / key return cache_path -def preload(path: str, file: str) -> dict[str, Any]: +def preload(path: Path, file: Path) -> dict[str, Any]: """ Read the contents of a file into a json dictionary with the pyerrors.json.load_json_dict method. @@ -234,12 +238,12 @@ def preload(path: str, file: str) -> dict[str, Any]: The data read from the file. """ get(path, file) - filedict: dict[str, Any] = pj.load_json_dict(os.path.join(path, file)) + filedict: dict[str, Any] = pj.load_json_dict(path / file) print("> read file") return filedict -def drop_record(path: str, meas_path: str) -> None: +def drop_record(path: Path, meas_path: str) -> None: """ Drop a record by it's path. @@ -251,9 +255,9 @@ def drop_record(path: str, meas_path: str) -> None: The measurement path as noted in the database. """ file_in_archive = meas_path.split("::")[0] - file = os.path.join(path, file_in_archive) + file = path / file_in_archive db_file = get_db_file(path) - db = os.path.join(path, db_file) + db = path / db_file get(path, db_file) sub_key = meas_path.split("::")[1] unlock(path, db_file) @@ -268,7 +272,7 @@ def drop_record(path: str, meas_path: str) -> None: known_meas = pj.load_json_dict(file) if sub_key in known_meas: del known_meas[sub_key] - unlock(path, file_in_archive) + unlock(path, Path(file_in_archive)) pj.dump_dict_to_json(known_meas, file) save(path, message="Drop measurements to database", files=[db, file]) return @@ -276,7 +280,7 @@ def drop_record(path: str, meas_path: str) -> None: raise ValueError("This measurement does not exist as a file!") -def drop_cache(path: str) -> None: +def drop_cache(path: Path) -> None: """ Drop the cache directory of the library. @@ -285,7 +289,7 @@ def drop_cache(path: str) -> None: path: str The path of the library. """ - cache_dir = os.path.join(path, ".cache") + cache_dir = path / ".cache" for f in os.listdir(cache_dir): - shutil.rmtree(os.path.join(cache_dir, f)) + shutil.rmtree(cache_dir / f) return diff --git a/corrlib/toml.py b/corrlib/toml.py index feafaf6..add3739 100644 --- a/corrlib/toml.py +++ b/corrlib/toml.py @@ -19,6 +19,7 @@ from .meas_io import write_measurement import os from .input.implementations import codes as known_codes from typing import Any +from pathlib import Path def replace_string(string: str, name: str, val: str) -> str: @@ -126,7 +127,7 @@ def check_measurement_data(measurements: dict[str, dict[str, str]], code: str) - return -def import_tomls(path: str, files: list[str], copy_files: bool=True) -> None: +def import_tomls(path: Path, files: list[str], copy_files: bool=True) -> None: """ Import multiple toml files. @@ -144,7 +145,7 @@ def import_tomls(path: str, files: list[str], copy_files: bool=True) -> None: return -def import_toml(path: str, file: str, copy_file: bool=True) -> None: +def import_toml(path: Path, file: str, copy_file: bool=True) -> None: """ Import a project decribed by a .toml file. @@ -171,7 +172,7 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None: aliases = project.get('aliases', []) uuid = project.get('uuid', None) if uuid is not None: - if not os.path.exists(path + "/projects/" + uuid): + if not os.path.exists(path / "projects" / uuid): uuid = import_project(path, project['url'], aliases=aliases) else: update_aliases(path, uuid, aliases) @@ -213,18 +214,18 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None: write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else None)) print(mname + " imported.") - if not os.path.exists(os.path.join(path, "toml_imports", uuid)): - os.makedirs(os.path.join(path, "toml_imports", uuid)) + if not os.path.exists(path / "toml_imports" / uuid): + os.makedirs(path / "toml_imports" / uuid) if copy_file: - import_file = os.path.join(path, "toml_imports", uuid, file.split("/")[-1]) + import_file = path / "toml_imports" / uuid / file.split("/")[-1] shutil.copy(file, import_file) - save(path, files=[import_file], message="Import using " + import_file) - print("File copied to " + import_file) + save(path, files=[import_file], message=f"Import using {import_file}") + print(f"File copied to {import_file}") print("Imported project.") return -def reimport_project(path: str, uuid: str) -> None: +def reimport_project(path: Path, uuid: str) -> None: """ Reimport an existing project using the files that are already available for this project. @@ -235,14 +236,14 @@ def reimport_project(path: str, uuid: str) -> None: uuid: str uuid of the project that is to be reimported. """ - config_path = "/".join([path, "import_scripts", uuid]) + config_path = path / "import_scripts" / uuid for p, filenames, dirnames in os.walk(config_path): for fname in filenames: import_toml(path, os.path.join(config_path, fname), copy_file=False) return -def update_project(path: str, uuid: str) -> None: +def update_project(path: Path, uuid: str) -> None: """ Update all entries associated with a given project. diff --git a/corrlib/tools.py b/corrlib/tools.py index 118b094..93f0678 100644 --- a/corrlib/tools.py +++ b/corrlib/tools.py @@ -1,6 +1,7 @@ import os from configparser import ConfigParser from typing import Any +from pathlib import Path CONFIG_FILENAME = ".corrlib" cached: bool = True @@ -73,7 +74,7 @@ def k2m(k: float) -> float: return (1/(2*k))-4 -def set_config(path: str, section: str, option: str, value: Any) -> None: +def set_config(path: Path, section: str, option: str, value: Any) -> None: """ Set configuration parameters for the library. @@ -88,7 +89,7 @@ def set_config(path: str, section: str, option: str, value: Any) -> None: value: Any The value we set the option to. """ - config_path = os.path.join(path, '.corrlib') + config_path = os.path.join(path, CONFIG_FILENAME) config = ConfigParser() if os.path.exists(config_path): config.read(config_path) @@ -100,7 +101,7 @@ def set_config(path: str, section: str, option: str, value: Any) -> None: return -def get_db_file(path: str) -> str: +def get_db_file(path: Path) -> Path: """ Get the database file associated with the library at the given path. @@ -118,11 +119,13 @@ def get_db_file(path: str) -> str: config = ConfigParser() if os.path.exists(config_path): config.read(config_path) - db_file = config.get('paths', 'db', fallback='backlogger.db') + else: + raise FileNotFoundError("Configuration file not found.") + db_file = Path(config.get('paths', 'db', fallback='backlogger.db')) return db_file -def cache_enabled(path: str) -> bool: +def cache_enabled(path: Path) -> bool: """ Check, whether the library is cached. Fallback is true. @@ -141,6 +144,10 @@ def cache_enabled(path: str) -> bool: config = ConfigParser() if os.path.exists(config_path): config.read(config_path) + else: + raise FileNotFoundError("Configuration file not found.") cached_str = config.get('core', 'cached', fallback='True') + if cached_str not in ['True', 'False']: + raise ValueError(f"String {cached_str} is not a valid option, only True and False are allowed!") cached_bool = cached_str == ('True') return cached_bool diff --git a/corrlib/tracker.py b/corrlib/tracker.py index e535b03..a6e9bf4 100644 --- a/corrlib/tracker.py +++ b/corrlib/tracker.py @@ -4,9 +4,10 @@ import datalad.api as dl from typing import Optional import shutil from .tools import get_db_file +from pathlib import Path -def get_tracker(path: str) -> str: +def get_tracker(path: Path) -> str: """ Get the tracker used in the dataset located at path. @@ -30,7 +31,7 @@ def get_tracker(path: str) -> str: return tracker -def get(path: str, file: str) -> None: +def get(path: Path, file: Path) -> None: """ Wrapper function to get a file from the dataset located at path with the specified tracker. @@ -56,7 +57,7 @@ def get(path: str, file: str) -> None: return -def save(path: str, message: str, files: Optional[list[str]]=None) -> None: +def save(path: Path, message: str, files: Optional[list[Path]]=None) -> None: """ Wrapper function to save a file to the dataset located at path with the specified tracker. @@ -72,7 +73,7 @@ def save(path: str, message: str, files: Optional[list[str]]=None) -> None: tracker = get_tracker(path) if tracker == 'datalad': if files is not None: - files = [os.path.join(path, f) for f in files] + files = [path / f for f in files] dl.save(files, message=message, dataset=path) elif tracker == 'None': Warning("Tracker 'None' does not implement save.") @@ -81,7 +82,7 @@ def save(path: str, message: str, files: Optional[list[str]]=None) -> None: raise ValueError(f"Tracker {tracker} is not supported.") -def init(path: str, tracker: str='datalad') -> None: +def init(path: Path, tracker: str='datalad') -> None: """ Initialize a dataset at the specified path with the specified tracker. @@ -101,7 +102,7 @@ def init(path: str, tracker: str='datalad') -> None: return -def unlock(path: str, file: str) -> None: +def unlock(path: Path, file: Path) -> None: """ Wrapper function to unlock a file in the dataset located at path with the specified tracker. @@ -123,7 +124,7 @@ def unlock(path: str, file: str) -> None: return -def clone(path: str, source: str, target: str) -> None: +def clone(path: Path, source: str, target: str) -> None: """ Wrapper function to clone a dataset from source to target with the specified tracker. Parameters @@ -147,7 +148,7 @@ def clone(path: str, source: str, target: str) -> None: return -def drop(path: str, reckless: Optional[str]=None) -> None: +def drop(path: Path, reckless: Optional[str]=None) -> None: """ Wrapper function to drop data from a dataset located at path with the specified tracker. diff --git a/tests/cli_test.py b/tests/cli_test.py index d4a4045..cba0a10 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -2,18 +2,19 @@ from typer.testing import CliRunner from corrlib.cli import app import os import sqlite3 as sql +from pathlib import Path runner = CliRunner() -def test_version(): +def test_version() -> None: result = runner.invoke(app, ["--version"]) assert result.exit_code == 0 assert "corrlib" in result.output -def test_init_folders(tmp_path): +def test_init_folders(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) assert result.exit_code == 0 @@ -21,7 +22,7 @@ def test_init_folders(tmp_path): assert os.path.exists(str(dataset_path / "backlogger.db")) -def test_init_db(tmp_path): +def test_init_db(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) assert result.exit_code == 0 @@ -81,7 +82,7 @@ def test_init_db(tmp_path): assert expected_col in backlog_column_names -def test_list(tmp_path): +def test_list(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) assert result.exit_code == 0 diff --git a/tests/import_project_test.py b/tests/import_project_test.py index 2dea06f..685d2cf 100644 --- a/tests/import_project_test.py +++ b/tests/import_project_test.py @@ -1,7 +1,7 @@ import corrlib.toml as t -def test_toml_check_measurement_data(): +def test_toml_check_measurement_data() -> None: measurements = { "a": { diff --git a/tests/sfcf_in_test.py b/tests/sfcf_in_test.py index 72921e7..7ebc94a 100644 --- a/tests/sfcf_in_test.py +++ b/tests/sfcf_in_test.py @@ -1,7 +1,7 @@ import corrlib.input.sfcf as input import json -def test_get_specs(): +def test_get_specs() -> None: parameters = { 'crr': [ 'f_P', 'f_A' @@ -26,4 +26,4 @@ def test_get_specs(): key = "f_P/q1 q2/1/0/0" specs = json.loads(input.get_specs(key, parameters)) assert specs['quarks'] == ['a', 'b'] - assert specs['wf1'][0] == [1, [0, 0]] \ No newline at end of file + assert specs['wf1'][0] == [1, [0, 0]] diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 1ea0ece..d78fb15 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -1,24 +1,25 @@ import corrlib.initialization as init import os import sqlite3 as sql +from pathlib import Path -def test_init_folders(tmp_path): +def test_init_folders(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(str(dataset_path)) + init.create(dataset_path) assert os.path.exists(str(dataset_path)) assert os.path.exists(str(dataset_path / "backlogger.db")) -def test_init_folders_no_tracker(tmp_path): +def test_init_folders_no_tracker(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(str(dataset_path), tracker="None") + init.create(dataset_path, tracker="None") assert os.path.exists(str(dataset_path)) assert os.path.exists(str(dataset_path / "backlogger.db")) -def test_init_config(tmp_path): +def test_init_config(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(str(dataset_path), tracker="None") + init.create(dataset_path, tracker="None") config_path = dataset_path / ".corrlib" assert os.path.exists(str(config_path)) from configparser import ConfigParser @@ -34,9 +35,9 @@ def test_init_config(tmp_path): assert config.get("paths", "import_scripts_path") == "import_scripts" -def test_init_db(tmp_path): +def test_init_db(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(str(dataset_path)) + init.create(dataset_path) assert os.path.exists(str(dataset_path / "backlogger.db")) conn = sql.connect(str(dataset_path / "backlogger.db")) cursor = conn.cursor() diff --git a/tests/tools_test.py b/tests/tools_test.py index ee76f1c..541674f 100644 --- a/tests/tools_test.py +++ b/tests/tools_test.py @@ -1,31 +1,84 @@ - - from corrlib import tools as tl +from configparser import ConfigParser +from pathlib import Path +import pytest -def test_m2k(): +def test_m2k() -> None: for m in [0.1, 0.5, 1.0]: expected_k = 1 / (2 * m + 8) assert tl.m2k(m) == expected_k -def test_k2m(): +def test_k2m() -> None: for m in [0.1, 0.5, 1.0]: assert tl.k2m(m) == (1/(2*m))-4 -def test_k2m_m2k(): +def test_k2m_m2k() -> None: for m in [0.1, 0.5, 1.0]: k = tl.m2k(m) m_converted = tl.k2m(k) assert abs(m - m_converted) < 1e-9 -def test_str2list(): +def test_str2list() -> None: assert tl.str2list("a,b,c") == ["a", "b", "c"] assert tl.str2list("1,2,3") == ["1", "2", "3"] -def test_list2str(): +def test_list2str() -> None: assert tl.list2str(["a", "b", "c"]) == "a,b,c" assert tl.list2str(["1", "2", "3"]) == "1,2,3" + + +def test_set_config(tmp_path: Path) -> None: + section = "core" + option = "test_option" + value = "test_value" + # config is not yet available + tl.set_config(tmp_path, section, option, value) + config_path = tmp_path / '.corrlib' + config = ConfigParser() + config.read(config_path) + assert config.get('core', 'test_option', fallback="not the value") == "test_value" + # now, a config file is already present + section = "core" + option = "test_option2" + value = "test_value2" + tl.set_config(tmp_path, section, option, value) + config_path = tmp_path / '.corrlib' + config = ConfigParser() + config.read(config_path) + assert config.get('core', 'test_option2', fallback="not the value") == "test_value2" + # update option 2 + section = "core" + option = "test_option2" + value = "test_value3" + tl.set_config(tmp_path, section, option, value) + config_path = tmp_path / '.corrlib' + config = ConfigParser() + config.read(config_path) + assert config.get('core', 'test_option2', fallback="not the value") == "test_value3" + + +def test_get_db_file(tmp_path: Path) -> None: + section = "paths" + option = "db" + value = "test_value" + # config is not yet available + tl.set_config(tmp_path, section, option, value) + assert tl.get_db_file(tmp_path) == Path("test_value") + + +def test_cache_enabled(tmp_path: Path) -> None: + section = "core" + option = "cached" + # config is not yet available + tl.set_config(tmp_path, section, option, "True") + assert tl.cache_enabled(tmp_path) + tl.set_config(tmp_path, section, option, "False") + assert not tl.cache_enabled(tmp_path) + tl.set_config(tmp_path, section, option, "lalala") + with pytest.raises(ValueError): + tl.cache_enabled(tmp_path)