mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-16 15:20:24 +01:00
Documentation updated
This commit is contained in:
parent
97334aa841
commit
0353836815
3 changed files with 103 additions and 289 deletions
|
@ -47,9 +47,6 @@
|
|||
|
||||
<h2>API Documentation</h2>
|
||||
<ul class="memberlist">
|
||||
<li>
|
||||
<a class="function" href="#derived_array">derived_array</a>
|
||||
</li>
|
||||
<li>
|
||||
<a class="function" href="#matmul">matmul</a>
|
||||
</li>
|
||||
|
@ -106,125 +103,13 @@
|
|||
<details>
|
||||
<summary>View Source</summary>
|
||||
<div class="codehilite"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
|
||||
<span class="kn">from</span> <span class="nn">autograd</span> <span class="kn">import</span> <span class="n">jacobian</span>
|
||||
<span class="kn">import</span> <span class="nn">autograd.numpy</span> <span class="k">as</span> <span class="nn">anp</span> <span class="c1"># Thinly-wrapped numpy</span>
|
||||
<span class="kn">from</span> <span class="nn">.obs</span> <span class="kn">import</span> <span class="n">derived_observable</span><span class="p">,</span> <span class="n">CObs</span><span class="p">,</span> <span class="n">Obs</span><span class="p">,</span> <span class="n">_merge_idx</span><span class="p">,</span> <span class="n">_expand_deltas_for_merge</span><span class="p">,</span> <span class="n">_filter_zeroes</span><span class="p">,</span> <span class="n">import_jackknife</span>
|
||||
<span class="kn">from</span> <span class="nn">.obs</span> <span class="kn">import</span> <span class="n">derived_observable</span><span class="p">,</span> <span class="n">CObs</span><span class="p">,</span> <span class="n">Obs</span><span class="p">,</span> <span class="n">import_jackknife</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
|
||||
<span class="kn">from</span> <span class="nn">autograd.extend</span> <span class="kn">import</span> <span class="n">defvjp</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">derived_array</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||||
<span class="sd">"""Construct a derived Obs for a matrix valued function according to func(data, **kwargs) using automatic differentiation.</span>
|
||||
|
||||
<span class="sd"> Parameters</span>
|
||||
<span class="sd"> ----------</span>
|
||||
<span class="sd"> func : object</span>
|
||||
<span class="sd"> arbitrary function of the form func(data, **kwargs). For the</span>
|
||||
<span class="sd"> automatic differentiation to work, all numpy functions have to have</span>
|
||||
<span class="sd"> the autograd wrapper (use 'import autograd.numpy as anp').</span>
|
||||
<span class="sd"> data : list</span>
|
||||
<span class="sd"> list of Obs, e.g. [obs1, obs2, obs3].</span>
|
||||
<span class="sd"> man_grad : list</span>
|
||||
<span class="sd"> manually supply a list or an array which contains the jacobian</span>
|
||||
<span class="sd"> of func. Use cautiously, supplying the wrong derivative will</span>
|
||||
<span class="sd"> not be intercepted.</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
|
||||
<span class="n">raveled_data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># Workaround for matrix operations containing non Obs data</span>
|
||||
<span class="k">for</span> <span class="n">i_data</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i_data</span><span class="p">,</span> <span class="n">Obs</span><span class="p">):</span>
|
||||
<span class="n">first_name</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">first_shape</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">first_name</span><span class="p">]</span>
|
||||
<span class="n">first_idl</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">idl</span><span class="p">[</span><span class="n">first_name</span><span class="p">]</span>
|
||||
<span class="k">break</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">)):</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">(</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">)):</span>
|
||||
<span class="n">raveled_data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">Obs</span><span class="p">([</span><span class="n">raveled_data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">first_shape</span><span class="p">)],</span> <span class="p">[</span><span class="n">first_name</span><span class="p">],</span> <span class="n">idl</span><span class="o">=</span><span class="p">[</span><span class="n">first_idl</span><span class="p">])</span>
|
||||
|
||||
<span class="n">n_obs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">)</span>
|
||||
<span class="n">new_names</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">y</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="n">o</span><span class="o">.</span><span class="n">names</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">]</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">x</span><span class="p">]))</span>
|
||||
|
||||
<span class="n">is_merged</span> <span class="o">=</span> <span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">o</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">is_merged</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">,</span> <span class="n">raveled_data</span><span class="p">)))</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">}</span>
|
||||
<span class="n">reweighted</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">o</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">reweighted</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">,</span> <span class="n">raveled_data</span><span class="p">)))</span> <span class="o">></span> <span class="mi">0</span>
|
||||
<span class="n">new_idl_d</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="n">idl</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">i_data</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">:</span>
|
||||
<span class="n">tmp</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">idl</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">tmp</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">idl</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">tmp</span><span class="p">)</span>
|
||||
<span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">_merge_idx</span><span class="p">(</span><span class="n">idl</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_merged</span><span class="p">[</span><span class="n">name</span><span class="p">]:</span>
|
||||
<span class="n">is_merged</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span> <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="p">[</span><span class="o">*</span><span class="n">idl</span><span class="p">,</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">]]])))</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">o</span><span class="o">.</span><span class="n">value</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">data</span><span class="p">])</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">value</span><span class="p">)(</span><span class="n">data</span><span class="p">)</span>
|
||||
|
||||
<span class="n">new_values</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="n">values</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||||
|
||||
<span class="n">new_r_values</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="n">tmp_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n_obs</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">item</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">):</span>
|
||||
<span class="n">tmp</span> <span class="o">=</span> <span class="n">item</span><span class="o">.</span><span class="n">r_values</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">tmp</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">tmp</span> <span class="o">=</span> <span class="n">item</span><span class="o">.</span><span class="n">value</span>
|
||||
<span class="n">tmp_values</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">tmp</span>
|
||||
<span class="n">tmp_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tmp_values</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||||
<span class="n">new_r_values</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="n">tmp_values</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="s1">'man_grad'</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
|
||||
<span class="n">deriv</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'man_grad'</span><span class="p">))</span>
|
||||
<span class="k">if</span> <span class="n">new_values</span><span class="o">.</span><span class="n">shape</span> <span class="o">+</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">deriv</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">'Manual derivative does not have correct shape.'</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">deriv</span> <span class="o">=</span> <span class="n">jacobian</span><span class="p">(</span><span class="n">func</span><span class="p">)(</span><span class="n">values</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||||
|
||||
<span class="n">final_result</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">new_values</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">object</span><span class="p">)</span>
|
||||
|
||||
<span class="n">d_extracted</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
|
||||
<span class="n">ens_length</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
|
||||
<span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">_expand_deltas_for_merge</span><span class="p">(</span><span class="n">o</span><span class="o">.</span><span class="n">deltas</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">o</span><span class="o">.</span><span class="n">idl</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">o</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">])</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">dat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span><span class="p">))])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span> <span class="o">+</span> <span class="p">(</span><span class="n">ens_length</span><span class="p">,</span> <span class="p">)))</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i_val</span><span class="p">,</span> <span class="n">new_val</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">ndenumerate</span><span class="p">(</span><span class="n">new_values</span><span class="p">):</span>
|
||||
<span class="n">new_deltas</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="n">ens_length</span> <span class="o">=</span> <span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">ens_length</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]):</span>
|
||||
<span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">tensordot</span><span class="p">(</span><span class="n">deriv</span><span class="p">[</span><span class="n">i_val</span> <span class="o">+</span> <span class="p">(</span><span class="n">i_dat</span><span class="p">,</span> <span class="p">)],</span> <span class="n">dat</span><span class="p">)</span>
|
||||
|
||||
<span class="n">new_samples</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">new_means</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">new_idl</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">is_merged</span><span class="p">[</span><span class="n">name</span><span class="p">]:</span>
|
||||
<span class="n">filtered_deltas</span><span class="p">,</span> <span class="n">filtered_idl_d</span> <span class="o">=</span> <span class="n">_filter_zeroes</span><span class="p">(</span><span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">filtered_deltas</span> <span class="o">=</span> <span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
||||
<span class="n">filtered_idl_d</span> <span class="o">=</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
||||
|
||||
<span class="n">new_samples</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">filtered_deltas</span><span class="p">)</span>
|
||||
<span class="n">new_idl</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">filtered_idl_d</span><span class="p">)</span>
|
||||
<span class="n">new_means</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">new_r_values</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="n">i_val</span><span class="p">])</span>
|
||||
<span class="n">final_result</span><span class="p">[</span><span class="n">i_val</span><span class="p">]</span> <span class="o">=</span> <span class="n">Obs</span><span class="p">(</span><span class="n">new_samples</span><span class="p">,</span> <span class="n">new_names</span><span class="p">,</span> <span class="n">means</span><span class="o">=</span><span class="n">new_means</span><span class="p">,</span> <span class="n">idl</span><span class="o">=</span><span class="n">new_idl</span><span class="p">)</span>
|
||||
<span class="n">final_result</span><span class="p">[</span><span class="n">i_val</span><span class="p">]</span><span class="o">.</span><span class="n">_value</span> <span class="o">=</span> <span class="n">new_val</span>
|
||||
<span class="n">final_result</span><span class="p">[</span><span class="n">i_val</span><span class="p">]</span><span class="o">.</span><span class="n">is_merged</span> <span class="o">=</span> <span class="n">is_merged</span>
|
||||
<span class="n">final_result</span><span class="p">[</span><span class="n">i_val</span><span class="p">]</span><span class="o">.</span><span class="n">reweighted</span> <span class="o">=</span> <span class="n">reweighted</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">final_result</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="o">*</span><span class="n">operands</span><span class="p">):</span>
|
||||
<span class="sd">"""Matrix multiply all operands.</span>
|
||||
|
||||
|
@ -264,8 +149,8 @@
|
|||
<span class="k">def</span> <span class="nf">multi_dot_i</span><span class="p">(</span><span class="n">operands</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">multi_dot</span><span class="p">(</span><span class="n">operands</span><span class="p">,</span> <span class="s1">'Imag'</span><span class="p">)</span>
|
||||
|
||||
<span class="n">Nr</span> <span class="o">=</span> <span class="n">derived_array</span><span class="p">(</span><span class="n">multi_dot_r</span><span class="p">,</span> <span class="n">extended_operands</span><span class="p">)</span>
|
||||
<span class="n">Ni</span> <span class="o">=</span> <span class="n">derived_array</span><span class="p">(</span><span class="n">multi_dot_i</span><span class="p">,</span> <span class="n">extended_operands</span><span class="p">)</span>
|
||||
<span class="n">Nr</span> <span class="o">=</span> <span class="n">derived_observable</span><span class="p">(</span><span class="n">multi_dot_r</span><span class="p">,</span> <span class="n">extended_operands</span><span class="p">,</span> <span class="n">array_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">Ni</span> <span class="o">=</span> <span class="n">derived_observable</span><span class="p">(</span><span class="n">multi_dot_i</span><span class="p">,</span> <span class="n">extended_operands</span><span class="p">,</span> <span class="n">array_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="n">res</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">Nr</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">m</span><span class="p">),</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">ndenumerate</span><span class="p">(</span><span class="n">Nr</span><span class="p">):</span>
|
||||
|
@ -278,7 +163,7 @@
|
|||
<span class="k">for</span> <span class="n">op</span> <span class="ow">in</span> <span class="n">operands</span><span class="p">[</span><span class="mi">1</span><span class="p">:]:</span>
|
||||
<span class="n">stack</span> <span class="o">=</span> <span class="n">stack</span> <span class="o">@</span> <span class="n">op</span>
|
||||
<span class="k">return</span> <span class="n">stack</span>
|
||||
<span class="k">return</span> <span class="n">derived_array</span><span class="p">(</span><span class="n">multi_dot</span><span class="p">,</span> <span class="n">operands</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">derived_observable</span><span class="p">(</span><span class="n">multi_dot</span><span class="p">,</span> <span class="n">operands</span><span class="p">,</span> <span class="n">array_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">jack_matmul</span><span class="p">(</span><span class="o">*</span><span class="n">operands</span><span class="p">):</span>
|
||||
|
@ -467,7 +352,7 @@
|
|||
<span class="k">if</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'num_grad'</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">:</span>
|
||||
<span class="n">op_big_matrix</span> <span class="o">=</span> <span class="n">_num_diff_mat_mat_op</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">big_matrix</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">op_big_matrix</span> <span class="o">=</span> <span class="n">derived_array</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">op</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">[</span><span class="n">big_matrix</span><span class="p">])[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">op_big_matrix</span> <span class="o">=</span> <span class="n">derived_observable</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">op</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">[</span><span class="n">big_matrix</span><span class="p">],</span> <span class="n">array_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">dim</span> <span class="o">=</span> <span class="n">op_big_matrix</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">op_A</span> <span class="o">=</span> <span class="n">op_big_matrix</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span> <span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">:</span> <span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">]</span>
|
||||
<span class="n">op_B</span> <span class="o">=</span> <span class="n">op_big_matrix</span><span class="p">[</span><span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">:,</span> <span class="mi">0</span><span class="p">:</span> <span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">]</span>
|
||||
|
@ -478,7 +363,7 @@
|
|||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'num_grad'</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">_num_diff_mat_mat_op</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">obs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">derived_array</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">op</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">[</span><span class="n">obs</span><span class="p">])[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="k">return</span> <span class="n">derived_observable</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">op</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">[</span><span class="n">obs</span><span class="p">],</span> <span class="n">array_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">eigh</span><span class="p">(</span><span class="n">obs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||||
|
@ -798,149 +683,6 @@
|
|||
|
||||
</details>
|
||||
|
||||
</section>
|
||||
<section id="derived_array">
|
||||
<div class="attr function"><a class="headerlink" href="#derived_array">#  </a>
|
||||
|
||||
|
||||
<span class="def">def</span>
|
||||
<span class="name">derived_array</span><span class="signature">(func, data, **kwargs)</span>:
|
||||
</div>
|
||||
|
||||
<details>
|
||||
<summary>View Source</summary>
|
||||
<div class="codehilite"><pre><span></span><span class="k">def</span> <span class="nf">derived_array</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||||
<span class="sd">"""Construct a derived Obs for a matrix valued function according to func(data, **kwargs) using automatic differentiation.</span>
|
||||
|
||||
<span class="sd"> Parameters</span>
|
||||
<span class="sd"> ----------</span>
|
||||
<span class="sd"> func : object</span>
|
||||
<span class="sd"> arbitrary function of the form func(data, **kwargs). For the</span>
|
||||
<span class="sd"> automatic differentiation to work, all numpy functions have to have</span>
|
||||
<span class="sd"> the autograd wrapper (use 'import autograd.numpy as anp').</span>
|
||||
<span class="sd"> data : list</span>
|
||||
<span class="sd"> list of Obs, e.g. [obs1, obs2, obs3].</span>
|
||||
<span class="sd"> man_grad : list</span>
|
||||
<span class="sd"> manually supply a list or an array which contains the jacobian</span>
|
||||
<span class="sd"> of func. Use cautiously, supplying the wrong derivative will</span>
|
||||
<span class="sd"> not be intercepted.</span>
|
||||
<span class="sd"> """</span>
|
||||
|
||||
<span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
|
||||
<span class="n">raveled_data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># Workaround for matrix operations containing non Obs data</span>
|
||||
<span class="k">for</span> <span class="n">i_data</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i_data</span><span class="p">,</span> <span class="n">Obs</span><span class="p">):</span>
|
||||
<span class="n">first_name</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">first_shape</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">first_name</span><span class="p">]</span>
|
||||
<span class="n">first_idl</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">idl</span><span class="p">[</span><span class="n">first_name</span><span class="p">]</span>
|
||||
<span class="k">break</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">)):</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">(</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">)):</span>
|
||||
<span class="n">raveled_data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">Obs</span><span class="p">([</span><span class="n">raveled_data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">first_shape</span><span class="p">)],</span> <span class="p">[</span><span class="n">first_name</span><span class="p">],</span> <span class="n">idl</span><span class="o">=</span><span class="p">[</span><span class="n">first_idl</span><span class="p">])</span>
|
||||
|
||||
<span class="n">n_obs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">)</span>
|
||||
<span class="n">new_names</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">y</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="n">o</span><span class="o">.</span><span class="n">names</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">]</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">x</span><span class="p">]))</span>
|
||||
|
||||
<span class="n">is_merged</span> <span class="o">=</span> <span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">o</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">is_merged</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">,</span> <span class="n">raveled_data</span><span class="p">)))</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">}</span>
|
||||
<span class="n">reweighted</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">o</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">reweighted</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">,</span> <span class="n">raveled_data</span><span class="p">)))</span> <span class="o">></span> <span class="mi">0</span>
|
||||
<span class="n">new_idl_d</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="n">idl</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">i_data</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">:</span>
|
||||
<span class="n">tmp</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">idl</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">tmp</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">idl</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">tmp</span><span class="p">)</span>
|
||||
<span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">_merge_idx</span><span class="p">(</span><span class="n">idl</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_merged</span><span class="p">[</span><span class="n">name</span><span class="p">]:</span>
|
||||
<span class="n">is_merged</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span> <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="p">[</span><span class="o">*</span><span class="n">idl</span><span class="p">,</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">]]])))</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">data</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">o</span><span class="o">.</span><span class="n">value</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">data</span><span class="p">])</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">vectorize</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">value</span><span class="p">)(</span><span class="n">data</span><span class="p">)</span>
|
||||
|
||||
<span class="n">new_values</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="n">values</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||||
|
||||
<span class="n">new_r_values</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="n">tmp_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n_obs</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">item</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">):</span>
|
||||
<span class="n">tmp</span> <span class="o">=</span> <span class="n">item</span><span class="o">.</span><span class="n">r_values</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">tmp</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">tmp</span> <span class="o">=</span> <span class="n">item</span><span class="o">.</span><span class="n">value</span>
|
||||
<span class="n">tmp_values</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">tmp</span>
|
||||
<span class="n">tmp_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tmp_values</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||||
<span class="n">new_r_values</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="n">tmp_values</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="s1">'man_grad'</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
|
||||
<span class="n">deriv</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'man_grad'</span><span class="p">))</span>
|
||||
<span class="k">if</span> <span class="n">new_values</span><span class="o">.</span><span class="n">shape</span> <span class="o">+</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">deriv</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">'Manual derivative does not have correct shape.'</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">deriv</span> <span class="o">=</span> <span class="n">jacobian</span><span class="p">(</span><span class="n">func</span><span class="p">)(</span><span class="n">values</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||||
|
||||
<span class="n">final_result</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">new_values</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">object</span><span class="p">)</span>
|
||||
|
||||
<span class="n">d_extracted</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
|
||||
<span class="n">ens_length</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
|
||||
<span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">_expand_deltas_for_merge</span><span class="p">(</span><span class="n">o</span><span class="o">.</span><span class="n">deltas</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">o</span><span class="o">.</span><span class="n">idl</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">o</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">])</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">dat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span><span class="p">))])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span> <span class="o">+</span> <span class="p">(</span><span class="n">ens_length</span><span class="p">,</span> <span class="p">)))</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i_val</span><span class="p">,</span> <span class="n">new_val</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">ndenumerate</span><span class="p">(</span><span class="n">new_values</span><span class="p">):</span>
|
||||
<span class="n">new_deltas</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="n">ens_length</span> <span class="o">=</span> <span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">ens_length</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]):</span>
|
||||
<span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">tensordot</span><span class="p">(</span><span class="n">deriv</span><span class="p">[</span><span class="n">i_val</span> <span class="o">+</span> <span class="p">(</span><span class="n">i_dat</span><span class="p">,</span> <span class="p">)],</span> <span class="n">dat</span><span class="p">)</span>
|
||||
|
||||
<span class="n">new_samples</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">new_means</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">new_idl</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">is_merged</span><span class="p">[</span><span class="n">name</span><span class="p">]:</span>
|
||||
<span class="n">filtered_deltas</span><span class="p">,</span> <span class="n">filtered_idl_d</span> <span class="o">=</span> <span class="n">_filter_zeroes</span><span class="p">(</span><span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">filtered_deltas</span> <span class="o">=</span> <span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
||||
<span class="n">filtered_idl_d</span> <span class="o">=</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
||||
|
||||
<span class="n">new_samples</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">filtered_deltas</span><span class="p">)</span>
|
||||
<span class="n">new_idl</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">filtered_idl_d</span><span class="p">)</span>
|
||||
<span class="n">new_means</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">new_r_values</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="n">i_val</span><span class="p">])</span>
|
||||
<span class="n">final_result</span><span class="p">[</span><span class="n">i_val</span><span class="p">]</span> <span class="o">=</span> <span class="n">Obs</span><span class="p">(</span><span class="n">new_samples</span><span class="p">,</span> <span class="n">new_names</span><span class="p">,</span> <span class="n">means</span><span class="o">=</span><span class="n">new_means</span><span class="p">,</span> <span class="n">idl</span><span class="o">=</span><span class="n">new_idl</span><span class="p">)</span>
|
||||
<span class="n">final_result</span><span class="p">[</span><span class="n">i_val</span><span class="p">]</span><span class="o">.</span><span class="n">_value</span> <span class="o">=</span> <span class="n">new_val</span>
|
||||
<span class="n">final_result</span><span class="p">[</span><span class="n">i_val</span><span class="p">]</span><span class="o">.</span><span class="n">is_merged</span> <span class="o">=</span> <span class="n">is_merged</span>
|
||||
<span class="n">final_result</span><span class="p">[</span><span class="n">i_val</span><span class="p">]</span><span class="o">.</span><span class="n">reweighted</span> <span class="o">=</span> <span class="n">reweighted</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">final_result</span>
|
||||
</pre></div>
|
||||
|
||||
</details>
|
||||
|
||||
<div class="docstring"><p>Construct a derived Obs for a matrix valued function according to func(data, **kwargs) using automatic differentiation.</p>
|
||||
|
||||
<h6 id="parameters">Parameters</h6>
|
||||
|
||||
<ul>
|
||||
<li><strong>func</strong> (object):
|
||||
arbitrary function of the form func(data, **kwargs). For the
|
||||
automatic differentiation to work, all numpy functions have to have
|
||||
the autograd wrapper (use 'import autograd.numpy as anp').</li>
|
||||
<li><strong>data</strong> (list):
|
||||
list of Obs, e.g. [obs1, obs2, obs3].</li>
|
||||
<li><strong>man_grad</strong> (list):
|
||||
manually supply a list or an array which contains the jacobian
|
||||
of func. Use cautiously, supplying the wrong derivative will
|
||||
not be intercepted.</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
|
||||
</section>
|
||||
<section id="matmul">
|
||||
<div class="attr function"><a class="headerlink" href="#matmul">#  </a>
|
||||
|
@ -991,8 +733,8 @@ not be intercepted.</li>
|
|||
<span class="k">def</span> <span class="nf">multi_dot_i</span><span class="p">(</span><span class="n">operands</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">multi_dot</span><span class="p">(</span><span class="n">operands</span><span class="p">,</span> <span class="s1">'Imag'</span><span class="p">)</span>
|
||||
|
||||
<span class="n">Nr</span> <span class="o">=</span> <span class="n">derived_array</span><span class="p">(</span><span class="n">multi_dot_r</span><span class="p">,</span> <span class="n">extended_operands</span><span class="p">)</span>
|
||||
<span class="n">Ni</span> <span class="o">=</span> <span class="n">derived_array</span><span class="p">(</span><span class="n">multi_dot_i</span><span class="p">,</span> <span class="n">extended_operands</span><span class="p">)</span>
|
||||
<span class="n">Nr</span> <span class="o">=</span> <span class="n">derived_observable</span><span class="p">(</span><span class="n">multi_dot_r</span><span class="p">,</span> <span class="n">extended_operands</span><span class="p">,</span> <span class="n">array_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">Ni</span> <span class="o">=</span> <span class="n">derived_observable</span><span class="p">(</span><span class="n">multi_dot_i</span><span class="p">,</span> <span class="n">extended_operands</span><span class="p">,</span> <span class="n">array_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
|
||||
<span class="n">res</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">Nr</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">m</span><span class="p">),</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">ndenumerate</span><span class="p">(</span><span class="n">Nr</span><span class="p">):</span>
|
||||
|
@ -1005,7 +747,7 @@ not be intercepted.</li>
|
|||
<span class="k">for</span> <span class="n">op</span> <span class="ow">in</span> <span class="n">operands</span><span class="p">[</span><span class="mi">1</span><span class="p">:]:</span>
|
||||
<span class="n">stack</span> <span class="o">=</span> <span class="n">stack</span> <span class="o">@</span> <span class="n">op</span>
|
||||
<span class="k">return</span> <span class="n">stack</span>
|
||||
<span class="k">return</span> <span class="n">derived_array</span><span class="p">(</span><span class="n">multi_dot</span><span class="p">,</span> <span class="n">operands</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">derived_observable</span><span class="p">(</span><span class="n">multi_dot</span><span class="p">,</span> <span class="n">operands</span><span class="p">,</span> <span class="n">array_mode</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
|
||||
</details>
|
||||
|
|
|
@ -1346,7 +1346,7 @@
|
|||
<span class="k">return</span> <span class="n">deltas</span><span class="p">,</span> <span class="n">idx</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">derived_observable</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="nf">derived_observable</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">array_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||||
<span class="sd">"""Construct a derived Obs according to func(data, **kwargs) using automatic differentiation.</span>
|
||||
|
||||
<span class="sd"> Parameters</span>
|
||||
|
@ -1379,6 +1379,7 @@
|
|||
<span class="n">raveled_data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># Workaround for matrix operations containing non Obs data</span>
|
||||
<span class="c1"># TODO: Find more elegant solution here.</span>
|
||||
<span class="k">for</span> <span class="n">i_data</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i_data</span><span class="p">,</span> <span class="n">Obs</span><span class="p">):</span>
|
||||
<span class="n">first_name</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
@ -1401,11 +1402,13 @@
|
|||
|
||||
<span class="n">n_obs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">)</span>
|
||||
<span class="n">new_names</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">y</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="n">o</span><span class="o">.</span><span class="n">names</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">]</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">x</span><span class="p">]))</span>
|
||||
<span class="n">new_cov_names</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">y</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="n">o</span><span class="o">.</span><span class="n">cov_names</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">]</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">x</span><span class="p">]))</span>
|
||||
<span class="n">new_sample_names</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">new_names</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">new_cov_names</span><span class="p">))</span>
|
||||
|
||||
<span class="n">is_merged</span> <span class="o">=</span> <span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">o</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">is_merged</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">,</span> <span class="n">raveled_data</span><span class="p">)))</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">}</span>
|
||||
<span class="n">is_merged</span> <span class="o">=</span> <span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">o</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">is_merged</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">,</span> <span class="n">raveled_data</span><span class="p">)))</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_sample_names</span><span class="p">}</span>
|
||||
<span class="n">reweighted</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">o</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">reweighted</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">,</span> <span class="n">raveled_data</span><span class="p">)))</span> <span class="o">></span> <span class="mi">0</span>
|
||||
<span class="n">new_idl_d</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_sample_names</span><span class="p">:</span>
|
||||
<span class="n">idl</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">i_data</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">:</span>
|
||||
<span class="n">tmp</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">idl</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||||
|
@ -1427,7 +1430,7 @@
|
|||
<span class="n">multi</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
|
||||
<span class="n">new_r_values</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_sample_names</span><span class="p">:</span>
|
||||
<span class="n">tmp_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n_obs</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">item</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">):</span>
|
||||
<span class="n">tmp</span> <span class="o">=</span> <span class="n">item</span><span class="o">.</span><span class="n">r_values</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||||
|
@ -1469,9 +1472,42 @@
|
|||
|
||||
<span class="n">final_result</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">new_values</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">object</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">array_mode</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">:</span>
|
||||
|
||||
<span class="n">new_covobs_lengths</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">y</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[[(</span><span class="n">n</span><span class="p">,</span> <span class="n">o</span><span class="o">.</span><span class="n">covobs</span><span class="p">[</span><span class="n">n</span><span class="p">]</span><span class="o">.</span><span class="n">N</span><span class="p">)</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="n">o</span><span class="o">.</span><span class="n">cov_names</span><span class="p">]</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">]</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">x</span><span class="p">]))</span>
|
||||
|
||||
<span class="k">class</span> <span class="nc">_Zero_grad</span><span class="p">():</span>
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">N</span><span class="p">):</span>
|
||||
<span class="c1"># self.grad = np.zeros(N)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">N</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
|
||||
<span class="n">d_extracted</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="n">g_extracted</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_sample_names</span><span class="p">:</span>
|
||||
<span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">ens_length</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
|
||||
<span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">_expand_deltas_for_merge</span><span class="p">(</span><span class="n">o</span><span class="o">.</span><span class="n">deltas</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">ens_length</span><span class="p">)),</span> <span class="n">o</span><span class="o">.</span><span class="n">idl</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">]),</span> <span class="n">o</span><span class="o">.</span><span class="n">shape</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">ens_length</span><span class="p">),</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">])</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">dat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span><span class="p">))])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span> <span class="o">+</span> <span class="p">(</span><span class="n">ens_length</span><span class="p">,</span> <span class="p">)))</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_cov_names</span><span class="p">:</span>
|
||||
<span class="n">g_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">zero_grad</span> <span class="o">=</span> <span class="n">_Zero_grad</span><span class="p">(</span><span class="n">new_covobs_lengths</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
|
||||
<span class="n">g_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">o</span><span class="o">.</span><span class="n">covobs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">zero_grad</span><span class="p">)</span><span class="o">.</span><span class="n">grad</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">dat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span><span class="p">))])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span> <span class="o">+</span> <span class="p">(</span><span class="n">new_covobs_lengths</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="mi">1</span><span class="p">)))</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i_val</span><span class="p">,</span> <span class="n">new_val</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">ndenumerate</span><span class="p">(</span><span class="n">new_values</span><span class="p">):</span>
|
||||
<span class="n">new_deltas</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="n">new_grad</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">if</span> <span class="n">array_mode</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_sample_names</span><span class="p">:</span>
|
||||
<span class="n">ens_length</span> <span class="o">=</span> <span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">ens_length</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]):</span>
|
||||
<span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">tensordot</span><span class="p">(</span><span class="n">deriv</span><span class="p">[</span><span class="n">i_val</span> <span class="o">+</span> <span class="p">(</span><span class="n">i_dat</span><span class="p">,</span> <span class="p">)],</span> <span class="n">dat</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_cov_names</span><span class="p">:</span>
|
||||
<span class="n">new_grad</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">g_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]):</span>
|
||||
<span class="n">new_grad</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">tensordot</span><span class="p">(</span><span class="n">deriv</span><span class="p">[</span><span class="n">i_val</span> <span class="o">+</span> <span class="p">(</span><span class="n">i_dat</span><span class="p">,</span> <span class="p">)],</span> <span class="n">dat</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">j_obs</span><span class="p">,</span> <span class="n">obs</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">ndenumerate</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">obs</span><span class="o">.</span><span class="n">names</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">obs</span><span class="o">.</span><span class="n">cov_names</span><span class="p">:</span>
|
||||
|
@ -4587,12 +4623,12 @@ should agree with samples from a full jackknife analysis up to O(1/N).</li>
|
|||
|
||||
|
||||
<span class="def">def</span>
|
||||
<span class="name">derived_observable</span><span class="signature">(func, data, **kwargs)</span>:
|
||||
<span class="name">derived_observable</span><span class="signature">(func, data, array_mode=False, **kwargs)</span>:
|
||||
</div>
|
||||
|
||||
<details>
|
||||
<summary>View Source</summary>
|
||||
<div class="codehilite"><pre><span></span><span class="k">def</span> <span class="nf">derived_observable</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||||
<div class="codehilite"><pre><span></span><span class="k">def</span> <span class="nf">derived_observable</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">array_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||||
<span class="sd">"""Construct a derived Obs according to func(data, **kwargs) using automatic differentiation.</span>
|
||||
|
||||
<span class="sd"> Parameters</span>
|
||||
|
@ -4625,6 +4661,7 @@ should agree with samples from a full jackknife analysis up to O(1/N).</li>
|
|||
<span class="n">raveled_data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># Workaround for matrix operations containing non Obs data</span>
|
||||
<span class="c1"># TODO: Find more elegant solution here.</span>
|
||||
<span class="k">for</span> <span class="n">i_data</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i_data</span><span class="p">,</span> <span class="n">Obs</span><span class="p">):</span>
|
||||
<span class="n">first_name</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
@ -4647,11 +4684,13 @@ should agree with samples from a full jackknife analysis up to O(1/N).</li>
|
|||
|
||||
<span class="n">n_obs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">)</span>
|
||||
<span class="n">new_names</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">y</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="n">o</span><span class="o">.</span><span class="n">names</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">]</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">x</span><span class="p">]))</span>
|
||||
<span class="n">new_cov_names</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">y</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[</span><span class="n">o</span><span class="o">.</span><span class="n">cov_names</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">]</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">x</span><span class="p">]))</span>
|
||||
<span class="n">new_sample_names</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">new_names</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">new_cov_names</span><span class="p">))</span>
|
||||
|
||||
<span class="n">is_merged</span> <span class="o">=</span> <span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">o</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">is_merged</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">,</span> <span class="n">raveled_data</span><span class="p">)))</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">}</span>
|
||||
<span class="n">is_merged</span> <span class="o">=</span> <span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">o</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">is_merged</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">,</span> <span class="n">raveled_data</span><span class="p">)))</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_sample_names</span><span class="p">}</span>
|
||||
<span class="n">reweighted</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">o</span><span class="p">:</span> <span class="n">o</span><span class="o">.</span><span class="n">reweighted</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">,</span> <span class="n">raveled_data</span><span class="p">)))</span> <span class="o">></span> <span class="mi">0</span>
|
||||
<span class="n">new_idl_d</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_sample_names</span><span class="p">:</span>
|
||||
<span class="n">idl</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">i_data</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">:</span>
|
||||
<span class="n">tmp</span> <span class="o">=</span> <span class="n">i_data</span><span class="o">.</span><span class="n">idl</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||||
|
@ -4673,7 +4712,7 @@ should agree with samples from a full jackknife analysis up to O(1/N).</li>
|
|||
<span class="n">multi</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
|
||||
<span class="n">new_r_values</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_names</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_sample_names</span><span class="p">:</span>
|
||||
<span class="n">tmp_values</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n_obs</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">item</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">raveled_data</span><span class="p">):</span>
|
||||
<span class="n">tmp</span> <span class="o">=</span> <span class="n">item</span><span class="o">.</span><span class="n">r_values</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
||||
|
@ -4715,9 +4754,42 @@ should agree with samples from a full jackknife analysis up to O(1/N).</li>
|
|||
|
||||
<span class="n">final_result</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">new_values</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">object</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">array_mode</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">:</span>
|
||||
|
||||
<span class="n">new_covobs_lengths</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">y</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">[[(</span><span class="n">n</span><span class="p">,</span> <span class="n">o</span><span class="o">.</span><span class="n">covobs</span><span class="p">[</span><span class="n">n</span><span class="p">]</span><span class="o">.</span><span class="n">N</span><span class="p">)</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="n">o</span><span class="o">.</span><span class="n">cov_names</span><span class="p">]</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">raveled_data</span><span class="p">]</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">x</span><span class="p">]))</span>
|
||||
|
||||
<span class="k">class</span> <span class="nc">_Zero_grad</span><span class="p">():</span>
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">N</span><span class="p">):</span>
|
||||
<span class="c1"># self.grad = np.zeros(N)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">N</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
|
||||
<span class="n">d_extracted</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="n">g_extracted</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_sample_names</span><span class="p">:</span>
|
||||
<span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">ens_length</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
|
||||
<span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">_expand_deltas_for_merge</span><span class="p">(</span><span class="n">o</span><span class="o">.</span><span class="n">deltas</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">ens_length</span><span class="p">)),</span> <span class="n">o</span><span class="o">.</span><span class="n">idl</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">]),</span> <span class="n">o</span><span class="o">.</span><span class="n">shape</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">ens_length</span><span class="p">),</span> <span class="n">new_idl_d</span><span class="p">[</span><span class="n">name</span><span class="p">])</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">dat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span><span class="p">))])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span> <span class="o">+</span> <span class="p">(</span><span class="n">ens_length</span><span class="p">,</span> <span class="p">)))</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_cov_names</span><span class="p">:</span>
|
||||
<span class="n">g_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">zero_grad</span> <span class="o">=</span> <span class="n">_Zero_grad</span><span class="p">(</span><span class="n">new_covobs_lengths</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
|
||||
<span class="n">g_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">o</span><span class="o">.</span><span class="n">covobs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">zero_grad</span><span class="p">)</span><span class="o">.</span><span class="n">grad</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">dat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span><span class="p">))])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dat</span><span class="o">.</span><span class="n">shape</span> <span class="o">+</span> <span class="p">(</span><span class="n">new_covobs_lengths</span><span class="p">[</span><span class="n">name</span><span class="p">],</span> <span class="mi">1</span><span class="p">)))</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">i_val</span><span class="p">,</span> <span class="n">new_val</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">ndenumerate</span><span class="p">(</span><span class="n">new_values</span><span class="p">):</span>
|
||||
<span class="n">new_deltas</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="n">new_grad</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">if</span> <span class="n">array_mode</span> <span class="ow">is</span> <span class="kc">True</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_sample_names</span><span class="p">:</span>
|
||||
<span class="n">ens_length</span> <span class="o">=</span> <span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">ens_length</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">d_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]):</span>
|
||||
<span class="n">new_deltas</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">tensordot</span><span class="p">(</span><span class="n">deriv</span><span class="p">[</span><span class="n">i_val</span> <span class="o">+</span> <span class="p">(</span><span class="n">i_dat</span><span class="p">,</span> <span class="p">)],</span> <span class="n">dat</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">new_cov_names</span><span class="p">:</span>
|
||||
<span class="n">new_grad</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">for</span> <span class="n">i_dat</span><span class="p">,</span> <span class="n">dat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">g_extracted</span><span class="p">[</span><span class="n">name</span><span class="p">]):</span>
|
||||
<span class="n">new_grad</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">tensordot</span><span class="p">(</span><span class="n">deriv</span><span class="p">[</span><span class="n">i_val</span> <span class="o">+</span> <span class="p">(</span><span class="n">i_dat</span><span class="p">,</span> <span class="p">)],</span> <span class="n">dat</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">j_obs</span><span class="p">,</span> <span class="n">obs</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">ndenumerate</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">obs</span><span class="o">.</span><span class="n">names</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">obs</span><span class="o">.</span><span class="n">cov_names</span><span class="p">:</span>
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Add table
Reference in a new issue