Building a custom linear operator¶
xitorch provides some linear algebra operations that does not need the explicit
matrix, such as xitorch.linalg.solve()
and xitorch.linalg.symeig()
.
To represent the matrix implicitly, base class xitorch.LinearOperator
should be used to construct user-defined linear operators.
To write a LinearOperator class, the method _mv
(matrix-vector
multiplication) must be implemented.
If the LinearOperator is used in xitorch’s functional with grad enabled, e.g.
xitorch.linalg.symeig()
or xitorch.linalg.solve()
, it must have
the method _getparamnames
implemented.
_getparamnames
returns a list of parameters affecting the output,
as in xitorch.EditableModule
As an example, to write the matrix
as a LinearOperator, we can write
import torch
import xitorch
class MyFlip(xitorch.LinearOperator):
def __init__(self, a, size):
super().__init__(shape=(size,size))
self.a = a
def _mv(self, x):
return torch.flip(x, dims=(-1,)) * a
def _getparamnames(self, prefix=""):
return [prefix+"a"]
a = torch.arange(1, 6, dtype=torch.float).requires_grad_()
flip = MyFlip(a, 5)
print(flip)
LinearOperator (MyFlip) with shape (5, 5), dtype = torch.float32, device = cpu
With only _mv
implemented, we can call all matrix operations, including
.mv()
(matrix-vector multiplication),
.mm()
(matrix-matrix multiplication),
.fullmatrix()
(returns the dense representation of the linear operator),
.rmv()
(matrix-vector right-multiplication), and
.rmm()
(matrix-matrix right-multiplication).
The matrix-matrix multiplication is calculated by batched matrix-vector calculation, while the right-multiplication is performed using the adjoint trick with the help of PyTorch’s autograd engine.
vec = torch.arange(5, dtype=torch.float)
mat = torch.cat((vec.unsqueeze(-1), 2*vec.unsqueeze(-1)), dim=-1)
print(flip.mv(vec))
tensor([4., 6., 6., 4., 0.], grad_fn=<MulBackward0>)
# matrix-vector right-multiplication
print(flip.rmv(vec))
tensor([20., 12., 6., 2., 0.], grad_fn=<FlipBackward>)
# matrix-matrix multiplication
print(flip.mm(mat))
tensor([[ 4., 8.],
[ 6., 12.],
[ 6., 12.],
[ 4., 8.],
[ 0., 0.]], grad_fn=<SqueezeBackward1>)
# getting the dense representation
print(flip.fullmatrix())
tensor([[0., 0., 0., 0., 1.],
[0., 0., 0., 2., 0.],
[0., 0., 3., 0., 0.],
[0., 4., 0., 0., 0.],
[5., 0., 0., 0., 0.]], grad_fn=<SqueezeBackward1>)
The LinearOperator instance can also be used for linear algebra’s operations
in xitorch, such as xitorch.linalg.solve()
from xitorch.linalg import solve
mmres = flip.mm(mat)
mat2 = solve(flip, mmres)
print(mat2)
tensor([[0., 0.],
[1., 2.],
[2., 4.],
[3., 6.],
[4., 8.]], grad_fn=<LinalgSolveBackward>)