corrlib/corrlib/find.py
Justus Kuhlmann 6d1f8f7f1b
Some checks failed
Mypy / mypy (push) Successful in 1m14s
Pytest / pytest (3.12) (push) Successful in 1m20s
Pytest / pytest (3.13) (push) Successful in 1m13s
Pytest / pytest (3.14) (push) Successful in 1m15s
Ruff / ruff (push) Failing after 1m1s
add NotImplemented warning for openQCD filter
2026-04-10 10:28:28 +02:00

388 lines
14 KiB
Python

import sqlite3
import os
import json
import pandas as pd
import numpy as np
from .input.implementations import codes
from .tools import k2m, get_db_file
from .tracker import get
from .integrity import check_time_validity
from typing import Any, Optional, Union
from pathlib import Path
import datetime as dt
from collections.abc import Callable
import warnings
def _project_lookup_by_alias(db: Path, alias: str) -> str:
"""
Lookup a projects UUID by its (human-readable) alias.
Parameters
----------
db: str
The database to look up the project.
alias: str
The alias to look up.
Returns
-------
uuid: str
The UUID of the project with the given alias.
"""
conn = sqlite3.connect(db)
c = conn.cursor()
c.execute(f"SELECT * FROM 'projects' WHERE aliases = '{alias}'")
results = c.fetchall()
conn.close()
if len(results)>1:
print("Error: multiple projects found with alias " + alias)
elif len(results) == 0:
raise Exception("Error: no project found with alias " + alias)
return str(results[0][0])
def _project_lookup_by_id(db: Path, uuid: str) -> list[tuple[str, ...]]:
"""
Return the project information available in the database by UUID.
Parameters
----------
db: str
The database to look up the project.
uuid: str
The uuid of the project in question.
Returns
-------
results: list
The row of the project in the database.
"""
conn = sqlite3.connect(db)
c = conn.cursor()
c.execute(f"SELECT * FROM 'projects' WHERE id = '{uuid}'")
results = c.fetchall()
conn.close()
return results
def _time_filter(results: pd.DataFrame, created_before: Optional[str]=None, created_after: Optional[Any]=None, updated_before: Optional[Any]=None, updated_after: Optional[Any]=None) -> pd.DataFrame:
"""
Filter the results from the database in terms of the creation and update times.
Parameters
----------
results: pd.DataFrame
The dataframe holding the unfilteres results from the database.
created_before: str
Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The creation date has to be truly before the date and time given.
created_after: str
Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The creation date has to be truly after the date and time given.
updated_before: str
Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The date of the last update has to be truly before the date and time given.
updated_after: str
Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The date of the last update has to be truly after the date and time given.
"""
drops = []
for ind in range(len(results)):
result = results.iloc[ind]
created_at = dt.datetime.fromisoformat(result['created_at'])
updated_at = dt.datetime.fromisoformat(result['updated_at'])
db_times_valid = check_time_validity(created_at=created_at, updated_at=updated_at)
if not db_times_valid:
raise ValueError('Time stamps not valid for result with path', result["path"])
if created_before is not None:
date_created_before = dt.datetime.fromisoformat(created_before)
if date_created_before < created_at:
drops.append(ind)
continue
if created_after is not None:
date_created_after = dt.datetime.fromisoformat(created_after)
if date_created_after > created_at:
drops.append(ind)
continue
if updated_before is not None:
date_updated_before = dt.datetime.fromisoformat(updated_before)
if date_updated_before < updated_at:
drops.append(ind)
continue
if updated_after is not None:
date_updated_after = dt.datetime.fromisoformat(updated_after)
if date_updated_after > updated_at:
drops.append(ind)
continue
return results.drop(drops)
def _db_lookup(db: Path, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None) -> pd.DataFrame:
"""
Look up a correlator record in the database by the data given to the method.
Parameters
----------
db: str
The database to look up the record.
ensemble: str
The ensemble the record is associated with.
correlator_name: str
The name of the correlator in question.
code: str
The name of the code which was used to calculate the correlator.
project: str, optional
The UUID of the project the correlator was calculated in.
parameters: str, optional
A dictionary holding the exact parameters for the measurement that are held in the database.
created_before: str, optional
Timestamp string before which the meaurement has been created.
created_after: str, optional
Timestamp string after which the meaurement has been created.
updated_before: str, optional
Timestamp string before which the meaurement has been updated.
updated_after: str, optional
Timestamp string after which the meaurement has been updated.
Returns
-------
results: pd.DataFrame
A pandas DataFrame holding the information received form the DB query.
"""
project_str = project
search_expr = f"SELECT * FROM 'backlogs' WHERE name = '{correlator_name}' AND ensemble = '{ensemble}'"
if project:
search_expr += f" AND project = '{project_str}'"
if code:
search_expr += f" AND code = '{code}'"
if parameters:
search_expr += f" AND parameters = '{parameters}'"
conn = sqlite3.connect(db)
results = pd.read_sql(search_expr, conn)
conn.close()
return results
def _sfcf_drop(param: dict[str, Any], **kwargs: Any) -> bool:
if 'offset' in kwargs:
if kwargs.get('offset') != param['offset']:
return True
if 'quark_kappas' in kwargs:
kappas = kwargs['quark_kappas']
if (not np.isclose(kappas[0], param['quarks'][0]['mass']) or not np.isclose(kappas[1], param['quarks'][1]['mass'])):
return True
if 'quark_masses' in kwargs:
masses = kwargs['quark_masses']
if (not np.isclose(masses[0], k2m(param['quarks'][0]['mass'])) or not np.isclose(masses[1], k2m(param['quarks'][1]['mass']))):
return True
if 'qk1' in kwargs:
quark_kappa1 = kwargs['qk1']
if not isinstance(quark_kappa1, list):
if (not np.isclose(quark_kappa1, param['quarks'][0]['mass'])):
return True
else:
if len(quark_kappa1) == 2:
if (quark_kappa1[0] > param['quarks'][0]['mass']) or (quark_kappa1[1] < param['quarks'][0]['mass']):
return True
else:
raise ValueError("quark_kappa1 has to have length 2")
if 'qk2' in kwargs:
quark_kappa2 = kwargs['qk2']
if not isinstance(quark_kappa2, list):
if (not np.isclose(quark_kappa2, param['quarks'][1]['mass'])):
return True
else:
if len(quark_kappa2) == 2:
if (quark_kappa2[0] > param['quarks'][1]['mass']) or (quark_kappa2[1] < param['quarks'][1]['mass']):
return True
else:
raise ValueError("quark_kappa2 has to have length 2")
if 'qm1' in kwargs:
quark_mass1 = kwargs['qm1']
if not isinstance(quark_mass1, list):
if (not np.isclose(quark_mass1, k2m(param['quarks'][0]['mass']))):
return True
else:
if len(quark_mass1) == 2:
if (quark_mass1[0] > k2m(param['quarks'][0]['mass'])) or (quark_mass1[1] < k2m(param['quarks'][0]['mass'])):
return True
else:
raise ValueError("quark_mass1 has to have length 2")
if 'qm2' in kwargs:
quark_mass2 = kwargs['qm2']
if not isinstance(quark_mass2, list):
if (not np.isclose(quark_mass2, k2m(param['quarks'][1]['mass']))):
return True
else:
if len(quark_mass2) == 2:
if (quark_mass2[0] > k2m(param['quarks'][1]['mass'])) or (quark_mass2[1] < k2m(param['quarks'][1]['mass'])):
return True
else:
raise ValueError("quark_mass2 has to have length 2")
if 'quark_thetas' in kwargs:
quark_thetas = kwargs['quark_thetas']
if (quark_thetas[0] != param['quarks'][0]['thetas'] and quark_thetas[1] != param['quarks'][1]['thetas']) or (quark_thetas[0] != param['quarks'][1]['thetas'] and quark_thetas[1] != param['quarks'][0]['thetas']):
return True
# careful, this is not save, when multiple contributions are present!
if 'wf1' in kwargs:
wf1 = kwargs['wf1']
if not (np.isclose(wf1[0][0], param['wf1'][0][0], 1e-8) and np.isclose(wf1[0][1][0], param['wf1'][0][1][0], 1e-8) and np.isclose(wf1[0][1][1], param['wf1'][0][1][1], 1e-8)):
return True
if 'wf2' in kwargs:
wf2 = kwargs['wf2']
if not (np.isclose(wf2[0][0], param['wf2'][0][0], 1e-8) and np.isclose(wf2[0][1][0], param['wf2'][0][1][0], 1e-8) and np.isclose(wf2[0][1][1], param['wf2'][0][1][1], 1e-8)):
return True
return False
def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
r"""
Filter method for the Database entries holding SFCF calculations.
Parameters
----------
results: pd.DataFrame
The unfiltered pandas DataFrame holding the entries from the database.
offset: list[float], optional
quark_kappas: list[float]
quarks_masses: list[float]
qk1: float, optional
Mass parameter $\kappa_1$ of the first quark.
qk2: float, optional
Mass parameter $\kappa_2$ of the first quark.
qm1: float, optional
Bare quark mass $m_1$ of the first quark.
qm2: float, optional
Bare quark mass $m_2$ of the first quark.
quarks_thetas: list[list[float]], optional
wf1: optional
wf2: optional
Results
-------
results: pd.DataFrame
The filtered DataFrame, only holding the records that fit to the parameters given.
"""
drops = []
for ind in range(len(results)):
result = results.iloc[ind]
param = json.loads(result['parameters'])
if _sfcf_drop(param, **kwargs):
drops.append(ind)
return results.drop(drops)
def openQCD_filter(results:pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
"""
Filter for parameters of openQCD.
Parameters
----------
results: pd.DataFrame
The unfiltered list of results from the database.
Returns
-------
results: pd.DataFrame
The filtered results.
"""
warnings.warn("A filter for openQCD parameters is no implemented yet.", Warning)
return results
def _code_filter(results: pd.DataFrame, code: str, **kwargs: Any) -> pd.DataFrame:
"""
Abstraction of the filters for the different codes that are available.
At the moment, only openQCD and SFCF are known.
The possible key words for the parameters can be seen in the descriptionso f the code-specific filters.
Parameters
----------
results: pd.DataFrame
The unfiltered list of results from the database.
code: str
The name of the code that produced the record at hand.
kwargs:
The keyworkd args that are handed over to the code-specific filters.
Returns
-------
results: pd.DataFrame
The filtered results.
"""
if code == "sfcf":
return sfcf_filter(results, **kwargs)
elif code == "openQCD":
return openQCD_filter(results, **kwargs)
else:
raise ValueError(f"Code {code} is not known.")
def find_record(path: Path, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None,
created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None,
revision: Optional[str]=None,
customFilter: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
**kwargs: Any) -> pd.DataFrame:
db_file = get_db_file(path)
db = path / db_file
if code not in codes:
raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes))
get(path, db_file)
results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters)
if any([arg is not None for arg in [created_before, created_after, updated_before, updated_after]]):
results = _time_filter(results, created_before, created_after, updated_before, updated_after)
results = _code_filter(results, code, **kwargs)
if customFilter is not None:
results = customFilter(results)
print("Found " + str(len(results)) + " result" + ("s" if len(results)>1 else ""))
return results.reset_index()
def find_project(path: Path, name: str) -> str:
"""
Find a project by it's human readable name.
Parameters
----------
path: str
The path of the library.
name: str
The name of the project to look for in the library.
Returns
-------
uuid: str
The uuid of the project in question.
"""
db_file = get_db_file(path)
get(path, db_file)
return _project_lookup_by_alias(path / db_file, name)
def list_projects(path: Path) -> list[tuple[str, str]]:
"""
List all projects known to the library.
Parameters
----------
path: str
The path of the library.
Returns
-------
results: list[Any]
The projects known to the library.
"""
db_file = get_db_file(path)
get(path, db_file)
conn = sqlite3.connect(os.path.join(path, db_file))
c = conn.cursor()
c.execute("SELECT id,aliases FROM projects")
results = c.fetchall()
conn.close()
return results