Merge branch 'develop' into documentation

This commit is contained in:
fjosw 2022-01-19 11:04:37 +00:00
commit 2d72c1ef6a
5 changed files with 38 additions and 18 deletions

View file

@ -4,7 +4,7 @@ import autograd.numpy as anp
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import scipy.linalg import scipy.linalg
from .obs import Obs, reweight, correlate, CObs from .obs import Obs, reweight, correlate, CObs
from .misc import dump_object from .misc import dump_object, _assert_equal_properties
from .fits import least_squares from .fits import least_squares
from .linalg import eigh, inv, cholesky from .linalg import eigh, inv, cholesky
from .roots import find_root from .roots import find_root
@ -42,10 +42,8 @@ class Corr:
if not isinstance(data_input, list): if not isinstance(data_input, list):
raise TypeError('Corr__init__ expects a list of timeslices.') raise TypeError('Corr__init__ expects a list of timeslices.')
# data_input can have multiple shapes. The simplest one is a list of Obs.
# We check, if this is the case
if all([(isinstance(item, Obs) or isinstance(item, CObs)) for item in data_input]): if all([(isinstance(item, Obs) or isinstance(item, CObs)) for item in data_input]):
_assert_equal_properties(data_input)
self.content = [np.asarray([item]) for item in data_input] self.content = [np.asarray([item]) for item in data_input]
self.N = 1 self.N = 1

View file

@ -9,6 +9,7 @@ import warnings
from ..obs import Obs from ..obs import Obs
from ..covobs import Covobs from ..covobs import Covobs
from ..correlators import Corr from ..correlators import Corr
from ..misc import _assert_equal_properties
from .. import version as pyerrorsversion from .. import version as pyerrorsversion
@ -104,20 +105,6 @@ def create_json_string(ol, description='', indent=1):
dl.append(ed) dl.append(ed)
return dl return dl
def _assert_equal_properties(ol, otype=Obs):
for o in ol:
if not isinstance(o, otype):
raise Exception("Wrong data type in list.")
for o in ol[1:]:
if not ol[0].is_merged == o.is_merged:
raise Exception("All Obs in list have to be defined on the same set of configs.")
if not ol[0].reweighted == o.reweighted:
raise Exception("All Obs in list have to have the same property 'reweighted'.")
if not ol[0].e_content == o.e_content:
raise Exception("All Obs in list have to be defined on the same set of configs.")
if not ol[0].idl == o.idl:
raise Exception("All Obs in list have to be defined on the same set of configurations.")
def write_Obs_to_dict(o): def write_Obs_to_dict(o):
d = {} d = {}
d['type'] = 'Obs' d['type'] = 'Obs'

View file

@ -70,3 +70,18 @@ def gen_correlated_data(means, cov, name, tau=0.5, samples=1000):
data.append(np.sqrt(1 - a ** 2) * rand[i] + a * data[-1]) data.append(np.sqrt(1 - a ** 2) * rand[i] + a * data[-1])
corr_data = np.array(data) - np.mean(data, axis=0) + means corr_data = np.array(data) - np.mean(data, axis=0) + means
return [Obs([dat], [name]) for dat in corr_data.T] return [Obs([dat], [name]) for dat in corr_data.T]
def _assert_equal_properties(ol, otype=Obs):
for o in ol:
if not isinstance(o, otype):
raise Exception("Wrong data type in list.")
for o in ol[1:]:
if not ol[0].is_merged == o.is_merged:
raise Exception("All Obs in list have to be defined on the same set of configs.")
if not ol[0].reweighted == o.reweighted:
raise Exception("All Obs in list have to have the same property 'reweighted'.")
if not ol[0].e_content == o.e_content:
raise Exception("All Obs in list have to be defined on the same set of configs.")
if not ol[0].idl == o.idl:
raise Exception("All Obs in list have to be defined on the same set of configurations.")

View file

@ -120,6 +120,24 @@ def test_padded_correlator():
[o for o in my_corr] [o for o in my_corr]
def test_corr_exceptions():
obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'])
obs_b= pe.Obs([np.random.normal(0.1, 0.1, 99)], ['test'])
with pytest.raises(Exception):
pe.Corr([obs_a, obs_b])
obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'])
obs_b= pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'], idl=[range(1, 200, 2)])
with pytest.raises(Exception):
pe.Corr([obs_a, obs_b])
obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'])
obs_b= pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test2'])
with pytest.raises(Exception):
pe.Corr([obs_a, obs_b])
def test_utility(): def test_utility():
corr_content = [] corr_content = []
for t in range(8): for t in range(8):

View file

@ -105,6 +105,7 @@ def test_json_corr_io():
my_corr.tag = corr_tag my_corr.tag = corr_tag
pe.input.json.dump_to_json(my_corr, 'corr') pe.input.json.dump_to_json(my_corr, 'corr')
recover = pe.input.json.load_json('corr') recover = pe.input.json.load_json('corr')
os.remove('corr.json.gz')
assert np.all([o.is_zero() for o in [x for x in (my_corr - recover) if x is not None]]) assert np.all([o.is_zero() for o in [x for x in (my_corr - recover) if x is not None]])
assert my_corr.tag == recover.tag assert my_corr.tag == recover.tag
assert my_corr.reweighted == recover.reweighted assert my_corr.reweighted == recover.reweighted
@ -120,5 +121,6 @@ def test_json_corr_2d_io():
my_corr.tag = tag my_corr.tag = tag
pe.input.json.dump_to_json(my_corr, 'corr') pe.input.json.dump_to_json(my_corr, 'corr')
recover = pe.input.json.load_json('corr') recover = pe.input.json.load_json('corr')
os.remove('corr.json.gz')
assert np.all([np.all([o.is_zero() for o in q]) for q in [x.ravel() for x in (my_corr - recover) if x is not None]]) assert np.all([np.all([o.is_zero() for o in q]) for q in [x.ravel() for x in (my_corr - recover) if x is not None]])
assert my_corr.tag == recover.tag assert my_corr.tag == recover.tag