linalg.matmul now works with any number of operands

This commit is contained in:
Fabian Joswig 2021-10-25 13:58:16 +01:00
parent 68b25ba4ca
commit d4b86f5f73
2 changed files with 61 additions and 7 deletions

View file

@ -124,18 +124,50 @@ def derived_array(func, data, **kwargs):
return final_result
def matmul(x1, x2):
if isinstance(x1[0, 0], CObs) or isinstance(x2[0, 0], CObs):
Lr, Li = np.vectorize(lambda x: (np.real(x), np.imag(x)))(x1)
Rr, Ri = np.vectorize(lambda x: (np.real(x), np.imag(x)))(x2)
Nr = derived_array(lambda x: x[0] @ x[2] - x[1] @ x[3], [Lr, Li, Rr, Ri])
Ni = derived_array(lambda x: x[0] @ x[3] + x[1] @ x[2], [Lr, Li, Rr, Ri])
def matmul(*operands):
if any(isinstance(o[0, 0], CObs) for o in operands):
extended_operands = []
for op in operands:
tmp = np.vectorize(lambda x: (np.real(x), np.imag(x)))(op)
extended_operands.append(tmp[0])
extended_operands.append(tmp[1])
def multi_dot(operands, part):
stack_r = operands[0]
stack_i = operands[1]
for op_r, op_i in zip(operands[2::2], operands[3::2]):
tmp_r = stack_r @ op_r - stack_i @ op_i
tmp_i = stack_r @ op_i + stack_i @ op_r
stack_r = tmp_r
stack_i = tmp_i
if part == 'Real':
return stack_r
else:
return stack_i
def multi_dot_r(operands):
return multi_dot(operands, 'Real')
def multi_dot_i(operands):
return multi_dot(operands, 'Imag')
Nr = derived_array(multi_dot_r, extended_operands)
Ni = derived_array(multi_dot_i, extended_operands)
res = np.empty_like(Nr)
for (n, m), entry in np.ndenumerate(Nr):
res[n, m] = CObs(Nr[n, m], Ni[n, m])
return res
else:
return derived_array(lambda x: x[0] @ x[1], [x1, x2])
def multi_dot(operands):
stack = operands[0]
for op in operands[1:]:
stack = stack @ op
return stack
return derived_array(multi_dot, operands)
def inv(x):