feat: linalg.einsum optimized

This commit is contained in:
Fabian Joswig 2021-12-10 15:39:57 +00:00
parent 207a60c085
commit fe03bf9844

View file

@ -178,7 +178,8 @@ def einsum(subscripts, *operands):
tmp_subscripts = ','.join([o + '...' for o in subscripts.split(',')])
extended_subscripts = '->'.join([o + '...' for o in tmp_subscripts.split('->')[:-1]] + [tmp_subscripts.split('->')[-1]])
jack_einsum = np.einsum(extended_subscripts, *conv_operands)
einsum_path = np.einsum_path(extended_subscripts, *conv_operands, optimize='optimal')[0]
jack_einsum = np.einsum(extended_subscripts, *conv_operands, optimize=einsum_path)
if jack_einsum.dtype == complex:
result = _imp_from_jack_c(jack_einsum, name, idl)