Fix/merge idx (#172)

* Fix: Corrected merging of idls

* Fix: Computation of drho in cases where tau_int is large compared to the chain length

* Removed unnecessary imports

* Refactor list comparisons in obs.py
This commit is contained in:
s-kuberski 2023-04-28 19:14:51 +02:00 committed by GitHub
parent 1184a0fe76
commit 65a9128a7d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 46 deletions

View file

@ -537,7 +537,21 @@ def test_correlate():
def test_merge_idx():
assert pe.obs._merge_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 10)
assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6250, 50)
assert isinstance(pe.obs._merge_idx([range(10, 1010, 10), range(10, 1010, 50)]), range)
assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6001, 50)
assert isinstance(pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]), range)
assert pe.obs._merge_idx([range(1, 1011, 2), range(1, 1010, 1)]) == range(1, 1010, 1)
assert isinstance(pe.obs._merge_idx([range(1, 1011, 2), range(1, 1010, 1)]), range)
assert pe.obs._merge_idx([range(1, 100, 2), range(2, 100, 2)]) == range(1, 100, 1)
assert isinstance(pe.obs._merge_idx([range(1, 100, 2), range(2, 100, 2)]), range)
for j in range(5):
idll = [range(1, int(round(np.random.uniform(300, 700))), int(round(np.random.uniform(1, 14)))) for i in range(10)]
assert pe.obs._merge_idx(idll) == sorted(set().union(*idll))
for j in range(5):
idll = [range(int(round(np.random.uniform(1, 28))), int(round(np.random.uniform(300, 700))), int(round(np.random.uniform(1, 14)))) for i in range(10)]
assert pe.obs._merge_idx(idll) == sorted(set().union(*idll))
idl = [list(np.arange(1, 14)) + list(range(16, 100, 4)), range(4, 604, 4), [2, 4, 5, 6, 8, 9, 12, 24], range(1, 20, 1), range(50, 789, 7)]
new_idx = pe.obs._merge_idx(idl)
@ -550,10 +564,21 @@ def test_intersection_idx():
assert pe.obs._intersection_idx([range(1, 100), range(1, 100), range(1, 100)]) == range(1, 100)
assert pe.obs._intersection_idx([range(1, 100, 10), range(1, 100, 2)]) == range(1, 100, 10)
assert pe.obs._intersection_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 50)
assert pe.obs._intersection_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6050, 250)
assert pe.obs._intersection_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6001, 250)
assert pe.obs._intersection_idx([range(1, 1011, 2), range(1, 1010, 1)]) == range(1, 1010, 2)
idll = [range(1, 100, 2), range(5, 105, 1)]
assert pe.obs._intersection_idx(idll) == range(5, 100, 2)
assert isinstance(pe.obs._intersection_idx(idll), range)
idll = [range(1, 100, 2), list(range(5, 105, 1))]
assert pe.obs._intersection_idx(idll) == range(5, 100, 2)
assert isinstance(pe.obs._intersection_idx(idll), range)
for ids in [[list(range(1, 80, 3)), list(range(1, 100, 2))], [range(1, 80, 3), range(1, 100, 2), range(1, 100, 7)]]:
assert list(pe.obs._intersection_idx(ids)) == pe.obs._intersection_idx([list(o) for o in ids])
interlist = pe.obs._intersection_idx([list(o) for o in ids])
listinter = list(pe.obs._intersection_idx(ids))
assert len(interlist) == len(listinter)
assert all([o in listinter for o in interlist])
assert all([o in interlist for o in listinter])
def test_merge_intersection():
@ -733,6 +758,15 @@ def test_gamma_method_irregular():
with pytest.raises(Exception):
my_obs.gm()
# check cases where tau is large compared to the chain length
N = 15
for i in range(10):
arr = np.random.normal(1, .2, size=N)
for rho in .1 * np.arange(20):
carr = gen_autocorrelated_array(arr, rho)
a = pe.Obs([carr], ['a'])
a.gm()
def test_irregular_gapped_dtauint():
my_idl = list(range(0, 5010, 10))