mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-05-15 03:53:41 +02:00
linalg.matmul now works with any number of operands
This commit is contained in:
parent
68b25ba4ca
commit
d4b86f5f73
2 changed files with 61 additions and 7 deletions
|
@ -124,18 +124,50 @@ def derived_array(func, data, **kwargs):
|
||||||
return final_result
|
return final_result
|
||||||
|
|
||||||
|
|
||||||
def matmul(x1, x2):
|
def matmul(*operands):
|
||||||
if isinstance(x1[0, 0], CObs) or isinstance(x2[0, 0], CObs):
|
if any(isinstance(o[0, 0], CObs) for o in operands):
|
||||||
Lr, Li = np.vectorize(lambda x: (np.real(x), np.imag(x)))(x1)
|
extended_operands = []
|
||||||
Rr, Ri = np.vectorize(lambda x: (np.real(x), np.imag(x)))(x2)
|
for op in operands:
|
||||||
Nr = derived_array(lambda x: x[0] @ x[2] - x[1] @ x[3], [Lr, Li, Rr, Ri])
|
tmp = np.vectorize(lambda x: (np.real(x), np.imag(x)))(op)
|
||||||
Ni = derived_array(lambda x: x[0] @ x[3] + x[1] @ x[2], [Lr, Li, Rr, Ri])
|
extended_operands.append(tmp[0])
|
||||||
|
extended_operands.append(tmp[1])
|
||||||
|
|
||||||
|
def multi_dot(operands, part):
|
||||||
|
stack_r = operands[0]
|
||||||
|
stack_i = operands[1]
|
||||||
|
for op_r, op_i in zip(operands[2::2], operands[3::2]):
|
||||||
|
tmp_r = stack_r @ op_r - stack_i @ op_i
|
||||||
|
tmp_i = stack_r @ op_i + stack_i @ op_r
|
||||||
|
|
||||||
|
stack_r = tmp_r
|
||||||
|
stack_i = tmp_i
|
||||||
|
|
||||||
|
if part == 'Real':
|
||||||
|
return stack_r
|
||||||
|
else:
|
||||||
|
return stack_i
|
||||||
|
|
||||||
|
def multi_dot_r(operands):
|
||||||
|
return multi_dot(operands, 'Real')
|
||||||
|
|
||||||
|
def multi_dot_i(operands):
|
||||||
|
return multi_dot(operands, 'Imag')
|
||||||
|
|
||||||
|
Nr = derived_array(multi_dot_r, extended_operands)
|
||||||
|
Ni = derived_array(multi_dot_i, extended_operands)
|
||||||
|
|
||||||
res = np.empty_like(Nr)
|
res = np.empty_like(Nr)
|
||||||
for (n, m), entry in np.ndenumerate(Nr):
|
for (n, m), entry in np.ndenumerate(Nr):
|
||||||
res[n, m] = CObs(Nr[n, m], Ni[n, m])
|
res[n, m] = CObs(Nr[n, m], Ni[n, m])
|
||||||
|
|
||||||
return res
|
return res
|
||||||
else:
|
else:
|
||||||
return derived_array(lambda x: x[0] @ x[1], [x1, x2])
|
def multi_dot(operands):
|
||||||
|
stack = operands[0]
|
||||||
|
for op in operands[1:]:
|
||||||
|
stack = stack @ op
|
||||||
|
return stack
|
||||||
|
return derived_array(multi_dot, operands)
|
||||||
|
|
||||||
|
|
||||||
def inv(x):
|
def inv(x):
|
||||||
|
|
|
@ -29,6 +29,28 @@ def test_matmul():
|
||||||
assert e.is_zero(), t
|
assert e.is_zero(), t
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_dot():
|
||||||
|
for dim in [4, 8]:
|
||||||
|
my_list = []
|
||||||
|
length = 1000 + np.random.randint(200)
|
||||||
|
for i in range(dim ** 2):
|
||||||
|
my_list.append(pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2']))
|
||||||
|
my_array = np.array(my_list).reshape((dim, dim))
|
||||||
|
tt = pe.linalg.matmul(my_array, my_array, my_array, my_array) - my_array @ my_array @ my_array @ my_array
|
||||||
|
for t, e in np.ndenumerate(tt):
|
||||||
|
assert e.is_zero(), t
|
||||||
|
|
||||||
|
my_list = []
|
||||||
|
length = 1000 + np.random.randint(200)
|
||||||
|
for i in range(dim ** 2):
|
||||||
|
my_list.append(pe.CObs(pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2']),
|
||||||
|
pe.Obs([np.random.rand(length), np.random.rand(length + 1)], ['t1', 't2'])))
|
||||||
|
my_array = np.array(my_list).reshape((dim, dim))
|
||||||
|
tt = pe.linalg.matmul(my_array, my_array, my_array, my_array) - my_array @ my_array @ my_array @ my_array
|
||||||
|
for t, e in np.ndenumerate(tt):
|
||||||
|
assert e.is_zero(), t
|
||||||
|
|
||||||
|
|
||||||
def test_matrix_inverse():
|
def test_matrix_inverse():
|
||||||
content = []
|
content = []
|
||||||
for t in range(9):
|
for t in range(9):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue