mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Trajectory view API docs. (#12718)
This commit is contained in:
parent
28ac4243f4
commit
391cdfae8c
31 changed files with 571 additions and 173 deletions
1
doc/source/images/rllib-batch-modes.svg
Normal file
1
doc/source/images/rllib-batch-modes.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 131 KiB |
1
doc/source/images/rllib-sample-collection.svg
Normal file
1
doc/source/images/rllib-sample-collection.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 111 KiB |
1
doc/source/images/rllib-trajectory-view-example.svg
Normal file
1
doc/source/images/rllib-trajectory-view-example.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 207 KiB |
|
@ -259,6 +259,7 @@ Papers
|
|||
rllib-env.rst
|
||||
rllib-models.rst
|
||||
rllib-algorithms.rst
|
||||
rllib-sample-collection.rst
|
||||
rllib-offline.rst
|
||||
rllib-concepts.rst
|
||||
rllib-examples.rst
|
||||
|
|
|
@ -50,7 +50,7 @@ Exploration-based plug-ins (can be combined with any algo)
|
|||
============================= ========== ======================= ================== =========== =====================
|
||||
Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support
|
||||
============================= ========== ======================= ================== =========== =====================
|
||||
`Curiosity`_ torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
|
||||
`Curiosity`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
|
||||
============================= ========== ======================= ================== =========== =====================
|
||||
|
||||
.. _`A2C, A3C`: rllib-algorithms.html#a3c
|
||||
|
|
|
@ -5,7 +5,7 @@ RLlib works with several different types of environments, including `OpenAI Gym
|
|||
|
||||
.. tip::
|
||||
|
||||
Not all environments work with all algorithms. Check out the algorithm `feature compatibility matrix <rllib-algorithms.html#feature-compatibility-matrix>`__ for more information.
|
||||
Not all environments work with all algorithms. Check out the `algorithm overview <rllib-algorithms.html#available-algorithms-overview>`__ for more information.
|
||||
|
||||
.. image:: rllib-envs.svg
|
||||
|
||||
|
|
|
@ -389,7 +389,7 @@ Custom models can be used to work with environments where (1) the set of valid a
|
|||
return action_logits + inf_mask, state
|
||||
|
||||
|
||||
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_actions_cartpole.py <https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_actions_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `feature compatibility matrix <rllib-env.html#feature-compatibility-matrix>`__.
|
||||
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_actions_cartpole.py <https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_actions_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `algorithm overview <rllib-algorithms.html#available-algorithms-overview>`__.
|
||||
|
||||
|
||||
Autoregressive Action Distributions
|
||||
|
|
337
doc/source/rllib-sample-collection.rst
Normal file
337
doc/source/rllib-sample-collection.rst
Normal file
|
@ -0,0 +1,337 @@
|
|||
RLlib Sample Collection and Trajectory Views
|
||||
============================================
|
||||
|
||||
The SampleCollector Class is Used to Store and Retrieve Temporary Data
|
||||
----------------------------------------------------------------------
|
||||
|
||||
RLlib's `RolloutWorkers <https://github.com/ray-project/ray/blob/master/rllib/evaluation/rollout_worker.py>`__,
|
||||
when running against a live environment, use the ``SamplerInput`` class to interact
|
||||
with that environment and produce batches of experiences.
|
||||
The two implemented sub-classes of ``SamplerInput`` are ``SyncSampler`` and ``AsyncSampler``
|
||||
(residing under the ``RolloutWorker.sampler`` property).
|
||||
|
||||
In case the "_use_trajectory_view_api" top-level config key is set to True
|
||||
(by default since version >=1.1.0), every such sampler object will use the
|
||||
``SampleCollector`` API to store and retrieve temporary environment-, model-, and other data
|
||||
during rollouts (see figure below).
|
||||
|
||||
.. Edit figure below at: https://docs.google.com/drawings/d/1ZdNUU3ChwiUeT-DBRxvLAsbEPEqEFWSPZcOyVy3KxVg/edit
|
||||
.. image:: images/rllib-sample-collection.svg
|
||||
|
||||
**Sample collection process implemented by RLlib:**
|
||||
The Policy's model tells the Sampler and its SampleCollector object, which data to store and
|
||||
how to present it back to the dependent methods (e.g. `Model.compute_actions()`).
|
||||
This is done using a dict that maps strings (column names) to `ViewRequirement` objects (details see below).
|
||||
|
||||
|
||||
The exact behavior for a single such rollout and the number of environment transitions therein
|
||||
are determined by the following Trainer config keys:
|
||||
|
||||
**batch_mode [truncate_episodes|complete_episodes]**:
|
||||
*truncated_episodes (default value)*:
|
||||
Rollouts are performed
|
||||
over exactly ``rollout_fragment_length`` (see below) number of steps. Thereby, steps are
|
||||
counted as either environment steps or as individual agent steps (see ``count_steps_as`` below).
|
||||
It does not matter, whether one or more episodes end within this rollout or whether
|
||||
the rollout starts in the middle of an already ongoing episode.
|
||||
*complete_episodes*:
|
||||
Each rollout is exactly one episode long and always starts
|
||||
at the beginning of an episode. It does not matter how long an episode lasts.
|
||||
The ``rollout_fragment_length`` setting will be ignored. Note that you have to be
|
||||
careful when chosing ``complete_episodes`` as batch_mode: If your environment does not
|
||||
terminate easily, this setting could lead to enormous batch sizes.
|
||||
|
||||
**rollout_fragment_length [int]**:
|
||||
The exact number of environment- or agent steps to
|
||||
be performed per rollout, if the ``batch_mode`` setting (see above) is "truncate_episodes".
|
||||
If ``batch_mode`` is "complete_episodes", ``rollout_fragment_length`` is ignored,
|
||||
The unit to count fragments in is set via ``multiagent.count_steps_by=[env_steps|agent_steps]``
|
||||
(within the ``multiagent`` config dict).
|
||||
|
||||
.. Edit figure below at: https://docs.google.com/drawings/d/1uRNGImBNq8gv3bBoFX_HernGyeovtCB3wKpZ71c0VE4/edit
|
||||
.. image:: images/rllib-batch-modes.svg
|
||||
|
||||
**Above:** The two supported batch modes in RLlib. For "truncated_episodes",
|
||||
batches can a) span over more than one episode, b) end in the middle of an episode, and
|
||||
c) start in the middle of an episode. Also, `Policy.postprocess_trajectory()` is always
|
||||
called at the end of a rollout-fragment (red lines on right side) as well as at the end
|
||||
of each episode (arrow heads). This way, RLlib makes sure that the
|
||||
`Policy.postprocess_trajectory()` method never sees data from more than one episode.
|
||||
|
||||
|
||||
**multiagent.count_steps_by [env_steps|agent_steps]**:
|
||||
Within the Trainer's ``multiagent`` config dict, you can set the unit, by which RLlib will count a) rollout fragment lengths as well as b) the size of the final train_batch (see below). The two supported values are:
|
||||
|
||||
*env_steps (default)*:
|
||||
Each call to ``[Env].step()`` is counted as one. It does not
|
||||
matter, how many individual agents are stepping simultaneously in this very call
|
||||
(not all existing agents in the environment may step at the same time).
|
||||
*agent_steps*:
|
||||
In a multi-agent environment, count each individual agent's step
|
||||
as one. For example, if N agents are in an environment and all these N agents
|
||||
always step at the same time, a single env step corresponds to N agent steps.
|
||||
Note that in the single-agent case, ``env_steps`` and ``agent_steps`` are the same thing.
|
||||
|
||||
**horizon [int]**:
|
||||
Some environments are limited by default in the number of maximum timesteps
|
||||
an episode can last. This limit is called the "horizon" of an episode.
|
||||
For example, for CartPole-v0, the maximum number of steps per episode is 200 by default.
|
||||
You can overwrite this setting, however, by using the ``horizon`` config.
|
||||
If provided, RLlib will first try to increase the environment's built-in horizon
|
||||
setting (e.g. openAI gym Envs have a ``spec.max_episode_steps`` property), if the user
|
||||
provided horizon is larger than this env-specific setting. In either case, no episode
|
||||
is allowed to exceed the given ``horizon`` number of timesteps (RLlib will
|
||||
artificially terminate an episode if this limit is hit).
|
||||
|
||||
**soft_horizon [bool]**:
|
||||
False by default. If set to True, the environment will
|
||||
a) not be reset when reaching ``horizon`` and b) no ``done=True`` will be set
|
||||
in the trajectory data sent to the postprocessors and training (``done`` will remain
|
||||
False at the horizon).
|
||||
|
||||
**no_done_at_end [bool]**:
|
||||
Never set ``done=True``, at the end of an episode or when any
|
||||
artificial horizon is reached.
|
||||
|
||||
To trigger a single rollout, RLlib calls ``RolloutWorker.sample()``, which returns
|
||||
a SampleBatch or MultiAgentBatch object representing all the data collected during that
|
||||
rollout. These batches are then usually further concatenated (from the ``num_workers``
|
||||
parallelized RolloutWorkers) to form a final train batch. The size of that train batch is determined
|
||||
by the ``train_batch_size`` config parameter. Train batches are usually sent to the Policy's
|
||||
``learn_on_batch`` method, which handles loss- and gradient calculations, and optimizer stepping.
|
||||
|
||||
RLlib's default ``SampleCollector`` class is the ``SimpleListCollector``, which appends single timestep data (e.g. actions)
|
||||
to lists, then builds SampleBatches from these and sends them to the downstream processing functions.
|
||||
It thereby tries to avoid collecting duplicate data separately (OBS and NEXT_OBS use the same underlying list).
|
||||
If you want to implement your own collection logic and data structures, you can sub-class ``SampleCollector``
|
||||
and specify that new class under the Trainer's "sample_collector" config key.
|
||||
|
||||
Let's now look at how the Policy's Model lets the RolloutWorker and its SampleCollector
|
||||
know, what data in the ongoing episode/trajectory to use for the different required method calls
|
||||
during rollouts. These method calls in particular are:
|
||||
``Policy.compute_actions_from_input_dict()`` to compute actions to be taken in an episode.
|
||||
``Policy.postprocess_trajectory()``, which is called after an episode ends or a rollout hit its
|
||||
``rollout_fragment_length`` limit (in ``batch_mode=truncated_episodes``), and ``Policy.learn_on_batch()``,
|
||||
which is called with a "train_batch" to improve the policy.
|
||||
|
||||
Trajectory View API
|
||||
-------------------
|
||||
|
||||
The trajectory view API allows custom models to define what parts of the trajectory they
|
||||
require in order to execute the forward pass. For example, in the simplest case, a model might
|
||||
only look at the latest observation. However, an RNN- or attention based model could look
|
||||
at previous states emitted by the model, concatenate previously seen rewards with the current observation,
|
||||
or require the entire range of the n most recent observations.
|
||||
|
||||
The trajectory view API lets models define these requirements and lets RLlib gather the required
|
||||
data for the forward pass in an efficient way.
|
||||
|
||||
Since the following methods all call into the model class, they are all indirectly using the trajectory view API.
|
||||
It is important to note that the API is only accessible to the user via the model classes
|
||||
(see below on how to setup trajectory view requirements for a custom model).
|
||||
|
||||
In particular, the methods receiving inputs that depend on a Model's trajectory view rules are:
|
||||
|
||||
a) ``Policy.compute_actions_from_input_dict()``
|
||||
b) ``Policy.postprocess_trajectory()`` and
|
||||
c) ``Policy.learn_on_batch()`` (and consecutively: the Policy's loss function).
|
||||
|
||||
The input data to these methods can stem from either the environment (observations, rewards, and env infos),
|
||||
the model itself (previously computed actions, internal state outputs, action-probs, etc..)
|
||||
or the Sampler (e.g. agent index, env ID, episode ID, timestep, etc..).
|
||||
All data has an associated time axis, which is 0-based, meaning that the first action taken, the
|
||||
first reward received in an episode, and the first observation (directly after a reset)
|
||||
all have t=0.
|
||||
|
||||
The idea is to allow more flexibility and standardization in how a model defines required
|
||||
"views" on the ongoing trajectory (during action computations/inference), past episodes (training
|
||||
on a batch), or even trajectories of other agents in the same episode, some of which
|
||||
may even use a different policy.
|
||||
|
||||
Such a "view requirements" formalism is helpful when having to support more complex model
|
||||
setups like RNNs, attention nets, observation image framestacking (e.g. for Atari),
|
||||
and building multi-agent communication channels.
|
||||
|
||||
The way to define a set of rules used for making the Model see certain
|
||||
data is through a "view requirements dict", residing in the ``Policy.model.view_requirements``
|
||||
property.
|
||||
View requirements dicts map strings (column names), such as "obs" or "actions" to
|
||||
a ``ViewRequirement`` object, which defines the exact conditions by which this column
|
||||
should be populated with data.
|
||||
|
||||
View Requirement Dictionaries
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
View requirements are stored within the ``view_requirements`` property of the ``ModelV2``
|
||||
class.
|
||||
You can acccess it like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
my_simple_model = ModelV2(...)
|
||||
print(my_simple_model.view_requirements)
|
||||
>>>{"obs": ViewRequirement(shift=0, space=[observation space])}
|
||||
|
||||
my_lstm_model = LSTMModel(...)
|
||||
print(my_lstm_model.view_requirements)
|
||||
>>>{
|
||||
>>> "obs": ViewRequirement(shift=0, space=[observation space]),
|
||||
>>> "prev_actions": ViewRequirement(shift=-1, data_col="actions", space=[action space]),
|
||||
>>> "prev_rewards": ViewRequirement(shift=-1, data_col="rewards"),
|
||||
>>>}
|
||||
|
||||
The ``view_requirements`` property holds a dictionary mapping
|
||||
string keys (e.g. "actions", "rewards", "next_obs", etc..)
|
||||
to a ``ViewRequirement`` object. This ``ViewRequirement`` object determines what exact data to
|
||||
provide under the given key in case a SampleBatch or a single-timestep (action computing) "input dict"
|
||||
needs to be build and fed into one of the above ModelV2- or Policy methods.
|
||||
|
||||
.. Edit figure below at: https://docs.google.com/drawings/d/1YEPUtMrRXmWfvM0E6mD3VsOaRlLV7DtctF-yL96VHGg/edit
|
||||
.. image:: images/rllib-trajectory-view-example.svg
|
||||
|
||||
**Above:** An example `ViewRequirements` dict that causes the current observation
|
||||
and the previous action to be available in each compute_action call, as
|
||||
well as for the Policy's `postprocess_trajectory()` function (and train batch).
|
||||
A similar setup is often used by LSTM/RNN-based models.
|
||||
|
||||
|
||||
The ViewRequirement class
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Here is a description of the constructor-settable properties of a ViewRequirement
|
||||
object and what each of these properties controls.
|
||||
|
||||
**data_col**:
|
||||
An optional string key referencing the underlying data to use to
|
||||
create the view. If not provided, assumes that there is data under the
|
||||
dict-key under which this ViewRequirement resides.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
ModelV2.view_requirements = {"rewards": ViewRequirements(shift=0)}
|
||||
# -> implies that the underlying data to use are the collected rewards
|
||||
# from the environment.
|
||||
|
||||
ModelV2.view_requirements = {"prev_rewards": ViewRequirements(data_col="rewards", shift=-1)}
|
||||
# -> means that the actual data used to create the "prev_rewards" column
|
||||
# is the "rewards" data from the environment (shifted by 1 timestep).
|
||||
|
||||
**space**:
|
||||
An optional gym.Space used as a hint for the SampleCollector to know,
|
||||
how to fill timesteps before the episode actually started (e.g. if
|
||||
shift=-2, we need dummy data at timesteps -2 and -1).
|
||||
|
||||
**shift [int]**:
|
||||
An int, a list of ints, or a range string (e.g. "-50:-1") to indicate
|
||||
which time offsets or ranges of the underlying data to use for the view.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
shift=0 # -> Use the data under ``data_col`` as is.
|
||||
shift=1 # -> Use the data under ``data_col``, but shifted by +1 timestep
|
||||
# (used by e.g. next_obs views).
|
||||
shift=-1 # -> Use the data under ``data_col``, but shifted by -1 timestep
|
||||
# (used by e.g. prev_actions views).
|
||||
shift=[-2, -1] # -> Use the data under ``data_col``, but always provide 2 values
|
||||
# at each timestep (the two previous ones).
|
||||
# Could be used e.g. to feed the last two actions or rewards into an LSTM.
|
||||
shift="-50:-1" # -> Use the data under ``data_col``, but always provide a range of
|
||||
# the last 50 timesteps (used by our attention nets).
|
||||
|
||||
**used_for_training [bool]**:
|
||||
True by default. If False, the column will not be available inside the train batch (arriving in the
|
||||
Policy's loss function).
|
||||
RLlib will automatically switch this to False for a given column, if it detects during
|
||||
Policy initialization that that column is not accessed inside the loss function (see below).
|
||||
|
||||
How does RLlib determine, which Views are required?
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
When initializing a Policy, it automatically determines how to later build batches
|
||||
for postprocessing, loss function calls, and action computations, based on
|
||||
the Model's ``view_requirements`` dict. It does so by sending generic dummy batches
|
||||
through its ``compute_actions_from_input_dict``, ``postprocess_trajectory``, and loss functions
|
||||
and then checks, which fields in these dummy batches get accessed, overwritten, deleted or added.
|
||||
Based on these test passes, the Policy then throws out those ViewRequirements from an initial
|
||||
very broad list, that it deems unnecessary. This procedure saves a lot of data copying
|
||||
during later rollouts, batch transfers (via ray) and loss calculations and makes things like
|
||||
manually deleting columns from a SampleBatch (e.g. PPO used to delete the "next_obs" column
|
||||
inside the postprocessing function) unnecessary.
|
||||
Note that the "rewards" and "dones" columns are never discarded and thus should always
|
||||
arrive in your loss function's SampleBatch (``train_batch`` arg).
|
||||
|
||||
|
||||
Setting ViewRequirements manually in your Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If you need to specify special view requirements for your model, you can add
|
||||
columns to the Model's ``view_requirements`` dict in the
|
||||
Model's constructor.
|
||||
|
||||
For example, our auto-LSTM wrapper classes (tf and torch) have these additional
|
||||
lines in their constructors (torch version shown here):
|
||||
|
||||
.. literalinclude:: ../../rllib/models/torch/recurrent_net.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
||||
This makes sure that, if the users requires this via the model config, previous rewards
|
||||
and/or previous actions are added properly to the ``compute_actions`` input-dicts and SampleBatches
|
||||
used for postprocessing and training.
|
||||
|
||||
|
||||
Another example are our attention nets, which make sure the last n (memory) model outputs
|
||||
are always fed back into the model on the next time step (tf version shown here).
|
||||
|
||||
.. literalinclude:: ../../rllib/models/tf/attention_net.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
||||
|
||||
Setting ViewRequirements manually after Policy construction
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Here is a simple example, of how you can modify and add to the ViewRequirements dict
|
||||
even after policy (or RolloutWorker) creation. However, note that it's better to
|
||||
make these modifications to your batches in your postprocessing function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Modify view_requirements in the Policy object.
|
||||
action_space = Discrete(2)
|
||||
rollout_worker = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_config=ppo.DEFAULT_CONFIG,
|
||||
policy_spec=ppo.PPOTorchPolicy,
|
||||
)
|
||||
policy = rollout_worker.policy_map["default_policy"]
|
||||
# Add the next action to the view reqs of the policy.
|
||||
# This should be visible then in postprocessing and train batches.
|
||||
policy.view_requirements["next_actions"] = ViewRequirement(
|
||||
SampleBatch.ACTIONS, shift=1, space=action_space)
|
||||
# Check, whether a sampled batch has the requested `next_actions` view.
|
||||
batch = rollout_worker.sample()
|
||||
self.assertTrue("next_actions" in batch.data)
|
||||
|
||||
# Doing the same in a custom postprocessing callback function:
|
||||
class MyCallback(DefaultCallbacks):
|
||||
# ...
|
||||
|
||||
@override(DefaultCallbacks)
|
||||
def on_postprocess_trajectory(self, worker, episode, agent_id, policy_id,
|
||||
policies, postprocessed_batch, original_batches,
|
||||
**kwargs):
|
||||
postprocessed_batch["next_actions"] = np.concatenate(
|
||||
[postprocessed_batch["actions"][1:],
|
||||
np.zeros_like([policies[policy_id].action_space.sample()])])
|
||||
|
||||
The above two examples add a "next_action" view to the postprocessed SampleBatche needed
|
||||
used by the Policy for training. It will not feed the "next_action"
|
||||
to the Model's ``compute_action`` calls (it can't b/c the next action is of course not known
|
||||
at that point).
|
|
@ -60,7 +60,6 @@ Training APIs
|
|||
Environments
|
||||
------------
|
||||
* `RLlib Environments Overview <rllib-env.html>`__
|
||||
* `Feature Compatibility Matrix <rllib-env.html#feature-compatibility-matrix>`__
|
||||
* `OpenAI Gym <rllib-env.html#openai-gym>`__
|
||||
* `Vectorized <rllib-env.html#vectorized>`__
|
||||
* `Multi-Agent and Hierarchical <rllib-env.html#multi-agent-and-hierarchical>`__
|
||||
|
@ -144,6 +143,11 @@ Algorithms
|
|||
|
||||
- |pytorch| :ref:`Curiosity (ICM: Intrinsic Curiosity Module) <curiosity>`
|
||||
|
||||
Sample Collection
|
||||
-----------------
|
||||
* `The SampleCollector Class is Used to Store and Retrieve Temporary Data <rllib-sample-collection.html#the-samplecollector-class-is-used-to-store-and-retrieve-temporary-data>`__
|
||||
* `Trajectory View API <rllib-sample-collection.html#trajectory-view-api>`__
|
||||
|
||||
|
||||
Offline Datasets
|
||||
----------------
|
||||
|
|
|
@ -110,7 +110,12 @@ Beyond environments defined in Python, RLlib supports batch training on `offline
|
|||
Customization
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
RLlib provides ways to customize almost all aspects of training, including the `environment <rllib-env.html#configuring-environments>`__, `neural network model <rllib-models.html#tensorflow-models>`__, `action distribution <rllib-models.html#custom-action-distributions>`__, and `policy definitions <rllib-concepts.html#policies>`__:
|
||||
RLlib provides ways to customize almost all aspects of training, including
|
||||
`neural network models <rllib-models.html#tensorflow-models>`__,
|
||||
`action distributions <rllib-models.html#custom-action-distributions>`__,
|
||||
`policy definitions <rllib-concepts.html#policies>`__:
|
||||
the `environment <rllib-env.html#configuring-environments>`__,
|
||||
and the `sample collection process <rllib-sample-collection.html>`__
|
||||
|
||||
.. image:: rllib-components.svg
|
||||
|
||||
|
|
|
@ -14,6 +14,8 @@ from ray.exceptions import RayError
|
|||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.env.normalize_actions import NormalizeActionWrapper
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.evaluation.collectors.simple_list_collector import \
|
||||
SimpleListCollector
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.policy import Policy
|
||||
|
@ -231,6 +233,10 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# generic ModelV2 `input_dicts` that can be requested by the model to
|
||||
# contain different information on the ongoing episode.
|
||||
"_use_trajectory_view_api": True,
|
||||
# The SampleCollector class to be used to collect and retrieve
|
||||
# environment-, model-, and sampler data. Override the SampleCollector base
|
||||
# class to implement your own collection/buffering/retrieval logic.
|
||||
"sample_collector": SimpleListCollector,
|
||||
|
||||
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter".
|
||||
"observation_filter": "NoFilter",
|
||||
|
|
|
@ -1,16 +1,22 @@
|
|||
from abc import abstractmethod, ABCMeta
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
||||
TensorType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _SampleCollector(metaclass=ABCMeta):
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
class SampleCollector(metaclass=ABCMeta):
|
||||
"""Collects samples for all policies and agents from a multi-agent env.
|
||||
|
||||
Note: This is an experimental class only used when
|
||||
|
@ -29,6 +35,34 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
communication channel).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
policy_map: Dict[PolicyID, Policy],
|
||||
clip_rewards: Union[bool, float],
|
||||
callbacks: "DefaultCallbacks",
|
||||
multiple_episodes_in_batch: bool = True,
|
||||
rollout_fragment_length: int = 200,
|
||||
count_steps_by: str = "env_steps"):
|
||||
"""Initializes a SampleCollector instance.
|
||||
|
||||
Args:
|
||||
policy_map (Dict[str, Policy]): Maps policy ids to policy
|
||||
instances.
|
||||
clip_rewards (Union[bool, float]): Whether to clip rewards before
|
||||
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
||||
callbacks (DefaultCallbacks): RLlib callbacks.
|
||||
multiple_episodes_in_batch (bool): Whether it's allowed to pack
|
||||
multiple episodes into the same built batch.
|
||||
rollout_fragment_length (int): The
|
||||
|
||||
"""
|
||||
|
||||
self.policy_map = policy_map
|
||||
self.clip_rewards = clip_rewards
|
||||
self.callbacks = callbacks
|
||||
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.count_steps_by = count_steps_by
|
||||
|
||||
@abstractmethod
|
||||
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
||||
policy_id: PolicyID, t: int,
|
||||
|
@ -55,7 +89,7 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
|
||||
Examples:
|
||||
>>> obs = env.reset()
|
||||
>>> collector.add_init_obs(12345, 0, "pol0", obs)
|
||||
>>> collector.add_init_obs(my_episode, 0, "pol0", -1, obs)
|
||||
>>> obs, r, done, info = env.step(action)
|
||||
>>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
|
||||
... "action": action, "obs": obs, "reward": r, "done": done
|
||||
|
@ -227,3 +261,4 @@ class _SampleCollector(metaclass=ABCMeta):
|
|||
`self.rollout_fragment_length` has not been reached yet.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# __sphinx_doc_end__
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.evaluation.collectors.sample_collector import _SampleCollector
|
||||
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
|
@ -379,7 +379,7 @@ class _PolicyCollectorGroup:
|
|||
self.agent_steps = 0
|
||||
|
||||
|
||||
class _SimpleListCollector(_SampleCollector):
|
||||
class SimpleListCollector(SampleCollector):
|
||||
"""Util to build SampleBatches for each policy in a multi-agent env.
|
||||
|
||||
Input data is per-agent, while output data is per-policy. There is an M:N
|
||||
|
@ -395,29 +395,15 @@ class _SimpleListCollector(_SampleCollector):
|
|||
multiple_episodes_in_batch: bool = True,
|
||||
rollout_fragment_length: int = 200,
|
||||
count_steps_by: str = "env_steps"):
|
||||
"""Initializes a _SimpleListCollector instance.
|
||||
"""Initializes a SimpleListCollector instance."""
|
||||
|
||||
Args:
|
||||
policy_map (Dict[str, Policy]): Maps policy ids to policy
|
||||
instances.
|
||||
clip_rewards (Union[bool, float]): Whether to clip rewards before
|
||||
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
||||
callbacks (DefaultCallbacks): RLlib callbacks.
|
||||
multiple_episodes_in_batch (bool): Whether it's allowed to pack
|
||||
multiple episodes into the same built batch.
|
||||
rollout_fragment_length (int): The
|
||||
super().__init__(policy_map, clip_rewards, callbacks,
|
||||
multiple_episodes_in_batch, rollout_fragment_length,
|
||||
count_steps_by)
|
||||
|
||||
"""
|
||||
|
||||
self.policy_map = policy_map
|
||||
self.clip_rewards = clip_rewards
|
||||
self.callbacks = callbacks
|
||||
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.count_steps_by = count_steps_by
|
||||
self.large_batch_threshold: int = max(
|
||||
1000, rollout_fragment_length *
|
||||
10) if rollout_fragment_length != float("inf") else 5000
|
||||
1000, self.rollout_fragment_length *
|
||||
10) if self.rollout_fragment_length != float("inf") else 5000
|
||||
|
||||
# Whenever we observe a new episode+agent, add a new
|
||||
# _SingleTrajectoryCollector.
|
||||
|
@ -430,8 +416,9 @@ class _SimpleListCollector(_SampleCollector):
|
|||
self.policy_collector_groups = []
|
||||
|
||||
# Agents to collect data from for the next forward pass (per policy).
|
||||
self.forward_pass_agent_keys = {pid: [] for pid in policy_map.keys()}
|
||||
self.forward_pass_size = {pid: 0 for pid in policy_map.keys()}
|
||||
self.forward_pass_agent_keys = \
|
||||
{pid: [] for pid in self.policy_map.keys()}
|
||||
self.forward_pass_size = {pid: 0 for pid in self.policy_map.keys()}
|
||||
|
||||
# Maps episode ID to the (non-built) env steps taken in this episode.
|
||||
self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
|
||||
|
@ -441,7 +428,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
# Maps episode ID to MultiAgentEpisode.
|
||||
self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {}
|
||||
|
||||
@override(_SampleCollector)
|
||||
@override(SampleCollector)
|
||||
def episode_step(self, episode_id: EpisodeID) -> None:
|
||||
episode = self.episodes[episode_id]
|
||||
self.episode_steps[episode_id] += 1
|
||||
|
@ -470,7 +457,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
"does at some point."
|
||||
if not self.multiple_episodes_in_batch else ""))
|
||||
|
||||
@override(_SampleCollector)
|
||||
@override(SampleCollector)
|
||||
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
|
||||
env_id: EnvID, policy_id: PolicyID, t: int,
|
||||
init_obs: TensorType) -> None:
|
||||
|
@ -481,7 +468,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
else:
|
||||
assert self.agent_key_to_policy_id[agent_key] == policy_id
|
||||
policy = self.policy_map[policy_id]
|
||||
view_reqs = policy.model.inference_view_requirements if \
|
||||
view_reqs = policy.model.view_requirements if \
|
||||
getattr(policy, "model", None) else policy.view_requirements
|
||||
|
||||
# Add initial obs to Trajectory.
|
||||
|
@ -503,7 +490,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
|
||||
self._add_to_next_inference_call(agent_key)
|
||||
|
||||
@override(_SampleCollector)
|
||||
@override(SampleCollector)
|
||||
def add_action_reward_next_obs(self, episode_id: EpisodeID,
|
||||
agent_id: AgentID, env_id: EnvID,
|
||||
policy_id: PolicyID, agent_done: bool,
|
||||
|
@ -525,14 +512,14 @@ class _SimpleListCollector(_SampleCollector):
|
|||
if not agent_done:
|
||||
self._add_to_next_inference_call(agent_key)
|
||||
|
||||
@override(_SampleCollector)
|
||||
@override(SampleCollector)
|
||||
def total_env_steps(self) -> int:
|
||||
# Add the non-built ongoing-episode env steps + the already built
|
||||
# env-steps.
|
||||
return sum(self.episode_steps.values()) + sum(
|
||||
pg.env_steps for pg in self.policy_collector_groups.values())
|
||||
|
||||
@override(_SampleCollector)
|
||||
@override(SampleCollector)
|
||||
def total_agent_steps(self) -> int:
|
||||
# Add the non-built ongoing-episode agent steps (still in the agent
|
||||
# collectors) + the already built agent steps.
|
||||
|
@ -540,13 +527,13 @@ class _SimpleListCollector(_SampleCollector):
|
|||
sum(pg.agent_steps for pg in
|
||||
self.policy_collector_groups.values())
|
||||
|
||||
@override(_SampleCollector)
|
||||
@override(SampleCollector)
|
||||
def get_inference_input_dict(self, policy_id: PolicyID) -> \
|
||||
Dict[str, TensorType]:
|
||||
policy = self.policy_map[policy_id]
|
||||
keys = self.forward_pass_agent_keys[policy_id]
|
||||
buffers = {k: self.agent_collectors[k].buffers for k in keys}
|
||||
view_reqs = policy.model.inference_view_requirements if \
|
||||
view_reqs = policy.model.view_requirements if \
|
||||
getattr(policy, "model", None) else policy.view_requirements
|
||||
|
||||
input_dict = {}
|
||||
|
@ -592,7 +579,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
|
||||
return input_dict
|
||||
|
||||
@override(_SampleCollector)
|
||||
@override(SampleCollector)
|
||||
def postprocess_episode(
|
||||
self,
|
||||
episode: MultiAgentEpisode,
|
||||
|
@ -725,7 +712,7 @@ class _SimpleListCollector(_SampleCollector):
|
|||
|
||||
return ma_batch
|
||||
|
||||
@override(_SampleCollector)
|
||||
@override(SampleCollector)
|
||||
def try_build_truncated_episode_multi_agent_batch(self) -> \
|
||||
List[Union[MultiAgentBatch, SampleBatch]]:
|
||||
batches = []
|
||||
|
|
|
@ -28,7 +28,10 @@ class MultiAgentEpisode:
|
|||
episode_id (int): Unique id identifying this trajectory.
|
||||
agent_rewards (dict): Summed rewards broken down by agent.
|
||||
custom_metrics (dict): Dict where the you can add custom metrics.
|
||||
user_data (dict): Dict that you can use for temporary storage.
|
||||
user_data (dict): Dict that you can use for temporary storage. E.g.
|
||||
in between two custom callbacks referring to the same episode.
|
||||
hist_data (dict): Dict mapping str keys to List[float] for storage of
|
||||
per-timestep float data throughout the episode.
|
||||
|
||||
Use case 1: Model-based rollouts in multi-agent:
|
||||
A custom compute_actions() function in a policy can inspect the
|
||||
|
|
|
@ -473,7 +473,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
# state.
|
||||
for pol in self.policy_map.values():
|
||||
if not pol._model_init_state_automatically_added:
|
||||
pol._update_model_inference_view_requirements_from_init_state()
|
||||
pol._update_model_view_requirements_from_init_state()
|
||||
|
||||
if (ray.is_initialized()
|
||||
and ray.worker._mode() != ray.worker.LOCAL_MODE):
|
||||
|
@ -543,6 +543,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
raise ValueError("Unsupported batch mode: {}".format(
|
||||
self.batch_mode))
|
||||
|
||||
# Create the IOContext for this worker.
|
||||
self.io_context: IOContext = IOContext(log_dir, policy_config,
|
||||
worker_index, self)
|
||||
self.reward_estimators: List[OffPolicyEstimator] = []
|
||||
|
@ -586,6 +587,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn,
|
||||
_use_trajectory_view_api=_use_trajectory_view_api,
|
||||
sample_collector_class=policy_config.get(
|
||||
"sample_collector_class"),
|
||||
)
|
||||
# Start the Sampler thread.
|
||||
self.sampler.start()
|
||||
|
@ -609,6 +612,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn,
|
||||
_use_trajectory_view_api=_use_trajectory_view_api,
|
||||
sample_collector_class=policy_config.get(
|
||||
"sample_collector_class"),
|
||||
)
|
||||
|
||||
self.input_reader: InputReader = input_creator(self.io_context)
|
||||
|
|
|
@ -6,13 +6,13 @@ import queue
|
|||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\
|
||||
TYPE_CHECKING, Union
|
||||
Type, TYPE_CHECKING, Union
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.evaluation.collectors.sample_collector import \
|
||||
_SampleCollector
|
||||
SampleCollector
|
||||
from ray.rllib.evaluation.collectors.simple_list_collector import \
|
||||
_SimpleListCollector
|
||||
SimpleListCollector
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
|
||||
from ray.rllib.evaluation.sample_batch_builder import \
|
||||
|
@ -119,26 +119,28 @@ class SyncSampler(SamplerInput):
|
|||
"""Sync SamplerInput that collects experiences when `get_data()` is called.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
policy_mapping_fn: Callable[[AgentID], PolicyID],
|
||||
preprocessors: Dict[PolicyID, Preprocessor],
|
||||
obs_filters: Dict[PolicyID, Filter],
|
||||
clip_rewards: bool,
|
||||
rollout_fragment_length: int,
|
||||
count_steps_by: str = "env_steps",
|
||||
callbacks: "DefaultCallbacks",
|
||||
horizon: int = None,
|
||||
multiple_episodes_in_batch: bool = False,
|
||||
tf_sess=None,
|
||||
clip_actions: bool = True,
|
||||
soft_horizon: bool = False,
|
||||
no_done_at_end: bool = False,
|
||||
observation_fn: "ObservationFunction" = None,
|
||||
_use_trajectory_view_api: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
policy_mapping_fn: Callable[[AgentID], PolicyID],
|
||||
preprocessors: Dict[PolicyID, Preprocessor],
|
||||
obs_filters: Dict[PolicyID, Filter],
|
||||
clip_rewards: bool,
|
||||
rollout_fragment_length: int,
|
||||
count_steps_by: str = "env_steps",
|
||||
callbacks: "DefaultCallbacks",
|
||||
horizon: int = None,
|
||||
multiple_episodes_in_batch: bool = False,
|
||||
tf_sess=None,
|
||||
clip_actions: bool = True,
|
||||
soft_horizon: bool = False,
|
||||
no_done_at_end: bool = False,
|
||||
observation_fn: "ObservationFunction" = None,
|
||||
_use_trajectory_view_api: bool = False,
|
||||
sample_collector_class: Optional[Type[SampleCollector]] = None):
|
||||
"""Initializes a SyncSampler object.
|
||||
|
||||
Args:
|
||||
|
@ -178,6 +180,9 @@ class SyncSampler(SamplerInput):
|
|||
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||
`_use_trajectory_view_api` to make generic trajectory views
|
||||
available to Models. Default: False.
|
||||
sample_collector_class (Optional[Type[SampleCollector]]): An
|
||||
optional Samplecollector sub-class to use to collect, store,
|
||||
and retrieve environment-, model-, and sampler data.
|
||||
"""
|
||||
|
||||
self.base_env = BaseEnv.to_base_env(env)
|
||||
|
@ -190,7 +195,9 @@ class SyncSampler(SamplerInput):
|
|||
self.extra_batches = queue.Queue()
|
||||
self.perf_stats = _PerfStats()
|
||||
if _use_trajectory_view_api:
|
||||
self.sample_collector = _SimpleListCollector(
|
||||
if not sample_collector_class:
|
||||
sample_collector_class = SimpleListCollector
|
||||
self.sample_collector = sample_collector_class(
|
||||
policies,
|
||||
clip_rewards,
|
||||
callbacks,
|
||||
|
@ -249,27 +256,30 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
from where they can be unqueued by the caller of `get_data()`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
policy_mapping_fn: Callable[[AgentID], PolicyID],
|
||||
preprocessors: Dict[PolicyID, Preprocessor],
|
||||
obs_filters: Dict[PolicyID, Filter],
|
||||
clip_rewards: bool,
|
||||
rollout_fragment_length: int,
|
||||
count_steps_by: str = "env_steps",
|
||||
callbacks: "DefaultCallbacks",
|
||||
horizon: int = None,
|
||||
multiple_episodes_in_batch: bool = False,
|
||||
tf_sess=None,
|
||||
clip_actions: bool = True,
|
||||
blackhole_outputs: bool = False,
|
||||
soft_horizon: bool = False,
|
||||
no_done_at_end: bool = False,
|
||||
observation_fn: "ObservationFunction" = None,
|
||||
_use_trajectory_view_api: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
policy_mapping_fn: Callable[[AgentID], PolicyID],
|
||||
preprocessors: Dict[PolicyID, Preprocessor],
|
||||
obs_filters: Dict[PolicyID, Filter],
|
||||
clip_rewards: bool,
|
||||
rollout_fragment_length: int,
|
||||
count_steps_by: str = "env_steps",
|
||||
callbacks: "DefaultCallbacks",
|
||||
horizon: int = None,
|
||||
multiple_episodes_in_batch: bool = False,
|
||||
tf_sess=None,
|
||||
clip_actions: bool = True,
|
||||
blackhole_outputs: bool = False,
|
||||
soft_horizon: bool = False,
|
||||
no_done_at_end: bool = False,
|
||||
observation_fn: "ObservationFunction" = None,
|
||||
_use_trajectory_view_api: bool = False,
|
||||
sample_collector_class: Optional[Type[SampleCollector]] = None,
|
||||
):
|
||||
"""Initializes a AsyncSampler object.
|
||||
|
||||
Args:
|
||||
|
@ -313,6 +323,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||
`_use_trajectory_view_api` to make generic trajectory views
|
||||
available to Models. Default: False.
|
||||
sample_collector_class (Optional[Type[SampleCollector]]): An
|
||||
optional Samplecollector sub-class to use to collect, store,
|
||||
and retrieve environment-, model-, and sampler data.
|
||||
"""
|
||||
for _, f in obs_filters.items():
|
||||
assert getattr(f, "is_concurrent", False), \
|
||||
|
@ -343,7 +356,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self.observation_fn = observation_fn
|
||||
self._use_trajectory_view_api = _use_trajectory_view_api
|
||||
if _use_trajectory_view_api:
|
||||
self.sample_collector = _SimpleListCollector(
|
||||
if not sample_collector_class:
|
||||
sample_collector_class = SimpleListCollector
|
||||
self.sample_collector = sample_collector_class(
|
||||
policies,
|
||||
clip_rewards,
|
||||
callbacks,
|
||||
|
@ -441,7 +456,7 @@ def _env_runner(
|
|||
no_done_at_end: bool,
|
||||
observation_fn: "ObservationFunction",
|
||||
_use_trajectory_view_api: bool = False,
|
||||
_sample_collector: Optional[_SampleCollector] = None,
|
||||
sample_collector: Optional[SampleCollector] = None,
|
||||
) -> Iterable[SampleBatchType]:
|
||||
"""This implements the common experience collection logic.
|
||||
|
||||
|
@ -480,8 +495,8 @@ def _env_runner(
|
|||
_use_trajectory_view_api (bool): Whether to use the (experimental)
|
||||
`_use_trajectory_view_api` to make generic trajectory views
|
||||
available to Models. Default: False.
|
||||
_sample_collector (Optional[_SampleCollector]): An optional
|
||||
_SampleCollector object to use
|
||||
sample_collector (Optional[SampleCollector]): An optional
|
||||
SampleCollector object to use
|
||||
|
||||
Yields:
|
||||
rollout (SampleBatch): Object containing state, action, reward,
|
||||
|
@ -498,20 +513,27 @@ def _env_runner(
|
|||
|
||||
# Trainer has a given `horizon` setting.
|
||||
if horizon:
|
||||
# `horizon` is larger than env's limit -> Error and explain how
|
||||
# to increase Env's own episode limit.
|
||||
# `horizon` is larger than env's limit.
|
||||
if max_episode_steps and horizon > max_episode_steps:
|
||||
raise ValueError(
|
||||
"Your `horizon` setting ({}) is larger than the Env's own "
|
||||
"timestep limit ({})! Try to increase the Env's limit via "
|
||||
"setting its `spec.max_episode_steps` property.".format(
|
||||
horizon, max_episode_steps))
|
||||
# Try to override the env's own max-step setting with our horizon.
|
||||
# If this won't work, throw an error.
|
||||
try:
|
||||
base_env.get_unwrapped()[0].spec.max_episode_steps = horizon
|
||||
base_env.get_unwrapped()[0]._max_episode_steps = horizon
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Your `horizon` setting ({}) is larger than the Env's own "
|
||||
"timestep limit ({}), which seems to be unsettable! Try "
|
||||
"to increase the Env's built-in limit to be at least as "
|
||||
"large as your wanted `horizon`.".format(
|
||||
horizon, max_episode_steps))
|
||||
# Otherwise, set Trainer's horizon to env's max-steps.
|
||||
elif max_episode_steps:
|
||||
horizon = max_episode_steps
|
||||
logger.debug(
|
||||
"No episode horizon specified, setting it to Env's limit ({}).".
|
||||
format(max_episode_steps))
|
||||
# No horizon/max_episode_steps -> Episodes may be infinitely long.
|
||||
else:
|
||||
horizon = float("inf")
|
||||
logger.debug("No episode horizon specified, assuming inf.")
|
||||
|
@ -594,7 +616,7 @@ def _env_runner(
|
|||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn,
|
||||
_sample_collector=_sample_collector,
|
||||
sample_collector=sample_collector,
|
||||
)
|
||||
else:
|
||||
active_envs, to_eval, outputs = _process_observations(
|
||||
|
@ -628,7 +650,7 @@ def _env_runner(
|
|||
eval_results = _do_policy_eval_w_trajectory_view_api(
|
||||
to_eval=to_eval,
|
||||
policies=policies,
|
||||
_sample_collector=_sample_collector,
|
||||
sample_collector=sample_collector,
|
||||
active_episodes=active_episodes,
|
||||
tf_sess=tf_sess,
|
||||
)
|
||||
|
@ -653,7 +675,7 @@ def _env_runner(
|
|||
policies=policies,
|
||||
clip_actions=clip_actions,
|
||||
_use_trajectory_view_api=_use_trajectory_view_api,
|
||||
_sample_collector=_sample_collector,
|
||||
sample_collector=sample_collector,
|
||||
)
|
||||
perf_stats.action_processing_time += time.time() - t3
|
||||
|
||||
|
@ -968,7 +990,7 @@ def _process_observations_w_trajectory_view_api(
|
|||
soft_horizon: bool,
|
||||
no_done_at_end: bool,
|
||||
observation_fn: "ObservationFunction",
|
||||
_sample_collector: _SampleCollector,
|
||||
sample_collector: SampleCollector,
|
||||
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
|
||||
RolloutMetrics, SampleBatchType]]]:
|
||||
"""Trajectory View API version of `_process_observations()`.
|
||||
|
@ -987,7 +1009,7 @@ def _process_observations_w_trajectory_view_api(
|
|||
episode: MultiAgentEpisode = active_episodes[env_id]
|
||||
|
||||
if not is_new_episode:
|
||||
_sample_collector.episode_step(episode.episode_id)
|
||||
sample_collector.episode_step(episode.episode_id)
|
||||
episode._add_agent_rewards(rewards[env_id])
|
||||
|
||||
# Check episode termination conditions.
|
||||
|
@ -1051,9 +1073,9 @@ def _process_observations_w_trajectory_view_api(
|
|||
|
||||
# Record transition info if applicable.
|
||||
if last_observation is None:
|
||||
_sample_collector.add_init_obs(episode, agent_id, env_id,
|
||||
policy_id, episode.length - 1,
|
||||
filtered_obs)
|
||||
sample_collector.add_init_obs(episode, agent_id, env_id,
|
||||
policy_id, episode.length - 1,
|
||||
filtered_obs)
|
||||
else:
|
||||
# Add actions, rewards, next-obs to collectors.
|
||||
values_dict = {
|
||||
|
@ -1079,7 +1101,7 @@ def _process_observations_w_trajectory_view_api(
|
|||
# Env infos for this agent.
|
||||
if "infos" in pol.view_requirements:
|
||||
values_dict["infos"] = agent_infos
|
||||
_sample_collector.add_action_reward_next_obs(
|
||||
sample_collector.add_action_reward_next_obs(
|
||||
episode.episode_id, agent_id, env_id, policy_id,
|
||||
agent_done, values_dict)
|
||||
|
||||
|
@ -1111,7 +1133,7 @@ def _process_observations_w_trajectory_view_api(
|
|||
# MultiAgentBatch from a single episode and add it to "outputs".
|
||||
# Otherwise, just postprocess and continue collecting across
|
||||
# episodes.
|
||||
ma_sample_batch = _sample_collector.postprocess_episode(
|
||||
ma_sample_batch = sample_collector.postprocess_episode(
|
||||
episode,
|
||||
is_done=is_done or (hit_horizon and not soft_horizon),
|
||||
check_dones=check_dones,
|
||||
|
@ -1170,7 +1192,7 @@ def _process_observations_w_trajectory_view_api(
|
|||
new_episode._set_last_observation(agent_id, filtered_obs)
|
||||
|
||||
# Add initial obs to buffer.
|
||||
_sample_collector.add_init_obs(
|
||||
sample_collector.add_init_obs(
|
||||
new_episode, agent_id, env_id, policy_id,
|
||||
new_episode.length - 1, filtered_obs)
|
||||
|
||||
|
@ -1183,7 +1205,7 @@ def _process_observations_w_trajectory_view_api(
|
|||
# Try to build something.
|
||||
if multiple_episodes_in_batch:
|
||||
sample_batches = \
|
||||
_sample_collector.try_build_truncated_episode_multi_agent_batch()
|
||||
sample_collector.try_build_truncated_episode_multi_agent_batch()
|
||||
if sample_batches:
|
||||
outputs.extend(sample_batches)
|
||||
|
||||
|
@ -1279,7 +1301,7 @@ def _do_policy_eval_w_trajectory_view_api(
|
|||
*,
|
||||
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
||||
policies: Dict[PolicyID, Policy],
|
||||
_sample_collector,
|
||||
sample_collector,
|
||||
active_episodes: Dict[str, MultiAgentEpisode],
|
||||
tf_sess: Optional["tf.Session"] = None,
|
||||
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
||||
|
@ -1291,7 +1313,7 @@ def _do_policy_eval_w_trajectory_view_api(
|
|||
be the batch's items for the model forward pass).
|
||||
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
|
||||
obj.
|
||||
_sample_collector (SampleCollector): The SampleCollector object to use.
|
||||
sample_collector (SampleCollector): The SampleCollector object to use.
|
||||
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
|
||||
batching TF policy evaluations.
|
||||
|
||||
|
@ -1313,7 +1335,7 @@ def _do_policy_eval_w_trajectory_view_api(
|
|||
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
policy: Policy = _get_or_raise(policies, policy_id)
|
||||
input_dict = _sample_collector.get_inference_input_dict(policy_id)
|
||||
input_dict = sample_collector.get_inference_input_dict(policy_id)
|
||||
eval_results[policy_id] = \
|
||||
policy.compute_actions_from_input_dict(
|
||||
input_dict,
|
||||
|
@ -1343,7 +1365,7 @@ def _process_policy_eval_results(
|
|||
policies: Dict[PolicyID, Policy],
|
||||
clip_actions: bool,
|
||||
_use_trajectory_view_api: bool = False,
|
||||
_sample_collector=None,
|
||||
sample_collector=None,
|
||||
) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
|
||||
"""Process the output of policy neural network evaluation.
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
config,
|
||||
env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv")
|
||||
policy = trainer.get_policy()
|
||||
view_req_model = policy.model.inference_view_requirements
|
||||
view_req_model = policy.model.view_requirements
|
||||
view_req_policy = policy.view_requirements
|
||||
assert len(view_req_model) == 1, view_req_model
|
||||
assert len(view_req_policy) == 8, view_req_policy
|
||||
|
@ -108,7 +108,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
for _ in framework_iterator(config):
|
||||
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
view_req_model = policy.model.inference_view_requirements
|
||||
view_req_model = policy.model.view_requirements
|
||||
view_req_policy = policy.view_requirements
|
||||
# 7=obs, prev-a + r, 2x state-in, 2x state-out.
|
||||
assert len(view_req_model) == 7, view_req_model
|
||||
|
|
|
@ -21,7 +21,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
|||
|
||||
self.model = _fake_model()
|
||||
self.model.time_major = True
|
||||
self.model.inference_view_requirements = {
|
||||
self.model.view_requirements = {
|
||||
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
||||
SampleBatch.EPS_ID: ViewRequirement(),
|
||||
"env_id": ViewRequirement(),
|
||||
|
@ -33,12 +33,12 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
|||
SampleBatch.REWARDS, shift=-1),
|
||||
}
|
||||
for i in range(2):
|
||||
self.model.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
self.model.view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift=-1,
|
||||
space=self.state_space)
|
||||
self.model.inference_view_requirements[
|
||||
self.model.view_requirements[
|
||||
"state_out_{}".format(i)] = \
|
||||
ViewRequirement(space=self.state_space)
|
||||
|
||||
|
@ -50,7 +50,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
|||
SampleBatch.REWARDS: ViewRequirement(),
|
||||
SampleBatch.DONES: ViewRequirement(),
|
||||
},
|
||||
**self.model.inference_view_requirements)
|
||||
**self.model.view_requirements)
|
||||
|
||||
@override(Policy)
|
||||
def is_recurrent(self):
|
||||
|
@ -97,7 +97,7 @@ class EpisodeEnvAwareAttentionPolicy(RandomPolicy):
|
|||
pass
|
||||
|
||||
self.model = _fake_model()
|
||||
self.model.inference_view_requirements = {
|
||||
self.model.view_requirements = {
|
||||
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
||||
SampleBatch.EPS_ID: ViewRequirement(),
|
||||
"env_id": ViewRequirement(),
|
||||
|
@ -114,7 +114,7 @@ class EpisodeEnvAwareAttentionPolicy(RandomPolicy):
|
|||
}
|
||||
|
||||
self.view_requirements = dict(super()._get_default_view_requirements(),
|
||||
**self.model.inference_view_requirements)
|
||||
**self.model.view_requirements)
|
||||
|
||||
@override(Policy)
|
||||
def is_recurrent(self):
|
||||
|
|
|
@ -62,7 +62,7 @@ class ModelV2:
|
|||
self._last_output = None
|
||||
self.time_major = self.model_config.get("_time_major")
|
||||
# Basic view requirement for all models: Use the observation as input.
|
||||
self.inference_view_requirements = {
|
||||
self.view_requirements = {
|
||||
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
|
||||
}
|
||||
|
||||
|
@ -341,7 +341,7 @@ class ModelV2:
|
|||
}
|
||||
|
||||
input_dict = {}
|
||||
for view_col, view_req in self.inference_view_requirements.items():
|
||||
for view_col, view_req in self.view_requirements.items():
|
||||
# Create batches of size 1 (single-agent input-dict).
|
||||
data_col = view_req.data_col or view_col
|
||||
if index == "last":
|
||||
|
|
|
@ -284,22 +284,22 @@ class GTrXLNet(RecurrentNetwork):
|
|||
self.register_variables(self.trxl_model.variables)
|
||||
self.trxl_model.summary()
|
||||
|
||||
# Setup inference view (`memory-inference` x past observations +
|
||||
# current one (0))
|
||||
# 1 to `num_transformer_units`: Memory data (one per transformer unit).
|
||||
# __sphinx_doc_begin__
|
||||
# Setup trajectory views (`memory-inference` x past memory outs).
|
||||
for i in range(self.num_transformer_units):
|
||||
space = Box(-1.0, 1.0, shape=(self.attn_dim, ))
|
||||
self.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
self.view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift="-{}:-1".format(self.memory_inference),
|
||||
# Repeat the incoming state every max-seq-len times.
|
||||
batch_repeat_value=self.max_seq_len,
|
||||
space=space)
|
||||
self.inference_view_requirements["state_out_{}".format(i)] = \
|
||||
self.view_requirements["state_out_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
space=space,
|
||||
used_for_training=False)
|
||||
# __sphinx_doc_end__
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state: List[TensorType],
|
||||
|
|
|
@ -184,11 +184,11 @@ class LSTMWrapper(RecurrentNetwork):
|
|||
|
||||
# Add prev-a/r to this model's view, if required.
|
||||
if model_config["lstm_use_prev_action"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||
shift=-1)
|
||||
if model_config["lstm_use_prev_reward"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
self.view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
|
|
|
@ -160,19 +160,17 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
|
|||
self.values_out = SlimFC(
|
||||
in_size=self.attn_dim, out_size=1, activation_fn=None)
|
||||
|
||||
# Setup inference view (`memory-inference` x past observations +
|
||||
# current one (0))
|
||||
# 1 to `num_transformer_units`: Memory data (one per transformer unit).
|
||||
# Setup trajectory views (`memory-inference` x past memory outs).
|
||||
for i in range(self.num_transformer_units):
|
||||
space = Box(-1.0, 1.0, shape=(self.attn_dim, ))
|
||||
self.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
self.view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift="-{}:-1".format(self.memory_inference),
|
||||
# Repeat the incoming state every max-seq-len times.
|
||||
batch_repeat_value=self.max_seq_len,
|
||||
space=space)
|
||||
self.inference_view_requirements["state_out_{}".format(i)] = \
|
||||
self.view_requirements["state_out_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
space=space,
|
||||
used_for_training=False)
|
||||
|
|
|
@ -166,14 +166,16 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
|||
activation_fn=None,
|
||||
initializer=torch.nn.init.xavier_uniform_)
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
# Add prev-a/r to this model's view, if required.
|
||||
if model_config["lstm_use_prev_action"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
self.view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||
shift=-1)
|
||||
if model_config["lstm_use_prev_reward"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
self.view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
# __sphinx_doc_end__
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
import logging
|
||||
import numpy as np
|
||||
import threading
|
||||
|
@ -14,15 +15,16 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
@PublicAPI
|
||||
class InputReader:
|
||||
class InputReader(metaclass=ABCMeta):
|
||||
"""Input object for loading experiences in policy evaluation."""
|
||||
|
||||
@abstractmethod
|
||||
@PublicAPI
|
||||
def next(self):
|
||||
"""Return the next batch of experiences read.
|
||||
"""Returns the next batch of experiences read.
|
||||
|
||||
Returns:
|
||||
SampleBatch or MultiAgentBatch read.
|
||||
Union[SampleBatch, MultiAgentBatch]: The experience read.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -171,7 +171,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
model_config=self.config["model"],
|
||||
framework="tf")
|
||||
# Auto-update model's inference view requirements, if recurrent.
|
||||
self._update_model_inference_view_requirements_from_init_state()
|
||||
self._update_model_view_requirements_from_init_state()
|
||||
|
||||
if existing_inputs:
|
||||
self._state_inputs = [
|
||||
|
@ -186,8 +186,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
get_placeholder(
|
||||
space=vr.space,
|
||||
time_axis=not isinstance(vr.shift, int),
|
||||
) for k, vr in
|
||||
self.model.inference_view_requirements.items()
|
||||
) for k, vr in self.model.view_requirements.items()
|
||||
if k.startswith("state_in_")
|
||||
]
|
||||
else:
|
||||
|
@ -200,7 +199,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
# Add NEXT_OBS, STATE_IN_0.., and others.
|
||||
self.view_requirements = self._get_default_view_requirements()
|
||||
# Combine view_requirements for Model and Policy.
|
||||
self.view_requirements.update(self.model.inference_view_requirements)
|
||||
self.view_requirements.update(self.model.view_requirements)
|
||||
|
||||
# Setup standard placeholders.
|
||||
if existing_inputs is not None:
|
||||
|
@ -560,7 +559,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
all_accessed_keys = \
|
||||
train_batch.accessed_keys | batch_for_postproc.accessed_keys | \
|
||||
batch_for_postproc.added_keys | set(
|
||||
self.model.inference_view_requirements.keys())
|
||||
self.model.view_requirements.keys())
|
||||
|
||||
TFPolicy._initialize_loss(self, loss, [(k, v)
|
||||
for k, v in train_batch.items()
|
||||
|
@ -584,7 +583,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
# Tag those only needed for post-processing.
|
||||
for key in batch_for_postproc.accessed_keys:
|
||||
if key not in train_batch.accessed_keys and \
|
||||
key not in self.model.inference_view_requirements:
|
||||
key not in self.model.view_requirements:
|
||||
if key in self.view_requirements:
|
||||
self.view_requirements[key].used_for_training = False
|
||||
if key in self._loss_input_dict:
|
||||
|
@ -597,7 +596,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
|
||||
SampleBatch.UNROLL_ID, SampleBatch.DONES,
|
||||
SampleBatch.REWARDS] and \
|
||||
key not in self.model.inference_view_requirements:
|
||||
key not in self.model.view_requirements:
|
||||
# If user deleted this key manually in postprocessing
|
||||
# fn, warn about it and do not remove from
|
||||
# view-requirements.
|
||||
|
|
|
@ -256,15 +256,14 @@ def build_eager_tf_policy(name,
|
|||
framework=self.framework,
|
||||
)
|
||||
# Auto-update model's inference view requirements, if recurrent.
|
||||
self._update_model_inference_view_requirements_from_init_state()
|
||||
self._update_model_view_requirements_from_init_state()
|
||||
|
||||
self.exploration = self._create_exploration()
|
||||
self._state_inputs = self.model.get_initial_state()
|
||||
self._is_recurrent = len(self._state_inputs) > 0
|
||||
|
||||
# Combine view_requirements for Model and Policy.
|
||||
self.view_requirements.update(
|
||||
self.model.inference_view_requirements)
|
||||
self.view_requirements.update(self.model.view_requirements)
|
||||
|
||||
if before_loss_init:
|
||||
before_loss_init(self, observation_space, action_space, config)
|
||||
|
|
|
@ -669,7 +669,7 @@ class Policy(metaclass=ABCMeta):
|
|||
for key in batch_for_postproc.accessed_keys:
|
||||
if key not in train_batch.accessed_keys and \
|
||||
key in self.view_requirements and \
|
||||
key not in self.model.inference_view_requirements:
|
||||
key not in self.model.view_requirements:
|
||||
self.view_requirements[key].used_for_training = False
|
||||
# Remove those not needed at all (leave those that are needed
|
||||
# by Sampler to properly execute sample collection).
|
||||
|
@ -679,7 +679,7 @@ class Policy(metaclass=ABCMeta):
|
|||
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
|
||||
SampleBatch.UNROLL_ID, SampleBatch.DONES,
|
||||
SampleBatch.REWARDS] and \
|
||||
key not in self.model.inference_view_requirements:
|
||||
key not in self.model.view_requirements:
|
||||
# If user deleted this key manually in postprocessing
|
||||
# fn, warn about it and do not remove from
|
||||
# view-requirements.
|
||||
|
@ -738,12 +738,12 @@ class Policy(metaclass=ABCMeta):
|
|||
# columns in the resulting batch may not all have the same batch size.
|
||||
return SampleBatch(ret, _dont_check_lens=True)
|
||||
|
||||
def _update_model_inference_view_requirements_from_init_state(self):
|
||||
def _update_model_view_requirements_from_init_state(self):
|
||||
"""Uses Model's (or this Policy's) init state to add needed ViewReqs.
|
||||
|
||||
Can be called from within a Policy to make sure RNNs automatically
|
||||
update their internal state-related view requirements.
|
||||
Changes the `self.inference_view_requirements` dict.
|
||||
Changes the `self.view_requirements` dict.
|
||||
"""
|
||||
self._model_init_state_automatically_added = True
|
||||
model = getattr(self, "model", None)
|
||||
|
@ -752,7 +752,7 @@ class Policy(metaclass=ABCMeta):
|
|||
for i, state in enumerate(obj.get_initial_state()):
|
||||
space = Box(-1.0, 1.0, shape=state.shape) if \
|
||||
hasattr(state, "shape") else state
|
||||
view_reqs = model.inference_view_requirements if model else \
|
||||
view_reqs = model.view_requirements if model else \
|
||||
self.view_requirements
|
||||
view_reqs["state_in_{}".format(i)] = ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
|
|
|
@ -9,7 +9,6 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils import add_mixins, force_list, NullContextManager
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch, try_import_jax
|
||||
|
@ -74,8 +73,6 @@ def build_policy_class(
|
|||
apply_gradients_fn: Optional[Callable[
|
||||
[Policy, "torch.optim.Optimizer"], None]] = None,
|
||||
mixins: Optional[List[type]] = None,
|
||||
view_requirements_fn: Optional[Callable[[Policy], Dict[
|
||||
str, ViewRequirement]]] = None,
|
||||
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
|
||||
) -> Type[TorchPolicy]:
|
||||
"""Helper function for creating a new Policy class at runtime.
|
||||
|
@ -181,9 +178,6 @@ def build_policy_class(
|
|||
mixins (Optional[List[type]]): Optional list of any class mixins for
|
||||
the returned policy class. These mixins will be applied in order
|
||||
and will have higher precedence than the TorchPolicy class.
|
||||
view_requirements_fn (Optional[Callable[[Policy],
|
||||
Dict[str, ViewRequirement]]]): An optional callable to retrieve
|
||||
additional train view requirements for this policy.
|
||||
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
|
||||
Optional callable that returns the divisibility requirement for
|
||||
sample batches. If None, will assume a value of 1.
|
||||
|
@ -260,12 +254,8 @@ def build_policy_class(
|
|||
get_batch_divisibility_req=get_batch_divisibility_req,
|
||||
)
|
||||
|
||||
# Update this Policy's ViewRequirements (if function given).
|
||||
if callable(view_requirements_fn):
|
||||
self.view_requirements.update(view_requirements_fn(self))
|
||||
# Merge Model's view requirements into Policy's.
|
||||
self.view_requirements.update(
|
||||
self.model.inference_view_requirements)
|
||||
self.view_requirements.update(self.model.view_requirements)
|
||||
|
||||
_before_loss_init = before_loss_init or after_init
|
||||
if _before_loss_init:
|
||||
|
|
|
@ -147,7 +147,7 @@ class TFPolicy(Policy):
|
|||
self.model = model
|
||||
# Auto-update model's inference view requirements, if recurrent.
|
||||
if self.model is not None:
|
||||
self._update_model_inference_view_requirements_from_init_state()
|
||||
self._update_model_view_requirements_from_init_state()
|
||||
|
||||
self.exploration = self._create_exploration()
|
||||
self._sess = sess
|
||||
|
|
|
@ -113,9 +113,9 @@ class TorchPolicy(Policy):
|
|||
self._state_inputs = self.model.get_initial_state()
|
||||
self._is_recurrent = len(self._state_inputs) > 0
|
||||
# Auto-update model's inference view requirements, if recurrent.
|
||||
self._update_model_inference_view_requirements_from_init_state()
|
||||
self._update_model_view_requirements_from_init_state()
|
||||
# Combine view_requirements for Model and Policy.
|
||||
self.view_requirements.update(self.model.inference_view_requirements)
|
||||
self.view_requirements.update(self.model.view_requirements)
|
||||
|
||||
self.exploration = self._create_exploration()
|
||||
self.unwrapped_model = model # used to support DistributedDataParallel
|
||||
|
|
|
@ -22,7 +22,7 @@ class ViewRequirement:
|
|||
|
||||
Examples:
|
||||
>>> # The default ViewRequirement for a Model is:
|
||||
>>> req = [ModelV2].inference_view_requirements
|
||||
>>> req = [ModelV2].view_requirements
|
||||
>>> print(req)
|
||||
{"obs": ViewRequirement(shift=0)}
|
||||
"""
|
||||
|
|
Loading…
Add table
Reference in a new issue