Source code for xitorch.linalg.symeig

import warnings
import torch
from typing import Union, Mapping, Any, Optional, Tuple, Callable
from xitorch import LinearOperator
from xitorch._core.linop import MatrixLinearOperator
from xitorch.linalg.solve import solve
from xitorch.debug.modes import is_debug_enabled
from xitorch._utils.assertfuncs import assert_runtime
from xitorch._utils.misc import set_default_option, \
    dummy_context_manager, get_method, get_and_pop_keys
from xitorch._docstr.api_docstr import get_methods_docstr
from xitorch._impls.linalg.symeig import exacteig, davidson
from xitorch._utils.exceptions import MathWarning

__all__ = ["lsymeig", "usymeig", "symeig", "svd"]

def lsymeig(A: LinearOperator, neig: Optional[int] = None,
            M: Optional[LinearOperator] = None,
            bck_options: Mapping[str, Any] = {},
            method: Union[str, Callable, None] = None,
            **fwd_options) -> Tuple[torch.Tensor, torch.Tensor]:
    return symeig(A, neig, "lowest", M, method=method, bck_options=bck_options, **fwd_options)

def usymeig(A: LinearOperator, neig: Optional[int] = None,
            M: Optional[LinearOperator] = None,
            bck_options: Mapping[str, Any] = {},
            method: Union[str, Callable, None] = None,
            **fwd_options) -> Tuple[torch.Tensor, torch.Tensor]:
    return symeig(A, neig, "uppest", M, method=method, bck_options=bck_options, **fwd_options)

[docs]def symeig(A: LinearOperator, neig: Optional[int] = None, mode: str = "lowest", M: Optional[LinearOperator] = None, bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor]: r""" Obtain ``neig`` lowest eigenvalues and eigenvectors of a linear operator, .. math:: \mathbf{AX = MXE} where :math:`\mathbf{A}, \mathbf{M}` are linear operators, :math:`\mathbf{E}` is a diagonal matrix containing the eigenvalues, and :math:`\mathbf{X}` is a matrix containing the eigenvectors. This function can handle derivatives for degenerate cases by setting non-zero ``degen_atol`` and ``degen_rtol`` in the backward option using the expressions in [1]_. Arguments --------- A: xitorch.LinearOperator The linear operator object on which the eigenpairs are constructed. It must be a Hermitian linear operator with shape ``(*BA, q, q)`` neig: int or None The number of eigenpairs to be retrieved. If ``None``, all eigenpairs are retrieved mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``neig`` eigenpairs. If ``"uppest"``, it will take the uppermost ``neig``. M: xitorch.LinearOperator The transformation on the right hand side. If ``None``, then ``M=I``. If specified, it must be a Hermitian with shape ``(*BM, q, q)``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation with some additional arguments for computing the backward derivatives: * ``degen_atol`` (``float`` or None): Minimum absolute difference between two eigenvalues to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.6``. If 0.0, no special treatment on degeneracy is applied. (default: None) * ``degen_rtol`` (``float`` or None): Minimum relative difference between two eigenvalues to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.4``. If 0.0, no special treatment on degeneracy is applied. (default: None) Note: the default values of ``degen_atol`` and ``degen_rtol`` are going to change in the future. So, for future compatibility, please specify the specific values. method: str or callable or None Method for the eigendecomposition. If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (eigenvalues, eigenvectors) It will return eigenvalues and eigenvectors with shapes respectively ``(*BAM, neig)`` and ``(*BAM, na, neig)``, where ``*BAM`` is the broadcasted shape of ``*BA`` and ``*BM``. References ---------- .. [1] Muhammad F. Kasim, "Derivatives of partial eigendecomposition of a real symmetric matrix for degenerate cases". arXiv:2011.04366 (2020) `https://arxiv.org/abs/2011.04366 <https://arxiv.org/abs/2011.04366>`_ """ assert_runtime(A.is_hermitian, "The linear operator A must be Hermitian") assert_runtime( not torch.is_grad_enabled() or A.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator A must be " "implemented if using symeig with grad enabled") if M is not None: assert_runtime(M.is_hermitian, "The linear operator M must be Hermitian") assert_runtime(M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime( not torch.is_grad_enabled() or M.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator M must be " "implemented if using symeig with grad enabled") mode = mode.lower() if mode == "uppermost": mode = "uppest" if method is None: if isinstance(A, MatrixLinearOperator) and \ (M is None or isinstance(M, MatrixLinearOperator)): method = "exacteig" else: # TODO: implement robust LOBPCG and put it here method = "exacteig" if neig is None: neig = A.shape[-1] # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method == "exacteig": return exacteig(A, neig, mode, M) else: fwd_options["method"] = method # get the unique parameters of A & M params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return symeig_torchfcn.apply(A, neig, mode, M, fwd_options, bck_options, na, *params, *mparams)
[docs]def svd(A: LinearOperator, k: Optional[int] = None, mode: str = "uppest", bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Perform the singular value decomposition (SVD): .. math:: \mathbf{A} = \mathbf{U\Sigma V}^H where :math:`\mathbf{U}` and :math:`\mathbf{V}` are semi-unitary matrix and :math:`\mathbf{\Sigma}` is a diagonal matrix containing real non-negative numbers. This function can handle derivatives for degenerate singular values by setting non-zero ``degen_atol`` and ``degen_rtol`` in the backward option using the expressions in [1]_. Arguments --------- A: xitorch.LinearOperator The linear operator to be decomposed. It has a shape of ``(*BA, m, n)`` where ``(*BA)`` is the batched dimension of ``A``. k: int or None The number of decomposition obtained. If ``None``, it will be ``min(*A.shape[-2:])`` mode: str ``"lowest"`` or ``"uppermost"``/``"uppest"``. If ``"lowest"``, it will take the lowest ``k`` decomposition. If ``"uppest"``, it will take the uppermost ``k``. bck_options: dict Method-specific options for :func:`solve` which used in backpropagation calculation with some additional arguments for computing the backward derivatives: * ``degen_atol`` (``float`` or None): Minimum absolute difference between two singular values to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.6``. If 0.0, no special treatment on degeneracy is applied. (default: None) * ``degen_rtol`` (``float`` or None): Minimum relative difference between two singular values to be treated as degenerate. If None, it is ``torch.finfo(dtype).eps**0.4``. If 0.0, no special treatment on degeneracy is applied. (default: None) Note: the default values of ``degen_atol`` and ``degen_rtol`` are going to change in the future. So, for future compatibility, please specify the specific values. method: str or callable or None Method for the svd (same options for :func:`symeig`). If ``None``, it will choose ``"exacteig"``. **fwd_options Method-specific options (see method section below). Returns ------- tuple of tensors (u, s, vh) It will return ``u, s, vh`` with shapes respectively ``(*BA, m, k)``, ``(*BA, k)``, and ``(*BA, k, n)``. Note ---- It is a naive implementation of symmetric eigendecomposition of ``A.H @ A`` or ``A @ A.H`` (depending which one is cheaper) References ---------- .. [1] Muhammad F. Kasim, "Derivatives of partial eigendecomposition of a real symmetric matrix for degenerate cases". arXiv:2011.04366 (2020) `https://arxiv.org/abs/2011.04366 <https://arxiv.org/abs/2011.04366>`_ """ # A: (*BA, m, n) # adapted from scipy.sparse.linalg.svds if is_debug_enabled(): A.check() BA = A.shape[:-2] m = A.shape[-2] n = A.shape[-1] if m < n: AAsym = A.matmul(A.H, is_hermitian=True) min_nm = m else: AAsym = A.H.matmul(A, is_hermitian=True) min_nm = n eivals, eivecs = symeig(AAsym, k, mode, bck_options=bck_options, method=method, **fwd_options) # (*BA, k) and (*BA, min(mn), k) # clamp the eigenvalues to a small positive values to avoid numerical # instability eivals = torch.clamp(eivals, min=0.0) s = torch.sqrt(eivals) # (*BA, k) sdiv = torch.clamp(s, min=1e-12).unsqueeze(-2) # (*BA, 1, k) if m < n: u = eivecs # (*BA, m, k) v = A.rmm(u) / sdiv # (*BA, n, k) else: v = eivecs # (*BA, n, k) u = A.mm(v) / sdiv # (*BA, m, k) vh = v.transpose(-2, -1).conj() return u, s, vh
class symeig_torchfcn(torch.autograd.Function): @staticmethod def forward(ctx, A, neig, mode, M, fwd_options, bck_options, na, *amparams): # A: LinearOperator (*BA, q, q) # M: LinearOperator (*BM, q, q) or None # separate the sets of parameters params = amparams[:na] mparams = amparams[na:] config = set_default_option({ }, fwd_options) ctx.bck_config = set_default_option({ "degen_atol": None, "degen_rtol": None, }, bck_options) # options for calculating the backward (not for `solve`) alg_keys = ["degen_atol", "degen_rtol"] ctx.bck_alg_config = get_and_pop_keys(ctx.bck_config, alg_keys) method = config.pop("method") with A.uselinopparams(*params), M.uselinopparams(*mparams) if M is not None else dummy_context_manager(): methods = { "davidson": davidson, "custom_exacteig": custom_exacteig, } method_fcn = get_method("symeig", methods, method) evals, evecs = method_fcn(A, neig, mode, M, **config) # save for the backward # evals: (*BAM, neig) # evecs: (*BAM, na, neig) ctx.save_for_backward(evals, evecs, *amparams) ctx.na = na ctx.A = A ctx.M = M return evals, evecs @staticmethod def backward(ctx, grad_evals, grad_evecs): # grad_evals: (*BAM, neig) # grad_evecs: (*BAM, na, neig) # get the variables from ctx evals, evecs = ctx.saved_tensors[:2] na = ctx.na amparams = ctx.saved_tensors[2:] params = amparams[:na] mparams = amparams[na:] M = ctx.M A = ctx.A degen_atol: Optional[float] = ctx.bck_alg_config["degen_atol"] degen_rtol: Optional[float] = ctx.bck_alg_config["degen_rtol"] # set the default values of degen_*tol dtype = evals.dtype if degen_atol is None: degen_atol = torch.finfo(dtype).eps**0.6 if degen_rtol is None: degen_rtol = torch.finfo(dtype).eps**0.4 # check the degeneracy if degen_atol > 0 or degen_rtol > 0: # idx_degen: (*BAM, neig, neig) idx_degen, isdegenerate = _check_degen(evals, degen_atol, degen_rtol) else: isdegenerate = False if not isdegenerate: idx_degen = None # the loss function where the gradient will be retrieved # warnings: if not all params have the connection to the output of A, # it could cause an infinite loop because pytorch will keep looking # for the *params node and propagate further backward via the `evecs` # path. So make sure all the *params are all connected in the graph. with torch.enable_grad(): params = [p.clone().requires_grad_() for p in params] with A.uselinopparams(*params): loss = A.mm(evecs) # (*BAM, na, neig) # if degenerate, check the conditions for finite derivative if is_debug_enabled() and isdegenerate: xtg = torch.matmul(evecs.transpose(-2, -1).conj(), grad_evecs) req1 = idx_degen * (xtg - xtg.transpose(-2, -1).conj()) reqtol = xtg.abs().max() * grad_evecs.shape[-2] * torch.finfo(grad_evecs.dtype).eps if not torch.all(torch.abs(req1) <= reqtol): # if the requirements are not satisfied, raises a warning msg = ("Degeneracy appears but the loss function seem to depend " "strongly on the eigenvector. The gradient might be incorrect.\n") msg += "Eigenvalues:\n%s\n" % str(evals) msg += "Degenerate map:\n%s\n" % str(idx_degen) msg += "Requirements (should be all 0s):\n%s" % str(req1) warnings.warn(MathWarning(msg)) # calculate the contributions from the eigenvalues gevalsA = grad_evals.unsqueeze(-2) * evecs # (*BAM, na, neig) # calculate the contributions from the eigenvectors with M.uselinopparams(*mparams) if M is not None else dummy_context_manager(): # orthogonalize the grad_evecs with evecs B = _ortho(grad_evecs, evecs, D=idx_degen, M=M, mright=False) # Based on test cases, complex datatype is more likely to suffer from # singularity error when doing the inverse. Therefore, I add a small # offset here to prevent that from happening if torch.is_complex(B): evals_offset = evals + 1e-14 else: evals_offset = evals with A.uselinopparams(*params): gevecs = solve(A, -B, evals_offset, M, bck_options=ctx.bck_config, **ctx.bck_config) # (*BAM, na, neig) # orthogonalize gevecs w.r.t. evecs gevecsA = _ortho(gevecs, evecs, D=None, M=M, mright=True) # accummulate the gradient contributions gaccumA = gevalsA + gevecsA grad_params = torch.autograd.grad( outputs=(loss,), inputs=params, grad_outputs=(gaccumA,), create_graph=torch.is_grad_enabled(), ) grad_mparams = [] if ctx.M is not None: with torch.enable_grad(): mparams = [p.clone().requires_grad_() for p in mparams] with M.uselinopparams(*mparams): mloss = M.mm(evecs) # (*BAM, na, neig) gevalsM = -gevalsA * evals.unsqueeze(-2) gevecsM = -gevecsA * evals.unsqueeze(-2) # the contribution from the parallel elements gevecsM_par = (-0.5 * torch.einsum("...ae,...ae->...e", grad_evecs, evecs.conj()) ).unsqueeze(-2) * evecs # (*BAM, na, neig) gaccumM = gevalsM + gevecsM + gevecsM_par grad_mparams = torch.autograd.grad( outputs=(mloss,), inputs=mparams, grad_outputs=(gaccumM,), create_graph=torch.is_grad_enabled(), ) return (None, None, None, None, None, None, None, *grad_params, *grad_mparams) def _check_degen(evals: torch.Tensor, degen_atol: float, degen_rtol: float) -> \ Tuple[torch.Tensor, bool]: # evals: (*BAM, neig) # get the index of degeneracies neig = evals.shape[-1] evals_diff = torch.abs(evals.unsqueeze(-2) - evals.unsqueeze(-1)) # (*BAM, neig, neig) degen_thrsh = degen_atol + degen_rtol * torch.abs(evals).unsqueeze(-1) idx_degen = (evals_diff < degen_thrsh).to(evals.dtype) isdegenerate = bool(torch.sum(idx_degen) > torch.numel(evals)) return idx_degen, isdegenerate def _ortho(A: torch.Tensor, B: torch.Tensor, *, D: Optional[torch.Tensor] = None, M: Optional[LinearOperator] = None, mright: bool = False) -> torch.Tensor: # orthogonalize every column in A w.r.t. columns in B # D is the degeneracy map, if None, it is identity matrix # M is the overlap matrix (in LinearOperator) # mright indicates whether to operate M at the right or at the left # shapes: # A: (*BAM, na, neig) # B: (*BAM, na, neig) if D is None: # contracted using opt_einsum str1 = "...rc,...rc->...c" Bconj = B.conj() if M is None: return A - torch.einsum(str1, A, Bconj).unsqueeze(-2) * B elif mright: return A - torch.einsum(str1, M.mm(A), Bconj).unsqueeze(-2) * B else: return A - M.mm(torch.einsum(str1, A, Bconj).unsqueeze(-2) * B) else: BH = B.transpose(-2, -1).conj() if M is None: DBHA = D * torch.matmul(BH, A) return A - torch.matmul(B, DBHA) elif mright: DBHA = D * torch.matmul(BH, M.mm(A)) return A - torch.matmul(B, DBHA) else: DBHA = D * torch.matmul(BH, A) return A - M.mm(torch.matmul(B, DBHA)) def custom_exacteig(A, neig, mode, M=None, **options): return exacteig(A, neig, mode, M) # docstring completion _symeig_methods = { "exacteig": exacteig, "davidson": davidson, } ignore_kwargs = ["M", "mparams"] symeig.__doc__ = get_methods_docstr(symeig, _symeig_methods, ignore_kwargs) svd.__doc__ = get_methods_docstr(svd, _symeig_methods)