feat: guards added for functionality that breaks with numpy>=1.25 and

autograd==1.5.
This commit is contained in:
Fabian Joswig 2023-06-19 13:28:30 +01:00
parent f14042132f
commit 13ace62262
No known key found for this signature in database
3 changed files with 31 additions and 20 deletions

View file

@ -1,3 +1,4 @@
from packaging import version
import numpy as np
import autograd.numpy as anp # Thinly-wrapped numpy
from .obs import derived_observable, CObs, Obs, import_jackknife
@ -260,6 +261,8 @@ def _mat_mat_op(op, obs, **kwargs):
def eigh(obs, **kwargs):
"""Computes the eigenvalues and eigenvectors of a given hermitian matrix of Obs according to np.linalg.eigh."""
if version.parse(np.__version__) >= version.parse("1.25.0"):
raise NotImplementedError("eigh error propagation is not working with numpy>=1.25 and autograd==1.5.")
w = derived_observable(lambda x, **kwargs: anp.linalg.eigh(x)[0], obs)
v = derived_observable(lambda x, **kwargs: anp.linalg.eigh(x)[1], obs)
return w, v
@ -278,6 +281,8 @@ def pinv(obs, **kwargs):
def svd(obs, **kwargs):
"""Computes the singular value decomposition of a matrix of Obs."""
if version.parse(np.__version__) >= version.parse("1.25.0"):
raise NotImplementedError("svd error propagation is not working with numpy>=1.25 and autograd==1.5.")
u = derived_observable(lambda x, **kwargs: anp.linalg.svd(x, full_matrices=False)[0], obs)
s = derived_observable(lambda x, **kwargs: anp.linalg.svd(x, full_matrices=False)[1], obs)
vh = derived_observable(lambda x, **kwargs: anp.linalg.svd(x, full_matrices=False)[2], obs)

View file

@ -1,3 +1,4 @@
from packaging import version
import numpy as np
import autograd.numpy as anp
import math
@ -291,6 +292,9 @@ def test_matrix_functions():
diff = entry - sym[i, j]
assert diff.is_zero()
# These linalg functions don't work with numpy>=1.25 and autograd==1.5.
# Remove this guard once this is fixed in autograd.
if version.parse(np.__version__) < version.parse("1.25.0"):
# Check eigh
e, v = pe.linalg.eigh(sym)
for i in range(dim):

View file

@ -1,3 +1,4 @@
from packaging import version
import numpy as np
import pyerrors as pe
import pytest
@ -5,6 +6,7 @@ import pytest
np.random.seed(0)
if version.parse(np.__version__) < version.parse("1.25.0"):
def test_mpm():
corr_content = []
for t in range(8):