diff --git a/pyerrors/input/json.py b/pyerrors/input/json.py index 131dccf3..1ba6682d 100644 --- a/pyerrors/input/json.py +++ b/pyerrors/input/json.py @@ -8,6 +8,7 @@ import platform import warnings from ..obs import Obs from ..covobs import Covobs +from ..correlators import Corr from .. import version as pyerrorsversion @@ -173,6 +174,18 @@ def create_json_string(ol, description='', indent=1): d['cdata'] = cdata return d + def write_Corr_to_dict(my_corr): + front_padding = next(i for i, j in enumerate(my_corr.content) if np.all(j)) + back_padding_start = front_padding + next((i for i, j in enumerate(my_corr.content[front_padding:]) if not np.all(j)), my_corr.T) + dat = write_Array_to_dict(np.array(my_corr.content[front_padding:back_padding_start])) + dat['type'] = 'Corr' + corr_meta_data = str(front_padding) + '|' + str(my_corr.T - back_padding_start) + '|' + str(my_corr.tag) + if 'tag' in dat.keys(): + dat['tag'].append(corr_meta_data) + else: + dat['tag'] = [corr_meta_data] + return dat + if not isinstance(ol, list): ol = [ol] @@ -193,6 +206,10 @@ def create_json_string(ol, description='', indent=1): d['obsdata'].append(write_List_to_dict(io)) elif isinstance(io, np.ndarray): d['obsdata'].append(write_Array_to_dict(io)) + elif isinstance(io, Corr): + d['obsdata'].append(write_Corr_to_dict(io)) + else: + raise Exception("Unkown datatype.") jsonstring = json.dumps(d, indent=indent, cls=my_encoder, ensure_ascii=False) @@ -374,6 +391,22 @@ def import_json_string(json_string, verbose=True, full_output=False): ret[-1].tag = taglist[i] return np.reshape(ret, layout) + def get_Corr_from_dict(o): + taglist = o.get('tag') + corr_meta_data = taglist[-1].split('|') + padding_front = int(corr_meta_data[0]) + padding_back = int(corr_meta_data[1]) + corr_tag = corr_meta_data[2] + tmp_o = o + tmp_o['tag'] = taglist[:-1] + if len(tmp_o['tag']) == 0: + del tmp_o['tag'] + dat = get_Array_from_dict(tmp_o) + my_corr = Corr(list(dat), padding_front=padding_front, padding_back=padding_back) + if corr_tag != 'None': + my_corr.tag = corr_tag + return my_corr + json_dict = json.loads(json_string) prog = json_dict.get('program', '') @@ -400,6 +433,10 @@ def import_json_string(json_string, verbose=True, full_output=False): ol.append(get_List_from_dict(io)) elif io['type'] == 'Array': ol.append(get_Array_from_dict(io)) + elif io['type'] == 'Corr': + ol.append(get_Corr_from_dict(io)) + else: + raise Exception("Unkown datatype.") if full_output: retd = {} diff --git a/tests/io_test.py b/tests/io_test.py index 92781785..d660f34a 100644 --- a/tests/io_test.py +++ b/tests/io_test.py @@ -89,3 +89,36 @@ def test_json_string_reconstruction(): assert reconstructed_string == json_string assert my_obs == reconstructed_obs2 + + +def test_json_corr_io(): + my_list = [pe.Obs([np.random.normal(1.0, 0.1, 100)], ['ens1']) for o in range(8)] + rw_list = pe.reweight(pe.Obs([np.random.normal(1.0, 0.1, 100)], ['ens1']), my_list) + + for obs_list in [my_list, rw_list]: + for tag in [None, "test"]: + obs_list[3].tag = tag + for fp in [0, 2]: + for bp in [0, 7]: + for corr_tag in [None, 'my_Corr_tag']: + my_corr = pe.Corr(obs_list, padding_front=fp, padding_back=bp) + my_corr.tag = corr_tag + pe.input.json.dump_to_json(my_corr, 'corr') + recover = pe.input.json.load_json('corr') + assert np.all([o.is_zero() for o in [x for x in (my_corr - recover) if x is not None]]) + assert my_corr.tag == recover.tag + assert my_corr.reweighted == recover.reweighted + + +def test_json_corr_2d_io(): + obs_list = [np.array([[pe.pseudo_Obs(1.0 + i, 0.1 * i, 'test'), pe.pseudo_Obs(0.0, 0.1 * i, 'test')], [pe.pseudo_Obs(0.0, 0.1 * i, 'test'), pe.pseudo_Obs(1.0 + i, 0.1 * i, 'test')]]) for i in range(8)] + + for tag in [None, "test"]: + obs_list[3][0, 1].tag = tag + for padding in [0, 1]: + my_corr = pe.Corr(obs_list, padding_front=padding, padding_back=padding) + my_corr.tag = tag + pe.input.json.dump_to_json(my_corr, 'corr') + recover = pe.input.json.load_json('corr') + assert np.all([np.all([o.is_zero() for o in q]) for q in [x.ravel() for x in (my_corr - recover) if x is not None]]) + assert my_corr.tag == recover.tag diff --git a/tests/obs_test.py b/tests/obs_test.py index a5e72ec9..210ac67c 100644 --- a/tests/obs_test.py +++ b/tests/obs_test.py @@ -615,7 +615,7 @@ def test_covariance_symmetry(): cov_ab = pe.covariance(test_obs1, a) cov_ba = pe.covariance(a, test_obs1) assert np.abs(cov_ab - cov_ba) <= 10 * np.finfo(np.float64).eps - assert np.abs(cov_ab) < test_obs1.dvalue * test_obs2.dvalue * (1 + 10 * np.finfo(np.float64).eps) + assert np.abs(cov_ab) < test_obs1.dvalue * a.dvalue * (1 + 10 * np.finfo(np.float64).eps) def test_empty_obs():