ray/rllib/models/repeated_values.py

173 lines
6.4 KiB
Python

from typing import List
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.typing import TensorType, TensorStructType
@PublicAPI
class RepeatedValues:
"""Represents a variable-length list of items from spaces.Repeated.
RepeatedValues are created when you use spaces.Repeated, and are
accessible as part of input_dict["obs"] in ModelV2 forward functions.
Example:
Suppose the gym space definition was:
Repeated(Repeated(Box(K), N), M)
Then in the model forward function, input_dict["obs"] is of type:
RepeatedValues(RepeatedValues(<Tensor shape=(B, M, N, K)>))
The tensor is accessible via:
input_dict["obs"].values.values
And the actual data lengths via:
# outer repetition, shape [B], range [0, M]
input_dict["obs"].lengths
-and-
# inner repetition, shape [B, M], range [0, N]
input_dict["obs"].values.lengths
Attributes:
values: The padded data tensor of shape [B, max_len, ..., sz],
where B is the batch dimension, max_len is the max length of this
list, followed by any number of sub list max lens, followed by the
actual data size.
lengths (List[int]): Tensor of shape [B, ...] that represents the
number of valid items in each list. When the list is nested within
other lists, there will be extra dimensions for the parent list
max lens.
max_len: The max number of items allowed in each list.
TODO(ekl): support conversion to tf.RaggedTensor.
"""
def __init__(self, values: TensorType, lengths: List[int], max_len: int):
self.values = values
self.lengths = lengths
self.max_len = max_len
self._unbatched_repr = None
def unbatch_all(self) -> List[List[TensorType]]:
"""Unbatch both the repeat and batch dimensions into Python lists.
This is only supported in PyTorch / TF eager mode.
This lets you view the data unbatched in its original form, but is
not efficient for processing.
Examples:
>>> batch = RepeatedValues(<Tensor shape=(B, N, K)>)
>>> items = batch.unbatch_all()
>>> print(len(items) == B)
True
>>> print(max(len(x) for x in items) <= N)
True
>>> print(items)
... [[<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
... ...
... [<Tensor_1 shape=(K)>, <Tensor_2 shape=(K)>],
... ...
... [<Tensor_1 shape=(K)>],
... ...
... [<Tensor_1 shape=(K)>, ..., <Tensor_N shape=(K)>]]
"""
if self._unbatched_repr is None:
B = _get_batch_dim_helper(self.values)
if B is None:
raise ValueError(
"Cannot call unbatch_all() when batch_dim is unknown. "
"This is probably because you are using TF graph mode."
)
else:
B = int(B)
slices = self.unbatch_repeat_dim()
result = []
for i in range(B):
if hasattr(self.lengths[i], "item"):
dynamic_len = int(self.lengths[i].item())
else:
dynamic_len = int(self.lengths[i].numpy())
dynamic_slice = []
for j in range(dynamic_len):
dynamic_slice.append(_batch_index_helper(slices, i, j))
result.append(dynamic_slice)
self._unbatched_repr = result
return self._unbatched_repr
def unbatch_repeat_dim(self) -> List[TensorType]:
"""Unbatches the repeat dimension (the one `max_len` in size).
This removes the repeat dimension. The result will be a Python list of
with length `self.max_len`. Note that the data is still padded.
Examples:
>>> batch = RepeatedValues(<Tensor shape=(B, N, K)>)
>>> items = batch.unbatch()
>>> len(items) == batch.max_len
True
>>> print(items)
... [<Tensor_1 shape=(B, K)>, ..., <Tensor_N shape=(B, K)>]
"""
return _unbatch_helper(self.values, self.max_len)
def __repr__(self):
return "RepeatedValues(value={}, lengths={}, max_len={})".format(
repr(self.values), repr(self.lengths), self.max_len
)
def __str__(self):
return repr(self)
def _get_batch_dim_helper(v: TensorStructType) -> int:
"""Tries to find the batch dimension size of v, or None."""
if isinstance(v, dict):
for u in v.values():
return _get_batch_dim_helper(u)
elif isinstance(v, tuple):
return _get_batch_dim_helper(v[0])
elif isinstance(v, RepeatedValues):
return _get_batch_dim_helper(v.values)
else:
B = v.shape[0]
if hasattr(B, "value"):
B = B.value # TensorFlow
return B
def _unbatch_helper(v: TensorStructType, max_len: int) -> TensorStructType:
"""Recursively unpacks the repeat dimension (max_len)."""
if isinstance(v, dict):
return {k: _unbatch_helper(u, max_len) for (k, u) in v.items()}
elif isinstance(v, tuple):
return tuple(_unbatch_helper(u, max_len) for u in v)
elif isinstance(v, RepeatedValues):
unbatched = _unbatch_helper(v.values, max_len)
return [
RepeatedValues(u, v.lengths[:, i, ...], v.max_len)
for i, u in enumerate(unbatched)
]
else:
return [v[:, i, ...] for i in range(max_len)]
def _batch_index_helper(v: TensorStructType, i: int, j: int) -> TensorStructType:
"""Selects the item at the ith batch index and jth repetition."""
if isinstance(v, dict):
return {k: _batch_index_helper(u, i, j) for (k, u) in v.items()}
elif isinstance(v, tuple):
return tuple(_batch_index_helper(u, i, j) for u in v)
elif isinstance(v, list):
# This is the output of unbatch_repeat_dim(). Unfortunately we have to
# process it here instead of in unbatch_all(), since it may be buried
# under a dict / tuple.
return _batch_index_helper(v[j], i, j)
elif isinstance(v, RepeatedValues):
unbatched = v.unbatch_all()
# Don't need to select j here; that's already done in unbatch_all.
return unbatched[i]
else:
return v[i, ...]