First try at replacing autograd by jax

This commit is contained in:
Fabian Joswig 2021-10-18 12:53:17 +01:00
parent 8d7a5daafa
commit 8fc5d96363
9 changed files with 76 additions and 55 deletions

View file

@ -1,6 +1,6 @@
import warnings import warnings
import numpy as np import numpy as np
import autograd.numpy as anp import jax.numpy as jnp
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import scipy.linalg import scipy.linalg
from .pyerrors import Obs, dump_object from .pyerrors import Obs, dump_object
@ -187,10 +187,10 @@ class Corr:
def Eigenvalue(self, t0, state=1): def Eigenvalue(self, t0, state=1):
G = self.smearing_symmetric() G = self.smearing_symmetric()
G0 = G.content[t0] G0 = G.content[t0]
L = mat_mat_op(anp.linalg.cholesky, G0) L = mat_mat_op(jnp.linalg.cholesky, G0)
Li = mat_mat_op(anp.linalg.inv, L) Li = mat_mat_op(jnp.linalg.inv, L)
LT = L.T LT = L.T
LTi = mat_mat_op(anp.linalg.inv, LT) LTi = mat_mat_op(jnp.linalg.inv, LT)
newcontent = [] newcontent = []
for t in range(self.T): for t in range(self.T):
Gt = G.content[t] Gt = G.content[t]
@ -263,9 +263,9 @@ class Corr:
elif variant in ['periodic', 'cosh', 'sinh']: elif variant in ['periodic', 'cosh', 'sinh']:
if variant in ['periodic', 'cosh']: if variant in ['periodic', 'cosh']:
func = anp.cosh func = jnp.cosh
else: else:
func = anp.sinh func = jnp.sinh
def root_function(x, d): def root_function(x, d):
return func(x * (t - self.T / 2)) / func(x * (t + 1 - self.T / 2)) - d return func(x * (t - self.T / 2)) / func(x * (t + 1 - self.T / 2)) - d

View file

@ -3,15 +3,15 @@
import warnings import warnings
import numpy as np import numpy as np
import autograd.numpy as anp import jax.numpy as jnp
import scipy.optimize import scipy.optimize
import scipy.stats import scipy.stats
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib import gridspec from matplotlib import gridspec
from scipy.odr import ODR, Model, RealData from scipy.odr import ODR, Model, RealData
import iminuit import iminuit
from autograd import jacobian from jax import jacobian
from autograd import elementwise_grad as egrad #from jax import elementwise_grad as egrad
from .pyerrors import Obs, derived_observable, covariance, pseudo_Obs from .pyerrors import Obs, derived_observable, covariance, pseudo_Obs
@ -24,7 +24,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
func has to be of the form func has to be of the form
def func(a, x): def func(a, x):
return a[0] + a[1] * x + a[2] * anp.sinh(x) return a[0] + a[1] * x + a[2] * jnp.sinh(x)
For multiple x values func can be of the form For multiple x values func can be of the form
@ -82,10 +82,10 @@ def standard_fit(x, y, func, silent=False, **kwargs):
if not silent: if not silent:
print('Fit with', n_parms, 'parameters') print('Fit with', n_parms, 'parameters')
y_f = [o.value for o in y] y_f = np.array([o.value for o in y])
dy_f = [o.dvalue for o in y] dy_f = np.array([o.dvalue for o in y])
if np.any(np.asarray(dy_f) <= 0.0): if np.any(dy_f <= 0.0):
raise Exception('No y errors available, run the gamma method first.') raise Exception('No y errors available, run the gamma method first.')
if 'initial_guess' in kwargs: if 'initial_guess' in kwargs:
@ -97,7 +97,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
def chisqfunc(p): def chisqfunc(p):
model = func(p, x) model = func(p, x)
chisq = anp.sum(((y_f - model) / dy_f) ** 2) chisq = jnp.sum(((y_f - model) / dy_f) ** 2)
return chisq return chisq
if 'method' in kwargs: if 'method' in kwargs:
@ -153,7 +153,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
def chisqfunc_compact(d): def chisqfunc_compact(d):
model = func(d[:n_parms], x) model = func(d[:n_parms], x)
chisq = anp.sum(((d[n_parms:] - model) / dy_f) ** 2) chisq = jnp.sum(((d[n_parms:] - model) / dy_f) ** 2)
return chisq return chisq
jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((fit_result.x, y_f))) jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((fit_result.x, y_f)))
@ -188,7 +188,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
func has to be of the form func has to be of the form
def func(a, x): def func(a, x):
y = a[0] + a[1] * x + a[2] * anp.sinh(x) y = a[0] + a[1] * x + a[2] * jnp.sinh(x)
return y return y
For multiple x values func can be of the form For multiple x values func can be of the form
@ -279,7 +279,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
def odr_chisquare(p): def odr_chisquare(p):
model = func(p[:n_parms], p[n_parms:].reshape(x_shape)) model = func(p[:n_parms], p[n_parms:].reshape(x_shape))
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((x_f - p[n_parms:].reshape(x_shape)) / dx_f) ** 2) chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((x_f - p[n_parms:].reshape(x_shape)) / dx_f) ** 2)
return chisq return chisq
if kwargs.get('expected_chisquare') is True: if kwargs.get('expected_chisquare') is True:
@ -312,7 +312,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
def odr_chisquare_compact_x(d): def odr_chisquare_compact_x(d):
model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape)) model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape))
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((d[n_parms + m:].reshape(x_shape) - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2) chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((d[n_parms + m:].reshape(x_shape) - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
return chisq return chisq
jac_jac_x = jacobian(jacobian(odr_chisquare_compact_x))(np.concatenate((output.beta, output.xplus.ravel(), x_f.ravel()))) jac_jac_x = jacobian(jacobian(odr_chisquare_compact_x))(np.concatenate((output.beta, output.xplus.ravel(), x_f.ravel())))
@ -321,7 +321,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
def odr_chisquare_compact_y(d): def odr_chisquare_compact_y(d):
model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape)) model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape))
chisq = anp.sum(((d[n_parms + m:] - model) / dy_f) ** 2) + anp.sum(((x_f - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2) chisq = jnp.sum(((d[n_parms + m:] - model) / dy_f) ** 2) + jnp.sum(((x_f - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
return chisq return chisq
jac_jac_y = jacobian(jacobian(odr_chisquare_compact_y))(np.concatenate((output.beta, output.xplus.ravel(), y_f))) jac_jac_y = jacobian(jacobian(odr_chisquare_compact_y))(np.concatenate((output.beta, output.xplus.ravel(), y_f)))
@ -349,7 +349,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
func has to be of the form func has to be of the form
def func(a, x): def func(a, x):
y = a[0] + a[1] * x + a[2] * anp.sinh(x) y = a[0] + a[1] * x + a[2] * jnp.sinh(x)
return y return y
It is important that all numpy functions refer to autograd.numpy, otherwise the differentiation It is important that all numpy functions refer to autograd.numpy, otherwise the differentiation
@ -441,7 +441,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
def chisqfunc(p): def chisqfunc(p):
model = func(p, x) model = func(p, x)
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((p_f - p) / dp_f) ** 2) chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((p_f - p) / dp_f) ** 2)
return chisq return chisq
if not silent: if not silent:
@ -469,7 +469,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
def chisqfunc_compact(d): def chisqfunc_compact(d):
model = func(d[:n_parms], x) model = func(d[:n_parms], x)
chisq = anp.sum(((d[n_parms: n_parms + len(x)] - model) / dy_f) ** 2) + anp.sum(((d[n_parms + len(x):] - d[:n_parms]) / dp_f) ** 2) chisq = jnp.sum(((d[n_parms: n_parms + len(x)] - model) / dy_f) ** 2) + jnp.sum(((d[n_parms + len(x):] - d[:n_parms]) / dp_f) ** 2)
return chisq return chisq
jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((params, y_f, p_f))) jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((params, y_f, p_f)))

View file

@ -3,7 +3,7 @@
import ctypes import ctypes
import hashlib import hashlib
import autograd.numpy as np # Thinly-wrapped numpy import jax.numpy as np # Thinly-wrapped numpy
from ..pyerrors import Obs from ..pyerrors import Obs

View file

@ -2,7 +2,7 @@
# coding: utf-8 # coding: utf-8
import numpy as np import numpy as np
import autograd.numpy as anp # Thinly-wrapped numpy import jax.numpy as anp # Thinly-wrapped numpy
from .pyerrors import derived_observable, CObs from .pyerrors import derived_observable, CObs

View file

@ -4,10 +4,12 @@
import warnings import warnings
import pickle import pickle
import numpy as np import numpy as np
import autograd.numpy as anp # Thinly-wrapped numpy import jax.numpy as jnp # Thinly-wrapped numpy
from autograd import jacobian from jax import jacobian
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numdifftools as nd import numdifftools as nd
from jax.config import config
config.update("jax_enable_x64", True)
class Obs: class Obs:
@ -573,7 +575,7 @@ class Obs:
return derived_observable(lambda x: y ** x[0], [self]) return derived_observable(lambda x: y ** x[0], [self])
def __abs__(self): def __abs__(self):
return derived_observable(lambda x: anp.abs(x[0]), [self]) return derived_observable(lambda x: jnp.abs(x[0]), [self])
# Overload numpy functions # Overload numpy functions
def sqrt(self): def sqrt(self):
@ -595,13 +597,13 @@ class Obs:
return derived_observable(lambda x, **kwargs: np.tan(x[0]), [self], man_grad=[1 / np.cos(self.value) ** 2]) return derived_observable(lambda x, **kwargs: np.tan(x[0]), [self], man_grad=[1 / np.cos(self.value) ** 2])
def arcsin(self): def arcsin(self):
return derived_observable(lambda x: anp.arcsin(x[0]), [self]) return derived_observable(lambda x: jnp.arcsin(x[0]), [self])
def arccos(self): def arccos(self):
return derived_observable(lambda x: anp.arccos(x[0]), [self]) return derived_observable(lambda x: jnp.arccos(x[0]), [self])
def arctan(self): def arctan(self):
return derived_observable(lambda x: anp.arctan(x[0]), [self]) return derived_observable(lambda x: jnp.arctan(x[0]), [self])
def sinh(self): def sinh(self):
return derived_observable(lambda x, **kwargs: np.sinh(x[0]), [self], man_grad=[np.cosh(self.value)]) return derived_observable(lambda x, **kwargs: np.sinh(x[0]), [self], man_grad=[np.cosh(self.value)])
@ -613,16 +615,16 @@ class Obs:
return derived_observable(lambda x, **kwargs: np.tanh(x[0]), [self], man_grad=[1 / np.cosh(self.value) ** 2]) return derived_observable(lambda x, **kwargs: np.tanh(x[0]), [self], man_grad=[1 / np.cosh(self.value) ** 2])
def arcsinh(self): def arcsinh(self):
return derived_observable(lambda x: anp.arcsinh(x[0]), [self]) return derived_observable(lambda x: jnp.arcsinh(x[0]), [self])
def arccosh(self): def arccosh(self):
return derived_observable(lambda x: anp.arccosh(x[0]), [self]) return derived_observable(lambda x: jnp.arccosh(x[0]), [self])
def arctanh(self): def arctanh(self):
return derived_observable(lambda x: anp.arctanh(x[0]), [self]) return derived_observable(lambda x: jnp.arctanh(x[0]), [self])
def sinc(self): def sinc(self):
return derived_observable(lambda x: anp.sinc(x[0]), [self]) return derived_observable(lambda x: jnp.sinc(x[0]), [self])
class CObs: class CObs:
@ -695,7 +697,7 @@ def derived_observable(func, data, **kwargs):
---------- ----------
func -- arbitrary function of the form func(data, **kwargs). For the func -- arbitrary function of the form func(data, **kwargs). For the
automatic differentiation to work, all numpy functions have to have automatic differentiation to work, all numpy functions have to have
the autograd wrapper (use 'import autograd.numpy as anp'). the autograd wrapper (use 'import autograd.numpy as jnp').
data -- list of Obs, e.g. [obs1, obs2, obs3]. data -- list of Obs, e.g. [obs1, obs2, obs3].
Keyword arguments Keyword arguments
@ -748,9 +750,9 @@ def derived_observable(func, data, **kwargs):
if new_shape[name] != tmp: if new_shape[name] != tmp:
raise Exception('Shapes of ensemble', name, 'do not match.') raise Exception('Shapes of ensemble', name, 'do not match.')
if data.ndim == 1: if data.ndim == 1:
values = np.array([o.value for o in data]) values = jnp.array([o.value for o in data])
else: else:
values = np.vectorize(lambda x: x.value)(data) values = jnp.array(np.vectorize(lambda x: x.value)(data))
new_values = func(values, **kwargs) new_values = func(values, **kwargs)

View file

@ -2,7 +2,7 @@
# coding: utf-8 # coding: utf-8
import scipy.optimize import scipy.optimize
from autograd import jacobian from jax import jacobian
from .pyerrors import derived_observable, pseudo_Obs from .pyerrors import derived_observable, pseudo_Obs

View file

@ -1,9 +1,12 @@
import autograd.numpy as np import numpy as np
import jax.numpy as jnp
import math import math
import scipy.optimize import scipy.optimize
from scipy.odr import ODR, Model, RealData from scipy.odr import ODR, Model, RealData
import pyerrors as pe import pyerrors as pe
import pytest import pytest
from jax.config import config
config.update("jax_enable_x64", True)
np.random.seed(0) np.random.seed(0)
@ -24,7 +27,7 @@ def test_standard_fit():
popt, pcov = scipy.optimize.curve_fit(f, x, y, sigma=[o.dvalue for o in oy], absolute_sigma=True) popt, pcov = scipy.optimize.curve_fit(f, x, y, sigma=[o.dvalue for o in oy], absolute_sigma=True)
def func(a, x): def func(a, x):
y = a[0] * np.exp(-a[1] * x) y = a[0] * jnp.exp(-a[1] * x)
return y return y
beta = pe.fits.standard_fit(x, oy, func) beta = pe.fits.standard_fit(x, oy, func)
@ -61,7 +64,7 @@ def test_odr_fit():
return a * np.exp(-b * x) return a * np.exp(-b * x)
def func(a, x): def func(a, x):
y = a[0] * np.exp(-a[1] * x) y = a[0] * jnp.exp(-a[1] * x)
return y return y
data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy]) data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy])

View file

@ -1,7 +1,10 @@
import autograd.numpy as np import numpy as np
import jax.numpy as jnp
import math import math
import pyerrors as pe import pyerrors as pe
import pytest import pytest
from jax.config import config
config.update("jax_enable_x64", True)
np.random.seed(0) np.random.seed(0)
@ -14,7 +17,7 @@ def test_matrix_inverse():
content.append(1.0) # Add 1.0 as a float content.append(1.0) # Add 1.0 as a float
matrix = np.diag(content) matrix = np.diag(content)
inverse_matrix = pe.linalg.mat_mat_op(np.linalg.inv, matrix) inverse_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
assert all([o.is_zero() for o in np.diag(matrix) * np.diag(inverse_matrix) - 1]) assert all([o.is_zero() for o in np.diag(matrix) * np.diag(inverse_matrix) - 1])
@ -35,7 +38,7 @@ def test_complex_matrix_inverse():
matrix[n, m] = entry.real.value + 1j * entry.imag.value matrix[n, m] = entry.real.value + 1j * entry.imag.value
inverse_matrix = np.linalg.inv(matrix) inverse_matrix = np.linalg.inv(matrix)
inverse_obs_matrix = pe.linalg.mat_mat_op(np.linalg.inv, obs_matrix) inverse_obs_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, obs_matrix)
for (n, m), entry in np.ndenumerate(inverse_matrix): for (n, m), entry in np.ndenumerate(inverse_matrix):
assert np.isclose(inverse_matrix[n, m].real, inverse_obs_matrix[n, m].real.value) assert np.isclose(inverse_matrix[n, m].real, inverse_obs_matrix[n, m].real.value)
assert np.isclose(inverse_matrix[n, m].imag, inverse_obs_matrix[n, m].imag.value) assert np.isclose(inverse_matrix[n, m].imag, inverse_obs_matrix[n, m].imag.value)
@ -53,7 +56,7 @@ def test_matrix_functions():
matrix = np.array(matrix) @ np.identity(dim) matrix = np.array(matrix) @ np.identity(dim)
# Check inverse of matrix # Check inverse of matrix
inv = pe.linalg.mat_mat_op(np.linalg.inv, matrix) inv = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
check_inv = matrix @ inv check_inv = matrix @ inv
for (i, j), entry in np.ndenumerate(check_inv): for (i, j), entry in np.ndenumerate(check_inv):
@ -66,7 +69,7 @@ def test_matrix_functions():
# Check Cholesky decomposition # Check Cholesky decomposition
sym = np.dot(matrix, matrix.T) sym = np.dot(matrix, matrix.T)
cholesky = pe.linalg.mat_mat_op(np.linalg.cholesky, sym) cholesky = pe.linalg.mat_mat_op(jnp.linalg.cholesky, sym)
check = cholesky @ cholesky.T check = cholesky @ cholesky.T
for (i, j), entry in np.ndenumerate(check): for (i, j), entry in np.ndenumerate(check):

View file

@ -1,10 +1,13 @@
import autograd.numpy as np import numpy as np
import jax.numpy as jnp
import os import os
import random import random
import string import string
import copy import copy
import pyerrors as pe import pyerrors as pe
import pytest import pytest
from jax.config import config
config.update("jax_enable_x64", True)
np.random.seed(0) np.random.seed(0)
@ -29,21 +32,31 @@ def test_comparison():
def test_function_overloading(): def test_function_overloading():
a = pe.pseudo_Obs(17, 2.9, 'e1') a = pe.pseudo_Obs(2, 2.9, 'e1')
b = pe.pseudo_Obs(4, 0.8, 'e1') b = pe.pseudo_Obs(4, 0.8, 'e1')
fs = [lambda x: x[0] + x[1], lambda x: x[1] + x[0], lambda x: x[0] - x[1], lambda x: x[1] - x[0], fs = [lambda x: x[0] + x[1], lambda x: x[1] + x[0], lambda x: x[0] - x[1], lambda x: x[1] - x[0],
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0], lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0]]
lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
for i, f in enumerate(fs): for i, f in enumerate(fs):
t1 = f([a, b]) t1 = f([a, b])
t2 = pe.derived_observable(f, [a, b]) t2 = pe.derived_observable(f, [a, b])
c = t2 - t1 c = t2 - t1
assert c.value == 0.0, str(i) assert c.is_zero()
assert np.all(np.abs(c.deltas['e1']) < 1e-14), str(i)
f_np = [lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
f_jnp = [lambda x: jnp.exp(x[0]), lambda x: jnp.sin(x[0]), lambda x: jnp.cos(x[0]), lambda x: jnp.tan(x[0]),
lambda x: jnp.log(x[0]), lambda x: jnp.sqrt(jnp.abs(x[0])),
lambda x: jnp.sinh(x[0]), lambda x: jnp.cosh(x[0]), lambda x: jnp.tanh(x[0])]
for i, (f1, f2) in enumerate(zip(f_np, f_jnp)):
t1 = f1([a])
t2 = pe.derived_observable(f2, [a])
c = t2 - t1
assert c.is_zero()
def test_overloading_vectorization(): def test_overloading_vectorization():
@ -121,7 +134,7 @@ def test_derived_observables():
test_obs = pe.pseudo_Obs(2, 0.1 * (1 + np.random.rand()), 't', int(1000 * (1 + np.random.rand()))) test_obs = pe.pseudo_Obs(2, 0.1 * (1 + np.random.rand()), 't', int(1000 * (1 + np.random.rand())))
# Check if autograd and numgrad give the same result # Check if autograd and numgrad give the same result
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs]) d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * jnp.sin(x[0] * x[1]), [test_obs, test_obs])
d_Obs_ad.gamma_method() d_Obs_ad.gamma_method()
d_Obs_fd = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs], num_grad=True) d_Obs_fd = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs], num_grad=True)
d_Obs_fd.gamma_method() d_Obs_fd.gamma_method()