From 3e29cf9ca8bb4b7cd6745085bb0af3d24e89488f Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Fri, 24 Jun 2022 12:50:26 +0100 Subject: [PATCH] fix: detection of invalid fit functions extended. --- pyerrors/fits.py | 16 +++++++++++----- tests/fits_test.py | 25 +++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/pyerrors/fits.py b/pyerrors/fits.py index 997c7f6f..eef16518 100644 --- a/pyerrors/fits.py +++ b/pyerrors/fits.py @@ -177,13 +177,15 @@ def total_least_squares(x, y, func, silent=False, **kwargs): if not callable(func): raise TypeError('func has to be a function.') - for i in range(25): + for i in range(42): try: func(np.arange(i), x.T[0]) except Exception: - pass + continue else: break + else: + raise RuntimeError("Fit function is not valid.") n_parms = i if not silent: @@ -321,9 +323,11 @@ def _prior_fit(x, y, func, priors, silent=False, **kwargs): try: func(np.arange(i), 0) except Exception: - pass + continue else: break + else: + raise RuntimeError("Fit function is not valid.") n_parms = i @@ -442,13 +446,15 @@ def _standard_fit(x, y, func, silent=False, **kwargs): if not callable(func): raise TypeError('func has to be a function.') - for i in range(25): + for i in range(42): try: func(np.arange(i), x.T[0]) except Exception: - pass + continue else: break + else: + raise RuntimeError("Fit function is not valid.") n_parms = i diff --git a/tests/fits_test.py b/tests/fits_test.py index 6cb806d2..c578a86d 100644 --- a/tests/fits_test.py +++ b/tests/fits_test.py @@ -495,6 +495,31 @@ def test_fit_no_autograd(): pe.total_least_squares(oy, oy, func) +def test_invalid_fit_function(): + def func1(a, x): + return a[0] + a[1] * x + a[2] * anp.sinh(x) + a[199] + + def func2(a, x, y): + return a[0] + a[1] * x + + def func3(x): + return x + + xvals =[] + yvals =[] + err = 0.1 + + for x in range(1, 8, 2): + xvals.append(x) + yvals.append(pe.pseudo_Obs(x + np.random.normal(0.0, err), err, 'test1') + pe.pseudo_Obs(0, err / 100, 'test2', samples=87)) + [o.gamma_method() for o in yvals] + for func in [func1, func2, func3]: + with pytest.raises(Exception): + pe.least_squares(xvals, yvals, func) + with pytest.raises(Exception): + pe.total_least_squares(yvals, yvals, func) + + def test_singular_correlated_fit(): obs1 = pe.pseudo_Obs(1.0, 0.1, 'test') with pytest.raises(Exception):