mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-05-14 19:43:41 +02:00
First try at replacing autograd by jax
This commit is contained in:
parent
8d7a5daafa
commit
8fc5d96363
9 changed files with 76 additions and 55 deletions
|
@ -1,9 +1,12 @@
|
|||
import autograd.numpy as np
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import math
|
||||
import scipy.optimize
|
||||
from scipy.odr import ODR, Model, RealData
|
||||
import pyerrors as pe
|
||||
import pytest
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
@ -24,7 +27,7 @@ def test_standard_fit():
|
|||
popt, pcov = scipy.optimize.curve_fit(f, x, y, sigma=[o.dvalue for o in oy], absolute_sigma=True)
|
||||
|
||||
def func(a, x):
|
||||
y = a[0] * np.exp(-a[1] * x)
|
||||
y = a[0] * jnp.exp(-a[1] * x)
|
||||
return y
|
||||
|
||||
beta = pe.fits.standard_fit(x, oy, func)
|
||||
|
@ -61,7 +64,7 @@ def test_odr_fit():
|
|||
return a * np.exp(-b * x)
|
||||
|
||||
def func(a, x):
|
||||
y = a[0] * np.exp(-a[1] * x)
|
||||
y = a[0] * jnp.exp(-a[1] * x)
|
||||
return y
|
||||
|
||||
data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy])
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import autograd.numpy as np
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import math
|
||||
import pyerrors as pe
|
||||
import pytest
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
@ -14,7 +17,7 @@ def test_matrix_inverse():
|
|||
|
||||
content.append(1.0) # Add 1.0 as a float
|
||||
matrix = np.diag(content)
|
||||
inverse_matrix = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
|
||||
inverse_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
|
||||
assert all([o.is_zero() for o in np.diag(matrix) * np.diag(inverse_matrix) - 1])
|
||||
|
||||
|
||||
|
@ -35,7 +38,7 @@ def test_complex_matrix_inverse():
|
|||
matrix[n, m] = entry.real.value + 1j * entry.imag.value
|
||||
|
||||
inverse_matrix = np.linalg.inv(matrix)
|
||||
inverse_obs_matrix = pe.linalg.mat_mat_op(np.linalg.inv, obs_matrix)
|
||||
inverse_obs_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, obs_matrix)
|
||||
for (n, m), entry in np.ndenumerate(inverse_matrix):
|
||||
assert np.isclose(inverse_matrix[n, m].real, inverse_obs_matrix[n, m].real.value)
|
||||
assert np.isclose(inverse_matrix[n, m].imag, inverse_obs_matrix[n, m].imag.value)
|
||||
|
@ -53,7 +56,7 @@ def test_matrix_functions():
|
|||
matrix = np.array(matrix) @ np.identity(dim)
|
||||
|
||||
# Check inverse of matrix
|
||||
inv = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
|
||||
inv = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
|
||||
check_inv = matrix @ inv
|
||||
|
||||
for (i, j), entry in np.ndenumerate(check_inv):
|
||||
|
@ -66,7 +69,7 @@ def test_matrix_functions():
|
|||
|
||||
# Check Cholesky decomposition
|
||||
sym = np.dot(matrix, matrix.T)
|
||||
cholesky = pe.linalg.mat_mat_op(np.linalg.cholesky, sym)
|
||||
cholesky = pe.linalg.mat_mat_op(jnp.linalg.cholesky, sym)
|
||||
check = cholesky @ cholesky.T
|
||||
|
||||
for (i, j), entry in np.ndenumerate(check):
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import autograd.numpy as np
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import copy
|
||||
import pyerrors as pe
|
||||
import pytest
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
@ -29,21 +32,31 @@ def test_comparison():
|
|||
|
||||
|
||||
def test_function_overloading():
|
||||
a = pe.pseudo_Obs(17, 2.9, 'e1')
|
||||
a = pe.pseudo_Obs(2, 2.9, 'e1')
|
||||
b = pe.pseudo_Obs(4, 0.8, 'e1')
|
||||
|
||||
fs = [lambda x: x[0] + x[1], lambda x: x[1] + x[0], lambda x: x[0] - x[1], lambda x: x[1] - x[0],
|
||||
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0],
|
||||
lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
|
||||
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
|
||||
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
|
||||
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0]]
|
||||
|
||||
for i, f in enumerate(fs):
|
||||
t1 = f([a, b])
|
||||
t2 = pe.derived_observable(f, [a, b])
|
||||
c = t2 - t1
|
||||
assert c.value == 0.0, str(i)
|
||||
assert np.all(np.abs(c.deltas['e1']) < 1e-14), str(i)
|
||||
assert c.is_zero()
|
||||
|
||||
|
||||
f_np = [lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
|
||||
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
|
||||
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
|
||||
f_jnp = [lambda x: jnp.exp(x[0]), lambda x: jnp.sin(x[0]), lambda x: jnp.cos(x[0]), lambda x: jnp.tan(x[0]),
|
||||
lambda x: jnp.log(x[0]), lambda x: jnp.sqrt(jnp.abs(x[0])),
|
||||
lambda x: jnp.sinh(x[0]), lambda x: jnp.cosh(x[0]), lambda x: jnp.tanh(x[0])]
|
||||
|
||||
for i, (f1, f2) in enumerate(zip(f_np, f_jnp)):
|
||||
t1 = f1([a])
|
||||
t2 = pe.derived_observable(f2, [a])
|
||||
c = t2 - t1
|
||||
assert c.is_zero()
|
||||
|
||||
|
||||
def test_overloading_vectorization():
|
||||
|
@ -121,7 +134,7 @@ def test_derived_observables():
|
|||
test_obs = pe.pseudo_Obs(2, 0.1 * (1 + np.random.rand()), 't', int(1000 * (1 + np.random.rand())))
|
||||
|
||||
# Check if autograd and numgrad give the same result
|
||||
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs])
|
||||
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * jnp.sin(x[0] * x[1]), [test_obs, test_obs])
|
||||
d_Obs_ad.gamma_method()
|
||||
d_Obs_fd = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs], num_grad=True)
|
||||
d_Obs_fd.gamma_method()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue