From 2363b755dd1d77fa815515bfb3f7f3e082419211 Mon Sep 17 00:00:00 2001 From: nils-ht <127844041+nils-ht@users.noreply.github.com> Date: Fri, 17 Mar 2023 13:52:07 +0000 Subject: [PATCH] NHT changes plots combined fit (#166) * NHT changes plots combined fit * feat: Exception for illegal combination added and test fixed. --------- Co-authored-by: Fabian Joswig --- pyerrors/correlators.py | 14 ++++++++++---- tests/fits_test.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index c581fbe8..930120bc 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -788,7 +788,7 @@ class Corr: self.prange = prange return - def show(self, x_range=None, comp=None, y_range=None, logscale=False, plateau=None, fit_res=None, ylabel=None, save=None, auto_gamma=False, hide_sigma=None, references=None, title=None): + def show(self, x_range=None, comp=None, y_range=None, logscale=False, plateau=None, fit_res=None, fit_key=None, ylabel=None, save=None, auto_gamma=False, hide_sigma=None, references=None, title=None): """Plots the correlator using the tag of the correlator as label if available. Parameters @@ -804,6 +804,8 @@ class Corr: Plateau value to be visualized in the figure. fit_res : Fit_result Fit_result object to be visualized. + fit_key : str + Key for the fit function in Fit_result.fit_function (for combined fits). ylabel : str Label for the y-axis. save : str @@ -883,9 +885,13 @@ class Corr: if fit_res: x_samples = np.arange(x_range[0], x_range[1] + 1, 0.05) - ax1.plot(x_samples, - fit_res.fit_function([o.value for o in fit_res.fit_parameters], x_samples), - ls='-', marker=',', lw=2) + if isinstance(fit_res.fit_function, dict): + if fit_key: + ax1.plot(x_samples, fit_res.fit_function[fit_key]([o.value for o in fit_res.fit_parameters], x_samples), ls='-', marker=',', lw=2) + else: + raise ValueError("Please provide a 'fit_key' for visualizing combined fits.") + else: + ax1.plot(x_samples, fit_res.fit_function([o.value for o in fit_res.fit_parameters], x_samples), ls='-', marker=',', lw=2) ax1.set_xlabel(r'$x_0 / a$') if ylabel: diff --git a/tests/fits_test.py b/tests/fits_test.py index f9f9c773..48e788bd 100644 --- a/tests/fits_test.py +++ b/tests/fits_test.py @@ -681,6 +681,42 @@ def test_combined_fit_no_autograd(): pe.least_squares(xs, ys, funcs, num_grad=True) +def test_plot_combined_fit_function(): + + def func_exp1(x): + return 0.3*anp.exp(0.5*x) + + def func_exp2(x): + return 0.3*anp.exp(0.8*x) + + xvals_b = np.arange(0,6) + xvals_a = np.arange(0,8) + + def func_a(a,x): + return a[0]*anp.exp(a[1]*x) + + def func_b(a,x): + return a[0]*anp.exp(a[2]*x) + + corr_a = pe.Corr([pe.Obs([np.random.normal(item, item*1.5, 1000)],['ensemble1']) for item in func_exp1(xvals_a)]) + corr_b = pe.Corr([pe.Obs([np.random.normal(item, item*1.4, 1000)],['ensemble1']) for item in func_exp2(xvals_b)]) + + funcs = {'a':func_a, 'b':func_b} + xs = {'a':xvals_a, 'b':xvals_b} + ys = {'a': [o[0] for o in corr_a.content], + 'b': [o[0] for o in corr_b.content]} + + corr_a.gm() + corr_b.gm() + + comb_fit = pe.least_squares(xs, ys, funcs) + + with pytest.raises(ValueError): + corr_a.show(x_range=[xs["a"][0], xs["a"][-1]], fit_res=comb_fit) + + corr_a.show(x_range=[xs["a"][0], xs["a"][-1]], fit_res=comb_fit, fit_key="a") + corr_b.show(x_range=[xs["b"][0], xs["b"][-1]], fit_res=comb_fit, fit_key="b") + def test_combined_fit_invalid_fit_functions(): def func1(a, x):