From af28f77ec56f359bb85bcce9870c8eb14b79853b Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Wed, 19 Jul 2023 15:06:19 +0100 Subject: [PATCH] __eq__ method for Corr class (#206) * feat: implemented __eq__ method for Corr class. * feat: __eq__ method now respects None entries in correlators. * feat: Obs can now be compared to None, __ne__ method removed as it is not required. * feat: Corr.__eq__ rewritten to give a per element comparison. * tests: additional test case for correlator comparison added. * feat: comparison now also works for padding. --- pyerrors/correlators.py | 7 +++++++ pyerrors/obs.py | 5 ++--- tests/correlators_test.py | 36 ++++++++++++++++++++++++++++++++++++ tests/obs_test.py | 1 + 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index 002d2c87..be5c69a6 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -1056,6 +1056,13 @@ class Corr: __array_priority__ = 10000 + def __eq__(self, y): + if isinstance(y, Corr): + comp = np.asarray(y.content, dtype=object) + else: + comp = np.asarray(y) + return np.asarray(self.content, dtype=object) == comp + def __add__(self, y): if isinstance(y, Corr): if ((self.N != y.N) or (self.T != y.T)): diff --git a/pyerrors/obs.py b/pyerrors/obs.py index 31ed84af..0409a5d4 100644 --- a/pyerrors/obs.py +++ b/pyerrors/obs.py @@ -773,11 +773,10 @@ class Obs: return self.value >= other def __eq__(self, other): + if other is None: + return False return (self - other).is_zero() - def __ne__(self, other): - return not (self - other).is_zero() - # Overload math operations def __add__(self, y): if isinstance(y, Obs): diff --git a/tests/correlators_test.py b/tests/correlators_test.py index 5b1c5e62..3d49164a 100644 --- a/tests/correlators_test.py +++ b/tests/correlators_test.py @@ -713,3 +713,39 @@ def test_corr_roll(): tt = mcorr.roll(T) - mcorr for el in tt: assert np.all(el == 0) + + +def test_correlator_comparison(): + scorr = pe.Corr([pe.pseudo_Obs(0.3, 0.1, "test") for o in range(4)]) + mcorr = pe.Corr(np.array([[scorr, scorr], [scorr, scorr]])) + for corr in [scorr, mcorr]: + assert (corr == corr).all() + assert np.all(corr == 1 * corr) + assert np.all(corr == (1 + 1e-16) * corr) + assert not np.all(corr == (1 + 1e-5) * corr) + assert np.all(corr == 1 / (1 / corr)) + assert np.all(corr - corr == 0) + assert np.all(corr * 0 == 0) + assert np.all(0 * corr == 0) + assert np.all(0 * corr + scorr[2] == scorr[2]) + assert np.all(-corr == 0 - corr) + assert np.all(corr ** 2 == corr * corr) + acorr = pe.Corr([scorr[0]] * 6) + assert np.all(acorr == scorr[0]) + assert not np.all(acorr == scorr[1]) + + mcorr[1][0, 1] = None + assert not np.all(mcorr == pe.Corr(np.array([[scorr, scorr], [scorr, scorr]]))) + + pcorr = pe.Corr([pe.pseudo_Obs(0.25, 0.1, "test") for o in range(2)], padding=[1, 1]) + assert np.all(pcorr == pcorr) + assert np.all(1 * pcorr == pcorr) + + +def test_corr_item(): + corr_aa = _gen_corr(1) + corr_ab = 0.5 * corr_aa + + corr_mat = pe.Corr(np.array([[corr_aa, corr_ab], [corr_ab, corr_aa]])) + corr_mat.item(0, 0) + assert corr_mat[0].item(0, 1) == corr_mat.item(0, 1)[0] diff --git a/tests/obs_test.py b/tests/obs_test.py index cc13c845..334521f0 100644 --- a/tests/obs_test.py +++ b/tests/obs_test.py @@ -103,6 +103,7 @@ def test_comparison(): test_obs1 = pe.pseudo_Obs(value1, 0.1, 't') value2 = np.random.normal(0, 100) test_obs2 = pe.pseudo_Obs(value2, 0.1, 't') + assert test_obs1 != None assert (value1 > value2) == (test_obs1 > test_obs2) assert (value1 < value2) == (test_obs1 < test_obs2) assert (value1 >= value2) == (test_obs1 >= test_obs2)