Merge branch 'develop' into documentation

This commit is contained in:
fjosw 2023-03-09 14:26:18 +00:00
commit dd14b35905
2 changed files with 17 additions and 15 deletions

View file

@ -9,7 +9,10 @@
"source": [ "source": [
"import pyerrors as pe\n", "import pyerrors as pe\n",
"import numpy as np\n", "import numpy as np\n",
"import matplotlib.pyplot as plt" "import matplotlib.pyplot as plt\n",
"from packaging import version\n",
"if version.parse(pe.__version__) < version.parse(\"2.6.0\"):\n",
" raise Exception(f\"v2.6.0 or newer is required for this example, you are using {pe.__version__}\")"
] ]
}, },
{ {

View file

@ -240,13 +240,7 @@ def least_squares(x, y, func, priors=None, silent=False, **kwargs):
loc_priors = [] loc_priors = []
for i_n, i_prior in enumerate(priors): for i_n, i_prior in enumerate(priors):
if isinstance(i_prior, Obs): loc_priors.append(_construct_prior_obs(i_prior, i_n))
loc_priors.append(i_prior)
elif isinstance(i_prior, str):
loc_val, loc_dval = _extract_val_and_dval(i_prior)
loc_priors.append(cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(i_n) + f"_{np.random.randint(2147483647):010d}"))
else:
raise TypeError("Prior entries need to be 'Obs' or 'str'.")
prior_mask = np.arange(len(priors)) prior_mask = np.arange(len(priors))
output.priors = loc_priors output.priors = loc_priors
@ -260,13 +254,8 @@ def least_squares(x, y, func, priors=None, silent=False, **kwargs):
prior_mask.append(pos) prior_mask.append(pos)
else: else:
raise TypeError("Prior position needs to be an integer.") raise TypeError("Prior position needs to be an integer.")
if isinstance(prior, Obs): loc_priors.append(_construct_prior_obs(prior, pos))
loc_priors.append(prior)
elif isinstance(prior, str):
loc_val, loc_dval = _extract_val_and_dval(prior)
loc_priors.append(cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(pos) + f"_{np.random.randint(2147483647):010d}"))
else:
raise TypeError("Prior entries need to be 'Obs' or 'str'.")
output.priors[pos] = loc_priors[-1] output.priors[pos] = loc_priors[-1]
if max(prior_mask) >= n_parms: if max(prior_mask) >= n_parms:
raise ValueError("Prior position out of range.") raise ValueError("Prior position out of range.")
@ -823,3 +812,13 @@ def _extract_val_and_dval(string):
else: else:
factor = 1 factor = 1
return float(split_string[0]), float(split_string[1][:-1]) * factor return float(split_string[0]), float(split_string[1][:-1]) * factor
def _construct_prior_obs(i_prior, i_n):
if isinstance(i_prior, Obs):
return i_prior
elif isinstance(i_prior, str):
loc_val, loc_dval = _extract_val_and_dval(i_prior)
return cov_Obs(loc_val, loc_dval ** 2, '#prior' + str(i_n) + f"_{np.random.randint(2147483647):010d}")
else:
raise TypeError("Prior entries need to be 'Obs' or 'str'.")