mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-15 14:50:25 +01:00
Corr.correlate implemented
This commit is contained in:
parent
5fb29fd326
commit
ac21569620
3 changed files with 32 additions and 1 deletions
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import autograd.numpy as anp
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.linalg
|
||||
from .pyerrors import Obs, dump_object, reweight
|
||||
from .pyerrors import Obs, dump_object, reweight, correlate
|
||||
from .fits import least_squares
|
||||
from .linalg import eigh, inv, cholesky
|
||||
from .roots import find_root
|
||||
|
@ -237,6 +237,22 @@ class Corr:
|
|||
"""Reverse the time ordering of the Corr"""
|
||||
return Corr(self.content[::-1])
|
||||
|
||||
def correlate(self, partner):
|
||||
"""Correlate the correlator with another correlator or Obs"""
|
||||
new_content = []
|
||||
for x0, t_slice in enumerate(self.content):
|
||||
if t_slice is None:
|
||||
new_content.append(None)
|
||||
else:
|
||||
if isinstance(partner, Corr):
|
||||
new_content.append(np.array([correlate(o, partner.content[x0][0]) for o in t_slice]))
|
||||
elif isinstance(partner, Obs):
|
||||
new_content.append(np.array([correlate(o, partner) for o in t_slice]))
|
||||
else:
|
||||
raise Exception("Can only correlate with an Obs or a Corr.")
|
||||
|
||||
return Corr(new_content)
|
||||
|
||||
def reweight(self, weight, **kwargs):
|
||||
"""Reweight the correlator.
|
||||
|
||||
|
|
|
@ -1145,6 +1145,13 @@ def reweight(weight, obs, **kwargs):
|
|||
def correlate(obs_a, obs_b):
|
||||
"""Correlate two observables.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
obs_a : Obs
|
||||
First observable
|
||||
obs_b : Obs
|
||||
Second observable
|
||||
|
||||
Keep in mind to only correlate primary observables which have not been reweighted
|
||||
yet. The reweighting has to be applied after correlating the observables.
|
||||
Currently only works if ensembles are identical. This is not really necessary.
|
||||
|
|
|
@ -290,6 +290,14 @@ def test_merge_obs():
|
|||
assert diff == -(my_obs1.value + my_obs2.value) / 2
|
||||
|
||||
|
||||
def test_correlate():
|
||||
my_obs1 = pe.Obs([np.random.rand(100)], ['t'])
|
||||
my_obs2 = pe.Obs([np.random.rand(100)], ['t'])
|
||||
corr1 = pe.correlate(my_obs1, my_obs2)
|
||||
corr2 = pe.correlate(my_obs2, my_obs1)
|
||||
assert corr1 == corr2
|
||||
|
||||
|
||||
def test_irregular_error_propagation():
|
||||
obs_list = [pe.Obs([np.random.rand(100)], ['t']),
|
||||
pe.Obs([np.random.rand(50)], ['t'], idl=[range(1, 100, 2)]),
|
||||
|
|
Loading…
Add table
Reference in a new issue