First try at replacing autograd by jax

This commit is contained in:
Fabian Joswig 2021-10-18 12:53:17 +01:00
parent 8d7a5daafa
commit 8fc5d96363
9 changed files with 76 additions and 55 deletions

View file

@ -1,6 +1,6 @@
import warnings
import numpy as np
import autograd.numpy as anp
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy.linalg
from .pyerrors import Obs, dump_object
@ -187,10 +187,10 @@ class Corr:
def Eigenvalue(self, t0, state=1):
G = self.smearing_symmetric()
G0 = G.content[t0]
L = mat_mat_op(anp.linalg.cholesky, G0)
Li = mat_mat_op(anp.linalg.inv, L)
L = mat_mat_op(jnp.linalg.cholesky, G0)
Li = mat_mat_op(jnp.linalg.inv, L)
LT = L.T
LTi = mat_mat_op(anp.linalg.inv, LT)
LTi = mat_mat_op(jnp.linalg.inv, LT)
newcontent = []
for t in range(self.T):
Gt = G.content[t]
@ -263,9 +263,9 @@ class Corr:
elif variant in ['periodic', 'cosh', 'sinh']:
if variant in ['periodic', 'cosh']:
func = anp.cosh
func = jnp.cosh
else:
func = anp.sinh
func = jnp.sinh
def root_function(x, d):
return func(x * (t - self.T / 2)) / func(x * (t + 1 - self.T / 2)) - d

View file

@ -3,15 +3,15 @@
import warnings
import numpy as np
import autograd.numpy as anp
import jax.numpy as jnp
import scipy.optimize
import scipy.stats
import matplotlib.pyplot as plt
from matplotlib import gridspec
from scipy.odr import ODR, Model, RealData
import iminuit
from autograd import jacobian
from autograd import elementwise_grad as egrad
from jax import jacobian
#from jax import elementwise_grad as egrad
from .pyerrors import Obs, derived_observable, covariance, pseudo_Obs
@ -24,7 +24,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
func has to be of the form
def func(a, x):
return a[0] + a[1] * x + a[2] * anp.sinh(x)
return a[0] + a[1] * x + a[2] * jnp.sinh(x)
For multiple x values func can be of the form
@ -82,10 +82,10 @@ def standard_fit(x, y, func, silent=False, **kwargs):
if not silent:
print('Fit with', n_parms, 'parameters')
y_f = [o.value for o in y]
dy_f = [o.dvalue for o in y]
y_f = np.array([o.value for o in y])
dy_f = np.array([o.dvalue for o in y])
if np.any(np.asarray(dy_f) <= 0.0):
if np.any(dy_f <= 0.0):
raise Exception('No y errors available, run the gamma method first.')
if 'initial_guess' in kwargs:
@ -97,7 +97,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
def chisqfunc(p):
model = func(p, x)
chisq = anp.sum(((y_f - model) / dy_f) ** 2)
chisq = jnp.sum(((y_f - model) / dy_f) ** 2)
return chisq
if 'method' in kwargs:
@ -153,7 +153,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
def chisqfunc_compact(d):
model = func(d[:n_parms], x)
chisq = anp.sum(((d[n_parms:] - model) / dy_f) ** 2)
chisq = jnp.sum(((d[n_parms:] - model) / dy_f) ** 2)
return chisq
jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((fit_result.x, y_f)))
@ -188,7 +188,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
func has to be of the form
def func(a, x):
y = a[0] + a[1] * x + a[2] * anp.sinh(x)
y = a[0] + a[1] * x + a[2] * jnp.sinh(x)
return y
For multiple x values func can be of the form
@ -279,7 +279,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
def odr_chisquare(p):
model = func(p[:n_parms], p[n_parms:].reshape(x_shape))
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((x_f - p[n_parms:].reshape(x_shape)) / dx_f) ** 2)
chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((x_f - p[n_parms:].reshape(x_shape)) / dx_f) ** 2)
return chisq
if kwargs.get('expected_chisquare') is True:
@ -312,7 +312,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
def odr_chisquare_compact_x(d):
model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape))
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((d[n_parms + m:].reshape(x_shape) - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((d[n_parms + m:].reshape(x_shape) - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
return chisq
jac_jac_x = jacobian(jacobian(odr_chisquare_compact_x))(np.concatenate((output.beta, output.xplus.ravel(), x_f.ravel())))
@ -321,7 +321,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
def odr_chisquare_compact_y(d):
model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape))
chisq = anp.sum(((d[n_parms + m:] - model) / dy_f) ** 2) + anp.sum(((x_f - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
chisq = jnp.sum(((d[n_parms + m:] - model) / dy_f) ** 2) + jnp.sum(((x_f - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
return chisq
jac_jac_y = jacobian(jacobian(odr_chisquare_compact_y))(np.concatenate((output.beta, output.xplus.ravel(), y_f)))
@ -349,7 +349,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
func has to be of the form
def func(a, x):
y = a[0] + a[1] * x + a[2] * anp.sinh(x)
y = a[0] + a[1] * x + a[2] * jnp.sinh(x)
return y
It is important that all numpy functions refer to autograd.numpy, otherwise the differentiation
@ -441,7 +441,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
def chisqfunc(p):
model = func(p, x)
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((p_f - p) / dp_f) ** 2)
chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((p_f - p) / dp_f) ** 2)
return chisq
if not silent:
@ -469,7 +469,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
def chisqfunc_compact(d):
model = func(d[:n_parms], x)
chisq = anp.sum(((d[n_parms: n_parms + len(x)] - model) / dy_f) ** 2) + anp.sum(((d[n_parms + len(x):] - d[:n_parms]) / dp_f) ** 2)
chisq = jnp.sum(((d[n_parms: n_parms + len(x)] - model) / dy_f) ** 2) + jnp.sum(((d[n_parms + len(x):] - d[:n_parms]) / dp_f) ** 2)
return chisq
jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((params, y_f, p_f)))

View file

@ -3,7 +3,7 @@
import ctypes
import hashlib
import autograd.numpy as np # Thinly-wrapped numpy
import jax.numpy as np # Thinly-wrapped numpy
from ..pyerrors import Obs

View file

@ -2,7 +2,7 @@
# coding: utf-8
import numpy as np
import autograd.numpy as anp # Thinly-wrapped numpy
import jax.numpy as anp # Thinly-wrapped numpy
from .pyerrors import derived_observable, CObs

View file

@ -4,10 +4,12 @@
import warnings
import pickle
import numpy as np
import autograd.numpy as anp # Thinly-wrapped numpy
from autograd import jacobian
import jax.numpy as jnp # Thinly-wrapped numpy
from jax import jacobian
import matplotlib.pyplot as plt
import numdifftools as nd
from jax.config import config
config.update("jax_enable_x64", True)
class Obs:
@ -573,7 +575,7 @@ class Obs:
return derived_observable(lambda x: y ** x[0], [self])
def __abs__(self):
return derived_observable(lambda x: anp.abs(x[0]), [self])
return derived_observable(lambda x: jnp.abs(x[0]), [self])
# Overload numpy functions
def sqrt(self):
@ -595,13 +597,13 @@ class Obs:
return derived_observable(lambda x, **kwargs: np.tan(x[0]), [self], man_grad=[1 / np.cos(self.value) ** 2])
def arcsin(self):
return derived_observable(lambda x: anp.arcsin(x[0]), [self])
return derived_observable(lambda x: jnp.arcsin(x[0]), [self])
def arccos(self):
return derived_observable(lambda x: anp.arccos(x[0]), [self])
return derived_observable(lambda x: jnp.arccos(x[0]), [self])
def arctan(self):
return derived_observable(lambda x: anp.arctan(x[0]), [self])
return derived_observable(lambda x: jnp.arctan(x[0]), [self])
def sinh(self):
return derived_observable(lambda x, **kwargs: np.sinh(x[0]), [self], man_grad=[np.cosh(self.value)])
@ -613,16 +615,16 @@ class Obs:
return derived_observable(lambda x, **kwargs: np.tanh(x[0]), [self], man_grad=[1 / np.cosh(self.value) ** 2])
def arcsinh(self):
return derived_observable(lambda x: anp.arcsinh(x[0]), [self])
return derived_observable(lambda x: jnp.arcsinh(x[0]), [self])
def arccosh(self):
return derived_observable(lambda x: anp.arccosh(x[0]), [self])
return derived_observable(lambda x: jnp.arccosh(x[0]), [self])
def arctanh(self):
return derived_observable(lambda x: anp.arctanh(x[0]), [self])
return derived_observable(lambda x: jnp.arctanh(x[0]), [self])
def sinc(self):
return derived_observable(lambda x: anp.sinc(x[0]), [self])
return derived_observable(lambda x: jnp.sinc(x[0]), [self])
class CObs:
@ -695,7 +697,7 @@ def derived_observable(func, data, **kwargs):
----------
func -- arbitrary function of the form func(data, **kwargs). For the
automatic differentiation to work, all numpy functions have to have
the autograd wrapper (use 'import autograd.numpy as anp').
the autograd wrapper (use 'import autograd.numpy as jnp').
data -- list of Obs, e.g. [obs1, obs2, obs3].
Keyword arguments
@ -748,9 +750,9 @@ def derived_observable(func, data, **kwargs):
if new_shape[name] != tmp:
raise Exception('Shapes of ensemble', name, 'do not match.')
if data.ndim == 1:
values = np.array([o.value for o in data])
values = jnp.array([o.value for o in data])
else:
values = np.vectorize(lambda x: x.value)(data)
values = jnp.array(np.vectorize(lambda x: x.value)(data))
new_values = func(values, **kwargs)

View file

@ -2,7 +2,7 @@
# coding: utf-8
import scipy.optimize
from autograd import jacobian
from jax import jacobian
from .pyerrors import derived_observable, pseudo_Obs

View file

@ -1,9 +1,12 @@
import autograd.numpy as np
import numpy as np
import jax.numpy as jnp
import math
import scipy.optimize
from scipy.odr import ODR, Model, RealData
import pyerrors as pe
import pytest
from jax.config import config
config.update("jax_enable_x64", True)
np.random.seed(0)
@ -24,7 +27,7 @@ def test_standard_fit():
popt, pcov = scipy.optimize.curve_fit(f, x, y, sigma=[o.dvalue for o in oy], absolute_sigma=True)
def func(a, x):
y = a[0] * np.exp(-a[1] * x)
y = a[0] * jnp.exp(-a[1] * x)
return y
beta = pe.fits.standard_fit(x, oy, func)
@ -61,7 +64,7 @@ def test_odr_fit():
return a * np.exp(-b * x)
def func(a, x):
y = a[0] * np.exp(-a[1] * x)
y = a[0] * jnp.exp(-a[1] * x)
return y
data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy])

View file

@ -1,7 +1,10 @@
import autograd.numpy as np
import numpy as np
import jax.numpy as jnp
import math
import pyerrors as pe
import pytest
from jax.config import config
config.update("jax_enable_x64", True)
np.random.seed(0)
@ -14,7 +17,7 @@ def test_matrix_inverse():
content.append(1.0) # Add 1.0 as a float
matrix = np.diag(content)
inverse_matrix = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
inverse_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
assert all([o.is_zero() for o in np.diag(matrix) * np.diag(inverse_matrix) - 1])
@ -35,7 +38,7 @@ def test_complex_matrix_inverse():
matrix[n, m] = entry.real.value + 1j * entry.imag.value
inverse_matrix = np.linalg.inv(matrix)
inverse_obs_matrix = pe.linalg.mat_mat_op(np.linalg.inv, obs_matrix)
inverse_obs_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, obs_matrix)
for (n, m), entry in np.ndenumerate(inverse_matrix):
assert np.isclose(inverse_matrix[n, m].real, inverse_obs_matrix[n, m].real.value)
assert np.isclose(inverse_matrix[n, m].imag, inverse_obs_matrix[n, m].imag.value)
@ -53,7 +56,7 @@ def test_matrix_functions():
matrix = np.array(matrix) @ np.identity(dim)
# Check inverse of matrix
inv = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
inv = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
check_inv = matrix @ inv
for (i, j), entry in np.ndenumerate(check_inv):
@ -66,7 +69,7 @@ def test_matrix_functions():
# Check Cholesky decomposition
sym = np.dot(matrix, matrix.T)
cholesky = pe.linalg.mat_mat_op(np.linalg.cholesky, sym)
cholesky = pe.linalg.mat_mat_op(jnp.linalg.cholesky, sym)
check = cholesky @ cholesky.T
for (i, j), entry in np.ndenumerate(check):

View file

@ -1,10 +1,13 @@
import autograd.numpy as np
import numpy as np
import jax.numpy as jnp
import os
import random
import string
import copy
import pyerrors as pe
import pytest
from jax.config import config
config.update("jax_enable_x64", True)
np.random.seed(0)
@ -29,21 +32,31 @@ def test_comparison():
def test_function_overloading():
a = pe.pseudo_Obs(17, 2.9, 'e1')
a = pe.pseudo_Obs(2, 2.9, 'e1')
b = pe.pseudo_Obs(4, 0.8, 'e1')
fs = [lambda x: x[0] + x[1], lambda x: x[1] + x[0], lambda x: x[0] - x[1], lambda x: x[1] - x[0],
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0],
lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0]]
for i, f in enumerate(fs):
t1 = f([a, b])
t2 = pe.derived_observable(f, [a, b])
c = t2 - t1
assert c.value == 0.0, str(i)
assert np.all(np.abs(c.deltas['e1']) < 1e-14), str(i)
assert c.is_zero()
f_np = [lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
f_jnp = [lambda x: jnp.exp(x[0]), lambda x: jnp.sin(x[0]), lambda x: jnp.cos(x[0]), lambda x: jnp.tan(x[0]),
lambda x: jnp.log(x[0]), lambda x: jnp.sqrt(jnp.abs(x[0])),
lambda x: jnp.sinh(x[0]), lambda x: jnp.cosh(x[0]), lambda x: jnp.tanh(x[0])]
for i, (f1, f2) in enumerate(zip(f_np, f_jnp)):
t1 = f1([a])
t2 = pe.derived_observable(f2, [a])
c = t2 - t1
assert c.is_zero()
def test_overloading_vectorization():
@ -121,7 +134,7 @@ def test_derived_observables():
test_obs = pe.pseudo_Obs(2, 0.1 * (1 + np.random.rand()), 't', int(1000 * (1 + np.random.rand())))
# Check if autograd and numgrad give the same result
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs])
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * jnp.sin(x[0] * x[1]), [test_obs, test_obs])
d_Obs_ad.gamma_method()
d_Obs_fd = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs], num_grad=True)
d_Obs_fd.gamma_method()