mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-15 14:50:25 +01:00
refactor: refactored construction of prior Obs in least_squares.
This commit is contained in:
parent
e41f869d18
commit
99a1033703
1 changed files with 13 additions and 14 deletions
|
@ -240,13 +240,7 @@ def least_squares(x, y, func, priors=None, silent=False, **kwargs):
|
|||
|
||||
loc_priors = []
|
||||
for i_n, i_prior in enumerate(priors):
|
||||
if isinstance(i_prior, Obs):
|
||||
loc_priors.append(i_prior)
|
||||
elif isinstance(i_prior, str):
|
||||
loc_val, loc_dval = _extract_val_and_dval(i_prior)
|
||||
loc_priors.append(cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(i_n) + f"_{np.random.randint(2147483647):010d}"))
|
||||
else:
|
||||
raise TypeError("Prior entries need to be 'Obs' or 'str'.")
|
||||
loc_priors.append(_construct_prior_obs(i_prior, i_n))
|
||||
|
||||
prior_mask = np.arange(len(priors))
|
||||
output.priors = loc_priors
|
||||
|
@ -260,13 +254,8 @@ def least_squares(x, y, func, priors=None, silent=False, **kwargs):
|
|||
prior_mask.append(pos)
|
||||
else:
|
||||
raise TypeError("Prior position needs to be an integer.")
|
||||
if isinstance(prior, Obs):
|
||||
loc_priors.append(prior)
|
||||
elif isinstance(prior, str):
|
||||
loc_val, loc_dval = _extract_val_and_dval(prior)
|
||||
loc_priors.append(cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(pos) + f"_{np.random.randint(2147483647):010d}"))
|
||||
else:
|
||||
raise TypeError("Prior entries need to be 'Obs' or 'str'.")
|
||||
loc_priors.append(_construct_prior_obs(prior, pos))
|
||||
|
||||
output.priors[pos] = loc_priors[-1]
|
||||
if max(prior_mask) >= n_parms:
|
||||
raise ValueError("Prior position out of range.")
|
||||
|
@ -823,3 +812,13 @@ def _extract_val_and_dval(string):
|
|||
else:
|
||||
factor = 1
|
||||
return float(split_string[0]), float(split_string[1][:-1]) * factor
|
||||
|
||||
|
||||
def _construct_prior_obs(i_prior, i_n):
|
||||
if isinstance(i_prior, Obs):
|
||||
return i_prior
|
||||
elif isinstance(i_prior, str):
|
||||
loc_val, loc_dval = _extract_val_and_dval(i_prior)
|
||||
return cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(i_n) + f"_{np.random.randint(2147483647):010d}")
|
||||
else:
|
||||
raise TypeError("Prior entries need to be 'Obs' or 'str'.")
|
||||
|
|
Loading…
Add table
Reference in a new issue