diff --git a/pyerrors/fits.py b/pyerrors/fits.py index fbaca972..fd8a1a51 100644 --- a/pyerrors/fits.py +++ b/pyerrors/fits.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - import gc import warnings import numpy as np @@ -16,6 +13,13 @@ from autograd import elementwise_grad as egrad from .pyerrors import Obs, derived_observable, covariance, pseudo_Obs +class Fit_result: + + def __init__(self): + self.fit_parameters = None + + + def standard_fit(x, y, func, silent=False, **kwargs): """Performs a non-linear fit to y = func(x) and returns a list of Obs corresponding to the fit parameters. @@ -38,8 +42,6 @@ def standard_fit(x, y, func, silent=False, **kwargs): Keyword arguments ----------------- - dict_output -- If true, the output is a dictionary containing all relevant - data instead of just a list of the fit parameters. silent -- If true all output to the console is omitted (default False). initial_guess -- can provide an initial guess for the input parameters. Relevant for non-linear fits with many parameters. @@ -55,9 +57,10 @@ def standard_fit(x, y, func, silent=False, **kwargs): has to be calculated (default False). """ - result_dict = {} - result_dict['fit_function'] = func + output = Fit_result() + + output.fit_function = func x = np.asarray(x) @@ -102,7 +105,7 @@ def standard_fit(x, y, func, silent=False, **kwargs): return chisq if 'method' in kwargs: - result_dict['method'] = kwargs.get('method') + utput.method = kwargs.get('method') if not silent: print('Method:', kwargs.get('method')) if kwargs.get('method') == 'migrad': @@ -114,7 +117,7 @@ def standard_fit(x, y, func, silent=False, **kwargs): chisquare = fit_result.fun else: - result_dict['method'] = 'Levenberg-Marquardt' + output.method = 'Levenberg-Marquardt' if not silent: print('Method: Levenberg-Marquardt') @@ -131,13 +134,13 @@ def standard_fit(x, y, func, silent=False, **kwargs): raise Exception('The minimization procedure did not converge.') if x.shape[-1] - n_parms > 0: - result_dict['chisquare/d.o.f.'] = chisquare / (x.shape[-1] - n_parms) + output.chisquare_by_dof = chisquare / (x.shape[-1] - n_parms) else: - result_dict['chisquare/d.o.f.'] = float('nan') + output.chisquare_by_dof = float('nan') if not silent: print(fit_result.message) - print('chisquare/d.o.f.:', result_dict['chisquare/d.o.f.']) + print('chisquare/d.o.f.:', output.chisquare_by_dof) if kwargs.get('expected_chisquare') is True: W = np.diag(1 / np.asarray(dy_f)) @@ -145,10 +148,10 @@ def standard_fit(x, y, func, silent=False, **kwargs): A = W @ jacobian(func)(fit_result.x, x) P_phi = A @ np.linalg.inv(A.T @ A) @ A.T expected_chisquare = np.trace((np.identity(x.shape[-1]) - P_phi) @ W @ cov @ W) - result_dict['chisquare/expected_chisquare'] = chisquare / expected_chisquare + output.chisquare_by_expected_chisquare = chisquare / expected_chisquare if not silent: print('chisquare/expected_chisquare:', - result_dict['chisquare/expected_chisquare']) + output.chisquare_by_expected_chisquare) hess_inv = np.linalg.pinv(jacobian(jacobian(chisqfunc))(fit_result.x)) @@ -165,10 +168,10 @@ def standard_fit(x, y, func, silent=False, **kwargs): for i in range(n_parms): result.append(derived_observable(lambda x, **kwargs: x[0], [pseudo_Obs(fit_result.x[i], 0.0, y[0].names[0], y[0].shape[y[0].names[0]])] + list(y), man_grad=[0] + list(deriv[i]))) - result_dict['fit_parameters'] = result + output.fit_parameters = result - result_dict['chisquare'] = chisqfunc(fit_result.x) - result_dict['d.o.f.'] = x.shape[-1] - n_parms + output.chisquare = chisqfunc(fit_result.x) + output.dof = x.shape[-1] - n_parms if kwargs.get('resplot') is True: residual_plot(x, y, func, result) @@ -176,7 +179,7 @@ def standard_fit(x, y, func, silent=False, **kwargs): if kwargs.get('qqplot') is True: qqplot(x, y, func, result) - return result_dict if kwargs.get('dict_output') else result + return output def odr_fit(x, y, func, silent=False, **kwargs): @@ -215,9 +218,9 @@ def odr_fit(x, y, func, silent=False, **kwargs): has to be calculated (default False). """ - result_dict = {} + output = Fit_result() - result_dict['fit_function'] = func + output.fit_function = func x = np.array(x) @@ -262,16 +265,16 @@ def odr_fit(x, y, func, silent=False, **kwargs): odr.set_job(fit_type=0, deriv=1) output = odr.run() - result_dict['residual_variance'] = output.res_var + output.residual_variance = output.res_var - result_dict['method'] = 'ODR' + output.method = 'ODR' - result_dict['xplus'] = output.xplus + output.xplus = output.xplus if not silent: print('Method: ODR') print(*output.stopreason) - print('Residual variance:', result_dict['residual_variance']) + print('Residual variance:', output.residual_variance) if output.info > 3: raise Exception('The minimization procedure did not converge.') @@ -304,10 +307,10 @@ def odr_fit(x, y, func, silent=False, **kwargs): if expected_chisquare <= 0.0: warnings.warn("Negative expected_chisquare.", RuntimeWarning) expected_chisquare = np.abs(expected_chisquare) - result_dict['chisquare/expected_chisquare'] = odr_chisquare(np.concatenate((output.beta, output.xplus.ravel()))) / expected_chisquare + output.chisquare_by_expected_chisquare = odr_chisquare(np.concatenate((output.beta, output.xplus.ravel()))) / expected_chisquare if not silent: print('chisquare/expected_chisquare:', - result_dict['chisquare/expected_chisquare']) + output.chisquare_by_expected_chisquare) hess_inv = np.linalg.pinv(jacobian(jacobian(odr_chisquare))(np.concatenate((output.beta, output.xplus.ravel())))) @@ -333,12 +336,12 @@ def odr_fit(x, y, func, silent=False, **kwargs): for i in range(n_parms): result.append(derived_observable(lambda x, **kwargs: x[0], [pseudo_Obs(output.beta[i], 0.0, y[0].names[0], y[0].shape[y[0].names[0]])] + list(x.ravel()) + list(y), man_grad=[0] + list(deriv_x[i]) + list(deriv_y[i]))) - result_dict['fit_parameters'] = result + output.fit_parameters = result - result_dict['odr_chisquare'] = odr_chisquare(np.concatenate((output.beta, output.xplus.ravel()))) - result_dict['d.o.f.'] = x.shape[-1] - n_parms + output.odr_chisquare = odr_chisquare(np.concatenate((output.beta, output.xplus.ravel()))) + output.dof = x.shape[-1] - n_parms - return result_dict if kwargs.get('dict_output') else result + return output def prior_fit(x, y, func, priors, silent=False, **kwargs): @@ -375,9 +378,9 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs): tol -- Specify the tolerance of the migrad solver (default 1e-4) """ - result_dict = {} + output = Fit_result() - result_dict['fit_function'] = func + output.fit_function = func if Obs.e_tag_global < 4: warnings.warn("e_tag_global is smaller than 4, this can cause problems when calculating errors from fits with priors", RuntimeWarning) @@ -416,7 +419,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs): loc_val, loc_dval = extract_val_and_dval(i_prior) loc_priors.append(pseudo_Obs(loc_val, loc_dval, 'p' + str(i_n))) - result_dict['priors'] = loc_priors + output.priors = loc_priors if not silent: print('Fit with', n_parms, 'parameters') @@ -456,12 +459,12 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs): m.migrad() params = np.asarray(m.values.values()) - result_dict['chisquare/d.o.f.'] = m.fval / len(x) + output.chisquare_by_dof = m.fval / len(x) - result_dict['method'] = 'migrad' + output.method = 'migrad' if not silent: - print('chisquare/d.o.f.:', result_dict['chisquare/d.o.f.']) + print('chisquare/d.o.f.:', output.chisquare_by_dof) if not m.get_fmin().is_valid: raise Exception('The minimization procedure did not converge.') @@ -481,8 +484,8 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs): for i in range(n_parms): result.append(derived_observable(lambda x, **kwargs: x[0], [pseudo_Obs(params[i], 0.0, y[0].names[0], y[0].shape[y[0].names[0]])] + list(y) + list(loc_priors), man_grad=[0] + list(deriv[i]))) - result_dict['fit_parameters'] = result - result_dict['chisquare'] = chisqfunc(np.asarray(params)) + output.fit_parameters = result + output.chisquare = chisqfunc(np.asarray(params)) if kwargs.get('resplot') is True: residual_plot(x, y, func, result) @@ -490,7 +493,7 @@ def prior_fit(x, y, func, priors, silent=False, **kwargs): if kwargs.get('qqplot') is True: qqplot(x, y, func, result) - return result_dict if kwargs.get('dict_output') else result + return output def fit_lin(x, y, **kwargs): @@ -506,9 +509,11 @@ def fit_lin(x, y, **kwargs): return y if all(isinstance(n, Obs) for n in x): - return odr_fit(x, y, f, **kwargs) + out = odr_fit(x, y, f, **kwargs) + return out.fit_parameters elif all(isinstance(n, float) or isinstance(n, int) for n in x) or isinstance(x, np.ndarray): - return standard_fit(x, y, f, **kwargs) + out = standard_fit(x, y, f, **kwargs) + return out.fit_parameters else: raise Exception('Unsupported types for x')