Source code for xitorch.linalg.solve

import torch
import warnings
from typing import Union, Any, Mapping, Optional, Callable
from xitorch import LinearOperator
from xitorch._core.linop import MatrixLinearOperator
from xitorch._utils.assertfuncs import assert_runtime
from xitorch._utils.misc import set_default_option, dummy_context_manager, get_method
from xitorch._docstr.api_docstr import get_methods_docstr
from xitorch.debug.modes import is_debug_enabled
from xitorch._impls.linalg.solve import exactsolve, wrap_gmres, \
    cg, bicgstab, broyden1_solve, _get_batchdims, gmres

[docs]def solve(A: LinearOperator, B: torch.Tensor, E: Union[torch.Tensor, None] = None, M: Optional[LinearOperator] = None, bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> torch.Tensor: r""" Performing iterative method to solve the equation .. math:: \mathbf{AX=B} or .. math:: \mathbf{AX-MXE=B} where :math:`\mathbf{E}` is a diagonal matrix. This function can also solve batched multiple inverse equation at the same time by applying :math:`\mathbf{A}` to a tensor :math:`\mathbf{X}` with shape ``(...,na,ncols)``. The applied :math:`\mathbf{E}` are not necessarily identical for each column. Arguments --------- A: xitorch.LinearOperator A linear operator that takes an input ``X`` and produce the vectors in the same space as ``B``. It should have the shape of ``(*BA, na, na)`` B: torch.Tensor The tensor on the right hand side with shape ``(*BB, na, ncols)`` E: torch.Tensor or None If a tensor, it will solve :math:`\mathbf{AX-MXE = B}`. It will be regarded as the diagonal of the matrix. Otherwise, it just solves :math:`\mathbf{AX = B}` and ``M`` is ignored. If it is a tensor, it should have shape of ``(*BE, ncols)``. M: xitorch.LinearOperator or None The transformation on the ``E`` side. If ``E`` is ``None``, then this argument is ignored. If E is not ``None`` and ``M`` is ``None``, then ``M=I``. If LinearOperator, it must be Hermitian with shape ``(*BM, na, na)``. bck_options: dict Options of the iterative solver in the backward calculation. method: str or callable or None The method of linear equation solver. If ``None``, it will choose ``"cg"`` or ``"bicgstab"`` based on the matrices symmetry. `Note`: default method will be changed quite frequently, so if you want future compatibility, please specify a method. **fwd_options Method-specific options (see method below) Returns ------- torch.Tensor The tensor :math:`\mathbf{X}` that satisfies :math:`\mathbf{AX-MXE=B}`. """ assert_runtime(A.shape[-1] == A.shape[-2], "The linear operator A must have a square shape") assert_runtime(A.shape[-1] == B.shape[-2], "Mismatch shape of A & B (A: %s, B: %s)" % (A.shape, B.shape)) assert_runtime( not torch.is_grad_enabled() or A.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator A must be " "implemented if using solve with grad enabled") if M is not None: assert_runtime(M.shape[-1] == M.shape[-2], "The linear operator M must have a square shape") assert_runtime(M.shape[-1] == A.shape[-1], "The shape of A & M must match (A: %s, M: %s)" % (A.shape, M.shape)) assert_runtime(M.is_hermitian, "The linear operator M must be a Hermitian matrix") assert_runtime( not torch.is_grad_enabled() or M.is_getparamnames_implemented, "The _getparamnames(self, prefix) of linear operator M must be " "implemented if using solve with grad enabled") if E is not None: assert_runtime(E.shape[-1] == B.shape[-1], "The last dimension of E & B must match (E: %s, B: %s)" % (E.shape, B.shape)) if E is None and M is not None: warnings.warn("M is supplied but will be ignored because E is not supplied") # perform expensive check if debug mode is enabled if is_debug_enabled(): A.check() if M is not None: M.check() if method is None: if isinstance(A, MatrixLinearOperator) and \ (M is None or isinstance(M, MatrixLinearOperator)): method = "exactsolve" elif A.shape[-1] <= 5: # for small matrix method = "exactsolve" else: is_hermit = A.is_hermitian and (M is None or M.is_hermitian) method = "cg" if is_hermit else "bicgstab" if method == "exactsolve": return exactsolve(A, B, E, M) else: # get the unique parameters of A params = A.getlinopparams() mparams = M.getlinopparams() if M is not None else [] na = len(params) return solve_torchfcn.apply( A, B, E, M, method, fwd_options, bck_options, na, *params, *mparams)
class solve_torchfcn(torch.autograd.Function): @staticmethod def forward(ctx, A, B, E, M, method, fwd_options, bck_options, na, *all_params): # A: (*BA, nr, nr) # B: (*BB, nr, ncols) # E: (*BE, ncols) or None # M: (*BM, nr, nr) or None # all_params: list of tensor of any shape # returns: (*BABEM, nr, ncols) # separate the parameters for A and for M params = all_params[:na] mparams = all_params[na:] config = set_default_option({ }, fwd_options) ctx.bck_config = set_default_option({ }, bck_options) if torch.all(B == 0): # special case dims = (*_get_batchdims(A, B, E, M), *B.shape[-2:]) x = torch.zeros(dims, dtype=B.dtype, device=B.device) else: with A.uselinopparams(*params), M.uselinopparams(*mparams) if M is not None else dummy_context_manager(): methods = { "custom_exactsolve": custom_exactsolve, "scipy_gmres": wrap_gmres, "broyden1": broyden1_solve, "cg": cg, "bicgstab": bicgstab, "gmres": gmres, } method_fcn = get_method("solve", methods, method) x = method_fcn(A, B, E, M, **config) ctx.e_is_none = E is None ctx.A = A ctx.M = M if ctx.e_is_none: ctx.save_for_backward(x, *all_params) else: ctx.save_for_backward(x, E, *all_params) ctx.na = na return x @staticmethod def backward(ctx, grad_x): # grad_x: (*BABEM, nr, ncols) # x: (*BABEM, nr, ncols) x = ctx.saved_tensors[0] idx_all_params = 1 if ctx.e_is_none else 2 all_params = ctx.saved_tensors[idx_all_params:] params = all_params[:ctx.na] mparams = all_params[ctx.na:] E = None if ctx.e_is_none else ctx.saved_tensors[1] # solve (A-biases*M)^T v = grad_x # this is the grad of B with ctx.A.uselinopparams(*params), \ ctx.M.uselinopparams(*mparams) if ctx.M is not None else dummy_context_manager(): AT = ctx.A.H # (*BA, nr, nr) MT = ctx.M.H if ctx.M is not None else None # (*BM, nr, nr) Econj = E.conj() if E is not None else None v = solve(AT, grad_x, Econj, MT, bck_options=ctx.bck_config, **ctx.bck_config) # (*BABEM, nr, ncols) grad_B = v # calculate the grad of matrices parameters with torch.enable_grad(): params = [p.clone().requires_grad_() for p in params] with ctx.A.uselinopparams(*params): loss = -ctx.A.mm(x) # (*BABEM, nr, ncols) grad_params = torch.autograd.grad((loss,), params, grad_outputs=(v,), create_graph=torch.is_grad_enabled(), allow_unused=True) # calculate the biases gradient grad_E = None if E is not None: if ctx.M is None: Mx = x else: with ctx.M.uselinopparams(*mparams): Mx = ctx.M.mm(x) # (*BABEM, nr, ncols) grad_E = torch.einsum('...rc,...rc->...c', v, Mx.conj()) # (*BABEM, ncols) # calculate the gradient to the biases matrices grad_mparams = [] if ctx.M is not None and E is not None: with torch.enable_grad(): mparams = [p.clone().requires_grad_() for p in mparams] lmbdax = x * E.unsqueeze(-2) with ctx.M.uselinopparams(*mparams): mloss = ctx.M.mm(lmbdax) grad_mparams = torch.autograd.grad((mloss,), mparams, grad_outputs=(v,), create_graph=torch.is_grad_enabled(), allow_unused=True) return (None, grad_B, grad_E, None, None, None, None, None, *grad_params, *grad_mparams) def custom_exactsolve(A, B, E=None, M=None, **options): # A: (*BA, na, na) # B: (*BB, na, ncols) # E: (*BE, ncols) # M: (*BM, na, na) return exactsolve(A, B, E, M) # docstring completion _solve_methods = { "cg": cg, "bicgstab": bicgstab, "exactsolve": exactsolve, "broyden1": broyden1_solve, "scipy_gmres": wrap_gmres, "gmres": gmres, } ignore_kwargs = ["E", "M", "mparams"] solve.__doc__ = get_methods_docstr(solve, _solve_methods, ignore_kwargs)