import torch
from typing import Union, Sequence, Any, Callable, Mapping
from xitorch.debug.modes import is_debug_enabled
from xitorch._core.pure_function import get_pure_function, make_sibling
from xitorch._utils.misc import set_default_option, TensorNonTensorSeparator, \
TensorPacker, get_method
from xitorch._utils.assertfuncs import assert_fcn_params
from xitorch._impls.integrate.mcsamples.mcmc import mh, mhcustom, dummy1d
from xitorch._docstr.api_docstr import get_methods_docstr
__all__ = ["mcquad"]
[docs]def mcquad(
ffcn: Union[Callable[..., torch.Tensor], Callable[..., Sequence[torch.Tensor]]],
log_pfcn: Callable[..., torch.Tensor],
x0: torch.Tensor,
fparams: Sequence[Any] = [],
pparams: Sequence[Any] = [],
bck_options: Mapping[str, Any] = {},
method: Union[str, Callable, None] = None,
**fwd_options) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
r"""
Performing monte carlo quadrature to calculate the expectation value:
.. math::
\mathbb{E}_p[f] = \frac{\int f(\mathbf{x},\theta_f) p(\mathbf{x},\theta_p)
\ \mathrm{d}\mathbf{x} }{ \int p(\mathbf{x},\theta_p)\ \mathrm{d}\mathbf{x} }
Arguments
---------
ffcn: Callable
The function with to be integrated. Its outputs is a tensor or a
list of tensors. To call the function: ``ffcn(x, *fparams)``
log_pfcn: Callable
The natural logarithm of the probability function. The output should be
a one-element tensor. To call the function: ``log_pfcn(x, *pparams)``
x0: torch.Tensor
Tensor with any size as the initial position.
The call ``ffcn(x0,*fparams)`` must work.
fparams: list
Sequence of any other parameters for ``ffcn``.
pparams: list
Sequence of any other parameters for ``gfcn``.
bck_options: dict
Options for the backward mcquad operation. Unspecified fields will be
taken from ``fwd_options``.
method: str or callable or None
Monte Carlo quadrature method. If None, it will choose ``"mh"``.
**fwd_options: dict
Method-specific options (see method section below).
Returns
-------
torch.Tensor or a list of torch.Tensor
The expectation values of the function ``ffcn`` over the space of ``x``.
If the output of ``ffcn`` is a list, then this is also a list.
"""
if method is None:
method = "mh"
return _mcquad(ffcn, log_pfcn, x0, None, None, fparams, pparams,
method, bck_options, **fwd_options)
def _mcquad(ffcn, log_pfcn, x0, xsamples, wsamples, fparams, pparams, method,
bck_options, **fwd_options):
# this is mcquad with an additional xsamples argument, to prevent xsamples being set by users
if is_debug_enabled():
assert_fcn_params(ffcn, (x0, *fparams))
assert_fcn_params(log_pfcn, (x0, *pparams))
# check if ffcn produces a list / tuple
out = ffcn(x0, *fparams)
is_tuple_out = isinstance(out, list) or isinstance(out, tuple)
# get the pure functions
pure_ffcn = get_pure_function(ffcn)
pure_logpfcn = get_pure_function(log_pfcn)
nfparams = len(fparams)
npparams = len(pparams)
fobjparams = pure_ffcn.objparams()
pobjparams = pure_logpfcn.objparams()
nf_objparams = len(fobjparams)
if is_tuple_out:
packer = TensorPacker(out)
@make_sibling(pure_ffcn)
def pure_ffcn2(x, *fparams):
y = pure_ffcn(x, *fparams)
return packer.flatten(y)
res = _MCQuad.apply(pure_ffcn2, pure_logpfcn, x0, None, None,
method, fwd_options, bck_options,
nfparams, nf_objparams, npparams, *fparams, *fobjparams, *pparams, *pobjparams)
return packer.pack(res)
else:
return _MCQuad.apply(pure_ffcn, pure_logpfcn, x0, None, None,
method, fwd_options, bck_options,
nfparams, nf_objparams, npparams, *fparams, *fobjparams, *pparams, *pobjparams)
class _MCQuad(torch.autograd.Function):
@staticmethod
def forward(ctx, ffcn, log_pfcn, x0, xsamples, wsamples,
method, fwd_options, bck_options,
nfparams, nf_objparams, npparams, *all_fpparams):
# set up the default options
config = fwd_options
ctx.bck_config = set_default_option(config, bck_options)
# split the parameters
fparams = all_fpparams[:nfparams]
fobjparams = all_fpparams[nfparams:nfparams + nf_objparams]
pparams = all_fpparams[nfparams + nf_objparams:nfparams + nf_objparams + npparams]
pobjparams = all_fpparams[nfparams + nf_objparams + npparams:]
# select the method for the sampling
if xsamples is None:
methods = {
"mh": mh,
"_dummy1d": dummy1d,
"mhcustom": mhcustom,
}
method_fcn = get_method("mcquad", methods, method)
xsamples, wsamples = method_fcn(log_pfcn, x0, pparams, **config)
epf = _integrate(ffcn, xsamples, wsamples, fparams)
# save parameters for backward calculations
ctx.xsamples = xsamples
ctx.wsamples = wsamples
ctx.ffcn = ffcn
ctx.log_pfcn = log_pfcn
ctx.fparam_sep = TensorNonTensorSeparator((*fparams, *fobjparams))
ctx.pparam_sep = TensorNonTensorSeparator((*pparams, *pobjparams))
ctx.nfparams = len(fparams)
ctx.npparams = len(pparams)
ctx.method = method
# save for backward
ftensor_params = ctx.fparam_sep.get_tensor_params()
ptensor_params = ctx.pparam_sep.get_tensor_params()
ctx.nftensorparams = len(ftensor_params)
ctx.nptensorparams = len(ptensor_params)
ctx.save_for_backward(epf, *ftensor_params, *ptensor_params)
return epf
@staticmethod
def backward(ctx, grad_epf):
# restore the parameters
alltensors = ctx.saved_tensors
nftensorparams = ctx.nftensorparams
nptensorparams = ctx.nptensorparams
epf = alltensors[0]
ftensor_params = alltensors[1:1 + nftensorparams]
ptensor_params = alltensors[1 + nftensorparams:]
fptensor_params = alltensors[1:]
# get the parameters and the object parameters
nfparams = ctx.nfparams
npparams = ctx.npparams
fall_params = ctx.fparam_sep.reconstruct_params(ftensor_params)
pall_params = ctx.pparam_sep.reconstruct_params(ptensor_params)
fparams = fall_params[:nfparams]
fobjparams = fall_params[nfparams:]
pparams = pall_params[:npparams]
pobjparams = pall_params[npparams:]
# get other things from the forward
ffcn = ctx.ffcn
log_pfcn = ctx.log_pfcn
xsamples = ctx.xsamples
wsamples = ctx.wsamples
grad_enabled = torch.is_grad_enabled()
def function_wrap(fcn, param_sep, nparams, x, tensor_params):
all_params = param_sep.reconstruct_params(tensor_params)
params = all_params[:nparams]
objparams = all_params[nparams:]
with fcn.useobjparams(objparams):
f = fcn(x, *params)
return f
def aug_function(x, *grad_and_fptensor_params):
local_grad_enabled = torch.is_grad_enabled()
grad_epf = grad_and_fptensor_params[0]
epf = grad_and_fptensor_params[1]
fptensor_params = grad_and_fptensor_params[2:]
ftensor_params = fptensor_params[:nftensorparams]
ptensor_params = fptensor_params[nftensorparams:]
with torch.enable_grad():
# if graph is constructed, then fptensor_params is a clone of
# fptensor_params from outside, therefore, it needs to be put
# in the pure function's objects (that's what function_wrap does)
if grad_enabled:
fout = function_wrap(ffcn, ctx.fparam_sep, nfparams, x, ftensor_params)
pout = function_wrap(log_pfcn, ctx.pparam_sep, npparams, x, ptensor_params)
# if graph is not constructed, then fptensor_params in this
# function *is* fptensor_params in the outside, so we can
# just use fparams and pparams from the outside
else:
fout = ffcn(x, *fparams)
pout = log_pfcn(x, *pparams)
# derivative of fparams
dLdthetaf = []
if len(ftensor_params) > 0:
dLdthetaf = torch.autograd.grad(fout, ftensor_params,
grad_outputs=grad_epf,
retain_graph=True,
create_graph=local_grad_enabled)
# derivative of pparams
dLdthetap = []
if len(ptensor_params) > 0:
dLdef = torch.dot((fout - epf).reshape(-1), grad_epf.reshape(-1))
dLdthetap = torch.autograd.grad(pout, ptensor_params,
grad_outputs=dLdef.reshape(pout.shape),
retain_graph=True,
create_graph=local_grad_enabled)
# combine the states needed for backward
outs = (
*dLdthetaf,
*dLdthetap,
)
return outs
if grad_enabled:
fptensor_params_copy = [y.clone().requires_grad_() for y in fptensor_params]
else:
fptensor_params_copy = fptensor_params
aug_epfs = _mcquad(aug_function, log_pfcn,
x0=xsamples[0], # unused because xsamples is set
xsamples=xsamples,
wsamples=wsamples,
fparams=(grad_epf, epf, *fptensor_params_copy),
pparams=pparams,
method=ctx.method,
bck_options=ctx.bck_config,
**ctx.bck_config)
dLdthetaf = aug_epfs[:nftensorparams]
dLdthetap = aug_epfs[nftensorparams:]
# combine the gradient for all fparams
dLdfnontensor = [None for _ in range(ctx.fparam_sep.nnontensors())]
dLdpnontensor = [None for _ in range(ctx.pparam_sep.nnontensors())]
dLdtf = ctx.fparam_sep.reconstruct_params(dLdthetaf, dLdfnontensor)
dLdtp = ctx.pparam_sep.reconstruct_params(dLdthetap, dLdpnontensor)
return (None, None, None, None, None, None, None, None, None, None, None,
*dLdtf, *dLdtp)
def _integrate(ffcn, xsamples, wsamples, fparams):
nsamples = len(xsamples)
res = 0.0
for x, w in zip(xsamples, wsamples):
res = res + ffcn(x, *fparams) * w
return res
# docstring completion
mcquad.__doc__ = get_methods_docstr(mcquad, [mh, mhcustom])