Merge pull request 'test/more' (#9) from test/more into develop
Reviewed-on: https://www.kuhl-mann.de/git/git/jkuhl/corrlib/pulls/9
This commit is contained in:
commit
04559cc95f
15 changed files with 1201 additions and 42 deletions
6
.github/workflows/pytest.yaml
vendored
6
.github/workflows/pytest.yaml
vendored
|
|
@ -20,6 +20,10 @@ jobs:
|
||||||
env:
|
env:
|
||||||
UV_CACHE_DIR: /tmp/.uv-cache
|
UV_CACHE_DIR: /tmp/.uv-cache
|
||||||
steps:
|
steps:
|
||||||
|
- name: Install git-annex
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y git-annex
|
||||||
- name: Check out the repository
|
- name: Check out the repository
|
||||||
uses: https://github.com/RouxAntoine/checkout@v4.1.8
|
uses: https://github.com/RouxAntoine/checkout@v4.1.8
|
||||||
with:
|
with:
|
||||||
|
|
@ -32,4 +36,4 @@ jobs:
|
||||||
- name: Install corrlib
|
- name: Install corrlib
|
||||||
run: uv sync --locked --all-extras --dev --python ${{ matrix.python-version }}
|
run: uv sync --locked --all-extras --dev --python ${{ matrix.python-version }}
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: uv run pytest tests
|
run: uv run pytest --cov=corrlib tests
|
||||||
|
|
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -4,4 +4,5 @@ __pycache__
|
||||||
test.ipynb
|
test.ipynb
|
||||||
.vscode
|
.vscode
|
||||||
.venv
|
.venv
|
||||||
.pytest_cache
|
.pytest_cache
|
||||||
|
.coverage
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
The aim of this project is to extend pyerrors to be able to collect measurements from different projects and make them easily accessable to
|
The aim of this project is to extend pyerrors to be able to collect measurements from different projects and make them easily accessable to
|
||||||
the research group. The idea is to build a database, in which the researcher can easily search for measurements on a correlator basis,
|
the research group. The idea is to build a database, in which the researcher can easily search for measurements on a correlator basis,
|
||||||
which may be reusable.
|
which may be reusable.
|
||||||
As a standard to store the measurements, we will use the .json.gz format from pyerrors.
|
As a standard to store the measurements, we will use the .json.gz format from pyerrors.
|
||||||
|
|
@ -15,8 +15,10 @@ For now, we are interested in collecting primary IObservables only, as these are
|
||||||
|
|
||||||
__app_name__ = "corrlib"
|
__app_name__ = "corrlib"
|
||||||
|
|
||||||
from .main import *
|
|
||||||
from .import input as input
|
from .import input as input
|
||||||
from .initialization import *
|
from .initialization import create as create
|
||||||
from .meas_io import *
|
from .meas_io import load_record as load_record
|
||||||
from .find import *
|
from .meas_io import load_records as load_records
|
||||||
|
from .find import find_project as find_project
|
||||||
|
from .find import find_record as find_record
|
||||||
|
from .find import list_projects as list_projects
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from .tools import str2list
|
||||||
from .main import update_aliases
|
from .main import update_aliases
|
||||||
from .meas_io import drop_cache as mio_drop_cache
|
from .meas_io import drop_cache as mio_drop_cache
|
||||||
import os
|
import os
|
||||||
from importlib.metadata import version, PackageNotFoundError
|
from importlib.metadata import version
|
||||||
|
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import datalad.api as dl
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
@ -16,7 +15,7 @@ def _project_lookup_by_alias(db, alias):
|
||||||
c.execute(f"SELECT * FROM 'projects' WHERE alias = '{alias}'")
|
c.execute(f"SELECT * FROM 'projects' WHERE alias = '{alias}'")
|
||||||
results = c.fetchall()
|
results = c.fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
if len(results) > 1:
|
if len(results)>1:
|
||||||
print("Error: multiple projects found with alias " + alias)
|
print("Error: multiple projects found with alias " + alias)
|
||||||
elif len(results) == 0:
|
elif len(results) == 0:
|
||||||
raise Exception("Error: no project found with alias " + alias)
|
raise Exception("Error: no project found with alias " + alias)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,6 @@
|
||||||
Import functions for different codes.
|
Import functions for different codes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from . import sfcf
|
from . import sfcf as sfcf
|
||||||
from . import openQCD
|
from . import openQCD as openQCD
|
||||||
from . import implementations
|
from . import implementations as implementations
|
||||||
|
|
|
||||||
|
|
@ -130,14 +130,14 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Unio
|
||||||
dl.save([db, path + '/projects/' + uuid], message="Import project from " + url, dataset=path)
|
dl.save([db, path + '/projects/' + uuid], message="Import project from " + url, dataset=path)
|
||||||
else:
|
else:
|
||||||
dl.drop(tmp_path, reckless='kill')
|
dl.drop(tmp_path, reckless='kill')
|
||||||
shutil.rmtree(tmp_path)
|
shutil.rmtree(tmp_path)
|
||||||
if aliases is not None:
|
if aliases is not None:
|
||||||
if isinstance(aliases, str):
|
if isinstance(aliases, str):
|
||||||
alias_list = [aliases]
|
alias_list = [aliases]
|
||||||
else:
|
else:
|
||||||
alias_list = aliases
|
alias_list = aliases
|
||||||
update_aliases(path, uuid, alias_list)
|
update_aliases(path, uuid, alias_list)
|
||||||
|
|
||||||
# make this more concrete
|
# make this more concrete
|
||||||
return uuid
|
return uuid
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=No
|
||||||
pars = {}
|
pars = {}
|
||||||
subkeys = []
|
subkeys = []
|
||||||
for i in range(len(parameters["rw_fcts"])):
|
for i in range(len(parameters["rw_fcts"])):
|
||||||
par_list = []
|
par_list = []
|
||||||
for k in parameters["rw_fcts"][i].keys():
|
for k in parameters["rw_fcts"][i].keys():
|
||||||
par_list.append(str(parameters["rw_fcts"][i][k]))
|
par_list.append(str(parameters["rw_fcts"][i][k]))
|
||||||
subkey = "/".join(par_list)
|
subkey = "/".join(par_list)
|
||||||
|
|
@ -79,12 +79,12 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=No
|
||||||
subkey = "/".join(par_list)
|
subkey = "/".join(par_list)
|
||||||
subkeys = [subkey]
|
subkeys = [subkey]
|
||||||
pars[subkey] = json.dumps(parameters)
|
pars[subkey] = json.dumps(parameters)
|
||||||
for subkey in subkeys:
|
for subkey in subkeys:
|
||||||
parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest()
|
parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest()
|
||||||
meas_path = file_in_archive + "::" + parHash
|
meas_path = file_in_archive + "::" + parHash
|
||||||
|
|
||||||
known_meas[parHash] = measurement[corr][subkey]
|
known_meas[parHash] = measurement[corr][subkey]
|
||||||
|
|
||||||
if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path,)).fetchone() is not None:
|
if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path,)).fetchone() is not None:
|
||||||
c.execute("UPDATE backlogs SET updated_at = datetime('now') WHERE path = ?", (meas_path, ))
|
c.execute("UPDATE backlogs SET updated_at = datetime('now') WHERE path = ?", (meas_path, ))
|
||||||
else:
|
else:
|
||||||
|
|
@ -107,7 +107,7 @@ def load_record(path: str, meas_path: str):
|
||||||
Path of the correlator library.
|
Path of the correlator library.
|
||||||
meas_path: str
|
meas_path: str
|
||||||
The path to the correlator in the backlog system.
|
The path to the correlator in the backlog system.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
co : Corr or Obs
|
co : Corr or Obs
|
||||||
|
|
@ -126,7 +126,7 @@ def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union
|
||||||
Path of the correlator library.
|
Path of the correlator library.
|
||||||
meas_paths: list[str]
|
meas_paths: list[str]
|
||||||
A list of the paths to the correlator in the backlog system.
|
A list of the paths to the correlator in the backlog system.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
List
|
List
|
||||||
|
|
|
||||||
|
|
@ -26,4 +26,3 @@ def get_file(path: str, file: str):
|
||||||
print("Downloading data...")
|
print("Downloading data...")
|
||||||
dl.get(os.path.join(path, file), dataset=path)
|
dl.get(os.path.join(path, file), dataset=path)
|
||||||
print("> downloaded file")
|
print("> downloaded file")
|
||||||
|
|
||||||
|
|
@ -1 +1,34 @@
|
||||||
__version__ = "0.2.3"
|
# file generated by setuptools-scm
|
||||||
|
# don't change, don't track in version control
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"__version__",
|
||||||
|
"__version_tuple__",
|
||||||
|
"version",
|
||||||
|
"version_tuple",
|
||||||
|
"__commit_id__",
|
||||||
|
"commit_id",
|
||||||
|
]
|
||||||
|
|
||||||
|
TYPE_CHECKING = False
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
||||||
|
COMMIT_ID = Union[str, None]
|
||||||
|
else:
|
||||||
|
VERSION_TUPLE = object
|
||||||
|
COMMIT_ID = object
|
||||||
|
|
||||||
|
version: str
|
||||||
|
__version__: str
|
||||||
|
__version_tuple__: VERSION_TUPLE
|
||||||
|
version_tuple: VERSION_TUPLE
|
||||||
|
commit_id: COMMIT_ID
|
||||||
|
__commit_id__: COMMIT_ID
|
||||||
|
|
||||||
|
__version__ = version = '0.2.4.dev14+g602324f84.d20251202'
|
||||||
|
__version_tuple__ = version_tuple = (0, 2, 4, 'dev14', 'g602324f84.d20251202')
|
||||||
|
|
||||||
|
__commit_id__ = commit_id = 'g602324f84'
|
||||||
|
|
|
||||||
|
|
@ -27,10 +27,19 @@ include = ["corrlib", "corrlib.*"]
|
||||||
write_to = "corrlib/version.py"
|
write_to = "corrlib/version.py"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
ignore = ["F403"]
|
ignore = ["E501"]
|
||||||
|
extend-select = [
|
||||||
|
"YTT",
|
||||||
|
"E",
|
||||||
|
"W",
|
||||||
|
"F",
|
||||||
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
|
"mypy>=1.19.0",
|
||||||
"pytest>=9.0.1",
|
"pytest>=9.0.1",
|
||||||
|
"pytest-cov>=7.0.0",
|
||||||
"pytest-pretty>=1.3.0",
|
"pytest-pretty>=1.3.0",
|
||||||
|
"ruff>=0.14.7",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
91
tests/cli_test.py
Normal file
91
tests/cli_test.py
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
from corrlib.cli import app
|
||||||
|
import os
|
||||||
|
import sqlite3 as sql
|
||||||
|
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
def test_version():
|
||||||
|
result = runner.invoke(app, ["--version"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "corrlib" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_folders(tmp_path):
|
||||||
|
dataset_path = tmp_path / "test_dataset"
|
||||||
|
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert os.path.exists(str(dataset_path))
|
||||||
|
assert os.path.exists(str(dataset_path / "backlogger.db"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_db(tmp_path):
|
||||||
|
dataset_path = tmp_path / "test_dataset"
|
||||||
|
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert os.path.exists(str(dataset_path / "backlogger.db"))
|
||||||
|
conn = sql.connect(str(dataset_path / "backlogger.db"))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||||
|
tables = cursor.fetchall()
|
||||||
|
expected_tables = [
|
||||||
|
'projects',
|
||||||
|
'backlogs',
|
||||||
|
]
|
||||||
|
table_names = [table[0] for table in tables]
|
||||||
|
for expected_table in expected_tables:
|
||||||
|
assert expected_table in table_names
|
||||||
|
|
||||||
|
cursor.execute("SELECT * FROM projects;")
|
||||||
|
projects = cursor.fetchall()
|
||||||
|
assert len(projects) == 0
|
||||||
|
|
||||||
|
cursor.execute("SELECT * FROM backlogs;")
|
||||||
|
backlogs = cursor.fetchall()
|
||||||
|
assert len(backlogs) == 0
|
||||||
|
|
||||||
|
cursor.execute("PRAGMA table_info('projects');")
|
||||||
|
project_columns = cursor.fetchall()
|
||||||
|
expected_project_columns = [
|
||||||
|
"id",
|
||||||
|
"aliases",
|
||||||
|
"customTags",
|
||||||
|
"owner",
|
||||||
|
"code",
|
||||||
|
"created_at",
|
||||||
|
"updated_at"
|
||||||
|
]
|
||||||
|
project_column_names = [col[1] for col in project_columns]
|
||||||
|
for expected_col in expected_project_columns:
|
||||||
|
assert expected_col in project_column_names
|
||||||
|
|
||||||
|
cursor.execute("PRAGMA table_info('backlogs');")
|
||||||
|
backlog_columns = cursor.fetchall()
|
||||||
|
expected_backlog_columns = [
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"ensemble",
|
||||||
|
"code",
|
||||||
|
"path",
|
||||||
|
"project",
|
||||||
|
"customTags",
|
||||||
|
"parameters",
|
||||||
|
"parameter_file",
|
||||||
|
"created_at",
|
||||||
|
"updated_at"
|
||||||
|
]
|
||||||
|
backlog_column_names = [col[1] for col in backlog_columns]
|
||||||
|
for expected_col in expected_backlog_columns:
|
||||||
|
assert expected_col in backlog_column_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_list(tmp_path):
|
||||||
|
dataset_path = tmp_path / "test_dataset"
|
||||||
|
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "ensembles"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "projects"])
|
||||||
|
assert result.exit_code == 0
|
||||||
68
tests/test_initialization.py
Normal file
68
tests/test_initialization.py
Normal file
|
|
@ -0,0 +1,68 @@
|
||||||
|
import corrlib.initialization as init
|
||||||
|
import os
|
||||||
|
import sqlite3 as sql
|
||||||
|
|
||||||
|
def test_init_folders(tmp_path):
|
||||||
|
dataset_path = tmp_path / "test_dataset"
|
||||||
|
init.create(str(dataset_path))
|
||||||
|
assert os.path.exists(str(dataset_path))
|
||||||
|
assert os.path.exists(str(dataset_path / "backlogger.db"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_db(tmp_path):
|
||||||
|
dataset_path = tmp_path / "test_dataset"
|
||||||
|
init.create(str(dataset_path))
|
||||||
|
assert os.path.exists(str(dataset_path / "backlogger.db"))
|
||||||
|
conn = sql.connect(str(dataset_path / "backlogger.db"))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||||
|
tables = cursor.fetchall()
|
||||||
|
expected_tables = [
|
||||||
|
'projects',
|
||||||
|
'backlogs',
|
||||||
|
]
|
||||||
|
table_names = [table[0] for table in tables]
|
||||||
|
for expected_table in expected_tables:
|
||||||
|
assert expected_table in table_names
|
||||||
|
|
||||||
|
cursor.execute("SELECT * FROM projects;")
|
||||||
|
projects = cursor.fetchall()
|
||||||
|
assert len(projects) == 0
|
||||||
|
|
||||||
|
cursor.execute("SELECT * FROM backlogs;")
|
||||||
|
backlogs = cursor.fetchall()
|
||||||
|
assert len(backlogs) == 0
|
||||||
|
|
||||||
|
cursor.execute("PRAGMA table_info('projects');")
|
||||||
|
project_columns = cursor.fetchall()
|
||||||
|
expected_project_columns = [
|
||||||
|
"id",
|
||||||
|
"aliases",
|
||||||
|
"customTags",
|
||||||
|
"owner",
|
||||||
|
"code",
|
||||||
|
"created_at",
|
||||||
|
"updated_at"
|
||||||
|
]
|
||||||
|
project_column_names = [col[1] for col in project_columns]
|
||||||
|
for expected_col in expected_project_columns:
|
||||||
|
assert expected_col in project_column_names
|
||||||
|
|
||||||
|
cursor.execute("PRAGMA table_info('backlogs');")
|
||||||
|
backlog_columns = cursor.fetchall()
|
||||||
|
expected_backlog_columns = [
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"ensemble",
|
||||||
|
"code",
|
||||||
|
"path",
|
||||||
|
"project",
|
||||||
|
"customTags",
|
||||||
|
"parameters",
|
||||||
|
"parameter_file",
|
||||||
|
"created_at",
|
||||||
|
"updated_at"
|
||||||
|
]
|
||||||
|
backlog_column_names = [col[1] for col in backlog_columns]
|
||||||
|
for expected_col in expected_backlog_columns:
|
||||||
|
assert expected_col in backlog_column_names
|
||||||
|
|
@ -4,15 +4,21 @@ from corrlib import tools as tl
|
||||||
|
|
||||||
|
|
||||||
def test_m2k():
|
def test_m2k():
|
||||||
assert tl.m2k(0.1) == 1/(2*0.1+8)
|
for m in [0.1, 0.5, 1.0]:
|
||||||
assert tl.m2k(0.5) == 1/(2*0.5+8)
|
expected_k = 1 / (2 * m + 8)
|
||||||
assert tl.m2k(1.0) == 1/(2*1.0+8)
|
assert tl.m2k(m) == expected_k
|
||||||
|
|
||||||
|
|
||||||
def test_k2m():
|
def test_k2m():
|
||||||
assert tl.k2m(0.1) == (1/(2*0.1))-4
|
for m in [0.1, 0.5, 1.0]:
|
||||||
assert tl.k2m(0.5) == (1/(2*0.5))-4
|
assert tl.k2m(m) == (1/(2*m))-4
|
||||||
assert tl.k2m(1.0) == (1/(2*1.0))-4
|
|
||||||
|
|
||||||
|
def test_k2m_m2k():
|
||||||
|
for m in [0.1, 0.5, 1.0]:
|
||||||
|
k = tl.m2k(m)
|
||||||
|
m_converted = tl.k2m(k)
|
||||||
|
assert abs(m - m_converted) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
def test_str2list():
|
def test_str2list():
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue