From 391cdfae8c566c141c03832cfff3c3e0cb9e61eb Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Wed, 30 Dec 2020 20:32:21 -0500 Subject: [PATCH] [RLlib] Trajectory view API docs. (#12718) --- doc/source/images/rllib-batch-modes.svg | 1 + doc/source/images/rllib-sample-collection.svg | 1 + .../images/rllib-trajectory-view-example.svg | 1 + doc/source/index.rst | 1 + doc/source/rllib-algorithms.rst | 2 +- doc/source/rllib-env.rst | 2 +- doc/source/rllib-models.rst | 2 +- doc/source/rllib-sample-collection.rst | 337 ++++++++++++++++++ doc/source/rllib-toc.rst | 6 +- doc/source/rllib.rst | 7 +- rllib/agents/trainer.py | 6 + .../evaluation/collectors/sample_collector.py | 41 ++- .../collectors/simple_list_collector.py | 55 ++- rllib/evaluation/episode.py | 5 +- rllib/evaluation/rollout_worker.py | 7 +- rllib/evaluation/sampler.py | 166 +++++---- .../tests/test_trajectory_view_api.py | 4 +- .../policy/episode_env_aware_policy.py | 12 +- rllib/models/modelv2.py | 4 +- rllib/models/tf/attention_net.py | 10 +- rllib/models/tf/recurrent_net.py | 4 +- rllib/models/torch/attention_net.py | 8 +- rllib/models/torch/recurrent_net.py | 6 +- rllib/offline/input_reader.py | 8 +- rllib/policy/dynamic_tf_policy.py | 13 +- rllib/policy/eager_tf_policy.py | 5 +- rllib/policy/policy.py | 10 +- rllib/policy/policy_template.py | 12 +- rllib/policy/tf_policy.py | 2 +- rllib/policy/torch_policy.py | 4 +- rllib/policy/view_requirement.py | 2 +- 31 files changed, 571 insertions(+), 173 deletions(-) create mode 100644 doc/source/images/rllib-batch-modes.svg create mode 100644 doc/source/images/rllib-sample-collection.svg create mode 100644 doc/source/images/rllib-trajectory-view-example.svg create mode 100644 doc/source/rllib-sample-collection.rst diff --git a/doc/source/images/rllib-batch-modes.svg b/doc/source/images/rllib-batch-modes.svg new file mode 100644 index 000000000..4dd7d5978 --- /dev/null +++ b/doc/source/images/rllib-batch-modes.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/images/rllib-sample-collection.svg b/doc/source/images/rllib-sample-collection.svg new file mode 100644 index 000000000..585e70434 --- /dev/null +++ b/doc/source/images/rllib-sample-collection.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/images/rllib-trajectory-view-example.svg b/doc/source/images/rllib-trajectory-view-example.svg new file mode 100644 index 000000000..f58aaaae1 --- /dev/null +++ b/doc/source/images/rllib-trajectory-view-example.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/doc/source/index.rst b/doc/source/index.rst index 8d1e1aed8..a2edb5432 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -259,6 +259,7 @@ Papers rllib-env.rst rllib-models.rst rllib-algorithms.rst + rllib-sample-collection.rst rllib-offline.rst rllib-concepts.rst rllib-examples.rst diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index aac235e94..dfa1bde00 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -50,7 +50,7 @@ Exploration-based plug-ins (can be combined with any algo) ============================= ========== ======================= ================== =========== ===================== Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support ============================= ========== ======================= ================== =========== ===================== -`Curiosity`_ torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_ +`Curiosity`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_ ============================= ========== ======================= ================== =========== ===================== .. _`A2C, A3C`: rllib-algorithms.html#a3c diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index efc9b5815..3b59778fc 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -5,7 +5,7 @@ RLlib works with several different types of environments, including `OpenAI Gym .. tip:: - Not all environments work with all algorithms. Check out the algorithm `feature compatibility matrix `__ for more information. + Not all environments work with all algorithms. Check out the `algorithm overview `__ for more information. .. image:: rllib-envs.svg diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 139f7c4c2..01ce506e1 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -389,7 +389,7 @@ Custom models can be used to work with environments where (1) the set of valid a return action_logits + inf_mask, state -Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_actions_cartpole.py `__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `feature compatibility matrix `__. +Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_actions_cartpole.py `__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `algorithm overview `__. Autoregressive Action Distributions diff --git a/doc/source/rllib-sample-collection.rst b/doc/source/rllib-sample-collection.rst new file mode 100644 index 000000000..c46a1d94d --- /dev/null +++ b/doc/source/rllib-sample-collection.rst @@ -0,0 +1,337 @@ +RLlib Sample Collection and Trajectory Views +============================================ + +The SampleCollector Class is Used to Store and Retrieve Temporary Data +---------------------------------------------------------------------- + +RLlib's `RolloutWorkers `__, +when running against a live environment, use the ``SamplerInput`` class to interact +with that environment and produce batches of experiences. +The two implemented sub-classes of ``SamplerInput`` are ``SyncSampler`` and ``AsyncSampler`` +(residing under the ``RolloutWorker.sampler`` property). + +In case the "_use_trajectory_view_api" top-level config key is set to True +(by default since version >=1.1.0), every such sampler object will use the +``SampleCollector`` API to store and retrieve temporary environment-, model-, and other data +during rollouts (see figure below). + +.. Edit figure below at: https://docs.google.com/drawings/d/1ZdNUU3ChwiUeT-DBRxvLAsbEPEqEFWSPZcOyVy3KxVg/edit +.. image:: images/rllib-sample-collection.svg + +**Sample collection process implemented by RLlib:** +The Policy's model tells the Sampler and its SampleCollector object, which data to store and +how to present it back to the dependent methods (e.g. `Model.compute_actions()`). +This is done using a dict that maps strings (column names) to `ViewRequirement` objects (details see below). + + +The exact behavior for a single such rollout and the number of environment transitions therein +are determined by the following Trainer config keys: + +**batch_mode [truncate_episodes|complete_episodes]**: + *truncated_episodes (default value)*: + Rollouts are performed + over exactly ``rollout_fragment_length`` (see below) number of steps. Thereby, steps are + counted as either environment steps or as individual agent steps (see ``count_steps_as`` below). + It does not matter, whether one or more episodes end within this rollout or whether + the rollout starts in the middle of an already ongoing episode. + *complete_episodes*: + Each rollout is exactly one episode long and always starts + at the beginning of an episode. It does not matter how long an episode lasts. + The ``rollout_fragment_length`` setting will be ignored. Note that you have to be + careful when chosing ``complete_episodes`` as batch_mode: If your environment does not + terminate easily, this setting could lead to enormous batch sizes. + +**rollout_fragment_length [int]**: + The exact number of environment- or agent steps to + be performed per rollout, if the ``batch_mode`` setting (see above) is "truncate_episodes". + If ``batch_mode`` is "complete_episodes", ``rollout_fragment_length`` is ignored, + The unit to count fragments in is set via ``multiagent.count_steps_by=[env_steps|agent_steps]`` + (within the ``multiagent`` config dict). + +.. Edit figure below at: https://docs.google.com/drawings/d/1uRNGImBNq8gv3bBoFX_HernGyeovtCB3wKpZ71c0VE4/edit +.. image:: images/rllib-batch-modes.svg + +**Above:** The two supported batch modes in RLlib. For "truncated_episodes", +batches can a) span over more than one episode, b) end in the middle of an episode, and +c) start in the middle of an episode. Also, `Policy.postprocess_trajectory()` is always +called at the end of a rollout-fragment (red lines on right side) as well as at the end +of each episode (arrow heads). This way, RLlib makes sure that the +`Policy.postprocess_trajectory()` method never sees data from more than one episode. + + +**multiagent.count_steps_by [env_steps|agent_steps]**: + Within the Trainer's ``multiagent`` config dict, you can set the unit, by which RLlib will count a) rollout fragment lengths as well as b) the size of the final train_batch (see below). The two supported values are: + + *env_steps (default)*: + Each call to ``[Env].step()`` is counted as one. It does not + matter, how many individual agents are stepping simultaneously in this very call + (not all existing agents in the environment may step at the same time). + *agent_steps*: + In a multi-agent environment, count each individual agent's step + as one. For example, if N agents are in an environment and all these N agents + always step at the same time, a single env step corresponds to N agent steps. + Note that in the single-agent case, ``env_steps`` and ``agent_steps`` are the same thing. + +**horizon [int]**: + Some environments are limited by default in the number of maximum timesteps + an episode can last. This limit is called the "horizon" of an episode. + For example, for CartPole-v0, the maximum number of steps per episode is 200 by default. + You can overwrite this setting, however, by using the ``horizon`` config. + If provided, RLlib will first try to increase the environment's built-in horizon + setting (e.g. openAI gym Envs have a ``spec.max_episode_steps`` property), if the user + provided horizon is larger than this env-specific setting. In either case, no episode + is allowed to exceed the given ``horizon`` number of timesteps (RLlib will + artificially terminate an episode if this limit is hit). + +**soft_horizon [bool]**: + False by default. If set to True, the environment will + a) not be reset when reaching ``horizon`` and b) no ``done=True`` will be set + in the trajectory data sent to the postprocessors and training (``done`` will remain + False at the horizon). + +**no_done_at_end [bool]**: + Never set ``done=True``, at the end of an episode or when any + artificial horizon is reached. + +To trigger a single rollout, RLlib calls ``RolloutWorker.sample()``, which returns +a SampleBatch or MultiAgentBatch object representing all the data collected during that +rollout. These batches are then usually further concatenated (from the ``num_workers`` +parallelized RolloutWorkers) to form a final train batch. The size of that train batch is determined +by the ``train_batch_size`` config parameter. Train batches are usually sent to the Policy's +``learn_on_batch`` method, which handles loss- and gradient calculations, and optimizer stepping. + +RLlib's default ``SampleCollector`` class is the ``SimpleListCollector``, which appends single timestep data (e.g. actions) +to lists, then builds SampleBatches from these and sends them to the downstream processing functions. +It thereby tries to avoid collecting duplicate data separately (OBS and NEXT_OBS use the same underlying list). +If you want to implement your own collection logic and data structures, you can sub-class ``SampleCollector`` +and specify that new class under the Trainer's "sample_collector" config key. + +Let's now look at how the Policy's Model lets the RolloutWorker and its SampleCollector +know, what data in the ongoing episode/trajectory to use for the different required method calls +during rollouts. These method calls in particular are: +``Policy.compute_actions_from_input_dict()`` to compute actions to be taken in an episode. +``Policy.postprocess_trajectory()``, which is called after an episode ends or a rollout hit its +``rollout_fragment_length`` limit (in ``batch_mode=truncated_episodes``), and ``Policy.learn_on_batch()``, +which is called with a "train_batch" to improve the policy. + +Trajectory View API +------------------- + +The trajectory view API allows custom models to define what parts of the trajectory they +require in order to execute the forward pass. For example, in the simplest case, a model might +only look at the latest observation. However, an RNN- or attention based model could look +at previous states emitted by the model, concatenate previously seen rewards with the current observation, +or require the entire range of the n most recent observations. + +The trajectory view API lets models define these requirements and lets RLlib gather the required +data for the forward pass in an efficient way. + +Since the following methods all call into the model class, they are all indirectly using the trajectory view API. +It is important to note that the API is only accessible to the user via the model classes +(see below on how to setup trajectory view requirements for a custom model). + +In particular, the methods receiving inputs that depend on a Model's trajectory view rules are: + +a) ``Policy.compute_actions_from_input_dict()`` +b) ``Policy.postprocess_trajectory()`` and +c) ``Policy.learn_on_batch()`` (and consecutively: the Policy's loss function). + +The input data to these methods can stem from either the environment (observations, rewards, and env infos), +the model itself (previously computed actions, internal state outputs, action-probs, etc..) +or the Sampler (e.g. agent index, env ID, episode ID, timestep, etc..). +All data has an associated time axis, which is 0-based, meaning that the first action taken, the +first reward received in an episode, and the first observation (directly after a reset) +all have t=0. + +The idea is to allow more flexibility and standardization in how a model defines required +"views" on the ongoing trajectory (during action computations/inference), past episodes (training +on a batch), or even trajectories of other agents in the same episode, some of which +may even use a different policy. + +Such a "view requirements" formalism is helpful when having to support more complex model +setups like RNNs, attention nets, observation image framestacking (e.g. for Atari), +and building multi-agent communication channels. + +The way to define a set of rules used for making the Model see certain +data is through a "view requirements dict", residing in the ``Policy.model.view_requirements`` +property. +View requirements dicts map strings (column names), such as "obs" or "actions" to +a ``ViewRequirement`` object, which defines the exact conditions by which this column +should be populated with data. + +View Requirement Dictionaries +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +View requirements are stored within the ``view_requirements`` property of the ``ModelV2`` +class. +You can acccess it like this: + +.. code-block:: python + + my_simple_model = ModelV2(...) + print(my_simple_model.view_requirements) + >>>{"obs": ViewRequirement(shift=0, space=[observation space])} + + my_lstm_model = LSTMModel(...) + print(my_lstm_model.view_requirements) + >>>{ + >>> "obs": ViewRequirement(shift=0, space=[observation space]), + >>> "prev_actions": ViewRequirement(shift=-1, data_col="actions", space=[action space]), + >>> "prev_rewards": ViewRequirement(shift=-1, data_col="rewards"), + >>>} + +The ``view_requirements`` property holds a dictionary mapping +string keys (e.g. "actions", "rewards", "next_obs", etc..) +to a ``ViewRequirement`` object. This ``ViewRequirement`` object determines what exact data to +provide under the given key in case a SampleBatch or a single-timestep (action computing) "input dict" +needs to be build and fed into one of the above ModelV2- or Policy methods. + +.. Edit figure below at: https://docs.google.com/drawings/d/1YEPUtMrRXmWfvM0E6mD3VsOaRlLV7DtctF-yL96VHGg/edit +.. image:: images/rllib-trajectory-view-example.svg + +**Above:** An example `ViewRequirements` dict that causes the current observation +and the previous action to be available in each compute_action call, as +well as for the Policy's `postprocess_trajectory()` function (and train batch). +A similar setup is often used by LSTM/RNN-based models. + + +The ViewRequirement class +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Here is a description of the constructor-settable properties of a ViewRequirement +object and what each of these properties controls. + +**data_col**: + An optional string key referencing the underlying data to use to + create the view. If not provided, assumes that there is data under the + dict-key under which this ViewRequirement resides. + + Examples: + +.. code-block:: python + + ModelV2.view_requirements = {"rewards": ViewRequirements(shift=0)} + # -> implies that the underlying data to use are the collected rewards + # from the environment. + + ModelV2.view_requirements = {"prev_rewards": ViewRequirements(data_col="rewards", shift=-1)} + # -> means that the actual data used to create the "prev_rewards" column + # is the "rewards" data from the environment (shifted by 1 timestep). + +**space**: + An optional gym.Space used as a hint for the SampleCollector to know, + how to fill timesteps before the episode actually started (e.g. if + shift=-2, we need dummy data at timesteps -2 and -1). + +**shift [int]**: + An int, a list of ints, or a range string (e.g. "-50:-1") to indicate + which time offsets or ranges of the underlying data to use for the view. + + Examples: + +.. code-block:: python + + shift=0 # -> Use the data under ``data_col`` as is. + shift=1 # -> Use the data under ``data_col``, but shifted by +1 timestep + # (used by e.g. next_obs views). + shift=-1 # -> Use the data under ``data_col``, but shifted by -1 timestep + # (used by e.g. prev_actions views). + shift=[-2, -1] # -> Use the data under ``data_col``, but always provide 2 values + # at each timestep (the two previous ones). + # Could be used e.g. to feed the last two actions or rewards into an LSTM. + shift="-50:-1" # -> Use the data under ``data_col``, but always provide a range of + # the last 50 timesteps (used by our attention nets). + +**used_for_training [bool]**: + True by default. If False, the column will not be available inside the train batch (arriving in the + Policy's loss function). + RLlib will automatically switch this to False for a given column, if it detects during + Policy initialization that that column is not accessed inside the loss function (see below). + +How does RLlib determine, which Views are required? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When initializing a Policy, it automatically determines how to later build batches +for postprocessing, loss function calls, and action computations, based on +the Model's ``view_requirements`` dict. It does so by sending generic dummy batches +through its ``compute_actions_from_input_dict``, ``postprocess_trajectory``, and loss functions +and then checks, which fields in these dummy batches get accessed, overwritten, deleted or added. +Based on these test passes, the Policy then throws out those ViewRequirements from an initial +very broad list, that it deems unnecessary. This procedure saves a lot of data copying +during later rollouts, batch transfers (via ray) and loss calculations and makes things like +manually deleting columns from a SampleBatch (e.g. PPO used to delete the "next_obs" column +inside the postprocessing function) unnecessary. +Note that the "rewards" and "dones" columns are never discarded and thus should always +arrive in your loss function's SampleBatch (``train_batch`` arg). + + +Setting ViewRequirements manually in your Model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you need to specify special view requirements for your model, you can add +columns to the Model's ``view_requirements`` dict in the +Model's constructor. + +For example, our auto-LSTM wrapper classes (tf and torch) have these additional +lines in their constructors (torch version shown here): + +.. literalinclude:: ../../rllib/models/torch/recurrent_net.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + +This makes sure that, if the users requires this via the model config, previous rewards +and/or previous actions are added properly to the ``compute_actions`` input-dicts and SampleBatches +used for postprocessing and training. + + +Another example are our attention nets, which make sure the last n (memory) model outputs +are always fed back into the model on the next time step (tf version shown here). + +.. literalinclude:: ../../rllib/models/tf/attention_net.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + + +Setting ViewRequirements manually after Policy construction +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Here is a simple example, of how you can modify and add to the ViewRequirements dict +even after policy (or RolloutWorker) creation. However, note that it's better to +make these modifications to your batches in your postprocessing function: + +.. code-block:: python + + # Modify view_requirements in the Policy object. + action_space = Discrete(2) + rollout_worker = RolloutWorker( + env_creator=lambda _: gym.make("CartPole-v0"), + policy_config=ppo.DEFAULT_CONFIG, + policy_spec=ppo.PPOTorchPolicy, + ) + policy = rollout_worker.policy_map["default_policy"] + # Add the next action to the view reqs of the policy. + # This should be visible then in postprocessing and train batches. + policy.view_requirements["next_actions"] = ViewRequirement( + SampleBatch.ACTIONS, shift=1, space=action_space) + # Check, whether a sampled batch has the requested `next_actions` view. + batch = rollout_worker.sample() + self.assertTrue("next_actions" in batch.data) + + # Doing the same in a custom postprocessing callback function: + class MyCallback(DefaultCallbacks): + # ... + + @override(DefaultCallbacks) + def on_postprocess_trajectory(self, worker, episode, agent_id, policy_id, + policies, postprocessed_batch, original_batches, + **kwargs): + postprocessed_batch["next_actions"] = np.concatenate( + [postprocessed_batch["actions"][1:], + np.zeros_like([policies[policy_id].action_space.sample()])]) + +The above two examples add a "next_action" view to the postprocessed SampleBatche needed +used by the Policy for training. It will not feed the "next_action" +to the Model's ``compute_action`` calls (it can't b/c the next action is of course not known +at that point). diff --git a/doc/source/rllib-toc.rst b/doc/source/rllib-toc.rst index 1989d0508..f4331eb6b 100644 --- a/doc/source/rllib-toc.rst +++ b/doc/source/rllib-toc.rst @@ -60,7 +60,6 @@ Training APIs Environments ------------ * `RLlib Environments Overview `__ -* `Feature Compatibility Matrix `__ * `OpenAI Gym `__ * `Vectorized `__ * `Multi-Agent and Hierarchical `__ @@ -144,6 +143,11 @@ Algorithms - |pytorch| :ref:`Curiosity (ICM: Intrinsic Curiosity Module) ` +Sample Collection +----------------- +* `The SampleCollector Class is Used to Store and Retrieve Temporary Data `__ +* `Trajectory View API `__ + Offline Datasets ---------------- diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index 984622f27..bbe35f36e 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -110,7 +110,12 @@ Beyond environments defined in Python, RLlib supports batch training on `offline Customization ~~~~~~~~~~~~~ -RLlib provides ways to customize almost all aspects of training, including the `environment `__, `neural network model `__, `action distribution `__, and `policy definitions `__: +RLlib provides ways to customize almost all aspects of training, including +`neural network models `__, +`action distributions `__, +`policy definitions `__: +the `environment `__, +and the `sample collection process `__ .. image:: rllib-components.svg diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 0162186e8..7d9d3f69d 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -14,6 +14,8 @@ from ray.exceptions import RayError from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.env.normalize_actions import NormalizeActionWrapper from ray.rllib.env.env_context import EnvContext +from ray.rllib.evaluation.collectors.simple_list_collector import \ + SimpleListCollector from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.policy import Policy @@ -231,6 +233,10 @@ COMMON_CONFIG: TrainerConfigDict = { # generic ModelV2 `input_dicts` that can be requested by the model to # contain different information on the ongoing episode. "_use_trajectory_view_api": True, + # The SampleCollector class to be used to collect and retrieve + # environment-, model-, and sampler data. Override the SampleCollector base + # class to implement your own collection/buffering/retrieval logic. + "sample_collector": SimpleListCollector, # Element-wise observation filter, either "NoFilter" or "MeanStdFilter". "observation_filter": "NoFilter", diff --git a/rllib/evaluation/collectors/sample_collector.py b/rllib/evaluation/collectors/sample_collector.py index da188e938..443265322 100644 --- a/rllib/evaluation/collectors/sample_collector.py +++ b/rllib/evaluation/collectors/sample_collector.py @@ -1,16 +1,22 @@ from abc import abstractmethod, ABCMeta import logging -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, TYPE_CHECKING, Union from ray.rllib.evaluation.episode import MultiAgentEpisode +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \ TensorType +if TYPE_CHECKING: + from ray.rllib.agents.callbacks import DefaultCallbacks + logger = logging.getLogger(__name__) -class _SampleCollector(metaclass=ABCMeta): +# yapf: disable +# __sphinx_doc_begin__ +class SampleCollector(metaclass=ABCMeta): """Collects samples for all policies and agents from a multi-agent env. Note: This is an experimental class only used when @@ -29,6 +35,34 @@ class _SampleCollector(metaclass=ABCMeta): communication channel). """ + def __init__(self, + policy_map: Dict[PolicyID, Policy], + clip_rewards: Union[bool, float], + callbacks: "DefaultCallbacks", + multiple_episodes_in_batch: bool = True, + rollout_fragment_length: int = 200, + count_steps_by: str = "env_steps"): + """Initializes a SampleCollector instance. + + Args: + policy_map (Dict[str, Policy]): Maps policy ids to policy + instances. + clip_rewards (Union[bool, float]): Whether to clip rewards before + postprocessing (at +/-1.0) or the actual value to +/- clip. + callbacks (DefaultCallbacks): RLlib callbacks. + multiple_episodes_in_batch (bool): Whether it's allowed to pack + multiple episodes into the same built batch. + rollout_fragment_length (int): The + + """ + + self.policy_map = policy_map + self.clip_rewards = clip_rewards + self.callbacks = callbacks + self.multiple_episodes_in_batch = multiple_episodes_in_batch + self.rollout_fragment_length = rollout_fragment_length + self.count_steps_by = count_steps_by + @abstractmethod def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, policy_id: PolicyID, t: int, @@ -55,7 +89,7 @@ class _SampleCollector(metaclass=ABCMeta): Examples: >>> obs = env.reset() - >>> collector.add_init_obs(12345, 0, "pol0", obs) + >>> collector.add_init_obs(my_episode, 0, "pol0", -1, obs) >>> obs, r, done, info = env.step(action) >>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, { ... "action": action, "obs": obs, "reward": r, "done": done @@ -227,3 +261,4 @@ class _SampleCollector(metaclass=ABCMeta): `self.rollout_fragment_length` has not been reached yet. """ raise NotImplementedError +# __sphinx_doc_end__ diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 96e6d0624..eb063e843 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -6,7 +6,7 @@ import numpy as np from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union from ray.rllib.env.base_env import _DUMMY_AGENT_ID -from ray.rllib.evaluation.collectors.sample_collector import _SampleCollector +from ray.rllib.evaluation.collectors.sample_collector import SampleCollector from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch @@ -379,7 +379,7 @@ class _PolicyCollectorGroup: self.agent_steps = 0 -class _SimpleListCollector(_SampleCollector): +class SimpleListCollector(SampleCollector): """Util to build SampleBatches for each policy in a multi-agent env. Input data is per-agent, while output data is per-policy. There is an M:N @@ -395,29 +395,15 @@ class _SimpleListCollector(_SampleCollector): multiple_episodes_in_batch: bool = True, rollout_fragment_length: int = 200, count_steps_by: str = "env_steps"): - """Initializes a _SimpleListCollector instance. + """Initializes a SimpleListCollector instance.""" - Args: - policy_map (Dict[str, Policy]): Maps policy ids to policy - instances. - clip_rewards (Union[bool, float]): Whether to clip rewards before - postprocessing (at +/-1.0) or the actual value to +/- clip. - callbacks (DefaultCallbacks): RLlib callbacks. - multiple_episodes_in_batch (bool): Whether it's allowed to pack - multiple episodes into the same built batch. - rollout_fragment_length (int): The + super().__init__(policy_map, clip_rewards, callbacks, + multiple_episodes_in_batch, rollout_fragment_length, + count_steps_by) - """ - - self.policy_map = policy_map - self.clip_rewards = clip_rewards - self.callbacks = callbacks - self.multiple_episodes_in_batch = multiple_episodes_in_batch - self.rollout_fragment_length = rollout_fragment_length - self.count_steps_by = count_steps_by self.large_batch_threshold: int = max( - 1000, rollout_fragment_length * - 10) if rollout_fragment_length != float("inf") else 5000 + 1000, self.rollout_fragment_length * + 10) if self.rollout_fragment_length != float("inf") else 5000 # Whenever we observe a new episode+agent, add a new # _SingleTrajectoryCollector. @@ -430,8 +416,9 @@ class _SimpleListCollector(_SampleCollector): self.policy_collector_groups = [] # Agents to collect data from for the next forward pass (per policy). - self.forward_pass_agent_keys = {pid: [] for pid in policy_map.keys()} - self.forward_pass_size = {pid: 0 for pid in policy_map.keys()} + self.forward_pass_agent_keys = \ + {pid: [] for pid in self.policy_map.keys()} + self.forward_pass_size = {pid: 0 for pid in self.policy_map.keys()} # Maps episode ID to the (non-built) env steps taken in this episode. self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int) @@ -441,7 +428,7 @@ class _SimpleListCollector(_SampleCollector): # Maps episode ID to MultiAgentEpisode. self.episodes: Dict[EpisodeID, MultiAgentEpisode] = {} - @override(_SampleCollector) + @override(SampleCollector) def episode_step(self, episode_id: EpisodeID) -> None: episode = self.episodes[episode_id] self.episode_steps[episode_id] += 1 @@ -470,7 +457,7 @@ class _SimpleListCollector(_SampleCollector): "does at some point." if not self.multiple_episodes_in_batch else "")) - @override(_SampleCollector) + @override(SampleCollector) def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, env_id: EnvID, policy_id: PolicyID, t: int, init_obs: TensorType) -> None: @@ -481,7 +468,7 @@ class _SimpleListCollector(_SampleCollector): else: assert self.agent_key_to_policy_id[agent_key] == policy_id policy = self.policy_map[policy_id] - view_reqs = policy.model.inference_view_requirements if \ + view_reqs = policy.model.view_requirements if \ getattr(policy, "model", None) else policy.view_requirements # Add initial obs to Trajectory. @@ -503,7 +490,7 @@ class _SimpleListCollector(_SampleCollector): self._add_to_next_inference_call(agent_key) - @override(_SampleCollector) + @override(SampleCollector) def add_action_reward_next_obs(self, episode_id: EpisodeID, agent_id: AgentID, env_id: EnvID, policy_id: PolicyID, agent_done: bool, @@ -525,14 +512,14 @@ class _SimpleListCollector(_SampleCollector): if not agent_done: self._add_to_next_inference_call(agent_key) - @override(_SampleCollector) + @override(SampleCollector) def total_env_steps(self) -> int: # Add the non-built ongoing-episode env steps + the already built # env-steps. return sum(self.episode_steps.values()) + sum( pg.env_steps for pg in self.policy_collector_groups.values()) - @override(_SampleCollector) + @override(SampleCollector) def total_agent_steps(self) -> int: # Add the non-built ongoing-episode agent steps (still in the agent # collectors) + the already built agent steps. @@ -540,13 +527,13 @@ class _SimpleListCollector(_SampleCollector): sum(pg.agent_steps for pg in self.policy_collector_groups.values()) - @override(_SampleCollector) + @override(SampleCollector) def get_inference_input_dict(self, policy_id: PolicyID) -> \ Dict[str, TensorType]: policy = self.policy_map[policy_id] keys = self.forward_pass_agent_keys[policy_id] buffers = {k: self.agent_collectors[k].buffers for k in keys} - view_reqs = policy.model.inference_view_requirements if \ + view_reqs = policy.model.view_requirements if \ getattr(policy, "model", None) else policy.view_requirements input_dict = {} @@ -592,7 +579,7 @@ class _SimpleListCollector(_SampleCollector): return input_dict - @override(_SampleCollector) + @override(SampleCollector) def postprocess_episode( self, episode: MultiAgentEpisode, @@ -725,7 +712,7 @@ class _SimpleListCollector(_SampleCollector): return ma_batch - @override(_SampleCollector) + @override(SampleCollector) def try_build_truncated_episode_multi_agent_batch(self) -> \ List[Union[MultiAgentBatch, SampleBatch]]: batches = [] diff --git a/rllib/evaluation/episode.py b/rllib/evaluation/episode.py index 4bf5a7172..f4a79ba12 100644 --- a/rllib/evaluation/episode.py +++ b/rllib/evaluation/episode.py @@ -28,7 +28,10 @@ class MultiAgentEpisode: episode_id (int): Unique id identifying this trajectory. agent_rewards (dict): Summed rewards broken down by agent. custom_metrics (dict): Dict where the you can add custom metrics. - user_data (dict): Dict that you can use for temporary storage. + user_data (dict): Dict that you can use for temporary storage. E.g. + in between two custom callbacks referring to the same episode. + hist_data (dict): Dict mapping str keys to List[float] for storage of + per-timestep float data throughout the episode. Use case 1: Model-based rollouts in multi-agent: A custom compute_actions() function in a policy can inspect the diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 370b4896f..729105371 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -473,7 +473,7 @@ class RolloutWorker(ParallelIteratorWorker): # state. for pol in self.policy_map.values(): if not pol._model_init_state_automatically_added: - pol._update_model_inference_view_requirements_from_init_state() + pol._update_model_view_requirements_from_init_state() if (ray.is_initialized() and ray.worker._mode() != ray.worker.LOCAL_MODE): @@ -543,6 +543,7 @@ class RolloutWorker(ParallelIteratorWorker): raise ValueError("Unsupported batch mode: {}".format( self.batch_mode)) + # Create the IOContext for this worker. self.io_context: IOContext = IOContext(log_dir, policy_config, worker_index, self) self.reward_estimators: List[OffPolicyEstimator] = [] @@ -586,6 +587,8 @@ class RolloutWorker(ParallelIteratorWorker): no_done_at_end=no_done_at_end, observation_fn=observation_fn, _use_trajectory_view_api=_use_trajectory_view_api, + sample_collector_class=policy_config.get( + "sample_collector_class"), ) # Start the Sampler thread. self.sampler.start() @@ -609,6 +612,8 @@ class RolloutWorker(ParallelIteratorWorker): no_done_at_end=no_done_at_end, observation_fn=observation_fn, _use_trajectory_view_api=_use_trajectory_view_api, + sample_collector_class=policy_config.get( + "sample_collector_class"), ) self.input_reader: InputReader = input_creator(self.io_context) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index a115a0149..0461bcd11 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -6,13 +6,13 @@ import queue import threading import time from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\ - TYPE_CHECKING, Union + Type, TYPE_CHECKING, Union from ray.util.debug import log_once from ray.rllib.evaluation.collectors.sample_collector import \ - _SampleCollector + SampleCollector from ray.rllib.evaluation.collectors.simple_list_collector import \ - _SimpleListCollector + SimpleListCollector from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.rollout_metrics import RolloutMetrics from ray.rllib.evaluation.sample_batch_builder import \ @@ -119,26 +119,28 @@ class SyncSampler(SamplerInput): """Sync SamplerInput that collects experiences when `get_data()` is called. """ - def __init__(self, - *, - worker: "RolloutWorker", - env: BaseEnv, - policies: Dict[PolicyID, Policy], - policy_mapping_fn: Callable[[AgentID], PolicyID], - preprocessors: Dict[PolicyID, Preprocessor], - obs_filters: Dict[PolicyID, Filter], - clip_rewards: bool, - rollout_fragment_length: int, - count_steps_by: str = "env_steps", - callbacks: "DefaultCallbacks", - horizon: int = None, - multiple_episodes_in_batch: bool = False, - tf_sess=None, - clip_actions: bool = True, - soft_horizon: bool = False, - no_done_at_end: bool = False, - observation_fn: "ObservationFunction" = None, - _use_trajectory_view_api: bool = False): + def __init__( + self, + *, + worker: "RolloutWorker", + env: BaseEnv, + policies: Dict[PolicyID, Policy], + policy_mapping_fn: Callable[[AgentID], PolicyID], + preprocessors: Dict[PolicyID, Preprocessor], + obs_filters: Dict[PolicyID, Filter], + clip_rewards: bool, + rollout_fragment_length: int, + count_steps_by: str = "env_steps", + callbacks: "DefaultCallbacks", + horizon: int = None, + multiple_episodes_in_batch: bool = False, + tf_sess=None, + clip_actions: bool = True, + soft_horizon: bool = False, + no_done_at_end: bool = False, + observation_fn: "ObservationFunction" = None, + _use_trajectory_view_api: bool = False, + sample_collector_class: Optional[Type[SampleCollector]] = None): """Initializes a SyncSampler object. Args: @@ -178,6 +180,9 @@ class SyncSampler(SamplerInput): _use_trajectory_view_api (bool): Whether to use the (experimental) `_use_trajectory_view_api` to make generic trajectory views available to Models. Default: False. + sample_collector_class (Optional[Type[SampleCollector]]): An + optional Samplecollector sub-class to use to collect, store, + and retrieve environment-, model-, and sampler data. """ self.base_env = BaseEnv.to_base_env(env) @@ -190,7 +195,9 @@ class SyncSampler(SamplerInput): self.extra_batches = queue.Queue() self.perf_stats = _PerfStats() if _use_trajectory_view_api: - self.sample_collector = _SimpleListCollector( + if not sample_collector_class: + sample_collector_class = SimpleListCollector + self.sample_collector = sample_collector_class( policies, clip_rewards, callbacks, @@ -249,27 +256,30 @@ class AsyncSampler(threading.Thread, SamplerInput): from where they can be unqueued by the caller of `get_data()`. """ - def __init__(self, - *, - worker: "RolloutWorker", - env: BaseEnv, - policies: Dict[PolicyID, Policy], - policy_mapping_fn: Callable[[AgentID], PolicyID], - preprocessors: Dict[PolicyID, Preprocessor], - obs_filters: Dict[PolicyID, Filter], - clip_rewards: bool, - rollout_fragment_length: int, - count_steps_by: str = "env_steps", - callbacks: "DefaultCallbacks", - horizon: int = None, - multiple_episodes_in_batch: bool = False, - tf_sess=None, - clip_actions: bool = True, - blackhole_outputs: bool = False, - soft_horizon: bool = False, - no_done_at_end: bool = False, - observation_fn: "ObservationFunction" = None, - _use_trajectory_view_api: bool = False): + def __init__( + self, + *, + worker: "RolloutWorker", + env: BaseEnv, + policies: Dict[PolicyID, Policy], + policy_mapping_fn: Callable[[AgentID], PolicyID], + preprocessors: Dict[PolicyID, Preprocessor], + obs_filters: Dict[PolicyID, Filter], + clip_rewards: bool, + rollout_fragment_length: int, + count_steps_by: str = "env_steps", + callbacks: "DefaultCallbacks", + horizon: int = None, + multiple_episodes_in_batch: bool = False, + tf_sess=None, + clip_actions: bool = True, + blackhole_outputs: bool = False, + soft_horizon: bool = False, + no_done_at_end: bool = False, + observation_fn: "ObservationFunction" = None, + _use_trajectory_view_api: bool = False, + sample_collector_class: Optional[Type[SampleCollector]] = None, + ): """Initializes a AsyncSampler object. Args: @@ -313,6 +323,9 @@ class AsyncSampler(threading.Thread, SamplerInput): _use_trajectory_view_api (bool): Whether to use the (experimental) `_use_trajectory_view_api` to make generic trajectory views available to Models. Default: False. + sample_collector_class (Optional[Type[SampleCollector]]): An + optional Samplecollector sub-class to use to collect, store, + and retrieve environment-, model-, and sampler data. """ for _, f in obs_filters.items(): assert getattr(f, "is_concurrent", False), \ @@ -343,7 +356,9 @@ class AsyncSampler(threading.Thread, SamplerInput): self.observation_fn = observation_fn self._use_trajectory_view_api = _use_trajectory_view_api if _use_trajectory_view_api: - self.sample_collector = _SimpleListCollector( + if not sample_collector_class: + sample_collector_class = SimpleListCollector + self.sample_collector = sample_collector_class( policies, clip_rewards, callbacks, @@ -441,7 +456,7 @@ def _env_runner( no_done_at_end: bool, observation_fn: "ObservationFunction", _use_trajectory_view_api: bool = False, - _sample_collector: Optional[_SampleCollector] = None, + sample_collector: Optional[SampleCollector] = None, ) -> Iterable[SampleBatchType]: """This implements the common experience collection logic. @@ -480,8 +495,8 @@ def _env_runner( _use_trajectory_view_api (bool): Whether to use the (experimental) `_use_trajectory_view_api` to make generic trajectory views available to Models. Default: False. - _sample_collector (Optional[_SampleCollector]): An optional - _SampleCollector object to use + sample_collector (Optional[SampleCollector]): An optional + SampleCollector object to use Yields: rollout (SampleBatch): Object containing state, action, reward, @@ -498,20 +513,27 @@ def _env_runner( # Trainer has a given `horizon` setting. if horizon: - # `horizon` is larger than env's limit -> Error and explain how - # to increase Env's own episode limit. + # `horizon` is larger than env's limit. if max_episode_steps and horizon > max_episode_steps: - raise ValueError( - "Your `horizon` setting ({}) is larger than the Env's own " - "timestep limit ({})! Try to increase the Env's limit via " - "setting its `spec.max_episode_steps` property.".format( - horizon, max_episode_steps)) + # Try to override the env's own max-step setting with our horizon. + # If this won't work, throw an error. + try: + base_env.get_unwrapped()[0].spec.max_episode_steps = horizon + base_env.get_unwrapped()[0]._max_episode_steps = horizon + except Exception: + raise ValueError( + "Your `horizon` setting ({}) is larger than the Env's own " + "timestep limit ({}), which seems to be unsettable! Try " + "to increase the Env's built-in limit to be at least as " + "large as your wanted `horizon`.".format( + horizon, max_episode_steps)) # Otherwise, set Trainer's horizon to env's max-steps. elif max_episode_steps: horizon = max_episode_steps logger.debug( "No episode horizon specified, setting it to Env's limit ({}).". format(max_episode_steps)) + # No horizon/max_episode_steps -> Episodes may be infinitely long. else: horizon = float("inf") logger.debug("No episode horizon specified, assuming inf.") @@ -594,7 +616,7 @@ def _env_runner( soft_horizon=soft_horizon, no_done_at_end=no_done_at_end, observation_fn=observation_fn, - _sample_collector=_sample_collector, + sample_collector=sample_collector, ) else: active_envs, to_eval, outputs = _process_observations( @@ -628,7 +650,7 @@ def _env_runner( eval_results = _do_policy_eval_w_trajectory_view_api( to_eval=to_eval, policies=policies, - _sample_collector=_sample_collector, + sample_collector=sample_collector, active_episodes=active_episodes, tf_sess=tf_sess, ) @@ -653,7 +675,7 @@ def _env_runner( policies=policies, clip_actions=clip_actions, _use_trajectory_view_api=_use_trajectory_view_api, - _sample_collector=_sample_collector, + sample_collector=sample_collector, ) perf_stats.action_processing_time += time.time() - t3 @@ -968,7 +990,7 @@ def _process_observations_w_trajectory_view_api( soft_horizon: bool, no_done_at_end: bool, observation_fn: "ObservationFunction", - _sample_collector: _SampleCollector, + sample_collector: SampleCollector, ) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[ RolloutMetrics, SampleBatchType]]]: """Trajectory View API version of `_process_observations()`. @@ -987,7 +1009,7 @@ def _process_observations_w_trajectory_view_api( episode: MultiAgentEpisode = active_episodes[env_id] if not is_new_episode: - _sample_collector.episode_step(episode.episode_id) + sample_collector.episode_step(episode.episode_id) episode._add_agent_rewards(rewards[env_id]) # Check episode termination conditions. @@ -1051,9 +1073,9 @@ def _process_observations_w_trajectory_view_api( # Record transition info if applicable. if last_observation is None: - _sample_collector.add_init_obs(episode, agent_id, env_id, - policy_id, episode.length - 1, - filtered_obs) + sample_collector.add_init_obs(episode, agent_id, env_id, + policy_id, episode.length - 1, + filtered_obs) else: # Add actions, rewards, next-obs to collectors. values_dict = { @@ -1079,7 +1101,7 @@ def _process_observations_w_trajectory_view_api( # Env infos for this agent. if "infos" in pol.view_requirements: values_dict["infos"] = agent_infos - _sample_collector.add_action_reward_next_obs( + sample_collector.add_action_reward_next_obs( episode.episode_id, agent_id, env_id, policy_id, agent_done, values_dict) @@ -1111,7 +1133,7 @@ def _process_observations_w_trajectory_view_api( # MultiAgentBatch from a single episode and add it to "outputs". # Otherwise, just postprocess and continue collecting across # episodes. - ma_sample_batch = _sample_collector.postprocess_episode( + ma_sample_batch = sample_collector.postprocess_episode( episode, is_done=is_done or (hit_horizon and not soft_horizon), check_dones=check_dones, @@ -1170,7 +1192,7 @@ def _process_observations_w_trajectory_view_api( new_episode._set_last_observation(agent_id, filtered_obs) # Add initial obs to buffer. - _sample_collector.add_init_obs( + sample_collector.add_init_obs( new_episode, agent_id, env_id, policy_id, new_episode.length - 1, filtered_obs) @@ -1183,7 +1205,7 @@ def _process_observations_w_trajectory_view_api( # Try to build something. if multiple_episodes_in_batch: sample_batches = \ - _sample_collector.try_build_truncated_episode_multi_agent_batch() + sample_collector.try_build_truncated_episode_multi_agent_batch() if sample_batches: outputs.extend(sample_batches) @@ -1279,7 +1301,7 @@ def _do_policy_eval_w_trajectory_view_api( *, to_eval: Dict[PolicyID, List[PolicyEvalData]], policies: Dict[PolicyID, Policy], - _sample_collector, + sample_collector, active_episodes: Dict[str, MultiAgentEpisode], tf_sess: Optional["tf.Session"] = None, ) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]: @@ -1291,7 +1313,7 @@ def _do_policy_eval_w_trajectory_view_api( be the batch's items for the model forward pass). policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy obj. - _sample_collector (SampleCollector): The SampleCollector object to use. + sample_collector (SampleCollector): The SampleCollector object to use. tf_sess (Optional[tf.Session]): Optional tensorflow session to use for batching TF policy evaluations. @@ -1313,7 +1335,7 @@ def _do_policy_eval_w_trajectory_view_api( for policy_id, eval_data in to_eval.items(): policy: Policy = _get_or_raise(policies, policy_id) - input_dict = _sample_collector.get_inference_input_dict(policy_id) + input_dict = sample_collector.get_inference_input_dict(policy_id) eval_results[policy_id] = \ policy.compute_actions_from_input_dict( input_dict, @@ -1343,7 +1365,7 @@ def _process_policy_eval_results( policies: Dict[PolicyID, Policy], clip_actions: bool, _use_trajectory_view_api: bool = False, - _sample_collector=None, + sample_collector=None, ) -> Dict[EnvID, Dict[AgentID, EnvActionType]]: """Process the output of policy neural network evaluation. diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index 927e74903..9b03960dc 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -63,7 +63,7 @@ class TestTrajectoryViewAPI(unittest.TestCase): config, env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv") policy = trainer.get_policy() - view_req_model = policy.model.inference_view_requirements + view_req_model = policy.model.view_requirements view_req_policy = policy.view_requirements assert len(view_req_model) == 1, view_req_model assert len(view_req_policy) == 8, view_req_policy @@ -108,7 +108,7 @@ class TestTrajectoryViewAPI(unittest.TestCase): for _ in framework_iterator(config): trainer = ppo.PPOTrainer(config, env="CartPole-v0") policy = trainer.get_policy() - view_req_model = policy.model.inference_view_requirements + view_req_model = policy.model.view_requirements view_req_policy = policy.view_requirements # 7=obs, prev-a + r, 2x state-in, 2x state-out. assert len(view_req_model) == 7, view_req_model diff --git a/rllib/examples/policy/episode_env_aware_policy.py b/rllib/examples/policy/episode_env_aware_policy.py index e632e1694..ed92ceb1a 100644 --- a/rllib/examples/policy/episode_env_aware_policy.py +++ b/rllib/examples/policy/episode_env_aware_policy.py @@ -21,7 +21,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy): self.model = _fake_model() self.model.time_major = True - self.model.inference_view_requirements = { + self.model.view_requirements = { SampleBatch.AGENT_INDEX: ViewRequirement(), SampleBatch.EPS_ID: ViewRequirement(), "env_id": ViewRequirement(), @@ -33,12 +33,12 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy): SampleBatch.REWARDS, shift=-1), } for i in range(2): - self.model.inference_view_requirements["state_in_{}".format(i)] = \ + self.model.view_requirements["state_in_{}".format(i)] = \ ViewRequirement( "state_out_{}".format(i), shift=-1, space=self.state_space) - self.model.inference_view_requirements[ + self.model.view_requirements[ "state_out_{}".format(i)] = \ ViewRequirement(space=self.state_space) @@ -50,7 +50,7 @@ class EpisodeEnvAwareLSTMPolicy(RandomPolicy): SampleBatch.REWARDS: ViewRequirement(), SampleBatch.DONES: ViewRequirement(), }, - **self.model.inference_view_requirements) + **self.model.view_requirements) @override(Policy) def is_recurrent(self): @@ -97,7 +97,7 @@ class EpisodeEnvAwareAttentionPolicy(RandomPolicy): pass self.model = _fake_model() - self.model.inference_view_requirements = { + self.model.view_requirements = { SampleBatch.AGENT_INDEX: ViewRequirement(), SampleBatch.EPS_ID: ViewRequirement(), "env_id": ViewRequirement(), @@ -114,7 +114,7 @@ class EpisodeEnvAwareAttentionPolicy(RandomPolicy): } self.view_requirements = dict(super()._get_default_view_requirements(), - **self.model.inference_view_requirements) + **self.model.view_requirements) @override(Policy) def is_recurrent(self): diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index a6c871d0f..43bde6a4b 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -62,7 +62,7 @@ class ModelV2: self._last_output = None self.time_major = self.model_config.get("_time_major") # Basic view requirement for all models: Use the observation as input. - self.inference_view_requirements = { + self.view_requirements = { SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space), } @@ -341,7 +341,7 @@ class ModelV2: } input_dict = {} - for view_col, view_req in self.inference_view_requirements.items(): + for view_col, view_req in self.view_requirements.items(): # Create batches of size 1 (single-agent input-dict). data_col = view_req.data_col or view_col if index == "last": diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py index ef49f4610..642e2c1b5 100644 --- a/rllib/models/tf/attention_net.py +++ b/rllib/models/tf/attention_net.py @@ -284,22 +284,22 @@ class GTrXLNet(RecurrentNetwork): self.register_variables(self.trxl_model.variables) self.trxl_model.summary() - # Setup inference view (`memory-inference` x past observations + - # current one (0)) - # 1 to `num_transformer_units`: Memory data (one per transformer unit). + # __sphinx_doc_begin__ + # Setup trajectory views (`memory-inference` x past memory outs). for i in range(self.num_transformer_units): space = Box(-1.0, 1.0, shape=(self.attn_dim, )) - self.inference_view_requirements["state_in_{}".format(i)] = \ + self.view_requirements["state_in_{}".format(i)] = \ ViewRequirement( "state_out_{}".format(i), shift="-{}:-1".format(self.memory_inference), # Repeat the incoming state every max-seq-len times. batch_repeat_value=self.max_seq_len, space=space) - self.inference_view_requirements["state_out_{}".format(i)] = \ + self.view_requirements["state_out_{}".format(i)] = \ ViewRequirement( space=space, used_for_training=False) + # __sphinx_doc_end__ @override(ModelV2) def forward(self, input_dict, state: List[TensorType], diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index 0dd27e6b3..fa51b54d0 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -184,11 +184,11 @@ class LSTMWrapper(RecurrentNetwork): # Add prev-a/r to this model's view, if required. if model_config["lstm_use_prev_action"]: - self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \ + self.view_requirements[SampleBatch.PREV_ACTIONS] = \ ViewRequirement(SampleBatch.ACTIONS, space=self.action_space, shift=-1) if model_config["lstm_use_prev_reward"]: - self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ + self.view_requirements[SampleBatch.PREV_REWARDS] = \ ViewRequirement(SampleBatch.REWARDS, shift=-1) @override(RecurrentNetwork) diff --git a/rllib/models/torch/attention_net.py b/rllib/models/torch/attention_net.py index 27d2d494e..a6440ec6f 100644 --- a/rllib/models/torch/attention_net.py +++ b/rllib/models/torch/attention_net.py @@ -160,19 +160,17 @@ class GTrXLNet(RecurrentNetwork, nn.Module): self.values_out = SlimFC( in_size=self.attn_dim, out_size=1, activation_fn=None) - # Setup inference view (`memory-inference` x past observations + - # current one (0)) - # 1 to `num_transformer_units`: Memory data (one per transformer unit). + # Setup trajectory views (`memory-inference` x past memory outs). for i in range(self.num_transformer_units): space = Box(-1.0, 1.0, shape=(self.attn_dim, )) - self.inference_view_requirements["state_in_{}".format(i)] = \ + self.view_requirements["state_in_{}".format(i)] = \ ViewRequirement( "state_out_{}".format(i), shift="-{}:-1".format(self.memory_inference), # Repeat the incoming state every max-seq-len times. batch_repeat_value=self.max_seq_len, space=space) - self.inference_view_requirements["state_out_{}".format(i)] = \ + self.view_requirements["state_out_{}".format(i)] = \ ViewRequirement( space=space, used_for_training=False) diff --git a/rllib/models/torch/recurrent_net.py b/rllib/models/torch/recurrent_net.py index fd4679022..247c4e073 100644 --- a/rllib/models/torch/recurrent_net.py +++ b/rllib/models/torch/recurrent_net.py @@ -166,14 +166,16 @@ class LSTMWrapper(RecurrentNetwork, nn.Module): activation_fn=None, initializer=torch.nn.init.xavier_uniform_) + # __sphinx_doc_begin__ # Add prev-a/r to this model's view, if required. if model_config["lstm_use_prev_action"]: - self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \ + self.view_requirements[SampleBatch.PREV_ACTIONS] = \ ViewRequirement(SampleBatch.ACTIONS, space=self.action_space, shift=-1) if model_config["lstm_use_prev_reward"]: - self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ + self.view_requirements[SampleBatch.PREV_REWARDS] = \ ViewRequirement(SampleBatch.REWARDS, shift=-1) + # __sphinx_doc_end__ @override(RecurrentNetwork) def forward(self, input_dict: Dict[str, TensorType], diff --git a/rllib/offline/input_reader.py b/rllib/offline/input_reader.py index e6db9e8b0..3b05e4772 100644 --- a/rllib/offline/input_reader.py +++ b/rllib/offline/input_reader.py @@ -1,3 +1,4 @@ +from abc import ABCMeta, abstractmethod import logging import numpy as np import threading @@ -14,15 +15,16 @@ logger = logging.getLogger(__name__) @PublicAPI -class InputReader: +class InputReader(metaclass=ABCMeta): """Input object for loading experiences in policy evaluation.""" + @abstractmethod @PublicAPI def next(self): - """Return the next batch of experiences read. + """Returns the next batch of experiences read. Returns: - SampleBatch or MultiAgentBatch read. + Union[SampleBatch, MultiAgentBatch]: The experience read. """ raise NotImplementedError diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 39b31f63b..10ecf9931 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -171,7 +171,7 @@ class DynamicTFPolicy(TFPolicy): model_config=self.config["model"], framework="tf") # Auto-update model's inference view requirements, if recurrent. - self._update_model_inference_view_requirements_from_init_state() + self._update_model_view_requirements_from_init_state() if existing_inputs: self._state_inputs = [ @@ -186,8 +186,7 @@ class DynamicTFPolicy(TFPolicy): get_placeholder( space=vr.space, time_axis=not isinstance(vr.shift, int), - ) for k, vr in - self.model.inference_view_requirements.items() + ) for k, vr in self.model.view_requirements.items() if k.startswith("state_in_") ] else: @@ -200,7 +199,7 @@ class DynamicTFPolicy(TFPolicy): # Add NEXT_OBS, STATE_IN_0.., and others. self.view_requirements = self._get_default_view_requirements() # Combine view_requirements for Model and Policy. - self.view_requirements.update(self.model.inference_view_requirements) + self.view_requirements.update(self.model.view_requirements) # Setup standard placeholders. if existing_inputs is not None: @@ -560,7 +559,7 @@ class DynamicTFPolicy(TFPolicy): all_accessed_keys = \ train_batch.accessed_keys | batch_for_postproc.accessed_keys | \ batch_for_postproc.added_keys | set( - self.model.inference_view_requirements.keys()) + self.model.view_requirements.keys()) TFPolicy._initialize_loss(self, loss, [(k, v) for k, v in train_batch.items() @@ -584,7 +583,7 @@ class DynamicTFPolicy(TFPolicy): # Tag those only needed for post-processing. for key in batch_for_postproc.accessed_keys: if key not in train_batch.accessed_keys and \ - key not in self.model.inference_view_requirements: + key not in self.model.view_requirements: if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: @@ -597,7 +596,7 @@ class DynamicTFPolicy(TFPolicy): SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS] and \ - key not in self.model.inference_view_requirements: + key not in self.model.view_requirements: # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index af4fa512c..c21f31755 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -256,15 +256,14 @@ def build_eager_tf_policy(name, framework=self.framework, ) # Auto-update model's inference view requirements, if recurrent. - self._update_model_inference_view_requirements_from_init_state() + self._update_model_view_requirements_from_init_state() self.exploration = self._create_exploration() self._state_inputs = self.model.get_initial_state() self._is_recurrent = len(self._state_inputs) > 0 # Combine view_requirements for Model and Policy. - self.view_requirements.update( - self.model.inference_view_requirements) + self.view_requirements.update(self.model.view_requirements) if before_loss_init: before_loss_init(self, observation_space, action_space, config) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 4695e366f..ab7cf3de1 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -669,7 +669,7 @@ class Policy(metaclass=ABCMeta): for key in batch_for_postproc.accessed_keys: if key not in train_batch.accessed_keys and \ key in self.view_requirements and \ - key not in self.model.inference_view_requirements: + key not in self.model.view_requirements: self.view_requirements[key].used_for_training = False # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). @@ -679,7 +679,7 @@ class Policy(metaclass=ABCMeta): SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS] and \ - key not in self.model.inference_view_requirements: + key not in self.model.view_requirements: # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. @@ -738,12 +738,12 @@ class Policy(metaclass=ABCMeta): # columns in the resulting batch may not all have the same batch size. return SampleBatch(ret, _dont_check_lens=True) - def _update_model_inference_view_requirements_from_init_state(self): + def _update_model_view_requirements_from_init_state(self): """Uses Model's (or this Policy's) init state to add needed ViewReqs. Can be called from within a Policy to make sure RNNs automatically update their internal state-related view requirements. - Changes the `self.inference_view_requirements` dict. + Changes the `self.view_requirements` dict. """ self._model_init_state_automatically_added = True model = getattr(self, "model", None) @@ -752,7 +752,7 @@ class Policy(metaclass=ABCMeta): for i, state in enumerate(obj.get_initial_state()): space = Box(-1.0, 1.0, shape=state.shape) if \ hasattr(state, "shape") else state - view_reqs = model.inference_view_requirements if model else \ + view_reqs = model.view_requirements if model else \ self.view_requirements view_reqs["state_in_{}".format(i)] = ViewRequirement( "state_out_{}".format(i), diff --git a/rllib/policy/policy_template.py b/rllib/policy/policy_template.py index 9c9557413..eb08ecc65 100644 --- a/rllib/policy/policy_template.py +++ b/rllib/policy/policy_template.py @@ -9,7 +9,6 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy -from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils import add_mixins, force_list, NullContextManager from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch, try_import_jax @@ -74,8 +73,6 @@ def build_policy_class( apply_gradients_fn: Optional[Callable[ [Policy, "torch.optim.Optimizer"], None]] = None, mixins: Optional[List[type]] = None, - view_requirements_fn: Optional[Callable[[Policy], Dict[ - str, ViewRequirement]]] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None ) -> Type[TorchPolicy]: """Helper function for creating a new Policy class at runtime. @@ -181,9 +178,6 @@ def build_policy_class( mixins (Optional[List[type]]): Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the TorchPolicy class. - view_requirements_fn (Optional[Callable[[Policy], - Dict[str, ViewRequirement]]]): An optional callable to retrieve - additional train view requirements for this policy. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1. @@ -260,12 +254,8 @@ def build_policy_class( get_batch_divisibility_req=get_batch_divisibility_req, ) - # Update this Policy's ViewRequirements (if function given). - if callable(view_requirements_fn): - self.view_requirements.update(view_requirements_fn(self)) # Merge Model's view requirements into Policy's. - self.view_requirements.update( - self.model.inference_view_requirements) + self.view_requirements.update(self.model.view_requirements) _before_loss_init = before_loss_init or after_init if _before_loss_init: diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index fe6ec900b..ef211bf8d 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -147,7 +147,7 @@ class TFPolicy(Policy): self.model = model # Auto-update model's inference view requirements, if recurrent. if self.model is not None: - self._update_model_inference_view_requirements_from_init_state() + self._update_model_view_requirements_from_init_state() self.exploration = self._create_exploration() self._sess = sess diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index e7a1c69ad..f81ac03ab 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -113,9 +113,9 @@ class TorchPolicy(Policy): self._state_inputs = self.model.get_initial_state() self._is_recurrent = len(self._state_inputs) > 0 # Auto-update model's inference view requirements, if recurrent. - self._update_model_inference_view_requirements_from_init_state() + self._update_model_view_requirements_from_init_state() # Combine view_requirements for Model and Policy. - self.view_requirements.update(self.model.inference_view_requirements) + self.view_requirements.update(self.model.view_requirements) self.exploration = self._create_exploration() self.unwrapped_model = model # used to support DistributedDataParallel diff --git a/rllib/policy/view_requirement.py b/rllib/policy/view_requirement.py index 25a5e908a..7d361d8dd 100644 --- a/rllib/policy/view_requirement.py +++ b/rllib/policy/view_requirement.py @@ -22,7 +22,7 @@ class ViewRequirement: Examples: >>> # The default ViewRequirement for a Model is: - >>> req = [ModelV2].inference_view_requirements + >>> req = [ModelV2].view_requirements >>> print(req) {"obs": ViewRequirement(shift=0)} """