mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-15 06:40:24 +01:00
Fix/merge idx (#172)
* Fix: Corrected merging of idls * Fix: Computation of drho in cases where tau_int is large compared to the chain length * Removed unnecessary imports * Refactor list comparisons in obs.py
This commit is contained in:
parent
1184a0fe76
commit
65a9128a7d
2 changed files with 77 additions and 46 deletions
|
@ -1,8 +1,6 @@
|
|||
import warnings
|
||||
import hashlib
|
||||
import pickle
|
||||
from math import gcd
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
import autograd.numpy as anp # Thinly-wrapped numpy
|
||||
from autograd import jacobian
|
||||
|
@ -280,7 +278,7 @@ class Obs:
|
|||
|
||||
def _compute_drho(i):
|
||||
tmp = (self.e_rho[e_name][i + 1:w_max]
|
||||
+ np.concatenate([self.e_rho[e_name][i - 1:None if i - w_max // 2 < 0 else 2 * (i - w_max // 2):-1],
|
||||
+ np.concatenate([self.e_rho[e_name][i - 1:None if i - w_max // 2 <= 0 else 2 * (i - w_max // 2):-1],
|
||||
self.e_rho[e_name][1:max(1, w_max - 2 * i)]])
|
||||
- 2 * self.e_rho[e_name][i] * self.e_rho[e_name][1:w_max - i])
|
||||
self.e_drho[e_name][i] = np.sqrt(np.sum(tmp ** 2) / e_N)
|
||||
|
@ -1022,7 +1020,7 @@ def _expand_deltas(deltas, idx, shape, gapsize):
|
|||
|
||||
|
||||
def _merge_idx(idl):
|
||||
"""Returns the union of all lists in idl as sorted list
|
||||
"""Returns the union of all lists in idl as range or sorted list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -1030,26 +1028,22 @@ def _merge_idx(idl):
|
|||
List of lists or ranges.
|
||||
"""
|
||||
|
||||
# Use groupby to efficiently check whether all elements of idl are identical
|
||||
try:
|
||||
g = groupby(idl)
|
||||
if next(g, True) and not next(g, False):
|
||||
return idl[0]
|
||||
except Exception:
|
||||
pass
|
||||
if _check_lists_equal(idl):
|
||||
return idl[0]
|
||||
|
||||
if np.all([type(idx) is range for idx in idl]):
|
||||
if len(set([idx[0] for idx in idl])) == 1:
|
||||
idstart = min([idx.start for idx in idl])
|
||||
idstop = max([idx.stop for idx in idl])
|
||||
idstep = min([idx.step for idx in idl])
|
||||
return range(idstart, idstop, idstep)
|
||||
idunion = sorted(set().union(*idl))
|
||||
|
||||
return sorted(set().union(*idl))
|
||||
# Check whether idunion can be expressed as range
|
||||
idrange = range(idunion[0], idunion[-1] + 1, idunion[1] - idunion[0])
|
||||
idtest = [list(idrange), idunion]
|
||||
if _check_lists_equal(idtest):
|
||||
return idrange
|
||||
|
||||
return idunion
|
||||
|
||||
|
||||
def _intersection_idx(idl):
|
||||
"""Returns the intersection of all lists in idl as sorted list
|
||||
"""Returns the intersection of all lists in idl as range or sorted list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -1057,28 +1051,21 @@ def _intersection_idx(idl):
|
|||
List of lists or ranges.
|
||||
"""
|
||||
|
||||
def _lcm(*args):
|
||||
"""Returns the lowest common multiple of args.
|
||||
if _check_lists_equal(idl):
|
||||
return idl[0]
|
||||
|
||||
From python 3.9 onwards the math library contains an lcm function."""
|
||||
return reduce(lambda a, b: a * b // gcd(a, b), args)
|
||||
idinter = sorted(set.intersection(*[set(o) for o in idl]))
|
||||
|
||||
# Use groupby to efficiently check whether all elements of idl are identical
|
||||
# Check whether idinter can be expressed as range
|
||||
try:
|
||||
g = groupby(idl)
|
||||
if next(g, True) and not next(g, False):
|
||||
return idl[0]
|
||||
except Exception:
|
||||
idrange = range(idinter[0], idinter[-1] + 1, idinter[1] - idinter[0])
|
||||
idtest = [list(idrange), idinter]
|
||||
if _check_lists_equal(idtest):
|
||||
return idrange
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
if np.all([type(idx) is range for idx in idl]):
|
||||
if len(set([idx[0] for idx in idl])) == 1:
|
||||
idstart = max([idx.start for idx in idl])
|
||||
idstop = min([idx.stop for idx in idl])
|
||||
idstep = _lcm(*[idx.step for idx in idl])
|
||||
return range(idstart, idstop, idstep)
|
||||
|
||||
return sorted(set.intersection(*[set(o) for o in idl]))
|
||||
return idinter
|
||||
|
||||
|
||||
def _expand_deltas_for_merge(deltas, idx, shape, new_idx):
|
||||
|
@ -1299,13 +1286,8 @@ def _reduce_deltas(deltas, idx_old, idx_new):
|
|||
if type(idx_old) is range and type(idx_new) is range:
|
||||
if idx_old == idx_new:
|
||||
return deltas
|
||||
# Use groupby to efficiently check whether all elements of idx_old and idx_new are identical
|
||||
try:
|
||||
g = groupby([idx_old, idx_new])
|
||||
if next(g, True) and not next(g, False):
|
||||
return deltas
|
||||
except Exception:
|
||||
pass
|
||||
if _check_lists_equal([idx_old, idx_new]):
|
||||
return deltas
|
||||
indices = np.intersect1d(idx_old, idx_new, assume_unique=True, return_indices=True)[1]
|
||||
if len(indices) < len(idx_new):
|
||||
raise Exception('Error in _reduce_deltas: Config of idx_new not in idx_old')
|
||||
|
@ -1650,3 +1632,18 @@ def _determine_gap(o, e_content, e_name):
|
|||
raise Exception(f"Replica for ensemble {e_name} do not have a common spacing.", gaps)
|
||||
|
||||
return gap
|
||||
|
||||
|
||||
def _check_lists_equal(idl):
|
||||
'''
|
||||
Use groupby to efficiently check whether all elements of idl are identical.
|
||||
Returns True if all elements are equal, otherwise False.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idl : list of lists, ranges or np.ndarrays
|
||||
'''
|
||||
g = groupby([np.nditer(el) if isinstance(el, np.ndarray) else el for el in idl])
|
||||
if next(g, True) and not next(g, False):
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -537,7 +537,21 @@ def test_correlate():
|
|||
|
||||
def test_merge_idx():
|
||||
assert pe.obs._merge_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 10)
|
||||
assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6250, 50)
|
||||
assert isinstance(pe.obs._merge_idx([range(10, 1010, 10), range(10, 1010, 50)]), range)
|
||||
assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6001, 50)
|
||||
assert isinstance(pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]), range)
|
||||
assert pe.obs._merge_idx([range(1, 1011, 2), range(1, 1010, 1)]) == range(1, 1010, 1)
|
||||
assert isinstance(pe.obs._merge_idx([range(1, 1011, 2), range(1, 1010, 1)]), range)
|
||||
assert pe.obs._merge_idx([range(1, 100, 2), range(2, 100, 2)]) == range(1, 100, 1)
|
||||
assert isinstance(pe.obs._merge_idx([range(1, 100, 2), range(2, 100, 2)]), range)
|
||||
|
||||
for j in range(5):
|
||||
idll = [range(1, int(round(np.random.uniform(300, 700))), int(round(np.random.uniform(1, 14)))) for i in range(10)]
|
||||
assert pe.obs._merge_idx(idll) == sorted(set().union(*idll))
|
||||
|
||||
for j in range(5):
|
||||
idll = [range(int(round(np.random.uniform(1, 28))), int(round(np.random.uniform(300, 700))), int(round(np.random.uniform(1, 14)))) for i in range(10)]
|
||||
assert pe.obs._merge_idx(idll) == sorted(set().union(*idll))
|
||||
|
||||
idl = [list(np.arange(1, 14)) + list(range(16, 100, 4)), range(4, 604, 4), [2, 4, 5, 6, 8, 9, 12, 24], range(1, 20, 1), range(50, 789, 7)]
|
||||
new_idx = pe.obs._merge_idx(idl)
|
||||
|
@ -550,10 +564,21 @@ def test_intersection_idx():
|
|||
assert pe.obs._intersection_idx([range(1, 100), range(1, 100), range(1, 100)]) == range(1, 100)
|
||||
assert pe.obs._intersection_idx([range(1, 100, 10), range(1, 100, 2)]) == range(1, 100, 10)
|
||||
assert pe.obs._intersection_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 50)
|
||||
assert pe.obs._intersection_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6050, 250)
|
||||
assert pe.obs._intersection_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6001, 250)
|
||||
assert pe.obs._intersection_idx([range(1, 1011, 2), range(1, 1010, 1)]) == range(1, 1010, 2)
|
||||
idll = [range(1, 100, 2), range(5, 105, 1)]
|
||||
assert pe.obs._intersection_idx(idll) == range(5, 100, 2)
|
||||
assert isinstance(pe.obs._intersection_idx(idll), range)
|
||||
idll = [range(1, 100, 2), list(range(5, 105, 1))]
|
||||
assert pe.obs._intersection_idx(idll) == range(5, 100, 2)
|
||||
assert isinstance(pe.obs._intersection_idx(idll), range)
|
||||
|
||||
for ids in [[list(range(1, 80, 3)), list(range(1, 100, 2))], [range(1, 80, 3), range(1, 100, 2), range(1, 100, 7)]]:
|
||||
assert list(pe.obs._intersection_idx(ids)) == pe.obs._intersection_idx([list(o) for o in ids])
|
||||
interlist = pe.obs._intersection_idx([list(o) for o in ids])
|
||||
listinter = list(pe.obs._intersection_idx(ids))
|
||||
assert len(interlist) == len(listinter)
|
||||
assert all([o in listinter for o in interlist])
|
||||
assert all([o in interlist for o in listinter])
|
||||
|
||||
|
||||
def test_merge_intersection():
|
||||
|
@ -733,6 +758,15 @@ def test_gamma_method_irregular():
|
|||
with pytest.raises(Exception):
|
||||
my_obs.gm()
|
||||
|
||||
# check cases where tau is large compared to the chain length
|
||||
N = 15
|
||||
for i in range(10):
|
||||
arr = np.random.normal(1, .2, size=N)
|
||||
for rho in .1 * np.arange(20):
|
||||
carr = gen_autocorrelated_array(arr, rho)
|
||||
a = pe.Obs([carr], ['a'])
|
||||
a.gm()
|
||||
|
||||
|
||||
def test_irregular_gapped_dtauint():
|
||||
my_idl = list(range(0, 5010, 10))
|
||||
|
|
Loading…
Add table
Reference in a new issue