From f51503555be8349d441c6e982f34d416f2defd74 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Tue, 8 Feb 2022 17:07:40 +0000 Subject: [PATCH] fix: CObs can now be added and multiplied to as well as subtracted from Obs in all combinations --- pyerrors/obs.py | 12 +++++------- tests/obs_test.py | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/pyerrors/obs.py b/pyerrors/obs.py index 72a4156e..94748ed6 100644 --- a/pyerrors/obs.py +++ b/pyerrors/obs.py @@ -693,7 +693,7 @@ class Obs: else: if isinstance(y, np.ndarray): return np.array([self + o for o in y]) - elif y.__class__.__name__ == 'Corr': + elif y.__class__.__name__ in ['Corr', 'CObs']: return NotImplemented else: return derived_observable(lambda x, **kwargs: x[0] + y, [self], man_grad=[1]) @@ -709,7 +709,7 @@ class Obs: return np.array([self * o for o in y]) elif isinstance(y, complex): return CObs(self * y.real, self * y.imag) - elif y.__class__.__name__ == 'Corr': + elif y.__class__.__name__ in ['Corr', 'CObs']: return NotImplemented else: return derived_observable(lambda x, **kwargs: x[0] * y, [self], man_grad=[y]) @@ -723,10 +723,8 @@ class Obs: else: if isinstance(y, np.ndarray): return np.array([self - o for o in y]) - - elif y.__class__.__name__ == 'Corr': + elif y.__class__.__name__ in ['Corr', 'CObs']: return NotImplemented - else: return derived_observable(lambda x, **kwargs: x[0] - y, [self], man_grad=[1]) @@ -742,7 +740,7 @@ class Obs: else: if isinstance(y, np.ndarray): return np.array([self / o for o in y]) - elif y.__class__.__name__ == 'Corr': + elif y.__class__.__name__ in ['Corr', 'CObs']: return NotImplemented else: return derived_observable(lambda x, **kwargs: x[0] / y, [self], man_grad=[1 / y]) @@ -753,7 +751,7 @@ class Obs: else: if isinstance(y, np.ndarray): return np.array([o / self for o in y]) - elif y.__class__.__name__ == 'Corr': + elif y.__class__.__name__ in ['Corr', 'CObs']: return NotImplemented else: return derived_observable(lambda x, **kwargs: y / x[0], [self], man_grad=[-y / self.value ** 2]) diff --git a/tests/obs_test.py b/tests/obs_test.py index 4944630f..548f1a42 100644 --- a/tests/obs_test.py +++ b/tests/obs_test.py @@ -117,6 +117,10 @@ def test_function_overloading(): np.arctanh(1 / b) np.sinc(1 / b) + b ** b + 0.5 ** b + b ** 0.5 + def test_overloading_vectorization(): a = np.random.randint(1, 100, 10) @@ -392,6 +396,9 @@ def test_cobs(): obs2 = pe.pseudo_Obs(-0.2, 0.03, 't') my_cobs = pe.CObs(obs1, obs2) + my_cobs == my_cobs + str(my_cobs) + repr(my_cobs) assert not (my_cobs + my_cobs.conjugate()).real.is_zero() assert (my_cobs + my_cobs.conjugate()).imag.is_zero() assert (my_cobs - my_cobs.conjugate()).real.is_zero() @@ -424,6 +431,23 @@ def test_cobs(): assert (other / my_cobs * my_cobs - other).is_zero() +def test_cobs_overloading(): + obs = pe.pseudo_Obs(1.1, 0.1, 't') + cobs = pe.CObs(obs, obs) + + cobs + obs + obs + cobs + + cobs - obs + obs - cobs + + cobs * obs + obs * cobs + + cobs / obs + obs / cobs + + def test_reweighting(): my_obs = pe.Obs([np.random.rand(1000)], ['t']) assert not my_obs.reweighted