From 9b023e339f3a70463f19827a301134bd3736286e Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Mon, 25 Oct 2021 09:41:34 +0100 Subject: [PATCH] Bug in complex mutliplication fixed --- pyerrors/pyerrors.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/pyerrors/pyerrors.py b/pyerrors/pyerrors.py index 32500af3..7d6f4257 100644 --- a/pyerrors/pyerrors.py +++ b/pyerrors/pyerrors.py @@ -714,16 +714,17 @@ class CObs: return -1 * (self - other) def __mul__(self, other): - if all(isinstance(i, Obs) for i in [self.real, self.imag, other.real, other.imag]): - return CObs(derived_observable(lambda x, **kwargs: x[0] * x[1] - x[2] * x[3], - [self.real, other.real, self.imag, other.imag], - man_grad=[other.real.value, self.real.value, -other.imag.value, -self.imag.value]), - derived_observable(lambda x, **kwargs: x[2] * x[1] + x[0] * x[3], - [self.real, other.real, self.imag, other.imag], - man_grad=[other.imag.value, self.imag.value, other.real.value, self.real.value])) - elif hasattr(other, 'real') and getattr(other, 'imag', 0) != 0: - return CObs(self.real * other.real - self.imag * other.imag, - self.imag * other.real + self.real * other.imag) + if hasattr(other, 'real') and getattr(other, 'imag', 0) != 0: + if all(isinstance(i, Obs) for i in [self.real, self.imag, other.real, other.imag]): + return CObs(derived_observable(lambda x, **kwargs: x[0] * x[1] - x[2] * x[3], + [self.real, other.real, self.imag, other.imag], + man_grad=[other.real.value, self.real.value, -other.imag.value, -self.imag.value]), + derived_observable(lambda x, **kwargs: x[2] * x[1] + x[0] * x[3], + [self.real, other.real, self.imag, other.imag], + man_grad=[other.imag.value, self.imag.value, other.real.value, self.real.value])) + else: + return CObs(self.real * other.real - self.imag * other.imag, + self.imag * other.real + self.real * other.imag) else: return CObs(self.real * np.real(other), self.imag * np.real(other))