diff --git a/pyerrors/fits.py b/pyerrors/fits.py index d9f963d9..20a5ee32 100644 --- a/pyerrors/fits.py +++ b/pyerrors/fits.py @@ -240,13 +240,7 @@ def least_squares(x, y, func, priors=None, silent=False, **kwargs): loc_priors = [] for i_n, i_prior in enumerate(priors): - if isinstance(i_prior, Obs): - loc_priors.append(i_prior) - elif isinstance(i_prior, str): - loc_val, loc_dval = _extract_val_and_dval(i_prior) - loc_priors.append(cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(i_n) + f"_{np.random.randint(2147483647):010d}")) - else: - raise TypeError("Prior entries need to be 'Obs' or 'str'.") + loc_priors.append(_construct_prior_obs(i_prior, i_n)) prior_mask = np.arange(len(priors)) output.priors = loc_priors @@ -260,13 +254,8 @@ def least_squares(x, y, func, priors=None, silent=False, **kwargs): prior_mask.append(pos) else: raise TypeError("Prior position needs to be an integer.") - if isinstance(prior, Obs): - loc_priors.append(prior) - elif isinstance(prior, str): - loc_val, loc_dval = _extract_val_and_dval(prior) - loc_priors.append(cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(pos) + f"_{np.random.randint(2147483647):010d}")) - else: - raise TypeError("Prior entries need to be 'Obs' or 'str'.") + loc_priors.append(_construct_prior_obs(prior, pos)) + output.priors[pos] = loc_priors[-1] if max(prior_mask) >= n_parms: raise ValueError("Prior position out of range.") @@ -823,3 +812,13 @@ def _extract_val_and_dval(string): else: factor = 1 return float(split_string[0]), float(split_string[1][:-1]) * factor + + +def _construct_prior_obs(i_prior, i_n): + if isinstance(i_prior, Obs): + return i_prior + elif isinstance(i_prior, str): + loc_val, loc_dval = _extract_val_and_dval(i_prior) + return cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(i_n) + f"_{np.random.randint(2147483647):010d}") + else: + raise TypeError("Prior entries need to be 'Obs' or 'str'.")