from __future__ import annotations
from typing import Any, Optional, List, Tuple, Dict
from copy import deepcopy, copy
import torch
__all__ = ["Packer"]
[docs]class Packer(object):
"""
Packer is an object that could extract the tensors in a structure and
rebuild the structure from the given tensors.
This object preserves the structure of the object by performing the deepcopy
of the object, except for the tensor.
Arguments
---------
obj: Any
Any structure object that contains tensors.
Example
-------
.. testsetup::
import torch
import xitorch
.. doctest::
>>> a = torch.tensor(1.0)
>>> obj = {
... "a": a,
... "b": a * 3,
... "c": a,
... }
>>> packer = xitorch.Packer(obj)
>>> tensors = packer.get_param_tensor_list()
>>> print(tensors)
[tensor(1.), tensor(3.)]
>>> new_tensors = [torch.tensor(2.0), torch.tensor(4.0)]
>>> new_obj = packer.construct_from_tensor_list(new_tensors)
>>> print(new_obj)
{'a': tensor(2.), 'b': tensor(4.), 'c': tensor(2.)}
"""
def __init__(self, obj: Any):
# deep copy the object, except the tensors
tensor_lists = _extract_tensors(obj)
memo = {id(t): t for t in tensor_lists}
self._tensor_memo = copy(memo) # shallow copy
self._obj = deepcopy(obj, memo)
# caches
self._params_tensor_list: Optional[List[torch.Tensor]] = tensor_lists
self._unique_params_idxs: Optional[List[int]] = None
self._unique_inverse_idxs: Optional[List[int]] = None
self._unique_tensor_shapes: Optional[List[torch.Size]] = None
self._tensor_shapes: Optional[List[torch.Size]] = None
self._unique_tensor_numels: Optional[List[int]] = None
self._unique_tensor_numel_tot: Optional[int] = None
self._tensor_numels: Optional[List[int]] = None
self._tensor_numel_tot: Optional[int] = None
[docs] def get_param_tensor_list(self, unique: bool = True) -> List[torch.Tensor]:
"""
Returns the list of tensors contained in the object. It will traverse
down the object via elements for list, values for dictionary, or
``__dict__`` for object that has ``__dict__`` attribute.
Arguments
---------
unique: bool
If True, then only returns the unique tensors. Otherwise, duplicates
can also be returned.
Returns
-------
list of torch.Tensor
List of tensors contained in the object.
"""
# get the params tensor list
if self._params_tensor_list is not None:
# get it from cache if available
params_tensors = self._params_tensor_list
else:
params_tensors = _extract_tensors(self._obj)
self._params_tensor_list = params_tensors
# only take the unique tensors if required
if unique:
if self._unique_params_idxs is not None:
unique_idxs = self._unique_params_idxs
unique_inverse = self._unique_inverse_idxs
else:
unique_idxs, unique_inverse = _get_unique_idxs(params_tensors)
self._unique_params_idxs = unique_idxs
self._unique_inverse_idxs = unique_inverse
params_tensors = [params_tensors[i] for i in unique_idxs]
if unique:
self._unique_tensor_shapes = [p.shape for p in params_tensors]
else:
self._tensor_shapes = [p.shape for p in params_tensors]
return params_tensors
[docs] def get_param_tensor(self, unique: bool = True) -> Optional[torch.Tensor]:
"""
Returns the tensor parameters as a single tensor. This can be used,
for example, if there are multiple parameters to be optimized using
``xitorch.optimize.minimize``.
Arguments
---------
unique: bool
If True, then only returns the tensor from unique tensors list.
Otherwise, duplicates can also be returned.
Returns
-------
torch.Tensor or None
The parameters of the object in a single tensor or None if there
is no tensor contained in the object.
"""
params = self.get_param_tensor_list(unique=unique)
if len(params) == 0:
return None
else:
if unique:
self._unique_tensor_numels = [p.numel() for p in params]
self._unique_tensor_numel_tot = sum(self._unique_tensor_numels)
else:
self._tensor_numels = [p.numel() for p in params]
self._tensor_numel_tot = sum(self._tensor_numels)
if len(params) == 1:
return params[0]
else:
tparam = torch.cat([p.reshape(-1) for p in params])
return tparam
[docs] def construct_from_tensor_list(self, tensors: List[torch.Tensor], unique: bool = True) -> Any:
"""
Construct the object from the tensor list and returns the object structure
with the new tensors. Executing this does not change the state of the
Packer object.
Arguments
---------
tensors: list of torch.Tensor
The tensor parameters to be filled into the object.
unique: bool
Indicating if the tensor list ``tensors`` is from the unique
parameters of the object.
Returns
-------
Any
A new object with the same structure as the input to ``__init__``
object except the tensor is changed according to ``tensors``.
"""
if unique:
tensor_shapes = self._unique_tensor_shapes
else:
tensor_shapes = self._tensor_shapes
if tensor_shapes is None:
raise RuntimeError("Please execute self.get_param_tensor_list(%s) first" % str(unique))
else:
# make sure the length matches
if len(tensor_shapes) != len(tensors):
raise RuntimeError("Mismatch length of the tensors")
if len(tensor_shapes) == 0:
return self._obj
# check the tensor shapes
for i, (tens, shape) in enumerate(zip(tensors, tensor_shapes)):
if tens.shape != shape:
msg = "The tensors[%d] has mismatch shape from the original. Expected: %s, got: %s" % \
(i, tens.shape, shape)
raise RuntimeError(msg)
# duplicate the tensors if the input is unique list of tensors
if unique:
assert self._unique_inverse_idxs, "Please report to Github"
tensors = [tensors[self._unique_inverse_idxs[i]] for i in range(len(self._unique_inverse_idxs))]
else:
# _put_tensors will change the tensors, so this is just to preserve
# the input
tensors = copy(tensors)
# deepcopy the object, except the tensors
memo = copy(self._tensor_memo)
new_obj = deepcopy(self._obj, memo)
new_obj = _put_tensors(new_obj, tensors)
return new_obj
[docs] def construct_from_tensor(self, a: torch.Tensor, unique: bool = True) -> Any:
"""
Construct the object from the single tensor (i.e. it is the parameters
tensor merged into a single tensor) and returns the object structure
with the new tensor. Executing this does not change the state of the
Packer object.
Arguments
---------
a: torch.Tensor
The single tensor parameter to be filled.
unique: bool
Indicating if the tensor ``a`` is from the unique parameters of the
object.
Returns
-------
Any
A new object with the same structure as the input to ``__init__``
object except the tensor is changed according to ``a``.
"""
if unique:
tensor_shapes = self._unique_tensor_shapes
tensor_numel_tot = self._unique_tensor_numel_tot
tensor_numels = self._unique_tensor_numels
else:
tensor_shapes = self._tensor_shapes
tensor_numel_tot = self._tensor_numel_tot
tensor_numels = self._tensor_numels
if tensor_shapes is None:
raise RuntimeError("Please execute self.get_param_tensor(%s) first" % str(unique))
elif len(tensor_shapes) == 0:
return self._obj
else:
assert tensor_numel_tot is not None, "Please report to Github"
assert tensor_numels is not None, "Please report to Github"
if a.numel() != tensor_numel_tot:
msg = "The number of element does not match. Expected: %d, got: %d" % \
(tensor_numel_tot, a.numel())
raise RuntimeError(msg)
if len(tensor_numels) == 1:
params: List[torch.Tensor] = [a]
else:
# reshape the parameters
ioffset = 0
params = []
for i in range(len(tensor_numels)):
p = a[ioffset:ioffset + tensor_numels[i]].reshape(tensor_shapes[i])
ioffset += tensor_numels[i]
params.append(p)
return self.construct_from_tensor_list(params, unique=unique)
def _extract_tensors(b: Any) -> List[torch.Tensor]:
# extract all the tensors from the given object
# this function traverses down the object to collect all the tensors
res: List[torch.Tensor] = []
if isinstance(b, torch.Tensor):
res.append(b)
elif isinstance(b, list):
for elmt in b:
res.extend(_extract_tensors(elmt))
elif isinstance(b, dict):
for elmt in b.values():
res.extend(_extract_tensors(elmt))
elif hasattr(b, "__dict__"):
for elmt in b.__dict__.values():
res.extend(_extract_tensors(elmt))
return res
def _put_tensors(b: Any, tensors: List) -> Any:
# put the tensors recursively in the object, with the same order as
# _extract_tensors.
# the tensors will be changed in this class, so make sure to have
# a shallow copy if you want to preserve your input
if isinstance(b, torch.Tensor):
b = tensors.pop(0)
elif isinstance(b, list):
for i, elmt in enumerate(b):
b[i] = _put_tensors(elmt, tensors)
elif isinstance(b, dict):
for key, elmt in b.items():
b[key] = _put_tensors(elmt, tensors)
elif hasattr(b, "__dict__"):
for key, elmt in b.__dict__.items():
b.__dict__[key] = _put_tensors(elmt, tensors)
return b
def _get_unique_idxs(b: List) -> Tuple[List[int], List[int]]:
# get unique indices based on the ids of the b's elements
# and the index for inversing the unique process
ids_list = [id(bb) for bb in b]
unique_ids: Dict[int, int] = {}
unique_idxs: List[int] = []
unique_inverse: List[int] = []
for i, idnum in enumerate(ids_list):
if idnum in unique_ids:
unique_inverse.append(unique_ids[idnum])
else:
unique_ids[idnum] = len(unique_idxs)
unique_idxs.append(i)
unique_inverse.append(unique_ids[idnum])
return unique_idxs, unique_inverse