mirror of
				https://github.com/fjosw/pyerrors.git
				synced 2025-11-04 09:35:45 +01:00 
			
		
		
		
	refactor!: fit_general deprecated and moved to tests
This commit is contained in:
		
					parent
					
						
							
								5a8b6483c8
							
						
					
				
			
			
				commit
				
					
						87c50f54c0
					
				
			
		
					 2 changed files with 101 additions and 105 deletions
				
			
		
							
								
								
									
										105
									
								
								pyerrors/fits.py
									
										
									
									
									
								
							
							
						
						
									
										105
									
								
								pyerrors/fits.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -646,10 +646,10 @@ def fit_lin(x, y, **kwargs):
 | 
			
		|||
        return y
 | 
			
		||||
 | 
			
		||||
    if all(isinstance(n, Obs) for n in x):
 | 
			
		||||
        out = odr_fit(x, y, f, **kwargs)
 | 
			
		||||
        out = total_least_squares(x, y, f, **kwargs)
 | 
			
		||||
        return out.fit_parameters
 | 
			
		||||
    elif all(isinstance(n, float) or isinstance(n, int) for n in x) or isinstance(x, np.ndarray):
 | 
			
		||||
        out = standard_fit(x, y, f, **kwargs)
 | 
			
		||||
        out = least_squares(x, y, f, **kwargs)
 | 
			
		||||
        return out.fit_parameters
 | 
			
		||||
    else:
 | 
			
		||||
        raise Exception('Unsupported types for x')
 | 
			
		||||
| 
						 | 
				
			
			@ -785,104 +785,3 @@ def ks_test(obs=None):
 | 
			
		|||
    plt.draw()
 | 
			
		||||
 | 
			
		||||
    print(scipy.stats.kstest(Qs, 'uniform'))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fit_general(x, y, func, silent=False, **kwargs):
 | 
			
		||||
    """Performs a non-linear fit to y = func(x) and returns a list of Obs corresponding to the fit parameters.
 | 
			
		||||
 | 
			
		||||
    Plausibility of the results should be checked. To control the numerical differentiation
 | 
			
		||||
    the kwargs of numdifftools.step_generators.MaxStepGenerator can be used.
 | 
			
		||||
 | 
			
		||||
    func has to be of the form
 | 
			
		||||
 | 
			
		||||
    def func(a, x):
 | 
			
		||||
        y = a[0] + a[1] * x + a[2] * np.sinh(x)
 | 
			
		||||
        return y
 | 
			
		||||
 | 
			
		||||
    y has to be a list of Obs, the dvalues of the Obs are used as yerror for the fit.
 | 
			
		||||
    x can either be a list of floats in which case no xerror is assumed, or
 | 
			
		||||
    a list of Obs, where the dvalues of the Obs are used as xerror for the fit.
 | 
			
		||||
 | 
			
		||||
    Keyword arguments
 | 
			
		||||
    -----------------
 | 
			
		||||
    silent -- If true all output to the console is omitted (default False).
 | 
			
		||||
    initial_guess -- can provide an initial guess for the input parameters. Relevant for non-linear fits
 | 
			
		||||
                     with many parameters.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    warnings.warn("New fit functions with exact error propagation are now available as alternative.", DeprecationWarning)
 | 
			
		||||
 | 
			
		||||
    if not callable(func):
 | 
			
		||||
        raise TypeError('func has to be a function.')
 | 
			
		||||
 | 
			
		||||
    for i in range(10):
 | 
			
		||||
        try:
 | 
			
		||||
            func(np.arange(i), 0)
 | 
			
		||||
        except:
 | 
			
		||||
            pass
 | 
			
		||||
        else:
 | 
			
		||||
            break
 | 
			
		||||
    n_parms = i
 | 
			
		||||
    if not silent:
 | 
			
		||||
        print('Fit with', n_parms, 'parameters')
 | 
			
		||||
 | 
			
		||||
    global print_output, beta0
 | 
			
		||||
    print_output = 1
 | 
			
		||||
    if 'initial_guess' in kwargs:
 | 
			
		||||
        beta0 = kwargs.get('initial_guess')
 | 
			
		||||
        if len(beta0) != n_parms:
 | 
			
		||||
            raise Exception('Initial guess does not have the correct length.')
 | 
			
		||||
    else:
 | 
			
		||||
        beta0 = np.arange(n_parms)
 | 
			
		||||
 | 
			
		||||
    if len(x) != len(y):
 | 
			
		||||
        raise Exception('x and y have to have the same length')
 | 
			
		||||
 | 
			
		||||
    if all(isinstance(n, Obs) for n in x):
 | 
			
		||||
        obs = x + y
 | 
			
		||||
        x_constants = None
 | 
			
		||||
        xerr = [o.dvalue for o in x]
 | 
			
		||||
        yerr = [o.dvalue for o in y]
 | 
			
		||||
    elif all(isinstance(n, float) or isinstance(n, int) for n in x) or isinstance(x, np.ndarray):
 | 
			
		||||
        obs = y
 | 
			
		||||
        x_constants = x
 | 
			
		||||
        xerr = None
 | 
			
		||||
        yerr = [o.dvalue for o in y]
 | 
			
		||||
    else:
 | 
			
		||||
        raise Exception('Unsupported types for x')
 | 
			
		||||
 | 
			
		||||
    def do_the_fit(obs, **kwargs):
 | 
			
		||||
 | 
			
		||||
        global print_output, beta0
 | 
			
		||||
 | 
			
		||||
        func = kwargs.get('function')
 | 
			
		||||
        yerr = kwargs.get('yerr')
 | 
			
		||||
        length = len(yerr)
 | 
			
		||||
 | 
			
		||||
        xerr = kwargs.get('xerr')
 | 
			
		||||
 | 
			
		||||
        if length == len(obs):
 | 
			
		||||
            assert 'x_constants' in kwargs
 | 
			
		||||
            data = RealData(kwargs.get('x_constants'), obs, sy=yerr)
 | 
			
		||||
            fit_type = 2
 | 
			
		||||
        elif length == len(obs) // 2:
 | 
			
		||||
            data = RealData(obs[:length], obs[length:], sx=xerr, sy=yerr)
 | 
			
		||||
            fit_type = 0
 | 
			
		||||
        else:
 | 
			
		||||
            raise Exception('x and y do not fit together.')
 | 
			
		||||
 | 
			
		||||
        model = Model(func)
 | 
			
		||||
 | 
			
		||||
        odr = ODR(data, model, beta0, partol=np.finfo(np.float64).eps)
 | 
			
		||||
        odr.set_job(fit_type=fit_type, deriv=1)
 | 
			
		||||
        output = odr.run()
 | 
			
		||||
        if print_output and not silent:
 | 
			
		||||
            print(*output.stopreason)
 | 
			
		||||
            print('chisquare/d.o.f.:', output.res_var)
 | 
			
		||||
            print_output = 0
 | 
			
		||||
        beta0 = output.beta
 | 
			
		||||
        return output.beta[kwargs.get('n')]
 | 
			
		||||
    res = []
 | 
			
		||||
    for n in range(n_parms):
 | 
			
		||||
        res.append(derived_observable(do_the_fit, obs, function=func, xerr=xerr, yerr=yerr, x_constants=x_constants, num_grad=True, n=n, **kwargs))
 | 
			
		||||
    return res
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -232,8 +232,7 @@ def test_odr_derivatives():
 | 
			
		|||
    out = pe.total_least_squares(x, y, func)
 | 
			
		||||
    fit1 = out.fit_parameters
 | 
			
		||||
 | 
			
		||||
    with pytest.warns(DeprecationWarning):
 | 
			
		||||
        tfit = pe.fits.fit_general(x, y, func, base_step=0.1, step_ratio=1.1, num_steps=20)
 | 
			
		||||
    tfit = fit_general(x, y, func, base_step=0.1, step_ratio=1.1, num_steps=20)
 | 
			
		||||
    assert np.abs(np.max(np.array(list(fit1[1].deltas.values()))
 | 
			
		||||
                  - np.array(list(tfit[1].deltas.values())))) < 10e-8
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -274,3 +273,101 @@ def test_r_value_persistence():
 | 
			
		|||
    assert np.isclose(fitp[1].value, fitp[1].r_values['a'])
 | 
			
		||||
    assert np.isclose(fitp[1].value, fitp[1].r_values['b'])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fit_general(x, y, func, silent=False, **kwargs):
 | 
			
		||||
    """Performs a non-linear fit to y = func(x) and returns a list of Obs corresponding to the fit parameters.
 | 
			
		||||
 | 
			
		||||
    Plausibility of the results should be checked. To control the numerical differentiation
 | 
			
		||||
    the kwargs of numdifftools.step_generators.MaxStepGenerator can be used.
 | 
			
		||||
 | 
			
		||||
    func has to be of the form
 | 
			
		||||
 | 
			
		||||
    def func(a, x):
 | 
			
		||||
        y = a[0] + a[1] * x + a[2] * np.sinh(x)
 | 
			
		||||
        return y
 | 
			
		||||
 | 
			
		||||
    y has to be a list of Obs, the dvalues of the Obs are used as yerror for the fit.
 | 
			
		||||
    x can either be a list of floats in which case no xerror is assumed, or
 | 
			
		||||
    a list of Obs, where the dvalues of the Obs are used as xerror for the fit.
 | 
			
		||||
 | 
			
		||||
    Keyword arguments
 | 
			
		||||
    -----------------
 | 
			
		||||
    silent -- If true all output to the console is omitted (default False).
 | 
			
		||||
    initial_guess -- can provide an initial guess for the input parameters. Relevant for non-linear fits
 | 
			
		||||
                     with many parameters.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if not callable(func):
 | 
			
		||||
        raise TypeError('func has to be a function.')
 | 
			
		||||
 | 
			
		||||
    for i in range(10):
 | 
			
		||||
        try:
 | 
			
		||||
            func(np.arange(i), 0)
 | 
			
		||||
        except:
 | 
			
		||||
            pass
 | 
			
		||||
        else:
 | 
			
		||||
            break
 | 
			
		||||
    n_parms = i
 | 
			
		||||
    if not silent:
 | 
			
		||||
        print('Fit with', n_parms, 'parameters')
 | 
			
		||||
 | 
			
		||||
    global print_output, beta0
 | 
			
		||||
    print_output = 1
 | 
			
		||||
    if 'initial_guess' in kwargs:
 | 
			
		||||
        beta0 = kwargs.get('initial_guess')
 | 
			
		||||
        if len(beta0) != n_parms:
 | 
			
		||||
            raise Exception('Initial guess does not have the correct length.')
 | 
			
		||||
    else:
 | 
			
		||||
        beta0 = np.arange(n_parms)
 | 
			
		||||
 | 
			
		||||
    if len(x) != len(y):
 | 
			
		||||
        raise Exception('x and y have to have the same length')
 | 
			
		||||
 | 
			
		||||
    if all(isinstance(n, pe.Obs) for n in x):
 | 
			
		||||
        obs = x + y
 | 
			
		||||
        x_constants = None
 | 
			
		||||
        xerr = [o.dvalue for o in x]
 | 
			
		||||
        yerr = [o.dvalue for o in y]
 | 
			
		||||
    elif all(isinstance(n, float) or isinstance(n, int) for n in x) or isinstance(x, np.ndarray):
 | 
			
		||||
        obs = y
 | 
			
		||||
        x_constants = x
 | 
			
		||||
        xerr = None
 | 
			
		||||
        yerr = [o.dvalue for o in y]
 | 
			
		||||
    else:
 | 
			
		||||
        raise Exception('Unsupported types for x')
 | 
			
		||||
 | 
			
		||||
    def do_the_fit(obs, **kwargs):
 | 
			
		||||
 | 
			
		||||
        global print_output, beta0
 | 
			
		||||
 | 
			
		||||
        func = kwargs.get('function')
 | 
			
		||||
        yerr = kwargs.get('yerr')
 | 
			
		||||
        length = len(yerr)
 | 
			
		||||
 | 
			
		||||
        xerr = kwargs.get('xerr')
 | 
			
		||||
 | 
			
		||||
        if length == len(obs):
 | 
			
		||||
            assert 'x_constants' in kwargs
 | 
			
		||||
            data = RealData(kwargs.get('x_constants'), obs, sy=yerr)
 | 
			
		||||
            fit_type = 2
 | 
			
		||||
        elif length == len(obs) // 2:
 | 
			
		||||
            data = RealData(obs[:length], obs[length:], sx=xerr, sy=yerr)
 | 
			
		||||
            fit_type = 0
 | 
			
		||||
        else:
 | 
			
		||||
            raise Exception('x and y do not fit together.')
 | 
			
		||||
 | 
			
		||||
        model = Model(func)
 | 
			
		||||
 | 
			
		||||
        odr = ODR(data, model, beta0, partol=np.finfo(np.float64).eps)
 | 
			
		||||
        odr.set_job(fit_type=fit_type, deriv=1)
 | 
			
		||||
        output = odr.run()
 | 
			
		||||
        if print_output and not silent:
 | 
			
		||||
            print(*output.stopreason)
 | 
			
		||||
            print('chisquare/d.o.f.:', output.res_var)
 | 
			
		||||
            print_output = 0
 | 
			
		||||
        beta0 = output.beta
 | 
			
		||||
        return output.beta[kwargs.get('n')]
 | 
			
		||||
    res = []
 | 
			
		||||
    for n in range(n_parms):
 | 
			
		||||
        res.append(pe.derived_observable(do_the_fit, obs, function=func, xerr=xerr, yerr=yerr, x_constants=x_constants, num_grad=True, n=n, **kwargs))
 | 
			
		||||
    return res
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue