diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index 9262d08d..2209c431 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -4,7 +4,7 @@ import autograd.numpy as anp import matplotlib.pyplot as plt import scipy.linalg from .obs import Obs, reweight, correlate, CObs -from .misc import dump_object +from .misc import dump_object, _assert_equal_properties from .fits import least_squares from .linalg import eigh, inv, cholesky from .roots import find_root @@ -42,10 +42,8 @@ class Corr: if not isinstance(data_input, list): raise TypeError('Corr__init__ expects a list of timeslices.') - # data_input can have multiple shapes. The simplest one is a list of Obs. - # We check, if this is the case if all([(isinstance(item, Obs) or isinstance(item, CObs)) for item in data_input]): - + _assert_equal_properties(data_input) self.content = [np.asarray([item]) for item in data_input] self.N = 1 diff --git a/pyerrors/misc.py b/pyerrors/misc.py index edbdc369..740aff9a 100644 --- a/pyerrors/misc.py +++ b/pyerrors/misc.py @@ -85,4 +85,3 @@ def _assert_equal_properties(ol, otype=Obs): raise Exception("All Obs in list have to be defined on the same set of configs.") if not ol[0].idl == o.idl: raise Exception("All Obs in list have to be defined on the same set of configurations.") - diff --git a/tests/correlators_test.py b/tests/correlators_test.py index 155fc61b..55d0a977 100644 --- a/tests/correlators_test.py +++ b/tests/correlators_test.py @@ -120,6 +120,24 @@ def test_padded_correlator(): [o for o in my_corr] +def test_corr_exceptions(): + obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test']) + obs_b= pe.Obs([np.random.normal(0.1, 0.1, 99)], ['test']) + with pytest.raises(Exception): + pe.Corr([obs_a, obs_b]) + + obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test']) + obs_b= pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'], idl=[range(1, 200, 2)]) + with pytest.raises(Exception): + pe.Corr([obs_a, obs_b]) + + obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test']) + obs_b= pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test2']) + with pytest.raises(Exception): + pe.Corr([obs_a, obs_b]) + + + def test_utility(): corr_content = [] for t in range(8):