Debugging EditableModule and LinearOperator

If you are implementing xitorch.EditableModule or xitorch.LinearOperator, how are you sure that your implementation is correct? For example, are parameters listed in getparamnames() method of xitorch.EditableModule complete or excessive? Does the implementation of xitorch.LinearOperator actually behave like a proper linear operator? We will answer those questions here.

Checking parameters in xitorch.EditableModule

Let’s say we have a class derived from xitorch.EditableModule:

import torch
import xitorch

class AClass(xitorch.EditableModule):
    def __init__(self, a):
        self.a = a
        self.b = a*a

    def mult(self, x):
        return self.b * x

    def getparamnames(self, methodname, prefix=""):
        if methodname == "mult":
            return [prefix+"a"]  # intentionally wrong
        else:
            raise KeyError()

The method getparamnames returns the wrong parameters for method mult above: it returns a while it should be b. To detect the fault, you can use the method assertparams of the classes derived from xitorch.EditableModule.

The method assertparams takes a method and its arguments and keyword arguments as the inputs. It raises warnings if it detects missing affecting variables and excessive variables. An example is shown below.

a = torch.tensor(2.0).requires_grad_()
x = torch.tensor(0.4).requires_grad_()
A = AClass(a)
A.assertparams(A.mult, x)
"mult" method check done
/home/docs/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/ipykernel_launcher.py:4: UserWarning: getparams for AClass.mult does not include: b
  after removing the cwd from sys.path.
/home/docs/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/ipykernel_launcher.py:4: UserWarning: getparams for AClass.mult has excess parameters: a
  after removing the cwd from sys.path.

Is my LinearOperator actually a linear operator?

Programmatically, to implement a LinearOperator, you just need to implement the matrix-vector multiplication function, ._mv(). But does the implemented operation behave like a linear operator?

To check if your implementation is correct, you can use the method .check() in classes derived from LinearOperator. It does not take any input and it will perform several checks which will raise an error if it fails.

Let’s take an example of a wrong implementation of a linear operator.

import torch
import xitorch

class WrongLinearOp(xitorch.LinearOperator):
    def __init__(self, a):
        shape = (torch.numel(a), torch.numel(a))
        super().__init__(shape=shape, dtype=a.dtype, device=a.device)
        self.a = a

    def _mv(self, x):
        return self.a * x + 1.0  # not a linear operator

a = torch.tensor(1.2, requires_grad=True)
linop = WrongLinearOp(a)
linop.check()
/home/docs/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/ipykernel_launcher.py:15: UserWarning: The linear operator check is performed. This might slow down your program.
  from ipykernel import kernelapp as app
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_223/342784360.py in <module>
     13 a = torch.tensor(1.2, requires_grad=True)
     14 linop = WrongLinearOp(a)
---> 15 linop.check()

~/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/xitorch-0.4.0.dev0+3327cc0-py3.7.egg/xitorch/_core/linop.py in check(self, warn)
    518             msg = "The linear operator check is performed. This might slow down your program."
    519             warnings.warn(msg, stacklevel=2)
--> 520         checklinop(self)
    521         print("Check linear operator done")
    522 

~/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/xitorch-0.4.0.dev0+3327cc0-py3.7.egg/xitorch/_core/linop.py in checklinop(linop)
    778     r = 2
    779     for (mv_xshape, mv_yshape) in zip(mv_xshapes, mv_yshapes):
--> 780         runtest("mv", mv_xshape, mv_yshape)
    781         runtest("mm", (*mv_xshape, r), (*mv_yshape, r))
    782 

~/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/xitorch-0.4.0.dev0+3327cc0-py3.7.egg/xitorch/_core/linop.py in runtest(methodname, xshape, yshape)
    749         y2 = fcn(x2)
    750         msg = "Linearity check fails\n%s\n" % str(linop)
--> 751         assert torch.allclose(y2, 1.25 * y), msg
    752         y0 = fcn(0 * x)
    753         assert torch.allclose(y0, y * 0), "Linearity check (with 0) fails\n" + str(linop)

AssertionError: Linearity check fails
LinearOperator (WrongLinearOp) with shape (1, 1), dtype = torch.float32, device = cpu

As expected, it raises an error where the check fails (i.e., it is in linearity check). This check should only be done in debugging mode as it takes considerable amount of time.