diff --git a/pyerrors/pyerrors.py b/pyerrors/pyerrors.py index 5d54b9af..84a41f17 100644 --- a/pyerrors/pyerrors.py +++ b/pyerrors/pyerrors.py @@ -655,6 +655,9 @@ class CObs: if isinstance(self.imag, Obs): self.imag.gamma_method(**kwargs) + def is_zero(self): + return self.real == 0.0 and self.imag == 0.0 + def conjugate(self): return CObs(self.real, -self.imag) @@ -678,7 +681,14 @@ class CObs: return -1 * (self - other) def __mul__(self, other): - if hasattr(other, 'real') and hasattr(other, 'imag'): + if all(isinstance(i, Obs) for i in [self.real, self.imag, other.real, other.imag]): + return CObs(derived_observable(lambda x, **kwargs: x[0] * x[1] - x[2] * x[3], + [self.real, other.real, self.imag, other.imag], + man_grad=[other.real.value, self.real.value, -other.imag.value, -self.imag.value]), + derived_observable(lambda x, **kwargs: x[2] * x[1] + x[0] * x[3], + [self.real, other.real, self.imag, other.imag], + man_grad=[other.imag.value, self.imag.value, other.real.value, self.real.value])) + elif hasattr(other, 'real') and hasattr(other, 'imag'): return CObs(self.real * other.real - self.imag * other.imag, self.imag * other.real + self.real * other.imag) else: diff --git a/tests/test_pyerrors.py b/tests/test_pyerrors.py index 58ec21f3..7d304464 100644 --- a/tests/test_pyerrors.py +++ b/tests/test_pyerrors.py @@ -204,6 +204,10 @@ def test_cobs(): assert (my_cobs - my_cobs.conjugate()).real.is_zero() assert not (my_cobs - my_cobs.conjugate()).imag.is_zero() np.abs(my_cobs) + + assert (my_cobs * my_cobs / my_cobs - my_cobs).is_zero() + assert (my_cobs + my_cobs - 2 * my_cobs).is_zero() + fs = [[lambda x: x[0] + x[1], lambda x: x[1] + x[0]], [lambda x: x[0] * x[1], lambda x: x[1] * x[0]]] for other in [1, 1.1, (1.1-0.2j), pe.CObs(obs1), pe.CObs(obs1, obs2)]: