[RLlib] Trajectory view API docs. (#12718)

This commit is contained in:
Sven Mika 2020-12-30 20:32:21 -05:00 committed by GitHub
parent 28ac4243f4
commit 391cdfae8c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 571 additions and 173 deletions

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 131 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 111 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 207 KiB

View file

@ -259,6 +259,7 @@ Papers
rllib-env.rst
rllib-models.rst
rllib-algorithms.rst
rllib-sample-collection.rst
rllib-offline.rst
rllib-concepts.rst
rllib-examples.rst

View file

@ -50,7 +50,7 @@ Exploration-based plug-ins (can be combined with any algo)
============================= ========== ======================= ================== =========== =====================
Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support
============================= ========== ======================= ================== =========== =====================
`Curiosity`_ torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
`Curiosity`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
============================= ========== ======================= ================== =========== =====================
.. _`A2C, A3C`: rllib-algorithms.html#a3c

View file

@ -5,7 +5,7 @@ RLlib works with several different types of environments, including `OpenAI Gym
.. tip::
Not all environments work with all algorithms. Check out the algorithm `feature compatibility matrix <rllib-algorithms.html#feature-compatibility-matrix>`__ for more information.
Not all environments work with all algorithms. Check out the `algorithm overview <rllib-algorithms.html#available-algorithms-overview>`__ for more information.
.. image:: rllib-envs.svg

View file

@ -389,7 +389,7 @@ Custom models can be used to work with environments where (1) the set of valid a
return action_logits + inf_mask, state
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_actions_cartpole.py <https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_actions_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `feature compatibility matrix <rllib-env.html#feature-compatibility-matrix>`__.
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_actions_cartpole.py <https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_actions_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `algorithm overview <rllib-algorithms.html#available-algorithms-overview>`__.
Autoregressive Action Distributions

View file

@ -0,0 +1,337 @@
RLlib Sample Collection and Trajectory Views
============================================
The SampleCollector Class is Used to Store and Retrieve Temporary Data
----------------------------------------------------------------------
RLlib's `RolloutWorkers <https://github.com/ray-project/ray/blob/master/rllib/evaluation/rollout_worker.py>`__,
when running against a live environment, use the ``SamplerInput`` class to interact
with that environment and produce batches of experiences.
The two implemented sub-classes of ``SamplerInput`` are ``SyncSampler`` and ``AsyncSampler``
(residing under the ``RolloutWorker.sampler`` property).
In case the "_use_trajectory_view_api" top-level config key is set to True
(by default since version >=1.1.0), every such sampler object will use the
``SampleCollector`` API to store and retrieve temporary environment-, model-, and other data
during rollouts (see figure below).
.. Edit figure below at: https://docs.google.com/drawings/d/1ZdNUU3ChwiUeT-DBRxvLAsbEPEqEFWSPZcOyVy3KxVg/edit
.. image:: images/rllib-sample-collection.svg
**Sample collection process implemented by RLlib:**
The Policy's model tells the Sampler and its SampleCollector object, which data to store and
how to present it back to the dependent methods (e.g. `Model.compute_actions()`).
This is done using a dict that maps strings (column names) to `ViewRequirement` objects (details see below).
The exact behavior for a single such rollout and the number of environment transitions therein
are determined by the following Trainer config keys:
**batch_mode [truncate_episodes|complete_episodes]**:
*truncated_episodes (default value)*:
Rollouts are performed
over exactly ``rollout_fragment_length`` (see below) number of steps. Thereby, steps are
counted as either environment steps or as individual agent steps (see ``count_steps_as`` below).
It does not matter, whether one or more episodes end within this rollout or whether
the rollout starts in the middle of an already ongoing episode.
*complete_episodes*:
Each rollout is exactly one episode long and always starts
at the beginning of an episode. It does not matter how long an episode lasts.
The ``rollout_fragment_length`` setting will be ignored. Note that you have to be
careful when chosing ``complete_episodes`` as batch_mode: If your environment does not
terminate easily, this setting could lead to enormous batch sizes.
**rollout_fragment_length [int]**:
The exact number of environment- or agent steps to
be performed per rollout, if the ``batch_mode`` setting (see above) is "truncate_episodes".
If ``batch_mode`` is "complete_episodes", ``rollout_fragment_length`` is ignored,
The unit to count fragments in is set via ``multiagent.count_steps_by=[env_steps|agent_steps]``
(within the ``multiagent`` config dict).
.. Edit figure below at: https://docs.google.com/drawings/d/1uRNGImBNq8gv3bBoFX_HernGyeovtCB3wKpZ71c0VE4/edit
.. image:: images/rllib-batch-modes.svg
**Above:** The two supported batch modes in RLlib. For "truncated_episodes",
batches can a) span over more than one episode, b) end in the middle of an episode, and
c) start in the middle of an episode. Also, `Policy.postprocess_trajectory()` is always
called at the end of a rollout-fragment (red lines on right side) as well as at the end
of each episode (arrow heads). This way, RLlib makes sure that the
`Policy.postprocess_trajectory()` method never sees data from more than one episode.
**multiagent.count_steps_by [env_steps|agent_steps]**:
Within the Trainer's ``multiagent`` config dict, you can set the unit, by which RLlib will count a) rollout fragment lengths as well as b) the size of the final train_batch (see below). The two supported values are:
*env_steps (default)*:
Each call to ``[Env].step()`` is counted as one. It does not
matter, how many individual agents are stepping simultaneously in this very call
(not all existing agents in the environment may step at the same time).
*agent_steps*:
In a multi-agent environment, count each individual agent's step
as one. For example, if N agents are in an environment and all these N agents
always step at the same time, a single env step corresponds to N agent steps.
Note that in the single-agent case, ``env_steps`` and ``agent_steps`` are the same thing.
**horizon [int]**:
Some environments are limited by default in the number of maximum timesteps
an episode can last. This limit is called the "horizon" of an episode.
For example, for CartPole-v0, the maximum number of steps per episode is 200 by default.
You can overwrite this setting, however, by using the ``horizon`` config.
If provided, RLlib will first try to increase the environment's built-in horizon
setting (e.g. openAI gym Envs have a ``spec.max_episode_steps`` property), if the user
provided horizon is larger than this env-specific setting. In either case, no episode
is allowed to exceed the given ``horizon`` number of timesteps (RLlib will
artificially terminate an episode if this limit is hit).
**soft_horizon [bool]**:
False by default. If set to True, the environment will
a) not be reset when reaching ``horizon`` and b) no ``done=True`` will be set
in the trajectory data sent to the postprocessors and training (``done`` will remain
False at the horizon).
**no_done_at_end [bool]**:
Never set ``done=True``, at the end of an episode or when any
artificial horizon is reached.
To trigger a single rollout, RLlib calls ``RolloutWorker.sample()``, which returns
a SampleBatch or MultiAgentBatch object representing all the data collected during that
rollout. These batches are then usually further concatenated (from the ``num_workers``
parallelized RolloutWorkers) to form a final train batch. The size of that train batch is determined
by the ``train_batch_size`` config parameter. Train batches are usually sent to the Policy's
``learn_on_batch`` method, which handles loss- and gradient calculations, and optimizer stepping.
RLlib's default ``SampleCollector`` class is the ``SimpleListCollector``, which appends single timestep data (e.g. actions)
to lists, then builds SampleBatches from these and sends them to the downstream processing functions.
It thereby tries to avoid collecting duplicate data separately (OBS and NEXT_OBS use the same underlying list).
If you want to implement your own collection logic and data structures, you can sub-class ``SampleCollector``
and specify that new class under the Trainer's "sample_collector" config key.
Let's now look at how the Policy's Model lets the RolloutWorker and its SampleCollector
know, what data in the ongoing episode/trajectory to use for the different required method calls
during rollouts. These method calls in particular are:
``Policy.compute_actions_from_input_dict()`` to compute actions to be taken in an episode.
``Policy.postprocess_trajectory()``, which is called after an episode ends or a rollout hit its
``rollout_fragment_length`` limit (in ``batch_mode=truncated_episodes``), and ``Policy.learn_on_batch()``,
which is called with a "train_batch" to improve the policy.
Trajectory View API
-------------------
The trajectory view API allows custom models to define what parts of the trajectory they
require in order to execute the forward pass. For example, in the simplest case, a model might
only look at the latest observation. However, an RNN- or attention based model could look
at previous states emitted by the model, concatenate previously seen rewards with the current observation,
or require the entire range of the n most recent observations.
The trajectory view API lets models define these requirements and lets RLlib gather the required
data for the forward pass in an efficient way.
Since the following methods all call into the model class, they are all indirectly using the trajectory view API.
It is important to note that the API is only accessible to the user via the model classes
(see below on how to setup trajectory view requirements for a custom model).
In particular, the methods receiving inputs that depend on a Model's trajectory view rules are:
a) ``Policy.compute_actions_from_input_dict()``
b) ``Policy.postprocess_trajectory()`` and
c) ``Policy.learn_on_batch()`` (and consecutively: the Policy's loss function).
The input data to these methods can stem from either the environment (observations, rewards, and env infos),
the model itself (previously computed actions, internal state outputs, action-probs, etc..)
or the Sampler (e.g. agent index, env ID, episode ID, timestep, etc..).
All data has an associated time axis, which is 0-based, meaning that the first action taken, the
first reward received in an episode, and the first observation (directly after a reset)
all have t=0.
The idea is to allow more flexibility and standardization in how a model defines required
"views" on the ongoing trajectory (during action computations/inference), past episodes (training
on a batch), or even trajectories of other agents in the same episode, some of which
may even use a different policy.
Such a "view requirements" formalism is helpful when having to support more complex model
setups like RNNs, attention nets, observation image framestacking (e.g. for Atari),
and building multi-agent communication channels.
The way to define a set of rules used for making the Model see certain
data is through a "view requirements dict", residing in the ``Policy.model.view_requirements``
property.
View requirements dicts map strings (column names), such as "obs" or "actions" to
a ``ViewRequirement`` object, which defines the exact conditions by which this column
should be populated with data.
View Requirement Dictionaries
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
View requirements are stored within the ``view_requirements`` property of the ``ModelV2``
class.
You can acccess it like this:
.. code-block:: python
my_simple_model = ModelV2(...)
print(my_simple_model.view_requirements)
>>>{"obs": ViewRequirement(shift=0, space=[observation space])}
my_lstm_model = LSTMModel(...)
print(my_lstm_model.view_requirements)
>>>{
>>> "obs": ViewRequirement(shift=0, space=[observation space]),
>>> "prev_actions": ViewRequirement(shift=-1, data_col="actions", space=[action space]),
>>> "prev_rewards": ViewRequirement(shift=-1, data_col="rewards"),
>>>}
The ``view_requirements`` property holds a dictionary mapping
string keys (e.g. "actions", "rewards", "next_obs", etc..)
to a ``ViewRequirement`` object. This ``ViewRequirement`` object determines what exact data to
provide under the given key in case a SampleBatch or a single-timestep (action computing) "input dict"
needs to be build and fed into one of the above ModelV2- or Policy methods.
.. Edit figure below at: https://docs.google.com/drawings/d/1YEPUtMrRXmWfvM0E6mD3VsOaRlLV7DtctF-yL96VHGg/edit
.. image:: images/rllib-trajectory-view-example.svg
**Above:** An example `ViewRequirements` dict that causes the current observation
and the previous action to be available in each compute_action call, as
well as for the Policy's `postprocess_trajectory()` function (and train batch).
A similar setup is often used by LSTM/RNN-based models.
The ViewRequirement class
~~~~~~~~~~~~~~~~~~~~~~~~~
Here is a description of the constructor-settable properties of a ViewRequirement
object and what each of these properties controls.
**data_col**:
An optional string key referencing the underlying data to use to
create the view. If not provided, assumes that there is data under the
dict-key under which this ViewRequirement resides.
Examples:
.. code-block:: python
ModelV2.view_requirements = {"rewards": ViewRequirements(shift=0)}
# -> implies that the underlying data to use are the collected rewards
# from the environment.
ModelV2.view_requirements = {"prev_rewards": ViewRequirements(data_col="rewards", shift=-1)}
# -> means that the actual data used to create the "prev_rewards" column
# is the "rewards" data from the environment (shifted by 1 timestep).
**space**:
An optional gym.Space used as a hint for the SampleCollector to know,
how to fill timesteps before the episode actually started (e.g. if
shift=-2, we need dummy data at timesteps -2 and -1).
**shift [int]**:
An int, a list of ints, or a range string (e.g. "-50:-1") to indicate
which time offsets or ranges of the underlying data to use for the view.
Examples:
.. code-block:: python
shift=0 # -> Use the data under ``data_col`` as is.
shift=1 # -> Use the data under ``data_col``, but shifted by +1 timestep
# (used by e.g. next_obs views).
shift=-1 # -> Use the data under ``data_col``, but shifted by -1 timestep
# (used by e.g. prev_actions views).
shift=[-2, -1] # -> Use the data under ``data_col``, but always provide 2 values
# at each timestep (the two previous ones).
# Could be used e.g. to feed the last two actions or rewards into an LSTM.
shift="-50:-1" # -> Use the data under ``data_col``, but always provide a range of
# the last 50 timesteps (used by our attention nets).
**used_for_training [bool]**:
True by default. If False, the column will not be available inside the train batch (arriving in the
Policy's loss function).
RLlib will automatically switch this to False for a given column, if it detects during
Policy initialization that that column is not accessed inside the loss function (see below).
How does RLlib determine, which Views are required?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
When initializing a Policy, it automatically determines how to later build batches
for postprocessing, loss function calls, and action computations, based on
the Model's ``view_requirements`` dict. It does so by sending generic dummy batches
through its ``compute_actions_from_input_dict``, ``postprocess_trajectory``, and loss functions
and then checks, which fields in these dummy batches get accessed, overwritten, deleted or added.
Based on these test passes, the Policy then throws out those ViewRequirements from an initial
very broad list, that it deems unnecessary. This procedure saves a lot of data copying
during later rollouts, batch transfers (via ray) and loss calculations and makes things like
manually deleting columns from a SampleBatch (e.g. PPO used to delete the "next_obs" column
inside the postprocessing function) unnecessary.
Note that the "rewards" and "dones" columns are never discarded and thus should always
arrive in your loss function's SampleBatch (``train_batch`` arg).
Setting ViewRequirements manually in your Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If you need to specify special view requirements for your model, you can add
columns to the Model's ``view_requirements`` dict in the
Model's constructor.
For example, our auto-LSTM wrapper classes (tf and torch) have these additional
lines in their constructors (torch version shown here):
.. literalinclude:: ../../rllib/models/torch/recurrent_net.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
This makes sure that, if the users requires this via the model config, previous rewards
and/or previous actions are added properly to the ``compute_actions`` input-dicts and SampleBatches
used for postprocessing and training.
Another example are our attention nets, which make sure the last n (memory) model outputs
are always fed back into the model on the next time step (tf version shown here).
.. literalinclude:: ../../rllib/models/tf/attention_net.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Setting ViewRequirements manually after Policy construction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Here is a simple example, of how you can modify and add to the ViewRequirements dict
even after policy (or RolloutWorker) creation. However, note that it's better to
make these modifications to your batches in your postprocessing function:
.. code-block:: python
# Modify view_requirements in the Policy object.
action_space = Discrete(2)
rollout_worker = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_config=ppo.DEFAULT_CONFIG,
policy_spec=ppo.PPOTorchPolicy,
)
policy = rollout_worker.policy_map["default_policy"]
# Add the next action to the view reqs of the policy.
# This should be visible then in postprocessing and train batches.
policy.view_requirements["next_actions"] = ViewRequirement(
SampleBatch.ACTIONS, shift=1, space=action_space)
# Check, whether a sampled batch has the requested `next_actions` view.
batch = rollout_worker.sample()
self.assertTrue("next_actions" in batch.data)
# Doing the same in a custom postprocessing callback function:
class MyCallback(DefaultCallbacks):
# ...
@override(DefaultCallbacks)
def on_postprocess_trajectory(self, worker, episode, agent_id, policy_id,
policies, postprocessed_batch, original_batches,
**kwargs):
postprocessed_batch["next_actions"] = np.concatenate(
[postprocessed_batch["actions"][1:],
np.zeros_like([policies[policy_id].action_space.sample()])])
The above two examples add a "next_action" view to the postprocessed SampleBatche needed
used by the Policy for training. It will not feed the "next_action"
to the Model's ``compute_action`` calls (it can't b/c the next action is of course not known
at that point).

View file

@ -60,7 +60,6 @@ Training APIs
Environments
------------
* `RLlib Environments Overview <rllib-env.html>`__
* `Feature Compatibility Matrix <rllib-env.html#feature-compatibility-matrix>`__
* `OpenAI Gym <rllib-env.html#openai-gym>`__
* `Vectorized <rllib-env.html#vectorized>`__
* `Multi-Agent and Hierarchical <rllib-env.html#multi-agent-and-hierarchical>`__
@ -144,6 +143,11 @@ Algorithms
- |pytorch| :ref:`Curiosity (ICM: Intrinsic Curiosity Module) <curiosity>`
Sample Collection
-----------------
* `The SampleCollector Class is Used to Store and Retrieve Temporary Data <rllib-sample-collection.html#the-samplecollector-class-is-used-to-store-and-retrieve-temporary-data>`__
* `Trajectory View API <rllib-sample-collection.html#trajectory-view-api>`__
Offline Datasets
----------------

View file

@ -110,7 +110,12 @@ Beyond environments defined in Python, RLlib supports batch training on `offline
Customization
~~~~~~~~~~~~~
RLlib provides ways to customize almost all aspects of training, including the `environment <rllib-env.html#configuring-environments>`__, `neural network model <rllib-models.html#tensorflow-models>`__, `action distribution <rllib-models.html#custom-action-distributions>`__, and `policy definitions <rllib-concepts.html#policies>`__:
RLlib provides ways to customize almost all aspects of training, including
`neural network models <rllib-models.html#tensorflow-models>`__,
`action distributions <rllib-models.html#custom-action-distributions>`__,
`policy definitions <rllib-concepts.html#policies>`__:
the `environment <rllib-env.html#configuring-environments>`__,
and the `sample collection process <rllib-sample-collection.html>`__
.. image:: rllib-components.svg

View file

@ -14,6 +14,8 @@ from ray.exceptions import RayError
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env.normalize_actions import NormalizeActionWrapper
from ray.rllib.env.env_context import EnvContext
from ray.rllib.evaluation.collectors.simple_list_collector import \
SimpleListCollector
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy import Policy
@ -231,6 +233,10 @@ COMMON_CONFIG: TrainerConfigDict = {
# generic ModelV2 `input_dicts` that can be requested by the model to
# contain different information on the ongoing episode.
"_use_trajectory_view_api": True,
# The SampleCollector class to be used to collect and retrieve
# environment-, model-, and sampler data. Override the SampleCollector base
# class to implement your own collection/buffering/retrieval logic.
"sample_collector": SimpleListCollector,
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter".
"observation_filter": "NoFilter",

View file

@ -1,16 +1,22 @@
from abc import abstractmethod, ABCMeta
import logging
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
TensorType
if TYPE_CHECKING:
from ray.rllib.agents.callbacks import DefaultCallbacks
logger = logging.getLogger(__name__)
class _SampleCollector(metaclass=ABCMeta):
# yapf: disable
# __sphinx_doc_begin__
class SampleCollector(metaclass=ABCMeta):
"""Collects samples for all policies and agents from a multi-agent env.
Note: This is an experimental class only used when
@ -29,6 +35,34 @@ class _SampleCollector(metaclass=ABCMeta):
communication channel).
"""
def __init__(self,
policy_map: Dict[PolicyID, Policy],
clip_rewards: Union[bool, float],
callbacks: "DefaultCallbacks",
multiple_episodes_in_batch: bool = True,
rollout_fragment_length: int = 200,
count_steps_by: str = "env_steps"):
"""Initializes a SampleCollector instance.
Args:
policy_map (Dict[str, Policy]): Maps policy ids to policy
instances.
clip_rewards (Union[bool, float]): Whether to clip rewards before
postprocessing (at +/-1.0) or the actual value to +/- clip.
callbacks (DefaultCallbacks): RLlib callbacks.
multiple_episodes_in_batch (bool): Whether it's allowed to pack
multiple episodes into the same built batch.
rollout_fragment_length (int): The
"""
self.policy_map = policy_map
self.clip_rewards = clip_rewards
self.callbacks = callbacks
self.multiple_episodes_in_batch = multiple_episodes_in_batch
self.rollout_fragment_length = rollout_fragment_length
self.count_steps_by = count_steps_by
@abstractmethod
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
policy_id: PolicyID, t: int,
@ -55,7 +89,7 @@ class _SampleCollector(metaclass=ABCMeta):
Examples:
>>> obs = env.reset()
>>> collector.add_init_obs(12345, 0, "pol0", obs)
>>> collector.add_init_obs(my_episode, 0, "pol0", -1, obs)
>>> obs, r, done, info = env.step(action)
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
... "action": action, "obs": obs, "reward": r, "done": done
@ -227,3 +261,4 @@ class _SampleCollector(metaclass=ABCMeta):
`self.rollout_fragment_length` has not been reached yet.
"""
raise NotImplementedError
# __sphinx_doc_end__

View file

@ -6,7 +6,7 @@ import numpy as np
from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.collectors.sample_collector import _SampleCollector
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
@ -379,7 +379,7 @@ class _PolicyCollectorGroup:
self.agent_steps = 0
class _SimpleListCollector(_SampleCollector):
class SimpleListCollector(SampleCollector):
"""Util to build SampleBatches for each policy in a multi-agent env.
Input data is per-agent, while output data is per-policy. There is an M:N
@ -395,29 +395,15 @@ class _SimpleListCollector(_SampleCollector):
multiple_episodes_in_batch: bool = True,
rollout_fragment_length: int = 200,
count_steps_by: str = "env_steps"):
"""Initializes a _SimpleListCollector instance.
"""Initializes a SimpleListCollector instance."""
Args:
policy_map (Dict[str, Policy]): Maps policy ids to policy
instances.
clip_rewards (Union[bool, float]): Whether to clip rewards before
postprocessing (at +/-1.0) or the actual value to +/- clip.
callbacks (DefaultCallbacks): RLlib callbacks.
multiple_episodes_in_batch (bool): Whether it's allowed to pack
multiple episodes into the same built batch.
rollout_fragment_length (int): The
super().__init__(policy_map, clip_rewards, callbacks,
multiple_episodes_in_batch, rollout_fragment_length,
count_steps_by)
"""
self.policy_map = policy_map
self.clip_rewards = clip_rewards
self.callbacks = callbacks
self.multiple_episodes_in_batch = multiple_episodes_in_batch
self.rollout_fragment_length = rollout_fragment_length
self.count_steps_by = count_steps_by
self.large_batch_threshold: int = max(
1000, rollout_fragment_length *
10) if rollout_fragment_length != float("inf") else 5000
1000, self.rollout_fragment_length *
10) if self.rollout_fragment_length != float("inf") else 5000
# Whenever we observe a new episode+agent, add a new
# _SingleTrajectoryCollector.
@ -430,8 +416,9 @@ class _SimpleListCollector(_SampleCollector):
self.policy_collector_groups = []
# Agents to collect data from for the next forward pass (per policy).
self.forward_pass_agent_keys = {pid: [] for pid in policy_map.keys()}
self.forward_pass_size = {pid: 0 for pid in policy_map.keys()}
self.forward_pass_agent_keys = \
{pid: [] for pid in self.policy_map.keys()}
self.forward_pass_size = {pid: 0 for pid in self.policy_map.keys()}
# Maps episode ID to the (non-built) env steps taken in this episode.
self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
@ -441,7 +428,7 @@ class _SimpleListCollector(_SampleCollector):
# Maps episode ID to MultiAgentEpisode.
self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {}
@override(_SampleCollector)
@override(SampleCollector)
def episode_step(self, episode_id: EpisodeID) -> None:
episode = self.episodes[episode_id]
self.episode_steps[episode_id] += 1
@ -470,7 +457,7 @@ class _SimpleListCollector(_SampleCollector):
"does at some point."
if not self.multiple_episodes_in_batch else ""))
@override(_SampleCollector)
@override(SampleCollector)
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
env_id: EnvID, policy_id: PolicyID, t: int,
init_obs: TensorType) -> None:
@ -481,7 +468,7 @@ class _SimpleListCollector(_SampleCollector):
else:
assert self.agent_key_to_policy_id[agent_key] == policy_id
policy = self.policy_map[policy_id]
view_reqs = policy.model.inference_view_requirements if \
view_reqs = policy.model.view_requirements if \
getattr(policy, "model", None) else policy.view_requirements
# Add initial obs to Trajectory.
@ -503,7 +490,7 @@ class _SimpleListCollector(_SampleCollector):
self._add_to_next_inference_call(agent_key)
@override(_SampleCollector)
@override(SampleCollector)
def add_action_reward_next_obs(self, episode_id: EpisodeID,
agent_id: AgentID, env_id: EnvID,
policy_id: PolicyID, agent_done: bool,
@ -525,14 +512,14 @@ class _SimpleListCollector(_SampleCollector):
if not agent_done:
self._add_to_next_inference_call(agent_key)
@override(_SampleCollector)
@override(SampleCollector)
def total_env_steps(self) -> int:
# Add the non-built ongoing-episode env steps + the already built
# env-steps.
return sum(self.episode_steps.values()) + sum(
pg.env_steps for pg in self.policy_collector_groups.values())
@override(_SampleCollector)
@override(SampleCollector)
def total_agent_steps(self) -> int:
# Add the non-built ongoing-episode agent steps (still in the agent
# collectors) + the already built agent steps.
@ -540,13 +527,13 @@ class _SimpleListCollector(_SampleCollector):
sum(pg.agent_steps for pg in
self.policy_collector_groups.values())
@override(_SampleCollector)
@override(SampleCollector)
def get_inference_input_dict(self, policy_id: PolicyID) -> \
Dict[str, TensorType]:
policy = self.policy_map[policy_id]
keys = self.forward_pass_agent_keys[policy_id]
buffers = {k: self.agent_collectors[k].buffers for k in keys}
view_reqs = policy.model.inference_view_requirements if \
view_reqs = policy.model.view_requirements if \
getattr(policy, "model", None) else policy.view_requirements
input_dict = {}
@ -592,7 +579,7 @@ class _SimpleListCollector(_SampleCollector):
return input_dict
@override(_SampleCollector)
@override(SampleCollector)
def postprocess_episode(
self,
episode: MultiAgentEpisode,
@ -725,7 +712,7 @@ class _SimpleListCollector(_SampleCollector):
return ma_batch
@override(_SampleCollector)
@override(SampleCollector)
def try_build_truncated_episode_multi_agent_batch(self) -> \
List[Union[MultiAgentBatch, SampleBatch]]:
batches = []

View file

@ -28,7 +28,10 @@ class MultiAgentEpisode:
episode_id (int): Unique id identifying this trajectory.
agent_rewards (dict): Summed rewards broken down by agent.
custom_metrics (dict): Dict where the you can add custom metrics.
user_data (dict): Dict that you can use for temporary storage.
user_data (dict): Dict that you can use for temporary storage. E.g.
in between two custom callbacks referring to the same episode.
hist_data (dict): Dict mapping str keys to List[float] for storage of
per-timestep float data throughout the episode.
Use case 1: Model-based rollouts in multi-agent:
A custom compute_actions() function in a policy can inspect the

View file

@ -473,7 +473,7 @@ class RolloutWorker(ParallelIteratorWorker):
# state.
for pol in self.policy_map.values():
if not pol._model_init_state_automatically_added:
pol._update_model_inference_view_requirements_from_init_state()
pol._update_model_view_requirements_from_init_state()
if (ray.is_initialized()
and ray.worker._mode() != ray.worker.LOCAL_MODE):
@ -543,6 +543,7 @@ class RolloutWorker(ParallelIteratorWorker):
raise ValueError("Unsupported batch mode: {}".format(
self.batch_mode))
# Create the IOContext for this worker.
self.io_context: IOContext = IOContext(log_dir, policy_config,
worker_index, self)
self.reward_estimators: List[OffPolicyEstimator] = []
@ -586,6 +587,8 @@ class RolloutWorker(ParallelIteratorWorker):
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
_use_trajectory_view_api=_use_trajectory_view_api,
sample_collector_class=policy_config.get(
"sample_collector_class"),
)
# Start the Sampler thread.
self.sampler.start()
@ -609,6 +612,8 @@ class RolloutWorker(ParallelIteratorWorker):
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
_use_trajectory_view_api=_use_trajectory_view_api,
sample_collector_class=policy_config.get(
"sample_collector_class"),
)
self.input_reader: InputReader = input_creator(self.io_context)

View file

@ -6,13 +6,13 @@ import queue
import threading
import time
from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\
TYPE_CHECKING, Union
Type, TYPE_CHECKING, Union
from ray.util.debug import log_once
from ray.rllib.evaluation.collectors.sample_collector import \
_SampleCollector
SampleCollector
from ray.rllib.evaluation.collectors.simple_list_collector import \
_SimpleListCollector
SimpleListCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.evaluation.sample_batch_builder import \
@ -119,26 +119,28 @@ class SyncSampler(SamplerInput):
"""Sync SamplerInput that collects experiences when `get_data()` is called.
"""
def __init__(self,
*,
worker: "RolloutWorker",
env: BaseEnv,
policies: Dict[PolicyID, Policy],
policy_mapping_fn: Callable[[AgentID], PolicyID],
preprocessors: Dict[PolicyID, Preprocessor],
obs_filters: Dict[PolicyID, Filter],
clip_rewards: bool,
rollout_fragment_length: int,
count_steps_by: str = "env_steps",
callbacks: "DefaultCallbacks",
horizon: int = None,
multiple_episodes_in_batch: bool = False,
tf_sess=None,
clip_actions: bool = True,
soft_horizon: bool = False,
no_done_at_end: bool = False,
observation_fn: "ObservationFunction" = None,
_use_trajectory_view_api: bool = False):
def __init__(
self,
*,
worker: "RolloutWorker",
env: BaseEnv,
policies: Dict[PolicyID, Policy],
policy_mapping_fn: Callable[[AgentID], PolicyID],
preprocessors: Dict[PolicyID, Preprocessor],
obs_filters: Dict[PolicyID, Filter],
clip_rewards: bool,
rollout_fragment_length: int,
count_steps_by: str = "env_steps",
callbacks: "DefaultCallbacks",
horizon: int = None,
multiple_episodes_in_batch: bool = False,
tf_sess=None,
clip_actions: bool = True,
soft_horizon: bool = False,
no_done_at_end: bool = False,
observation_fn: "ObservationFunction" = None,
_use_trajectory_view_api: bool = False,
sample_collector_class: Optional[Type[SampleCollector]] = None):
"""Initializes a SyncSampler object.
Args:
@ -178,6 +180,9 @@ class SyncSampler(SamplerInput):
_use_trajectory_view_api (bool): Whether to use the (experimental)
`_use_trajectory_view_api` to make generic trajectory views
available to Models. Default: False.
sample_collector_class (Optional[Type[SampleCollector]]): An
optional Samplecollector sub-class to use to collect, store,
and retrieve environment-, model-, and sampler data.
"""
self.base_env = BaseEnv.to_base_env(env)
@ -190,7 +195,9 @@ class SyncSampler(SamplerInput):
self.extra_batches = queue.Queue()
self.perf_stats = _PerfStats()
if _use_trajectory_view_api:
self.sample_collector = _SimpleListCollector(
if not sample_collector_class:
sample_collector_class = SimpleListCollector
self.sample_collector = sample_collector_class(
policies,
clip_rewards,
callbacks,
@ -249,27 +256,30 @@ class AsyncSampler(threading.Thread, SamplerInput):
from where they can be unqueued by the caller of `get_data()`.
"""
def __init__(self,
*,
worker: "RolloutWorker",
env: BaseEnv,
policies: Dict[PolicyID, Policy],
policy_mapping_fn: Callable[[AgentID], PolicyID],
preprocessors: Dict[PolicyID, Preprocessor],
obs_filters: Dict[PolicyID, Filter],
clip_rewards: bool,
rollout_fragment_length: int,
count_steps_by: str = "env_steps",
callbacks: "DefaultCallbacks",
horizon: int = None,
multiple_episodes_in_batch: bool = False,
tf_sess=None,
clip_actions: bool = True,
blackhole_outputs: bool = False,
soft_horizon: bool = False,
no_done_at_end: bool = False,
observation_fn: "ObservationFunction" = None,
_use_trajectory_view_api: bool = False):
def __init__(
self,
*,
worker: "RolloutWorker",
env: BaseEnv,
policies: Dict[PolicyID, Policy],
policy_mapping_fn: Callable[[AgentID], PolicyID],
preprocessors: Dict[PolicyID, Preprocessor],
obs_filters: Dict[PolicyID, Filter],
clip_rewards: bool,
rollout_fragment_length: int,
count_steps_by: str = "env_steps",
callbacks: "DefaultCallbacks",
horizon: int = None,
multiple_episodes_in_batch: bool = False,
tf_sess=None,
clip_actions: bool = True,
blackhole_outputs: bool = False,
soft_horizon: bool = False,
no_done_at_end: bool = False,
observation_fn: "ObservationFunction" = None,
_use_trajectory_view_api: bool = False,
sample_collector_class: Optional[Type[SampleCollector]] = None,
):
"""Initializes a AsyncSampler object.
Args:
@ -313,6 +323,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
_use_trajectory_view_api (bool): Whether to use the (experimental)
`_use_trajectory_view_api` to make generic trajectory views
available to Models. Default: False.
sample_collector_class (Optional[Type[SampleCollector]]): An
optional Samplecollector sub-class to use to collect, store,
and retrieve environment-, model-, and sampler data.
"""
for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \
@ -343,7 +356,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.observation_fn = observation_fn
self._use_trajectory_view_api = _use_trajectory_view_api
if _use_trajectory_view_api:
self.sample_collector = _SimpleListCollector(
if not sample_collector_class:
sample_collector_class = SimpleListCollector
self.sample_collector = sample_collector_class(
policies,
clip_rewards,
callbacks,
@ -441,7 +456,7 @@ def _env_runner(
no_done_at_end: bool,
observation_fn: "ObservationFunction",
_use_trajectory_view_api: bool = False,
_sample_collector: Optional[_SampleCollector] = None,
sample_collector: Optional[SampleCollector] = None,
) -> Iterable[SampleBatchType]:
"""This implements the common experience collection logic.
@ -480,8 +495,8 @@ def _env_runner(
_use_trajectory_view_api (bool): Whether to use the (experimental)
`_use_trajectory_view_api` to make generic trajectory views
available to Models. Default: False.
_sample_collector (Optional[_SampleCollector]): An optional
_SampleCollector object to use
sample_collector (Optional[SampleCollector]): An optional
SampleCollector object to use
Yields:
rollout (SampleBatch): Object containing state, action, reward,
@ -498,20 +513,27 @@ def _env_runner(
# Trainer has a given `horizon` setting.
if horizon:
# `horizon` is larger than env's limit -> Error and explain how
# to increase Env's own episode limit.
# `horizon` is larger than env's limit.
if max_episode_steps and horizon > max_episode_steps:
raise ValueError(
"Your `horizon` setting ({}) is larger than the Env's own "
"timestep limit ({})! Try to increase the Env's limit via "
"setting its `spec.max_episode_steps` property.".format(
horizon, max_episode_steps))
# Try to override the env's own max-step setting with our horizon.
# If this won't work, throw an error.
try:
base_env.get_unwrapped()[0].spec.max_episode_steps = horizon
base_env.get_unwrapped()[0]._max_episode_steps = horizon
except Exception:
raise ValueError(
"Your `horizon` setting ({}) is larger than the Env's own "
"timestep limit ({}), which seems to be unsettable! Try "
"to increase the Env's built-in limit to be at least as "
"large as your wanted `horizon`.".format(
horizon, max_episode_steps))
# Otherwise, set Trainer's horizon to env's max-steps.
elif max_episode_steps:
horizon = max_episode_steps
logger.debug(
"No episode horizon specified, setting it to Env's limit ({}).".
format(max_episode_steps))
# No horizon/max_episode_steps -> Episodes may be infinitely long.
else:
horizon = float("inf")
logger.debug("No episode horizon specified, assuming inf.")
@ -594,7 +616,7 @@ def _env_runner(
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
_sample_collector=_sample_collector,
sample_collector=sample_collector,
)
else:
active_envs, to_eval, outputs = _process_observations(
@ -628,7 +650,7 @@ def _env_runner(
eval_results = _do_policy_eval_w_trajectory_view_api(
to_eval=to_eval,
policies=policies,
_sample_collector=_sample_collector,
sample_collector=sample_collector,
active_episodes=active_episodes,
tf_sess=tf_sess,
)
@ -653,7 +675,7 @@ def _env_runner(
policies=policies,
clip_actions=clip_actions,
_use_trajectory_view_api=_use_trajectory_view_api,
_sample_collector=_sample_collector,
sample_collector=sample_collector,
)
perf_stats.action_processing_time += time.time() - t3
@ -968,7 +990,7 @@ def _process_observations_w_trajectory_view_api(
soft_horizon: bool,
no_done_at_end: bool,
observation_fn: "ObservationFunction",
_sample_collector: _SampleCollector,
sample_collector: SampleCollector,
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
RolloutMetrics, SampleBatchType]]]:
"""Trajectory View API version of `_process_observations()`.
@ -987,7 +1009,7 @@ def _process_observations_w_trajectory_view_api(
episode: MultiAgentEpisode = active_episodes[env_id]
if not is_new_episode:
_sample_collector.episode_step(episode.episode_id)
sample_collector.episode_step(episode.episode_id)
episode._add_agent_rewards(rewards[env_id])
# Check episode termination conditions.
@ -1051,9 +1073,9 @@ def _process_observations_w_trajectory_view_api(
# Record transition info if applicable.
if last_observation is None:
_sample_collector.add_init_obs(episode, agent_id, env_id,
policy_id, episode.length - 1,
filtered_obs)
sample_collector.add_init_obs(episode, agent_id, env_id,
policy_id, episode.length - 1,
filtered_obs)
else:
# Add actions, rewards, next-obs to collectors.
values_dict = {
@ -1079,7 +1101,7 @@ def _process_observations_w_trajectory_view_api(
# Env infos for this agent.
if "infos" in pol.view_requirements:
values_dict["infos"] = agent_infos
_sample_collector.add_action_reward_next_obs(
sample_collector.add_action_reward_next_obs(
episode.episode_id, agent_id, env_id, policy_id,
agent_done, values_dict)
@ -1111,7 +1133,7 @@ def _process_observations_w_trajectory_view_api(
# MultiAgentBatch from a single episode and add it to "outputs".
# Otherwise, just postprocess and continue collecting across
# episodes.
ma_sample_batch = _sample_collector.postprocess_episode(
ma_sample_batch = sample_collector.postprocess_episode(
episode,
is_done=is_done or (hit_horizon and not soft_horizon),
check_dones=check_dones,
@ -1170,7 +1192,7 @@ def _process_observations_w_trajectory_view_api(
new_episode._set_last_observation(agent_id, filtered_obs)
# Add initial obs to buffer.
_sample_collector.add_init_obs(
sample_collector.add_init_obs(
new_episode, agent_id, env_id, policy_id,
new_episode.length - 1, filtered_obs)
@ -1183,7 +1205,7 @@ def _process_observations_w_trajectory_view_api(
# Try to build something.
if multiple_episodes_in_batch:
sample_batches = \
_sample_collector.try_build_truncated_episode_multi_agent_batch()
sample_collector.try_build_truncated_episode_multi_agent_batch()
if sample_batches:
outputs.extend(sample_batches)
@ -1279,7 +1301,7 @@ def _do_policy_eval_w_trajectory_view_api(
*,
to_eval: Dict[PolicyID, List[PolicyEvalData]],
policies: Dict[PolicyID, Policy],
_sample_collector,
sample_collector,
active_episodes: Dict[str, MultiAgentEpisode],
tf_sess: Optional["tf.Session"] = None,
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
@ -1291,7 +1313,7 @@ def _do_policy_eval_w_trajectory_view_api(
be the batch's items for the model forward pass).
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
obj.
_sample_collector (SampleCollector): The SampleCollector object to use.
sample_collector (SampleCollector): The SampleCollector object to use.
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
batching TF policy evaluations.
@ -1313,7 +1335,7 @@ def _do_policy_eval_w_trajectory_view_api(
for policy_id, eval_data in to_eval.items():
policy: Policy = _get_or_raise(policies, policy_id)
input_dict = _sample_collector.get_inference_input_dict(policy_id)
input_dict = sample_collector.get_inference_input_dict(policy_id)
eval_results[policy_id] = \
policy.compute_actions_from_input_dict(
input_dict,
@ -1343,7 +1365,7 @@ def _process_policy_eval_results(
policies: Dict[PolicyID, Policy],
clip_actions: bool,
_use_trajectory_view_api: bool = False,
_sample_collector=None,
sample_collector=None,
) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
"""Process the output of policy neural network evaluation.

View file

@ -63,7 +63,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
config,
env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv")
policy = trainer.get_policy()
view_req_model = policy.model.inference_view_requirements
view_req_model = policy.model.view_requirements
view_req_policy = policy.view_requirements
assert len(view_req_model) == 1, view_req_model
assert len(view_req_policy) == 8, view_req_policy
@ -108,7 +108,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
for _ in framework_iterator(config):
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
policy = trainer.get_policy()
view_req_model = policy.model.inference_view_requirements
view_req_model = policy.model.view_requirements
view_req_policy = policy.view_requirements
# 7=obs, prev-a + r, 2x state-in, 2x state-out.
assert len(view_req_model) == 7, view_req_model

View file

@ -21,7 +21,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
self.model = _fake_model()
self.model.time_major = True
self.model.inference_view_requirements = {
self.model.view_requirements = {
SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
"env_id": ViewRequirement(),
@ -33,12 +33,12 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
SampleBatch.REWARDS, shift=-1),
}
for i in range(2):
self.model.inference_view_requirements["state_in_{}".format(i)] = \
self.model.view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift=-1,
space=self.state_space)
self.model.inference_view_requirements[
self.model.view_requirements[
"state_out_{}".format(i)] = \
ViewRequirement(space=self.state_space)
@ -50,7 +50,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),
},
**self.model.inference_view_requirements)
**self.model.view_requirements)
@override(Policy)
def is_recurrent(self):
@ -97,7 +97,7 @@ class EpisodeEnvAwareAttentionPolicy(RandomPolicy):
pass
self.model = _fake_model()
self.model.inference_view_requirements = {
self.model.view_requirements = {
SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
"env_id": ViewRequirement(),
@ -114,7 +114,7 @@ class EpisodeEnvAwareAttentionPolicy(RandomPolicy):
}
self.view_requirements = dict(super()._get_default_view_requirements(),
**self.model.inference_view_requirements)
**self.model.view_requirements)
@override(Policy)
def is_recurrent(self):

View file

@ -62,7 +62,7 @@ class ModelV2:
self._last_output = None
self.time_major = self.model_config.get("_time_major")
# Basic view requirement for all models: Use the observation as input.
self.inference_view_requirements = {
self.view_requirements = {
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
}
@ -341,7 +341,7 @@ class ModelV2:
}
input_dict = {}
for view_col, view_req in self.inference_view_requirements.items():
for view_col, view_req in self.view_requirements.items():
# Create batches of size 1 (single-agent input-dict).
data_col = view_req.data_col or view_col
if index == "last":

View file

@ -284,22 +284,22 @@ class GTrXLNet(RecurrentNetwork):
self.register_variables(self.trxl_model.variables)
self.trxl_model.summary()
# Setup inference view (`memory-inference` x past observations +
# current one (0))
# 1 to `num_transformer_units`: Memory data (one per transformer unit).
# __sphinx_doc_begin__
# Setup trajectory views (`memory-inference` x past memory outs).
for i in range(self.num_transformer_units):
space = Box(-1.0, 1.0, shape=(self.attn_dim, ))
self.inference_view_requirements["state_in_{}".format(i)] = \
self.view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift="-{}:-1".format(self.memory_inference),
# Repeat the incoming state every max-seq-len times.
batch_repeat_value=self.max_seq_len,
space=space)
self.inference_view_requirements["state_out_{}".format(i)] = \
self.view_requirements["state_out_{}".format(i)] = \
ViewRequirement(
space=space,
used_for_training=False)
# __sphinx_doc_end__
@override(ModelV2)
def forward(self, input_dict, state: List[TensorType],

View file

@ -184,11 +184,11 @@ class LSTMWrapper(RecurrentNetwork):
# Add prev-a/r to this model's view, if required.
if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1)
if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
self.view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1)
@override(RecurrentNetwork)

View file

@ -160,19 +160,17 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
self.values_out = SlimFC(
in_size=self.attn_dim, out_size=1, activation_fn=None)
# Setup inference view (`memory-inference` x past observations +
# current one (0))
# 1 to `num_transformer_units`: Memory data (one per transformer unit).
# Setup trajectory views (`memory-inference` x past memory outs).
for i in range(self.num_transformer_units):
space = Box(-1.0, 1.0, shape=(self.attn_dim, ))
self.inference_view_requirements["state_in_{}".format(i)] = \
self.view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift="-{}:-1".format(self.memory_inference),
# Repeat the incoming state every max-seq-len times.
batch_repeat_value=self.max_seq_len,
space=space)
self.inference_view_requirements["state_out_{}".format(i)] = \
self.view_requirements["state_out_{}".format(i)] = \
ViewRequirement(
space=space,
used_for_training=False)

View file

@ -166,14 +166,16 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_)
# __sphinx_doc_begin__
# Add prev-a/r to this model's view, if required.
if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1)
if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
self.view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1)
# __sphinx_doc_end__
@override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType],

View file

@ -1,3 +1,4 @@
from abc import ABCMeta, abstractmethod
import logging
import numpy as np
import threading
@ -14,15 +15,16 @@ logger = logging.getLogger(__name__)
@PublicAPI
class InputReader:
class InputReader(metaclass=ABCMeta):
"""Input object for loading experiences in policy evaluation."""
@abstractmethod
@PublicAPI
def next(self):
"""Return the next batch of experiences read.
"""Returns the next batch of experiences read.
Returns:
SampleBatch or MultiAgentBatch read.
Union[SampleBatch, MultiAgentBatch]: The experience read.
"""
raise NotImplementedError

View file

@ -171,7 +171,7 @@ class DynamicTFPolicy(TFPolicy):
model_config=self.config["model"],
framework="tf")
# Auto-update model's inference view requirements, if recurrent.
self._update_model_inference_view_requirements_from_init_state()
self._update_model_view_requirements_from_init_state()
if existing_inputs:
self._state_inputs = [
@ -186,8 +186,7 @@ class DynamicTFPolicy(TFPolicy):
get_placeholder(
space=vr.space,
time_axis=not isinstance(vr.shift, int),
) for k, vr in
self.model.inference_view_requirements.items()
) for k, vr in self.model.view_requirements.items()
if k.startswith("state_in_")
]
else:
@ -200,7 +199,7 @@ class DynamicTFPolicy(TFPolicy):
# Add NEXT_OBS, STATE_IN_0.., and others.
self.view_requirements = self._get_default_view_requirements()
# Combine view_requirements for Model and Policy.
self.view_requirements.update(self.model.inference_view_requirements)
self.view_requirements.update(self.model.view_requirements)
# Setup standard placeholders.
if existing_inputs is not None:
@ -560,7 +559,7 @@ class DynamicTFPolicy(TFPolicy):
all_accessed_keys = \
train_batch.accessed_keys | batch_for_postproc.accessed_keys | \
batch_for_postproc.added_keys | set(
self.model.inference_view_requirements.keys())
self.model.view_requirements.keys())
TFPolicy._initialize_loss(self, loss, [(k, v)
for k, v in train_batch.items()
@ -584,7 +583,7 @@ class DynamicTFPolicy(TFPolicy):
# Tag those only needed for post-processing.
for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys and \
key not in self.model.inference_view_requirements:
key not in self.model.view_requirements:
if key in self.view_requirements:
self.view_requirements[key].used_for_training = False
if key in self._loss_input_dict:
@ -597,7 +596,7 @@ class DynamicTFPolicy(TFPolicy):
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
SampleBatch.UNROLL_ID, SampleBatch.DONES,
SampleBatch.REWARDS] and \
key not in self.model.inference_view_requirements:
key not in self.model.view_requirements:
# If user deleted this key manually in postprocessing
# fn, warn about it and do not remove from
# view-requirements.

View file

@ -256,15 +256,14 @@ def build_eager_tf_policy(name,
framework=self.framework,
)
# Auto-update model's inference view requirements, if recurrent.
self._update_model_inference_view_requirements_from_init_state()
self._update_model_view_requirements_from_init_state()
self.exploration = self._create_exploration()
self._state_inputs = self.model.get_initial_state()
self._is_recurrent = len(self._state_inputs) > 0
# Combine view_requirements for Model and Policy.
self.view_requirements.update(
self.model.inference_view_requirements)
self.view_requirements.update(self.model.view_requirements)
if before_loss_init:
before_loss_init(self, observation_space, action_space, config)

View file

@ -669,7 +669,7 @@ class Policy(metaclass=ABCMeta):
for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys and \
key in self.view_requirements and \
key not in self.model.inference_view_requirements:
key not in self.model.view_requirements:
self.view_requirements[key].used_for_training = False
# Remove those not needed at all (leave those that are needed
# by Sampler to properly execute sample collection).
@ -679,7 +679,7 @@ class Policy(metaclass=ABCMeta):
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
SampleBatch.UNROLL_ID, SampleBatch.DONES,
SampleBatch.REWARDS] and \
key not in self.model.inference_view_requirements:
key not in self.model.view_requirements:
# If user deleted this key manually in postprocessing
# fn, warn about it and do not remove from
# view-requirements.
@ -738,12 +738,12 @@ class Policy(metaclass=ABCMeta):
# columns in the resulting batch may not all have the same batch size.
return SampleBatch(ret, _dont_check_lens=True)
def _update_model_inference_view_requirements_from_init_state(self):
def _update_model_view_requirements_from_init_state(self):
"""Uses Model's (or this Policy's) init state to add needed ViewReqs.
Can be called from within a Policy to make sure RNNs automatically
update their internal state-related view requirements.
Changes the `self.inference_view_requirements` dict.
Changes the `self.view_requirements` dict.
"""
self._model_init_state_automatically_added = True
model = getattr(self, "model", None)
@ -752,7 +752,7 @@ class Policy(metaclass=ABCMeta):
for i, state in enumerate(obj.get_initial_state()):
space = Box(-1.0, 1.0, shape=state.shape) if \
hasattr(state, "shape") else state
view_reqs = model.inference_view_requirements if model else \
view_reqs = model.view_requirements if model else \
self.view_requirements
view_reqs["state_in_{}".format(i)] = ViewRequirement(
"state_out_{}".format(i),

View file

@ -9,7 +9,6 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils import add_mixins, force_list, NullContextManager
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch, try_import_jax
@ -74,8 +73,6 @@ def build_policy_class(
apply_gradients_fn: Optional[Callable[
[Policy, "torch.optim.Optimizer"], None]] = None,
mixins: Optional[List[type]] = None,
view_requirements_fn: Optional[Callable[[Policy], Dict[
str, ViewRequirement]]] = None,
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
) -> Type[TorchPolicy]:
"""Helper function for creating a new Policy class at runtime.
@ -181,9 +178,6 @@ def build_policy_class(
mixins (Optional[List[type]]): Optional list of any class mixins for
the returned policy class. These mixins will be applied in order
and will have higher precedence than the TorchPolicy class.
view_requirements_fn (Optional[Callable[[Policy],
Dict[str, ViewRequirement]]]): An optional callable to retrieve
additional train view requirements for this policy.
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
Optional callable that returns the divisibility requirement for
sample batches. If None, will assume a value of 1.
@ -260,12 +254,8 @@ def build_policy_class(
get_batch_divisibility_req=get_batch_divisibility_req,
)
# Update this Policy's ViewRequirements (if function given).
if callable(view_requirements_fn):
self.view_requirements.update(view_requirements_fn(self))
# Merge Model's view requirements into Policy's.
self.view_requirements.update(
self.model.inference_view_requirements)
self.view_requirements.update(self.model.view_requirements)
_before_loss_init = before_loss_init or after_init
if _before_loss_init:

View file

@ -147,7 +147,7 @@ class TFPolicy(Policy):
self.model = model
# Auto-update model's inference view requirements, if recurrent.
if self.model is not None:
self._update_model_inference_view_requirements_from_init_state()
self._update_model_view_requirements_from_init_state()
self.exploration = self._create_exploration()
self._sess = sess

View file

@ -113,9 +113,9 @@ class TorchPolicy(Policy):
self._state_inputs = self.model.get_initial_state()
self._is_recurrent = len(self._state_inputs) > 0
# Auto-update model's inference view requirements, if recurrent.
self._update_model_inference_view_requirements_from_init_state()
self._update_model_view_requirements_from_init_state()
# Combine view_requirements for Model and Policy.
self.view_requirements.update(self.model.inference_view_requirements)
self.view_requirements.update(self.model.view_requirements)
self.exploration = self._create_exploration()
self.unwrapped_model = model # used to support DistributedDataParallel

View file

@ -22,7 +22,7 @@ class ViewRequirement:
Examples:
>>> # The default ViewRequirement for a Model is:
>>> req = [ModelV2].inference_view_requirements
>>> req = [ModelV2].view_requirements
>>> print(req)
{"obs": ViewRequirement(shift=0)}
"""