diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index 2274f7e1..a363852c 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -867,6 +867,11 @@ class Corr: else: newcontent.append(self.content[t] + y) return Corr(newcontent, prange=self.prange) + elif isinstance(y, np.ndarray): + if y.shape == (self.T,): + return Corr(list((np.array(self.content).T + y).T)) + else: + raise ValueError("operands could not be broadcast together") else: raise TypeError("Corr + wrong type") @@ -890,6 +895,11 @@ class Corr: else: newcontent.append(self.content[t] * y) return Corr(newcontent, prange=self.prange) + elif isinstance(y, np.ndarray): + if y.shape == (self.T,): + return Corr(list((np.array(self.content).T * y).T)) + else: + raise ValueError("operands could not be broadcast together") else: raise TypeError("Corr * wrong type") @@ -939,6 +949,11 @@ class Corr: else: newcontent.append(self.content[t] / y) return Corr(newcontent, prange=self.prange) + elif isinstance(y, np.ndarray): + if y.shape == (self.T,): + return Corr(list((np.array(self.content).T / y).T)) + else: + raise ValueError("operands could not be broadcast together") else: raise TypeError('Corr / wrong type') diff --git a/tests/correlators_test.py b/tests/correlators_test.py index 27dc0f67..2d068d90 100644 --- a/tests/correlators_test.py +++ b/tests/correlators_test.py @@ -219,13 +219,6 @@ def test_prange(): def test_matrix_corr(): - def _gen_corr(val): - corr_content = [] - for t in range(16): - corr_content.append(pe.pseudo_Obs(val, 0.1, 't', 2000)) - - return pe.correlators.Corr(corr_content) - corr_aa = _gen_corr(1) corr_ab = _gen_corr(0.5) @@ -311,3 +304,26 @@ def test_corr_matrix_none_entries(): corr = pe.Corr(oy) corr = corr.deriv() pe.Corr(np.array([[corr, corr], [corr, corr]])) + + +def test_corr_vector_operations(): + my_corr = _gen_corr(1.0) + my_vec = np.arange(1, 17) + + my_corr + my_vec + my_corr - my_vec + my_corr * my_vec + my_corr / my_vec + + assert np.all([o == 0 for o in ((my_corr + my_vec) - my_vec) - my_corr]) + assert np.all([o == 0 for o in ((my_corr - my_vec) + my_vec) - my_corr]) + assert np.all([o == 0 for o in ((my_corr * my_vec) / my_vec) - my_corr]) + assert np.all([o == 0 for o in ((my_corr / my_vec) * my_vec) - my_corr]) + +def _gen_corr(val): + corr_content = [] + for t in range(16): + corr_content.append(pe.pseudo_Obs(val, 0.1, 't', 2000)) + + return pe.correlators.Corr(corr_content) +