feat: When initializing 1d correlators it is now checked whether all obs

are defined on the same ensembles.
This commit is contained in:
Fabian Joswig 2022-01-19 11:03:45 +00:00
parent 78ff4bb117
commit c3ba07280b
3 changed files with 20 additions and 5 deletions

View file

@ -4,7 +4,7 @@ import autograd.numpy as anp
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import scipy.linalg import scipy.linalg
from .obs import Obs, reweight, correlate, CObs from .obs import Obs, reweight, correlate, CObs
from .misc import dump_object from .misc import dump_object, _assert_equal_properties
from .fits import least_squares from .fits import least_squares
from .linalg import eigh, inv, cholesky from .linalg import eigh, inv, cholesky
from .roots import find_root from .roots import find_root
@ -42,10 +42,8 @@ class Corr:
if not isinstance(data_input, list): if not isinstance(data_input, list):
raise TypeError('Corr__init__ expects a list of timeslices.') raise TypeError('Corr__init__ expects a list of timeslices.')
# data_input can have multiple shapes. The simplest one is a list of Obs.
# We check, if this is the case
if all([(isinstance(item, Obs) or isinstance(item, CObs)) for item in data_input]): if all([(isinstance(item, Obs) or isinstance(item, CObs)) for item in data_input]):
_assert_equal_properties(data_input)
self.content = [np.asarray([item]) for item in data_input] self.content = [np.asarray([item]) for item in data_input]
self.N = 1 self.N = 1

View file

@ -85,4 +85,3 @@ def _assert_equal_properties(ol, otype=Obs):
raise Exception("All Obs in list have to be defined on the same set of configs.") raise Exception("All Obs in list have to be defined on the same set of configs.")
if not ol[0].idl == o.idl: if not ol[0].idl == o.idl:
raise Exception("All Obs in list have to be defined on the same set of configurations.") raise Exception("All Obs in list have to be defined on the same set of configurations.")

View file

@ -120,6 +120,24 @@ def test_padded_correlator():
[o for o in my_corr] [o for o in my_corr]
def test_corr_exceptions():
obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'])
obs_b= pe.Obs([np.random.normal(0.1, 0.1, 99)], ['test'])
with pytest.raises(Exception):
pe.Corr([obs_a, obs_b])
obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'])
obs_b= pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'], idl=[range(1, 200, 2)])
with pytest.raises(Exception):
pe.Corr([obs_a, obs_b])
obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'])
obs_b= pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test2'])
with pytest.raises(Exception):
pe.Corr([obs_a, obs_b])
def test_utility(): def test_utility():
corr_content = [] corr_content = []
for t in range(8): for t in range(8):