Building a custom linear operator¶
xitorch provides some linear algebra operations that does not need the explicit
matrix, such as xitorch.linalg.solve() and xitorch.linalg.symeig().
To represent the matrix implicitly, base class xitorch.LinearOperator
should be used to construct user-defined linear operators.
To write a LinearOperator class, the method _mv (matrix-vector
multiplication) must be implemented.
If the LinearOperator is used in xitorch’s functional with grad enabled, e.g.
xitorch.linalg.symeig() or xitorch.linalg.solve(), it must have
the method _getparamnames implemented.
_getparamnames returns a list of parameters affecting the output,
as in xitorch.EditableModule
As an example, to write the matrix
as a LinearOperator, we can write
import torch
import xitorch
class MyFlip(xitorch.LinearOperator):
def __init__(self, a, size):
super().__init__(shape=(size,size))
self.a = a
def _mv(self, x):
return torch.flip(x, dims=(-1,)) * a
def _getparamnames(self, prefix=""):
return [prefix+"a"]
a = torch.arange(1, 6, dtype=torch.float).requires_grad_()
flip = MyFlip(a, 5)
print(flip)
LinearOperator (MyFlip) with shape (5, 5), dtype = torch.float32, device = cpu
With only _mv implemented, we can call all matrix operations, including
.mv()(matrix-vector multiplication),
.mm()(matrix-matrix multiplication),
.fullmatrix()(returns the dense representation of the linear operator),
.rmv()(matrix-vector right-multiplication), and
.rmm()(matrix-matrix right-multiplication).
The matrix-matrix multiplication is calculated by batched matrix-vector calculation, while the right-multiplication is performed using the adjoint trick with the help of PyTorch’s autograd engine.
vec = torch.arange(5, dtype=torch.float)
mat = torch.cat((vec.unsqueeze(-1), 2*vec.unsqueeze(-1)), dim=-1)
print(flip.mv(vec))
tensor([4., 6., 6., 4., 0.], grad_fn=<MulBackward0>)
# matrix-vector right-multiplication
print(flip.rmv(vec))
tensor([20., 12., 6., 2., 0.], grad_fn=<FlipBackward>)
# matrix-matrix multiplication
print(flip.mm(mat))
tensor([[ 4., 8.],
[ 6., 12.],
[ 6., 12.],
[ 4., 8.],
[ 0., 0.]], grad_fn=<SqueezeBackward1>)
# getting the dense representation
print(flip.fullmatrix())
tensor([[0., 0., 0., 0., 1.],
[0., 0., 0., 2., 0.],
[0., 0., 3., 0., 0.],
[0., 4., 0., 0., 0.],
[5., 0., 0., 0., 0.]], grad_fn=<SqueezeBackward1>)
The LinearOperator instance can also be used for linear algebra’s operations
in xitorch, such as xitorch.linalg.solve()
from xitorch.linalg import solve
mmres = flip.mm(mat)
mat2 = solve(flip, mmres)
print(mat2)
tensor([[0., 0.],
[1., 2.],
[2., 4.],
[3., 6.],
[4., 8.]], grad_fn=<LinalgSolveBackward>)