mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-15 14:50:25 +01:00
First try at replacing autograd by jax
This commit is contained in:
parent
8d7a5daafa
commit
8fc5d96363
9 changed files with 76 additions and 55 deletions
|
@ -1,6 +1,6 @@
|
||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import autograd.numpy as anp
|
import jax.numpy as jnp
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import scipy.linalg
|
import scipy.linalg
|
||||||
from .pyerrors import Obs, dump_object
|
from .pyerrors import Obs, dump_object
|
||||||
|
@ -187,10 +187,10 @@ class Corr:
|
||||||
def Eigenvalue(self, t0, state=1):
|
def Eigenvalue(self, t0, state=1):
|
||||||
G = self.smearing_symmetric()
|
G = self.smearing_symmetric()
|
||||||
G0 = G.content[t0]
|
G0 = G.content[t0]
|
||||||
L = mat_mat_op(anp.linalg.cholesky, G0)
|
L = mat_mat_op(jnp.linalg.cholesky, G0)
|
||||||
Li = mat_mat_op(anp.linalg.inv, L)
|
Li = mat_mat_op(jnp.linalg.inv, L)
|
||||||
LT = L.T
|
LT = L.T
|
||||||
LTi = mat_mat_op(anp.linalg.inv, LT)
|
LTi = mat_mat_op(jnp.linalg.inv, LT)
|
||||||
newcontent = []
|
newcontent = []
|
||||||
for t in range(self.T):
|
for t in range(self.T):
|
||||||
Gt = G.content[t]
|
Gt = G.content[t]
|
||||||
|
@ -263,9 +263,9 @@ class Corr:
|
||||||
|
|
||||||
elif variant in ['periodic', 'cosh', 'sinh']:
|
elif variant in ['periodic', 'cosh', 'sinh']:
|
||||||
if variant in ['periodic', 'cosh']:
|
if variant in ['periodic', 'cosh']:
|
||||||
func = anp.cosh
|
func = jnp.cosh
|
||||||
else:
|
else:
|
||||||
func = anp.sinh
|
func = jnp.sinh
|
||||||
|
|
||||||
def root_function(x, d):
|
def root_function(x, d):
|
||||||
return func(x * (t - self.T / 2)) / func(x * (t + 1 - self.T / 2)) - d
|
return func(x * (t - self.T / 2)) / func(x * (t + 1 - self.T / 2)) - d
|
||||||
|
|
|
@ -3,15 +3,15 @@
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import autograd.numpy as anp
|
import jax.numpy as jnp
|
||||||
import scipy.optimize
|
import scipy.optimize
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib import gridspec
|
from matplotlib import gridspec
|
||||||
from scipy.odr import ODR, Model, RealData
|
from scipy.odr import ODR, Model, RealData
|
||||||
import iminuit
|
import iminuit
|
||||||
from autograd import jacobian
|
from jax import jacobian
|
||||||
from autograd import elementwise_grad as egrad
|
#from jax import elementwise_grad as egrad
|
||||||
from .pyerrors import Obs, derived_observable, covariance, pseudo_Obs
|
from .pyerrors import Obs, derived_observable, covariance, pseudo_Obs
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
|
||||||
func has to be of the form
|
func has to be of the form
|
||||||
|
|
||||||
def func(a, x):
|
def func(a, x):
|
||||||
return a[0] + a[1] * x + a[2] * anp.sinh(x)
|
return a[0] + a[1] * x + a[2] * jnp.sinh(x)
|
||||||
|
|
||||||
For multiple x values func can be of the form
|
For multiple x values func can be of the form
|
||||||
|
|
||||||
|
@ -82,10 +82,10 @@ def standard_fit(x, y, func, silent=False, **kwargs):
|
||||||
if not silent:
|
if not silent:
|
||||||
print('Fit with', n_parms, 'parameters')
|
print('Fit with', n_parms, 'parameters')
|
||||||
|
|
||||||
y_f = [o.value for o in y]
|
y_f = np.array([o.value for o in y])
|
||||||
dy_f = [o.dvalue for o in y]
|
dy_f = np.array([o.dvalue for o in y])
|
||||||
|
|
||||||
if np.any(np.asarray(dy_f) <= 0.0):
|
if np.any(dy_f <= 0.0):
|
||||||
raise Exception('No y errors available, run the gamma method first.')
|
raise Exception('No y errors available, run the gamma method first.')
|
||||||
|
|
||||||
if 'initial_guess' in kwargs:
|
if 'initial_guess' in kwargs:
|
||||||
|
@ -97,7 +97,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
|
||||||
|
|
||||||
def chisqfunc(p):
|
def chisqfunc(p):
|
||||||
model = func(p, x)
|
model = func(p, x)
|
||||||
chisq = anp.sum(((y_f - model) / dy_f) ** 2)
|
chisq = jnp.sum(((y_f - model) / dy_f) ** 2)
|
||||||
return chisq
|
return chisq
|
||||||
|
|
||||||
if 'method' in kwargs:
|
if 'method' in kwargs:
|
||||||
|
@ -153,7 +153,7 @@ def standard_fit(x, y, func, silent=False, **kwargs):
|
||||||
|
|
||||||
def chisqfunc_compact(d):
|
def chisqfunc_compact(d):
|
||||||
model = func(d[:n_parms], x)
|
model = func(d[:n_parms], x)
|
||||||
chisq = anp.sum(((d[n_parms:] - model) / dy_f) ** 2)
|
chisq = jnp.sum(((d[n_parms:] - model) / dy_f) ** 2)
|
||||||
return chisq
|
return chisq
|
||||||
|
|
||||||
jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((fit_result.x, y_f)))
|
jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((fit_result.x, y_f)))
|
||||||
|
@ -188,7 +188,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
|
||||||
func has to be of the form
|
func has to be of the form
|
||||||
|
|
||||||
def func(a, x):
|
def func(a, x):
|
||||||
y = a[0] + a[1] * x + a[2] * anp.sinh(x)
|
y = a[0] + a[1] * x + a[2] * jnp.sinh(x)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
For multiple x values func can be of the form
|
For multiple x values func can be of the form
|
||||||
|
@ -279,7 +279,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
|
||||||
|
|
||||||
def odr_chisquare(p):
|
def odr_chisquare(p):
|
||||||
model = func(p[:n_parms], p[n_parms:].reshape(x_shape))
|
model = func(p[:n_parms], p[n_parms:].reshape(x_shape))
|
||||||
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((x_f - p[n_parms:].reshape(x_shape)) / dx_f) ** 2)
|
chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((x_f - p[n_parms:].reshape(x_shape)) / dx_f) ** 2)
|
||||||
return chisq
|
return chisq
|
||||||
|
|
||||||
if kwargs.get('expected_chisquare') is True:
|
if kwargs.get('expected_chisquare') is True:
|
||||||
|
@ -312,7 +312,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
|
||||||
|
|
||||||
def odr_chisquare_compact_x(d):
|
def odr_chisquare_compact_x(d):
|
||||||
model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape))
|
model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape))
|
||||||
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((d[n_parms + m:].reshape(x_shape) - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
|
chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((d[n_parms + m:].reshape(x_shape) - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
|
||||||
return chisq
|
return chisq
|
||||||
|
|
||||||
jac_jac_x = jacobian(jacobian(odr_chisquare_compact_x))(np.concatenate((output.beta, output.xplus.ravel(), x_f.ravel())))
|
jac_jac_x = jacobian(jacobian(odr_chisquare_compact_x))(np.concatenate((output.beta, output.xplus.ravel(), x_f.ravel())))
|
||||||
|
@ -321,7 +321,7 @@ def odr_fit(x, y, func, silent=False, **kwargs):
|
||||||
|
|
||||||
def odr_chisquare_compact_y(d):
|
def odr_chisquare_compact_y(d):
|
||||||
model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape))
|
model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape))
|
||||||
chisq = anp.sum(((d[n_parms + m:] - model) / dy_f) ** 2) + anp.sum(((x_f - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
|
chisq = jnp.sum(((d[n_parms + m:] - model) / dy_f) ** 2) + jnp.sum(((x_f - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2)
|
||||||
return chisq
|
return chisq
|
||||||
|
|
||||||
jac_jac_y = jacobian(jacobian(odr_chisquare_compact_y))(np.concatenate((output.beta, output.xplus.ravel(), y_f)))
|
jac_jac_y = jacobian(jacobian(odr_chisquare_compact_y))(np.concatenate((output.beta, output.xplus.ravel(), y_f)))
|
||||||
|
@ -349,7 +349,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
|
||||||
func has to be of the form
|
func has to be of the form
|
||||||
|
|
||||||
def func(a, x):
|
def func(a, x):
|
||||||
y = a[0] + a[1] * x + a[2] * anp.sinh(x)
|
y = a[0] + a[1] * x + a[2] * jnp.sinh(x)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
It is important that all numpy functions refer to autograd.numpy, otherwise the differentiation
|
It is important that all numpy functions refer to autograd.numpy, otherwise the differentiation
|
||||||
|
@ -441,7 +441,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
|
||||||
|
|
||||||
def chisqfunc(p):
|
def chisqfunc(p):
|
||||||
model = func(p, x)
|
model = func(p, x)
|
||||||
chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((p_f - p) / dp_f) ** 2)
|
chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((p_f - p) / dp_f) ** 2)
|
||||||
return chisq
|
return chisq
|
||||||
|
|
||||||
if not silent:
|
if not silent:
|
||||||
|
@ -469,7 +469,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs):
|
||||||
|
|
||||||
def chisqfunc_compact(d):
|
def chisqfunc_compact(d):
|
||||||
model = func(d[:n_parms], x)
|
model = func(d[:n_parms], x)
|
||||||
chisq = anp.sum(((d[n_parms: n_parms + len(x)] - model) / dy_f) ** 2) + anp.sum(((d[n_parms + len(x):] - d[:n_parms]) / dp_f) ** 2)
|
chisq = jnp.sum(((d[n_parms: n_parms + len(x)] - model) / dy_f) ** 2) + jnp.sum(((d[n_parms + len(x):] - d[:n_parms]) / dp_f) ** 2)
|
||||||
return chisq
|
return chisq
|
||||||
|
|
||||||
jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((params, y_f, p_f)))
|
jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((params, y_f, p_f)))
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
import hashlib
|
import hashlib
|
||||||
import autograd.numpy as np # Thinly-wrapped numpy
|
import jax.numpy as np # Thinly-wrapped numpy
|
||||||
from ..pyerrors import Obs
|
from ..pyerrors import Obs
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import autograd.numpy as anp # Thinly-wrapped numpy
|
import jax.numpy as anp # Thinly-wrapped numpy
|
||||||
from .pyerrors import derived_observable, CObs
|
from .pyerrors import derived_observable, CObs
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,12 @@
|
||||||
import warnings
|
import warnings
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import autograd.numpy as anp # Thinly-wrapped numpy
|
import jax.numpy as jnp # Thinly-wrapped numpy
|
||||||
from autograd import jacobian
|
from jax import jacobian
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numdifftools as nd
|
import numdifftools as nd
|
||||||
|
from jax.config import config
|
||||||
|
config.update("jax_enable_x64", True)
|
||||||
|
|
||||||
|
|
||||||
class Obs:
|
class Obs:
|
||||||
|
@ -573,7 +575,7 @@ class Obs:
|
||||||
return derived_observable(lambda x: y ** x[0], [self])
|
return derived_observable(lambda x: y ** x[0], [self])
|
||||||
|
|
||||||
def __abs__(self):
|
def __abs__(self):
|
||||||
return derived_observable(lambda x: anp.abs(x[0]), [self])
|
return derived_observable(lambda x: jnp.abs(x[0]), [self])
|
||||||
|
|
||||||
# Overload numpy functions
|
# Overload numpy functions
|
||||||
def sqrt(self):
|
def sqrt(self):
|
||||||
|
@ -595,13 +597,13 @@ class Obs:
|
||||||
return derived_observable(lambda x, **kwargs: np.tan(x[0]), [self], man_grad=[1 / np.cos(self.value) ** 2])
|
return derived_observable(lambda x, **kwargs: np.tan(x[0]), [self], man_grad=[1 / np.cos(self.value) ** 2])
|
||||||
|
|
||||||
def arcsin(self):
|
def arcsin(self):
|
||||||
return derived_observable(lambda x: anp.arcsin(x[0]), [self])
|
return derived_observable(lambda x: jnp.arcsin(x[0]), [self])
|
||||||
|
|
||||||
def arccos(self):
|
def arccos(self):
|
||||||
return derived_observable(lambda x: anp.arccos(x[0]), [self])
|
return derived_observable(lambda x: jnp.arccos(x[0]), [self])
|
||||||
|
|
||||||
def arctan(self):
|
def arctan(self):
|
||||||
return derived_observable(lambda x: anp.arctan(x[0]), [self])
|
return derived_observable(lambda x: jnp.arctan(x[0]), [self])
|
||||||
|
|
||||||
def sinh(self):
|
def sinh(self):
|
||||||
return derived_observable(lambda x, **kwargs: np.sinh(x[0]), [self], man_grad=[np.cosh(self.value)])
|
return derived_observable(lambda x, **kwargs: np.sinh(x[0]), [self], man_grad=[np.cosh(self.value)])
|
||||||
|
@ -613,16 +615,16 @@ class Obs:
|
||||||
return derived_observable(lambda x, **kwargs: np.tanh(x[0]), [self], man_grad=[1 / np.cosh(self.value) ** 2])
|
return derived_observable(lambda x, **kwargs: np.tanh(x[0]), [self], man_grad=[1 / np.cosh(self.value) ** 2])
|
||||||
|
|
||||||
def arcsinh(self):
|
def arcsinh(self):
|
||||||
return derived_observable(lambda x: anp.arcsinh(x[0]), [self])
|
return derived_observable(lambda x: jnp.arcsinh(x[0]), [self])
|
||||||
|
|
||||||
def arccosh(self):
|
def arccosh(self):
|
||||||
return derived_observable(lambda x: anp.arccosh(x[0]), [self])
|
return derived_observable(lambda x: jnp.arccosh(x[0]), [self])
|
||||||
|
|
||||||
def arctanh(self):
|
def arctanh(self):
|
||||||
return derived_observable(lambda x: anp.arctanh(x[0]), [self])
|
return derived_observable(lambda x: jnp.arctanh(x[0]), [self])
|
||||||
|
|
||||||
def sinc(self):
|
def sinc(self):
|
||||||
return derived_observable(lambda x: anp.sinc(x[0]), [self])
|
return derived_observable(lambda x: jnp.sinc(x[0]), [self])
|
||||||
|
|
||||||
|
|
||||||
class CObs:
|
class CObs:
|
||||||
|
@ -695,7 +697,7 @@ def derived_observable(func, data, **kwargs):
|
||||||
----------
|
----------
|
||||||
func -- arbitrary function of the form func(data, **kwargs). For the
|
func -- arbitrary function of the form func(data, **kwargs). For the
|
||||||
automatic differentiation to work, all numpy functions have to have
|
automatic differentiation to work, all numpy functions have to have
|
||||||
the autograd wrapper (use 'import autograd.numpy as anp').
|
the autograd wrapper (use 'import autograd.numpy as jnp').
|
||||||
data -- list of Obs, e.g. [obs1, obs2, obs3].
|
data -- list of Obs, e.g. [obs1, obs2, obs3].
|
||||||
|
|
||||||
Keyword arguments
|
Keyword arguments
|
||||||
|
@ -748,9 +750,9 @@ def derived_observable(func, data, **kwargs):
|
||||||
if new_shape[name] != tmp:
|
if new_shape[name] != tmp:
|
||||||
raise Exception('Shapes of ensemble', name, 'do not match.')
|
raise Exception('Shapes of ensemble', name, 'do not match.')
|
||||||
if data.ndim == 1:
|
if data.ndim == 1:
|
||||||
values = np.array([o.value for o in data])
|
values = jnp.array([o.value for o in data])
|
||||||
else:
|
else:
|
||||||
values = np.vectorize(lambda x: x.value)(data)
|
values = jnp.array(np.vectorize(lambda x: x.value)(data))
|
||||||
|
|
||||||
new_values = func(values, **kwargs)
|
new_values = func(values, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
import scipy.optimize
|
import scipy.optimize
|
||||||
from autograd import jacobian
|
from jax import jacobian
|
||||||
from .pyerrors import derived_observable, pseudo_Obs
|
from .pyerrors import derived_observable, pseudo_Obs
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
import autograd.numpy as np
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
import math
|
import math
|
||||||
import scipy.optimize
|
import scipy.optimize
|
||||||
from scipy.odr import ODR, Model, RealData
|
from scipy.odr import ODR, Model, RealData
|
||||||
import pyerrors as pe
|
import pyerrors as pe
|
||||||
import pytest
|
import pytest
|
||||||
|
from jax.config import config
|
||||||
|
config.update("jax_enable_x64", True)
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
|
@ -24,7 +27,7 @@ def test_standard_fit():
|
||||||
popt, pcov = scipy.optimize.curve_fit(f, x, y, sigma=[o.dvalue for o in oy], absolute_sigma=True)
|
popt, pcov = scipy.optimize.curve_fit(f, x, y, sigma=[o.dvalue for o in oy], absolute_sigma=True)
|
||||||
|
|
||||||
def func(a, x):
|
def func(a, x):
|
||||||
y = a[0] * np.exp(-a[1] * x)
|
y = a[0] * jnp.exp(-a[1] * x)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
beta = pe.fits.standard_fit(x, oy, func)
|
beta = pe.fits.standard_fit(x, oy, func)
|
||||||
|
@ -61,7 +64,7 @@ def test_odr_fit():
|
||||||
return a * np.exp(-b * x)
|
return a * np.exp(-b * x)
|
||||||
|
|
||||||
def func(a, x):
|
def func(a, x):
|
||||||
y = a[0] * np.exp(-a[1] * x)
|
y = a[0] * jnp.exp(-a[1] * x)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy])
|
data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy])
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import autograd.numpy as np
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
import math
|
import math
|
||||||
import pyerrors as pe
|
import pyerrors as pe
|
||||||
import pytest
|
import pytest
|
||||||
|
from jax.config import config
|
||||||
|
config.update("jax_enable_x64", True)
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
|
@ -14,7 +17,7 @@ def test_matrix_inverse():
|
||||||
|
|
||||||
content.append(1.0) # Add 1.0 as a float
|
content.append(1.0) # Add 1.0 as a float
|
||||||
matrix = np.diag(content)
|
matrix = np.diag(content)
|
||||||
inverse_matrix = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
|
inverse_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
|
||||||
assert all([o.is_zero() for o in np.diag(matrix) * np.diag(inverse_matrix) - 1])
|
assert all([o.is_zero() for o in np.diag(matrix) * np.diag(inverse_matrix) - 1])
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,7 +38,7 @@ def test_complex_matrix_inverse():
|
||||||
matrix[n, m] = entry.real.value + 1j * entry.imag.value
|
matrix[n, m] = entry.real.value + 1j * entry.imag.value
|
||||||
|
|
||||||
inverse_matrix = np.linalg.inv(matrix)
|
inverse_matrix = np.linalg.inv(matrix)
|
||||||
inverse_obs_matrix = pe.linalg.mat_mat_op(np.linalg.inv, obs_matrix)
|
inverse_obs_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, obs_matrix)
|
||||||
for (n, m), entry in np.ndenumerate(inverse_matrix):
|
for (n, m), entry in np.ndenumerate(inverse_matrix):
|
||||||
assert np.isclose(inverse_matrix[n, m].real, inverse_obs_matrix[n, m].real.value)
|
assert np.isclose(inverse_matrix[n, m].real, inverse_obs_matrix[n, m].real.value)
|
||||||
assert np.isclose(inverse_matrix[n, m].imag, inverse_obs_matrix[n, m].imag.value)
|
assert np.isclose(inverse_matrix[n, m].imag, inverse_obs_matrix[n, m].imag.value)
|
||||||
|
@ -53,7 +56,7 @@ def test_matrix_functions():
|
||||||
matrix = np.array(matrix) @ np.identity(dim)
|
matrix = np.array(matrix) @ np.identity(dim)
|
||||||
|
|
||||||
# Check inverse of matrix
|
# Check inverse of matrix
|
||||||
inv = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
|
inv = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
|
||||||
check_inv = matrix @ inv
|
check_inv = matrix @ inv
|
||||||
|
|
||||||
for (i, j), entry in np.ndenumerate(check_inv):
|
for (i, j), entry in np.ndenumerate(check_inv):
|
||||||
|
@ -66,7 +69,7 @@ def test_matrix_functions():
|
||||||
|
|
||||||
# Check Cholesky decomposition
|
# Check Cholesky decomposition
|
||||||
sym = np.dot(matrix, matrix.T)
|
sym = np.dot(matrix, matrix.T)
|
||||||
cholesky = pe.linalg.mat_mat_op(np.linalg.cholesky, sym)
|
cholesky = pe.linalg.mat_mat_op(jnp.linalg.cholesky, sym)
|
||||||
check = cholesky @ cholesky.T
|
check = cholesky @ cholesky.T
|
||||||
|
|
||||||
for (i, j), entry in np.ndenumerate(check):
|
for (i, j), entry in np.ndenumerate(check):
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
import autograd.numpy as np
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
import copy
|
import copy
|
||||||
import pyerrors as pe
|
import pyerrors as pe
|
||||||
import pytest
|
import pytest
|
||||||
|
from jax.config import config
|
||||||
|
config.update("jax_enable_x64", True)
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
|
@ -29,21 +32,31 @@ def test_comparison():
|
||||||
|
|
||||||
|
|
||||||
def test_function_overloading():
|
def test_function_overloading():
|
||||||
a = pe.pseudo_Obs(17, 2.9, 'e1')
|
a = pe.pseudo_Obs(2, 2.9, 'e1')
|
||||||
b = pe.pseudo_Obs(4, 0.8, 'e1')
|
b = pe.pseudo_Obs(4, 0.8, 'e1')
|
||||||
|
|
||||||
fs = [lambda x: x[0] + x[1], lambda x: x[1] + x[0], lambda x: x[0] - x[1], lambda x: x[1] - x[0],
|
fs = [lambda x: x[0] + x[1], lambda x: x[1] + x[0], lambda x: x[0] - x[1], lambda x: x[1] - x[0],
|
||||||
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0],
|
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0]]
|
||||||
lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
|
|
||||||
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
|
|
||||||
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
|
|
||||||
|
|
||||||
for i, f in enumerate(fs):
|
for i, f in enumerate(fs):
|
||||||
t1 = f([a, b])
|
t1 = f([a, b])
|
||||||
t2 = pe.derived_observable(f, [a, b])
|
t2 = pe.derived_observable(f, [a, b])
|
||||||
c = t2 - t1
|
c = t2 - t1
|
||||||
assert c.value == 0.0, str(i)
|
assert c.is_zero()
|
||||||
assert np.all(np.abs(c.deltas['e1']) < 1e-14), str(i)
|
|
||||||
|
|
||||||
|
f_np = [lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
|
||||||
|
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
|
||||||
|
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
|
||||||
|
f_jnp = [lambda x: jnp.exp(x[0]), lambda x: jnp.sin(x[0]), lambda x: jnp.cos(x[0]), lambda x: jnp.tan(x[0]),
|
||||||
|
lambda x: jnp.log(x[0]), lambda x: jnp.sqrt(jnp.abs(x[0])),
|
||||||
|
lambda x: jnp.sinh(x[0]), lambda x: jnp.cosh(x[0]), lambda x: jnp.tanh(x[0])]
|
||||||
|
|
||||||
|
for i, (f1, f2) in enumerate(zip(f_np, f_jnp)):
|
||||||
|
t1 = f1([a])
|
||||||
|
t2 = pe.derived_observable(f2, [a])
|
||||||
|
c = t2 - t1
|
||||||
|
assert c.is_zero()
|
||||||
|
|
||||||
|
|
||||||
def test_overloading_vectorization():
|
def test_overloading_vectorization():
|
||||||
|
@ -121,7 +134,7 @@ def test_derived_observables():
|
||||||
test_obs = pe.pseudo_Obs(2, 0.1 * (1 + np.random.rand()), 't', int(1000 * (1 + np.random.rand())))
|
test_obs = pe.pseudo_Obs(2, 0.1 * (1 + np.random.rand()), 't', int(1000 * (1 + np.random.rand())))
|
||||||
|
|
||||||
# Check if autograd and numgrad give the same result
|
# Check if autograd and numgrad give the same result
|
||||||
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs])
|
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * jnp.sin(x[0] * x[1]), [test_obs, test_obs])
|
||||||
d_Obs_ad.gamma_method()
|
d_Obs_ad.gamma_method()
|
||||||
d_Obs_fd = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs], num_grad=True)
|
d_Obs_fd = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs], num_grad=True)
|
||||||
d_Obs_fd.gamma_method()
|
d_Obs_fd.gamma_method()
|
||||||
|
|
Loading…
Add table
Reference in a new issue