corrlib/corrlib/find.py
2025-05-15 16:18:28 +00:00

167 lines
7.2 KiB
Python

import sqlite3
import datalad.api as dl
import os
import json
import pandas as pd
import numpy as np
from .input.implementations import codes
from .tools import k2m
# this will implement the search functionality
def _project_lookup_by_alias(db, alias):
# this will lookup the project name based on the alias
conn = sqlite3.connect(db)
c = conn.cursor()
c.execute(f"SELECT * FROM 'projects' WHERE alias = '{alias}'")
results = c.fetchall()
conn.close()
if len(results) > 1:
print("Error: multiple projects found with alias " + alias)
elif len(results) == 0:
raise Exception("Error: no project found with alias " + alias)
return results[0][0]
def _project_lookup_by_id(db, uuid):
conn = sqlite3.connect(db)
c = conn.cursor()
c.execute(f"SELECT * FROM 'projects' WHERE id = '{uuid}'")
results = c.fetchall()
conn.close()
return results
def _db_lookup(db, ensemble, correlator_name,code, project=None, parameters=None, created_before=None, created_after=None, updated_before=None, updated_after=None, revision=None):
project_str = project
search_expr = f"SELECT * FROM 'backlogs' WHERE name = '{correlator_name}' AND ensemble = '{ensemble}'"
if project:
search_expr += f" AND project = '{project_str}'"
if code:
search_expr += f" AND code = '{code}'"
if parameters:
search_expr += f" AND parameters = '{parameters}'"
if created_before:
search_expr += f" AND created_at < '{created_before}'"
if created_after:
search_expr += f" AND created_at > '{created_after}'"
if updated_before:
search_expr += f" AND updated_at < '{updated_before}'"
if updated_after:
search_expr += f" AND updated_at > '{updated_after}'"
conn = sqlite3.connect(db)
results = pd.read_sql(search_expr, conn)
conn.close()
return results
def sfcf_filter(results, **kwargs):
drops = []
for ind in range(len(results)):
result = results.iloc[ind]
if result['code'] == 'sfcf':
param = json.loads(result['parameters'])
if 'offset' in kwargs:
if kwargs.get('offset') != param['offset']:
drops.append(ind)
continue
if 'quark_kappas' in kwargs:
kappas = kwargs['quark_kappas']
if (not np.isclose(kappas[0], param['quarks'][0]['mass']) or not np.isclose(kappas[1], param['quarks'][1]['mass'])):
drops.append(ind)
continue
if 'quark_masses' in kwargs:
masses = kwargs['quark_masses']
if (not np.isclose(masses[0], k2m(param['quarks'][0]['mass'])) or not np.isclose(masses[1], k2m(param['quarks'][1]['mass']))):
drops.append(ind)
continue
if 'qk1' in kwargs:
quark_kappa1 = kwargs['qk1']
if not isinstance(quark_kappa1, list):
if (not np.isclose(quark_kappa1, param['quarks'][0]['mass'])):
drops.append(ind)
continue
else:
if len(quark_kappa1) == 2:
if (quark_kappa1[0] > param['quarks'][0]['mass']) or (quark_kappa1[1] < param['quarks'][0]['mass']):
drops.append(ind)
continue
if 'qk2' in kwargs:
quark_kappa2 = kwargs['qk2']
if not isinstance(quark_kappa2, list):
if (not np.isclose(quark_kappa2, param['quarks'][1]['mass'])):
drops.append(ind)
continue
else:
if len(quark_kappa2) == 2:
if (quark_kappa2[0] > param['quarks'][1]['mass']) or (quark_kappa2[1] < param['quarks'][1]['mass']):
drops.append(ind)
continue
if 'qm1' in kwargs:
quark_mass1 = kwargs['qm1']
if not isinstance(quark_mass1, list):
if (not np.isclose(quark_mass1, k2m(param['quarks'][0]['mass']))):
drops.append(ind)
continue
else:
if len(quark_mass1) == 2:
if (quark_mass1[0] > k2m(param['quarks'][0]['mass'])) or (quark_mass1[1] < k2m(param['quarks'][0]['mass'])):
drops.append(ind)
continue
if 'qm2' in kwargs:
quark_mass2 = kwargs['qm2']
if not isinstance(quark_mass2, list):
if (not np.isclose(quark_mass2, k2m(param['quarks'][1]['mass']))):
drops.append(ind)
continue
else:
if len(quark_mass2) == 2:
if (quark_mass2[0] > k2m(param['quarks'][1]['mass'])) or (quark_mass2[1] < k2m(param['quarks'][1]['mass'])):
drops.append(ind)
continue
if 'quark_thetas' in kwargs:
quark_thetas = kwargs['quark_thetas']
if (quark_thetas[0] != param['quarks'][0]['thetas'] and quark_thetas[1] != param['quarks'][1]['thetas']) or (quark_thetas[0] != param['quarks'][1]['thetas'] and quark_thetas[1] != param['quarks'][0]['thetas']):
drops.append(ind)
continue
# careful, this is not save, when multiple contributions are present!
if 'wf1' in kwargs:
wf1 = kwargs['wf1']
if not (np.isclose(wf1[0][0], param['wf1'][0][0], 1e-8) and np.isclose(wf1[0][1][0], param['wf1'][0][1][0], 1e-8) and np.isclose(wf1[0][1][1], param['wf1'][0][1][1], 1e-8)):
drops.append(ind)
continue
if 'wf2' in kwargs:
wf2 = kwargs['wf2']
if not (np.isclose(wf2[0][0], param['wf2'][0][0], 1e-8) and np.isclose(wf2[0][1][0], param['wf2'][0][1][0], 1e-8) and np.isclose(wf2[0][1][1], param['wf2'][0][1][1], 1e-8)):
drops.append(ind)
continue
return results.drop(drops)
def find_record(path, ensemble, correlator_name, code, project=None, parameters=None, created_before=None, created_after=None, updated_before=None, updated_after=None, revision=None, **kwargs):
db = path + '/backlogger.db'
if code not in codes:
raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes))
if os.path.exists(db):
dl.get(db, dataset=path)
results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after, revision=revision)
if code == "sfcf":
results = sfcf_filter(results, **kwargs)
print("Found " + str(len(results)) + " results")
return results.reset_index()
def find_project(db, name):
return _project_lookup_by_alias(db, name)
def list_projects(path):
db = path + '/backlogger.db'
conn = sqlite3.connect(db)
c = conn.cursor()
c.execute("SELECT id,aliases FROM projects")
results = c.fetchall()
conn.close()
return results