diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index 8ef26ce2..72ce606b 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -1,6 +1,6 @@ import warnings import numpy as np -import autograd.numpy as anp +import jax.numpy as jnp import matplotlib.pyplot as plt import scipy.linalg from .pyerrors import Obs, dump_object @@ -187,10 +187,10 @@ class Corr: def Eigenvalue(self, t0, state=1): G = self.smearing_symmetric() G0 = G.content[t0] - L = mat_mat_op(anp.linalg.cholesky, G0) - Li = mat_mat_op(anp.linalg.inv, L) + L = mat_mat_op(jnp.linalg.cholesky, G0) + Li = mat_mat_op(jnp.linalg.inv, L) LT = L.T - LTi = mat_mat_op(anp.linalg.inv, LT) + LTi = mat_mat_op(jnp.linalg.inv, LT) newcontent = [] for t in range(self.T): Gt = G.content[t] @@ -263,9 +263,9 @@ class Corr: elif variant in ['periodic', 'cosh', 'sinh']: if variant in ['periodic', 'cosh']: - func = anp.cosh + func = jnp.cosh else: - func = anp.sinh + func = jnp.sinh def root_function(x, d): return func(x * (t - self.T / 2)) / func(x * (t + 1 - self.T / 2)) - d diff --git a/pyerrors/fits.py b/pyerrors/fits.py index 1781df37..40ea7e78 100644 --- a/pyerrors/fits.py +++ b/pyerrors/fits.py @@ -3,15 +3,15 @@ import warnings import numpy as np -import autograd.numpy as anp +import jax.numpy as jnp import scipy.optimize import scipy.stats import matplotlib.pyplot as plt from matplotlib import gridspec from scipy.odr import ODR, Model, RealData import iminuit -from autograd import jacobian -from autograd import elementwise_grad as egrad +from jax import jacobian +#from jax import elementwise_grad as egrad from .pyerrors import Obs, derived_observable, covariance, pseudo_Obs @@ -24,7 +24,7 @@ def standard_fit(x, y, func, silent=False, **kwargs): func has to be of the form def func(a, x): - return a[0] + a[1] * x + a[2] * anp.sinh(x) + return a[0] + a[1] * x + a[2] * jnp.sinh(x) For multiple x values func can be of the form @@ -82,10 +82,10 @@ def standard_fit(x, y, func, silent=False, **kwargs): if not silent: print('Fit with', n_parms, 'parameters') - y_f = [o.value for o in y] - dy_f = [o.dvalue for o in y] + y_f = np.array([o.value for o in y]) + dy_f = np.array([o.dvalue for o in y]) - if np.any(np.asarray(dy_f) <= 0.0): + if np.any(dy_f <= 0.0): raise Exception('No y errors available, run the gamma method first.') if 'initial_guess' in kwargs: @@ -97,7 +97,7 @@ def standard_fit(x, y, func, silent=False, **kwargs): def chisqfunc(p): model = func(p, x) - chisq = anp.sum(((y_f - model) / dy_f) ** 2) + chisq = jnp.sum(((y_f - model) / dy_f) ** 2) return chisq if 'method' in kwargs: @@ -153,7 +153,7 @@ def standard_fit(x, y, func, silent=False, **kwargs): def chisqfunc_compact(d): model = func(d[:n_parms], x) - chisq = anp.sum(((d[n_parms:] - model) / dy_f) ** 2) + chisq = jnp.sum(((d[n_parms:] - model) / dy_f) ** 2) return chisq jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((fit_result.x, y_f))) @@ -188,7 +188,7 @@ def odr_fit(x, y, func, silent=False, **kwargs): func has to be of the form def func(a, x): - y = a[0] + a[1] * x + a[2] * anp.sinh(x) + y = a[0] + a[1] * x + a[2] * jnp.sinh(x) return y For multiple x values func can be of the form @@ -279,7 +279,7 @@ def odr_fit(x, y, func, silent=False, **kwargs): def odr_chisquare(p): model = func(p[:n_parms], p[n_parms:].reshape(x_shape)) - chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((x_f - p[n_parms:].reshape(x_shape)) / dx_f) ** 2) + chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((x_f - p[n_parms:].reshape(x_shape)) / dx_f) ** 2) return chisq if kwargs.get('expected_chisquare') is True: @@ -312,7 +312,7 @@ def odr_fit(x, y, func, silent=False, **kwargs): def odr_chisquare_compact_x(d): model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape)) - chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((d[n_parms + m:].reshape(x_shape) - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2) + chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((d[n_parms + m:].reshape(x_shape) - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2) return chisq jac_jac_x = jacobian(jacobian(odr_chisquare_compact_x))(np.concatenate((output.beta, output.xplus.ravel(), x_f.ravel()))) @@ -321,7 +321,7 @@ def odr_fit(x, y, func, silent=False, **kwargs): def odr_chisquare_compact_y(d): model = func(d[:n_parms], d[n_parms:n_parms + m].reshape(x_shape)) - chisq = anp.sum(((d[n_parms + m:] - model) / dy_f) ** 2) + anp.sum(((x_f - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2) + chisq = jnp.sum(((d[n_parms + m:] - model) / dy_f) ** 2) + jnp.sum(((x_f - d[n_parms:n_parms + m].reshape(x_shape)) / dx_f) ** 2) return chisq jac_jac_y = jacobian(jacobian(odr_chisquare_compact_y))(np.concatenate((output.beta, output.xplus.ravel(), y_f))) @@ -349,7 +349,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs): func has to be of the form def func(a, x): - y = a[0] + a[1] * x + a[2] * anp.sinh(x) + y = a[0] + a[1] * x + a[2] * jnp.sinh(x) return y It is important that all numpy functions refer to autograd.numpy, otherwise the differentiation @@ -441,7 +441,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs): def chisqfunc(p): model = func(p, x) - chisq = anp.sum(((y_f - model) / dy_f) ** 2) + anp.sum(((p_f - p) / dp_f) ** 2) + chisq = jnp.sum(((y_f - model) / dy_f) ** 2) + jnp.sum(((p_f - p) / dp_f) ** 2) return chisq if not silent: @@ -469,7 +469,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs): def chisqfunc_compact(d): model = func(d[:n_parms], x) - chisq = anp.sum(((d[n_parms: n_parms + len(x)] - model) / dy_f) ** 2) + anp.sum(((d[n_parms + len(x):] - d[:n_parms]) / dp_f) ** 2) + chisq = jnp.sum(((d[n_parms: n_parms + len(x)] - model) / dy_f) ** 2) + jnp.sum(((d[n_parms + len(x):] - d[:n_parms]) / dp_f) ** 2) return chisq jac_jac = jacobian(jacobian(chisqfunc_compact))(np.concatenate((params, y_f, p_f))) diff --git a/pyerrors/input/bdio.py b/pyerrors/input/bdio.py index 8d14b440..c6a53bfd 100644 --- a/pyerrors/input/bdio.py +++ b/pyerrors/input/bdio.py @@ -3,7 +3,7 @@ import ctypes import hashlib -import autograd.numpy as np # Thinly-wrapped numpy +import jax.numpy as np # Thinly-wrapped numpy from ..pyerrors import Obs diff --git a/pyerrors/linalg.py b/pyerrors/linalg.py index a8810091..467f74c3 100644 --- a/pyerrors/linalg.py +++ b/pyerrors/linalg.py @@ -2,7 +2,7 @@ # coding: utf-8 import numpy as np -import autograd.numpy as anp # Thinly-wrapped numpy +import jax.numpy as anp # Thinly-wrapped numpy from .pyerrors import derived_observable, CObs diff --git a/pyerrors/pyerrors.py b/pyerrors/pyerrors.py index fa30c0f4..f86fd1c1 100644 --- a/pyerrors/pyerrors.py +++ b/pyerrors/pyerrors.py @@ -4,10 +4,12 @@ import warnings import pickle import numpy as np -import autograd.numpy as anp # Thinly-wrapped numpy -from autograd import jacobian +import jax.numpy as jnp # Thinly-wrapped numpy +from jax import jacobian import matplotlib.pyplot as plt import numdifftools as nd +from jax.config import config +config.update("jax_enable_x64", True) class Obs: @@ -573,7 +575,7 @@ class Obs: return derived_observable(lambda x: y ** x[0], [self]) def __abs__(self): - return derived_observable(lambda x: anp.abs(x[0]), [self]) + return derived_observable(lambda x: jnp.abs(x[0]), [self]) # Overload numpy functions def sqrt(self): @@ -595,13 +597,13 @@ class Obs: return derived_observable(lambda x, **kwargs: np.tan(x[0]), [self], man_grad=[1 / np.cos(self.value) ** 2]) def arcsin(self): - return derived_observable(lambda x: anp.arcsin(x[0]), [self]) + return derived_observable(lambda x: jnp.arcsin(x[0]), [self]) def arccos(self): - return derived_observable(lambda x: anp.arccos(x[0]), [self]) + return derived_observable(lambda x: jnp.arccos(x[0]), [self]) def arctan(self): - return derived_observable(lambda x: anp.arctan(x[0]), [self]) + return derived_observable(lambda x: jnp.arctan(x[0]), [self]) def sinh(self): return derived_observable(lambda x, **kwargs: np.sinh(x[0]), [self], man_grad=[np.cosh(self.value)]) @@ -613,16 +615,16 @@ class Obs: return derived_observable(lambda x, **kwargs: np.tanh(x[0]), [self], man_grad=[1 / np.cosh(self.value) ** 2]) def arcsinh(self): - return derived_observable(lambda x: anp.arcsinh(x[0]), [self]) + return derived_observable(lambda x: jnp.arcsinh(x[0]), [self]) def arccosh(self): - return derived_observable(lambda x: anp.arccosh(x[0]), [self]) + return derived_observable(lambda x: jnp.arccosh(x[0]), [self]) def arctanh(self): - return derived_observable(lambda x: anp.arctanh(x[0]), [self]) + return derived_observable(lambda x: jnp.arctanh(x[0]), [self]) def sinc(self): - return derived_observable(lambda x: anp.sinc(x[0]), [self]) + return derived_observable(lambda x: jnp.sinc(x[0]), [self]) class CObs: @@ -695,7 +697,7 @@ def derived_observable(func, data, **kwargs): ---------- func -- arbitrary function of the form func(data, **kwargs). For the automatic differentiation to work, all numpy functions have to have - the autograd wrapper (use 'import autograd.numpy as anp'). + the autograd wrapper (use 'import autograd.numpy as jnp'). data -- list of Obs, e.g. [obs1, obs2, obs3]. Keyword arguments @@ -748,9 +750,9 @@ def derived_observable(func, data, **kwargs): if new_shape[name] != tmp: raise Exception('Shapes of ensemble', name, 'do not match.') if data.ndim == 1: - values = np.array([o.value for o in data]) + values = jnp.array([o.value for o in data]) else: - values = np.vectorize(lambda x: x.value)(data) + values = jnp.array(np.vectorize(lambda x: x.value)(data)) new_values = func(values, **kwargs) diff --git a/pyerrors/roots.py b/pyerrors/roots.py index 0c7c1566..6bfccbbd 100644 --- a/pyerrors/roots.py +++ b/pyerrors/roots.py @@ -2,7 +2,7 @@ # coding: utf-8 import scipy.optimize -from autograd import jacobian +from jax import jacobian from .pyerrors import derived_observable, pseudo_Obs diff --git a/tests/test_fits.py b/tests/test_fits.py index 461136ce..b8b0fe47 100644 --- a/tests/test_fits.py +++ b/tests/test_fits.py @@ -1,9 +1,12 @@ -import autograd.numpy as np +import numpy as np +import jax.numpy as jnp import math import scipy.optimize from scipy.odr import ODR, Model, RealData import pyerrors as pe import pytest +from jax.config import config +config.update("jax_enable_x64", True) np.random.seed(0) @@ -24,7 +27,7 @@ def test_standard_fit(): popt, pcov = scipy.optimize.curve_fit(f, x, y, sigma=[o.dvalue for o in oy], absolute_sigma=True) def func(a, x): - y = a[0] * np.exp(-a[1] * x) + y = a[0] * jnp.exp(-a[1] * x) return y beta = pe.fits.standard_fit(x, oy, func) @@ -61,7 +64,7 @@ def test_odr_fit(): return a * np.exp(-b * x) def func(a, x): - y = a[0] * np.exp(-a[1] * x) + y = a[0] * jnp.exp(-a[1] * x) return y data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy]) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index af121912..a153f7e9 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1,7 +1,10 @@ -import autograd.numpy as np +import numpy as np +import jax.numpy as jnp import math import pyerrors as pe import pytest +from jax.config import config +config.update("jax_enable_x64", True) np.random.seed(0) @@ -14,7 +17,7 @@ def test_matrix_inverse(): content.append(1.0) # Add 1.0 as a float matrix = np.diag(content) - inverse_matrix = pe.linalg.mat_mat_op(np.linalg.inv, matrix) + inverse_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix) assert all([o.is_zero() for o in np.diag(matrix) * np.diag(inverse_matrix) - 1]) @@ -35,7 +38,7 @@ def test_complex_matrix_inverse(): matrix[n, m] = entry.real.value + 1j * entry.imag.value inverse_matrix = np.linalg.inv(matrix) - inverse_obs_matrix = pe.linalg.mat_mat_op(np.linalg.inv, obs_matrix) + inverse_obs_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, obs_matrix) for (n, m), entry in np.ndenumerate(inverse_matrix): assert np.isclose(inverse_matrix[n, m].real, inverse_obs_matrix[n, m].real.value) assert np.isclose(inverse_matrix[n, m].imag, inverse_obs_matrix[n, m].imag.value) @@ -53,7 +56,7 @@ def test_matrix_functions(): matrix = np.array(matrix) @ np.identity(dim) # Check inverse of matrix - inv = pe.linalg.mat_mat_op(np.linalg.inv, matrix) + inv = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix) check_inv = matrix @ inv for (i, j), entry in np.ndenumerate(check_inv): @@ -66,7 +69,7 @@ def test_matrix_functions(): # Check Cholesky decomposition sym = np.dot(matrix, matrix.T) - cholesky = pe.linalg.mat_mat_op(np.linalg.cholesky, sym) + cholesky = pe.linalg.mat_mat_op(jnp.linalg.cholesky, sym) check = cholesky @ cholesky.T for (i, j), entry in np.ndenumerate(check): diff --git a/tests/test_pyerrors.py b/tests/test_pyerrors.py index 982dcab3..5798381b 100644 --- a/tests/test_pyerrors.py +++ b/tests/test_pyerrors.py @@ -1,10 +1,13 @@ -import autograd.numpy as np +import numpy as np +import jax.numpy as jnp import os import random import string import copy import pyerrors as pe import pytest +from jax.config import config +config.update("jax_enable_x64", True) np.random.seed(0) @@ -29,21 +32,31 @@ def test_comparison(): def test_function_overloading(): - a = pe.pseudo_Obs(17, 2.9, 'e1') + a = pe.pseudo_Obs(2, 2.9, 'e1') b = pe.pseudo_Obs(4, 0.8, 'e1') fs = [lambda x: x[0] + x[1], lambda x: x[1] + x[0], lambda x: x[0] - x[1], lambda x: x[1] - x[0], - lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0], - lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]), - lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])), - lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])] + lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0]] for i, f in enumerate(fs): t1 = f([a, b]) t2 = pe.derived_observable(f, [a, b]) c = t2 - t1 - assert c.value == 0.0, str(i) - assert np.all(np.abs(c.deltas['e1']) < 1e-14), str(i) + assert c.is_zero() + + + f_np = [lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]), + lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])), + lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])] + f_jnp = [lambda x: jnp.exp(x[0]), lambda x: jnp.sin(x[0]), lambda x: jnp.cos(x[0]), lambda x: jnp.tan(x[0]), + lambda x: jnp.log(x[0]), lambda x: jnp.sqrt(jnp.abs(x[0])), + lambda x: jnp.sinh(x[0]), lambda x: jnp.cosh(x[0]), lambda x: jnp.tanh(x[0])] + + for i, (f1, f2) in enumerate(zip(f_np, f_jnp)): + t1 = f1([a]) + t2 = pe.derived_observable(f2, [a]) + c = t2 - t1 + assert c.is_zero() def test_overloading_vectorization(): @@ -121,7 +134,7 @@ def test_derived_observables(): test_obs = pe.pseudo_Obs(2, 0.1 * (1 + np.random.rand()), 't', int(1000 * (1 + np.random.rand()))) # Check if autograd and numgrad give the same result - d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs]) + d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * jnp.sin(x[0] * x[1]), [test_obs, test_obs]) d_Obs_ad.gamma_method() d_Obs_fd = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs], num_grad=True) d_Obs_fd.gamma_method()