import torch
import copy
from typing import Callable, Union, Mapping, Any, Sequence, Dict
from xitorch._utils.assertfuncs import assert_fcn_params, assert_runtime
from xitorch._core.pure_function import get_pure_function, make_sibling
from xitorch._impls.integrate.ivp.explicit_rk import rk4_ivp, rk38_ivp, fwd_euler_ivp
from xitorch._impls.integrate.ivp.adaptive_rk import rk23_adaptive, rk45_adaptive
from xitorch._utils.misc import set_default_option, TensorNonTensorSeparator, \
TensorPacker, get_method
from xitorch._utils.tensor import convert_none_grads_to_zeros
from xitorch._docstr.api_docstr import get_methods_docstr
from xitorch.debug.modes import is_debug_enabled
__all__ = ["solve_ivp"]
[docs]def solve_ivp(fcn: Union[Callable[..., torch.Tensor], Callable[..., Sequence[torch.Tensor]]],
ts: torch.Tensor,
y0: torch.Tensor,
params: Sequence[Any] = [],
bck_options: Mapping[str, Any] = {},
method: Union[str, Callable, None] = None,
**fwd_options) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
r"""
Solve the initial value problem (IVP) or also commonly known as ordinary
differential equations (ODE), where given the initial value :math:`\mathbf{y_0}`,
it then solves
.. math::
\mathbf{y}(t) = \mathbf{y_0} + \int_{t_0}^{t} \mathbf{f}(t', \mathbf{y}, \theta)\ \mathrm{d}t'
Although the original ``solve_ivp`` does not accept batched ``ts``, it can
be batched using functorch's ``vmap`` (only for explicit solver, though,
e.g. ``rk38``, ``rk4``, and ``euler``). Adaptive steps cannot be vmapped
at the moment.
Arguments
---------
fcn: callable
The function that represents dy/dt. The function takes an input of a
single time ``t`` and tensor ``y`` with shape ``(*ny)`` and
produce :math:`\mathrm{d}\mathbf{y}/\mathrm{d}t` with shape ``(*ny)``.
The output of the function must be a tensor with shape ``(*ny)`` or
a list of tensors.
ts: torch.tensor
The time points where the value of `y` will be returned.
It must be monotonically increasing or decreasing.
It is a tensor with shape ``(nt,)``.
y0: torch.tensor
The initial value of ``y``, i.e. ``y(t[0]) == y0``.
It is a tensor with shape ``(*ny)`` or a list of tensors.
params: list
Sequence of other parameters required in the function.
bck_options: dict
Options for the backward solve_ivp method. If not specified, it will
take the same options as fwd_options.
method: str or callable or None
Initial value problem solver. If None, it will choose ``"rk45"``.
**fwd_options
Method-specific option (see method section below).
Returns
-------
torch.tensor or a list of tensors
The values of ``y`` for each time step in ``ts``.
It is a tensor with shape ``(nt,*ny)`` or a list of tensors
"""
if is_debug_enabled():
assert_fcn_params(fcn, (ts[0], y0, *params))
assert_runtime(len(ts.shape) == 1, "Argument ts must be a 1D tensor")
if method is None: # set the default method
method = "rk45"
fwd_options["method"] = method
is_y0_list = isinstance(y0, (list, tuple))
pfcn = get_pure_function(fcn)
if is_y0_list:
nt = len(ts)
roller = TensorPacker(y0)
@make_sibling(pfcn)
def pfcn2(t, ytensor, *params):
ylist = roller.pack(ytensor)
res_list = pfcn(t, ylist, *params)
if not isinstance(res_list, (list, tuple)):
raise RuntimeError("The y0 and output of fcn must both be tuple or a tensor")
res = roller.flatten(res_list)
return res
y0 = roller.flatten(y0)
res = _SolveIVP.apply(pfcn2, ts, fwd_options, bck_options, len(params), y0, *params, *pfcn.objparams())
return roller.pack(res)
else:
return _SolveIVP.apply(pfcn, ts, fwd_options, bck_options, len(params), y0, *params, *pfcn.objparams())
class _SolveIVP(torch.autograd.Function):
@staticmethod
def forward(ctx, pfcn, ts, fwd_options, bck_options, nparams, y0, *allparams):
config = fwd_options
ctx.bck_config = set_default_option(config, bck_options)
params = allparams[:nparams]
objparams = allparams[nparams:]
method = config.pop("method")
methods = {
"rk4": rk4_ivp,
"rk38": rk38_ivp,
"rk23": rk23_adaptive,
"rk45": rk45_adaptive,
"euler": fwd_euler_ivp,
}
solver = get_method("solve_ivp", methods, method)
yt = solver(pfcn, ts, y0, params, **config)
# save the parameters for backward
ctx.param_sep = TensorNonTensorSeparator(allparams, varonly=True)
tensor_params = ctx.param_sep.get_tensor_params()
ctx.save_for_backward(ts, y0, *tensor_params)
ctx.pfcn = pfcn
ctx.nparams = nparams
ctx.yt = yt
ctx.ts_requires_grad = ts.requires_grad
return yt
@staticmethod
def backward(ctx, grad_yt):
# grad_yt: (nt, *ny)
nparams = ctx.nparams
pfcn = ctx.pfcn
param_sep = ctx.param_sep
yt = ctx.yt
ts_requires_grad = ctx.ts_requires_grad
# restore the parameters
saved_tensors = ctx.saved_tensors
ts = saved_tensors[0]
y0 = saved_tensors[1]
tensor_params = list(saved_tensors[2:])
allparams = param_sep.reconstruct_params(tensor_params)
ntensor_params = len(tensor_params)
params = allparams[:nparams]
objparams = allparams[nparams:]
grad_enabled = torch.is_grad_enabled()
# custom function to evaluate the input `pfcn` based on whether we want
# to connect the graph or not
def pfunc2(t, y, tensor_params):
if not grad_enabled:
# if graph is not constructed, then use the default tensor_params
ycopy = y.detach().requires_grad_() # [yi.detach().requires_grad_() for yi in y]
tcopy = t.detach().requires_grad_()
f = pfcn(tcopy, ycopy, *params)
return f, tcopy, ycopy, tensor_params
else:
# if graph is constructed, then use the clone of the tensor params
# so that infinite loop of backward can be avoided
tensor_params_copy = [p.clone().requires_grad_() for p in tensor_params]
ycopy = y.clone().requires_grad_()
tcopy = t.clone().requires_grad_()
allparams_copy = param_sep.reconstruct_params(tensor_params_copy)
params_copy = allparams_copy[:nparams]
objparams_copy = allparams_copy[nparams:]
with pfcn.useobjparams(objparams_copy):
f = pfcn(tcopy, ycopy, *params_copy)
return f, tcopy, ycopy, tensor_params_copy
# slices and indices definitions on the augmented states
y_index = 0
dLdy_index = 1
dLdt_index = 2
dLdt_slice = slice(dLdt_index, dLdt_index + 1, None) # [2:3]
dLdp_slice = slice(-ntensor_params, None, None) if ntensor_params > 0 else slice(0,
0, None) # [-ntensor_params:]
state_size = 3 + ntensor_params
states = [None for _ in range(state_size)]
def new_pfunc(t, states, *tensor_params):
# t: single-element
y = states[y_index]
dLdy = -states[dLdy_index]
with torch.enable_grad():
f, t2, y2, tensor_params2 = pfunc2(t, y, tensor_params)
allgradinputs = ([y2] + [t2] + list(tensor_params2))
allgrads = torch.autograd.grad(f,
inputs=allgradinputs,
grad_outputs=dLdy,
retain_graph=True,
allow_unused=True,
create_graph=torch.is_grad_enabled()) # list of (*ny)
allgrads = convert_none_grads_to_zeros(allgrads, allgradinputs)
outs = (
f, # dydt
*allgrads,
)
return outs
ts_flip = ts.flip(0)
t_flip_idx = -1
states[y_index] = yt[t_flip_idx]
states[dLdy_index] = grad_yt[t_flip_idx]
states[dLdt_index] = torch.zeros_like(ts[0])
states[dLdp_slice] = [torch.zeros_like(tp) for tp in tensor_params]
grad_ts = [None for _ in range(len(ts))] if ts_requires_grad else None
# define a new function for the augmented dynamics
bkw_roller = TensorPacker(states)
@make_sibling(new_pfunc)
def pfcn_back(t, ytensor, *params):
ylist = bkw_roller.pack(ytensor)
res_list = new_pfunc(t, ylist, *params)
res = bkw_roller.flatten(res_list)
return res
for i in range(len(ts_flip) - 1):
if ts_requires_grad:
feval = pfunc2(ts_flip[i], states[y_index], tensor_params)[0]
dLdt1 = torch.dot(feval.reshape(-1), grad_yt[t_flip_idx].reshape(-1))
states[dLdt_index] -= dLdt1
grad_ts[t_flip_idx] = dLdt1.reshape(-1)
t_flip_idx -= 1
states_flatten = bkw_roller.flatten(states)
fwd_config = copy.copy(ctx.bck_config)
bck_config = copy.copy(ctx.bck_config)
outs_flatten = _SolveIVP.apply(
pfcn_back, ts_flip[i:i + 2], fwd_config, bck_config, len(tensor_params),
states_flatten, *tensor_params)
outs = bkw_roller.pack(outs_flatten)
# only take the output for the earliest time
states = [out[-1] for out in outs]
states[y_index] = yt[t_flip_idx]
# gyt is the contribution from the input grad_y
# gy0 is the propagated gradients from the later time step
states[dLdy_index] = grad_yt[t_flip_idx] + states[dLdy_index]
if ts_requires_grad:
grad_ts[0] = states[dLdt_index].reshape(-1)
grad_y0 = states[dLdy_index] # dL/dy0, (*ny)
if ts_requires_grad:
grad_ts = torch.cat(grad_ts).reshape(*ts.shape)
grad_tensor_params = states[dLdp_slice]
grad_ntensor_params = [None for _ in range(len(allparams) - ntensor_params)]
grad_params = param_sep.reconstruct_params(grad_tensor_params, grad_ntensor_params)
return (None, grad_ts, None, None, None, grad_y0, *grad_params)
# docstring completion
ivp_methods: Dict[str, Callable] = {
"rk45": rk45_adaptive,
"rk23": rk23_adaptive,
"rk4": rk4_ivp,
"rk38": rk38_ivp,
"euler": fwd_euler_ivp,
}
solve_ivp.__doc__ = get_methods_docstr(solve_ivp, ivp_methods)