EditableModule¶
-
class
xitorch.
EditableModule
[source]¶ EditableModule
is a base class to enable classes that it inherits be converted to pure functions for higher order derivatives purpose.-
abstract
getparamnames
(methodname: str, prefix: str = '') → List[str][source]¶ This method should list tensor names that affect the output of the method with name indicated in
methodname
. If themethodname
is not on the list in this function, it should raiseKeyError
.- Parameters
methodname (str) – The name of the method of the class.
prefix (str) – The prefix to be appended in front of the parameters name. This usually contains the dots.
- Returns
Sequence of name of parameters affecting the output of the method.
- Return type
Sequence of string
- Raises
KeyError – If the list in this function does not contain
methodname
.
Example
>>> class A(xitorch.EditableModule): ... def __init__(self, a): ... self.b = a*a ... ... def mult(self, x): ... return self.b * x ... ... def getparamnames(self, methodname, prefix=""): ... if methodname == "mult": ... return [prefix+"b"] ... else: ... raise KeyError()
-
getuniqueparams
(methodname: str, onlyleaves: bool = False) → List[torch.Tensor][source]¶ Returns the list of unique parameters involved in the method specified by methodname.
- Parameters
methodname (str) – Name of the method where the returned parameters play roles.
onlyleaves (bool) – If True, only returns leaf tensors. Otherwise, returns all tensors.
- Returns
List of tensors that are involved in the specified method of the object.
- Return type
list of tensors
-
assertparams
(method, *args, **kwargs)[source]¶ Perform a rigorous check on the implemented
getparamnames
in the class for a given method and its arguments as well as keyword arguments. It raises warnings if there are missing or excess parameters in thegetparamnames
implementation.- Parameters
method (callable method) – The method of this class to be tested
*args – Arguments of the method
**kwargs – Keyword arguments of the method
Example
>>> class AClass(xitorch.EditableModule): ... def __init__(self, a): ... self.a = a ... self.b = a*a ... ... def mult(self, x): ... return self.b * x ... ... def getparamnames(self, methodname, prefix=""): ... if methodname == "mult": ... return [prefix+"a"] # intentionally wrong ... else: ... raise KeyError() >>> a = torch.tensor(2.0).requires_grad_() >>> x = torch.tensor(0.4).requires_grad_() >>> A = AClass(a) >>> A.assertparams(A.mult, x) <...>:1: UserWarning: getparams for AClass.mult does not include: b A.assertparams(A.mult, x) <...>:1: UserWarning: getparams for AClass.mult has excess parameters: a A.assertparams(A.mult, x) "mult" method check done
-
abstract