mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-05-15 03:53:41 +02:00
feat: guards added for functionality that breaks with numpy>=1.25 and
autograd==1.5.
This commit is contained in:
parent
f14042132f
commit
13ace62262
3 changed files with 31 additions and 20 deletions
|
@ -1,3 +1,4 @@
|
|||
from packaging import version
|
||||
import numpy as np
|
||||
import autograd.numpy as anp
|
||||
import math
|
||||
|
@ -291,23 +292,26 @@ def test_matrix_functions():
|
|||
diff = entry - sym[i, j]
|
||||
assert diff.is_zero()
|
||||
|
||||
# Check eigh
|
||||
e, v = pe.linalg.eigh(sym)
|
||||
for i in range(dim):
|
||||
tmp = sym @ v[:, i] - v[:, i] * e[i]
|
||||
for j in range(dim):
|
||||
assert tmp[j].is_zero()
|
||||
# These linalg functions don't work with numpy>=1.25 and autograd==1.5.
|
||||
# Remove this guard once this is fixed in autograd.
|
||||
if version.parse(np.__version__) < version.parse("1.25.0"):
|
||||
# Check eigh
|
||||
e, v = pe.linalg.eigh(sym)
|
||||
for i in range(dim):
|
||||
tmp = sym @ v[:, i] - v[:, i] * e[i]
|
||||
for j in range(dim):
|
||||
assert tmp[j].is_zero()
|
||||
|
||||
# Check eig function
|
||||
e2 = pe.linalg.eig(sym)
|
||||
assert np.all(np.sort(e) == np.sort(e2))
|
||||
# Check eig function
|
||||
e2 = pe.linalg.eig(sym)
|
||||
assert np.all(np.sort(e) == np.sort(e2))
|
||||
|
||||
# Check svd
|
||||
u, v, vh = pe.linalg.svd(sym)
|
||||
diff = sym - u @ np.diag(v) @ vh
|
||||
# Check svd
|
||||
u, v, vh = pe.linalg.svd(sym)
|
||||
diff = sym - u @ np.diag(v) @ vh
|
||||
|
||||
for (i, j), entry in np.ndenumerate(diff):
|
||||
assert entry.is_zero()
|
||||
for (i, j), entry in np.ndenumerate(diff):
|
||||
assert entry.is_zero()
|
||||
|
||||
# Check determinant
|
||||
assert pe.linalg.det(np.diag(np.diag(matrix))) == np.prod(np.diag(matrix))
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from packaging import version
|
||||
import numpy as np
|
||||
import pyerrors as pe
|
||||
import pytest
|
||||
|
@ -5,10 +6,11 @@ import pytest
|
|||
np.random.seed(0)
|
||||
|
||||
|
||||
def test_mpm():
|
||||
corr_content = []
|
||||
for t in range(8):
|
||||
f = 0.8 * np.exp(-0.4 * t)
|
||||
corr_content.append(pe.pseudo_Obs(np.random.normal(f, 1e-2 * f), 1e-2 * f, 't'))
|
||||
if version.parse(np.__version__) < version.parse("1.25.0"):
|
||||
def test_mpm():
|
||||
corr_content = []
|
||||
for t in range(8):
|
||||
f = 0.8 * np.exp(-0.4 * t)
|
||||
corr_content.append(pe.pseudo_Obs(np.random.normal(f, 1e-2 * f), 1e-2 * f, 't'))
|
||||
|
||||
res = pe.mpm.matrix_pencil_method(corr_content)
|
||||
res = pe.mpm.matrix_pencil_method(corr_content)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue