Merge branch 'develop' into documentation

This commit is contained in:
fjosw 2021-11-12 09:40:00 +00:00
commit 9382041ad6
3 changed files with 26 additions and 3 deletions

View file

@ -86,7 +86,7 @@ def derived_array(func, data, **kwargs):
for name in new_names:
d_extracted[name] = []
for i_dat, dat in enumerate(data):
ens_length = new_idl_d[name][-1] - new_idl_d[name][0] + 1
ens_length = len(new_idl_d[name])
d_extracted[name].append(np.array([_expand_deltas_for_merge(o.deltas[name], o.idl[name], o.shape[name], new_idl_d[name]) for o in dat.reshape(np.prod(dat.shape))]).reshape(dat.shape + (ens_length, )))
for i_val, new_val in np.ndenumerate(new_values):

View file

@ -248,7 +248,7 @@ class Obs:
r_length = []
for r_name in e_content[e_name]:
if self.idl[r_name] is range:
if isinstance(self.idl[r_name], range):
r_length.append(len(self.idl[r_name]))
else:
r_length.append((self.idl[r_name][-1] - self.idl[r_name][0] + 1))
@ -339,7 +339,7 @@ class Obs:
idx -- List or range of configs on which the deltas are defined.
shape -- Number of configs in idx.
"""
if type(idx) is range:
if isinstance(idx, range):
return deltas
else:
ret = np.zeros(idx[-1] - idx[0] + 1)

View file

@ -51,6 +51,29 @@ def test_multi_dot():
assert e.is_zero(), t
def test_matmul_irregular_histories():
dim = 2
length = 500
standard_array = []
for i in range(dim ** 2):
standard_array.append(pe.Obs([np.random.normal(1.1, 0.2, length)], ['ens1']))
standard_matrix = np.array(standard_array).reshape((dim, dim))
for idl in [range(1, 501, 2), range(250, 273), [2, 8, 19, 20, 78]]:
irregular_array = []
for i in range(dim ** 2):
irregular_array.append(pe.Obs([np.random.normal(1.1, 0.2, len(idl))], ['ens1'], idl=[idl]))
irregular_matrix = np.array(irregular_array).reshape((dim, dim))
t1 = standard_matrix @ irregular_matrix
t2 = pe.linalg.matmul(standard_matrix, irregular_matrix)
assert np.all([o.is_zero() for o in (t1 - t2).ravel()])
assert np.all([o.is_merged for o in t1.ravel()])
assert np.all([o.is_merged for o in t2.ravel()])
def test_matrix_inverse():
content = []
for t in range(9):