from abc import abstractmethod
import torch
from typing import Callable, Union, Mapping, Any, Sequence
from xitorch._utils.assertfuncs import assert_fcn_params, assert_runtime
from xitorch._core.pure_function import get_pure_function, make_sibling
from xitorch._utils.misc import set_default_option, TensorNonTensorSeparator, \
TensorPacker, get_method
from xitorch._impls.integrate.fixed_quad import leggauss
from xitorch._docstr.api_docstr import get_methods_docstr
from xitorch.debug.modes import is_debug_enabled
__all__ = ["quad"]
[docs]def quad(
fcn: Union[Callable[..., torch.Tensor], Callable[..., Sequence[torch.Tensor]]],
xl: Union[float, int, torch.Tensor],
xu: Union[float, int, torch.Tensor],
params: Sequence[Any] = [],
bck_options: Mapping[str, Any] = {},
method: Union[str, Callable, None] = None,
**fwd_options) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
r"""
Calculate the quadrature:
.. math::
y = \int_{x_l}^{x_u} f(x, \theta)\ \mathrm{d}x
Arguments
---------
fcn: callable
The function to be integrated. Its output must be a tensor with
shape ``(*nout)`` or list of tensors.
xl: float, int or 1-element torch.Tensor
The lower bound of the integration.
xu: float, int or 1-element torch.Tensor
The upper bound of the integration.
params: list
Sequence of any other parameters for the function ``fcn``.
bck_options: dict
Options for the backward quadrature method.
method: str or callable or None
Quadrature method. If None, it will choose ``"leggauss"``.
**fwd_options
Method-specific options (see method section).
Returns
-------
torch.tensor or a list of tensors
The quadrature results with shape ``(*nout)`` or list of tensors.
"""
# perform implementation check if debug mode is enabled
if is_debug_enabled():
assert_fcn_params(fcn, (xl, *params))
if isinstance(xl, torch.Tensor):
assert_runtime(torch.numel(xl) == 1, "xl must be a 1-element tensors")
if isinstance(xu, torch.Tensor):
assert_runtime(torch.numel(xu) == 1, "xu must be a 1-element tensors")
if method is None:
method = "leggauss"
fwd_options["method"] = method
out = fcn(xl, *params)
if isinstance(out, torch.Tensor):
dtype = out.dtype
device = out.device
is_tuple_out = False
elif len(out) > 0:
dtype = out[0].dtype
device = out[0].device
is_tuple_out = True
else:
raise RuntimeError("The output of the fcn must be non-empty")
pfunc = get_pure_function(fcn)
nparams = len(params)
if is_tuple_out:
packer = TensorPacker(out)
@make_sibling(pfunc)
def pfunc2(x, *params):
y = fcn(x, *params)
return packer.flatten(y)
res = _Quadrature.apply(pfunc2, xl, xu, fwd_options, bck_options, nparams,
dtype, device, *params, *pfunc.objparams())
return packer.pack(res)
else:
return _Quadrature.apply(pfunc, xl, xu, fwd_options, bck_options, nparams,
dtype, device, *params, *pfunc.objparams())
class _Quadrature(torch.autograd.Function):
# NOTE: _Quadrature method do not involve changing the state (objparams) of
# fcn, so there is no need in using `with fcn.useobjparams(objparams)`
# statements.
# The function `disable_state_change()` is used to disable state change of
# the pure function during the execution of the forward and backward
# calculations
@staticmethod
def forward(ctx, fcn, xl, xu, fwd_options, bck_options, nparams,
dtype, device, *all_params):
with fcn.disable_state_change():
config = fwd_options
ctx.bck_config = set_default_option(config, bck_options)
params = all_params[:nparams]
objparams = all_params[nparams:]
# convert to tensor
xl = torch.as_tensor(xl, dtype=dtype, device=device)
xu = torch.as_tensor(xu, dtype=dtype, device=device)
# apply transformation if the boundaries contain inf
if _isinf(xl) or _isinf(xu):
tfm = _TanInfTransform()
@make_sibling(fcn)
def fcn2(t, *params):
ys = fcn(tfm.forward(t), *params)
dxdt = tfm.dxdt(t)
return ys * dxdt
tl = tfm.x2t(xl)
tu = tfm.x2t(xu)
else:
fcn2 = fcn
tl = xl
tu = xu
method = config.pop("method")
methods = {
"leggauss": leggauss
}
method_fcn = get_method("quad", methods, method)
y = method_fcn(fcn2, tl, tu, params, **config)
# save the parameters for backward
ctx.param_sep = TensorNonTensorSeparator(all_params)
tensor_params = ctx.param_sep.get_tensor_params()
ctx.xltensor = isinstance(xl, torch.Tensor)
ctx.xutensor = isinstance(xu, torch.Tensor)
xlxu_tensor = ([xl] if ctx.xltensor else []) + \
([xu] if ctx.xutensor else [])
ctx.xlxu_nontensor = ([xl] if not ctx.xltensor else []) + \
([xu] if not ctx.xutensor else [])
ctx.save_for_backward(*xlxu_tensor, *tensor_params)
ctx.fcn = fcn
ctx.nparams = nparams
return y
@staticmethod
def backward(ctx, grad_ys):
# retrieve the params
ntensor_params = ctx.param_sep.ntensors()
tensor_params = ctx.saved_tensors[-ntensor_params:]
allparams = ctx.param_sep.reconstruct_params(tensor_params)
nparams = ctx.nparams
params = allparams[:nparams]
fcn = ctx.fcn
with fcn.disable_state_change():
# restore xl, and xu
xlxu_tensor = ctx.saved_tensors[:-ntensor_params]
if ctx.xltensor and ctx.xutensor:
xl, xu = xlxu_tensor
elif ctx.xltensor:
xl = xlxu_tensor[0]
xu = ctx.xlxu_nontensor[0]
elif ctx.xutensor:
xu = xlxu_tensor[0]
xl = ctx.xlxu_nontensor[0]
else:
xl, xu = ctx.xlxu_nontensor
# calculate the gradient for the boundaries
grad_xl = -torch.dot(grad_ys.reshape(-1), fcn(xl, *params).reshape(-1)
).reshape(xl.shape) if ctx.xltensor else None
grad_xu = torch.dot(grad_ys.reshape(-1), fcn(xu, *params).reshape(-1)
).reshape(xu.shape) if ctx.xutensor else None
def new_fcn(x, *grad_y_params):
grad_ys = grad_y_params[0]
# not setting objparams and params because the params and objparams
# are still the same objects as the objects outside
with torch.enable_grad():
f = fcn(x, *params)
dfdts = torch.autograd.grad(f, tensor_params,
grad_outputs=grad_ys,
retain_graph=True,
create_graph=torch.is_grad_enabled())
return dfdts
# reconstruct grad_params
# listing tensor_params in the params of quad to make sure it gets
# the gradient calculated
dydts = quad(new_fcn, xl, xu, params=(grad_ys, *tensor_params),
fwd_options=ctx.bck_config, bck_options=ctx.bck_config)
dydns = [None for _ in range(ctx.param_sep.nnontensors())]
grad_params = ctx.param_sep.reconstruct_params(dydts, dydns)
return (None, grad_xl, grad_xu, None, None, None, None, None, *grad_params)
def _isinf(x):
return torch.any(torch.isinf(x))
class _BaseInfTransform(object):
@abstractmethod
def forward(self, t):
pass
@abstractmethod
def dxdt(self, t):
pass
@abstractmethod
def x2t(self, x):
pass
class _TanInfTransform(_BaseInfTransform):
def forward(self, t):
return torch.tan(t)
def dxdt(self, t):
sec = 1. / torch.cos(t)
return sec * sec
def x2t(self, x):
return torch.atan(x)
# docstring completion
quad.__doc__ = get_methods_docstr(quad, [leggauss])