Source code for xitorch._core.pure_function

import torch
import inspect
from typing import Callable, List, Tuple, Union, Sequence
from xitorch._utils.attr import set_attr, del_attr
from xitorch._utils.unique import Uniquifier
from xitorch._core.editable_module import EditableModule
from contextlib import contextmanager
from abc import abstractmethod

__all__ = ["get_pure_function", "make_sibling"]

############################ functional ###############################
class PureFunction(object):
    """
    PureFunction class wraps methods to make it stateless and expose the pure
    function to take inputs of the original inputs (`params`) and the object's
    states (`objparams`).
    For functions, this class only acts as a thin wrapper.
    """

    def __init__(self, fcntocall: Callable):
        self._state_change_allowed = True
        self._allobjparams = self._get_all_obj_params_init()
        self._uniq = Uniquifier(self._allobjparams)
        self._cur_objparams = self._uniq.get_unique_objs()
        self._fcntocall = fcntocall

        # restore stack stores list of (objparams, identical)
        # everytime the objparams are set, it will store the old objparams
        # and indication if the old and new objparams are identical
        self._restore_stack: List[Tuple[List, bool]] = []

    def __call__(self, *params):
        return self._fcntocall(*params)

    @abstractmethod
    def _get_all_obj_params_init(self):
        pass

    @abstractmethod
    def _set_all_obj_params(self, allobjparams):
        pass

    def objparams(self) -> List:
        return self._cur_objparams

    def set_objparams(self, objparams: List):
        # TODO: check if identical with current object parameters
        identical = _check_identical_objs(objparams, self._cur_objparams)
        self._restore_stack.append((self._cur_objparams, identical))
        if not identical:
            allobjparams = self._uniq.map_unique_objs(objparams)
            self._set_all_obj_params(allobjparams)
            self._cur_objparams = list(objparams)

    def restore_objparams(self):
        old_objparams, identical = self._restore_stack.pop(-1)
        if not identical:
            allobjparams = self._uniq.map_unique_objs(old_objparams)
            self._set_all_obj_params(allobjparams)
            self._cur_objparams = old_objparams

    @contextmanager
    def useobjparams(self, objparams: List):
        if not self._state_change_allowed:
            raise RuntimeError("The state change is disabled")
        try:
            self.set_objparams(objparams)
            yield
        finally:
            self.restore_objparams()

    @contextmanager
    def disable_state_change(self):
        try:
            prev_status = self._state_change_allowed
            self._state_change_allowed = False
            yield
        finally:
            self._state_change_allowed = prev_status

class FunctionPureFunction(PureFunction):
    def _get_all_obj_params_init(self):
        return []

    def _set_all_obj_params(self, objparams):
        pass

class EditableModulePureFunction(PureFunction):
    def __init__(self, obj: EditableModule, method: Callable):
        self.obj = obj
        self.method = method
        super().__init__(method)

    def _get_all_obj_params_init(self) -> List:
        return list(self.obj.getparams(self.method.__name__))

    def _set_all_obj_params(self, allobjparams: List):
        self.obj.setparams(self.method.__name__, *allobjparams)

class TorchNNPureFunction(PureFunction):
    def __init__(self, obj: torch.nn.Module, method: Callable):
        self.obj = obj
        self.method = method
        super().__init__(method)

    def _get_all_obj_params_init(self) -> List:
        # get the tensors in the torch.nn.Module to be used as params
        named_params = list(self.obj.named_parameters())
        if len(named_params) == 0:
            paramnames: List[str] = []
            obj_params: List[Union[torch.Tensor, torch.nn.Parameter]] = []
        else:
            paramnames_temp, obj_params_temp = zip(*named_params)
            paramnames = list(paramnames_temp)
            obj_params = list(obj_params_temp)
        self.names = paramnames
        return obj_params

    def _set_all_obj_params(self, objparams: List):
        for (name, param) in zip(self.names, objparams):
            del_attr(self.obj, name)  # delete required in case the param is not a torch.nn.Parameter
            set_attr(self.obj, name, param)

class SingleSiblingPureFunction(PureFunction):
    def __init__(self, fcn: Callable, fcntocall: Callable):
        self.pfunc = get_pure_function(fcn)
        super().__init__(fcntocall)

    def _get_all_obj_params_init(self) -> List:
        return self.pfunc._get_all_obj_params_init()

    def _set_all_obj_params(self, allobjparams: List):
        self.pfunc._set_all_obj_params(allobjparams)

class MultiSiblingPureFunction(PureFunction):
    def __init__(self, fcns: Sequence[Callable], fcntocall: Callable):
        self.pfuncs = [get_pure_function(fcn) for fcn in fcns]
        self.npfuncs = len(self.pfuncs)
        super().__init__(fcntocall)

    def _get_all_obj_params_init(self) -> List:
        res: List[Union[torch.Tensor, torch.nn.Parameter]] = []
        self.cumsum_idx = [0] * (self.npfuncs + 1)
        for i, pfunc in enumerate(self.pfuncs):
            objparams = pfunc._get_all_obj_params_init()
            res = res + objparams
            self.cumsum_idx[i + 1] = self.cumsum_idx[i] + len(objparams)
        return res

    def _set_all_obj_params(self, allobjparams: List):
        for i, pfunc in enumerate(self.pfuncs):
            pfunc._set_all_obj_params(allobjparams[self.cumsum_idx[i]:self.cumsum_idx[i + 1]])

def _check_identical_objs(objs1: List, objs2: List) -> bool:
    for obj1, obj2 in zip(objs1, objs2):
        if id(obj1) != id(obj2):
            return False
    return True

def get_pure_function(fcn) -> PureFunction:
    """
    Get the pure function form of the function or method ``fcn``.

    Arguments
    ---------
    fcn: function or method
        Function or method to be converted into a ``PureFunction`` by exposing
        the hidden parameters affecting its outputs.

    Returns
    -------
    PureFunction
        The pure function wrapper
    """

    errmsg = "The input function must be a function, a method of " \
        "torch.nn.Module, a method of xitorch.EditableModule, or a sibling method"

    if isinstance(fcn, PureFunction):
        return fcn

    elif inspect.isfunction(fcn) or isinstance(fcn, torch.jit.ScriptFunction):
        return FunctionPureFunction(fcn)

    # if it is a method from an object, unroll the parameters and add
    # the object's parameters as well
    elif inspect.ismethod(fcn) or hasattr(fcn, "__call__"):
        if inspect.ismethod(fcn):
            obj = fcn.__self__
        else:
            obj = fcn
            fcn = fcn.__call__

        if isinstance(obj, EditableModule):
            return EditableModulePureFunction(obj, fcn)
        elif isinstance(obj, torch.nn.Module):
            return TorchNNPureFunction(obj, fcn)
        else:
            raise RuntimeError(errmsg)

    else:
        raise RuntimeError(errmsg)

[docs]def make_sibling(*pfuncs) -> Callable[[Callable], PureFunction]: """ Used as a decor to mark the decorated function as a sibling method of the input ``pfunc``. Sibling method is a method that is virtually belong to the same object, but behaves differently. Changing the state of the decorated function will also change the state of ``pfunc`` and its other siblings. """ if len(pfuncs) == 0: raise TypeError("At least 1 function is required as the argument") elif len(pfuncs) == 1: return lambda fcn: SingleSiblingPureFunction(pfuncs[0], fcntocall=fcn) else: return lambda fcn: MultiSiblingPureFunction(pfuncs, fcntocall=fcn)