refactor: removed code from autograd master, test adjusted

This commit is contained in:
Fabian Joswig 2021-12-10 14:26:05 +00:00
parent 5ebfda5250
commit 525dea0209
2 changed files with 1 additions and 52 deletions

View file

@ -2,9 +2,6 @@ import numpy as np
import autograd.numpy as anp # Thinly-wrapped numpy
from .obs import derived_observable, CObs, Obs, import_jackknife
from functools import partial
from autograd.extend import defvjp
def matmul(*operands):
"""Matrix multiply all operands.
@ -527,51 +524,3 @@ def _num_diff_svd(obs, **kwargs):
res_mat2.append(row)
return (np.array(res_mat0) @ np.identity(mid_index), np.array(res_mat1) @ np.identity(mid_index), np.array(res_mat2) @ np.identity(shape[1]))
# This code block is directly taken from the current master branch of autograd and remains
# only until the new version is released on PyPi
_dot = partial(anp.einsum, '...ij,...jk->...ik')
# batched diag
def _diag(a):
return anp.eye(a.shape[-1]) * a
# batched diagonal, similar to matrix_diag in tensorflow
def _matrix_diag(a):
reps = anp.array(a.shape)
reps[:-1] = 1
reps[-1] = a.shape[-1]
newshape = list(a.shape) + [a.shape[-1]]
return _diag(anp.tile(a, reps).reshape(newshape))
# https://arxiv.org/pdf/1701.00392.pdf Eq(4.77)
# Note the formula from Sec3.1 in https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf is incomplete
def grad_eig(ans, x):
"""Gradient of a general square (complex valued) matrix"""
e, u = ans # eigenvalues as 1d array, eigenvectors in columns
n = e.shape[-1]
def vjp(g):
ge, gu = g
ge = _matrix_diag(ge)
f = 1 / (e[..., anp.newaxis, :] - e[..., :, anp.newaxis] + 1.e-20)
f -= _diag(f)
ut = anp.swapaxes(u, -1, -2)
r1 = f * _dot(ut, gu)
r2 = -f * (_dot(_dot(ut, anp.conj(u)), anp.real(_dot(ut, gu)) * anp.eye(n)))
r = _dot(_dot(anp.linalg.inv(ut), ge + r1 + r2), ut)
if not anp.iscomplexobj(x):
r = anp.real(r)
# the derivative is still complex for real input (imaginary delta is allowed), real output
# but the derivative should be real in real input case when imaginary delta is forbidden
return r
return vjp
defvjp(anp.linalg.eig, grad_eig)
# End of the code block from autograd.master

View file

@ -302,7 +302,7 @@ def test_matrix_functions():
# Check eig function
e2 = pe.linalg.eig(sym)
assert np.all(e == e2[::-1])
assert np.all(np.sort(e) == np.sort(e2))
# Check svd
u, v, vh = pe.linalg.svd(sym)