diff --git a/tests/json_io_test.py b/tests/json_io_test.py index d13f2675..6fac73b8 100644 --- a/tests/json_io_test.py +++ b/tests/json_io_test.py @@ -353,3 +353,55 @@ def test_dobsio(): for j in range(len(or1)): o = or1[j] - or2[j] assert(o.is_zero()) + + +def test_reconstruct_non_linear_r_obs(tmp_path): + to = pe.Obs([np.random.rand(500), np.random.rand(500), np.random.rand(111)], + ["e|r1", "e|r2", "my_new_ensemble_54^£$|8'[@124435%6^7&()~#"], + idl=[range(1, 501), range(0, 500), range(1, 999, 9)]) + to = np.log(to ** 2) / to + to.dump((tmp_path / "test_equality").as_posix()) + ro = pe.input.json.load_json((tmp_path / "test_equality").as_posix()) + assert assert_equal_Obs(to, ro) + + +def test_reconstruct_non_linear_r_obs_list(tmp_path): + to = pe.Obs([np.random.rand(500), np.random.rand(500), np.random.rand(111)], + ["e|r1", "e|r2", "my_new_ensemble_54^£$|8'[@124435%6^7&()~#"], + idl=[range(1, 501), range(0, 500), range(1, 999, 9)]) + to = np.log(to ** 2) / to + for to_list in [[to, to, to], np.array([to, to, to])]: + pe.input.json.dump_to_json(to_list, (tmp_path / "test_equality_list").as_posix()) + ro_list = pe.input.json.load_json((tmp_path / "test_equality_list").as_posix()) + for oa, ob in zip(to_list, ro_list): + assert assert_equal_Obs(oa, ob) + + +def assert_equal_Obs(to, ro): + for kw in ["N", "cov_names", "covobs", "ddvalue", "dvalue", "e_content", + "e_names", "idl", "mc_names", "names", + "reweighted", "shape", "tag"]: + if not getattr(to, kw) == getattr(ro, kw): + print(kw, "does not match.") + return False + + for kw in ["value"]: + if not np.isclose(getattr(to, kw), getattr(ro, kw), atol=1e-14): + print(kw, "does not match.") + return False + + + for kw in ["r_values", "deltas"]: + for (k, v), (k2, v2) in zip(getattr(to, kw).items(), getattr(ro, kw).items()): + assert k == k2 + if not np.allclose(v, v2, atol=1e-14): + print(kw, "does not match.") + return False + + m_to = getattr(to, "is_merged") + m_ro = getattr(ro, "is_merged") + if not m_to == m_ro: + if not (all(value is False for value in m_ro.values()) and all(value is False for value in m_to.values())): + print("is_merged", "does not match.") + return False + return True