From 82cd2f11eaa83388179793bbcb48c2004599c1e6 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Fri, 3 Mar 2023 16:35:26 +0000 Subject: [PATCH] fix: multi dim fits fixed in least squares. Test added. (#160) Co-authored-by: Simon Kuberski --- pyerrors/fits.py | 2 +- tests/fits_test.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pyerrors/fits.py b/pyerrors/fits.py index dbbebc61..0ddfd54a 100644 --- a/pyerrors/fits.py +++ b/pyerrors/fits.py @@ -554,7 +554,7 @@ def _combined_fit(x, y, func, silent=False, **kwargs): for key in key_ls: if not callable(funcd[key]): raise TypeError('func (key=' + key + ') is not a function.') - if len(xd[key]) != len(yd[key]): + if np.asarray(xd[key]).shape[-1] != len(yd[key]): raise Exception('x and y input (key=' + key + ') do not have the same length') for i in range(100): try: diff --git a/tests/fits_test.py b/tests/fits_test.py index d596d89e..206f5f40 100644 --- a/tests/fits_test.py +++ b/tests/fits_test.py @@ -913,6 +913,17 @@ def test_combined_resplot_qqplot(): plt.close('all') +def test_x_multidim_fit(): + x1 = np.arange(1, 10) + x = np.array([[xi, xi] for xi in x1]).T + y = [pe.pseudo_Obs(i + 2 / i, .1 * i, 't') for i in x[0]] + [o.gm() for o in y] + def fitf(a, x): + return a[0] * x[0] + a[1] / x[1] + + pe.fits.least_squares(x, y, fitf) + + def fit_general(x, y, func, silent=False, **kwargs): """Performs a non-linear fit to y = func(x) and returns a list of Obs corresponding to the fit parameters.