[rllib] RLlib in 60 seconds documentation (#5430)

This commit is contained in:
Eric Liang 2019-08-12 17:39:02 -07:00 committed by GitHub
parent 3218ee389a
commit 79949fb8a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 274 additions and 180 deletions

View file

@ -200,6 +200,7 @@ The following are good places to discuss Ray.
:caption: RLlib
rllib.rst
rllib-toc.rst
rllib-training.rst
rllib-env.rst
rllib-models.rst

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 35 KiB

138
doc/source/rllib-toc.rst Normal file
View file

@ -0,0 +1,138 @@
RLlib Table of Contents
=======================
Training APIs
-------------
* `Command-line <rllib-training.html>`__
* `Configuration <rllib-training.html#configuration>`__
* `Python API <rllib-training.html#python-api>`__
* `Debugging <rllib-training.html#debugging>`__
* `REST API <rllib-training.html#rest-api>`__
Environments
------------
* `RLlib Environments Overview <rllib-env.html>`__
* `Feature Compatibility Matrix <rllib-env.html#feature-compatibility-matrix>`__
* `OpenAI Gym <rllib-env.html#openai-gym>`__
* `Vectorized <rllib-env.html#vectorized>`__
* `Multi-Agent and Hierarchical <rllib-env.html#multi-agent-and-hierarchical>`__
* `Interfacing with External Agents <rllib-env.html#interfacing-with-external-agents>`__
* `Advanced Integrations <rllib-env.html#advanced-integrations>`__
Models, Preprocessors, and Action Distributions
-----------------------------------------------
* `RLlib Models, Preprocessors, and Action Distributions Overview <rllib-models.html>`__
* `TensorFlow Models <rllib-models.html#tensorflow-models>`__
* `PyTorch Models <rllib-models.html#pytorch-models>`__
* `Custom Preprocessors <rllib-models.html#custom-preprocessors>`__
* `Custom Action Distributions <rllib-models.html#custom-action-distributions>`__
* `Supervised Model Losses <rllib-models.html#supervised-model-losses>`__
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
* `Autoregressive Action Distributions <rllib-models.html#autoregressive-action-distributions>`__
Algorithms
----------
* High-throughput architectures
- `Distributed Prioritized Experience Replay (Ape-X) <rllib-algorithms.html#distributed-prioritized-experience-replay-ape-x>`__
- `Importance Weighted Actor-Learner Architecture (IMPALA) <rllib-algorithms.html#importance-weighted-actor-learner-architecture-impala>`__
- `Asynchronous Proximal Policy Optimization (APPO) <rllib-algorithms.html#asynchronous-proximal-policy-optimization-appo>`__
* Gradient-based
- `Advantage Actor-Critic (A2C, A3C) <rllib-algorithms.html#advantage-actor-critic-a2c-a3c>`__
- `Deep Deterministic Policy Gradients (DDPG, TD3) <rllib-algorithms.html#deep-deterministic-policy-gradients-ddpg-td3>`__
- `Deep Q Networks (DQN, Rainbow, Parametric DQN) <rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn>`__
- `Policy Gradients <rllib-algorithms.html#policy-gradients>`__
- `Proximal Policy Optimization (PPO) <rllib-algorithms.html#proximal-policy-optimization-ppo>`__
- `Soft Actor Critic (SAC) <rllib-algorithms.html#soft-actor-critic-sac>`__
* Derivative-free
- `Augmented Random Search (ARS) <rllib-algorithms.html#augmented-random-search-ars>`__
- `Evolution Strategies <rllib-algorithms.html#evolution-strategies>`__
* Multi-agent specific
- `QMIX Monotonic Value Factorisation (QMIX, VDN, IQN) <rllib-algorithms.html#qmix-monotonic-value-factorisation-qmix-vdn-iqn>`__
- `Multi-Agent Deep Deterministic Policy Gradient (contrib/MADDPG) <rllib-algorithms.html#multi-agent-deep-deterministic-policy-gradient-contrib-maddpg>`__
* Offline
- `Advantage Re-Weighted Imitation Learning (MARWIL) <rllib-algorithms.html#advantage-re-weighted-imitation-learning-marwil>`__
Offline Datasets
----------------
* `Working with Offline Datasets <rllib-offline.html>`__
* `Input Pipeline for Supervised Losses <rllib-offline.html#input-pipeline-for-supervised-losses>`__
* `Input API <rllib-offline.html#input-api>`__
* `Output API <rllib-offline.html#output-api>`__
Concepts and Custom Algorithms
------------------------------
* `Policies <rllib-concepts.html>`__
- `Policies in Multi-Agent <rllib-concepts.html#policies-in-multi-agent>`__
- `Building Policies in TensorFlow <rllib-concepts.html#building-policies-in-tensorflow>`__
- `Building Policies in TensorFlow Eager <rllib-concepts.html#building-policies-in-tensorflow-eager>`__
- `Building Policies in PyTorch <rllib-concepts.html#building-policies-in-pytorch>`__
- `Extending Existing Policies <rllib-concepts.html#extending-existing-policies>`__
* `Policy Evaluation <rllib-concepts.html#policy-evaluation>`__
* `Policy Optimization <rllib-concepts.html#policy-optimization>`__
* `Trainers <rllib-concepts.html#trainers>`__
Examples
--------
* `Tuned Examples <rllib-examples.html#tuned-examples>`__
* `Training Workflows <rllib-examples.html#training-workflows>`__
* `Custom Envs and Models <rllib-examples.html#custom-envs-and-models>`__
* `Serving and Offline <rllib-examples.html#serving-and-offline>`__
* `Multi-Agent and Hierarchical <rllib-examples.html#multi-agent-and-hierarchical>`__
* `Community Examples <rllib-examples.html#community-examples>`__
Development
-----------
* `Development Install <rllib-dev.html#development-install>`__
* `API Stability <rllib-dev.html#api-stability>`__
* `Features <rllib-dev.html#feature-development>`__
* `Benchmarks <rllib-dev.html#benchmarks>`__
* `Contributing Algorithms <rllib-dev.html#contributing-algorithms>`__
Package Reference
-----------------
* `ray.rllib.agents <rllib-package-ref.html#module-ray.rllib.agents>`__
* `ray.rllib.env <rllib-package-ref.html#module-ray.rllib.env>`__
* `ray.rllib.evaluation <rllib-package-ref.html#module-ray.rllib.evaluation>`__
* `ray.rllib.models <rllib-package-ref.html#module-ray.rllib.models>`__
* `ray.rllib.optimizers <rllib-package-ref.html#module-ray.rllib.optimizers>`__
* `ray.rllib.utils <rllib-package-ref.html#module-ray.rllib.utils>`__
Troubleshooting
---------------
If you encounter errors like
`blas_thread_init: pthread_create: Resource temporarily unavailable` when using many workers,
try setting ``OMP_NUM_THREADS=1``. Similarly, check configured system limits with
`ulimit -a` for other resource limit errors.
If you encounter out-of-memory errors, consider setting ``redis_max_memory`` and ``object_store_memory`` in ``ray.init()`` to reduce memory usage.
For debugging unexpected hangs or performance problems, you can run ``ray stack`` to dump
the stack traces of all Ray workers on the current node, and ``ray timeline`` to dump
a timeline visualization of tasks to a file.

View file

@ -7,155 +7,109 @@ RLlib is an open-source library for reinforcement learning that offers both high
To get started, take a look over the `custom env example <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__ and the `API documentation <rllib-training.html>`__. If you're looking to develop custom algorithms with RLlib, also check out `concepts and custom algorithms <rllib-concepts.html>`__.
Installation
------------
RLlib in 60 seconds
-------------------
The following is a whirlwind overview of RLlib. See also the full `table of contents <rllib-toc.html>`__ for a more in-depth guide including the `list of built-in algorithms <rllib-toc.html#algorithms>`__.
Running RLlib
~~~~~~~~~~~~~
RLlib has extra dependencies on top of ``ray``. First, you'll need to install either `PyTorch <http://pytorch.org/>`__ or `TensorFlow <https://www.tensorflow.org>`__. Then, install the RLlib module:
.. code-block:: bash
pip install tensorflow # or tensorflow-gpu
pip install ray[rllib] # also recommended: ray[debug]
You might also want to clone the `Ray repo <https://github.com/ray-project/ray>`__ for convenient access to RLlib helper scripts:
Then, you can try out training in the following equivalent ways:
.. code-block:: bash
git clone https://github.com/ray-project/ray
cd ray/rllib
rllib train --run=PPO --env=CartPole-v0
Training APIs
-------------
* `Command-line <rllib-training.html>`__
* `Configuration <rllib-training.html#configuration>`__
* `Python API <rllib-training.html#python-api>`__
* `Debugging <rllib-training.html#debugging>`__
* `REST API <rllib-training.html#rest-api>`__
.. code-block:: python
Environments
------------
* `RLlib Environments Overview <rllib-env.html>`__
* `Feature Compatibility Matrix <rllib-env.html#feature-compatibility-matrix>`__
* `OpenAI Gym <rllib-env.html#openai-gym>`__
* `Vectorized <rllib-env.html#vectorized>`__
* `Multi-Agent and Hierarchical <rllib-env.html#multi-agent-and-hierarchical>`__
* `Interfacing with External Agents <rllib-env.html#interfacing-with-external-agents>`__
* `Advanced Integrations <rllib-env.html#advanced-integrations>`__
from ray import tune
from ray.rllib.agents.ppo import PPOTrainer
tune.run(PPOTrainer, config={"env": "CartPole-v0"})
Models, Preprocessors, and Action Distributions
-----------------------------------------------
* `RLlib Models, Preprocessors, and Action Distributions Overview <rllib-models.html>`__
* `TensorFlow Models <rllib-models.html#tensorflow-models>`__
* `PyTorch Models <rllib-models.html#pytorch-models>`__
* `Custom Preprocessors <rllib-models.html#custom-preprocessors>`__
* `Custom Action Distributions <rllib-models.html#custom-action-distributions>`__
* `Supervised Model Losses <rllib-models.html#supervised-model-losses>`__
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
* `Autoregressive Action Distributions <rllib-models.html#autoregressive-action-distributions>`__
Next, we'll cover three key concepts in RLlib: Policies, Samples, and Trainers.
Algorithms
----------
Policies
~~~~~~~~
* High-throughput architectures
`Policies <rllib-concepts.html#policies>`__ are a core concept in RLlib. In a nutshell, policies are Python classes that define how an agent acts in an environment. `Rollout workers <rllib-concepts.html#policy-evaluation>`__ query the policy to determine agent actions. In a `gym <rllib-env.html#openai-gym>`__ environment, there is a single agent and policy. In `vector envs <rllib-env.html#vectorized>`__, policy inference is for multiple agents at once, and in `multi-agent <rllib-env.html#multi-agent-and-hierarchical>`__, there may be multiple policies, each controlling one or more agents:
- `Distributed Prioritized Experience Replay (Ape-X) <rllib-algorithms.html#distributed-prioritized-experience-replay-ape-x>`__
.. image:: multi-flat.svg
- `Importance Weighted Actor-Learner Architecture (IMPALA) <rllib-algorithms.html#importance-weighted-actor-learner-architecture-impala>`__
Policies can be implemented using `any framework <https://github.com/ray-project/ray/blob/master/rllib/policy/policy.py>`__. However, for TensorFlow and PyTorch, RLlib has `build_tf_policy <rllib-concepts.html#building-policies-in-tensorflow>`__ and `build_torch_policy <rllib-concepts.html#building-policies-in-pytorch>`__ helper functions that let you define a trainable policy with a functional-style API, for example:
- `Asynchronous Proximal Policy Optimization (APPO) <rllib-algorithms.html#asynchronous-proximal-policy-optimization-appo>`__
.. code-block:: python
* Gradient-based
def policy_gradient_loss(policy, batch_tensors):
actions = batch_tensors[SampleBatch.ACTIONS]
rewards = batch_tensors[SampleBatch.REWARDS]
return -tf.reduce_mean(policy.action_dist.logp(actions) * rewards)
- `Advantage Actor-Critic (A2C, A3C) <rllib-algorithms.html#advantage-actor-critic-a2c-a3c>`__
# <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
MyTFPolicy = build_tf_policy(
name="MyTFPolicy",
loss_fn=policy_gradient_loss)
- `Deep Deterministic Policy Gradients (DDPG, TD3) <rllib-algorithms.html#deep-deterministic-policy-gradients-ddpg-td3>`__
Sample Batches
~~~~~~~~~~~~~~
- `Deep Q Networks (DQN, Rainbow, Parametric DQN) <rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn>`__
Whether running in a single process or `large cluster <rllib-training.html#specifying-resources>`__, all data interchange in RLlib is in the form of `sample batches <https://github.com/ray-project/ray/blob/master/rllib/policy/sample_batch.py>`__. Sample batches encode one or more fragments of a trajectory. Typically, RLlib collects batches of size ``sample_batch_size`` from rollout workers, and concatenates one or more of these batches into a batch of size ``train_batch_size`` that is the input to SGD.
- `Policy Gradients <rllib-algorithms.html#policy-gradients>`__
A typical sample batch looks something like the following when summarized. Since all values are kept in arrays, this allows for efficient encoding and transmission across the network:
- `Proximal Policy Optimization (PPO) <rllib-algorithms.html#proximal-policy-optimization-ppo>`__
.. code-block:: python
- `Soft Actor Critic (SAC) <rllib-algorithms.html#soft-actor-critic-sac>`__
{ 'action_logp': np.ndarray((200,), dtype=float32, min=-0.701, max=-0.685, mean=-0.694),
'actions': np.ndarray((200,), dtype=int64, min=0.0, max=1.0, mean=0.495),
'dones': np.ndarray((200,), dtype=bool, min=0.0, max=1.0, mean=0.055),
'infos': np.ndarray((200,), dtype=object, head={}),
'new_obs': np.ndarray((200, 4), dtype=float32, min=-2.46, max=2.259, mean=0.018),
'obs': np.ndarray((200, 4), dtype=float32, min=-2.46, max=2.259, mean=0.016),
'rewards': np.ndarray((200,), dtype=float32, min=1.0, max=1.0, mean=1.0),
't': np.ndarray((200,), dtype=int64, min=0.0, max=34.0, mean=9.14)}
* Derivative-free
In `multi-agent mode <rllib-concepts.html#policies-in-multi-agent>`__, sample batches are collected separately for each individual policy.
- `Augmented Random Search (ARS) <rllib-algorithms.html#augmented-random-search-ars>`__
Training
~~~~~~~~
- `Evolution Strategies <rllib-algorithms.html#evolution-strategies>`__
Policies each define a ``learn_on_batch()`` method that improves the policy given a sample batch of input. For TF and Torch policies, this is implemented using a `loss function` that takes as input sample batch tensors and outputs a scalar loss. Here are a few example loss functions:
* Multi-agent specific
- Simple `policy gradient loss <https://github.com/ray-project/ray/blob/master/rllib/agents/pg/pg_policy.py>`__
- Simple `Q-function loss <https://github.com/ray-project/ray/blob/a1d2e1762325cd34e14dc411666d63bb15d6eaf0/rllib/agents/dqn/simple_q_policy.py#L136>`__
- Importance-weighted `APPO surrogate loss <https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/appo_policy.py>`__
- `QMIX Monotonic Value Factorisation (QMIX, VDN, IQN) <rllib-algorithms.html#qmix-monotonic-value-factorisation-qmix-vdn-iqn>`__
- `Multi-Agent Deep Deterministic Policy Gradient (contrib/MADDPG) <rllib-algorithms.html#multi-agent-deep-deterministic-policy-gradient-contrib-maddpg>`__
RLlib `Trainer classes <rllib-concepts.html#trainers>`__ coordinate the distributed workflow of running rollouts and optimizing policies. They do this by leveraging `policy optimizers <rllib-concepts.html#policy-optimization>`__ that implement the desired computation pattern (i.e., synchronous or asynchronous sampling, distributed replay, etc):
* Offline
.. figure:: a2c-arch.svg
- `Advantage Re-Weighted Imitation Learning (MARWIL) <rllib-algorithms.html#advantage-re-weighted-imitation-learning-marwil>`__
Synchronous Sampling (e.g., A2C, PG, PPO)
Offline Datasets
----------------
* `Working with Offline Datasets <rllib-offline.html>`__
* `Input Pipeline for Supervised Losses <rllib-offline.html#input-pipeline-for-supervised-losses>`__
* `Input API <rllib-offline.html#input-api>`__
* `Output API <rllib-offline.html#output-api>`__
.. figure:: dqn-arch.svg
Concepts and Custom Algorithms
------------------------------
* `Policies <rllib-concepts.html>`__
Synchronous Replay (e.g., DQN, DDPG, TD3)
- `Policies in Multi-Agent <rllib-concepts.html#policies-in-multi-agent>`__
.. figure:: impala-arch.svg
- `Building Policies in TensorFlow <rllib-concepts.html#building-policies-in-tensorflow>`__
Asynchronous Sampling (e.g., IMPALA, APPO)
- `Building Policies in TensorFlow Eager <rllib-concepts.html#building-policies-in-tensorflow-eager>`__
.. figure:: apex-arch.svg
- `Building Policies in PyTorch <rllib-concepts.html#building-policies-in-pytorch>`__
Asynchronous Replay (e.g., Ape-X)
- `Extending Existing Policies <rllib-concepts.html#extending-existing-policies>`__
RLlib uses `Ray actors <actors.html>`__ to scale these architectures from a single core to many thousands of cores in a cluster. You can `configure the parallelism <rllib-training.html#specifying-resources>`__ used for training by changing the ``num_workers`` parameter.
* `Policy Evaluation <rllib-concepts.html#policy-evaluation>`__
* `Policy Optimization <rllib-concepts.html#policy-optimization>`__
* `Trainers <rllib-concepts.html#trainers>`__
Customization
~~~~~~~~~~~~~
Examples
--------
RLlib provides ways to customize almost all aspects of training, including the `environment <rllib-env.html#configuring-environments>`__, `neural network model <rllib-models.html#tensorflow-models>`__, `action distribution <rllib-models.html#custom-action-distributions>`__, and `policy definitions <rllib-concepts.html#policies>`__:
* `Tuned Examples <rllib-examples.html#tuned-examples>`__
* `Training Workflows <rllib-examples.html#training-workflows>`__
* `Custom Envs and Models <rllib-examples.html#custom-envs-and-models>`__
* `Serving and Offline <rllib-examples.html#serving-and-offline>`__
* `Multi-Agent and Hierarchical <rllib-examples.html#multi-agent-and-hierarchical>`__
* `Community Examples <rllib-examples.html#community-examples>`__
.. image:: rllib-components.svg
Development
-----------
* `Development Install <rllib-dev.html#development-install>`__
* `API Stability <rllib-dev.html#api-stability>`__
* `Features <rllib-dev.html#feature-development>`__
* `Benchmarks <rllib-dev.html#benchmarks>`__
* `Contributing Algorithms <rllib-dev.html#contributing-algorithms>`__
Package Reference
-----------------
* `ray.rllib.agents <rllib-package-ref.html#module-ray.rllib.agents>`__
* `ray.rllib.env <rllib-package-ref.html#module-ray.rllib.env>`__
* `ray.rllib.evaluation <rllib-package-ref.html#module-ray.rllib.evaluation>`__
* `ray.rllib.models <rllib-package-ref.html#module-ray.rllib.models>`__
* `ray.rllib.optimizers <rllib-package-ref.html#module-ray.rllib.optimizers>`__
* `ray.rllib.utils <rllib-package-ref.html#module-ray.rllib.utils>`__
Troubleshooting
---------------
If you encounter errors like
`blas_thread_init: pthread_create: Resource temporarily unavailable` when using many workers,
try setting ``OMP_NUM_THREADS=1``. Similarly, check configured system limits with
`ulimit -a` for other resource limit errors.
If you encounter out-of-memory errors, consider setting ``redis_max_memory`` and ``object_store_memory`` in ``ray.init()`` to reduce memory usage.
For debugging unexpected hangs or performance problems, you can run ``ray stack`` to dump
the stack traces of all Ray workers on the current node, and ``ray timeline`` to dump
a timeline visualization of tasks to a file.
To learn more, proceed to the `table of contents <rllib-toc.html>`__.

View file

@ -14,77 +14,6 @@ from ray.rllib.utils.memory import concat_aligned
DEFAULT_POLICY_ID = "default_policy"
@PublicAPI
class MultiAgentBatch(object):
"""A batch of experiences from multiple policies in the environment.
Attributes:
policy_batches (dict): Mapping from policy id to a normal SampleBatch
of experiences. Note that these batches may be of different length.
count (int): The number of timesteps in the environment this batch
contains. This will be less than the number of transitions this
batch contains across all policies in total.
"""
@PublicAPI
def __init__(self, policy_batches, count):
self.policy_batches = policy_batches
self.count = count
@staticmethod
@PublicAPI
def wrap_as_needed(batches, count):
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
return batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(batches, count)
@staticmethod
@PublicAPI
def concat_samples(samples):
policy_batches = collections.defaultdict(list)
total_count = 0
for s in samples:
assert isinstance(s, MultiAgentBatch)
for policy_id, batch in s.policy_batches.items():
policy_batches[policy_id].append(batch)
total_count += s.count
out = {}
for policy_id, batches in policy_batches.items():
out[policy_id] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out, total_count)
@PublicAPI
def copy(self):
return MultiAgentBatch(
{k: v.copy()
for (k, v) in self.policy_batches.items()}, self.count)
@PublicAPI
def total(self):
ct = 0
for batch in self.policy_batches.values():
ct += batch.count
return ct
@DeveloperAPI
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
for batch in self.policy_batches.values():
batch.compress(bulk=bulk, columns=columns)
@DeveloperAPI
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
for batch in self.policy_batches.values():
batch.decompress_if_needed(columns)
def __str__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)
def __repr__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)
@PublicAPI
class SampleBatch(object):
"""Wrapper around a dictionary with string keys and array-like values.
@ -294,3 +223,74 @@ class SampleBatch(object):
def __contains__(self, x):
return x in self.data
@PublicAPI
class MultiAgentBatch(object):
"""A batch of experiences from multiple policies in the environment.
Attributes:
policy_batches (dict): Mapping from policy id to a normal SampleBatch
of experiences. Note that these batches may be of different length.
count (int): The number of timesteps in the environment this batch
contains. This will be less than the number of transitions this
batch contains across all policies in total.
"""
@PublicAPI
def __init__(self, policy_batches, count):
self.policy_batches = policy_batches
self.count = count
@staticmethod
@PublicAPI
def wrap_as_needed(batches, count):
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
return batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(batches, count)
@staticmethod
@PublicAPI
def concat_samples(samples):
policy_batches = collections.defaultdict(list)
total_count = 0
for s in samples:
assert isinstance(s, MultiAgentBatch)
for policy_id, batch in s.policy_batches.items():
policy_batches[policy_id].append(batch)
total_count += s.count
out = {}
for policy_id, batches in policy_batches.items():
out[policy_id] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out, total_count)
@PublicAPI
def copy(self):
return MultiAgentBatch(
{k: v.copy()
for (k, v) in self.policy_batches.items()}, self.count)
@PublicAPI
def total(self):
ct = 0
for batch in self.policy_batches.values():
ct += batch.count
return ct
@DeveloperAPI
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
for batch in self.policy_batches.values():
batch.compress(bulk=bulk, columns=columns)
@DeveloperAPI
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
for batch in self.policy_batches.values():
batch.decompress_if_needed(columns)
def __str__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)
def __repr__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)