mirror of
				https://github.com/fjosw/pyerrors.git
				synced 2025-10-31 15:55:45 +01:00 
			
		
		
		
	NHT changes plots combined fit (#166)
* NHT changes plots combined fit * feat: Exception for illegal combination added and test fixed. --------- Co-authored-by: Fabian Joswig <fabian.joswig@ed.ac.uk>
This commit is contained in:
		
					parent
					
						
							
								83204ce794
							
						
					
				
			
			
				commit
				
					
						2363b755dd
					
				
			
		
					 2 changed files with 46 additions and 4 deletions
				
			
		|  | @ -788,7 +788,7 @@ class Corr: | ||||||
|         self.prange = prange |         self.prange = prange | ||||||
|         return |         return | ||||||
| 
 | 
 | ||||||
|     def show(self, x_range=None, comp=None, y_range=None, logscale=False, plateau=None, fit_res=None, ylabel=None, save=None, auto_gamma=False, hide_sigma=None, references=None, title=None): |     def show(self, x_range=None, comp=None, y_range=None, logscale=False, plateau=None, fit_res=None, fit_key=None, ylabel=None, save=None, auto_gamma=False, hide_sigma=None, references=None, title=None): | ||||||
|         """Plots the correlator using the tag of the correlator as label if available. |         """Plots the correlator using the tag of the correlator as label if available. | ||||||
| 
 | 
 | ||||||
|         Parameters |         Parameters | ||||||
|  | @ -804,6 +804,8 @@ class Corr: | ||||||
|             Plateau value to be visualized in the figure. |             Plateau value to be visualized in the figure. | ||||||
|         fit_res : Fit_result |         fit_res : Fit_result | ||||||
|             Fit_result object to be visualized. |             Fit_result object to be visualized. | ||||||
|  |         fit_key : str | ||||||
|  |             Key for the fit function in Fit_result.fit_function (for combined fits). | ||||||
|         ylabel : str |         ylabel : str | ||||||
|             Label for the y-axis. |             Label for the y-axis. | ||||||
|         save : str |         save : str | ||||||
|  | @ -883,9 +885,13 @@ class Corr: | ||||||
| 
 | 
 | ||||||
|         if fit_res: |         if fit_res: | ||||||
|             x_samples = np.arange(x_range[0], x_range[1] + 1, 0.05) |             x_samples = np.arange(x_range[0], x_range[1] + 1, 0.05) | ||||||
|             ax1.plot(x_samples, |             if isinstance(fit_res.fit_function, dict): | ||||||
|                      fit_res.fit_function([o.value for o in fit_res.fit_parameters], x_samples), |                 if fit_key: | ||||||
|                      ls='-', marker=',', lw=2) |                     ax1.plot(x_samples, fit_res.fit_function[fit_key]([o.value for o in fit_res.fit_parameters], x_samples), ls='-', marker=',', lw=2) | ||||||
|  |                 else: | ||||||
|  |                     raise ValueError("Please provide a 'fit_key' for visualizing combined fits.") | ||||||
|  |             else: | ||||||
|  |                 ax1.plot(x_samples, fit_res.fit_function([o.value for o in fit_res.fit_parameters], x_samples), ls='-', marker=',', lw=2) | ||||||
| 
 | 
 | ||||||
|         ax1.set_xlabel(r'$x_0 / a$') |         ax1.set_xlabel(r'$x_0 / a$') | ||||||
|         if ylabel: |         if ylabel: | ||||||
|  |  | ||||||
|  | @ -681,6 +681,42 @@ def test_combined_fit_no_autograd(): | ||||||
| 
 | 
 | ||||||
|     pe.least_squares(xs, ys, funcs, num_grad=True) |     pe.least_squares(xs, ys, funcs, num_grad=True) | ||||||
| 
 | 
 | ||||||
|  | def test_plot_combined_fit_function(): | ||||||
|  | 
 | ||||||
|  |     def func_exp1(x): | ||||||
|  |         return 0.3*anp.exp(0.5*x) | ||||||
|  | 
 | ||||||
|  |     def func_exp2(x): | ||||||
|  |         return 0.3*anp.exp(0.8*x) | ||||||
|  | 
 | ||||||
|  |     xvals_b = np.arange(0,6) | ||||||
|  |     xvals_a = np.arange(0,8) | ||||||
|  | 
 | ||||||
|  |     def func_a(a,x): | ||||||
|  |         return a[0]*anp.exp(a[1]*x) | ||||||
|  | 
 | ||||||
|  |     def func_b(a,x): | ||||||
|  |         return a[0]*anp.exp(a[2]*x) | ||||||
|  | 
 | ||||||
|  |     corr_a = pe.Corr([pe.Obs([np.random.normal(item, item*1.5, 1000)],['ensemble1']) for item in func_exp1(xvals_a)]) | ||||||
|  |     corr_b = pe.Corr([pe.Obs([np.random.normal(item, item*1.4, 1000)],['ensemble1']) for item in func_exp2(xvals_b)]) | ||||||
|  | 
 | ||||||
|  |     funcs = {'a':func_a, 'b':func_b} | ||||||
|  |     xs = {'a':xvals_a, 'b':xvals_b} | ||||||
|  |     ys = {'a': [o[0] for o in corr_a.content], | ||||||
|  |           'b': [o[0] for o in corr_b.content]} | ||||||
|  | 
 | ||||||
|  |     corr_a.gm() | ||||||
|  |     corr_b.gm() | ||||||
|  | 
 | ||||||
|  |     comb_fit = pe.least_squares(xs, ys, funcs) | ||||||
|  | 
 | ||||||
|  |     with pytest.raises(ValueError): | ||||||
|  |         corr_a.show(x_range=[xs["a"][0], xs["a"][-1]], fit_res=comb_fit) | ||||||
|  | 
 | ||||||
|  |     corr_a.show(x_range=[xs["a"][0], xs["a"][-1]], fit_res=comb_fit, fit_key="a") | ||||||
|  |     corr_b.show(x_range=[xs["b"][0], xs["b"][-1]], fit_res=comb_fit, fit_key="b") | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def test_combined_fit_invalid_fit_functions(): | def test_combined_fit_invalid_fit_functions(): | ||||||
|     def func1(a, x): |     def func1(a, x): | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue