From 7f8c2ce33b7bc16899f048216639afb10ab96dfb Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Fri, 21 Jul 2023 14:15:41 +0100 Subject: [PATCH] feat: added support for addition and multiplication of complex numbers (#209) to Corr objects. --- pyerrors/correlators.py | 4 ++-- pyerrors/obs.py | 2 ++ tests/correlators_test.py | 10 ++++++++++ tests/obs_test.py | 7 +++++++ 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index be5c69a6..5c9e236a 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -1075,7 +1075,7 @@ class Corr: newcontent.append(self.content[t] + y.content[t]) return Corr(newcontent) - elif isinstance(y, (Obs, int, float, CObs)): + elif isinstance(y, (Obs, int, float, CObs, complex)): newcontent = [] for t in range(self.T): if _check_for_none(self, self.content[t]): @@ -1103,7 +1103,7 @@ class Corr: newcontent.append(self.content[t] * y.content[t]) return Corr(newcontent) - elif isinstance(y, (Obs, int, float, CObs)): + elif isinstance(y, (Obs, int, float, CObs, complex)): newcontent = [] for t in range(self.T): if _check_for_none(self, self.content[t]): diff --git a/pyerrors/obs.py b/pyerrors/obs.py index e9dc20cb..df287a10 100644 --- a/pyerrors/obs.py +++ b/pyerrors/obs.py @@ -784,6 +784,8 @@ class Obs: else: if isinstance(y, np.ndarray): return np.array([self + o for o in y]) + elif isinstance(y, complex): + return CObs(self, 0) + y elif y.__class__.__name__ in ['Corr', 'CObs']: return NotImplemented else: diff --git a/tests/correlators_test.py b/tests/correlators_test.py index 3d49164a..5f213034 100644 --- a/tests/correlators_test.py +++ b/tests/correlators_test.py @@ -749,3 +749,13 @@ def test_corr_item(): corr_mat = pe.Corr(np.array([[corr_aa, corr_ab], [corr_ab, corr_aa]])) corr_mat.item(0, 0) assert corr_mat[0].item(0, 1) == corr_mat.item(0, 1)[0] + + +def test_complex_add_and_mul(): + o = pe.pseudo_Obs(1.0, 0.3, "my_r345sfg16£$%&$%^%$^$", samples=47) + co = pe.CObs(o, 0.341 * o) + for obs in [o, co]: + cc = pe.Corr([obs for _ in range(4)]) + cc += 2j + cc = cc * 4j + cc.real + cc.imag diff --git a/tests/obs_test.py b/tests/obs_test.py index 334521f0..02d539ef 100644 --- a/tests/obs_test.py +++ b/tests/obs_test.py @@ -1333,3 +1333,10 @@ def test_vec_gm(): cc = pe.Corr(obs) pe.gm(cc, S=4.12) assert np.all(np.vectorize(lambda x: x.S["qq"])(cc.content) == 4.12) + +def test_complex_addition(): + o = pe.pseudo_Obs(34.12, 1e-4, "testens") + r = o + 2j + assert r.real == o + r = r * 1j + assert r.imag == o