odr_fit renamed, deprecation warning added

least_squares and total_least_squares are now available to the top level
namespace
This commit is contained in:
Fabian Joswig 2021-11-01 12:01:46 +00:00
parent 1013307e24
commit 0e8fc7d36a
3 changed files with 11 additions and 5 deletions

View file

@ -1,7 +1,7 @@
from .pyerrors import *
from .fits import *
from . import correlators
from . import dirac
from . import fits
from . import linalg
from . import misc
from . import mpm

View file

@ -214,6 +214,11 @@ def standard_fit(x, y, func, silent=False, **kwargs):
def odr_fit(x, y, func, silent=False, **kwargs):
warnings.warn("odr_fit renamed to total_least_squares", DeprecationWarning)
return total_least_squares(x, y, func, silent=silent, **kwargs)
def total_least_squares(x, y, func, silent=False, **kwargs):
"""Performs a non-linear fit to y = func(x) and returns a list of Obs corresponding to the fit parameters.
x has to be a list of Obs, or a tuple of lists of Obs

View file

@ -27,7 +27,7 @@ def test_standard_fit():
y = a[0] * np.exp(-a[1] * x)
return y
out = pe.fits.standard_fit(x, oy, func)
out = pe.least_squares(x, oy, func)
beta = out.fit_parameters
pe.Obs.e_tag_global = 5
@ -71,7 +71,7 @@ def test_odr_fit():
odr.set_job(fit_type=0, deriv=1)
output = odr.run()
out = pe.fits.odr_fit(ox, oy, func)
out = pe.total_least_squares(ox, oy, func)
beta = out.fit_parameters
pe.Obs.e_tag_global = 5
@ -97,9 +97,10 @@ def test_odr_derivatives():
def func(a, x):
return a[0] + a[1] * x ** 2
out = pe.fits.odr_fit(x, y, func)
out = pe.total_least_squares(x, y, func)
fit1 = out.fit_parameters
with pytest.warns(DeprecationWarning):
tfit = pe.fits.fit_general(x, y, func, base_step=0.1, step_ratio=1.1, num_steps=20)
assert np.abs(np.max(np.array(list(fit1[1].deltas.values()))
- np.array(list(tfit[1].deltas.values())))) < 10e-8