Source code for xitorch._core.linop

from __future__ import annotations
from typing import Sequence, Optional, List, Union
import warnings
import traceback
import torch
from abc import abstractmethod
from contextlib import contextmanager
from scipy.sparse.linalg import LinearOperator as spLinearOperator
from xitorch._core.editable_module import EditableModule
from xitorch.debug.modes import is_debug_enabled
from xitorch._utils.bcast import get_bcasted_dims

__all__ = ["LinearOperator"]

[docs]class LinearOperator(EditableModule): """ ``LinearOperator`` is a base class designed to behave as a linear operator without explicitly determining the matrix. This ``LinearOperator`` should be able to operate as batched linear operators where its shape is ``(B1,B2,...,Bb,p,q)`` with ``B*`` as the (optional) batch dimensions. For a user-defined class to behave as ``LinearOperator``, it must use ``LinearOperator`` as one of the parent and it has to have ``._mv()`` method implemented and ``._getparamnames()`` if used in xitorch's functionals with torch grad enabled. """ _implementation_checked = False _is_mv_implemented = False _is_mm_implemented = False _is_rmv_implemented = False _is_rmm_implemented = False _is_fullmatrix_implemented = False _is_gpn_implemented = False def __new__(cls, *args, **kwargs): # check the implemented functions in the class if not cls._implementation_checked: cls._is_mv_implemented = cls.__check_if_implemented("_mv") cls._is_mm_implemented = cls.__check_if_implemented("_mm") cls._is_rmv_implemented = cls.__check_if_implemented("_rmv") cls._is_rmm_implemented = cls.__check_if_implemented("_rmm") cls._is_fullmatrix_implemented = cls.__check_if_implemented("_fullmatrix") cls._is_gpn_implemented = cls.__check_if_implemented("_getparamnames") cls._implementation_checked = True if not cls._is_mv_implemented: raise RuntimeError("LinearOperator must have at least _mv(self) " "method implemented") return super(LinearOperator, cls).__new__(cls) @classmethod def __check_if_implemented(cls, methodname: str) -> bool: this_method = getattr(cls, methodname) base_method = getattr(LinearOperator, methodname) return this_method is not base_method
[docs] @classmethod def m(cls, mat: torch.Tensor, is_hermitian: Optional[bool] = None): """ Class method to wrap a matrix into ``LinearOperator``. Arguments --------- mat: torch.Tensor Matrix to be wrapped in the ``LinearOperator``. is_hermitian: bool or None Indicating if the matrix is Hermitian. If ``None``, the symmetry will be checked. If supplied as a bool, there is no check performed. Returns ------- LinearOperator Linear operator object that represents the matrix. Example ------- .. testsetup:: * import torch import xitorch torch.manual_seed(100) .. doctest:: >>> mat = torch.rand(1,3,1,2) # 1x2 matrix with (1,3) batch dimensions >>> linop = xitorch.LinearOperator.m(mat) >>> print(linop) MatrixLinearOperator with shape (1, 3, 1, 2): tensor([[[[0.1117, 0.8158]], <BLANKLINE> [[0.2626, 0.4839]], <BLANKLINE> [[0.6765, 0.7539]]]]) """ if is_hermitian is None: if mat.shape[-2] != mat.shape[-1]: is_hermitian = False else: is_hermitian = torch.allclose(mat, mat.transpose(-2, -1).conj()) elif is_hermitian: # check the hermitian if not torch.allclose(mat, mat.transpose(-2, -1).conj()): raise RuntimeError("The linear operator is indicated to be hermitian, but the matrix is not") return MatrixLinearOperator(mat, is_hermitian)
def __init__(self, shape: Sequence[int], is_hermitian: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, _suppress_hermit_warning: bool = False) -> None: super(LinearOperator, self).__init__() if len(shape) < 2: raise RuntimeError("The shape must have at least 2 dimensions") self._shape = shape self._batchshape = list(shape[:-2]) self._is_hermitian = is_hermitian self._dtype = dtype if dtype is not None else torch.float32 self._device = device if device is not None else torch.device("cpu") if is_hermitian and shape[-1] != shape[-2]: raise RuntimeError("The object is indicated as Hermitian, but the shape is not square") # check which methods are implemented if not _suppress_hermit_warning and self._is_hermitian and \ (self._is_rmv_implemented or self._is_rmm_implemented): warnings.warn("The LinearOperator is Hermitian with implemented " "rmv or rmm. We will use the mv and mm methods " "instead", stacklevel=2) def __repr__(self) -> str: return "LinearOperator (%s) with shape %s, dtype = %s, device = %s" % \ (self.__class__.__name__, _shape2str(self.shape), self.dtype, self.device)
[docs] @abstractmethod def _getparamnames(self, prefix: str = "") -> List[str]: """ List the self's parameters that affecting the ``LinearOperator``. This is for the derivative purpose. Arguments --------- prefix: str The prefix to be appended in front of the parameters name. This usually contains the dots. Returns ------- list of str List of parameter names (including the prefix) that affecting the ``LinearOperator``. """ return []
[docs] @abstractmethod def _mv(self, x: torch.Tensor) -> torch.Tensor: """ Abstract method to be implemented for matrix-vector multiplication. Required for all ``LinearOperator`` objects. """ pass
# @abstractmethod
[docs] def _rmv(self, x: torch.Tensor) -> torch.Tensor: """ Abstract method to be implemented for transposed matrix-vector multiplication. Optional. If not implemented, it will use the adjoint trick to compute ``.rmv()``. Usually implemented for efficiency reasons. """ pass
# @abstractmethod # (optional)
[docs] def _mm(self, x: torch.Tensor) -> torch.Tensor: """ Abstract method to be implemented for matrix-matrix multiplication. If not implemented, then it uses batched version of matrix-vector multiplication. Usually this is implemented for efficiency reasons. """ pass
# @abstractmethod
[docs] def _rmm(self, x: torch.Tensor) -> torch.Tensor: """ Abstract method to be implemented for transposed matrix-matrix multiplication. If not implemented, then it uses batched version of transposed matrix-vector multiplication. Usually this is implemented for efficiency reasons. """ pass
# @abstractmethod def _fullmatrix(self) -> torch.Tensor: pass # linear operators must have a set of parameters that affects most of # the methods (i.e. mm, mv, rmm, rmv) def getlinopparams(self) -> Sequence[torch.Tensor]: return self.getuniqueparams("mm") @contextmanager def uselinopparams(self, *params): methodname = "mm" try: _orig_params_ = self.getuniqueparams(methodname) self.setuniqueparams(methodname, *params) yield self finally: self.setuniqueparams(methodname, *_orig_params_) ############# implemented functions ################
[docs] def mv(self, x: torch.Tensor) -> torch.Tensor: """ Apply the matrix-vector operation to vector ``x`` with shape ``(...,q)``. The batch dimensions of ``x`` need not be the same as the batch dimensions of the ``LinearOperator``, but it must be broadcastable. Arguments --------- x: torch.tensor The vector with shape ``(...,q)`` where the linear operation is operated on Returns ------- y: torch.tensor The result of the linear operation with shape ``(...,p)`` """ self.__assert_if_init_executed() if x.shape[-1] != self.shape[-1]: raise RuntimeError("Cannot operate .mv on shape %s. Expected (...,%d)" % (str(tuple(x.shape)), self.shape[-1])) return self._mv(x)
[docs] def mm(self, x: torch.Tensor) -> torch.Tensor: """ Apply the matrix-matrix operation to matrix ``x`` with shape ``(...,q,r)``. The batch dimensions of ``x`` need not be the same as the batch dimensions of the ``LinearOperator``, but it must be broadcastable. Arguments --------- x: torch.tensor The matrix with shape ``(...,q,r)`` where the linear operation is operated on. Returns ------- y: torch.tensor The result of the linear operation with shape ``(...,p,r)`` """ self.__assert_if_init_executed() if x.shape[-2] != self.shape[-1]: raise RuntimeError("Cannot operate .mm on shape %s. Expected (...,%d,*)" % (str(tuple(x.shape)), self.shape[-1])) xbatchshape = list(x.shape[:-2]) if self._is_mm_implemented: return self._mm(x) else: # use batched mv as mm # move the last dimension to the very first dimension to be broadcasted if len(xbatchshape) < len(self._batchshape): xbatchshape = [1] * (len(self._batchshape) - len(xbatchshape)) + xbatchshape x1 = x.reshape(1, *xbatchshape, *x.shape[-2:]) xnew = x1.transpose(0, -1).squeeze(-1) # (r,...,q) # apply batched mv and restore the initial shape ynew = self._mv(xnew) # (r,...,p) y = ynew.unsqueeze(-1).transpose(0, -1).squeeze(0) # (...,p,r) return y
[docs] def rmv(self, x: torch.Tensor) -> torch.Tensor: """ Apply the matrix-vector adjoint operation to vector ``x`` with shape ``(...,p)``, i.e. ``A^H x``. The batch dimensions of ``x`` need not be the same as the batch dimensions of the ``LinearOperator``, but it must be broadcastable. Arguments --------- x: torch.tensor The vector of shape ``(...,p)`` where the adjoint linear operation is operated at. Returns ------- y: torch.tensor The result of the adjoint linear operation with shape ``(...,q)`` """ self.__assert_if_init_executed() if x.shape[-1] != self.shape[-2]: raise RuntimeError("Cannot operate .rmv on shape %s. Expected (...,%d)" % (str(tuple(x.shape)), self.shape[-2])) if self._is_hermitian: return self._mv(x) elif not self._is_rmv_implemented: return self.__adjoint_rmv(x) return self._rmv(x)
[docs] def rmm(self, x: torch.Tensor) -> torch.Tensor: """ Apply the matrix-matrix adjoint operation to matrix ``x`` with shape ``(...,p,r)``, i.e. ``A^H X``. The batch dimensions of ``x`` need not be the same as the batch dimensions of the ``LinearOperator``, but it must be broadcastable. Arguments --------- x: torch.Tensor The matrix of shape ``(...,p,r)`` where the adjoint linear operation is operated on. Returns ------- y: torch.Tensor The result of the adjoint linear operation with shape ``(...,q,r)``. """ self.__assert_if_init_executed() if x.shape[-2] != self.shape[-2]: raise RuntimeError("Cannot operate .rmm on shape %s. Expected (...,%d,*)" % (str(tuple(x.shape)), self.shape[-2])) if self._is_hermitian: return self.mm(x) xbatchshape = list(x.shape[:-2]) if self._is_rmm_implemented: return self._rmm(x) else: # use batched mv as mm rmv = self._rmv if self._is_rmv_implemented else self.rmv # move the last dimension to the very first dimension to be broadcasted if len(xbatchshape) < len(self._batchshape): xbatchshape = [1] * (len(self._batchshape) - len(xbatchshape)) + xbatchshape x1 = x.reshape(1, *xbatchshape, *x.shape[-2:]) # (1,...,p,r) xnew = x1.transpose(0, -1).squeeze(-1) # (r,...,p) # apply batched mv and restore the initial shape ynew = rmv(xnew) # (r,...,q) y = ynew.unsqueeze(-1).transpose(0, -1).squeeze(0) # (...,q,r) return y
def fullmatrix(self) -> torch.Tensor: if self._is_fullmatrix_implemented: return self._fullmatrix() else: self.__assert_if_init_executed() nq = self._shape[-1] V = torch.eye(nq, dtype=self._dtype, device=self._device) # (nq,nq) return self.mm(V) # (B1,B2,...,Bb,np,nq) def scipy_linalg_op(self): to_tensor = lambda x: torch.tensor(x, dtype=self.dtype, device=self.device) return spLinearOperator( shape=self.shape, matvec=lambda v: self.mv(to_tensor(v)).detach().cpu().numpy(), rmatvec=lambda v: self.rmv(to_tensor(v)).detach().cpu().numpy(), matmat=lambda v: self.mm(to_tensor(v)).detach().cpu().numpy(), rmatmat=lambda v: self.rmm(to_tensor(v)).detach().cpu().numpy(), ) def getparamnames(self, methodname: str, prefix: str = "") -> List[str]: """""" # just to remove the docstring from EditableModule because user # does not need to know about this function if methodname in ["mv", "rmv", "mm", "rmm", "fullmatrix"]: return self._getparamnames(prefix=prefix) else: raise KeyError("getparamnames for method %s is not implemented" % methodname) ############# cached properties ################ @property def H(self): """ Returns a LinearOperator representing the Hermite / transposed of the self LinearOperator. Returns ------- LinearOperator The Hermite / transposed LinearOperator """ if self._is_hermitian: return self elif isinstance(self, MatrixLinearOperator): return LinearOperator.m(self.fullmatrix().transpose(-2, -1).conj()) return AdjointLinearOperator(self) ############# special functions ################
[docs] def matmul(self, b: LinearOperator, is_hermitian: bool = False): """ Returns a LinearOperator representing `self @ b`. Arguments --------- b: LinearOperator Other linear operator is_hermitian: bool Flag to indicate if the resulting LinearOperator is Hermitian. Returns ------- LinearOperator LinearOperator representing `self @ b` """ # returns linear operator that represents self @ b if self.shape[-1] != b.shape[-2]: raise RuntimeError("Mismatch shape of matmul operation: %s and %s" % (self.shape, b.shape)) if isinstance(self, MatrixLinearOperator) and isinstance(b, MatrixLinearOperator): return LinearOperator.m(self.fullmatrix() @ b.fullmatrix(), is_hermitian=is_hermitian) return MatmulLinearOperator(self, b, is_hermitian=is_hermitian)
def __add__(self, b: LinearOperator): assert isinstance(b, LinearOperator), \ "Only addition with another LinearOperator is supported" if self.shape[-2:] != b.shape[-2:]: raise RuntimeError("Mismatch shape of add operation: %s and %s" % (self.shape, b.shape)) if isinstance(self, MatrixLinearOperator) and isinstance(b, MatrixLinearOperator): return LinearOperator.m(self.fullmatrix() + b.fullmatrix()) return AddLinearOperator(self, b) def __sub__(self, b: LinearOperator): assert isinstance(b, LinearOperator), \ "Only subtraction with another LinearOperator is supported" if self.shape[-2:] != b.shape[-2:]: raise RuntimeError("Mismatch shape of add operation: %s and %s" % (self.shape, b.shape)) if isinstance(self, MatrixLinearOperator) and isinstance(b, MatrixLinearOperator): return LinearOperator.m(self.fullmatrix() - b.fullmatrix()) return AddLinearOperator(self, b, -1) def __rsub__(self, b: LinearOperator): return b.__sub__(self) def __mul__(self, f: Union[int, float]): if not (isinstance(f, int) or isinstance(f, float)): raise TypeError("LinearOperator multiplication only supports integer or floating point") if isinstance(self, MatrixLinearOperator): return LinearOperator.m(self.fullmatrix() * f) return MulLinearOperator(self, f) def __rmul__(self, f: Union[int, float]): return self.__mul__(f) ############# properties ################ @property def dtype(self) -> torch.dtype: return self._dtype @property def device(self) -> torch.device: return self._device @property def shape(self) -> Sequence[int]: return self._shape @property def is_hermitian(self) -> bool: return self._is_hermitian # implementation @property def is_mv_implemented(self) -> bool: return True @property def is_mm_implemented(self) -> bool: return self._is_mm_implemented @property def is_rmv_implemented(self) -> bool: return self._is_rmv_implemented @property def is_rmm_implemented(self) -> bool: return self._is_rmm_implemented @property def is_fullmatrix_implemented(self) -> bool: return self._is_fullmatrix_implemented @property def is_getparamnames_implemented(self) -> bool: return self._is_gpn_implemented ############ debug functions ##############
[docs] def check(self, warn: Optional[bool] = None) -> None: """ Perform checks to make sure the ``LinearOperator`` behaves as a proper linear operator. Arguments --------- warn: bool or None If ``True``, then raises a warning to the user that the check might slow down the program. This is to remind the user to turn off the check when not in a debugging mode. If ``None``, it will raise a warning if it runs not in a debug mode, but will be silent if it runs in a debug mode. Raises ------ RuntimeError Raised if an error is raised when performing linear operations of the object (e.g. calling ``.mv()``, ``.mm()``, etc) AssertionError Raised if the linear operations do not behave as proper linear operations. (e.g. not scaling linearly) """ if warn is None: warn = not is_debug_enabled() if warn: msg = "The linear operator check is performed. This might slow down your program." warnings.warn(msg, stacklevel=2) checklinop(self) print("Check linear operator done")
############ private functions ################# def __adjoint_rmv(self, xt: torch.Tensor) -> torch.Tensor: # xt: (*BY, p) # xdummy: (*BY, q) # calculate the right matvec multiplication by using the adjoint trick BY = xt.shape[:-1] BA = self.shape[:-2] BAY = get_bcasted_dims(BY, BA) # calculate y = Ax p, q = self.shape[-2:] xdummy = torch.zeros((*BAY, q), dtype=xt.dtype, device=xt.device).requires_grad_() with torch.enable_grad(): y = self.mv(xdummy) # (*BAY, p) # calculate (dL/dx)^T = A^T (dL/dy)^T with (dL/dy)^T = xt xt2 = xt.contiguous().expand_as(y) # (*BAY, p) res = torch.autograd.grad(y, xdummy, grad_outputs=xt2, create_graph=torch.is_grad_enabled())[0] # (*BAY, q) return res # def __check_if_implemented(self, methodname: str) -> bool: # this_method = getattr(self, methodname).__func__ # base_method = getattr(LinearOperator, methodname) # return this_method is not base_method def __assert_if_init_executed(self): if not hasattr(self, "_shape"): raise RuntimeError("super().__init__ must be executed first")
############## special linear operators ############## class AdjointLinearOperator(LinearOperator): def __init__(self, obj: LinearOperator): super(AdjointLinearOperator, self).__init__( shape=(*obj.shape[:-2], obj.shape[-1], obj.shape[-2]), is_hermitian=obj.is_hermitian, dtype=obj.dtype, device=obj.device, _suppress_hermit_warning=True, ) self.obj = obj def __repr__(self): return "AdjointLinearOperator with shape %s of:\n - %s" % \ (_shape2str(self.shape), _indent(self.obj.__repr__(), 3)) def _mv(self, x: torch.Tensor) -> torch.Tensor: if not self.obj.is_rmv_implemented: raise RuntimeError("The ._rmv of must be implemented to call .H.mv()") return self.obj._rmv(x) def _rmv(self, x: torch.Tensor) -> torch.Tensor: return self.obj._mv(x) def _getparamnames(self, prefix: str = "") -> List[str]: return self.obj._getparamnames(prefix=prefix + "obj.") @property def H(self): return self.obj class MatmulLinearOperator(LinearOperator): def __init__(self, a: LinearOperator, b: LinearOperator, is_hermitian: bool = False): shape = (*get_bcasted_dims(a.shape[:-2], b.shape[:-2]), a.shape[-2], b.shape[-1]) super(MatmulLinearOperator, self).__init__( shape=shape, is_hermitian=is_hermitian, dtype=a.dtype, device=a.device, _suppress_hermit_warning=True, ) self.a = a self.b = b def __repr__(self): return "MatmulLinearOperator with shape %s of:\n * %s\n * %s" % \ (_shape2str(self.shape), _indent(self.a.__repr__(), 3), _indent(self.b.__repr__(), 3)) def _mv(self, x: torch.Tensor) -> torch.Tensor: return self.a._mv(self.b._mv(x)) def _rmv(self, x: torch.Tensor) -> torch.Tensor: return self.b.rmv(self.a.rmv(x)) def _getparamnames(self, prefix: str = "") -> List[str]: return self.a._getparamnames(prefix=prefix + "a.") + \ self.b._getparamnames(prefix=prefix + "b.") class AddLinearOperator(LinearOperator): def __init__(self, a: LinearOperator, b: LinearOperator, mul: int = 1): shape = (*get_bcasted_dims(a.shape[:-2], b.shape[:-2]), a.shape[-2], b.shape[-1]) is_hermitian = a.is_hermitian and b.is_hermitian super(AddLinearOperator, self).__init__( shape=shape, is_hermitian=is_hermitian, dtype=a.dtype, device=a.device, _suppress_hermit_warning=True, ) self.a = a self.b = b assert mul == 1 or mul == -1 self.mul = mul def __repr__(self): return "AddLinearOperator with shape %s of:\n * %s\n * %s" % \ (_shape2str(self.shape), _indent(self.a.__repr__(), 3), _indent(self.b.__repr__(), 3)) def _mv(self, x: torch.Tensor) -> torch.Tensor: return self.a._mv(x) + self.mul * self.b._mv(x) def _rmv(self, x: torch.Tensor) -> torch.Tensor: return self.a.rmv(x) + self.mul * self.b.rmv(x) def _getparamnames(self, prefix: str = "") -> List[str]: return self.a._getparamnames(prefix=prefix + "a.") + \ self.b._getparamnames(prefix=prefix + "b.") class MulLinearOperator(LinearOperator): def __init__(self, a: LinearOperator, f: Union[int, float]): shape = a.shape is_hermitian = a.is_hermitian super(MulLinearOperator, self).__init__( shape=shape, is_hermitian=is_hermitian, dtype=a.dtype, device=a.device, _suppress_hermit_warning=True, ) self.a = a self.f = f def __repr__(self): return "MulLinearOperator with shape %s of: \n * %s\n * %s" % \ (_shape2str(self.shape), _indent(self.a.__repr__(), 3), _indent(self.f.__repr__(), 3)) def _mv(self, x: torch.Tensor) -> torch.Tensor: return self.a._mv(x) * self.f def _rmv(self, x: torch.Tensor) -> torch.Tensor: return self.a._rmv(x) * self.f def _getparamnames(self, prefix: str = "") -> List[str]: pnames = self.a._getparamnames(prefix=prefix + "a.") return pnames class MatrixLinearOperator(LinearOperator): def __init__(self, mat: torch.Tensor, is_hermitian: bool) -> None: super(MatrixLinearOperator, self).__init__( shape=mat.shape, is_hermitian=is_hermitian, dtype=mat.dtype, device=mat.device, _suppress_hermit_warning=True, ) self.mat = mat def __repr__(self): return "MatrixLinearOperator with shape %s:\n %s" % \ (_shape2str(self.shape), _indent(self.mat.__repr__(), 3)) def _mv(self, x: torch.Tensor) -> torch.Tensor: return torch.matmul(self.mat, x.unsqueeze(-1)).squeeze(-1) def _mm(self, x: torch.Tensor) -> torch.Tensor: return torch.matmul(self.mat, x) def _rmv(self, x: torch.Tensor) -> torch.Tensor: return torch.matmul(self.mat.transpose(-2, -1).conj(), x.unsqueeze(-1)).squeeze(-1) def _rmm(self, x: torch.Tensor) -> torch.Tensor: return torch.matmul(self.mat.transpose(-2, -1).conj(), x) def _fullmatrix(self) -> torch.Tensor: return self.mat def _getparamnames(self, prefix: str = "") -> List[str]: return [prefix + "mat"] def checklinop(linop: LinearOperator) -> None: """ Check if the implemented mv and mm can receive the possible shapes and returns the correct shape. If an error is found, then this function raise AssertionError. Argument -------- * linop: LinearOperator instance The instance of LinearOperator to be checked Exception --------- * AssertionError Raised if there is a shape mismatch * RuntimeError Raised if there is an error when evaluating the .mv, .mm, .rmv, or .rmm methods """ shape = linop.shape p, q = shape[-2:] batchshape = shape[:-2] def runtest(methodname, xshape, yshape): x = torch.rand(xshape, dtype=linop.dtype, device=linop.device) fcn = getattr(linop, methodname) try: y = fcn(x) except Exception: s = traceback.format_exc() msg = "An error is raised from .%s with input shape: %s (linear operator shape: %s)\n" % \ (methodname, tuple(xshape), tuple(linop.shape)) msg += "--- full traceback ---\n%s" % s raise RuntimeError(msg) msg = "The output shape of .%s is not correct. Input: %s, expected output: %s, output: %s" % \ (methodname, tuple(x.shape), tuple(yshape), tuple(y.shape)) msg += "\n" + str(linop) assert list(y.shape) == list(yshape), msg # linearity test x2 = 1.25 * x y2 = fcn(x2) msg = "Linearity check fails\n%s\n" % str(linop) assert torch.allclose(y2, 1.25 * y), msg y0 = fcn(0 * x) assert torch.allclose(y0, y * 0), "Linearity check (with 0) fails\n" + str(linop) # batched test xnew = torch.cat((x.unsqueeze(0), x2.unsqueeze(0)), dim=0) ynew = fcn(xnew) # (2, ..., q) msg = "Batched test fails (expanding batches changes the results)" + str(linop) assert torch.allclose(ynew[0], y), msg assert torch.allclose(ynew[1], y2), msg # generate shapes mv_xshapes = [ (q,), (1, q), (1, 1, q), (*batchshape, q), (1, *batchshape, q), ] mv_yshapes = [ (*batchshape, p), (*batchshape, p) if len(batchshape) >= 1 else (1, p), (*batchshape, p) if len(batchshape) >= 2 else (1, 1, p), (*batchshape, p), (1, *batchshape, p) ] # test matvec and matmat, run input in multiple shapes to make sure no error is raised r = 2 for (mv_xshape, mv_yshape) in zip(mv_xshapes, mv_yshapes): runtest("mv", mv_xshape, mv_yshape) runtest("mm", (*mv_xshape, r), (*mv_yshape, r)) if not linop.is_rmv_implemented: return rmv_xshapes = [ (p,), (1, p), (1, 1, p), (*batchshape, p), (1, *batchshape, p), ] rmv_yshapes = [ (*batchshape, q), (*batchshape, q) if len(batchshape) >= 1 else (1, q), (*batchshape, q) if len(batchshape) >= 2 else (1, 1, q), (*batchshape, q), (1, *batchshape, q) ] for (rmv_xshape, rmv_yshape) in zip(rmv_xshapes, rmv_yshapes): runtest("rmv", rmv_xshape, rmv_yshape) runtest("rmm", (*rmv_xshape, r), (*rmv_yshape, r)) ########### repr helper functions ########### def _indent(s, nspace): # give indentation of the second line and next lines spaces = " " * nspace lines = [spaces + c if i > 0 else c for i, c in enumerate(s.split("\n"))] return "\n".join(lines) def _shape2str(shape): return "(%s)" % (", ".join([str(s) for s in shape]))