mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-15 06:40:24 +01:00
First try at replacing autograd by jax
This commit is contained in:
parent
8d7a5daafa
commit
8fc5d96363
9 changed files with 76 additions and 55 deletions
|
@ -1,6 +1,6 @@
|
|||
import warnings
|
||||
import numpy as np
|
||||
import autograd.numpy as anp
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.linalg
|
||||
from .pyerrors import Obs, dump_object
|
||||
|
@ -187,10 +187,10 @@ class Corr:
|
|||
def Eigenvalue(self, t0, state=1):
|
||||
G = self.smearing_symmetric()
|
||||
G0 = G.content[t0]
|
||||
L = mat_mat_op(anp.linalg.cholesky, G0)
|
||||
Li = mat_mat_op(anp.linalg.inv, L)
|
||||
L = mat_mat_op(jnp.linalg.cholesky, G0)
|
||||
Li = mat_mat_op(jnp.linalg.inv, L)
|
||||
LT = L.T
|
||||
LTi = mat_mat_op(anp.linalg.inv, LT)
|
||||
LTi = mat_mat_op(jnp.linalg.inv, LT)
|
||||
newcontent = []
|
||||
for t in range(self.T):
|
||||
Gt = G.content[t]
|
||||
|
@ -263,9 +263,9 @@ class Corr:
|
|||
|
||||
elif variant in ['periodic', 'cosh', 'sinh']:
|
||||
if variant in ['periodic', 'cosh']:
|
||||
func = anp.cosh
|
||||
func = jnp.cosh
|
||||
else:
|
||||
func = anp.sinh
|
||||
func = jnp.sinh
|
||||
|
||||
def root_function(x, d):
|
||||
return func(x * (t - self.T / 2)) / func(x * (t + 1 - self.T / 2)) - d
|
||||
|
|
|
@ -3,15 +3,15 @@
|
|||
|
||||
import warnings
|
||||
import numpy as np
|
||||
import autograd.numpy as anp
|
||||
import jax.numpy as jnp
|
||||
import scipy.optimize
|
||||
import scipy.stats
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import gridspec
|
||||
from scipy.odr import ODR, Model, RealData
|
||||
import iminuit
|
||||
from autograd import jacobian
|
||||
from autograd import elementwise_grad as egrad
|
||||
from jax import jacobian
|
||||
#from jax import elementwise_grad as egrad
|
||||
from .pyerrors import Obs, derived_observable, covariance, pseudo_Obs
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
|
|||
func has to be of the form
|
||||
|
||||
def func(a, x):
|
||||
return a[0] + a[1] * x + a[2] * anp.sinh(x)
|
||||
return a[0] + a[1] * x + a[2] * jnp.sinh(x)
|
||||
|
||||
For multiple x values func can be of the form
|
||||
|
||||
|
@ -82,10 +82,10 @@ def standard_fit(x, y, func, silent=False, **kwargs):
|
|||
if not silent:
|
||||
print('Fit with', n_parms, 'parameters')
|
||||
|
||||
y_f = [o.value for o in y]
|
||||
dy_f = [o.dvalue for o in y]
|
||||
y_f = np.array([o.value for o in y])
|
||||
dy_f = np.array([o.dvalue for o in y])
|
||||
|
||||
if np.any(np.asarray(dy_f) <= 0.0):
|
||||
if np.any(dy_f <= 0.0):
|
||||
raise Exception('No y errors available, run the gamma method first.')
|
||||
|
||||
if 'initial_guess' in kwargs:
|
||||
|
@ -97,7 +97,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
|
|||
|
||||
def chisqfunc(p):
|
||||
model = func(p, x)
|
||||
chisq = anp.sum(((y_f - model) / dy_f) ** 2)
|
||||
chisq = jnp.sum(((y_f - model) / dy_f) ** 2)
|
||||
return chisq
|
||||
|
||||
if 'method' in kwargs:
|
||||
|
@ -153,7 +153,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
|
|||
|
||||
def chisqfunc_compact(d):
|
||||
model = func(d[:n_parms], x)
|
||||
chisq = anp.sum(((d[n_parms:] - model) / dy_f) ** 2)
|
||||
chisq = jnp.sum(((d[n_parms:] - model) / dy_f) ** 2)
|
||||
return chisq
|
||||
|
||||
jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((fit_result.x, y_f)))
|
||||
|
@ -188,7 +188,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
|
|||
func has to be of the form
|
||||
|
||||
def func(a, x):
|
||||
y = a[0] + a[1] * x + a[2] * anp.sinh(x)
|
||||
y = a[0] + a[1] * x + a[2] * jnp.sinh(x)
|
||||
return y
|
||||
|
||||
For multiple x values func can be of the form
|
||||
|
@ -279,7 +279,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
|
|||
|
||||
def odr_chisquare(p):
|
||||
model = func(p[:n_parms], p[n_parms:].reshape(x_shape))
|
||||
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((x_f - p[n_parms:].reshape(x_shape)) / dx_f) ** 2)
|
||||
chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((x_f - p[n_parms:].reshape(x_shape)) / dx_f) ** 2)
|
||||
return chisq
|
||||
|
||||
if kwargs.get('expected_chisquare') is True:
|
||||
|
@ -312,7 +312,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
|
|||
|
||||
def odr_chisquare_compact_x(d):
|
||||
model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape))
|
||||
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((d[n_parms + m:].reshape(x_shape) - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
|
||||
chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((d[n_parms + m:].reshape(x_shape) - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
|
||||
return chisq
|
||||
|
||||
jac_jac_x = jacobian(jacobian(odr_chisquare_compact_x))(np.concatenate((output.beta, output.xplus.ravel(), x_f.ravel())))
|
||||
|
@ -321,7 +321,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
|
|||
|
||||
def odr_chisquare_compact_y(d):
|
||||
model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape))
|
||||
chisq = anp.sum(((d[n_parms + m:] - model) / dy_f) ** 2) + anp.sum(((x_f - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
|
||||
chisq = jnp.sum(((d[n_parms + m:] - model) / dy_f) ** 2) + jnp.sum(((x_f - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
|
||||
return chisq
|
||||
|
||||
jac_jac_y = jacobian(jacobian(odr_chisquare_compact_y))(np.concatenate((output.beta, output.xplus.ravel(), y_f)))
|
||||
|
@ -349,7 +349,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
|
|||
func has to be of the form
|
||||
|
||||
def func(a, x):
|
||||
y = a[0] + a[1] * x + a[2] * anp.sinh(x)
|
||||
y = a[0] + a[1] * x + a[2] * jnp.sinh(x)
|
||||
return y
|
||||
|
||||
It is important that all numpy functions refer to autograd.numpy, otherwise the differentiation
|
||||
|
@ -441,7 +441,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
|
|||
|
||||
def chisqfunc(p):
|
||||
model = func(p, x)
|
||||
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((p_f - p) / dp_f) ** 2)
|
||||
chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((p_f - p) / dp_f) ** 2)
|
||||
return chisq
|
||||
|
||||
if not silent:
|
||||
|
@ -469,7 +469,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
|
|||
|
||||
def chisqfunc_compact(d):
|
||||
model = func(d[:n_parms], x)
|
||||
chisq = anp.sum(((d[n_parms: n_parms + len(x)] - model) / dy_f) ** 2) + anp.sum(((d[n_parms + len(x):] - d[:n_parms]) / dp_f) ** 2)
|
||||
chisq = jnp.sum(((d[n_parms: n_parms + len(x)] - model) / dy_f) ** 2) + jnp.sum(((d[n_parms + len(x):] - d[:n_parms]) / dp_f) ** 2)
|
||||
return chisq
|
||||
|
||||
jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((params, y_f, p_f)))
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import ctypes
|
||||
import hashlib
|
||||
import autograd.numpy as np # Thinly-wrapped numpy
|
||||
import jax.numpy as np # Thinly-wrapped numpy
|
||||
from ..pyerrors import Obs
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# coding: utf-8
|
||||
|
||||
import numpy as np
|
||||
import autograd.numpy as anp # Thinly-wrapped numpy
|
||||
import jax.numpy as anp # Thinly-wrapped numpy
|
||||
from .pyerrors import derived_observable, CObs
|
||||
|
||||
|
||||
|
|
|
@ -4,10 +4,12 @@
|
|||
import warnings
|
||||
import pickle
|
||||
import numpy as np
|
||||
import autograd.numpy as anp # Thinly-wrapped numpy
|
||||
from autograd import jacobian
|
||||
import jax.numpy as jnp # Thinly-wrapped numpy
|
||||
from jax import jacobian
|
||||
import matplotlib.pyplot as plt
|
||||
import numdifftools as nd
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
class Obs:
|
||||
|
@ -573,7 +575,7 @@ class Obs:
|
|||
return derived_observable(lambda x: y ** x[0], [self])
|
||||
|
||||
def __abs__(self):
|
||||
return derived_observable(lambda x: anp.abs(x[0]), [self])
|
||||
return derived_observable(lambda x: jnp.abs(x[0]), [self])
|
||||
|
||||
# Overload numpy functions
|
||||
def sqrt(self):
|
||||
|
@ -595,13 +597,13 @@ class Obs:
|
|||
return derived_observable(lambda x, **kwargs: np.tan(x[0]), [self], man_grad=[1 / np.cos(self.value) ** 2])
|
||||
|
||||
def arcsin(self):
|
||||
return derived_observable(lambda x: anp.arcsin(x[0]), [self])
|
||||
return derived_observable(lambda x: jnp.arcsin(x[0]), [self])
|
||||
|
||||
def arccos(self):
|
||||
return derived_observable(lambda x: anp.arccos(x[0]), [self])
|
||||
return derived_observable(lambda x: jnp.arccos(x[0]), [self])
|
||||
|
||||
def arctan(self):
|
||||
return derived_observable(lambda x: anp.arctan(x[0]), [self])
|
||||
return derived_observable(lambda x: jnp.arctan(x[0]), [self])
|
||||
|
||||
def sinh(self):
|
||||
return derived_observable(lambda x, **kwargs: np.sinh(x[0]), [self], man_grad=[np.cosh(self.value)])
|
||||
|
@ -613,16 +615,16 @@ class Obs:
|
|||
return derived_observable(lambda x, **kwargs: np.tanh(x[0]), [self], man_grad=[1 / np.cosh(self.value) ** 2])
|
||||
|
||||
def arcsinh(self):
|
||||
return derived_observable(lambda x: anp.arcsinh(x[0]), [self])
|
||||
return derived_observable(lambda x: jnp.arcsinh(x[0]), [self])
|
||||
|
||||
def arccosh(self):
|
||||
return derived_observable(lambda x: anp.arccosh(x[0]), [self])
|
||||
return derived_observable(lambda x: jnp.arccosh(x[0]), [self])
|
||||
|
||||
def arctanh(self):
|
||||
return derived_observable(lambda x: anp.arctanh(x[0]), [self])
|
||||
return derived_observable(lambda x: jnp.arctanh(x[0]), [self])
|
||||
|
||||
def sinc(self):
|
||||
return derived_observable(lambda x: anp.sinc(x[0]), [self])
|
||||
return derived_observable(lambda x: jnp.sinc(x[0]), [self])
|
||||
|
||||
|
||||
class CObs:
|
||||
|
@ -695,7 +697,7 @@ def derived_observable(func, data, **kwargs):
|
|||
----------
|
||||
func -- arbitrary function of the form func(data, **kwargs). For the
|
||||
automatic differentiation to work, all numpy functions have to have
|
||||
the autograd wrapper (use 'import autograd.numpy as anp').
|
||||
the autograd wrapper (use 'import autograd.numpy as jnp').
|
||||
data -- list of Obs, e.g. [obs1, obs2, obs3].
|
||||
|
||||
Keyword arguments
|
||||
|
@ -748,9 +750,9 @@ def derived_observable(func, data, **kwargs):
|
|||
if new_shape[name] != tmp:
|
||||
raise Exception('Shapes of ensemble', name, 'do not match.')
|
||||
if data.ndim == 1:
|
||||
values = np.array([o.value for o in data])
|
||||
values = jnp.array([o.value for o in data])
|
||||
else:
|
||||
values = np.vectorize(lambda x: x.value)(data)
|
||||
values = jnp.array(np.vectorize(lambda x: x.value)(data))
|
||||
|
||||
new_values = func(values, **kwargs)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# coding: utf-8
|
||||
|
||||
import scipy.optimize
|
||||
from autograd import jacobian
|
||||
from jax import jacobian
|
||||
from .pyerrors import derived_observable, pseudo_Obs
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import autograd.numpy as np
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import math
|
||||
import scipy.optimize
|
||||
from scipy.odr import ODR, Model, RealData
|
||||
import pyerrors as pe
|
||||
import pytest
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
@ -24,7 +27,7 @@ def test_standard_fit():
|
|||
popt, pcov = scipy.optimize.curve_fit(f, x, y, sigma=[o.dvalue for o in oy], absolute_sigma=True)
|
||||
|
||||
def func(a, x):
|
||||
y = a[0] * np.exp(-a[1] * x)
|
||||
y = a[0] * jnp.exp(-a[1] * x)
|
||||
return y
|
||||
|
||||
beta = pe.fits.standard_fit(x, oy, func)
|
||||
|
@ -61,7 +64,7 @@ def test_odr_fit():
|
|||
return a * np.exp(-b * x)
|
||||
|
||||
def func(a, x):
|
||||
y = a[0] * np.exp(-a[1] * x)
|
||||
y = a[0] * jnp.exp(-a[1] * x)
|
||||
return y
|
||||
|
||||
data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy])
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import autograd.numpy as np
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import math
|
||||
import pyerrors as pe
|
||||
import pytest
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
@ -14,7 +17,7 @@ def test_matrix_inverse():
|
|||
|
||||
content.append(1.0) # Add 1.0 as a float
|
||||
matrix = np.diag(content)
|
||||
inverse_matrix = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
|
||||
inverse_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
|
||||
assert all([o.is_zero() for o in np.diag(matrix) * np.diag(inverse_matrix) - 1])
|
||||
|
||||
|
||||
|
@ -35,7 +38,7 @@ def test_complex_matrix_inverse():
|
|||
matrix[n, m] = entry.real.value + 1j * entry.imag.value
|
||||
|
||||
inverse_matrix = np.linalg.inv(matrix)
|
||||
inverse_obs_matrix = pe.linalg.mat_mat_op(np.linalg.inv, obs_matrix)
|
||||
inverse_obs_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, obs_matrix)
|
||||
for (n, m), entry in np.ndenumerate(inverse_matrix):
|
||||
assert np.isclose(inverse_matrix[n, m].real, inverse_obs_matrix[n, m].real.value)
|
||||
assert np.isclose(inverse_matrix[n, m].imag, inverse_obs_matrix[n, m].imag.value)
|
||||
|
@ -53,7 +56,7 @@ def test_matrix_functions():
|
|||
matrix = np.array(matrix) @ np.identity(dim)
|
||||
|
||||
# Check inverse of matrix
|
||||
inv = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
|
||||
inv = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
|
||||
check_inv = matrix @ inv
|
||||
|
||||
for (i, j), entry in np.ndenumerate(check_inv):
|
||||
|
@ -66,7 +69,7 @@ def test_matrix_functions():
|
|||
|
||||
# Check Cholesky decomposition
|
||||
sym = np.dot(matrix, matrix.T)
|
||||
cholesky = pe.linalg.mat_mat_op(np.linalg.cholesky, sym)
|
||||
cholesky = pe.linalg.mat_mat_op(jnp.linalg.cholesky, sym)
|
||||
check = cholesky @ cholesky.T
|
||||
|
||||
for (i, j), entry in np.ndenumerate(check):
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import autograd.numpy as np
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import copy
|
||||
import pyerrors as pe
|
||||
import pytest
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
@ -29,21 +32,31 @@ def test_comparison():
|
|||
|
||||
|
||||
def test_function_overloading():
|
||||
a = pe.pseudo_Obs(17, 2.9, 'e1')
|
||||
a = pe.pseudo_Obs(2, 2.9, 'e1')
|
||||
b = pe.pseudo_Obs(4, 0.8, 'e1')
|
||||
|
||||
fs = [lambda x: x[0] + x[1], lambda x: x[1] + x[0], lambda x: x[0] - x[1], lambda x: x[1] - x[0],
|
||||
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0],
|
||||
lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
|
||||
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
|
||||
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
|
||||
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0]]
|
||||
|
||||
for i, f in enumerate(fs):
|
||||
t1 = f([a, b])
|
||||
t2 = pe.derived_observable(f, [a, b])
|
||||
c = t2 - t1
|
||||
assert c.value == 0.0, str(i)
|
||||
assert np.all(np.abs(c.deltas['e1']) < 1e-14), str(i)
|
||||
assert c.is_zero()
|
||||
|
||||
|
||||
f_np = [lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
|
||||
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
|
||||
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
|
||||
f_jnp = [lambda x: jnp.exp(x[0]), lambda x: jnp.sin(x[0]), lambda x: jnp.cos(x[0]), lambda x: jnp.tan(x[0]),
|
||||
lambda x: jnp.log(x[0]), lambda x: jnp.sqrt(jnp.abs(x[0])),
|
||||
lambda x: jnp.sinh(x[0]), lambda x: jnp.cosh(x[0]), lambda x: jnp.tanh(x[0])]
|
||||
|
||||
for i, (f1, f2) in enumerate(zip(f_np, f_jnp)):
|
||||
t1 = f1([a])
|
||||
t2 = pe.derived_observable(f2, [a])
|
||||
c = t2 - t1
|
||||
assert c.is_zero()
|
||||
|
||||
|
||||
def test_overloading_vectorization():
|
||||
|
@ -121,7 +134,7 @@ def test_derived_observables():
|
|||
test_obs = pe.pseudo_Obs(2, 0.1 * (1 + np.random.rand()), 't', int(1000 * (1 + np.random.rand())))
|
||||
|
||||
# Check if autograd and numgrad give the same result
|
||||
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs])
|
||||
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * jnp.sin(x[0] * x[1]), [test_obs, test_obs])
|
||||
d_Obs_ad.gamma_method()
|
||||
d_Obs_fd = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs], num_grad=True)
|
||||
d_Obs_fd.gamma_method()
|
||||
|
|
Loading…
Add table
Reference in a new issue