Merge pull request 'test/more' (#9) from test/more into develop
All checks were successful
Pytest / pytest (3.12) (push) Successful in 51s
Pytest / pytest (3.13) (push) Successful in 48s
Pytest / pytest (3.14) (push) Successful in 47s

Reviewed-on: https://www.kuhl-mann.de/git/git/jkuhl/corrlib/pulls/9
This commit is contained in:
Justus Kuhlmann 2025-12-02 10:33:37 +01:00
commit 04559cc95f
15 changed files with 1201 additions and 42 deletions

View file

@ -20,6 +20,10 @@ jobs:
env: env:
UV_CACHE_DIR: /tmp/.uv-cache UV_CACHE_DIR: /tmp/.uv-cache
steps: steps:
- name: Install git-annex
run: |
sudo apt-get update
sudo apt-get install -y git-annex
- name: Check out the repository - name: Check out the repository
uses: https://github.com/RouxAntoine/checkout@v4.1.8 uses: https://github.com/RouxAntoine/checkout@v4.1.8
with: with:
@ -32,4 +36,4 @@ jobs:
- name: Install corrlib - name: Install corrlib
run: uv sync --locked --all-extras --dev --python ${{ matrix.python-version }} run: uv sync --locked --all-extras --dev --python ${{ matrix.python-version }}
- name: Run tests - name: Run tests
run: uv run pytest tests run: uv run pytest --cov=corrlib tests

3
.gitignore vendored
View file

@ -4,4 +4,5 @@ __pycache__
test.ipynb test.ipynb
.vscode .vscode
.venv .venv
.pytest_cache .pytest_cache
.coverage

View file

@ -1,5 +1,5 @@
""" """
The aim of this project is to extend pyerrors to be able to collect measurements from different projects and make them easily accessable to The aim of this project is to extend pyerrors to be able to collect measurements from different projects and make them easily accessable to
the research group. The idea is to build a database, in which the researcher can easily search for measurements on a correlator basis, the research group. The idea is to build a database, in which the researcher can easily search for measurements on a correlator basis,
which may be reusable. which may be reusable.
As a standard to store the measurements, we will use the .json.gz format from pyerrors. As a standard to store the measurements, we will use the .json.gz format from pyerrors.
@ -15,8 +15,10 @@ For now, we are interested in collecting primary IObservables only, as these are
__app_name__ = "corrlib" __app_name__ = "corrlib"
from .main import *
from .import input as input from .import input as input
from .initialization import * from .initialization import create as create
from .meas_io import * from .meas_io import load_record as load_record
from .find import * from .meas_io import load_records as load_records
from .find import find_project as find_project
from .find import find_record as find_record
from .find import list_projects as list_projects

View file

@ -8,7 +8,7 @@ from .tools import str2list
from .main import update_aliases from .main import update_aliases
from .meas_io import drop_cache as mio_drop_cache from .meas_io import drop_cache as mio_drop_cache
import os import os
from importlib.metadata import version, PackageNotFoundError from importlib.metadata import version
app = typer.Typer() app = typer.Typer()

View file

@ -1,5 +1,4 @@
import sqlite3 import sqlite3
import datalad.api as dl
import os import os
import json import json
import pandas as pd import pandas as pd
@ -16,7 +15,7 @@ def _project_lookup_by_alias(db, alias):
c.execute(f"SELECT * FROM 'projects' WHERE alias = '{alias}'") c.execute(f"SELECT * FROM 'projects' WHERE alias = '{alias}'")
results = c.fetchall() results = c.fetchall()
conn.close() conn.close()
if len(results) > 1: if len(results)>1:
print("Error: multiple projects found with alias " + alias) print("Error: multiple projects found with alias " + alias)
elif len(results) == 0: elif len(results) == 0:
raise Exception("Error: no project found with alias " + alias) raise Exception("Error: no project found with alias " + alias)

View file

@ -2,6 +2,6 @@
Import functions for different codes. Import functions for different codes.
""" """
from . import sfcf from . import sfcf as sfcf
from . import openQCD from . import openQCD as openQCD
from . import implementations from . import implementations as implementations

View file

@ -130,14 +130,14 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Unio
dl.save([db, path + '/projects/' + uuid], message="Import project from " + url, dataset=path) dl.save([db, path + '/projects/' + uuid], message="Import project from " + url, dataset=path)
else: else:
dl.drop(tmp_path, reckless='kill') dl.drop(tmp_path, reckless='kill')
shutil.rmtree(tmp_path) shutil.rmtree(tmp_path)
if aliases is not None: if aliases is not None:
if isinstance(aliases, str): if isinstance(aliases, str):
alias_list = [aliases] alias_list = [aliases]
else: else:
alias_list = aliases alias_list = aliases
update_aliases(path, uuid, alias_list) update_aliases(path, uuid, alias_list)
# make this more concrete # make this more concrete
return uuid return uuid

View file

@ -58,7 +58,7 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=No
pars = {} pars = {}
subkeys = [] subkeys = []
for i in range(len(parameters["rw_fcts"])): for i in range(len(parameters["rw_fcts"])):
par_list = [] par_list = []
for k in parameters["rw_fcts"][i].keys(): for k in parameters["rw_fcts"][i].keys():
par_list.append(str(parameters["rw_fcts"][i][k])) par_list.append(str(parameters["rw_fcts"][i][k]))
subkey = "/".join(par_list) subkey = "/".join(par_list)
@ -79,12 +79,12 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=No
subkey = "/".join(par_list) subkey = "/".join(par_list)
subkeys = [subkey] subkeys = [subkey]
pars[subkey] = json.dumps(parameters) pars[subkey] = json.dumps(parameters)
for subkey in subkeys: for subkey in subkeys:
parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest() parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest()
meas_path = file_in_archive + "::" + parHash meas_path = file_in_archive + "::" + parHash
known_meas[parHash] = measurement[corr][subkey] known_meas[parHash] = measurement[corr][subkey]
if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path,)).fetchone() is not None: if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path,)).fetchone() is not None:
c.execute("UPDATE backlogs SET updated_at = datetime('now') WHERE path = ?", (meas_path, )) c.execute("UPDATE backlogs SET updated_at = datetime('now') WHERE path = ?", (meas_path, ))
else: else:
@ -107,7 +107,7 @@ def load_record(path: str, meas_path: str):
Path of the correlator library. Path of the correlator library.
meas_path: str meas_path: str
The path to the correlator in the backlog system. The path to the correlator in the backlog system.
Returns Returns
------- -------
co : Corr or Obs co : Corr or Obs
@ -126,7 +126,7 @@ def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union
Path of the correlator library. Path of the correlator library.
meas_paths: list[str] meas_paths: list[str]
A list of the paths to the correlator in the backlog system. A list of the paths to the correlator in the backlog system.
Returns Returns
------- -------
List List

View file

@ -26,4 +26,3 @@ def get_file(path: str, file: str):
print("Downloading data...") print("Downloading data...")
dl.get(os.path.join(path, file), dataset=path) dl.get(os.path.join(path, file), dataset=path)
print("> downloaded file") print("> downloaded file")

View file

@ -1 +1,34 @@
__version__ = "0.2.3" # file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '0.2.4.dev14+g602324f84.d20251202'
__version_tuple__ = version_tuple = (0, 2, 4, 'dev14', 'g602324f84.d20251202')
__commit_id__ = commit_id = 'g602324f84'

View file

@ -27,10 +27,19 @@ include = ["corrlib", "corrlib.*"]
write_to = "corrlib/version.py" write_to = "corrlib/version.py"
[tool.ruff.lint] [tool.ruff.lint]
ignore = ["F403"] ignore = ["E501"]
extend-select = [
"YTT",
"E",
"W",
"F",
]
[dependency-groups] [dependency-groups]
dev = [ dev = [
"mypy>=1.19.0",
"pytest>=9.0.1", "pytest>=9.0.1",
"pytest-cov>=7.0.0",
"pytest-pretty>=1.3.0", "pytest-pretty>=1.3.0",
"ruff>=0.14.7",
] ]

91
tests/cli_test.py Normal file
View file

@ -0,0 +1,91 @@
from typer.testing import CliRunner
from corrlib.cli import app
import os
import sqlite3 as sql
runner = CliRunner()
def test_version():
result = runner.invoke(app, ["--version"])
assert result.exit_code == 0
assert "corrlib" in result.output
def test_init_folders(tmp_path):
dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0
assert os.path.exists(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_db(tmp_path):
dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0
assert os.path.exists(str(dataset_path / "backlogger.db"))
conn = sql.connect(str(dataset_path / "backlogger.db"))
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
expected_tables = [
'projects',
'backlogs',
]
table_names = [table[0] for table in tables]
for expected_table in expected_tables:
assert expected_table in table_names
cursor.execute("SELECT * FROM projects;")
projects = cursor.fetchall()
assert len(projects) == 0
cursor.execute("SELECT * FROM backlogs;")
backlogs = cursor.fetchall()
assert len(backlogs) == 0
cursor.execute("PRAGMA table_info('projects');")
project_columns = cursor.fetchall()
expected_project_columns = [
"id",
"aliases",
"customTags",
"owner",
"code",
"created_at",
"updated_at"
]
project_column_names = [col[1] for col in project_columns]
for expected_col in expected_project_columns:
assert expected_col in project_column_names
cursor.execute("PRAGMA table_info('backlogs');")
backlog_columns = cursor.fetchall()
expected_backlog_columns = [
"id",
"name",
"ensemble",
"code",
"path",
"project",
"customTags",
"parameters",
"parameter_file",
"created_at",
"updated_at"
]
backlog_column_names = [col[1] for col in backlog_columns]
for expected_col in expected_backlog_columns:
assert expected_col in backlog_column_names
def test_list(tmp_path):
dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0
result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "ensembles"])
assert result.exit_code == 0
result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "projects"])
assert result.exit_code == 0

View file

@ -0,0 +1,68 @@
import corrlib.initialization as init
import os
import sqlite3 as sql
def test_init_folders(tmp_path):
dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path))
assert os.path.exists(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_db(tmp_path):
dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db"))
conn = sql.connect(str(dataset_path / "backlogger.db"))
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
expected_tables = [
'projects',
'backlogs',
]
table_names = [table[0] for table in tables]
for expected_table in expected_tables:
assert expected_table in table_names
cursor.execute("SELECT * FROM projects;")
projects = cursor.fetchall()
assert len(projects) == 0
cursor.execute("SELECT * FROM backlogs;")
backlogs = cursor.fetchall()
assert len(backlogs) == 0
cursor.execute("PRAGMA table_info('projects');")
project_columns = cursor.fetchall()
expected_project_columns = [
"id",
"aliases",
"customTags",
"owner",
"code",
"created_at",
"updated_at"
]
project_column_names = [col[1] for col in project_columns]
for expected_col in expected_project_columns:
assert expected_col in project_column_names
cursor.execute("PRAGMA table_info('backlogs');")
backlog_columns = cursor.fetchall()
expected_backlog_columns = [
"id",
"name",
"ensemble",
"code",
"path",
"project",
"customTags",
"parameters",
"parameter_file",
"created_at",
"updated_at"
]
backlog_column_names = [col[1] for col in backlog_columns]
for expected_col in expected_backlog_columns:
assert expected_col in backlog_column_names

View file

@ -4,15 +4,21 @@ from corrlib import tools as tl
def test_m2k(): def test_m2k():
assert tl.m2k(0.1) == 1/(2*0.1+8) for m in [0.1, 0.5, 1.0]:
assert tl.m2k(0.5) == 1/(2*0.5+8) expected_k = 1 / (2 * m + 8)
assert tl.m2k(1.0) == 1/(2*1.0+8) assert tl.m2k(m) == expected_k
def test_k2m(): def test_k2m():
assert tl.k2m(0.1) == (1/(2*0.1))-4 for m in [0.1, 0.5, 1.0]:
assert tl.k2m(0.5) == (1/(2*0.5))-4 assert tl.k2m(m) == (1/(2*m))-4
assert tl.k2m(1.0) == (1/(2*1.0))-4
def test_k2m_m2k():
for m in [0.1, 0.5, 1.0]:
k = tl.m2k(m)
m_converted = tl.k2m(k)
assert abs(m - m_converted) < 1e-9
def test_str2list(): def test_str2list():

973
uv.lock generated

File diff suppressed because it is too large Load diff