mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-15 23:00:25 +01:00
Merge branch 'develop' into feature/covobs
This commit is contained in:
commit
9db709a171
9 changed files with 135 additions and 176 deletions
|
@ -246,12 +246,11 @@ See `pyerrors.obs.Obs.export_jackknife` for details.
|
||||||
from .obs import *
|
from .obs import *
|
||||||
from .correlators import *
|
from .correlators import *
|
||||||
from .fits import *
|
from .fits import *
|
||||||
|
from .misc import *
|
||||||
from . import dirac
|
from . import dirac
|
||||||
from . import input
|
from . import input
|
||||||
from . import linalg
|
from . import linalg
|
||||||
from . import misc
|
|
||||||
from . import mpm
|
from . import mpm
|
||||||
from . import npr
|
|
||||||
from . import roots
|
from . import roots
|
||||||
|
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
|
|
|
@ -3,7 +3,8 @@ import numpy as np
|
||||||
import autograd.numpy as anp
|
import autograd.numpy as anp
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import scipy.linalg
|
import scipy.linalg
|
||||||
from .obs import Obs, dump_object, reweight, correlate
|
from .obs import Obs, reweight, correlate
|
||||||
|
from .misc import dump_object
|
||||||
from .fits import least_squares
|
from .fits import least_squares
|
||||||
from .linalg import eigh, inv, cholesky
|
from .linalg import eigh, inv, cholesky
|
||||||
from .roots import find_root
|
from .roots import find_root
|
||||||
|
|
|
@ -301,7 +301,8 @@ def total_least_squares(x, y, func, silent=False, **kwargs):
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for i in range(n_parms):
|
for i in range(n_parms):
|
||||||
result.append(derived_observable(lambda x, **kwargs: x[0], [pseudo_Obs(out.beta[i], 0.0, y[0].names[0], y[0].shape[y[0].names[0]])] + list(x.ravel()) + list(y), man_grad=[0] + list(deriv_x[i]) + list(deriv_y[i])))
|
result.append(derived_observable(lambda x, **kwargs: x[0], list(x.ravel()) + list(y), man_grad=list(deriv_x[i]) + list(deriv_y[i])))
|
||||||
|
result[-1]._value = out.beta[i]
|
||||||
|
|
||||||
output.fit_parameters = result + const_par
|
output.fit_parameters = result + const_par
|
||||||
|
|
||||||
|
@ -418,7 +419,8 @@ def _prior_fit(x, y, func, priors, silent=False, **kwargs):
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for i in range(n_parms):
|
for i in range(n_parms):
|
||||||
result.append(derived_observable(lambda x, **kwargs: x[0], [pseudo_Obs(params[i], 0.0, y[0].names[0], y[0].shape[y[0].names[0]])] + list(y) + list(loc_priors), man_grad=[0] + list(deriv[i])))
|
result.append(derived_observable(lambda x, **kwargs: x[0], list(y) + list(loc_priors), man_grad=list(deriv[i])))
|
||||||
|
result[-1]._value = params[i]
|
||||||
|
|
||||||
output.fit_parameters = result
|
output.fit_parameters = result
|
||||||
output.chisquare = chisqfunc(np.asarray(params))
|
output.chisquare = chisqfunc(np.asarray(params))
|
||||||
|
@ -612,7 +614,8 @@ def _standard_fit(x, y, func, silent=False, **kwargs):
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for i in range(n_parms):
|
for i in range(n_parms):
|
||||||
result.append(derived_observable(lambda x, **kwargs: x[0], [pseudo_Obs(fit_result.x[i], 0.0, y[0].names[0], y[0].shape[y[0].names[0]])] + list(y), man_grad=[0] + list(deriv[i])))
|
result.append(derived_observable(lambda x, **kwargs: x[0], list(y), man_grad=list(deriv[i])))
|
||||||
|
result[-1]._value = fit_result.x[i]
|
||||||
|
|
||||||
output.fit_parameters = result + const_par
|
output.fit_parameters = result + const_par
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,11 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf-8
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..obs import Obs, CObs
|
from ..obs import Obs, CObs
|
||||||
from ..correlators import Corr
|
from ..correlators import Corr
|
||||||
from ..npr import Npr_matrix
|
|
||||||
|
|
||||||
|
|
||||||
def _get_files(path, filestem):
|
def _get_files(path, filestem, idl):
|
||||||
ls = os.listdir(path)
|
ls = os.listdir(path)
|
||||||
|
|
||||||
# Clean up file list
|
# Clean up file list
|
||||||
|
@ -24,11 +20,19 @@ def _get_files(path, filestem):
|
||||||
# Sort according to configuration number
|
# Sort according to configuration number
|
||||||
files.sort(key=get_cnfg_number)
|
files.sort(key=get_cnfg_number)
|
||||||
|
|
||||||
# Check that configurations are evenly spaced
|
|
||||||
cnfg_numbers = []
|
cnfg_numbers = []
|
||||||
|
filtered_files = []
|
||||||
for line in files:
|
for line in files:
|
||||||
cnfg_numbers.append(get_cnfg_number(line))
|
no = get_cnfg_number(line)
|
||||||
|
if idl:
|
||||||
|
if no in list(idl):
|
||||||
|
filtered_files.append(line)
|
||||||
|
cnfg_numbers.append(no)
|
||||||
|
else:
|
||||||
|
filtered_files.append(line)
|
||||||
|
cnfg_numbers.append(no)
|
||||||
|
|
||||||
|
# Check that configurations are evenly spaced
|
||||||
dc = np.unique(np.diff(cnfg_numbers))
|
dc = np.unique(np.diff(cnfg_numbers))
|
||||||
if np.any(dc < 0):
|
if np.any(dc < 0):
|
||||||
raise Exception("Unsorted files")
|
raise Exception("Unsorted files")
|
||||||
|
@ -37,10 +41,10 @@ def _get_files(path, filestem):
|
||||||
else:
|
else:
|
||||||
raise Exception('Configurations are not evenly spaced.')
|
raise Exception('Configurations are not evenly spaced.')
|
||||||
|
|
||||||
return files, idx
|
return filtered_files, idx
|
||||||
|
|
||||||
|
|
||||||
def read_meson_hd5(path, filestem, ens_id, meson='meson_0', tree='meson'):
|
def read_meson_hd5(path, filestem, ens_id, meson='meson_0', tree='meson', idl=None):
|
||||||
"""Read hadrons meson hdf5 file and extract the meson labeled 'meson'
|
"""Read hadrons meson hdf5 file and extract the meson labeled 'meson'
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -58,9 +62,11 @@ def read_meson_hd5(path, filestem, ens_id, meson='meson_0', tree='meson'):
|
||||||
Label of the upmost directory in the hdf5 file, default 'meson'
|
Label of the upmost directory in the hdf5 file, default 'meson'
|
||||||
for outputs of the Meson module. Can be altered to read input
|
for outputs of the Meson module. Can be altered to read input
|
||||||
from other modules with similar structures.
|
from other modules with similar structures.
|
||||||
|
idl : range
|
||||||
|
If specified only conifgurations in the given range are read in.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
files, idx = _get_files(path, filestem)
|
files, idx = _get_files(path, filestem, idl)
|
||||||
|
|
||||||
corr_data = []
|
corr_data = []
|
||||||
infos = []
|
infos = []
|
||||||
|
@ -84,7 +90,47 @@ def read_meson_hd5(path, filestem, ens_id, meson='meson_0', tree='meson'):
|
||||||
return corr
|
return corr
|
||||||
|
|
||||||
|
|
||||||
def read_ExternalLeg_hd5(path, filestem, ens_id, order='F'):
|
class Npr_matrix(np.ndarray):
|
||||||
|
|
||||||
|
def __new__(cls, input_array, mom_in=None, mom_out=None):
|
||||||
|
obj = np.asarray(input_array).view(cls)
|
||||||
|
obj.mom_in = mom_in
|
||||||
|
obj.mom_out = mom_out
|
||||||
|
return obj
|
||||||
|
|
||||||
|
@property
|
||||||
|
def g5H(self):
|
||||||
|
"""Gamma_5 hermitean conjugate
|
||||||
|
|
||||||
|
Uses the fact that the propagator is gamma5 hermitean, so just the
|
||||||
|
in and out momenta of the propagator are exchanged.
|
||||||
|
"""
|
||||||
|
return Npr_matrix(self,
|
||||||
|
mom_in=self.mom_out,
|
||||||
|
mom_out=self.mom_in)
|
||||||
|
|
||||||
|
def _propagate_mom(self, other, name):
|
||||||
|
s_mom = getattr(self, name, None)
|
||||||
|
o_mom = getattr(other, name, None)
|
||||||
|
if s_mom is not None and o_mom is not None:
|
||||||
|
if not np.allclose(s_mom, o_mom):
|
||||||
|
raise Exception(name + ' does not match.')
|
||||||
|
return o_mom if o_mom is not None else s_mom
|
||||||
|
|
||||||
|
def __matmul__(self, other):
|
||||||
|
return self.__new__(Npr_matrix,
|
||||||
|
super().__matmul__(other),
|
||||||
|
self._propagate_mom(other, 'mom_in'),
|
||||||
|
self._propagate_mom(other, 'mom_out'))
|
||||||
|
|
||||||
|
def __array_finalize__(self, obj):
|
||||||
|
if obj is None:
|
||||||
|
return
|
||||||
|
self.mom_in = getattr(obj, 'mom_in', None)
|
||||||
|
self.mom_out = getattr(obj, 'mom_out', None)
|
||||||
|
|
||||||
|
|
||||||
|
def read_ExternalLeg_hd5(path, filestem, ens_id, idl=None):
|
||||||
"""Read hadrons ExternalLeg hdf5 file and output an array of CObs
|
"""Read hadrons ExternalLeg hdf5 file and output an array of CObs
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -92,12 +138,11 @@ def read_ExternalLeg_hd5(path, filestem, ens_id, order='F'):
|
||||||
path -- path to the files to read
|
path -- path to the files to read
|
||||||
filestem -- namestem of the files to read
|
filestem -- namestem of the files to read
|
||||||
ens_id -- name of the ensemble, required for internal bookkeeping
|
ens_id -- name of the ensemble, required for internal bookkeeping
|
||||||
order -- order in which the array is to be reshaped,
|
idl : range
|
||||||
'F' for the first index changing fastest (9 4x4 matrices) default.
|
If specified only conifgurations in the given range are read in.
|
||||||
'C' for the last index changing fastest (16 3x3 matrices),
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
files, idx = _get_files(path, filestem)
|
files, idx = _get_files(path, filestem, idl)
|
||||||
|
|
||||||
mom = None
|
mom = None
|
||||||
|
|
||||||
|
@ -119,10 +164,10 @@ def read_ExternalLeg_hd5(path, filestem, ens_id, order='F'):
|
||||||
imag = Obs([rolled_array[si, sj, ci, cj].imag], [ens_id], idl=[idx])
|
imag = Obs([rolled_array[si, sj, ci, cj].imag], [ens_id], idl=[idx])
|
||||||
matrix[si, sj, ci, cj] = CObs(real, imag)
|
matrix[si, sj, ci, cj] = CObs(real, imag)
|
||||||
|
|
||||||
return Npr_matrix(matrix.swapaxes(1, 2).reshape((12, 12), order=order), mom_in=mom)
|
return Npr_matrix(matrix.swapaxes(1, 2).reshape((12, 12), order='F'), mom_in=mom)
|
||||||
|
|
||||||
|
|
||||||
def read_Bilinear_hd5(path, filestem, ens_id, order='F'):
|
def read_Bilinear_hd5(path, filestem, ens_id, idl=None):
|
||||||
"""Read hadrons Bilinear hdf5 file and output an array of CObs
|
"""Read hadrons Bilinear hdf5 file and output an array of CObs
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -130,12 +175,11 @@ def read_Bilinear_hd5(path, filestem, ens_id, order='F'):
|
||||||
path -- path to the files to read
|
path -- path to the files to read
|
||||||
filestem -- namestem of the files to read
|
filestem -- namestem of the files to read
|
||||||
ens_id -- name of the ensemble, required for internal bookkeeping
|
ens_id -- name of the ensemble, required for internal bookkeeping
|
||||||
order -- order in which the array is to be reshaped,
|
idl : range
|
||||||
'F' for the first index changing fastest (9 4x4 matrices) default.
|
If specified only conifgurations in the given range are read in.
|
||||||
'C' for the last index changing fastest (16 3x3 matrices),
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
files, idx = _get_files(path, filestem)
|
files, idx = _get_files(path, filestem, idl)
|
||||||
|
|
||||||
mom_in = None
|
mom_in = None
|
||||||
mom_out = None
|
mom_out = None
|
||||||
|
@ -169,6 +213,6 @@ def read_Bilinear_hd5(path, filestem, ens_id, order='F'):
|
||||||
imag = Obs([rolled_array[si, sj, ci, cj].imag], [ens_id], idl=[idx])
|
imag = Obs([rolled_array[si, sj, ci, cj].imag], [ens_id], idl=[idx])
|
||||||
matrix[si, sj, ci, cj] = CObs(real, imag)
|
matrix[si, sj, ci, cj] = CObs(real, imag)
|
||||||
|
|
||||||
result_dict[key] = Npr_matrix(matrix.swapaxes(1, 2).reshape((12, 12), order=order), mom_in=mom_in, mom_out=mom_out)
|
result_dict[key] = Npr_matrix(matrix.swapaxes(1, 2).reshape((12, 12), order='F'), mom_in=mom_in, mom_out=mom_out)
|
||||||
|
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
|
@ -1,18 +1,56 @@
|
||||||
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .obs import Obs
|
from .obs import Obs
|
||||||
|
|
||||||
|
|
||||||
|
def dump_object(obj, name, **kwargs):
|
||||||
|
"""Dump object into pickle file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
obj : object
|
||||||
|
object to be saved in the pickle file
|
||||||
|
name : str
|
||||||
|
name of the file
|
||||||
|
path : str
|
||||||
|
specifies a custom path for the file (default '.')
|
||||||
|
"""
|
||||||
|
if 'path' in kwargs:
|
||||||
|
file_name = kwargs.get('path') + '/' + name + '.p'
|
||||||
|
else:
|
||||||
|
file_name = name + '.p'
|
||||||
|
with open(file_name, 'wb') as fb:
|
||||||
|
pickle.dump(obj, fb)
|
||||||
|
|
||||||
|
|
||||||
|
def load_object(path):
|
||||||
|
"""Load object from pickle file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : str
|
||||||
|
path to the file
|
||||||
|
"""
|
||||||
|
with open(path, 'rb') as file:
|
||||||
|
return pickle.load(file)
|
||||||
|
|
||||||
|
|
||||||
def gen_correlated_data(means, cov, name, tau=0.5, samples=1000):
|
def gen_correlated_data(means, cov, name, tau=0.5, samples=1000):
|
||||||
""" Generate observables with given covariance and autocorrelation times.
|
""" Generate observables with given covariance and autocorrelation times.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
means -- list containing the mean value of each observable.
|
means : list
|
||||||
cov -- covariance matrix for the data to be geneated.
|
list containing the mean value of each observable.
|
||||||
name -- ensemble name for the data to be geneated.
|
cov : numpy.ndarray
|
||||||
tau -- can either be a real number or a list with an entry for
|
covariance matrix for the data to be generated.
|
||||||
every dataset.
|
name : str
|
||||||
samples -- number of samples to be generated for each observable.
|
ensemble name for the data to be geneated.
|
||||||
|
tau : float or list
|
||||||
|
can either be a real number or a list with an entry for
|
||||||
|
every dataset.
|
||||||
|
samples : int
|
||||||
|
number of samples to be generated for each observable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(means) == cov.shape[-1]
|
assert len(means) == cov.shape[-1]
|
||||||
|
|
108
pyerrors/npr.py
108
pyerrors/npr.py
|
@ -1,108 +0,0 @@
|
||||||
import warnings
|
|
||||||
import numpy as np
|
|
||||||
from .linalg import inv, matmul
|
|
||||||
from .dirac import gamma
|
|
||||||
|
|
||||||
|
|
||||||
L = None
|
|
||||||
T = None
|
|
||||||
|
|
||||||
|
|
||||||
class Npr_matrix(np.ndarray):
|
|
||||||
|
|
||||||
def __new__(cls, input_array, mom_in=None, mom_out=None):
|
|
||||||
obj = np.asarray(input_array).view(cls)
|
|
||||||
obj.mom_in = mom_in
|
|
||||||
obj.mom_out = mom_out
|
|
||||||
return obj
|
|
||||||
|
|
||||||
@property
|
|
||||||
def g5H(self):
|
|
||||||
"""Gamma_5 hermitean conjugate
|
|
||||||
|
|
||||||
Uses the fact that the propagator is gamma5 hermitean, so just the
|
|
||||||
in and out momenta of the propagator are exchanged.
|
|
||||||
"""
|
|
||||||
return Npr_matrix(self,
|
|
||||||
mom_in=self.mom_out,
|
|
||||||
mom_out=self.mom_in)
|
|
||||||
|
|
||||||
def _propagate_mom(self, other, name):
|
|
||||||
s_mom = getattr(self, name, None)
|
|
||||||
o_mom = getattr(other, name, None)
|
|
||||||
if s_mom is not None and o_mom is not None:
|
|
||||||
if not np.allclose(s_mom, o_mom):
|
|
||||||
raise Exception(name + ' does not match.')
|
|
||||||
return o_mom if o_mom is not None else s_mom
|
|
||||||
|
|
||||||
def __matmul__(self, other):
|
|
||||||
return self.__new__(Npr_matrix,
|
|
||||||
super().__matmul__(other),
|
|
||||||
self._propagate_mom(other, 'mom_in'),
|
|
||||||
self._propagate_mom(other, 'mom_out'))
|
|
||||||
|
|
||||||
def __array_finalize__(self, obj):
|
|
||||||
if obj is None:
|
|
||||||
return
|
|
||||||
self.mom_in = getattr(obj, 'mom_in', None)
|
|
||||||
self.mom_out = getattr(obj, 'mom_out', None)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_geometry():
|
|
||||||
if L is None:
|
|
||||||
raise Exception("Spatial extent 'L' not set.")
|
|
||||||
else:
|
|
||||||
if not isinstance(L, int):
|
|
||||||
raise Exception("Spatial extent 'L' must be an integer.")
|
|
||||||
if T is None:
|
|
||||||
raise Exception("Temporal extent 'T' not set.")
|
|
||||||
if not isinstance(T, int):
|
|
||||||
raise Exception("Temporal extent 'T' must be an integer.")
|
|
||||||
|
|
||||||
|
|
||||||
def inv_propagator(prop):
|
|
||||||
""" Inverts a 12x12 quark propagator"""
|
|
||||||
if prop.shape != (12, 12):
|
|
||||||
raise Exception("Only 12x12 propagators can be inverted.")
|
|
||||||
return Npr_matrix(inv(prop), prop.mom_in)
|
|
||||||
|
|
||||||
|
|
||||||
def Zq(inv_prop, fermion='Wilson'):
|
|
||||||
""" Calculates the quark field renormalization constant Zq
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
inv_prop : array
|
|
||||||
Inverted 12x12 quark propagator
|
|
||||||
fermion : str
|
|
||||||
Fermion type for which the tree-level propagator is used
|
|
||||||
in the calculation of Zq. Default Wilson.
|
|
||||||
"""
|
|
||||||
_check_geometry()
|
|
||||||
mom = np.copy(inv_prop.mom_in)
|
|
||||||
mom[3] /= T / L
|
|
||||||
sin_mom = np.sin(2 * np.pi / L * mom)
|
|
||||||
|
|
||||||
if fermion == 'Wilson':
|
|
||||||
p_slash = -1j * (sin_mom[0] * gamma[0] + sin_mom[1] * gamma[1] + sin_mom[2] * gamma[2] + sin_mom[3] * gamma[3]) / np.sum(sin_mom ** 2)
|
|
||||||
elif fermion == 'Continuum':
|
|
||||||
p_mom = 2 * np.pi / L * mom
|
|
||||||
p_slash = -1j * (p_mom[0] * gamma[0] + p_mom[1] * gamma[1] + p_mom[2] * gamma[2] + p_mom[3] * gamma[3]) / np.sum(p_mom ** 2)
|
|
||||||
elif fermion == 'DWF':
|
|
||||||
W = np.sum(1 - np.cos(2 * np.pi / L * mom))
|
|
||||||
s2 = np.sum(sin_mom ** 2)
|
|
||||||
p_slash = -1j * (sin_mom[0] * gamma[0] + sin_mom[1] * gamma[1] + sin_mom[2] * gamma[2] + sin_mom[3] * gamma[3])
|
|
||||||
p_slash /= 2 * (W - 1 + np.sqrt((1 - W) ** 2 + s2))
|
|
||||||
else:
|
|
||||||
raise Exception("Fermion type '" + fermion + "' not implemented")
|
|
||||||
|
|
||||||
res = 1 / 12. * np.trace(matmul(inv_prop, np.kron(np.eye(3, dtype=int), p_slash)))
|
|
||||||
res.gamma_method()
|
|
||||||
|
|
||||||
if not res.imag.is_zero_within_error(5):
|
|
||||||
warnings.warn("Imaginary part of Zq is not zero within 5 sigma")
|
|
||||||
return res
|
|
||||||
|
|
||||||
res.real.tag = "Zq '" + fermion + "', p=" + str(inv_prop.mom_in)
|
|
||||||
|
|
||||||
return res.real
|
|
|
@ -1578,38 +1578,6 @@ def pseudo_Obs(value, dvalue, name, samples=1000):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def dump_object(obj, name, **kwargs):
|
|
||||||
"""Dump object into pickle file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
obj : object
|
|
||||||
object to be saved in the pickle file
|
|
||||||
name : str
|
|
||||||
name of the file
|
|
||||||
path : str
|
|
||||||
specifies a custom path for the file (default '.')
|
|
||||||
"""
|
|
||||||
if 'path' in kwargs:
|
|
||||||
file_name = kwargs.get('path') + '/' + name + '.p'
|
|
||||||
else:
|
|
||||||
file_name = name + '.p'
|
|
||||||
with open(file_name, 'wb') as fb:
|
|
||||||
pickle.dump(obj, fb)
|
|
||||||
|
|
||||||
|
|
||||||
def load_object(path):
|
|
||||||
"""Load object from pickle file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : str
|
|
||||||
path to the file
|
|
||||||
"""
|
|
||||||
with open(path, 'rb') as file:
|
|
||||||
return pickle.load(file)
|
|
||||||
|
|
||||||
|
|
||||||
def import_jackknife(jacks, name, idl=None):
|
def import_jackknife(jacks, name, idl=None):
|
||||||
"""Imports jackknife samples and returns an Obs
|
"""Imports jackknife samples and returns an Obs
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import scipy.optimize
|
import scipy.optimize
|
||||||
from autograd import jacobian
|
from autograd import jacobian
|
||||||
from .obs import derived_observable, pseudo_Obs
|
from .obs import derived_observable
|
||||||
|
|
||||||
|
|
||||||
def find_root(d, func, guess=1.0, **kwargs):
|
def find_root(d, func, guess=1.0, **kwargs):
|
||||||
|
@ -33,4 +33,6 @@ def find_root(d, func, guess=1.0, **kwargs):
|
||||||
da = jacobian(lambda u, v: func(v, u))(d.value, root[0])
|
da = jacobian(lambda u, v: func(v, u))(d.value, root[0])
|
||||||
deriv = - da / dx
|
deriv = - da / dx
|
||||||
|
|
||||||
return derived_observable(lambda x, **kwargs: x[0], [pseudo_Obs(root, 0.0, d.names[0], d.shape[d.names[0]]), d], man_grad=[0, deriv])
|
res = derived_observable(lambda x, **kwargs: x[0], [d], man_grad=[deriv])
|
||||||
|
res._value = root[0]
|
||||||
|
return res
|
||||||
|
|
|
@ -17,3 +17,15 @@ def test_root_linear():
|
||||||
assert np.isclose(my_root.value, value)
|
assert np.isclose(my_root.value, value)
|
||||||
difference = my_obs - my_root
|
difference = my_obs - my_root
|
||||||
assert difference.is_zero()
|
assert difference.is_zero()
|
||||||
|
|
||||||
|
|
||||||
|
def test_root_linear_idl():
|
||||||
|
|
||||||
|
def root_function(x, d):
|
||||||
|
return x - d
|
||||||
|
|
||||||
|
my_obs = pe.Obs([np.random.rand(50)], ['t'], idl=[range(20, 120, 2)])
|
||||||
|
my_root = pe.roots.find_root(my_obs, root_function)
|
||||||
|
|
||||||
|
difference = my_obs - my_root
|
||||||
|
assert difference.is_zero()
|
||||||
|
|
Loading…
Add table
Reference in a new issue