Debugging EditableModule and LinearOperator¶
If you are implementing xitorch.EditableModule
or
xitorch.LinearOperator
, how are you sure that your implementation is
correct?
For example, are parameters listed in getparamnames()
method of
xitorch.EditableModule
complete or excessive?
Does the implementation of xitorch.LinearOperator
actually behave like
a proper linear operator?
We will answer those questions here.
Checking parameters in xitorch.EditableModule
¶
Let’s say we have a class derived from xitorch.EditableModule
:
import torch
import xitorch
class AClass(xitorch.EditableModule):
def __init__(self, a):
self.a = a
self.b = a*a
def mult(self, x):
return self.b * x
def getparamnames(self, methodname, prefix=""):
if methodname == "mult":
return [prefix+"a"] # intentionally wrong
else:
raise KeyError()
The method getparamnames
returns the wrong parameters for method mult
above: it returns a
while it should be b
.
To detect the fault, you can use the method assertparams
of
the classes derived from xitorch.EditableModule
.
The method assertparams
takes a method and its arguments and keyword
arguments as the inputs.
It raises warnings if it detects missing affecting variables and excessive
variables.
An example is shown below.
a = torch.tensor(2.0).requires_grad_()
x = torch.tensor(0.4).requires_grad_()
A = AClass(a)
A.assertparams(A.mult, x)
"mult" method check done
/home/docs/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/ipykernel_launcher.py:4: UserWarning: getparams for AClass.mult does not include: b after removing the cwd from sys.path. /home/docs/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/ipykernel_launcher.py:4: UserWarning: getparams for AClass.mult has excess parameters: a after removing the cwd from sys.path.
Is my LinearOperator
actually a linear operator?¶
Programmatically, to implement a LinearOperator
, you just need to
implement the matrix-vector multiplication function, ._mv()
.
But does the implemented operation behave like a linear operator?
To check if your implementation is correct, you can use the method .check()
in classes derived from LinearOperator
.
It does not take any input and it will perform several checks which will raise
an error if it fails.
Let’s take an example of a wrong implementation of a linear operator.
import torch
import xitorch
class WrongLinearOp(xitorch.LinearOperator):
def __init__(self, a):
shape = (torch.numel(a), torch.numel(a))
super().__init__(shape=shape, dtype=a.dtype, device=a.device)
self.a = a
def _mv(self, x):
return self.a * x + 1.0 # not a linear operator
a = torch.tensor(1.2, requires_grad=True)
linop = WrongLinearOp(a)
linop.check()
/home/docs/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/ipykernel_launcher.py:15: UserWarning: The linear operator check is performed. This might slow down your program. from ipykernel import kernelapp as app
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
/tmp/ipykernel_223/342784360.py in <module>
13 a = torch.tensor(1.2, requires_grad=True)
14 linop = WrongLinearOp(a)
---> 15 linop.check()
~/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/xitorch-0.4.0.dev0+3327cc0-py3.7.egg/xitorch/_core/linop.py in check(self, warn)
518 msg = "The linear operator check is performed. This might slow down your program."
519 warnings.warn(msg, stacklevel=2)
--> 520 checklinop(self)
521 print("Check linear operator done")
522
~/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/xitorch-0.4.0.dev0+3327cc0-py3.7.egg/xitorch/_core/linop.py in checklinop(linop)
778 r = 2
779 for (mv_xshape, mv_yshape) in zip(mv_xshapes, mv_yshapes):
--> 780 runtest("mv", mv_xshape, mv_yshape)
781 runtest("mm", (*mv_xshape, r), (*mv_yshape, r))
782
~/checkouts/readthedocs.org/user_builds/xitorch/envs/latest/lib/python3.7/site-packages/xitorch-0.4.0.dev0+3327cc0-py3.7.egg/xitorch/_core/linop.py in runtest(methodname, xshape, yshape)
749 y2 = fcn(x2)
750 msg = "Linearity check fails\n%s\n" % str(linop)
--> 751 assert torch.allclose(y2, 1.25 * y), msg
752 y0 = fcn(0 * x)
753 assert torch.allclose(y0, y * 0), "Linearity check (with 0) fails\n" + str(linop)
AssertionError: Linearity check fails
LinearOperator (WrongLinearOp) with shape (1, 1), dtype = torch.float32, device = cpu
As expected, it raises an error where the check fails (i.e., it is in linearity check). This check should only be done in debugging mode as it takes considerable amount of time.