Source code for xitorch.optimize.rootfinder

from typing import Callable, Mapping, Any, Sequence, Union, List
import torch
from xitorch._utils.misc import TensorNonTensorSeparator, get_method
from xitorch._utils.assertfuncs import assert_fcn_params
from xitorch._impls.optimize.root.rootsolver import broyden1, broyden2, \
    linearmixing
from xitorch._impls.optimize.minimizer import gd, adam
from xitorch.linalg.solve import solve
from xitorch.grad.jachess import jac
from xitorch._core.pure_function import get_pure_function, make_sibling
from xitorch._docstr.api_docstr import get_methods_docstr
from xitorch.debug.modes import is_debug_enabled

__all__ = ["equilibrium", "rootfinder", "minimize"]

_RF_METHODS = {
    "broyden1": broyden1,
    "broyden2": broyden2,
    "linearmixing": linearmixing,
}

_OPT_METHODS = {
    "gd": gd,
    "adam": adam,
}

[docs]def rootfinder( fcn: Callable[..., torch.Tensor], y0: torch.Tensor, params: Sequence[Any] = [], bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> torch.Tensor: r""" Solving the rootfinder equation of a given function, .. math:: \mathbf{0} = \mathbf{f}(\mathbf{y}, \theta) where :math:`\mathbf{f}` is a function that can be non-linear and produce output of the same shape of :math:`\mathbf{y}`, and :math:`\theta` is other parameters required in the function. The output of this block is :math:`\mathbf{y}` that produces the :math:`\mathbf{0}` as the output. Arguments --------- fcn : callable The function :math:`\mathbf{f}` with output tensor ``(*ny)`` y0 : torch.tensor Initial guess of the solution with shape ``(*ny)`` params : list Sequence of any other parameters to be put in ``fcn`` bck_options : dict Method-specific options for the backward solve (see :func:`xitorch.linalg.solve`) method : str or callable or None Rootfinder method. If None, it will choose ``"broyden1"``. **fwd_options Method-specific options (see method section) Returns ------- torch.tensor The solution which satisfies :math:`\mathbf{0} = \mathbf{f}(\mathbf{y},\theta)` with shape ``(*ny)`` Example ------- .. testsetup:: root1 import torch from xitorch.optimize import rootfinder .. doctest:: root1 >>> def func1(y, A): # example function ... return torch.tanh(A @ y + 0.1) + y / 2.0 >>> A = torch.tensor([[1.1, 0.4], [0.3, 0.8]]).requires_grad_() >>> y0 = torch.zeros((2,1)) # zeros as the initial guess >>> yroot = rootfinder(func1, y0, params=(A,)) >>> print(yroot) tensor([[-0.0459], [-0.0663]], grad_fn=<_RootFinderBackward>) """ # perform implementation check if debug mode is enabled if is_debug_enabled(): assert_fcn_params(fcn, (y0, *params)) pfunc = get_pure_function(fcn) fwd_options["method"] = _get_rootfinder_default_method(method) return _RootFinder.apply(pfunc, y0, pfunc, False, fwd_options, bck_options, len(params), *params, *pfunc.objparams())
[docs]def equilibrium( fcn: Callable[..., torch.Tensor], y0: torch.Tensor, params: Sequence[Any] = [], bck_options: Mapping[str, Any] = {}, method: Union[str, Callable, None] = None, **fwd_options) -> torch.Tensor: r""" Solving the equilibrium equation of a given function, .. math:: \mathbf{y} = \mathbf{f}(\mathbf{y}, \theta) where :math:`\mathbf{f}` is a function that can be non-linear and produce output of the same shape of :math:`\mathbf{y}`, and :math:`\theta` is other parameters required in the function. The output of this block is :math:`\mathbf{y}` that produces the same :math:`\mathbf{y}` as the output. Arguments --------- fcn : callable The function :math:`\mathbf{f}` with output tensor ``(*ny)`` y0 : torch.tensor Initial guess of the solution with shape ``(*ny)`` params : list Sequence of any other parameters to be put in ``fcn`` bck_options : dict Method-specific options for the backward solve (see :func:`xitorch.linalg.solve`) method : str or None Rootfinder method. If None, it will choose ``"broyden1"``. **fwd_options Method-specific options (see method section) Returns ------- torch.tensor The solution which satisfies :math:`\mathbf{y} = \mathbf{f}(\mathbf{y},\theta)` with shape ``(*ny)`` Example ------- .. testsetup:: equil1 import torch from xitorch.optimize import equilibrium .. doctest:: equil1 >>> def func1(y, A): # example function ... return torch.tanh(A @ y + 0.1) + y / 2.0 >>> A = torch.tensor([[1.1, 0.4], [0.3, 0.8]]).requires_grad_() >>> y0 = torch.zeros((2,1)) # zeros as the initial guess >>> yequil = equilibrium(func1, y0, params=(A,)) >>> print(yequil) tensor([[ 0.2313], [-0.5957]], grad_fn=<_RootFinderBackward>) Note ---- * This is a direct implementation of finding the root of :math:`\mathbf{g}(\mathbf{y}, \theta) = \mathbf{y} - \mathbf{f}(\mathbf{y}, \theta)` """ # perform implementation check if debug mode is enabled if is_debug_enabled(): assert_fcn_params(fcn, (y0, *params)) pfunc = get_pure_function(fcn) @make_sibling(pfunc) def new_fcn(y, *params): return y - pfunc(y, *params) fwd_options["method"] = _get_rootfinder_default_method(method) return _RootFinder.apply(new_fcn, y0, new_fcn, False, fwd_options, bck_options, len(params), *params, *pfunc.objparams())
[docs]def minimize( fcn: Callable[..., torch.Tensor], y0: torch.Tensor, params: Sequence[Any] = [], bck_options: Mapping[str, Any] = {}, method: Union[str, Callable] = None, **fwd_options) -> torch.Tensor: r""" Solve the unbounded minimization problem: .. math:: \mathbf{y^*} = \arg\min_\mathbf{y} f(\mathbf{y}, \theta) to find the best :math:`\mathbf{y}` that minimizes the output of the function :math:`f`. Arguments --------- fcn: callable The function to be optimized with output tensor with 1 element. y0: torch.tensor Initial guess of the solution with shape ``(*ny)`` params: list Sequence of any other parameters to be put in ``fcn`` bck_options: dict Method-specific options for the backward solve (see :func:`xitorch.linalg.solve`) method: str or callable or None Minimization method. If None, it will choose ``"broyden1"``. **fwd_options Method-specific options (see method section) Returns ------- torch.tensor The solution of the minimization with shape ``(*ny)`` Example ------- .. testsetup:: root1 import torch from xitorch.optimize import minimize .. doctest:: root1 >>> def func1(y, A): # example function ... return torch.sum((A @ y)**2 + y / 2.0) >>> A = torch.tensor([[1.1, 0.4], [0.3, 0.8]]).requires_grad_() >>> y0 = torch.zeros((2,1)) # zeros as the initial guess >>> ymin = minimize(func1, y0, params=(A,)) >>> print(ymin) tensor([[-0.0519], [-0.2684]], grad_fn=<_RootFinderBackward>) """ assert not torch.is_complex(y0), \ "complex number is not supported on xitorch.optimize.rootfinder at the moment" # perform implementation check if debug mode is enabled if is_debug_enabled(): assert_fcn_params(fcn, (y0, *params)) pfunc = get_pure_function(fcn) fwd_options["method"] = _get_minimizer_default_method(method) method = fwd_options["method"] # minimization can use rootfinder algorithm, so check if it is actually # using the optimization algorithm, not the rootfinder algorithm opt_method = method not in _RF_METHODS.keys() # the rootfinder algorithms are designed to move to the opposite direction # of the output of the function, so the output of this function is just # the grad of z w.r.t. y # if it is going to optimization method, then also returns the value @make_sibling(pfunc) def _min_fwd_fcn(y, *params): with torch.enable_grad(): y1 = y.clone().requires_grad_() z = pfunc(y1, *params) grady, = torch.autograd.grad(z, (y1,), retain_graph=True, create_graph=torch.is_grad_enabled()) return z, grady @make_sibling(_min_fwd_fcn) def _rf_fcn(y, *params): z, grady = _min_fwd_fcn(y, *params) return grady # if using the optimization algorithm, then the forward function is the one # that returns f and grad if opt_method: _fwd_fcn = _min_fwd_fcn # if it is just using the rootfinder algorithm, then the forward function # is the one that returns only the grad else: _fwd_fcn = _rf_fcn return _RootFinder.apply(_rf_fcn, y0, _fwd_fcn, opt_method, fwd_options, bck_options, len(params), *params, *pfunc.objparams())
class _RootFinder(torch.autograd.Function): @staticmethod def forward(ctx, fcn, y0, fwd_fcn, is_opt_method, options, bck_options, nparams, *allparams): # fcn: a function that returns what has to be 0 (will be used in the # backward, not used in the forward). For minimization, it is # the gradient # fwd_fcn: a function that will be executed in the forward method # (unused in the backward) # This class is also used for minimization, where fcn and fwd_fcn might # be slightly different # set default options config = options ctx.bck_options = bck_options params = allparams[:nparams] objparams = allparams[nparams:] with fwd_fcn.useobjparams(objparams): method = config.pop("method") methods = _RF_METHODS if not is_opt_method else _OPT_METHODS name = "rootfinder" if not is_opt_method else "minimizer" method_fcn = get_method(name, methods, method) y = method_fcn(fwd_fcn, y0, params, **config) ctx.fcn = fcn ctx.is_opt_method = is_opt_method # split tensors and non-tensors params ctx.nparams = nparams ctx.param_sep = TensorNonTensorSeparator(allparams) tensor_params = ctx.param_sep.get_tensor_params() ctx.save_for_backward(y, *tensor_params) return y @staticmethod def backward(ctx, grad_yout): param_sep = ctx.param_sep yout = ctx.saved_tensors[0] nparams = ctx.nparams fcn = ctx.fcn # merge the tensor and nontensor parameters tensor_params = ctx.saved_tensors[1:] allparams = param_sep.reconstruct_params(tensor_params) params = allparams[:nparams] objparams = allparams[nparams:] # dL/df with ctx.fcn.useobjparams(objparams): jac_dfdy = jac(fcn, params=(yout, *params), idxs=[0])[0] gyfcn = solve(A=jac_dfdy.H, B=-grad_yout.reshape(-1, 1), bck_options=ctx.bck_options, **ctx.bck_options) gyfcn = gyfcn.reshape(grad_yout.shape) # get the grad for the params with torch.enable_grad(): tensor_params_copy = [p.clone().requires_grad_() for p in tensor_params] allparams_copy = param_sep.reconstruct_params(tensor_params_copy) params_copy = allparams_copy[:nparams] objparams_copy = allparams_copy[nparams:] with ctx.fcn.useobjparams(objparams_copy): yfcn = fcn(yout, *params_copy) grad_tensor_params = torch.autograd.grad(yfcn, tensor_params_copy, grad_outputs=gyfcn, create_graph=torch.is_grad_enabled(), allow_unused=True) grad_nontensor_params = [None for _ in range(param_sep.nnontensors())] grad_params = param_sep.reconstruct_params(grad_tensor_params, grad_nontensor_params) return (None, None, None, None, None, None, None, *grad_params) def _get_rootfinder_default_method(method): if method is None: return "broyden1" else: return method def _get_minimizer_default_method(method): if method is None: return "broyden1" else: return method # docstring completion rf_methods: List[Callable] = [broyden1, broyden2, linearmixing] opt_methods: List[Callable] = [gd, adam] rootfinder.__doc__ = get_methods_docstr(rootfinder, rf_methods) equilibrium.__doc__ = get_methods_docstr(equilibrium, rf_methods) minimize.__doc__ = get_methods_docstr(minimize, rf_methods + opt_methods)