mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-15 14:50:25 +01:00
Complex scalar array operations now correctly overloaded
This commit is contained in:
parent
ad53f28e46
commit
f395cb3d88
2 changed files with 44 additions and 4 deletions
|
@ -695,7 +695,9 @@ class CObs:
|
|||
return CObs(self.real, -self.imag)
|
||||
|
||||
def __add__(self, other):
|
||||
if hasattr(other, 'real') and hasattr(other, 'imag'):
|
||||
if isinstance(other, np.ndarray):
|
||||
return other + self
|
||||
elif hasattr(other, 'real') and hasattr(other, 'imag'):
|
||||
return CObs(self.real + other.real,
|
||||
self.imag + other.imag)
|
||||
else:
|
||||
|
@ -705,7 +707,9 @@ class CObs:
|
|||
return self + y
|
||||
|
||||
def __sub__(self, other):
|
||||
if hasattr(other, 'real') and hasattr(other, 'imag'):
|
||||
if isinstance(other, np.ndarray):
|
||||
return -1 * (other - self)
|
||||
elif hasattr(other, 'real') and hasattr(other, 'imag'):
|
||||
return CObs(self.real - other.real, self.imag - other.imag)
|
||||
else:
|
||||
return CObs(self.real - other, self.imag)
|
||||
|
@ -714,7 +718,9 @@ class CObs:
|
|||
return -1 * (self - other)
|
||||
|
||||
def __mul__(self, other):
|
||||
if hasattr(other, 'real') and hasattr(other, 'imag'):
|
||||
if isinstance(other, np.ndarray):
|
||||
return other * self
|
||||
elif hasattr(other, 'real') and hasattr(other, 'imag'):
|
||||
if all(isinstance(i, Obs) for i in [self.real, self.imag, other.real, other.imag]):
|
||||
return CObs(derived_observable(lambda x, **kwargs: x[0] * x[1] - x[2] * x[3],
|
||||
[self.real, other.real, self.imag, other.imag],
|
||||
|
@ -734,7 +740,9 @@ class CObs:
|
|||
return self * other
|
||||
|
||||
def __truediv__(self, other):
|
||||
if hasattr(other, 'real') and hasattr(other, 'imag'):
|
||||
if isinstance(other, np.ndarray):
|
||||
return 1 / (other / self)
|
||||
elif hasattr(other, 'real') and hasattr(other, 'imag'):
|
||||
r = other.real ** 2 + other.imag ** 2
|
||||
return CObs((self.real * other.real + self.imag * other.imag) / r, (self.imag * other.real - self.real * other.imag) / r)
|
||||
else:
|
||||
|
|
|
@ -124,3 +124,35 @@ def test_matrix_functions():
|
|||
tmp = sym @ v[:, i] - v[:, i] * e[i]
|
||||
for j in range(dim):
|
||||
assert tmp[j].is_zero()
|
||||
|
||||
|
||||
def test_complex_matrix_operations():
|
||||
dimension = 4
|
||||
base_matrix = np.empty((dimension, dimension), dtype=object)
|
||||
for (n, m), entry in np.ndenumerate(base_matrix):
|
||||
exponent_real = np.random.normal(3, 5)
|
||||
exponent_imag = np.random.normal(3, 5)
|
||||
base_matrix[n, m] = pe.CObs(pe.pseudo_Obs(2 + 10 ** exponent_real, 10 ** (exponent_real - 1), 't'),
|
||||
pe.pseudo_Obs(2 + 10 ** exponent_imag, 10 ** (exponent_imag - 1), 't'))
|
||||
|
||||
for other in [2, 2.3, (1 - 0.1j), (0 + 2.1j)]:
|
||||
ta = base_matrix * other
|
||||
tb = other * base_matrix
|
||||
diff = ta - tb
|
||||
for (i, j), entry in np.ndenumerate(diff):
|
||||
assert entry.is_zero()
|
||||
ta = base_matrix + other
|
||||
tb = other + base_matrix
|
||||
diff = ta - tb
|
||||
for (i, j), entry in np.ndenumerate(diff):
|
||||
assert entry.is_zero()
|
||||
ta = base_matrix - other
|
||||
tb = other - base_matrix
|
||||
diff = ta + tb
|
||||
for (i, j), entry in np.ndenumerate(diff):
|
||||
assert entry.is_zero()
|
||||
ta = base_matrix / other
|
||||
tb = other / base_matrix
|
||||
diff = ta * tb - 1
|
||||
for (i, j), entry in np.ndenumerate(diff):
|
||||
assert entry.is_zero()
|
||||
|
|
Loading…
Add table
Reference in a new issue