mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-05-15 12:03:42 +02:00
Merge branch 'develop' into documentation
This commit is contained in:
commit
2d72c1ef6a
5 changed files with 38 additions and 18 deletions
|
@ -4,7 +4,7 @@ import autograd.numpy as anp
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import scipy.linalg
|
import scipy.linalg
|
||||||
from .obs import Obs, reweight, correlate, CObs
|
from .obs import Obs, reweight, correlate, CObs
|
||||||
from .misc import dump_object
|
from .misc import dump_object, _assert_equal_properties
|
||||||
from .fits import least_squares
|
from .fits import least_squares
|
||||||
from .linalg import eigh, inv, cholesky
|
from .linalg import eigh, inv, cholesky
|
||||||
from .roots import find_root
|
from .roots import find_root
|
||||||
|
@ -42,10 +42,8 @@ class Corr:
|
||||||
if not isinstance(data_input, list):
|
if not isinstance(data_input, list):
|
||||||
raise TypeError('Corr__init__ expects a list of timeslices.')
|
raise TypeError('Corr__init__ expects a list of timeslices.')
|
||||||
|
|
||||||
# data_input can have multiple shapes. The simplest one is a list of Obs.
|
|
||||||
# We check, if this is the case
|
|
||||||
if all([(isinstance(item, Obs) or isinstance(item, CObs)) for item in data_input]):
|
if all([(isinstance(item, Obs) or isinstance(item, CObs)) for item in data_input]):
|
||||||
|
_assert_equal_properties(data_input)
|
||||||
self.content = [np.asarray([item]) for item in data_input]
|
self.content = [np.asarray([item]) for item in data_input]
|
||||||
self.N = 1
|
self.N = 1
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import warnings
|
||||||
from ..obs import Obs
|
from ..obs import Obs
|
||||||
from ..covobs import Covobs
|
from ..covobs import Covobs
|
||||||
from ..correlators import Corr
|
from ..correlators import Corr
|
||||||
|
from ..misc import _assert_equal_properties
|
||||||
from .. import version as pyerrorsversion
|
from .. import version as pyerrorsversion
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,20 +105,6 @@ def create_json_string(ol, description='', indent=1):
|
||||||
dl.append(ed)
|
dl.append(ed)
|
||||||
return dl
|
return dl
|
||||||
|
|
||||||
def _assert_equal_properties(ol, otype=Obs):
|
|
||||||
for o in ol:
|
|
||||||
if not isinstance(o, otype):
|
|
||||||
raise Exception("Wrong data type in list.")
|
|
||||||
for o in ol[1:]:
|
|
||||||
if not ol[0].is_merged == o.is_merged:
|
|
||||||
raise Exception("All Obs in list have to be defined on the same set of configs.")
|
|
||||||
if not ol[0].reweighted == o.reweighted:
|
|
||||||
raise Exception("All Obs in list have to have the same property 'reweighted'.")
|
|
||||||
if not ol[0].e_content == o.e_content:
|
|
||||||
raise Exception("All Obs in list have to be defined on the same set of configs.")
|
|
||||||
if not ol[0].idl == o.idl:
|
|
||||||
raise Exception("All Obs in list have to be defined on the same set of configurations.")
|
|
||||||
|
|
||||||
def write_Obs_to_dict(o):
|
def write_Obs_to_dict(o):
|
||||||
d = {}
|
d = {}
|
||||||
d['type'] = 'Obs'
|
d['type'] = 'Obs'
|
||||||
|
|
|
@ -70,3 +70,18 @@ def gen_correlated_data(means, cov, name, tau=0.5, samples=1000):
|
||||||
data.append(np.sqrt(1 - a ** 2) * rand[i] + a * data[-1])
|
data.append(np.sqrt(1 - a ** 2) * rand[i] + a * data[-1])
|
||||||
corr_data = np.array(data) - np.mean(data, axis=0) + means
|
corr_data = np.array(data) - np.mean(data, axis=0) + means
|
||||||
return [Obs([dat], [name]) for dat in corr_data.T]
|
return [Obs([dat], [name]) for dat in corr_data.T]
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_equal_properties(ol, otype=Obs):
|
||||||
|
for o in ol:
|
||||||
|
if not isinstance(o, otype):
|
||||||
|
raise Exception("Wrong data type in list.")
|
||||||
|
for o in ol[1:]:
|
||||||
|
if not ol[0].is_merged == o.is_merged:
|
||||||
|
raise Exception("All Obs in list have to be defined on the same set of configs.")
|
||||||
|
if not ol[0].reweighted == o.reweighted:
|
||||||
|
raise Exception("All Obs in list have to have the same property 'reweighted'.")
|
||||||
|
if not ol[0].e_content == o.e_content:
|
||||||
|
raise Exception("All Obs in list have to be defined on the same set of configs.")
|
||||||
|
if not ol[0].idl == o.idl:
|
||||||
|
raise Exception("All Obs in list have to be defined on the same set of configurations.")
|
||||||
|
|
|
@ -120,6 +120,24 @@ def test_padded_correlator():
|
||||||
[o for o in my_corr]
|
[o for o in my_corr]
|
||||||
|
|
||||||
|
|
||||||
|
def test_corr_exceptions():
|
||||||
|
obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'])
|
||||||
|
obs_b= pe.Obs([np.random.normal(0.1, 0.1, 99)], ['test'])
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
pe.Corr([obs_a, obs_b])
|
||||||
|
|
||||||
|
obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'])
|
||||||
|
obs_b= pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'], idl=[range(1, 200, 2)])
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
pe.Corr([obs_a, obs_b])
|
||||||
|
|
||||||
|
obs_a = pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test'])
|
||||||
|
obs_b= pe.Obs([np.random.normal(0.1, 0.1, 100)], ['test2'])
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
pe.Corr([obs_a, obs_b])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_utility():
|
def test_utility():
|
||||||
corr_content = []
|
corr_content = []
|
||||||
for t in range(8):
|
for t in range(8):
|
||||||
|
|
|
@ -105,6 +105,7 @@ def test_json_corr_io():
|
||||||
my_corr.tag = corr_tag
|
my_corr.tag = corr_tag
|
||||||
pe.input.json.dump_to_json(my_corr, 'corr')
|
pe.input.json.dump_to_json(my_corr, 'corr')
|
||||||
recover = pe.input.json.load_json('corr')
|
recover = pe.input.json.load_json('corr')
|
||||||
|
os.remove('corr.json.gz')
|
||||||
assert np.all([o.is_zero() for o in [x for x in (my_corr - recover) if x is not None]])
|
assert np.all([o.is_zero() for o in [x for x in (my_corr - recover) if x is not None]])
|
||||||
assert my_corr.tag == recover.tag
|
assert my_corr.tag == recover.tag
|
||||||
assert my_corr.reweighted == recover.reweighted
|
assert my_corr.reweighted == recover.reweighted
|
||||||
|
@ -120,5 +121,6 @@ def test_json_corr_2d_io():
|
||||||
my_corr.tag = tag
|
my_corr.tag = tag
|
||||||
pe.input.json.dump_to_json(my_corr, 'corr')
|
pe.input.json.dump_to_json(my_corr, 'corr')
|
||||||
recover = pe.input.json.load_json('corr')
|
recover = pe.input.json.load_json('corr')
|
||||||
|
os.remove('corr.json.gz')
|
||||||
assert np.all([np.all([o.is_zero() for o in q]) for q in [x.ravel() for x in (my_corr - recover) if x is not None]])
|
assert np.all([np.all([o.is_zero() for o in q]) for q in [x.ravel() for x in (my_corr - recover) if x is not None]])
|
||||||
assert my_corr.tag == recover.tag
|
assert my_corr.tag == recover.tag
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue