refactor: calculation of N in Obs.__init__ optimized

This commit is contained in:
Fabian Joswig 2021-12-08 16:14:48 +00:00
parent 2702b5519d
commit ae53daa915

View file

@ -120,23 +120,24 @@ class Obs:
self.idl[name] = range(1, len(sample) + 1)
self._value = 0
self.N = 0
if means is not None:
for name, sample, mean in sorted(zip(names, samples, means)):
self.shape[name] = len(self.idl[name])
self.N += self.shape[name]
if len(sample) != self.shape[name]:
raise Exception('Incompatible samples and idx for %s: %d vs. %d' % (name, len(sample), self.shape[name]))
self.r_values[name] = mean
self.deltas[name] = sample
self.N = sum(list(self.shape.values()))
else:
for name, sample in sorted(zip(names, samples)):
self.shape[name] = len(self.idl[name])
self.N += self.shape[name]
if len(sample) != self.shape[name]:
raise Exception('Incompatible samples and idx for %s: %d vs. %d' % (name, len(sample), self.shape[name]))
self.r_values[name] = np.mean(sample)
self.deltas[name] = sample - self.r_values[name]
self._value += self.shape[name] * self.r_values[name]
self.N = sum(list(self.shape.values()))
self._value /= self.N
self.is_merged = {}