Fix/merge idx (#172)

* Fix: Corrected merging of idls

* Fix: Computation of drho in cases where tau_int is large compared to the chain length

* Removed unnecessary imports

* Refactor list comparisons in obs.py
This commit is contained in:
s-kuberski 2023-04-28 19:14:51 +02:00 committed by GitHub
parent 1184a0fe76
commit 65a9128a7d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 46 deletions

View file

@ -1,8 +1,6 @@
import warnings
import hashlib
import pickle
from math import gcd
from functools import reduce
import numpy as np
import autograd.numpy as anp # Thinly-wrapped numpy
from autograd import jacobian
@ -280,7 +278,7 @@ class Obs:
def _compute_drho(i):
tmp = (self.e_rho[e_name][i + 1:w_max]
+ np.concatenate([self.e_rho[e_name][i - 1:None if i - w_max // 2 < 0 else 2 * (i - w_max // 2):-1],
+ np.concatenate([self.e_rho[e_name][i - 1:None if i - w_max // 2 <= 0 else 2 * (i - w_max // 2):-1],
self.e_rho[e_name][1:max(1, w_max - 2 * i)]])
- 2 * self.e_rho[e_name][i] * self.e_rho[e_name][1:w_max - i])
self.e_drho[e_name][i] = np.sqrt(np.sum(tmp ** 2) / e_N)
@ -1022,7 +1020,7 @@ def _expand_deltas(deltas, idx, shape, gapsize):
def _merge_idx(idl):
"""Returns the union of all lists in idl as sorted list
"""Returns the union of all lists in idl as range or sorted list
Parameters
----------
@ -1030,26 +1028,22 @@ def _merge_idx(idl):
List of lists or ranges.
"""
# Use groupby to efficiently check whether all elements of idl are identical
try:
g = groupby(idl)
if next(g, True) and not next(g, False):
return idl[0]
except Exception:
pass
if _check_lists_equal(idl):
return idl[0]
if np.all([type(idx) is range for idx in idl]):
if len(set([idx[0] for idx in idl])) == 1:
idstart = min([idx.start for idx in idl])
idstop = max([idx.stop for idx in idl])
idstep = min([idx.step for idx in idl])
return range(idstart, idstop, idstep)
idunion = sorted(set().union(*idl))
return sorted(set().union(*idl))
# Check whether idunion can be expressed as range
idrange = range(idunion[0], idunion[-1] + 1, idunion[1] - idunion[0])
idtest = [list(idrange), idunion]
if _check_lists_equal(idtest):
return idrange
return idunion
def _intersection_idx(idl):
"""Returns the intersection of all lists in idl as sorted list
"""Returns the intersection of all lists in idl as range or sorted list
Parameters
----------
@ -1057,28 +1051,21 @@ def _intersection_idx(idl):
List of lists or ranges.
"""
def _lcm(*args):
"""Returns the lowest common multiple of args.
if _check_lists_equal(idl):
return idl[0]
From python 3.9 onwards the math library contains an lcm function."""
return reduce(lambda a, b: a * b // gcd(a, b), args)
idinter = sorted(set.intersection(*[set(o) for o in idl]))
# Use groupby to efficiently check whether all elements of idl are identical
# Check whether idinter can be expressed as range
try:
g = groupby(idl)
if next(g, True) and not next(g, False):
return idl[0]
except Exception:
idrange = range(idinter[0], idinter[-1] + 1, idinter[1] - idinter[0])
idtest = [list(idrange), idinter]
if _check_lists_equal(idtest):
return idrange
except IndexError:
pass
if np.all([type(idx) is range for idx in idl]):
if len(set([idx[0] for idx in idl])) == 1:
idstart = max([idx.start for idx in idl])
idstop = min([idx.stop for idx in idl])
idstep = _lcm(*[idx.step for idx in idl])
return range(idstart, idstop, idstep)
return sorted(set.intersection(*[set(o) for o in idl]))
return idinter
def _expand_deltas_for_merge(deltas, idx, shape, new_idx):
@ -1299,13 +1286,8 @@ def _reduce_deltas(deltas, idx_old, idx_new):
if type(idx_old) is range and type(idx_new) is range:
if idx_old == idx_new:
return deltas
# Use groupby to efficiently check whether all elements of idx_old and idx_new are identical
try:
g = groupby([idx_old, idx_new])
if next(g, True) and not next(g, False):
return deltas
except Exception:
pass
if _check_lists_equal([idx_old, idx_new]):
return deltas
indices = np.intersect1d(idx_old, idx_new, assume_unique=True, return_indices=True)[1]
if len(indices) < len(idx_new):
raise Exception('Error in _reduce_deltas: Config of idx_new not in idx_old')
@ -1650,3 +1632,18 @@ def _determine_gap(o, e_content, e_name):
raise Exception(f"Replica for ensemble {e_name} do not have a common spacing.", gaps)
return gap
def _check_lists_equal(idl):
'''
Use groupby to efficiently check whether all elements of idl are identical.
Returns True if all elements are equal, otherwise False.
Parameters
----------
idl : list of lists, ranges or np.ndarrays
'''
g = groupby([np.nditer(el) if isinstance(el, np.ndarray) else el for el in idl])
if next(g, True) and not next(g, False):
return True
return False