Merge pull request #142 from fjosw/feat/root_of_multi_parameter_functions

Root of multi parameter functions
This commit is contained in:
Fabian Joswig 2023-01-10 10:34:36 +00:00 committed by GitHub
commit 569bf8c2f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 5 deletions

View file

@ -27,15 +27,17 @@ def find_root(d, func, guess=1.0, **kwargs):
Obs Obs
`Obs` valued root of the function. `Obs` valued root of the function.
''' '''
root = scipy.optimize.fsolve(func, guess, d.value) d_val = np.vectorize(lambda x: x.value)(np.array(d))
root = scipy.optimize.fsolve(func, guess, d_val)
# Error propagation as detailed in arXiv:1809.01289 # Error propagation as detailed in arXiv:1809.01289
dx = jacobian(func)(root[0], d.value) dx = jacobian(func)(root[0], d_val)
try: try:
da = jacobian(lambda u, v: func(v, u))(d.value, root[0]) da = jacobian(lambda u, v: func(v, u))(d_val, root[0])
except TypeError: except TypeError:
raise Exception("It is required to use autograd.numpy instead of numpy within root functions, see the documentation for details.") from None raise Exception("It is required to use autograd.numpy instead of numpy within root functions, see the documentation for details.") from None
deriv = - da / dx deriv = - da / dx
res = derived_observable(lambda x, **kwargs: (x[0] + np.finfo(np.float64).eps) / (np.array(d).reshape(-1)[0].value + np.finfo(np.float64).eps) * root[0],
res = derived_observable(lambda x, **kwargs: (x[0] + np.finfo(np.float64).eps) / (d.value + np.finfo(np.float64).eps) * root[0], [d], man_grad=[deriv]) np.array(d).reshape(-1), man_grad=np.array(deriv).reshape(-1))
return res return res

View file

@ -42,3 +42,13 @@ def test_root_no_autograd():
with pytest.raises(Exception): with pytest.raises(Exception):
my_root = pe.roots.find_root(my_obs, root_function) my_root = pe.roots.find_root(my_obs, root_function)
def test_root_multi_parameter():
o1 = pe.pseudo_Obs(1.1, 0.1, "test")
o2 = pe.pseudo_Obs(1.3, 0.12, "test")
f2 = lambda x, d: d[0] + d[1] * x
assert f2(-o1 / o2, [o1, o2]) == 0
assert pe.find_root([o1, o2], f2) == -o1 / o2