mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
92 lines
3.4 KiB
Python
92 lines
3.4 KiB
Python
import numpy as np
|
|
from typing import Dict, Optional
|
|
|
|
from ray.rllib.utils.types import TensorType
|
|
|
|
|
|
class ViewRequirement:
|
|
"""Single view requirement (for one column in a ModelV2 input_dict).
|
|
|
|
Note: This is an experimental class.
|
|
|
|
ModelV2 returns a Dict[str, ViewRequirement] upon calling
|
|
`ModelV2.get_view_requirements()`, where the str key represents the column
|
|
name (C) under which the view is available in the `input_dict` and
|
|
ViewRequirement specifies the actual underlying column names (in the
|
|
original data buffer), timesteps, and other options to build the view
|
|
for N.
|
|
|
|
Examples:
|
|
>>> # The default ViewRequirement for a Model is:
|
|
>>> req = [ModelV2].get_view_requirements(is_training=False)
|
|
>>> print(req)
|
|
{"obs": ViewRequirement(timesteps=0)}
|
|
"""
|
|
|
|
def __init__(self,
|
|
data_col: Optional[str] = None,
|
|
timesteps: int = 0,
|
|
fill_mode: str = "zeros",
|
|
repeat_mode: str = "all"):
|
|
"""Initializes a ViewRequirement object.
|
|
|
|
Args:
|
|
data_col (): The data column name from the SampleBatch (str key).
|
|
If None, use the dict key under which this ViewRequirement
|
|
resides.
|
|
timesteps (Union[List[int], int]): List of relative (or absolute
|
|
timesteps) to be present in the input_dict.
|
|
fill_mode (str): The fill mode in case t<0 or t>H.
|
|
One of "zeros", "tile".
|
|
repeat_mode (str): The repeat-mode (one of "all" or "only_first").
|
|
E.g. for training, we only want the first internal state
|
|
timestep (the NN will calculate all others again anyways).
|
|
"""
|
|
self.data_col = data_col
|
|
self.timesteps = timesteps
|
|
|
|
# Switch on absolute timestep mode. Default: False.
|
|
# TODO: (sven)
|
|
# "absolute_timesteps",
|
|
|
|
self.fill_mode = fill_mode
|
|
self.repeat_mode = repeat_mode
|
|
|
|
# Provide all data as time major (default: False).
|
|
# TODO: (sven)
|
|
# "time_major",
|
|
|
|
|
|
def get_trajectory_view(
|
|
model,
|
|
trajectories,
|
|
is_training: bool = False) -> Dict[str, TensorType]:
|
|
"""Returns an input_dict for a Model's forward pass given some data.
|
|
|
|
Args:
|
|
model (ModelV2): The ModelV2 object for which to generate the view
|
|
(input_dict) from `data`.
|
|
trajectories (List[Trajectory]): The data from which to generate
|
|
an input_dict.
|
|
is_training (bool): Whether the view should be generated for training
|
|
purposes or inference (default).
|
|
|
|
Returns:
|
|
Dict[str, TensorType]: The input_dict to be passed into the ModelV2
|
|
for inference/training.
|
|
"""
|
|
# Get ModelV2's view requirements.
|
|
view_reqs = model.get_view_requirements(is_training=is_training)
|
|
# Construct the view dict.
|
|
view = {}
|
|
for view_col, view_req in view_reqs.items():
|
|
# Create the batch of data from the different buffers in `data`.
|
|
# TODO: (sven): Here, we actually do create a copy of the data (from a
|
|
# list). The only way to avoid this entirely would be to keep a
|
|
# single(!) np buffer per column across all currently ongoing
|
|
# agents + episodes (which seems very hard to realize).
|
|
view[view_col] = np.array([
|
|
t.buffers[view_req.data_col][t.cursor + view_req.timesteps]
|
|
for t in trajectories
|
|
])
|
|
return view
|