mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-15 14:50:25 +01:00
linalg.matmul now works with any number of operands
This commit is contained in:
parent
68b25ba4ca
commit
d4b86f5f73
2 changed files with 61 additions and 7 deletions
|
@ -124,18 +124,50 @@ def derived_array(func, data, **kwargs):
|
|||
return final_result
|
||||
|
||||
|
||||
def matmul(x1, x2):
|
||||
if isinstance(x1[0, 0], CObs) or isinstance(x2[0, 0], CObs):
|
||||
Lr, Li = np.vectorize(lambda x: (np.real(x), np.imag(x)))(x1)
|
||||
Rr, Ri = np.vectorize(lambda x: (np.real(x), np.imag(x)))(x2)
|
||||
Nr = derived_array(lambda x: x[0] @ x[2] - x[1] @ x[3], [Lr, Li, Rr, Ri])
|
||||
Ni = derived_array(lambda x: x[0] @ x[3] + x[1] @ x[2], [Lr, Li, Rr, Ri])
|
||||
def matmul(*operands):
|
||||
if any(isinstance(o[0, 0], CObs) for o in operands):
|
||||
extended_operands = []
|
||||
for op in operands:
|
||||
tmp = np.vectorize(lambda x: (np.real(x), np.imag(x)))(op)
|
||||
extended_operands.append(tmp[0])
|
||||
extended_operands.append(tmp[1])
|
||||
|
||||
def multi_dot(operands, part):
|
||||
stack_r = operands[0]
|
||||
stack_i = operands[1]
|
||||
for op_r, op_i in zip(operands[2::2], operands[3::2]):
|
||||
tmp_r = stack_r @ op_r - stack_i @ op_i
|
||||
tmp_i = stack_r @ op_i + stack_i @ op_r
|
||||
|
||||
stack_r = tmp_r
|
||||
stack_i = tmp_i
|
||||
|
||||
if part == 'Real':
|
||||
return stack_r
|
||||
else:
|
||||
return stack_i
|
||||
|
||||
def multi_dot_r(operands):
|
||||
return multi_dot(operands, 'Real')
|
||||
|
||||
def multi_dot_i(operands):
|
||||
return multi_dot(operands, 'Imag')
|
||||
|
||||
Nr = derived_array(multi_dot_r, extended_operands)
|
||||
Ni = derived_array(multi_dot_i, extended_operands)
|
||||
|
||||
res = np.empty_like(Nr)
|
||||
for (n, m), entry in np.ndenumerate(Nr):
|
||||
res[n, m] = CObs(Nr[n, m], Ni[n, m])
|
||||
|
||||
return res
|
||||
else:
|
||||
return derived_array(lambda x: x[0] @ x[1], [x1, x2])
|
||||
def multi_dot(operands):
|
||||
stack = operands[0]
|
||||
for op in operands[1:]:
|
||||
stack = stack @ op
|
||||
return stack
|
||||
return derived_array(multi_dot, operands)
|
||||
|
||||
|
||||
def inv(x):
|
||||
|
|
|
@ -29,6 +29,28 @@ def test_matmul():
|
|||
assert e.is_zero(), t
|
||||
|
||||
|
||||
def test_multi_dot():
|
||||
for dim in [4, 8]:
|
||||
my_list = []
|
||||
length = 1000 + np.random.randint(200)
|
||||
for i in range(dim ** 2):
|
||||
my_list.append(pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2']))
|
||||
my_array = np.array(my_list).reshape((dim, dim))
|
||||
tt = pe.linalg.matmul(my_array, my_array, my_array, my_array) - my_array @ my_array @ my_array @ my_array
|
||||
for t, e in np.ndenumerate(tt):
|
||||
assert e.is_zero(), t
|
||||
|
||||
my_list = []
|
||||
length = 1000 + np.random.randint(200)
|
||||
for i in range(dim ** 2):
|
||||
my_list.append(pe.CObs(pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2']),
|
||||
pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2'])))
|
||||
my_array = np.array(my_list).reshape((dim, dim))
|
||||
tt = pe.linalg.matmul(my_array, my_array, my_array, my_array) - my_array @ my_array @ my_array @ my_array
|
||||
for t, e in np.ndenumerate(tt):
|
||||
assert e.is_zero(), t
|
||||
|
||||
|
||||
def test_matrix_inverse():
|
||||
content = []
|
||||
for t in range(9):
|
||||
|
|
Loading…
Add table
Reference in a new issue