refactor: refactored construction of prior Obs in least_squares.

This commit is contained in:
Fabian Joswig 2023-03-09 14:25:37 +00:00
parent e41f869d18
commit 99a1033703
No known key found for this signature in database

View file

@ -240,13 +240,7 @@ def least_squares(x, y, func, priors=None, silent=False, **kwargs):
loc_priors = []
for i_n, i_prior in enumerate(priors):
if isinstance(i_prior, Obs):
loc_priors.append(i_prior)
elif isinstance(i_prior, str):
loc_val, loc_dval = _extract_val_and_dval(i_prior)
loc_priors.append(cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(i_n) + f"_{np.random.randint(2147483647):010d}"))
else:
raise TypeError("Prior entries need to be 'Obs' or 'str'.")
loc_priors.append(_construct_prior_obs(i_prior, i_n))
prior_mask = np.arange(len(priors))
output.priors = loc_priors
@ -260,13 +254,8 @@ def least_squares(x, y, func, priors=None, silent=False, **kwargs):
prior_mask.append(pos)
else:
raise TypeError("Prior position needs to be an integer.")
if isinstance(prior, Obs):
loc_priors.append(prior)
elif isinstance(prior, str):
loc_val, loc_dval = _extract_val_and_dval(prior)
loc_priors.append(cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(pos) + f"_{np.random.randint(2147483647):010d}"))
else:
raise TypeError("Prior entries need to be 'Obs' or 'str'.")
loc_priors.append(_construct_prior_obs(prior, pos))
output.priors[pos] = loc_priors[-1]
if max(prior_mask) >= n_parms:
raise ValueError("Prior position out of range.")
@ -823,3 +812,13 @@ def _extract_val_and_dval(string):
else:
factor = 1
return float(split_string[0]), float(split_string[1][:-1]) * factor
def _construct_prior_obs(i_prior, i_n):
if isinstance(i_prior, Obs):
return i_prior
elif isinstance(i_prior, str):
loc_val, loc_dval = _extract_val_and_dval(i_prior)
return cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(i_n) + f"_{np.random.randint(2147483647):010d}")
else:
raise TypeError("Prior entries need to be 'Obs' or 'str'.")