diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index 95eadf74..736bcb7a 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -69,9 +69,6 @@ class Corr: if isinstance(data_input, list): - if all([isinstance(item, (Obs, CObs)) for item in data_input]): - _assert_equal_properties(data_input) - self.content = [np.asarray([item]) for item in data_input] if all([isinstance(item, (Obs, CObs)) or item is None for item in data_input]): _assert_equal_properties([o for o in data_input if o is not None]) self.content = [np.asarray([item]) if item is not None else None for item in data_input] @@ -972,6 +969,8 @@ class Corr: content_string += "Description: " + self.tag + "\n" if self.N != 1: return content_string + if isinstance(self[0], CObs): + return content_string if print_range[1]: print_range[1] += 1 @@ -1139,8 +1138,10 @@ class Corr: for t in range(self.T): if _check_for_none(self, newcontent[t]): continue - if np.isnan(np.sum(newcontent[t]).value): - newcontent[t] = None + tmp_sum = np.sum(newcontent[t]) + if hasattr(tmp_sum, "value"): + if np.isnan(tmp_sum.value): + newcontent[t] = None if all([item is None for item in newcontent]): raise Exception('Operation returns undefined correlator') return Corr(newcontent) @@ -1197,8 +1198,8 @@ class Corr: @property def real(self): def return_real(obs_OR_cobs): - if isinstance(obs_OR_cobs, CObs): - return obs_OR_cobs.real + if isinstance(obs_OR_cobs.flatten()[0], CObs): + return np.vectorize(lambda x: x.real)(obs_OR_cobs) else: return obs_OR_cobs @@ -1207,8 +1208,8 @@ class Corr: @property def imag(self): def return_imag(obs_OR_cobs): - if isinstance(obs_OR_cobs, CObs): - return obs_OR_cobs.imag + if isinstance(obs_OR_cobs.flatten()[0], CObs): + return np.vectorize(lambda x: x.imag)(obs_OR_cobs) else: return obs_OR_cobs * 0 # So it stays the right type diff --git a/pyerrors/misc.py b/pyerrors/misc.py index bbfb8e7d..cf0dd18a 100644 --- a/pyerrors/misc.py +++ b/pyerrors/misc.py @@ -105,16 +105,11 @@ def gen_correlated_data(means, cov, name, tau=0.5, samples=1000): def _assert_equal_properties(ol, otype=Obs): - if not isinstance(ol[0], otype): - raise Exception("Wrong data type in list.") + otype = type(ol[0]) for o in ol[1:]: if not isinstance(o, otype): raise Exception("Wrong data type in list.") - if not ol[0].is_merged == o.is_merged: - raise Exception("All Obs in list have to have the same state 'is_merged'.") - if not ol[0].reweighted == o.reweighted: - raise Exception("All Obs in list have to have the same property 'reweighted'.") - if not ol[0].e_content == o.e_content: - raise Exception("All Obs in list have to be defined on the same set of configs.") - if not ol[0].idl == o.idl: - raise Exception("All Obs in list have to be defined on the same set of configurations.") + for attr in ["is_merged", "reweighted", "e_content", "idl"]: + if hasattr(ol[0], attr): + if not getattr(ol[0], attr) == getattr(o, attr): + raise Exception(f"All Obs in list have to have the same state '{attr}'.") diff --git a/tests/correlators_test.py b/tests/correlators_test.py index 49332254..1ab31f12 100644 --- a/tests/correlators_test.py +++ b/tests/correlators_test.py @@ -532,3 +532,13 @@ def test_prune(): with pytest.raises(Exception): corr_mat.prune(3) corr_mat.prune(4) + + +def test_complex_Corr(): + o1 = pe.pseudo_Obs(1.0, 0.1, "test") + cobs = pe.CObs(o1, -o1) + ccorr = pe.Corr([cobs, cobs, cobs]) + assert np.all([ccorr.imag[i] == -ccorr.real[i] for i in range(ccorr.T)]) + print(ccorr) + mcorr = pe.Corr(np.array([[ccorr, ccorr], [ccorr, ccorr]])) + assert np.all([mcorr.imag[i] == -mcorr.real[i] for i in range(mcorr.T)])