EditableModule

class xitorch.EditableModule[source]

EditableModule is a base class to enable classes that it inherits be converted to pure functions for higher order derivatives purpose.

abstract getparamnames(methodname: str, prefix: str = '') → List[str][source]

This method should list tensor names that affect the output of the method with name indicated in methodname. If the methodname is not on the list in this function, it should raise KeyError.

Parameters
  • methodname (str) – The name of the method of the class.

  • prefix (str) – The prefix to be appended in front of the parameters name. This usually contains the dots.

Returns

Sequence of name of parameters affecting the output of the method.

Return type

Sequence of string

Raises

KeyError – If the list in this function does not contain methodname.

Example

>>> class A(xitorch.EditableModule):
...     def __init__(self, a):
...         self.b = a*a
...
...     def mult(self, x):
...         return self.b * x
...
...     def getparamnames(self, methodname, prefix=""):
...         if methodname == "mult":
...             return [prefix+"b"]
...         else:
...             raise KeyError()
getuniqueparams(methodname: str, onlyleaves: bool = False) → List[torch.Tensor][source]

Returns the list of unique parameters involved in the method specified by methodname.

Parameters
  • methodname (str) – Name of the method where the returned parameters play roles.

  • onlyleaves (bool) – If True, only returns leaf tensors. Otherwise, returns all tensors.

Returns

List of tensors that are involved in the specified method of the object.

Return type

list of tensors

assertparams(method, *args, **kwargs)[source]

Perform a rigorous check on the implemented getparamnames in the class for a given method and its arguments as well as keyword arguments. It raises warnings if there are missing or excess parameters in the getparamnames implementation.

Parameters
  • method (callable method) – The method of this class to be tested

  • *args – Arguments of the method

  • **kwargs – Keyword arguments of the method

Example

>>> class AClass(xitorch.EditableModule):
...     def __init__(self, a):
...         self.a = a
...         self.b = a*a
...
...     def mult(self, x):
...         return self.b * x
...
...     def getparamnames(self, methodname, prefix=""):
...         if methodname == "mult":
...             return [prefix+"a"]  # intentionally wrong
...         else:
...             raise KeyError()
>>> a = torch.tensor(2.0).requires_grad_()
>>> x = torch.tensor(0.4).requires_grad_()
>>> A = AClass(a)
>>> A.assertparams(A.mult, x) 
<...>:1: UserWarning: getparams for AClass.mult does not include: b
  A.assertparams(A.mult, x) 
<...>:1: UserWarning: getparams for AClass.mult has excess parameters: a
  A.assertparams(A.mult, x) 
"mult" method check done