mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-03-16 07:10:24 +01:00
Documentation updated
This commit is contained in:
parent
f13ddce69c
commit
97014d2d25
2 changed files with 1 additions and 94 deletions
|
@ -80,9 +80,6 @@
|
|||
<li>
|
||||
<a class="function" href="#slogdet">slogdet</a>
|
||||
</li>
|
||||
<li>
|
||||
<a class="function" href="#grad_eig">grad_eig</a>
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
@ -106,9 +103,6 @@
|
|||
<span class="kn">import</span> <span class="nn">autograd.numpy</span> <span class="k">as</span> <span class="nn">anp</span> <span class="c1"># Thinly-wrapped numpy</span>
|
||||
<span class="kn">from</span> <span class="nn">.obs</span> <span class="kn">import</span> <span class="n">derived_observable</span><span class="p">,</span> <span class="n">CObs</span><span class="p">,</span> <span class="n">Obs</span><span class="p">,</span> <span class="n">import_jackknife</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
|
||||
<span class="kn">from</span> <span class="nn">autograd.extend</span> <span class="kn">import</span> <span class="n">defvjp</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="o">*</span><span class="n">operands</span><span class="p">):</span>
|
||||
<span class="sd">"""Matrix multiply all operands.</span>
|
||||
|
@ -631,54 +625,6 @@
|
|||
<span class="n">res_mat2</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">row</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">res_mat0</span><span class="p">)</span> <span class="o">@</span> <span class="n">np</span><span class="o">.</span><span class="n">identity</span><span class="p">(</span><span class="n">mid_index</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">res_mat1</span><span class="p">)</span> <span class="o">@</span> <span class="n">np</span><span class="o">.</span><span class="n">identity</span><span class="p">(</span><span class="n">mid_index</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">res_mat2</span><span class="p">)</span> <span class="o">@</span> <span class="n">np</span><span class="o">.</span><span class="n">identity</span><span class="p">(</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
|
||||
|
||||
|
||||
<span class="c1"># This code block is directly taken from the current master branch of autograd and remains</span>
|
||||
<span class="c1"># only until the new version is released on PyPi</span>
|
||||
<span class="n">_dot</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">anp</span><span class="o">.</span><span class="n">einsum</span><span class="p">,</span> <span class="s1">'...ij,...jk->...ik'</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="c1"># batched diag</span>
|
||||
<span class="k">def</span> <span class="nf">_diag</span><span class="p">(</span><span class="n">a</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">anp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="o">*</span> <span class="n">a</span>
|
||||
|
||||
|
||||
<span class="c1"># batched diagonal, similar to matrix_diag in tensorflow</span>
|
||||
<span class="k">def</span> <span class="nf">_matrix_diag</span><span class="p">(</span><span class="n">a</span><span class="p">):</span>
|
||||
<span class="n">reps</span> <span class="o">=</span> <span class="n">anp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||||
<span class="n">reps</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="n">reps</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="n">newshape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">+</span> <span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span>
|
||||
<span class="k">return</span> <span class="n">_diag</span><span class="p">(</span><span class="n">anp</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">reps</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">newshape</span><span class="p">))</span>
|
||||
|
||||
<span class="c1"># https://arxiv.org/pdf/1701.00392.pdf Eq(4.77)</span>
|
||||
<span class="c1"># Note the formula from Sec3.1 in https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf is incomplete</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">grad_eig</span><span class="p">(</span><span class="n">ans</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="sd">"""Gradient of a general square (complex valued) matrix"""</span>
|
||||
<span class="n">e</span><span class="p">,</span> <span class="n">u</span> <span class="o">=</span> <span class="n">ans</span> <span class="c1"># eigenvalues as 1d array, eigenvectors in columns</span>
|
||||
<span class="n">n</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">vjp</span><span class="p">(</span><span class="n">g</span><span class="p">):</span>
|
||||
<span class="n">ge</span><span class="p">,</span> <span class="n">gu</span> <span class="o">=</span> <span class="n">g</span>
|
||||
<span class="n">ge</span> <span class="o">=</span> <span class="n">_matrix_diag</span><span class="p">(</span><span class="n">ge</span><span class="p">)</span>
|
||||
<span class="n">f</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="n">e</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">anp</span><span class="o">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:]</span> <span class="o">-</span> <span class="n">e</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:,</span> <span class="n">anp</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="o">+</span> <span class="mf">1.e-20</span><span class="p">)</span>
|
||||
<span class="n">f</span> <span class="o">-=</span> <span class="n">_diag</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
|
||||
<span class="n">ut</span> <span class="o">=</span> <span class="n">anp</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span>
|
||||
<span class="n">r1</span> <span class="o">=</span> <span class="n">f</span> <span class="o">*</span> <span class="n">_dot</span><span class="p">(</span><span class="n">ut</span><span class="p">,</span> <span class="n">gu</span><span class="p">)</span>
|
||||
<span class="n">r2</span> <span class="o">=</span> <span class="o">-</span><span class="n">f</span> <span class="o">*</span> <span class="p">(</span><span class="n">_dot</span><span class="p">(</span><span class="n">_dot</span><span class="p">(</span><span class="n">ut</span><span class="p">,</span> <span class="n">anp</span><span class="o">.</span><span class="n">conj</span><span class="p">(</span><span class="n">u</span><span class="p">)),</span> <span class="n">anp</span><span class="o">.</span><span class="n">real</span><span class="p">(</span><span class="n">_dot</span><span class="p">(</span><span class="n">ut</span><span class="p">,</span> <span class="n">gu</span><span class="p">))</span> <span class="o">*</span> <span class="n">anp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">n</span><span class="p">)))</span>
|
||||
<span class="n">r</span> <span class="o">=</span> <span class="n">_dot</span><span class="p">(</span><span class="n">_dot</span><span class="p">(</span><span class="n">anp</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">inv</span><span class="p">(</span><span class="n">ut</span><span class="p">),</span> <span class="n">ge</span> <span class="o">+</span> <span class="n">r1</span> <span class="o">+</span> <span class="n">r2</span><span class="p">),</span> <span class="n">ut</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">anp</span><span class="o">.</span><span class="n">iscomplexobj</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="n">r</span> <span class="o">=</span> <span class="n">anp</span><span class="o">.</span><span class="n">real</span><span class="p">(</span><span class="n">r</span><span class="p">)</span>
|
||||
<span class="c1"># the derivative is still complex for real input (imaginary delta is allowed), real output</span>
|
||||
<span class="c1"># but the derivative should be real in real input case when imaginary delta is forbidden</span>
|
||||
<span class="k">return</span> <span class="n">r</span>
|
||||
<span class="k">return</span> <span class="n">vjp</span>
|
||||
|
||||
|
||||
<span class="n">defvjp</span><span class="p">(</span><span class="n">anp</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">eig</span><span class="p">,</span> <span class="n">grad_eig</span><span class="p">)</span>
|
||||
<span class="c1"># End of the code block from autograd.master</span>
|
||||
</pre></div>
|
||||
|
||||
</details>
|
||||
|
@ -1184,45 +1130,6 @@ Obs valued.</li>
|
|||
</div>
|
||||
|
||||
|
||||
</section>
|
||||
<section id="grad_eig">
|
||||
<div class="attr function"><a class="headerlink" href="#grad_eig">#  </a>
|
||||
|
||||
|
||||
<span class="def">def</span>
|
||||
<span class="name">grad_eig</span><span class="signature">(ans, x)</span>:
|
||||
</div>
|
||||
|
||||
<details>
|
||||
<summary>View Source</summary>
|
||||
<div class="codehilite"><pre><span></span><span class="k">def</span> <span class="nf">grad_eig</span><span class="p">(</span><span class="n">ans</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="sd">"""Gradient of a general square (complex valued) matrix"""</span>
|
||||
<span class="n">e</span><span class="p">,</span> <span class="n">u</span> <span class="o">=</span> <span class="n">ans</span> <span class="c1"># eigenvalues as 1d array, eigenvectors in columns</span>
|
||||
<span class="n">n</span> <span class="o">=</span> <span class="n">e</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">vjp</span><span class="p">(</span><span class="n">g</span><span class="p">):</span>
|
||||
<span class="n">ge</span><span class="p">,</span> <span class="n">gu</span> <span class="o">=</span> <span class="n">g</span>
|
||||
<span class="n">ge</span> <span class="o">=</span> <span class="n">_matrix_diag</span><span class="p">(</span><span class="n">ge</span><span class="p">)</span>
|
||||
<span class="n">f</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="n">e</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">anp</span><span class="o">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:]</span> <span class="o">-</span> <span class="n">e</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:,</span> <span class="n">anp</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="o">+</span> <span class="mf">1.e-20</span><span class="p">)</span>
|
||||
<span class="n">f</span> <span class="o">-=</span> <span class="n">_diag</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
|
||||
<span class="n">ut</span> <span class="o">=</span> <span class="n">anp</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span>
|
||||
<span class="n">r1</span> <span class="o">=</span> <span class="n">f</span> <span class="o">*</span> <span class="n">_dot</span><span class="p">(</span><span class="n">ut</span><span class="p">,</span> <span class="n">gu</span><span class="p">)</span>
|
||||
<span class="n">r2</span> <span class="o">=</span> <span class="o">-</span><span class="n">f</span> <span class="o">*</span> <span class="p">(</span><span class="n">_dot</span><span class="p">(</span><span class="n">_dot</span><span class="p">(</span><span class="n">ut</span><span class="p">,</span> <span class="n">anp</span><span class="o">.</span><span class="n">conj</span><span class="p">(</span><span class="n">u</span><span class="p">)),</span> <span class="n">anp</span><span class="o">.</span><span class="n">real</span><span class="p">(</span><span class="n">_dot</span><span class="p">(</span><span class="n">ut</span><span class="p">,</span> <span class="n">gu</span><span class="p">))</span> <span class="o">*</span> <span class="n">anp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">n</span><span class="p">)))</span>
|
||||
<span class="n">r</span> <span class="o">=</span> <span class="n">_dot</span><span class="p">(</span><span class="n">_dot</span><span class="p">(</span><span class="n">anp</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">inv</span><span class="p">(</span><span class="n">ut</span><span class="p">),</span> <span class="n">ge</span> <span class="o">+</span> <span class="n">r1</span> <span class="o">+</span> <span class="n">r2</span><span class="p">),</span> <span class="n">ut</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">anp</span><span class="o">.</span><span class="n">iscomplexobj</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="n">r</span> <span class="o">=</span> <span class="n">anp</span><span class="o">.</span><span class="n">real</span><span class="p">(</span><span class="n">r</span><span class="p">)</span>
|
||||
<span class="c1"># the derivative is still complex for real input (imaginary delta is allowed), real output</span>
|
||||
<span class="c1"># but the derivative should be real in real input case when imaginary delta is forbidden</span>
|
||||
<span class="k">return</span> <span class="n">r</span>
|
||||
<span class="k">return</span> <span class="n">vjp</span>
|
||||
</pre></div>
|
||||
|
||||
</details>
|
||||
|
||||
<div class="docstring"><p>Gradient of a general square (complex valued) matrix</p>
|
||||
</div>
|
||||
|
||||
|
||||
</section>
|
||||
</main>
|
||||
<script>
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Add table
Reference in a new issue