corrlib/corrlib/meas_io.py
2025-05-16 15:40:26 +00:00

202 lines
7.3 KiB
Python

from pyerrors.input import json as pj
import os
import datalad.api as dl
import sqlite3
from .input import sfcf,openQCD
import json
from typing import Union
from pyerrors import Obs, Corr, dump_object, load_object
from hashlib import sha256
from .tools import cached
import shutil
def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=None):
"""
Write a measurement to the backlog.
If the file for the measurement already exists, update the measurement.
Parameters
----------
path: str
The path to the backlogger folder.
ensemble: str
The ensemble of the measurement.
measurement: dict
Measurements to be captured in the backlogging system.
uuid: str
The uuid of the project.
"""
db = os.path.join(path, 'backlogger.db')
dl.unlock(db, dataset=path)
conn = sqlite3.connect(db)
c = conn.cursor()
files = []
for corr in measurement.keys():
file_in_archive = os.path.join('.', 'archive', ensemble, corr, uuid + '.json.gz')
file = os.path.join(path, file_in_archive)
files.append(file)
known_meas = {}
if not os.path.exists(os.path.join(path, '.', 'archive', ensemble, corr)):
os.makedirs(os.path.join(path, '.', 'archive', ensemble, corr))
else:
if os.path.exists(file):
dl.unlock(file, dataset=path)
known_meas = pj.load_json_dict(file)
if code == "sfcf":
parameters = sfcf.read_param(path, uuid, parameter_file)
pars = {}
subkeys = list(measurement[corr].keys())
for subkey in subkeys:
pars[subkey] = sfcf.get_specs(corr + "/" + subkey, parameters)
elif code == "openQCD":
ms_type = list(measurement.keys())[0]
if ms_type == 'ms1':
parameters = openQCD.read_ms1_param(path, uuid, parameter_file)
pars = {}
subkeys = []
for i in range(len(parameters["rw_fcts"])):
par_list = []
for k in parameters["rw_fcts"][i].keys():
par_list.append(str(parameters["rw_fcts"][i][k]))
subkey = "/".join(par_list)
subkeys.append(subkey)
pars[subkey] = json.dumps(parameters["rw_fcts"][i])
elif ms_type in ['t0', 't1']:
if parameter_file is not None:
parameters = openQCD.read_ms3_param(path, uuid, parameter_file)
else:
parameters = {}
for rwp in ["integrator", "eps", "ntot", "dnms"]:
parameters[rwp] = "Unknown"
pars = {}
subkeys = []
par_list= []
for k in ["integrator", "eps", "ntot", "dnms"]:
par_list.append(str(parameters[k]))
subkey = "/".join(par_list)
subkeys = [subkey]
pars[subkey] = json.dumps(parameters)
for subkey in subkeys:
parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest()
meas_path = file_in_archive + "::" + parHash
known_meas[parHash] = measurement[corr][subkey]
if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path,)).fetchone() is not None:
c.execute("UPDATE backlogs SET updated_at = datetime('now') WHERE path = ?", (meas_path, ))
else:
c.execute("INSERT INTO backlogs (name, ensemble, code, path, project, parameters, parameter_file, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))",
(corr, ensemble, code, meas_path, uuid, pars[subkey], parameter_file))
conn.commit()
pj.dump_dict_to_json(known_meas, file)
files.append(path + '/backlogger.db')
conn.close()
dl.save(files, message="Add measurements to database", dataset=path)
def load_record(path: str, meas_path: str):
"""
Load a list of records by their paths.
Parameters
----------
path: str
Path of the correlator library.
meas_path: str
The path to the correlator in the backlog system.
Returns
-------
co : Corr or Obs
The correlator in question.
"""
return load_records(path, [meas_path])[0]
def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union[Corr, Obs]]:
"""
Load a list of records by their paths.
Parameters
----------
path: str
Path of the correlator library.
meas_paths: list[str]
A list of the paths to the correlator in the backlog system.
Returns
-------
List
"""
needed_data: dict[str, list[str]] = {}
for mpath in meas_paths:
file = mpath.split("::")[0]
if file not in needed_data.keys():
needed_data[file] = []
key = mpath.split("::")[1]
needed_data[file].append(key)
returned_data: list = []
for file in needed_data.keys():
for key in list(needed_data[file]):
if os.path.exists(cache_path(path, file, key) + ".p"):
returned_data.append(load_object(cache_path(path, file, key) + ".p"))
else:
if file not in preloaded:
preloaded[file] = preload(path, file)
returned_data.append(preloaded[file][key])
if cached:
if not os.path.exists(cache_dir(path, file)):
os.makedirs(cache_dir(path, file))
dump_object(preloaded[file][key], cache_path(path, file, key))
return returned_data
def cache_dir(path, file):
cache_path_list = [path]
cache_path_list.append(".cache")
cache_path_list.extend(file.split("/")[1:])
cache_path = "/".join(cache_path_list)
return cache_path
def cache_path(path, file, key):
cache_path = os.path.join(cache_dir(path, file), key)
return cache_path
def preload(path: str, file: str):
dl.get(os.path.join(path, file), dataset=path)
filedict = pj.load_json_dict(os.path.join(path, file))
return filedict
def drop_record(path: str, meas_path: str):
file_in_archive = meas_path.split("::")[0]
file = os.path.join(path, file_in_archive)
db = os.path.join(path, 'backlogger.db')
sub_key = meas_path.split("::")[1]
dl.unlock(db, dataset=path)
conn = sqlite3.connect(db)
c = conn.cursor()
if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path, )).fetchone() is not None:
c.execute("DELETE FROM backlogs WHERE path = ?", (meas_path, ))
else:
raise ValueError("This measurement does not exist as an entry!")
conn.commit()
known_meas = pj.load_json_dict(file)
if sub_key in known_meas:
del known_meas[sub_key]
dl.unlock(file, dataset=path)
pj.dump_dict_to_json(known_meas, file)
dl.save([db, file], message="Drop measurements to database", dataset=path)
return
else:
raise ValueError("This measurement does not exist as a file!")
def drop_cache(path: str):
cache_dir = os.path.join(path, ".cache")
for f in os.listdir(cache_dir):
shutil.rmtree(os.path.join(cache_dir, f))