feat: json import export for correlator class implemented, tests added

This commit is contained in:
Fabian Joswig 2022-01-18 12:13:20 +00:00
parent 268f71fa19
commit f282ecdaf6
2 changed files with 70 additions and 0 deletions

View file

@ -8,6 +8,7 @@ import platform
import warnings import warnings
from ..obs import Obs from ..obs import Obs
from ..covobs import Covobs from ..covobs import Covobs
from ..correlators import Corr
from .. import version as pyerrorsversion from .. import version as pyerrorsversion
@ -173,6 +174,18 @@ def create_json_string(ol, description='', indent=1):
d['cdata'] = cdata d['cdata'] = cdata
return d return d
def write_Corr_to_dict(my_corr):
front_padding = next(i for i, j in enumerate(my_corr.content) if np.all(j))
back_padding_start = front_padding + next((i for i, j in enumerate(my_corr.content[front_padding:]) if not np.all(j)), my_corr.T)
dat = write_Array_to_dict(np.array(my_corr.content[front_padding:back_padding_start]))
dat['type'] = 'Corr'
corr_meta_data = str(front_padding) + '|' + str(my_corr.T - back_padding_start) + '|' + str(my_corr.tag)
if 'tag' in dat.keys():
dat['tag'].append(corr_meta_data)
else:
dat['tag'] = [corr_meta_data]
return dat
if not isinstance(ol, list): if not isinstance(ol, list):
ol = [ol] ol = [ol]
@ -193,6 +206,10 @@ def create_json_string(ol, description='', indent=1):
d['obsdata'].append(write_List_to_dict(io)) d['obsdata'].append(write_List_to_dict(io))
elif isinstance(io, np.ndarray): elif isinstance(io, np.ndarray):
d['obsdata'].append(write_Array_to_dict(io)) d['obsdata'].append(write_Array_to_dict(io))
elif isinstance(io, Corr):
d['obsdata'].append(write_Corr_to_dict(io))
else:
raise Exception("Unkown datatype.")
jsonstring = json.dumps(d, indent=indent, cls=my_encoder, ensure_ascii=False) jsonstring = json.dumps(d, indent=indent, cls=my_encoder, ensure_ascii=False)
@ -374,6 +391,22 @@ def import_json_string(json_string, verbose=True, full_output=False):
ret[-1].tag = taglist[i] ret[-1].tag = taglist[i]
return np.reshape(ret, layout) return np.reshape(ret, layout)
def get_Corr_from_dict(o):
taglist = o.get('tag')
corr_meta_data = taglist[-1].split('|')
padding_front = int(corr_meta_data[0])
padding_back = int(corr_meta_data[1])
corr_tag = corr_meta_data[2]
tmp_o = o
tmp_o['tag'] = taglist[:-1]
if len(tmp_o['tag']) == 0:
del tmp_o['tag']
dat = get_Array_from_dict(tmp_o)
my_corr = Corr(list(dat), padding_front=padding_front, padding_back=padding_back)
if corr_tag != 'None':
my_corr.tag = corr_tag
return my_corr
json_dict = json.loads(json_string) json_dict = json.loads(json_string)
prog = json_dict.get('program', '') prog = json_dict.get('program', '')
@ -400,6 +433,10 @@ def import_json_string(json_string, verbose=True, full_output=False):
ol.append(get_List_from_dict(io)) ol.append(get_List_from_dict(io))
elif io['type'] == 'Array': elif io['type'] == 'Array':
ol.append(get_Array_from_dict(io)) ol.append(get_Array_from_dict(io))
elif io['type'] == 'Corr':
ol.append(get_Corr_from_dict(io))
else:
raise Exception("Unkown datatype.")
if full_output: if full_output:
retd = {} retd = {}

View file

@ -89,3 +89,36 @@ def test_json_string_reconstruction():
assert reconstructed_string == json_string assert reconstructed_string == json_string
assert my_obs == reconstructed_obs2 assert my_obs == reconstructed_obs2
def test_json_corr_io():
my_list = [pe.Obs([np.random.normal(1.0, 0.1, 100)], ['ens1']) for o in range(8)]
rw_list = pe.reweight(pe.Obs([np.random.normal(1.0, 0.1, 100)], ['ens1']), my_list)
for obs_list in [my_list, rw_list]:
for tag in [None, "test"]:
obs_list[3].tag = tag
for fp in [0, 2]:
for bp in [0, 7]:
for corr_tag in [None, 'my_Corr_tag']:
my_corr = pe.Corr(obs_list, padding_front=fp, padding_back=bp)
my_corr.tag = corr_tag
pe.input.json.dump_to_json(my_corr, 'corr')
recover = pe.input.json.load_json('corr')
assert np.all([o.is_zero() for o in [x for x in (my_corr - recover) if x is not None]])
assert my_corr.tag == recover.tag
assert my_corr.reweighted == recover.reweighted
def test_json_corr_2d_io():
obs_list = [np.array([[pe.pseudo_Obs(1.0 + i, 0.1 * i, 'test'), pe.pseudo_Obs(0.0, 0.1 * i, 'test')], [pe.pseudo_Obs(0.0, 0.1 * i, 'test'), pe.pseudo_Obs(1.0 + i, 0.1 * i, 'test')]]) for i in range(8)]
for tag in [None, "test"]:
obs_list[3][0, 1].tag = tag
for padding in [0, 1]:
my_corr = pe.Corr(obs_list, padding_front=padding, padding_back=padding)
my_corr.tag = tag
pe.input.json.dump_to_json(my_corr, 'corr')
recover = pe.input.json.load_json('corr')
assert np.all([np.all([o.is_zero() for o in q]) for q in [x.ravel() for x in (my_corr - recover) if x is not None]])
assert my_corr.tag == recover.tag