Fix/merge idx (#172)

* Fix: Corrected merging of idls

* Fix: Computation of drho in cases where tau_int is large compared to the chain length

* Removed unnecessary imports

* Refactor list comparisons in obs.py
This commit is contained in:
s-kuberski 2023-04-28 19:14:51 +02:00 committed by GitHub
parent 1184a0fe76
commit 65a9128a7d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 46 deletions

View file

@ -1,8 +1,6 @@
import warnings import warnings
import hashlib import hashlib
import pickle import pickle
from math import gcd
from functools import reduce
import numpy as np import numpy as np
import autograd.numpy as anp # Thinly-wrapped numpy import autograd.numpy as anp # Thinly-wrapped numpy
from autograd import jacobian from autograd import jacobian
@ -280,7 +278,7 @@ class Obs:
def _compute_drho(i): def _compute_drho(i):
tmp = (self.e_rho[e_name][i + 1:w_max] tmp = (self.e_rho[e_name][i + 1:w_max]
+ np.concatenate([self.e_rho[e_name][i - 1:None if i - w_max // 2 < 0 else 2 * (i - w_max // 2):-1], + np.concatenate([self.e_rho[e_name][i - 1:None if i - w_max // 2 <= 0 else 2 * (i - w_max // 2):-1],
self.e_rho[e_name][1:max(1, w_max - 2 * i)]]) self.e_rho[e_name][1:max(1, w_max - 2 * i)]])
- 2 * self.e_rho[e_name][i] * self.e_rho[e_name][1:w_max - i]) - 2 * self.e_rho[e_name][i] * self.e_rho[e_name][1:w_max - i])
self.e_drho[e_name][i] = np.sqrt(np.sum(tmp ** 2) / e_N) self.e_drho[e_name][i] = np.sqrt(np.sum(tmp ** 2) / e_N)
@ -1022,7 +1020,7 @@ def _expand_deltas(deltas, idx, shape, gapsize):
def _merge_idx(idl): def _merge_idx(idl):
"""Returns the union of all lists in idl as sorted list """Returns the union of all lists in idl as range or sorted list
Parameters Parameters
---------- ----------
@ -1030,26 +1028,22 @@ def _merge_idx(idl):
List of lists or ranges. List of lists or ranges.
""" """
# Use groupby to efficiently check whether all elements of idl are identical if _check_lists_equal(idl):
try:
g = groupby(idl)
if next(g, True) and not next(g, False):
return idl[0] return idl[0]
except Exception:
pass
if np.all([type(idx) is range for idx in idl]): idunion = sorted(set().union(*idl))
if len(set([idx[0] for idx in idl])) == 1:
idstart = min([idx.start for idx in idl])
idstop = max([idx.stop for idx in idl])
idstep = min([idx.step for idx in idl])
return range(idstart, idstop, idstep)
return sorted(set().union(*idl)) # Check whether idunion can be expressed as range
idrange = range(idunion[0], idunion[-1] + 1, idunion[1] - idunion[0])
idtest = [list(idrange), idunion]
if _check_lists_equal(idtest):
return idrange
return idunion
def _intersection_idx(idl): def _intersection_idx(idl):
"""Returns the intersection of all lists in idl as sorted list """Returns the intersection of all lists in idl as range or sorted list
Parameters Parameters
---------- ----------
@ -1057,28 +1051,21 @@ def _intersection_idx(idl):
List of lists or ranges. List of lists or ranges.
""" """
def _lcm(*args): if _check_lists_equal(idl):
"""Returns the lowest common multiple of args.
From python 3.9 onwards the math library contains an lcm function."""
return reduce(lambda a, b: a * b // gcd(a, b), args)
# Use groupby to efficiently check whether all elements of idl are identical
try:
g = groupby(idl)
if next(g, True) and not next(g, False):
return idl[0] return idl[0]
except Exception:
idinter = sorted(set.intersection(*[set(o) for o in idl]))
# Check whether idinter can be expressed as range
try:
idrange = range(idinter[0], idinter[-1] + 1, idinter[1] - idinter[0])
idtest = [list(idrange), idinter]
if _check_lists_equal(idtest):
return idrange
except IndexError:
pass pass
if np.all([type(idx) is range for idx in idl]): return idinter
if len(set([idx[0] for idx in idl])) == 1:
idstart = max([idx.start for idx in idl])
idstop = min([idx.stop for idx in idl])
idstep = _lcm(*[idx.step for idx in idl])
return range(idstart, idstop, idstep)
return sorted(set.intersection(*[set(o) for o in idl]))
def _expand_deltas_for_merge(deltas, idx, shape, new_idx): def _expand_deltas_for_merge(deltas, idx, shape, new_idx):
@ -1299,13 +1286,8 @@ def _reduce_deltas(deltas, idx_old, idx_new):
if type(idx_old) is range and type(idx_new) is range: if type(idx_old) is range and type(idx_new) is range:
if idx_old == idx_new: if idx_old == idx_new:
return deltas return deltas
# Use groupby to efficiently check whether all elements of idx_old and idx_new are identical if _check_lists_equal([idx_old, idx_new]):
try:
g = groupby([idx_old, idx_new])
if next(g, True) and not next(g, False):
return deltas return deltas
except Exception:
pass
indices = np.intersect1d(idx_old, idx_new, assume_unique=True, return_indices=True)[1] indices = np.intersect1d(idx_old, idx_new, assume_unique=True, return_indices=True)[1]
if len(indices) < len(idx_new): if len(indices) < len(idx_new):
raise Exception('Error in _reduce_deltas: Config of idx_new not in idx_old') raise Exception('Error in _reduce_deltas: Config of idx_new not in idx_old')
@ -1650,3 +1632,18 @@ def _determine_gap(o, e_content, e_name):
raise Exception(f"Replica for ensemble {e_name} do not have a common spacing.", gaps) raise Exception(f"Replica for ensemble {e_name} do not have a common spacing.", gaps)
return gap return gap
def _check_lists_equal(idl):
'''
Use groupby to efficiently check whether all elements of idl are identical.
Returns True if all elements are equal, otherwise False.
Parameters
----------
idl : list of lists, ranges or np.ndarrays
'''
g = groupby([np.nditer(el) if isinstance(el, np.ndarray) else el for el in idl])
if next(g, True) and not next(g, False):
return True
return False

View file

@ -537,7 +537,21 @@ def test_correlate():
def test_merge_idx(): def test_merge_idx():
assert pe.obs._merge_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 10) assert pe.obs._merge_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 10)
assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6250, 50) assert isinstance(pe.obs._merge_idx([range(10, 1010, 10), range(10, 1010, 50)]), range)
assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6001, 50)
assert isinstance(pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]), range)
assert pe.obs._merge_idx([range(1, 1011, 2), range(1, 1010, 1)]) == range(1, 1010, 1)
assert isinstance(pe.obs._merge_idx([range(1, 1011, 2), range(1, 1010, 1)]), range)
assert pe.obs._merge_idx([range(1, 100, 2), range(2, 100, 2)]) == range(1, 100, 1)
assert isinstance(pe.obs._merge_idx([range(1, 100, 2), range(2, 100, 2)]), range)
for j in range(5):
idll = [range(1, int(round(np.random.uniform(300, 700))), int(round(np.random.uniform(1, 14)))) for i in range(10)]
assert pe.obs._merge_idx(idll) == sorted(set().union(*idll))
for j in range(5):
idll = [range(int(round(np.random.uniform(1, 28))), int(round(np.random.uniform(300, 700))), int(round(np.random.uniform(1, 14)))) for i in range(10)]
assert pe.obs._merge_idx(idll) == sorted(set().union(*idll))
idl = [list(np.arange(1, 14)) + list(range(16, 100, 4)), range(4, 604, 4), [2, 4, 5, 6, 8, 9, 12, 24], range(1, 20, 1), range(50, 789, 7)] idl = [list(np.arange(1, 14)) + list(range(16, 100, 4)), range(4, 604, 4), [2, 4, 5, 6, 8, 9, 12, 24], range(1, 20, 1), range(50, 789, 7)]
new_idx = pe.obs._merge_idx(idl) new_idx = pe.obs._merge_idx(idl)
@ -550,10 +564,21 @@ def test_intersection_idx():
assert pe.obs._intersection_idx([range(1, 100), range(1, 100), range(1, 100)]) == range(1, 100) assert pe.obs._intersection_idx([range(1, 100), range(1, 100), range(1, 100)]) == range(1, 100)
assert pe.obs._intersection_idx([range(1, 100, 10), range(1, 100, 2)]) == range(1, 100, 10) assert pe.obs._intersection_idx([range(1, 100, 10), range(1, 100, 2)]) == range(1, 100, 10)
assert pe.obs._intersection_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 50) assert pe.obs._intersection_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 50)
assert pe.obs._intersection_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6050, 250) assert pe.obs._intersection_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6001, 250)
assert pe.obs._intersection_idx([range(1, 1011, 2), range(1, 1010, 1)]) == range(1, 1010, 2)
idll = [range(1, 100, 2), range(5, 105, 1)]
assert pe.obs._intersection_idx(idll) == range(5, 100, 2)
assert isinstance(pe.obs._intersection_idx(idll), range)
idll = [range(1, 100, 2), list(range(5, 105, 1))]
assert pe.obs._intersection_idx(idll) == range(5, 100, 2)
assert isinstance(pe.obs._intersection_idx(idll), range)
for ids in [[list(range(1, 80, 3)), list(range(1, 100, 2))], [range(1, 80, 3), range(1, 100, 2), range(1, 100, 7)]]: for ids in [[list(range(1, 80, 3)), list(range(1, 100, 2))], [range(1, 80, 3), range(1, 100, 2), range(1, 100, 7)]]:
assert list(pe.obs._intersection_idx(ids)) == pe.obs._intersection_idx([list(o) for o in ids]) interlist = pe.obs._intersection_idx([list(o) for o in ids])
listinter = list(pe.obs._intersection_idx(ids))
assert len(interlist) == len(listinter)
assert all([o in listinter for o in interlist])
assert all([o in interlist for o in listinter])
def test_merge_intersection(): def test_merge_intersection():
@ -733,6 +758,15 @@ def test_gamma_method_irregular():
with pytest.raises(Exception): with pytest.raises(Exception):
my_obs.gm() my_obs.gm()
# check cases where tau is large compared to the chain length
N = 15
for i in range(10):
arr = np.random.normal(1, .2, size=N)
for rho in .1 * np.arange(20):
carr = gen_autocorrelated_array(arr, rho)
a = pe.Obs([carr], ['a'])
a.gm()
def test_irregular_gapped_dtauint(): def test_irregular_gapped_dtauint():
my_idl = list(range(0, 5010, 10)) my_idl = list(range(0, 5010, 10))