refactor io

This commit is contained in:
Justus Kuhlmann 2025-11-20 17:13:17 +01:00
commit 4f3e78177e
Signed by: jkuhl
GPG key ID: 00ED992DD79B85A6
3 changed files with 46 additions and 24 deletions

33
corrlib/cache_io.py Normal file
View file

@ -0,0 +1,33 @@
from typing import Union, Optional
import os
import shutil
def drop_cache_files(path: str, fs: Optional[list[str]]=None):
cache_dir = os.path.join(path, ".cache")
if fs is None:
fs = os.listdir(cache_dir)
for f in fs:
shutil.rmtree(os.path.join(cache_dir, f))
def cache_dir(path, file):
cache_path_list = [path]
cache_path_list.append(".cache")
cache_path_list.extend(file.split("/")[1:])
cache_path = "/".join(cache_path_list)
return cache_path
def cache_path(path, file, hash, key):
cache_path = os.path.join(cache_dir(path, file), hash, key)
return cache_path
def is_in_cache(path, record, hash):
if os.file.exists(cache_path(path, file, hash, key)):
return True
else:
return False

View file

@ -4,14 +4,15 @@ import datalad.api as dl
import sqlite3 import sqlite3
from .input import sfcf,openQCD from .input import sfcf,openQCD
import json import json
from typing import Union from typing import Union, Optional
from pyerrors import Obs, Corr, dump_object, load_object from pyerrors import Obs, Corr, dump_object, load_object
from hashlib import sha256, sha1 from hashlib import sha256, sha1
from .tools import cached from .tools import cached, record2name_key
import shutil import shutil
from .caching import cache_path, cache_dir
def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=None): def write_measurement(path, ensemble, measurement, uuid, code, parameter_file: Optional[str]=None):
""" """
Write a measurement to the backlog. Write a measurement to the backlog.
If the file for the measurement already exists, update the measurement. If the file for the measurement already exists, update the measurement.
@ -115,7 +116,7 @@ def load_record(path: str, meas_path: str):
return load_records(path, [meas_path])[0] return load_records(path, [meas_path])[0]
def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union[Corr, Obs]]: def load_records(path: str, record_paths: list[str], preloaded = {}) -> list[Union[Corr, Obs]]:
""" """
Load a list of records by their paths. Load a list of records by their paths.
@ -131,11 +132,10 @@ def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union
List List
""" """
needed_data: dict[str, list[str]] = {} needed_data: dict[str, list[str]] = {}
for mpath in meas_paths: for rpath in record_paths:
file = mpath.split("::")[0] file, key = record2name_key(rpath)
if file not in needed_data.keys(): if file not in needed_data.keys():
needed_data[file] = [] needed_data[file] = []
key = mpath.split("::")[1]
needed_data[file].append(key) needed_data[file].append(key)
returned_data: list = [] returned_data: list = []
for file in needed_data.keys(): for file in needed_data.keys():
@ -153,19 +153,6 @@ def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union
return returned_data return returned_data
def cache_dir(path, file):
cache_path_list = [path]
cache_path_list.append(".cache")
cache_path_list.extend(file.split("/")[1:])
cache_path = "/".join(cache_path_list)
return cache_path
def cache_path(path, file, key):
cache_path = os.path.join(cache_dir(path, file), key)
return cache_path
def preload(path: str, file: str): def preload(path: str, file: str):
dl.get(os.path.join(path, file), dataset=path) dl.get(os.path.join(path, file), dataset=path)
filedict = pj.load_json_dict(os.path.join(path, file)) filedict = pj.load_json_dict(os.path.join(path, file))
@ -196,7 +183,3 @@ def drop_record(path: str, meas_path: str):
else: else:
raise ValueError("This measurement does not exist as a file!") raise ValueError("This measurement does not exist as a file!")
def drop_cache(path: str):
cache_dir = os.path.join(path, ".cache")
for f in os.listdir(cache_dir):
shutil.rmtree(os.path.join(cache_dir, f))

View file

@ -16,3 +16,9 @@ def m2k(m):
def k2m(k): def k2m(k):
return (1/(2*k))-4 return (1/(2*k))-4
def record2name_key(record_path: str):
file = record_path.split("::")[0]
key = record_path.split("::")[1]
return file, key