Further optimiuations in Obs.__init__

This commit is contained in:
Fabian Joswig 2021-10-21 16:06:29 +01:00
parent 0cea84a9c1
commit a3ddf51965

View file

@ -49,6 +49,7 @@ class Obs:
def __init__(self, samples, names, **kwargs): def __init__(self, samples, names, **kwargs):
if 'means' not in kwargs:
if len(samples) != len(names): if len(samples) != len(names):
raise Exception('Length of samples and names incompatible.') raise Exception('Length of samples and names incompatible.')
if len(names) != len(set(names)): if len(names) != len(set(names)):
@ -64,8 +65,6 @@ class Obs:
self.deltas = {} self.deltas = {}
if 'means' in kwargs: if 'means' in kwargs:
if len(samples) != len(kwargs.get('means')):
raise Exception('Length of samples and means incompatible.')
for name, sample, mean in sorted(zip(names, samples, kwargs.get('means'))): for name, sample, mean in sorted(zip(names, samples, kwargs.get('means'))):
self.shape[name] = np.size(sample) self.shape[name] = np.size(sample)
self.r_values[name] = mean self.r_values[name] = mean
@ -79,6 +78,7 @@ class Obs:
self.N = sum(map(np.size, list(self.deltas.values()))) self.N = sum(map(np.size, list(self.deltas.values())))
self.value = 0 self.value = 0
if 'means' not in kwargs:
for name in self.names: for name in self.names:
self.value += self.shape[name] * self.r_values[name] self.value += self.shape[name] * self.r_values[name]
self.value /= self.N self.value /= self.N