Huge speedup for derived_observable in the case of irregular MC chains by introducing the dict is_merged and rewriting _filter_zeroes. The cost is still non-negligible.

This commit is contained in:
Simon Kuberski 2021-11-17 16:44:54 +01:00
parent cca8d6fbfa
commit 17252c4f0d

View file

@ -118,7 +118,7 @@ class Obs:
raise Exception('Incompatible samples and idx for %s: %d vs. %d' % (name, len(sample), self.shape[name]))
self.r_values[name] = np.mean(sample)
self.deltas[name] = sample - self.r_values[name]
self.is_merged = False
self.is_merged = {}
self.N = sum(list(self.shape.values()))
self._value = 0
@ -941,44 +941,33 @@ def _expand_deltas_for_merge(deltas, idx, shape, new_idx):
return np.array([ret[new_idx[i] - new_idx[0]] for i in range(len(new_idx))])
def _filter_zeroes(names, deltas, idl, eps=Obs.filter_eps):
def _filter_zeroes(deltas, idx, eps=Obs.filter_eps):
"""Filter out all configurations with vanishing fluctuation such that they do not
contribute to the error estimate anymore. Returns the new names, deltas and
idl according to the filtering.
contribute to the error estimate anymore. Returns the new deltas and
idx according to the filtering.
A fluctuation is considered to be vanishing, if it is smaller than eps times
the mean of the absolute values of all deltas in one list.
Parameters
----------
names : list
List of names
deltas : dict
Dict lists of fluctuations
idx : dict
Dict of lists or ranges of configs on which the deltas are defined.
Has to be a subset of new_idx.
deltas : list
List of fluctuations
idx : list
List or ranges of configs on which the deltas are defined.
eps : float
Prefactor that enters the filter criterion.
"""
new_names = []
new_deltas = {}
new_idl = {}
for name in names:
nd = []
ni = []
maxd = np.mean(np.fabs(deltas[name]))
for i in range(len(deltas[name])):
if not np.isclose(0.0, deltas[name][i], atol=eps * maxd):
nd.append(deltas[name][i])
ni.append(idl[name][i])
if nd:
new_names.append(name)
new_deltas[name] = np.array(nd)
new_idl[name] = ni
if new_names:
return (new_names, new_deltas, new_idl)
new_deltas = []
new_idx = []
maxd = np.mean(np.fabs(deltas))
for i in range(len(deltas)):
if abs(deltas[i]) > eps * maxd:
new_deltas.append(deltas[i])
new_idx.append(idx[i])
if new_idx:
return np.array(new_deltas), new_idx
else:
return (names, deltas, idl)
return deltas, idx
def derived_observable(func, data, **kwargs):
@ -1028,7 +1017,9 @@ def derived_observable(func, data, **kwargs):
n_obs = len(raveled_data)
new_names = sorted(set([y for x in [o.names for o in raveled_data] for y in x]))
is_merged = len(list(filter(lambda o: o.is_merged is True, raveled_data))) > 0
is_merged = {}
for name in list(set().union(*[o.names for o in raveled_data])):
is_merged[name] = len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0
reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 0
new_idl_d = {}
for name in new_names:
@ -1038,8 +1029,8 @@ def derived_observable(func, data, **kwargs):
if tmp is not None:
idl.append(tmp)
new_idl_d[name] = _merge_idx(idl)
if not is_merged:
is_merged = (1 != len(set([len(idx) for idx in [*idl, new_idl_d[name]]])))
if not is_merged[name]:
is_merged[name] = (1 != len(set([len(idx) for idx in [*idl, new_idl_d[name]]])))
if data.ndim == 1:
values = np.array([o.value for o in data])
@ -1104,12 +1095,16 @@ def derived_observable(func, data, **kwargs):
new_samples = []
new_means = []
new_idl = []
if is_merged:
filtered_names, filtered_deltas, filtered_idl_d = _filter_zeroes(new_names, new_deltas, new_idl_d)
else:
filtered_names = new_names
filtered_deltas = new_deltas
filtered_idl_d = new_idl_d
filtered_names = []
filtered_deltas = {}
filtered_idl_d = {}
for name in new_names:
filtered_names.append(name)
if is_merged[name]:
filtered_deltas[name], filtered_idl_d[name] = _filter_zeroes(new_deltas[name], new_idl_d[name])
else:
filtered_deltas[name] = new_deltas[name]
filtered_idl_d[name] = new_idl_d[name]
for name in filtered_names:
new_samples.append(filtered_deltas[name])
new_means.append(new_r_values[name][i_val])
@ -1236,7 +1231,7 @@ def correlate(obs_a, obs_b):
new_idl.append(obs_a.idl[name])
o = Obs(new_samples, sorted(obs_a.names), idl=new_idl)
o.is_merged = obs_a.is_merged or obs_b.is_merged
o.is_merged = {name: (obs_a.is_merged.get(name, False) or obs_b.is_merged.get(name, False)) for name in o.names}
o.reweighted = obs_a.reweighted or obs_b.reweighted
return o
@ -1578,6 +1573,6 @@ def merge_obs(list_of_obs):
names = sorted(new_dict.keys())
o = Obs([new_dict[name] for name in names], names, idl=[idl_dict[name] for name in names])
o.is_merged = np.any([oi.is_merged for oi in list_of_obs])
o.is_merged = {name: np.any([oi.is_merged.get(name, False) for oi in list_of_obs]) for name in o.names}
o.reweighted = np.max([oi.reweighted for oi in list_of_obs])
return o