import inspect
import warnings
from abc import abstractmethod
import copy
import torch
from typing import Sequence, Union, Dict, List
from xitorch._utils.exceptions import GetSetParamsError
from xitorch._utils.attr import get_attr, set_attr, del_attr

__all__ = ["EditableModule"]

torch_float_type = [torch.float32, torch.float, torch.float64, torch.float16]

[docs]class EditableModule(object): """ ``EditableModule`` is a base class to enable classes that it inherits be converted to pure functions for higher order derivatives purpose. """ def getparams(self, methodname: str) -> Sequence[torch.Tensor]: # Returns a list of tensor parameters used in the object's operations paramnames = self.cached_getparamnames(methodname) return [get_attr(self, name) for name in paramnames] def setparams(self, methodname: str, *params) -> int: # Set the input parameters to the object's parameters to make a copy of # the operations. # *params is an excessive list of the parameters to be set and the # method will return the number of parameters it sets. paramnames = self.cached_getparamnames(methodname) for name, val in zip(paramnames, params): try: set_attr(self, name, val) except TypeError as e: # failed because val should be param del_attr(self, name) set_attr(self, name, val) return len(params) def cached_getparamnames(self, methodname: str, refresh: bool = False) -> List[str]: # getparamnames, but cached, so it is only called once if not hasattr(self, "_paramnames_"): self._paramnames_: Dict[str, List[str]] = {} if methodname not in self._paramnames_: self._paramnames_[methodname] = self.getparamnames(methodname) return self._paramnames_[methodname]
[docs] @abstractmethod def getparamnames(self, methodname: str, prefix: str = "") -> List[str]: """ This method should list tensor names that affect the output of the method with name indicated in ``methodname``. If the ``methodname`` is not on the list in this function, it should raise ``KeyError``. Arguments --------- methodname: str The name of the method of the class. prefix: str The prefix to be appended in front of the parameters name. This usually contains the dots. Returns ------- Sequence of string Sequence of name of parameters affecting the output of the method. Raises ------ KeyError If the list in this function does not contain ``methodname``. Example ------- .. testsetup:: import torch import xitorch .. doctest:: >>> class A(xitorch.EditableModule): ... def __init__(self, a): ... self.b = a*a ... ... def mult(self, x): ... return self.b * x ... ... def getparamnames(self, methodname, prefix=""): ... if methodname == "mult": ... return [prefix+"b"] ... else: ... raise KeyError() """ pass
[docs] def getuniqueparams(self, methodname: str, onlyleaves: bool = False) -> List[torch.Tensor]: """ Returns the list of unique parameters involved in the method specified by `methodname`. Arguments --------- methodname: str Name of the method where the returned parameters play roles. onlyleaves: bool If True, only returns leaf tensors. Otherwise, returns all tensors. Returns ------- list of tensors List of tensors that are involved in the specified method of the object. """ allparams = self.getparams(methodname) idxs = self._get_unique_params_idxs(methodname, allparams) if onlyleaves: return [allparams[i] for i in idxs if allparams[i].is_leaf] else: return [allparams[i] for i in idxs]
def setuniqueparams(self, methodname: str, *uniqueparams) -> int: nparams = self._number_of_params[methodname] allparams = [None for _ in range(nparams)] maps = self._unique_params_maps[methodname] for j in range(len(uniqueparams)): jmap = maps[j] p = uniqueparams[j] for i in jmap: allparams[i] = p return self.setparams(methodname, *allparams) def _get_unique_params_idxs(self, methodname: str, allparams: Union[Sequence[torch.Tensor], None] = None) -> Sequence[int]: if not hasattr(self, "_unique_params_idxs"): self._unique_params_idxs = {} # type: Dict[str,Sequence[int]] self._unique_params_maps = {} self._number_of_params = {} if methodname in self._unique_params_idxs: return self._unique_params_idxs[methodname] if allparams is None: allparams = self.getparams(methodname) # get the unique ids ids = [] # type: List[int] idxs = [] idx_map = [] # type: List[List[int]] for i in range(len(allparams)): param = allparams[i] id_param = id(param) # search the id if it has been added to the list try: jfound = ids.index(id_param) idx_map[jfound].append(i) continue except ValueError: pass ids.append(id_param) idxs.append(i) idx_map.append([i]) self._number_of_params[methodname] = len(allparams) self._unique_params_idxs[methodname] = idxs self._unique_params_maps[methodname] = idx_map return idxs ############# debugging #############
[docs] def assertparams(self, method, *args, **kwargs): """ Perform a rigorous check on the implemented ``getparamnames`` in the class for a given method and its arguments as well as keyword arguments. It raises warnings if there are missing or excess parameters in the ``getparamnames`` implementation. Arguments --------- method: callable method The method of this class to be tested *args: Arguments of the method **kwargs: Keyword arguments of the method Example ------- .. testsetup:: assertparams import torch import xitorch import sys sys.stderr = sys.stdout .. doctest:: assertparams >>> class AClass(xitorch.EditableModule): ... def __init__(self, a): ... self.a = a ... self.b = a*a ... ... def mult(self, x): ... return self.b * x ... ... def getparamnames(self, methodname, prefix=""): ... if methodname == "mult": ... return [prefix+"a"] # intentionally wrong ... else: ... raise KeyError() >>> a = torch.tensor(2.0).requires_grad_() >>> x = torch.tensor(0.4).requires_grad_() >>> A = AClass(a) >>> A.assertparams(A.mult, x) # doctest:+ELLIPSIS <...>:1: UserWarning: getparams for AClass.mult does not include: b A.assertparams(A.mult, x) # doctest:+ELLIPSIS <...>:1: UserWarning: getparams for AClass.mult has excess parameters: a A.assertparams(A.mult, x) # doctest:+ELLIPSIS "mult" method check done """ # check the method input if not inspect.ismethod(method): raise TypeError("The input method must be a method") methodself = method.__self__ if methodself is not self: raise RuntimeError("The method does not belong to the same instance") methodname = method.__name__ # assert if the method preserve the float tensors of the object self.__assert_method_preserve(method, *args, **kwargs) self.__assert_get_correct_params(method, *args, **kwargs) # check if getparams returns the correct tensors print('"%s" method check done' % methodname)
def __assert_method_preserve(self, method, *args, **kwargs): # this method assert if method does not change the float tensor parameters # of the object (i.e. it preserves the state of the object) all_params0, names0 = _get_tensors(self) all_params0 = [p.clone() for p in all_params0] method(*args, **kwargs) all_params1, names1 = _get_tensors(self) # now assert if all_params0 == all_params1 clsname = method.__self__.__class__.__name__ methodname = method.__name__ msg = "The method %s.%s does not preserve the object's float tensors: \n" % (clsname, methodname) if len(all_params0) != len(all_params1): msg += "The number of parameters changed:\n" msg += "* number of object's parameters before: %d\n" % len(all_params0) msg += "* number of object's parameters after : %d\n" % len(all_params1) raise GetSetParamsError(msg) for pname, p0, p1 in zip(names0, all_params0, all_params1): if p0.shape != p1.shape: msg += "The shape of %s changed\n" % pname msg += "* (before) %s.shape: %s\n" % (pname, p0.shape) msg += "* (after ) %s.shape: %s\n" % (pname, p1.shape) raise GetSetParamsError(msg) if not torch.allclose(p0, p1): msg += "The value of %s changed\n" % pname msg += "* (before) %s: %s\n" % (pname, p0) msg += "* (after ) %s: %s\n" % (pname, p1) raise GetSetParamsError(msg) def __assert_get_correct_params(self, method, *args, **kwargs): # this function perform checks if the getparams on the method returns # the correct tensors methodname = method.__name__ clsname = method.__self__.__class__.__name__ # get all tensor parameters in the object all_params, all_names = _get_tensors(self) def _get_tensor_name(param): for i in range(len(all_params)): if id(all_params[i]) == id(param): return all_names[i] return None # get the parameter tensors used in the operation and the tensors specified by the developer oper_names, oper_params = self.__list_operating_params(method, *args, **kwargs) user_names = self.getparamnames(method.__name__) user_params = [get_attr(self, name) for name in user_names] user_params_id = [id(p) for p in user_params] oper_params_id = [id(p) for p in oper_params] user_params_id_set = set(user_params_id) oper_params_id_set = set(oper_params_id) # check if the userparams contains non-tensor for i in range(len(user_params)): param = user_params[i] if (not isinstance(param, torch.Tensor)) or \ (isinstance(param, torch.Tensor) and param.dtype not in torch_float_type): msg = "Parameter %s is a non-floating point tensor" % user_names[i] raise GetSetParamsError(msg) # check if there are missing parameters (present in operating params, but not in the user params) missing_names = [] for i in range(len(oper_names)): if oper_params_id[i] not in user_params_id_set: # if oper_names[i] not in user_names: missing_names.append(oper_names[i]) # if there are missing parameters, give a warning (because the program # can still run correctly, e.g. missing parameters are parameters that # are never set to require grad) if len(missing_names) > 0: msg = "getparams for %s.%s does not include: %s" % (clsname, methodname, ", ".join(missing_names)) warnings.warn(msg, stacklevel=3) # check if there are excessive parameters (present in the user params, but not in the operating params) excess_names = [] for i in range(len(user_names)): if user_params_id[i] not in oper_params_id_set: # if user_names[i] not in oper_names: excess_names.append(user_names[i]) # if there are excess parameters, give warnings if len(excess_names) > 0: msg = "getparams for %s.%s has excess parameters: %s" % \ (clsname, methodname, ", ".join(excess_names)) warnings.warn(msg, stacklevel=3) def __list_operating_params(self, method, *args, **kwargs): # Sequence the tensors used in executing the method by calling the method # and see which parameters are connected in the backward graph # get all the tensors recursively all_tensors, all_names = _get_tensors(self) # copy the tensors and require them to be differentiable copy_tensors0 = [tensor.clone().detach().requires_grad_() for tensor in all_tensors] copy_tensors = copy.copy(copy_tensors0) _set_tensors(self, copy_tensors) # run the method and see which one has the gradients output = method(*args, **kwargs) if not isinstance(output, torch.Tensor): raise RuntimeError("The method to be asserted must have a tensor output") output = output.sum() grad_tensors = torch.autograd.grad(output, copy_tensors0, retain_graph=True, allow_unused=True) # return the original tensor all_tensors_copy = copy.copy(all_tensors) _set_tensors(self, all_tensors_copy) names = [] params = [] for i, grad in enumerate(grad_tensors): if grad is None: continue names.append(all_names[i]) params.append(all_tensors[i]) return names, params
############################ traversing functions ############################ def _traverse_obj(obj, prefix, action, crit, max_depth=20, exception_ids=None): """ Traverse an object to get/set variables that are accessible through the object. """ if exception_ids is None: # None is set as default arg to avoid expanding list for multiple # invokes of _get_tensors without exception_ids argument exception_ids = set() if isinstance(obj, torch.nn.Module): generators = [obj._parameters.items(), obj._modules.items()] name_format = "{prefix}{key}" objdicts = [obj._parameters, obj._modules] elif hasattr(obj, "__dict__"): generators = [obj.__dict__.items()] name_format = "{prefix}{key}" objdicts = [obj.__dict__] elif hasattr(obj, "__iter__"): generators = [obj.items() if isinstance(obj, dict) else enumerate(obj)] name_format = "{prefix}[{key}]" objdicts = [obj] else: raise RuntimeError("The object must be iterable or keyable") for generator, objdict in zip(generators, objdicts): for key, elmt in generator: name = name_format.format(prefix=prefix, key=key) if crit(elmt): action(elmt, name, objdict, key) continue hasdict = hasattr(elmt, "__dict__") hasiter = hasattr(elmt, "__iter__") if hasdict or hasiter: # add exception to avoid infinite loop if there is a mutual dependant on objects if id(elmt) in exception_ids: continue else: exception_ids.add(id(elmt)) if max_depth > 0: _traverse_obj(elmt, action=action, crit=crit, prefix=name + "." if hasdict else name, max_depth=max_depth - 1, exception_ids=exception_ids) else: raise RecursionError("Maximum number of recursion reached") def _get_tensors(obj, prefix="", max_depth=20): """ Collect all tensors in an object recursively and return the tensors as well as their "names" (names meaning the address, e.g. "self.a[0].elmt"). Arguments --------- * obj: an instance The object user wants to traverse down * prefix: str Prefix of the name of the collected tensors. Default: "" Returns ------- * res: list of torch.Tensor Sequence of tensors collected recursively in the object. * name: list of str Sequence of names of the collected tensors. """ # get the tensors recursively towards torch.nn.Module res = [] names = [] def action(elmt, name, objdict, key): res.append(elmt) names.append(name) # traverse down the object to collect the tensors crit = lambda elmt: isinstance(elmt, torch.Tensor) and elmt.dtype in torch_float_type _traverse_obj(obj, action=action, crit=crit, prefix=prefix, max_depth=max_depth) return res, names def _set_tensors(obj, all_params, max_depth=20): """ Set the tensors in an object to new tensor object listed in `all_params`. Arguments --------- * obj: an instance The object user wants to traverse down * all_params: list of torch.Tensor Sequence of tensors to be put in the object. * max_depth: int Maximum recursive depth to avoid infinitely running program. If the maximum depth is reached, then raise a RecursionError. """ def action(elmt, name, objdict, key): objdict[key] = all_params.pop(0) # traverse down the object to collect the tensors crit = lambda elmt: isinstance(elmt, torch.Tensor) and elmt.dtype in torch_float_type _traverse_obj(obj, action=action, crit=crit, prefix="", max_depth=max_depth)