feat: multi parameter root feature added.

This commit is contained in:
Fabian Joswig 2023-01-10 10:28:12 +00:00
parent 43cf9c29d4
commit a45e20b51c
No known key found for this signature in database

View file

@ -27,15 +27,17 @@ def find_root(d, func, guess=1.0, **kwargs):
Obs Obs
`Obs` valued root of the function. `Obs` valued root of the function.
''' '''
root = scipy.optimize.fsolve(func, guess, d.value) d_val = np.vectorize(lambda x: x.value)(np.array(d))
root = scipy.optimize.fsolve(func, guess, d_val)
# Error propagation as detailed in arXiv:1809.01289 # Error propagation as detailed in arXiv:1809.01289
dx = jacobian(func)(root[0], d.value) dx = jacobian(func)(root[0], d_val)
try: try:
da = jacobian(lambda u, v: func(v, u))(d.value, root[0]) da = jacobian(lambda u, v: func(v, u))(d_val, root[0])
except TypeError: except TypeError:
raise Exception("It is required to use autograd.numpy instead of numpy within root functions, see the documentation for details.") from None raise Exception("It is required to use autograd.numpy instead of numpy within root functions, see the documentation for details.") from None
deriv = - da / dx deriv = - da / dx
res = derived_observable(lambda x, **kwargs: (x[0] + np.finfo(np.float64).eps) / (np.array(d).reshape(-1)[0].value + np.finfo(np.float64).eps) * root[0],
res = derived_observable(lambda x, **kwargs: (x[0] + np.finfo(np.float64).eps) / (d.value + np.finfo(np.float64).eps) * root[0], [d], man_grad=[deriv]) np.array(d).reshape(-1), man_grad=np.array(deriv).reshape(-1))
return res return res