fix: Combined fit can now handle list and array inputs for y-values, test added.

This commit is contained in:
Fabian Joswig 2022-12-19 16:06:12 +01:00
parent 140b626aae
commit 33ff2219ba
No known key found for this signature in database
2 changed files with 12 additions and 15 deletions

View file

@ -703,12 +703,8 @@ def _combined_fit(x, y, func, silent=False, **kwargs):
jacobian = auto_jacobian
hessian = auto_hessian
x_all = []
y_all = []
for key in x.keys():
y_all += y[key]
x_all = np.concatenate([np.array(o) for o in x.values()])
y_all = np.concatenate([np.array(o) for o in y.values()])
if len(x_all.shape) > 2:
raise Exception('Unknown format for x values')

View file

@ -610,18 +610,19 @@ def test_ks_test():
def test_combined_fit_list_v_array():
res = []
y_test = {'a': [pe.Obs([np.random.normal(i, 0.5, 1000)], ['ensemble1']) for i in range(1, 7)]}
for x_test in [{'a': [0, 1, 2, 3, 4, 5]}, {'a': np.arange(6)}]:
for key in y_test.keys():
[item.gamma_method() for item in y_test[key]]
def func_a(a, x):
return a[1] * x + a[0]
for y_test in [{'a': [pe.Obs([np.random.normal(i, 0.5, 1000)], ['ensemble1']) for i in range(1, 7)]},
{'a': np.array([pe.Obs([np.random.normal(i, 0.5, 1000)], ['ensemble1']) for i in range(1, 7)])}]:
for x_test in [{'a': [0, 1, 2, 3, 4, 5]}, {'a': np.arange(6)}]:
for key in y_test.keys():
[item.gamma_method() for item in y_test[key]]
def func_a(a, x):
return a[1] * x + a[0]
funcs_test = {"a": func_a}
res.append(pe.fits.least_squares(x_test, y_test, funcs_test))
funcs_test = {"a": func_a}
res.append(pe.fits.least_squares(x_test, y_test, funcs_test))
assert (res[0][0] - res[1][0]).is_zero(atol=1e-8)
assert (res[0][1] - res[1][1]).is_zero(atol=1e-8)
assert (res[0][0] - res[1][0]).is_zero(atol=1e-8)
assert (res[0][1] - res[1][1]).is_zero(atol=1e-8)
def fit_general(x, y, func, silent=False, **kwargs):