From 65a9128a7de28bcf3c45e37db88002635cdd58ae Mon Sep 17 00:00:00 2001 From: s-kuberski Date: Fri, 28 Apr 2023 19:14:51 +0200 Subject: [PATCH] Fix/merge idx (#172) * Fix: Corrected merging of idls * Fix: Computation of drho in cases where tau_int is large compared to the chain length * Removed unnecessary imports * Refactor list comparisons in obs.py --- pyerrors/obs.py | 83 +++++++++++++++++++++++------------------------ tests/obs_test.py | 40 +++++++++++++++++++++-- 2 files changed, 77 insertions(+), 46 deletions(-) diff --git a/pyerrors/obs.py b/pyerrors/obs.py index c8f4c3e2..8819ac1a 100644 --- a/pyerrors/obs.py +++ b/pyerrors/obs.py @@ -1,8 +1,6 @@ import warnings import hashlib import pickle -from math import gcd -from functools import reduce import numpy as np import autograd.numpy as anp # Thinly-wrapped numpy from autograd import jacobian @@ -280,7 +278,7 @@ class Obs: def _compute_drho(i): tmp = (self.e_rho[e_name][i + 1:w_max] - + np.concatenate([self.e_rho[e_name][i - 1:None if i - w_max // 2 < 0 else 2 * (i - w_max // 2):-1], + + np.concatenate([self.e_rho[e_name][i - 1:None if i - w_max // 2 <= 0 else 2 * (i - w_max // 2):-1], self.e_rho[e_name][1:max(1, w_max - 2 * i)]]) - 2 * self.e_rho[e_name][i] * self.e_rho[e_name][1:w_max - i]) self.e_drho[e_name][i] = np.sqrt(np.sum(tmp ** 2) / e_N) @@ -1022,7 +1020,7 @@ def _expand_deltas(deltas, idx, shape, gapsize): def _merge_idx(idl): - """Returns the union of all lists in idl as sorted list + """Returns the union of all lists in idl as range or sorted list Parameters ---------- @@ -1030,26 +1028,22 @@ def _merge_idx(idl): List of lists or ranges. """ - # Use groupby to efficiently check whether all elements of idl are identical - try: - g = groupby(idl) - if next(g, True) and not next(g, False): - return idl[0] - except Exception: - pass + if _check_lists_equal(idl): + return idl[0] - if np.all([type(idx) is range for idx in idl]): - if len(set([idx[0] for idx in idl])) == 1: - idstart = min([idx.start for idx in idl]) - idstop = max([idx.stop for idx in idl]) - idstep = min([idx.step for idx in idl]) - return range(idstart, idstop, idstep) + idunion = sorted(set().union(*idl)) - return sorted(set().union(*idl)) + # Check whether idunion can be expressed as range + idrange = range(idunion[0], idunion[-1] + 1, idunion[1] - idunion[0]) + idtest = [list(idrange), idunion] + if _check_lists_equal(idtest): + return idrange + + return idunion def _intersection_idx(idl): - """Returns the intersection of all lists in idl as sorted list + """Returns the intersection of all lists in idl as range or sorted list Parameters ---------- @@ -1057,28 +1051,21 @@ def _intersection_idx(idl): List of lists or ranges. """ - def _lcm(*args): - """Returns the lowest common multiple of args. + if _check_lists_equal(idl): + return idl[0] - From python 3.9 onwards the math library contains an lcm function.""" - return reduce(lambda a, b: a * b // gcd(a, b), args) + idinter = sorted(set.intersection(*[set(o) for o in idl])) - # Use groupby to efficiently check whether all elements of idl are identical + # Check whether idinter can be expressed as range try: - g = groupby(idl) - if next(g, True) and not next(g, False): - return idl[0] - except Exception: + idrange = range(idinter[0], idinter[-1] + 1, idinter[1] - idinter[0]) + idtest = [list(idrange), idinter] + if _check_lists_equal(idtest): + return idrange + except IndexError: pass - if np.all([type(idx) is range for idx in idl]): - if len(set([idx[0] for idx in idl])) == 1: - idstart = max([idx.start for idx in idl]) - idstop = min([idx.stop for idx in idl]) - idstep = _lcm(*[idx.step for idx in idl]) - return range(idstart, idstop, idstep) - - return sorted(set.intersection(*[set(o) for o in idl])) + return idinter def _expand_deltas_for_merge(deltas, idx, shape, new_idx): @@ -1299,13 +1286,8 @@ def _reduce_deltas(deltas, idx_old, idx_new): if type(idx_old) is range and type(idx_new) is range: if idx_old == idx_new: return deltas - # Use groupby to efficiently check whether all elements of idx_old and idx_new are identical - try: - g = groupby([idx_old, idx_new]) - if next(g, True) and not next(g, False): - return deltas - except Exception: - pass + if _check_lists_equal([idx_old, idx_new]): + return deltas indices = np.intersect1d(idx_old, idx_new, assume_unique=True, return_indices=True)[1] if len(indices) < len(idx_new): raise Exception('Error in _reduce_deltas: Config of idx_new not in idx_old') @@ -1650,3 +1632,18 @@ def _determine_gap(o, e_content, e_name): raise Exception(f"Replica for ensemble {e_name} do not have a common spacing.", gaps) return gap + + +def _check_lists_equal(idl): + ''' + Use groupby to efficiently check whether all elements of idl are identical. + Returns True if all elements are equal, otherwise False. + + Parameters + ---------- + idl : list of lists, ranges or np.ndarrays + ''' + g = groupby([np.nditer(el) if isinstance(el, np.ndarray) else el for el in idl]) + if next(g, True) and not next(g, False): + return True + return False diff --git a/tests/obs_test.py b/tests/obs_test.py index 02d5b4ac..5b294326 100644 --- a/tests/obs_test.py +++ b/tests/obs_test.py @@ -537,7 +537,21 @@ def test_correlate(): def test_merge_idx(): assert pe.obs._merge_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 10) - assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6250, 50) + assert isinstance(pe.obs._merge_idx([range(10, 1010, 10), range(10, 1010, 50)]), range) + assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6001, 50) + assert isinstance(pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]), range) + assert pe.obs._merge_idx([range(1, 1011, 2), range(1, 1010, 1)]) == range(1, 1010, 1) + assert isinstance(pe.obs._merge_idx([range(1, 1011, 2), range(1, 1010, 1)]), range) + assert pe.obs._merge_idx([range(1, 100, 2), range(2, 100, 2)]) == range(1, 100, 1) + assert isinstance(pe.obs._merge_idx([range(1, 100, 2), range(2, 100, 2)]), range) + + for j in range(5): + idll = [range(1, int(round(np.random.uniform(300, 700))), int(round(np.random.uniform(1, 14)))) for i in range(10)] + assert pe.obs._merge_idx(idll) == sorted(set().union(*idll)) + + for j in range(5): + idll = [range(int(round(np.random.uniform(1, 28))), int(round(np.random.uniform(300, 700))), int(round(np.random.uniform(1, 14)))) for i in range(10)] + assert pe.obs._merge_idx(idll) == sorted(set().union(*idll)) idl = [list(np.arange(1, 14)) + list(range(16, 100, 4)), range(4, 604, 4), [2, 4, 5, 6, 8, 9, 12, 24], range(1, 20, 1), range(50, 789, 7)] new_idx = pe.obs._merge_idx(idl) @@ -550,10 +564,21 @@ def test_intersection_idx(): assert pe.obs._intersection_idx([range(1, 100), range(1, 100), range(1, 100)]) == range(1, 100) assert pe.obs._intersection_idx([range(1, 100, 10), range(1, 100, 2)]) == range(1, 100, 10) assert pe.obs._intersection_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 50) - assert pe.obs._intersection_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6050, 250) + assert pe.obs._intersection_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6001, 250) + assert pe.obs._intersection_idx([range(1, 1011, 2), range(1, 1010, 1)]) == range(1, 1010, 2) + idll = [range(1, 100, 2), range(5, 105, 1)] + assert pe.obs._intersection_idx(idll) == range(5, 100, 2) + assert isinstance(pe.obs._intersection_idx(idll), range) + idll = [range(1, 100, 2), list(range(5, 105, 1))] + assert pe.obs._intersection_idx(idll) == range(5, 100, 2) + assert isinstance(pe.obs._intersection_idx(idll), range) for ids in [[list(range(1, 80, 3)), list(range(1, 100, 2))], [range(1, 80, 3), range(1, 100, 2), range(1, 100, 7)]]: - assert list(pe.obs._intersection_idx(ids)) == pe.obs._intersection_idx([list(o) for o in ids]) + interlist = pe.obs._intersection_idx([list(o) for o in ids]) + listinter = list(pe.obs._intersection_idx(ids)) + assert len(interlist) == len(listinter) + assert all([o in listinter for o in interlist]) + assert all([o in interlist for o in listinter]) def test_merge_intersection(): @@ -733,6 +758,15 @@ def test_gamma_method_irregular(): with pytest.raises(Exception): my_obs.gm() + # check cases where tau is large compared to the chain length + N = 15 + for i in range(10): + arr = np.random.normal(1, .2, size=N) + for rho in .1 * np.arange(20): + carr = gen_autocorrelated_array(arr, rho) + a = pe.Obs([carr], ['a']) + a.gm() + def test_irregular_gapped_dtauint(): my_idl = list(range(0, 5010, 10))