From b07f16fe1c1a1db010014f85657746ee9c3db94c Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Tue, 2 Nov 2021 14:19:00 +0000 Subject: [PATCH] docstring and test added for Corr.reweight --- pyerrors/correlators.py | 19 +++++++++++++++++-- pyerrors/pyerrors.py | 7 ++++--- tests/correlators_test.py | 6 ++++++ 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index 269b975e..ab210da9 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -233,13 +233,28 @@ class Corr: """Reverse the time ordering of the Corr""" return Corr(self.content[::-1]) - def reweight(self, weight): + def reweight(self, weight, **kwargs): + """Reweight the correlator. + + Parameters + ---------- + weight : Obs + Reweighting factor. An Observable that has to be defined on a superset of the + configurations in obs[i].idl for all i. + + Keyword arguments + ----------------- + all_configs : bool + if True, the reweighted observables are normalized by the average of + the reweighting factor on all configurations in weight.idl and not + on the configurations in obs[i].idl. + """ new_content = [] for t_slice in self: if t_slice is None: new_content.append(None) else: - new_content.append(np.array(reweight(weight, t_slice))) + new_content.append(np.array(reweight(weight, t_slice, **kwargs))) return Corr(new_content) def T_symmetry(self, partner, parity=+1): diff --git a/pyerrors/pyerrors.py b/pyerrors/pyerrors.py index fc184f53..53a5f2ac 100644 --- a/pyerrors/pyerrors.py +++ b/pyerrors/pyerrors.py @@ -1105,9 +1105,10 @@ def reweight(weight, obs, **kwargs): Keyword arguments ----------------- - all_configs -- if True, the reweighted observables are normalized by the average of - the reweighting factor on all configurations in weight.idl and not - on the configurations in obs[i].idl. + all_configs : bool + if True, the reweighted observables are normalized by the average of + the reweighting factor on all configurations in weight.idl and not + on the configurations in obs[i].idl. """ result = [] for i in range(len(obs)): diff --git a/tests/correlators_test.py b/tests/correlators_test.py index 181c7b07..47373f68 100644 --- a/tests/correlators_test.py +++ b/tests/correlators_test.py @@ -54,6 +54,12 @@ def test_m_eff(): my_corr.m_eff('cosh') my_corr.m_eff('sinh') +def test_reweighting(): + my_corr = pe.correlators.Corr([pe.pseudo_Obs(10, 0.1, 't'), pe.pseudo_Obs(0, 0.05, 't')]) + assert my_corr.reweighted is False + r_my_corr = my_corr.reweight(pe.pseudo_Obs(1, 0.1, 't')) + assert r_my_corr.reweighted is True + def test_T_symmetry(): my_corr = pe.correlators.Corr([pe.pseudo_Obs(10, 0.1, 't'), pe.pseudo_Obs(0, 0.05, 't')]) with pytest.warns(RuntimeWarning):