Derivatives of xitorch.linalg.symeig

Author: Muhammad Firmansyah Kasim (2020)

Problem

The function xitorch.linalg.symeig decomposes a linear operator to its \(k\) smallest or largest eigenvectors and eigenvalues,

\[\mathbf{AX} = \mathbf{MXE}\]

where \(\mathbf{A}, \mathbf{M}\) are symmetric \(n\times n\) linear operators that act as the inputs of the function. The outputs: \(\mathbf{X}\) is an \(n\times k\) matrix containing the eigenvectors on its column, and \(\mathbf{E}\) is a \(k\times k\) diagonal matrix containing the corresponding eigenvalues.

The linear operators \(\mathbf{A}\) and \(\mathbf{M}\) have parameters that their elements depend on, which are denoted by \(\theta_A\) and \(\theta_M\), respectively. In this case, we only consider 1 parameters for each linear operator. Extending it to multiple parameters for one linear operator can be done trivially because the obtained expression will be similar to other parameters.

In this page, we will derive the expression for backward derivative (a.k.a. the vector-Jacobian product) of the linear operators parameters: \(\overline{\theta_A} \equiv \partial \mathcal{L}/\partial \theta_A\) and \(\overline{\theta_M} \equiv \partial \mathcal{L}/\partial \theta_M\) as functions of \(\mathbf{\overline{X}} \equiv \partial \mathcal{L}/\partial \mathbf{X}\) and \(\mathbf{\overline{E}} \equiv \partial \mathcal{L}/\partial \mathbf{E}\) for a loss value \(\mathcal{L}\). One challenge is that we only have implicit linear operators \(\mathbf{A}\) and \(\mathbf{M}\) where they are expressed by their matrix-vector multiplication and right-multiplications without explicit representation on their matrix elements. Another challenge is that only \(k\) eigenpairs are available, so calculations involving full eigenvectors and eigenvalues cannot be used.

This derivation assumes the eigenvalues are all unique. Cases with degenerate eigenvalues are treated differently.

Forward derivative of a single eigenpair

Let’s start with the eigendecomposition expression for one eigenvector and eigenvalue,

(1)\[\mathbf{Ax} = \lambda \mathbf{Mx},\]

where the eigenvector is normalized,

(2)\[\mathbf{x}^T\mathbf{Mx} = 1.\]

Applying first order derivative to the equations above we obtain,

(3)\[\mathbf{A'x} + \mathbf{A}\mathbf{x'} = \lambda' \mathbf{Mx} + \lambda \mathbf{M'x} + \lambda \mathbf{Mx'}\]

and

(4)\[\mathbf{x}^T \mathbf{M'x} + 2\mathbf{x}^T \mathbf{Mx'} = 0.\]

Applying \(\mathbf{x}^T\) on both sides of equation (3), we obtain

(5)\[\mathbf{x}^T\mathbf{A'x} + \mathbf{x}^T\mathbf{A}\mathbf{x'} = \lambda' \mathbf{x}^T\mathbf{Mx} + \lambda \mathbf{x}^T\mathbf{M'x} + \lambda \mathbf{x}^T\mathbf{Mx'}.\]

Substituting \(\mathbf{x}^T\mathbf{Mx}\) from equation (2) and \(\mathbf{x}^T\mathbf{A}\) from the transposed equation (1), we get the derivative of the eigenvalue,

(6)\[\lambda' = \mathbf{x}^T(\mathbf{A'} - \lambda\mathbf{M'})\mathbf{x}.\]

To obtain the derivative of the eigenvector, we substitute (6) to (3) and rearrange it to obtain,

(7)\[(\mathbf{A} - \lambda \mathbf{M})\mathbf{x'} = -(\mathbf{I} - \mathbf{Mxx}^T)(\mathbf{A'} - \lambda \mathbf{M'})\mathbf{x}\]

The matrix \((\mathbf{A} - \lambda \mathbf{M})\) is not a full rank matrix, so when multiplied to \(\mathbf{x'}\), some of its component is lost. To solve this, we split \(\mathbf{x'}\) into 2 components, orthogonal (\(\mathbf{x_M'}\)) and parallel (\(\mathbf{x_{-M}'}\)):

(8)\[\mathbf{x'} = \mathbf{x_M'} + \mathbf{x_{-M}'}\]

where

(9)\[\begin{split}\left(\mathbf{I} - \mathbf{xx}^T\mathbf{M}\right) \mathbf{x_M'} &= \mathbf{x_M'} \\ \left(\mathbf{I} - \mathbf{xx}^T\mathbf{M}\right) \mathbf{x_{-M}'} &= \mathbf{0}.\end{split}\]

Simple arrangement of the equations above yields

(10)\[\begin{split}\mathbf{xx}^T\mathbf{M}\mathbf{x_M'} &= \mathbf{0} \\ \mathbf{x_{-M}'} &= \mathbf{xx}^T\mathbf{M}\mathbf{x_{-M}'}.\end{split}\]

Using the equations (10) in equation (4) and (7) produces

(11)\[\begin{split}\mathbf{x}^T\mathbf{Mx_{-M}'} &= -\frac{1}{2}\mathbf{x}^T\mathbf{M'x} \\ (\mathbf{A} - \lambda \mathbf{M})\mathbf{x_M'} &= -(\mathbf{I} - \mathbf{Mxx}^T)(\mathbf{A'} - \lambda \mathbf{M'})\mathbf{x}.\end{split}\]

Multiplying the first equation above with \(\mathbf{x}\) and using the second equation from (10), we obtain,

(12)\[\mathbf{x_{-M}'} = -\frac{1}{2}\mathbf{xx}^T\mathbf{M'x}.\]

Moving the matrix \((\mathbf{A} - \lambda \mathbf{M})\) on the second equation of (11) to the right hand side gives us

(13)\[\mathbf{x_M'} = -(\mathbf{I} - \mathbf{xx}^T\mathbf{M})(\mathbf{A} - \lambda \mathbf{M})^{+} (\mathbf{I} - \mathbf{Mxx}^T)(\mathbf{A'} - \lambda \mathbf{M'})\mathbf{x},\]

where the symbol \(\mathbf{C}^{+}\) indicates the pseudo-inverse of the matrix. The additional term \((\mathbf{I} - \mathbf{xx}^T\mathbf{M})\) is to make sure the result is orthogonal. The calculation of the pseudo-inverse can be obtained using standard linear equation solver.

To summarize, the forward derivatives are given by

\[\begin{split}\lambda' &= \mathbf{x}^T(\mathbf{A'} - \lambda\mathbf{M'})\mathbf{x}. \\ \mathbf{x'} &= -\frac{1}{2}\mathbf{xx}^T\mathbf{M'x} - (\mathbf{I} - \mathbf{xx}^T\mathbf{M})(\mathbf{A} - \lambda \mathbf{M})^{+} (\mathbf{I} - \mathbf{Mxx}^T)(\mathbf{A'} - \lambda \mathbf{M'})\mathbf{x}.\end{split}\]

Backward derivative

From the forward derivatives, it is relatively straightforward to get the backward derivatives. Using the relation

\[\mathbf{P'} = \mathbf{QR'S} \implies \mathbf{\overline{R}} = \mathbf{Q}^T\mathbf{\overline{P}}\mathbf{S}^T,\]

we get

\[\begin{split}\mathbf{\overline{A}} &= \mathbf{xx}^T \overline{\lambda} - (\mathbf{I} - \mathbf{xx}^T\mathbf{M})(\mathbf{A} - \lambda \mathbf{M})^{+} (\mathbf{I} - \mathbf{Mxx}^T)\mathbf{\overline{x}} \mathbf{x}^T \\ \mathbf{\overline{M}} &= -\mathbf{xx}^T \lambda \overline{\lambda} -\frac{1}{2}\mathbf{xx}^T\mathbf{\overline{x}}\mathbf{x}^T + \lambda (\mathbf{I} - \mathbf{xx}^T\mathbf{M})(\mathbf{A} - \lambda \mathbf{M})^{+} (\mathbf{I} - \mathbf{Mxx}^T)\mathbf{\overline{x}} \mathbf{x}^T.\end{split}\]

For cases with multiple eigenpairs, the contributions should be summed from all eigenvalues and eigenvectors,

(14)\[\begin{split}\mathbf{\overline{A}} &= \mathbf{X\overline{E}X}^T - \mathbf{\overline{Y}X}^T\\ \mathbf{\overline{M}} &= \mathbf{XE\overline{E}X}^T - \frac{1}{2}\mathbf{X}(\mathbf{I}\circ\mathbf{X}^T\mathbf{\overline{X}})\mathbf{X}^T + \mathbf{\overline{Y}EX}^T.\end{split}\]

where \(\circ\) indicates element-wise multiplication and

(15)\[\begin{split}\mathbf{\overline{Y}} &= \mathbf{\overline{V}} - \mathbf{X}\left(\mathbf{I}\circ\mathbf{X}^T\mathbf{M\overline{V}}\right) \\ \mathbf{\overline{V}} &: \mathrm{solve}\ \mathbf{A\overline{V}} - \mathbf{M\overline{V}E} = \mathbf{\overline{X}} - \mathbf{MX} \left(\mathbf{I}\circ\mathbf{X}^T \mathbf{\overline{X}}\right).\end{split}\]

Given the gradient of each elements in the linear operator, the gradient with respect to the parameters of \(\mathbf{A}\) and \(\mathbf{M}\) are

\[\begin{split}\overline{\theta_A} &= \mathrm{tr}\left(\mathbf{\overline{A}}^T \frac{\partial \mathbf{A}}{\partial \theta_A}\right) \\ \overline{\theta_M} &= \mathrm{tr}\left(\mathbf{\overline{M}}^T \frac{\partial \mathbf{M}}{\partial \theta_M}\right)\end{split}\]

or more conveniently written as

\[\begin{split}\overline{\theta_A} &= \mathrm{tr}\left[(\mathbf{X\overline{E} - \overline{Y}})^T \frac{\partial (\mathbf{AX})}{\partial \theta_A}\right] \\ \overline{\theta_M} &= \mathrm{tr}\left[ \left(\mathbf{XE\overline{E}} - \frac{1}{2}\mathbf{X}(\mathbf{I}\circ\mathbf{X}^T\mathbf{\overline{X}}) + \mathbf{\overline{Y}E}\right)^T \frac{\partial (\mathbf{MX})}{\partial \theta_M}\right].\end{split}\]

In PyTorch, the terms above can be calculated by propagating the gradient from \(\mathbf{AX}\) or \(\mathbf{MX}\) with initial gradient given on the left term, e.g. \((\mathbf{X\overline{E}} - \mathbf{\overline{Y}})\) for \(\overline{\theta_A}\).