First try at replacing autograd by jax

This commit is contained in:
Fabian Joswig 2021-10-18 12:53:17 +01:00
parent 8d7a5daafa
commit 8fc5d96363
9 changed files with 76 additions and 55 deletions

View file

@ -1,9 +1,12 @@
import autograd.numpy as np
import numpy as np
import jax.numpy as jnp
import math
import scipy.optimize
from scipy.odr import ODR, Model, RealData
import pyerrors as pe
import pytest
from jax.config import config
config.update("jax_enable_x64", True)
np.random.seed(0)
@ -24,7 +27,7 @@ def test_standard_fit():
popt, pcov = scipy.optimize.curve_fit(f, x, y, sigma=[o.dvalue for o in oy], absolute_sigma=True)
def func(a, x):
y = a[0] * np.exp(-a[1] * x)
y = a[0] * jnp.exp(-a[1] * x)
return y
beta = pe.fits.standard_fit(x, oy, func)
@ -61,7 +64,7 @@ def test_odr_fit():
return a * np.exp(-b * x)
def func(a, x):
y = a[0] * np.exp(-a[1] * x)
y = a[0] * jnp.exp(-a[1] * x)
return y
data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy])

View file

@ -1,7 +1,10 @@
import autograd.numpy as np
import numpy as np
import jax.numpy as jnp
import math
import pyerrors as pe
import pytest
from jax.config import config
config.update("jax_enable_x64", True)
np.random.seed(0)
@ -14,7 +17,7 @@ def test_matrix_inverse():
content.append(1.0) # Add 1.0 as a float
matrix = np.diag(content)
inverse_matrix = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
inverse_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
assert all([o.is_zero() for o in np.diag(matrix) * np.diag(inverse_matrix) - 1])
@ -35,7 +38,7 @@ def test_complex_matrix_inverse():
matrix[n, m] = entry.real.value + 1j * entry.imag.value
inverse_matrix = np.linalg.inv(matrix)
inverse_obs_matrix = pe.linalg.mat_mat_op(np.linalg.inv, obs_matrix)
inverse_obs_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, obs_matrix)
for (n, m), entry in np.ndenumerate(inverse_matrix):
assert np.isclose(inverse_matrix[n, m].real, inverse_obs_matrix[n, m].real.value)
assert np.isclose(inverse_matrix[n, m].imag, inverse_obs_matrix[n, m].imag.value)
@ -53,7 +56,7 @@ def test_matrix_functions():
matrix = np.array(matrix) @ np.identity(dim)
# Check inverse of matrix
inv = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
inv = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
check_inv = matrix @ inv
for (i, j), entry in np.ndenumerate(check_inv):
@ -66,7 +69,7 @@ def test_matrix_functions():
# Check Cholesky decomposition
sym = np.dot(matrix, matrix.T)
cholesky = pe.linalg.mat_mat_op(np.linalg.cholesky, sym)
cholesky = pe.linalg.mat_mat_op(jnp.linalg.cholesky, sym)
check = cholesky @ cholesky.T
for (i, j), entry in np.ndenumerate(check):

View file

@ -1,10 +1,13 @@
import autograd.numpy as np
import numpy as np
import jax.numpy as jnp
import os
import random
import string
import copy
import pyerrors as pe
import pytest
from jax.config import config
config.update("jax_enable_x64", True)
np.random.seed(0)
@ -29,21 +32,31 @@ def test_comparison():
def test_function_overloading():
a = pe.pseudo_Obs(17, 2.9, 'e1')
a = pe.pseudo_Obs(2, 2.9, 'e1')
b = pe.pseudo_Obs(4, 0.8, 'e1')
fs = [lambda x: x[0] + x[1], lambda x: x[1] + x[0], lambda x: x[0] - x[1], lambda x: x[1] - x[0],
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0],
lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0]]
for i, f in enumerate(fs):
t1 = f([a, b])
t2 = pe.derived_observable(f, [a, b])
c = t2 - t1
assert c.value == 0.0, str(i)
assert np.all(np.abs(c.deltas['e1']) < 1e-14), str(i)
assert c.is_zero()
f_np = [lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
f_jnp = [lambda x: jnp.exp(x[0]), lambda x: jnp.sin(x[0]), lambda x: jnp.cos(x[0]), lambda x: jnp.tan(x[0]),
lambda x: jnp.log(x[0]), lambda x: jnp.sqrt(jnp.abs(x[0])),
lambda x: jnp.sinh(x[0]), lambda x: jnp.cosh(x[0]), lambda x: jnp.tanh(x[0])]
for i, (f1, f2) in enumerate(zip(f_np, f_jnp)):
t1 = f1([a])
t2 = pe.derived_observable(f2, [a])
c = t2 - t1
assert c.is_zero()
def test_overloading_vectorization():
@ -121,7 +134,7 @@ def test_derived_observables():
test_obs = pe.pseudo_Obs(2, 0.1 * (1 + np.random.rand()), 't', int(1000 * (1 + np.random.rand())))
# Check if autograd and numgrad give the same result
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs])
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * jnp.sin(x[0] * x[1]), [test_obs, test_obs])
d_Obs_ad.gamma_method()
d_Obs_fd = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs], num_grad=True)
d_Obs_fd.gamma_method()