feat: _intersection_idx and _collapse_deltas_for_merge together with

tests added.
This commit is contained in:
Fabian Joswig 2022-04-08 11:14:58 +01:00
parent 5e7753a66d
commit 934d091249
2 changed files with 75 additions and 0 deletions

View file

@ -972,6 +972,33 @@ def _merge_idx(idl):
return sorted(set().union(*idl)) return sorted(set().union(*idl))
def _intersection_idx(idl):
"""Returns the intersection of all lists in idl as sorted list
Parameters
----------
idl : list
List of lists or ranges.
"""
# Use groupby to efficiently check whether all elements of idl are identical
try:
g = groupby(idl)
if next(g, True) and not next(g, False):
return idl[0]
except Exception:
pass
if np.all([type(idx) is range for idx in idl]):
if len(set([idx[0] for idx in idl])) == 1:
idstart = max([idx.start for idx in idl])
idstop = min([idx.stop for idx in idl])
idstep = max([idx.step for idx in idl])
return range(idstart, idstop, idstep)
return sorted(set().intersection(*idl))
def _expand_deltas_for_merge(deltas, idx, shape, new_idx): def _expand_deltas_for_merge(deltas, idx, shape, new_idx):
"""Expand deltas defined on idx to the list of configs that is defined by new_idx. """Expand deltas defined on idx to the list of configs that is defined by new_idx.
New, empty entries are filled by 0. If idx and new_idx are of type range, the smallest New, empty entries are filled by 0. If idx and new_idx are of type range, the smallest
@ -999,6 +1026,34 @@ def _expand_deltas_for_merge(deltas, idx, shape, new_idx):
return np.array([ret[new_idx[i] - new_idx[0]] for i in range(len(new_idx))]) return np.array([ret[new_idx[i] - new_idx[0]] for i in range(len(new_idx))])
def _collapse_deltas_for_merge(deltas, idx, shape, new_idx):
"""Collapse deltas defined on idx to the list of configs that is defined by new_idx.
If idx and new_idx are of type range, the smallest
common divisor of the step sizes is used as new step size.
Parameters
----------
deltas : list
List of fluctuations
idx : list
List or range of configs on which the deltas are defined.
Has to be a subset of new_idx and has to be sorted in ascending order.
shape : list
Number of configs in idx.
new_idx : list
List of configs that defines the new range, has to be sorted in ascending order.
"""
if type(idx) is range and type(new_idx) is range:
if idx == new_idx:
return deltas
ret = np.zeros(new_idx[-1] - new_idx[0] + 1)
for i in range(shape):
if idx[i] in new_idx:
ret[idx[i] - new_idx[0]] = deltas[i]
return np.array([ret[new_idx[i] - new_idx[0]] for i in range(len(new_idx))])
def _filter_zeroes(deltas, idx, eps=Obs.filter_eps): def _filter_zeroes(deltas, idx, eps=Obs.filter_eps):
"""Filter out all configurations with vanishing fluctuation such that they do not """Filter out all configurations with vanishing fluctuation such that they do not
contribute to the error estimate anymore. Returns the new deltas and contribute to the error estimate anymore. Returns the new deltas and

View file

@ -515,6 +515,26 @@ def test_merge_idx():
assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6250, 50) assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6250, 50)
def test_intersection_idx():
assert pe.obs._intersection_idx([range(1, 100), range(1, 100), range(1, 100)]) == range(1, 100)
assert pe.obs._intersection_idx([range(1, 100, 10), range(1, 100, 2)]) == range(1, 100, 10)
assert pe.obs._intersection_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 50)
assert pe.obs._intersection_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6050, 250)
def test_intersection_collapse():
range1 = range(1, 2000, 2)
range2 = range(2, 2001, 8)
obs1 = pe.Obs([np.random.normal(1.0, 0.1, len(range1))], ["ens"], idl=[range1])
obs_merge = obs1 + pe.Obs([np.random.normal(1.0, 0.1, len(range2))], ["ens"], idl=[range2])
intersection = pe.obs._intersection_idx([o.idl["ens"] for o in [obs1, obs_merge]])
coll = pe.obs._collapse_deltas_for_merge(obs_merge.deltas["ens"], obs_merge.idl["ens"], len(obs_merge.idl["ens"]), range1)
assert np.all(coll == obs1.deltas["ens"])
def test_irregular_error_propagation(): def test_irregular_error_propagation():
obs_list = [pe.Obs([np.random.rand(100)], ['t']), obs_list = [pe.Obs([np.random.rand(100)], ['t']),
pe.Obs([np.random.rand(50)], ['t'], idl=[range(1, 100, 2)]), pe.Obs([np.random.rand(50)], ['t'], idl=[range(1, 100, 2)]),