mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[rllib] Deprecate policy optimizers (#8345)
This commit is contained in:
parent
d27e6da1b2
commit
9a83908c46
80 changed files with 2194 additions and 1100 deletions
|
@ -49,7 +49,7 @@ Distributed Prioritized Experience Replay (Ape-X)
|
|||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1803.00933>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/apex.py>`__
|
||||
Ape-X variations of DQN, DDPG, and QMIX (`APEX_DQN <https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/apex.py>`__, `APEX_DDPG <https://github.com/ray-project/ray/blob/master/rllib/agents/ddpg/apex.py>`__, `APEX_QMIX <https://github.com/ray-project/ray/blob/master/rllib/agents/qmix/apex.py>`__) use a single GPU learner and many CPU workers for experience collection. Experience collection can scale to hundreds of CPU workers due to the distributed prioritization of experience prior to storage in replay buffers.
|
||||
Ape-X variations of DQN and DDPG (`APEX_DQN <https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/apex.py>`__, `APEX_DDPG <https://github.com/ray-project/ray/blob/master/rllib/agents/ddpg/apex.py>`__) use a single GPU learner and many CPU workers for experience collection. Experience collection can scale to hundreds of CPU workers due to the distributed prioritization of experience prior to storage in replay buffers.
|
||||
|
||||
.. figure:: apex-arch.svg
|
||||
|
||||
|
@ -200,9 +200,9 @@ Advantage Actor-Critic (A2C, A3C)
|
|||
---------------------------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1602.01783>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/a3c/a3c.py>`__
|
||||
RLlib implements A2C and A3C using SyncSamplesOptimizer and AsyncGradientsOptimizer respectively for policy optimization. These algorithms scale to up to 16-32 worker processes depending on the environment.
|
||||
RLlib implements both A2C and A3C. These algorithms scale to 16-32+ worker processes depending on the environment.
|
||||
|
||||
A2C also supports microbatching (i.e., gradient accumulation), which can be enabled by setting the ``microbatch_size`` config. Microbatching allows for training with a ``train_batch_size`` much larger than GPU memory. See also the `microbatch optimizer implementation <https://github.com/ray-project/ray/blob/master/rllib/optimizers/microbatch_optimizer.py>`__.
|
||||
A2C also supports microbatching (i.e., gradient accumulation), which can be enabled by setting the ``microbatch_size`` config. Microbatching allows for training with a ``train_batch_size`` much larger than GPU memory.
|
||||
|
||||
.. figure:: a2c-arch.svg
|
||||
|
||||
|
@ -237,7 +237,7 @@ Deep Deterministic Policy Gradients (DDPG, TD3)
|
|||
-----------------------------------------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1509.02971>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/ddpg/ddpg.py>`__
|
||||
DDPG is implemented similarly to DQN (below). The algorithm can be scaled by increasing the number of workers, switching to AsyncGradientsOptimizer, or using Ape-X. The improvements from `TD3 <https://spinningup.openai.com/en/latest/algorithms/td3.html>`__ are available as ``TD3``.
|
||||
DDPG is implemented similarly to DQN (below). The algorithm can be scaled by increasing the number of workers or using Ape-X. The improvements from `TD3 <https://spinningup.openai.com/en/latest/algorithms/td3.html>`__ are available as ``TD3``.
|
||||
|
||||
.. figure:: dqn-arch.svg
|
||||
|
||||
|
@ -258,7 +258,7 @@ Deep Q Networks (DQN, Rainbow, Parametric DQN)
|
|||
----------------------------------------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1312.5602>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/dqn.py>`__
|
||||
RLlib DQN is implemented using the SyncReplayOptimizer. The algorithm can be scaled by increasing the number of workers, using the AsyncGradientsOptimizer for async DQN, or using Ape-X. Memory usage is reduced by compressing samples in the replay buffer with LZ4. All of the DQN improvements evaluated in `Rainbow <https://arxiv.org/abs/1710.02298>`__ are available, though not all are enabled by default. See also how to use `parametric-actions in DQN <rllib-models.html#variable-length-parametric-action-spaces>`__.
|
||||
DQN can be scaled by increasing the number of workers or using Ape-X. Memory usage is reduced by compressing samples in the replay buffer with LZ4. All of the DQN improvements evaluated in `Rainbow <https://arxiv.org/abs/1710.02298>`__ are available, though not all are enabled by default. See also how to use `parametric-actions in DQN <rllib-models.html#variable-length-parametric-action-spaces>`__.
|
||||
|
||||
.. figure:: dqn-arch.svg
|
||||
|
||||
|
@ -495,7 +495,7 @@ Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rll
|
|||
Single-Player Alpha Zero (contrib/AlphaZero)
|
||||
--------------------------------------------
|
||||
|pytorch|
|
||||
`[paper] <https://arxiv.org/abs/1712.01815>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/contrib/alpha_zero>`__ AlphaZero is an RL agent originally designed for two-player games. This version adapts it to handle single player games. The code can be used with the SyncSamplesOptimizer as well as with a modified version of the SyncReplayOptimizer, and it scales to any number of workers. It also implements the ranked rewards `(R2) <https://arxiv.org/abs/1807.01672>`__ strategy to enable self-play even in the one-player setting. The code is mainly purposed to be used for combinatorial optimization.
|
||||
`[paper] <https://arxiv.org/abs/1712.01815>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/contrib/alpha_zero>`__ AlphaZero is an RL agent originally designed for two-player games. This version adapts it to handle single player games. The code can be sscaled to any number of workers. It also implements the ranked rewards `(R2) <https://arxiv.org/abs/1807.01672>`__ strategy to enable self-play even in the one-player setting. The code is mainly purposed to be used for combinatorial optimization.
|
||||
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/contrib/alpha_zero/examples/train_cartpole.py>`__
|
||||
|
||||
|
|
|
@ -217,6 +217,10 @@ In the above section you saw how to compose a simple policy gradient algorithm w
|
|||
|
||||
Besides some boilerplate for defining the PPO configuration and some warnings, there are two important arguments to take note of here: ``make_policy_optimizer=choose_policy_optimizer``, and ``after_optimizer_step=update_kl``.
|
||||
|
||||
.. warning::
|
||||
|
||||
Policy optimizers are deprecated. This documentation will be updated in the future.
|
||||
|
||||
The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#policy-optimization>`__ to use for distributed training. You can think of these policy optimizers as coordinating the distributed workflow needed to improve the policy. Depending on the trainer config, PPO can switch between a simple synchronous optimizer, or a multi-GPU optimizer that implements minibatch SGD (the default):
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -581,6 +585,10 @@ Here is an example of creating a set of rollout workers and using them gather ex
|
|||
Policy Optimization
|
||||
-------------------
|
||||
|
||||
.. warning::
|
||||
|
||||
Policy optimizers are deprecated. This documentation will be updated in the future.
|
||||
|
||||
Similar to how a `gradient-descent optimizer <https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer>`__ can be used to improve a model, RLlib's `policy optimizers <https://github.com/ray-project/ray/tree/master/rllib/optimizers>`__ implement different strategies for improving a policy.
|
||||
|
||||
For example, in A3C you'd want to compute gradients asynchronously on different workers, and apply them to a central policy replica. This strategy is implemented by the `AsyncGradientsOptimizer <https://github.com/ray-project/ray/blob/master/rllib/optimizers/async_gradients_optimizer.py>`__. Another alternative is to gather experiences synchronously in parallel and optimize the model centrally, as in `SyncSamplesOptimizer <https://github.com/ray-project/ray/blob/master/rllib/optimizers/sync_samples_optimizer.py>`__. Policy optimizers abstract these strategies away into reusable modules.
|
||||
|
|
|
@ -91,7 +91,7 @@ Instead of using the ``use_lstm: True`` option, it can be preferable use a custo
|
|||
Batch Normalization
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
You can use ``tf.layers.batch_normalization(x, training=input_dict["is_training"])`` to add batch norm layers to your custom model: `code example <https://github.com/ray-project/ray/blob/master/rllib/examples/batch_norm_model.py>`__. RLlib will automatically run the update ops for the batch norm layers during optimization (see `tf_policy.py <https://github.com/ray-project/ray/blob/master/rllib/policy/tf_policy.py>`__ and `multi_gpu_impl.py <https://github.com/ray-project/ray/blob/master/rllib/optimizers/multi_gpu_impl.py>`__ for the exact handling of these updates).
|
||||
You can use ``tf.layers.batch_normalization(x, training=input_dict["is_training"])`` to add batch norm layers to your custom model: `code example <https://github.com/ray-project/ray/blob/master/rllib/examples/batch_norm_model.py>`__. RLlib will automatically run the update ops for the batch norm layers during optimization (see `tf_policy.py <https://github.com/ray-project/ray/blob/master/rllib/policy/tf_policy.py>`__ and `multi_gpu_impl.py <https://github.com/ray-project/ray/blob/master/rllib/execution/multi_gpu_impl.py>`__ for the exact handling of these updates).
|
||||
|
||||
In case RLlib does not properly detect the update ops for your custom model, you can override the ``update_ops()`` method to return the list of ops to run for updates.
|
||||
|
||||
|
|
|
@ -19,18 +19,18 @@ ray.rllib.evaluation
|
|||
.. automodule:: ray.rllib.evaluation
|
||||
:members:
|
||||
|
||||
ray.rllib.execution
|
||||
--------------------
|
||||
|
||||
.. automodule:: ray.rllib.execution
|
||||
:members:
|
||||
|
||||
ray.rllib.models
|
||||
----------------
|
||||
|
||||
.. automodule:: ray.rllib.models
|
||||
:members:
|
||||
|
||||
ray.rllib.optimizers
|
||||
--------------------
|
||||
|
||||
.. automodule:: ray.rllib.optimizers
|
||||
:members:
|
||||
|
||||
ray.rllib.utils
|
||||
---------------
|
||||
|
||||
|
|
|
@ -180,8 +180,8 @@ Package Reference
|
|||
* `ray.rllib.agents <rllib-package-ref.html#module-ray.rllib.agents>`__
|
||||
* `ray.rllib.env <rllib-package-ref.html#module-ray.rllib.env>`__
|
||||
* `ray.rllib.evaluation <rllib-package-ref.html#module-ray.rllib.evaluation>`__
|
||||
* `ray.rllib.execution <rllib-package-ref.html#module-ray.rllib.execution>`__
|
||||
* `ray.rllib.models <rllib-package-ref.html#module-ray.rllib.models>`__
|
||||
* `ray.rllib.optimizers <rllib-package-ref.html#module-ray.rllib.optimizers>`__
|
||||
* `ray.rllib.utils <rllib-package-ref.html#module-ray.rllib.utils>`__
|
||||
|
||||
Troubleshooting
|
||||
|
|
|
@ -92,7 +92,7 @@ Policies each define a ``learn_on_batch()`` method that improves the policy give
|
|||
- Simple `Q-function loss <https://github.com/ray-project/ray/blob/a1d2e1762325cd34e14dc411666d63bb15d6eaf0/rllib/agents/dqn/simple_q_policy.py#L136>`__
|
||||
- Importance-weighted `APPO surrogate loss <https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/appo_policy.py>`__
|
||||
|
||||
RLlib `Trainer classes <rllib-concepts.html#trainers>`__ coordinate the distributed workflow of running rollouts and optimizing policies. They do this by leveraging `policy optimizers <rllib-concepts.html#policy-optimization>`__ that implement the desired computation pattern. The following figure shows *synchronous sampling*, the simplest of `these patterns <rllib-algorithms.html>`__:
|
||||
RLlib `Trainer classes <rllib-concepts.html#trainers>`__ coordinate the distributed workflow of running rollouts and optimizing policies. They do this by leveraging Ray `parallel iterators <iter.html>`__ to implement the desired computation pattern. The following figure shows *synchronous sampling*, the simplest of `these patterns <rllib-algorithms.html>`__:
|
||||
|
||||
.. figure:: a2c-arch.svg
|
||||
|
||||
|
|
|
@ -29,18 +29,10 @@ def test_metrics(ray_start_regular_shared):
|
|||
it = it.gather_sync().for_each(f)
|
||||
it2 = it2.gather_sync().for_each(f)
|
||||
|
||||
# Context cannot be accessed outside the iterator.
|
||||
with pytest.raises(ValueError):
|
||||
LocalIterator.get_metrics()
|
||||
|
||||
# Tests iterators have isolated contexts.
|
||||
assert it.take(4) == [1, 3, 6, 10]
|
||||
assert it2.take(4) == [1, 3, 6, 10]
|
||||
|
||||
# Context cannot be accessed outside the iterator.
|
||||
with pytest.raises(ValueError):
|
||||
LocalIterator.get_metrics()
|
||||
|
||||
|
||||
def test_zip_with_source_actor(ray_start_regular_shared):
|
||||
it = from_items([1, 2, 3, 4], num_shards=2)
|
||||
|
|
|
@ -670,15 +670,8 @@ class LocalIterator(Generic[T]):
|
|||
|
||||
@contextmanager
|
||||
def _metrics_context(self):
|
||||
if hasattr(self.thread_local, "metrics"):
|
||||
prev_metrics = self.thread_local.metrics
|
||||
else:
|
||||
prev_metrics = None
|
||||
try:
|
||||
self.thread_local.metrics = self.shared_metrics.get()
|
||||
yield
|
||||
finally:
|
||||
self.thread_local.metrics = prev_metrics
|
||||
self.thread_local.metrics = self.shared_metrics.get()
|
||||
yield
|
||||
|
||||
def __iter__(self):
|
||||
self._build_once()
|
||||
|
|
22
rllib/BUILD
22
rllib/BUILD
|
@ -1166,30 +1166,23 @@ py_test(
|
|||
|
||||
# --------------------------------------------------------------------
|
||||
# Optimizers and Memories
|
||||
# rllib/optimizers/
|
||||
# rllib/execution/
|
||||
#
|
||||
# Tag: optimizers
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "test_optimizers",
|
||||
tags = ["optimizers"],
|
||||
size = "large",
|
||||
srcs = ["optimizers/tests/test_optimizers.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_segment_tree",
|
||||
tags = ["optimizers"],
|
||||
size = "small",
|
||||
srcs = ["optimizers/tests/test_segment_tree.py"]
|
||||
srcs = ["execution/tests/test_segment_tree.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_prioritized_replay_buffer",
|
||||
tags = ["optimizers"],
|
||||
size = "small",
|
||||
srcs = ["optimizers/tests/test_prioritized_replay_buffer.py"]
|
||||
srcs = ["execution/tests/test_prioritized_replay_buffer.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
|
@ -2054,15 +2047,6 @@ py_test(
|
|||
args = ["--stop-timesteps=2000", "--run=QMIX"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/twostep_game_apex_qmix",
|
||||
main = "examples/twostep_game.py",
|
||||
tags = ["examples", "examples_T"],
|
||||
size = "medium",
|
||||
srcs = ["examples/twostep_game.py"],
|
||||
args = ["--stop-timesteps=2000", "--run=APEX_QMIX", "--num-cpus=4"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "contrib/bandits/examples/lin_ts",
|
||||
main = "contrib/bandits/examples/simple_context_bandit.py",
|
||||
|
|
|
@ -2,7 +2,6 @@ import math
|
|||
|
||||
from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, \
|
||||
validate_config, get_policy_class
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer, MicrobatchOptimizer
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
|
@ -27,18 +26,6 @@ A2C_DEFAULT_CONFIG = merge_dicts(
|
|||
)
|
||||
|
||||
|
||||
def choose_policy_optimizer(workers, config):
|
||||
if config["microbatch_size"]:
|
||||
return MicrobatchOptimizer(
|
||||
workers,
|
||||
train_batch_size=config["train_batch_size"],
|
||||
microbatch_size=config["microbatch_size"])
|
||||
else:
|
||||
return SyncSamplesOptimizer(
|
||||
workers, train_batch_size=config["train_batch_size"])
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
|
@ -71,6 +58,5 @@ A2CTrainer = build_trainer(
|
|||
default_config=A2C_DEFAULT_CONFIG,
|
||||
default_policy=A3CTFPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
make_policy_optimizer=choose_policy_optimizer,
|
||||
validate_config=validate_config,
|
||||
execution_plan=execution_plan)
|
||||
|
|
|
@ -6,7 +6,6 @@ from ray.rllib.agents.trainer_template import build_trainer
|
|||
from ray.rllib.execution.rollout_ops import AsyncGradients
|
||||
from ray.rllib.execution.train_ops import ApplyGradients
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,11 +61,6 @@ def validate_config(config):
|
|||
"Multithreading can be lead to crashes if used with pytorch.")
|
||||
|
||||
|
||||
def make_async_optimizer(workers, config):
|
||||
return AsyncGradientsOptimizer(workers, **config["optimizer"])
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
# For A3C, compute policy gradients remotely on the rollout workers.
|
||||
grads = AsyncGradients(workers)
|
||||
|
@ -84,5 +78,4 @@ A3CTrainer = build_trainer(
|
|||
default_policy=A3CTFPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
validate_config=validate_config,
|
||||
make_policy_optimizer=make_async_optimizer,
|
||||
execution_plan=execution_plan)
|
||||
|
|
|
@ -16,10 +16,8 @@ class TestA2C(unittest.TestCase):
|
|||
|
||||
def test_a2c_exec_impl(ray_start_regular):
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
env="CartPole-v0", config={
|
||||
"min_iter_time_s": 0,
|
||||
"use_exec_api": True
|
||||
})
|
||||
assert isinstance(trainer.train(), dict)
|
||||
check_compute_action(trainer)
|
||||
|
@ -30,7 +28,6 @@ class TestA2C(unittest.TestCase):
|
|||
config={
|
||||
"min_iter_time_s": 0,
|
||||
"microbatch_size": 10,
|
||||
"use_exec_api": True,
|
||||
})
|
||||
assert isinstance(trainer.train(), dict)
|
||||
check_compute_action(trainer)
|
||||
|
|
|
@ -17,7 +17,6 @@ from ray.rllib.agents.es.es_tf_policy import rollout
|
|||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -329,7 +328,7 @@ class ARSTrainer(Trainer):
|
|||
worker.do_rollouts.remote(theta_id) for worker in self.workers
|
||||
]
|
||||
# Get the results of the rollouts.
|
||||
for result in ray_get_and_free(rollout_ids):
|
||||
for result in ray.get(rollout_ids):
|
||||
results.append(result)
|
||||
# Update the number of episodes and the number of timesteps
|
||||
# keeping in mind that result.noisy_lengths is a list of lists,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES
|
||||
from ray.rllib.agents.dqn.apex import apex_execution_plan
|
||||
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, \
|
||||
DEFAULT_CONFIG as DDPG_CONFIG
|
||||
|
||||
|
@ -30,4 +30,4 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
|||
ApexDDPGTrainer = DDPGTrainer.with_updates(
|
||||
name="APEX_DDPG",
|
||||
default_config=APEX_DDPG_DEFAULT_CONFIG,
|
||||
**APEX_TRAINER_PROPERTIES)
|
||||
execution_plan=apex_execution_plan)
|
||||
|
|
|
@ -6,7 +6,7 @@ import ray.rllib.agents.ddpg as ddpg
|
|||
from ray.rllib.agents.ddpg.ddpg_torch_policy import ddpg_actor_critic_loss as \
|
||||
loss_torch
|
||||
from ray.rllib.agents.sac.tests.test_sac import SimpleEnv
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.numpy import fc, huber_loss, l2_loss, relu, sigmoid
|
||||
|
|
|
@ -4,6 +4,7 @@ import copy
|
|||
import ray
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer, \
|
||||
DEFAULT_CONFIG as DQN_CONFIG, calculate_rr_weights
|
||||
from ray.rllib.agents.dqn.learner_thread import LearnerThread
|
||||
from ray.rllib.execution.common import STEPS_TRAINED_COUNTER, \
|
||||
SampleBatchType, _get_shared_metrics, _get_global_vars
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
|
@ -12,12 +13,9 @@ from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
|
|||
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
|
||||
from ray.rllib.execution.train_ops import UpdateTargetNetwork
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.optimizers import AsyncReplayOptimizer
|
||||
from ray.rllib.optimizers.async_replay_optimizer import ReplayActor
|
||||
from ray.rllib.execution.replay_buffer import ReplayActor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.actors import create_colocated
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LearnerThread
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -51,45 +49,6 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def defer_make_workers(trainer, env_creator, policy, config):
|
||||
# Hack to workaround https://github.com/ray-project/ray/issues/2541
|
||||
# The workers will be created later, after the optimizer is created
|
||||
return trainer._make_workers(env_creator, policy, config, num_workers=0)
|
||||
|
||||
|
||||
def make_async_optimizer(workers, config):
|
||||
assert len(workers.remote_workers()) == 0
|
||||
extra_config = config["optimizer"].copy()
|
||||
for key in [
|
||||
"prioritized_replay", "prioritized_replay_alpha",
|
||||
"prioritized_replay_beta", "prioritized_replay_eps"
|
||||
]:
|
||||
if key in config:
|
||||
extra_config[key] = config[key]
|
||||
opt = AsyncReplayOptimizer(
|
||||
workers,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
rollout_fragment_length=config["rollout_fragment_length"],
|
||||
**extra_config)
|
||||
workers.add_workers(config["num_workers"])
|
||||
opt._set_workers(workers.remote_workers())
|
||||
return opt
|
||||
|
||||
|
||||
def update_target_based_on_num_steps_trained(trainer, fetches):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if (trainer.optimizer.num_steps_trained -
|
||||
trainer.state["last_target_update_ts"] >
|
||||
trainer.config["target_network_update_freq"]):
|
||||
trainer.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
trainer.state["last_target_update_ts"] = (
|
||||
trainer.optimizer.num_steps_trained)
|
||||
trainer.state["num_target_updates"] += 1
|
||||
|
||||
|
||||
# Update worker weights as they finish generating experiences.
|
||||
class UpdateWorkerWeights:
|
||||
def __init__(self, learner_thread, workers, max_weight_sync_delay):
|
||||
|
@ -112,14 +71,12 @@ class UpdateWorkerWeights:
|
|||
actor.set_weights.remote(self.weights, _get_global_vars())
|
||||
self.steps_since_update[actor] = 0
|
||||
# Update metrics.
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters["num_weight_syncs"] += 1
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers: WorkerSet, config: dict):
|
||||
def apex_execution_plan(workers: WorkerSet, config: dict):
|
||||
# Create a number of replay buffer actors.
|
||||
# TODO(ekl) support batch replay options
|
||||
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
|
||||
replay_actors = create_colocated(ReplayActor, [
|
||||
num_replay_buffer_shards,
|
||||
|
@ -216,14 +173,7 @@ def execution_plan(workers: WorkerSet, config: dict):
|
|||
selected_workers=selected_workers).for_each(add_apex_metrics)
|
||||
|
||||
|
||||
APEX_TRAINER_PROPERTIES = {
|
||||
"make_workers": defer_make_workers,
|
||||
"make_policy_optimizer": make_async_optimizer,
|
||||
"after_optimizer_step": update_target_based_on_num_steps_trained,
|
||||
}
|
||||
|
||||
ApexTrainer = DQNTrainer.with_updates(
|
||||
name="APEX",
|
||||
default_config=APEX_DEFAULT_CONFIG,
|
||||
execution_plan=execution_plan,
|
||||
**APEX_TRAINER_PROPERTIES)
|
||||
execution_plan=apex_execution_plan)
|
||||
|
|
|
@ -4,11 +4,10 @@ from ray.rllib.agents.trainer import with_common_config
|
|||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
|
||||
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
||||
from ray.rllib.optimizers import SyncReplayOptimizer
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.exploration import PerWorkerEpsilonGreedy
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
|
||||
|
@ -140,35 +139,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def make_policy_optimizer(workers, config):
|
||||
"""Create the single process DQN policy optimizer.
|
||||
|
||||
Returns:
|
||||
SyncReplayOptimizer: Used for generic off-policy Trainers.
|
||||
"""
|
||||
# SimpleQ does not use a PR buffer.
|
||||
kwargs = {"prioritized_replay": config.get("prioritized_replay", False)}
|
||||
kwargs.update(**config["optimizer"])
|
||||
if "prioritized_replay" in config:
|
||||
kwargs.update({
|
||||
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
|
||||
"prioritized_replay_beta": config["prioritized_replay_beta"],
|
||||
"prioritized_replay_beta_annealing_timesteps": config[
|
||||
"prioritized_replay_beta_annealing_timesteps"],
|
||||
"final_prioritized_replay_beta": config[
|
||||
"final_prioritized_replay_beta"],
|
||||
"prioritized_replay_eps": config["prioritized_replay_eps"],
|
||||
})
|
||||
|
||||
return SyncReplayOptimizer(
|
||||
workers,
|
||||
# TODO(sven): Move all PR-beta decays into Schedule components.
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
**kwargs)
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
"""Checks and updates the config based on settings.
|
||||
|
||||
|
@ -258,54 +228,6 @@ def validate_config(config):
|
|||
config["rollout_fragment_length"] = adjusted_batch_size
|
||||
|
||||
|
||||
def get_initial_state(config):
|
||||
return {
|
||||
"last_target_update_ts": 0,
|
||||
"num_target_updates": 0,
|
||||
}
|
||||
|
||||
|
||||
# TODO(sven): Move this to generic Trainer. Every Algo should do this.
|
||||
def update_worker_exploration(trainer):
|
||||
"""Sets epsilon exploration values in all policies to updated values.
|
||||
|
||||
According to current time-step.
|
||||
|
||||
Args:
|
||||
trainer (Trainer): The Trainer object for the DQN.
|
||||
"""
|
||||
# Store some data for metrics after learning.
|
||||
global_timestep = trainer.optimizer.num_steps_sampled
|
||||
trainer.train_start_timestep = global_timestep
|
||||
|
||||
# Get all current exploration-infos (from Policies, which cache this info).
|
||||
trainer.exploration_infos = trainer.workers.foreach_trainable_policy(
|
||||
lambda p, _: p.get_exploration_info())
|
||||
|
||||
|
||||
def after_train_result(trainer, result):
|
||||
"""Add some DQN specific metrics to results."""
|
||||
global_timestep = trainer.optimizer.num_steps_sampled
|
||||
result.update(
|
||||
timesteps_this_iter=global_timestep - trainer.train_start_timestep,
|
||||
info=dict({
|
||||
"exploration_infos": trainer.exploration_infos,
|
||||
"num_target_updates": trainer.state["num_target_updates"],
|
||||
}, **trainer.optimizer.stats()))
|
||||
|
||||
|
||||
def update_target_if_needed(trainer, fetches):
|
||||
"""Update the target network in configured intervals."""
|
||||
global_timestep = trainer.optimizer.num_steps_sampled
|
||||
if global_timestep - trainer.state["last_target_update_ts"] > \
|
||||
trainer.config["target_network_update_freq"]:
|
||||
trainer.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
trainer.state["last_target_update_ts"] = global_timestep
|
||||
trainer.state["num_target_updates"] += 1
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
if config.get("prioritized_replay"):
|
||||
prio_args = {
|
||||
|
@ -405,11 +327,6 @@ GenericOffPolicyTrainer = build_trainer(
|
|||
get_policy_class=get_policy_class,
|
||||
default_config=DEFAULT_CONFIG,
|
||||
validate_config=validate_config,
|
||||
get_initial_state=get_initial_state,
|
||||
make_policy_optimizer=make_policy_optimizer,
|
||||
before_train_step=update_worker_exploration,
|
||||
after_optimizer_step=update_target_if_needed,
|
||||
after_train_result=after_train_result,
|
||||
execution_plan=execution_plan)
|
||||
|
||||
DQNTrainer = GenericOffPolicyTrainer.with_updates(
|
||||
|
|
59
rllib/agents/dqn/learner_thread.py
Normal file
59
rllib/agents/dqn/learner_thread.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import threading
|
||||
from six.moves import queue
|
||||
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
|
||||
LEARNER_QUEUE_MAX_SIZE = 16
|
||||
|
||||
|
||||
class LearnerThread(threading.Thread):
|
||||
"""Background thread that updates the local model from replay data.
|
||||
|
||||
The learner thread communicates with the main thread through Queues. This
|
||||
is needed since Ray operations can only be run on the main thread. In
|
||||
addition, moving heavyweight gradient ops session runs off the main thread
|
||||
improves overall throughput.
|
||||
"""
|
||||
|
||||
def __init__(self, local_worker):
|
||||
threading.Thread.__init__(self)
|
||||
self.learner_queue_size = WindowStat("size", 50)
|
||||
self.local_worker = local_worker
|
||||
self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
|
||||
self.outqueue = queue.Queue()
|
||||
self.queue_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.overall_timer = TimerStat()
|
||||
self.daemon = True
|
||||
self.weights_updated = False
|
||||
self.stopped = False
|
||||
self.stats = {}
|
||||
|
||||
def run(self):
|
||||
while not self.stopped:
|
||||
self.step()
|
||||
|
||||
def step(self):
|
||||
with self.overall_timer:
|
||||
with self.queue_timer:
|
||||
ra, replay = self.inqueue.get()
|
||||
if replay is not None:
|
||||
prio_dict = {}
|
||||
with self.grad_timer:
|
||||
grad_out = self.local_worker.learn_on_batch(replay)
|
||||
for pid, info in grad_out.items():
|
||||
td_error = info.get(
|
||||
"td_error",
|
||||
info[LEARNER_STATS_KEY].get("td_error"))
|
||||
prio_dict[pid] = (replay.policy_batches[pid].data.get(
|
||||
"batch_indexes"), td_error)
|
||||
self.stats[pid] = get_learner_stats(info)
|
||||
self.grad_timer.push_units_processed(replay.count)
|
||||
self.outqueue.put((ra, prio_dict, replay.count))
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
self.weights_updated = True
|
||||
self.overall_timer.push_units_processed(replay and replay.count
|
||||
or 0)
|
|
@ -8,7 +8,7 @@ from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
|
|||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -87,7 +87,6 @@ def get_policy_class(config):
|
|||
return SimpleQTFPolicy
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
local_replay_buffer = LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
|
|
|
@ -14,7 +14,6 @@ from ray.rllib.env.env_context import EnvContext
|
|||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils import FilterManager
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -317,7 +316,7 @@ class ESTrainer(Trainer):
|
|||
worker.do_rollouts.remote(theta_id) for worker in self._workers
|
||||
]
|
||||
# Get the results of the rollouts.
|
||||
for result in ray_get_and_free(rollout_ids):
|
||||
for result in ray.get(rollout_ids):
|
||||
results.append(result)
|
||||
# Update the number of episodes and the number of timesteps
|
||||
# keeping in mind that result.noisy_lengths is a list of lists,
|
||||
|
|
|
@ -1,26 +1,22 @@
|
|||
import copy
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy
|
||||
from ray.rllib.agents.impala.tree_agg import \
|
||||
gather_experiences_tree_aggregation
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.execution.common import STEPS_TRAINED_COUNTER, _get_global_vars
|
||||
from ray.rllib.execution.learner_thread import LearnerThread
|
||||
from ray.rllib.execution.multi_gpu_learner import TFMultiGPULearner
|
||||
from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
|
||||
from ray.rllib.execution.common import STEPS_TRAINED_COUNTER, \
|
||||
_get_global_vars, _get_shared_metrics
|
||||
from ray.rllib.execution.replay_ops import MixInReplay
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.optimizers import AsyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
|
||||
from ray.rllib.optimizers.aso_learner import LearnerThread
|
||||
from ray.rllib.optimizers.aso_multi_gpu_learner import TFMultiGPULearner
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.resources import Resources
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -91,50 +87,15 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"vf_loss_coeff": 0.5,
|
||||
"entropy_coeff": 0.01,
|
||||
"entropy_coeff_schedule": None,
|
||||
|
||||
# Callback for APPO to use to update KL, target network periodically.
|
||||
# The input to the callback is the learner fetches dict.
|
||||
"after_train_step": None,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def defer_make_workers(trainer, env_creator, policy, config):
|
||||
# Defer worker creation to after the optimizer has been created.
|
||||
return trainer._make_workers(env_creator, policy, config, 0)
|
||||
|
||||
|
||||
def make_aggregators_and_optimizer(workers, config):
|
||||
if config["num_aggregation_workers"] > 0:
|
||||
# Create co-located aggregator actors first for placement pref
|
||||
aggregators = TreeAggregator.precreate_aggregators(
|
||||
config["num_aggregation_workers"])
|
||||
else:
|
||||
aggregators = None
|
||||
workers.add_workers(config["num_workers"])
|
||||
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
workers,
|
||||
lr=config["lr"],
|
||||
num_gpus=config["num_gpus"],
|
||||
rollout_fragment_length=config["rollout_fragment_length"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
replay_buffer_num_slots=config["replay_buffer_num_slots"],
|
||||
replay_proportion=config["replay_proportion"],
|
||||
num_data_loader_buffers=config["num_data_loader_buffers"],
|
||||
max_sample_requests_in_flight_per_worker=config[
|
||||
"max_sample_requests_in_flight_per_worker"],
|
||||
broadcast_interval=config["broadcast_interval"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
minibatch_buffer_size=config["minibatch_buffer_size"],
|
||||
num_aggregation_workers=config["num_aggregation_workers"],
|
||||
learner_queue_size=config["learner_queue_size"],
|
||||
learner_queue_timeout=config["learner_queue_timeout"],
|
||||
**config["optimizer"])
|
||||
|
||||
if aggregators:
|
||||
# Assign the pre-created aggregators to the optimizer
|
||||
optimizer.aggregator.init(aggregators)
|
||||
return optimizer
|
||||
|
||||
|
||||
class OverrideDefaultResourceRequest:
|
||||
@classmethod
|
||||
@override(Trainable)
|
||||
|
@ -230,16 +191,18 @@ class BroadcastUpdateLearnerWeights:
|
|||
self.steps_since_broadcast = 0
|
||||
self.learner_thread.weights_updated = False
|
||||
# Update metrics.
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters["num_weight_broadcasts"] += 1
|
||||
actor.set_weights.remote(self.weights, _get_global_vars())
|
||||
|
||||
|
||||
def record_steps_trained(count):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
def record_steps_trained(item):
|
||||
count, fetches = item
|
||||
metrics = _get_shared_metrics()
|
||||
# Manually update the steps trained counter since the learner thread
|
||||
# is executing outside the pipeline.
|
||||
metrics.counters[STEPS_TRAINED_COUNTER] += count
|
||||
return item
|
||||
|
||||
|
||||
def gather_experiences_directly(workers, config):
|
||||
|
@ -261,7 +224,6 @@ def gather_experiences_directly(workers, config):
|
|||
return train_batches
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
if config["num_aggregation_workers"] > 0:
|
||||
train_batches = gather_experiences_tree_aggregation(workers, config)
|
||||
|
@ -290,26 +252,14 @@ def execution_plan(workers, config):
|
|||
merged_op = Concurrently(
|
||||
[enqueue_op, dequeue_op], mode="async", output_indexes=[1])
|
||||
|
||||
def add_learner_metrics(result):
|
||||
def timer_to_ms(timer):
|
||||
return round(1000 * timer.mean, 3)
|
||||
|
||||
result["info"].update({
|
||||
"learner_queue": learner_thread.learner_queue_size.stats(),
|
||||
"learner": copy.deepcopy(learner_thread.stats),
|
||||
"timing_breakdown": {
|
||||
"learner_grad_time_ms": timer_to_ms(learner_thread.grad_timer),
|
||||
"learner_load_time_ms": timer_to_ms(learner_thread.load_timer),
|
||||
"learner_load_wait_time_ms": timer_to_ms(
|
||||
learner_thread.load_wait_timer),
|
||||
"learner_dequeue_time_ms": timer_to_ms(
|
||||
learner_thread.queue_timer),
|
||||
}
|
||||
})
|
||||
return result
|
||||
# Callback for APPO to use to update KL, target network periodically.
|
||||
# The input to the callback is the learner fetches dict.
|
||||
if config["after_train_step"]:
|
||||
merged_op = merged_op.for_each(lambda t: t[1]).for_each(
|
||||
config["after_train_step"](workers, config))
|
||||
|
||||
return StandardMetricsReporting(merged_op, workers, config) \
|
||||
.for_each(add_learner_metrics)
|
||||
.for_each(learner_thread.add_learner_metrics)
|
||||
|
||||
|
||||
ImpalaTrainer = build_trainer(
|
||||
|
@ -318,7 +268,5 @@ ImpalaTrainer = build_trainer(
|
|||
default_policy=VTraceTFPolicy,
|
||||
validate_config=validate_config,
|
||||
get_policy_class=get_policy_class,
|
||||
make_workers=defer_make_workers,
|
||||
make_policy_optimizer=make_aggregators_and_optimizer,
|
||||
execution_plan=execution_plan,
|
||||
mixins=[OverrideDefaultResourceRequest])
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
from ray.rllib.optimizers import SyncBatchReplayOptimizer
|
||||
from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \
|
||||
StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.train_ops import TrainOneStep
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -37,15 +42,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def make_optimizer(workers, config):
|
||||
return SyncBatchReplayOptimizer(
|
||||
workers,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["replay_buffer_size"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
)
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config.get("use_pytorch") is True:
|
||||
from ray.rllib.agents.marwil.marwil_torch_policy import \
|
||||
|
@ -55,9 +51,27 @@ def get_policy_class(config):
|
|||
return MARWILTFPolicy
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
replay_buffer = SimpleReplayBuffer(config["replay_buffer_size"])
|
||||
|
||||
store_op = rollouts \
|
||||
.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
|
||||
|
||||
replay_op = Replay(local_buffer=replay_buffer) \
|
||||
.combine(
|
||||
ConcatBatches(min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(workers))
|
||||
|
||||
train_op = Concurrently(
|
||||
[store_op, replay_op], mode="round_robin", output_indexes=[1])
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
MARWILTrainer = build_trainer(
|
||||
name="MARWIL",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=MARWILTFPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
make_policy_optimizer=make_optimizer)
|
||||
execution_plan=execution_plan)
|
||||
|
|
|
@ -29,6 +29,7 @@ class TestMARWIL(unittest.TestCase):
|
|||
for i in range(num_iterations):
|
||||
trainer.train()
|
||||
check_compute_action(trainer, include_prev_action_reward=True)
|
||||
trainer.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.train_ops import TrainOneStep
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -25,26 +22,8 @@ def get_policy_class(config):
|
|||
return PGTFPolicy
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
# Collects experiences in parallel from multiple RolloutWorker actors.
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# Combine experiences batches until we hit `train_batch_size` in size.
|
||||
# Then, train the policy on those experiences and update the workers.
|
||||
train_op = rollouts \
|
||||
.combine(ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(workers))
|
||||
|
||||
# Add on the standard episode reward, etc. metrics reporting. This returns
|
||||
# a LocalIterator[metrics_dict] representing metrics for each train step.
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
PGTrainer = build_trainer(
|
||||
name="PG",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=PGTFPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
execution_plan=execution_plan)
|
||||
get_policy_class=get_policy_class)
|
||||
|
|
|
@ -17,16 +17,6 @@ class TestPG(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def test_pg_exec_impl(ray_start_regular):
|
||||
trainer = pg.PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"min_iter_time_s": 0,
|
||||
"use_exec_api": True
|
||||
})
|
||||
assert isinstance(trainer.train(), dict)
|
||||
check_compute_action(trainer)
|
||||
|
||||
def test_pg_compilation(self):
|
||||
"""Test whether a PGTrainer can be built with both frameworks."""
|
||||
config = pg.DEFAULT_CONFIG.copy()
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from ray.rllib.agents.impala.impala import validate_config
|
||||
from ray.rllib.agents.ppo.appo_tf_policy import AsyncPPOTFPolicy
|
||||
from ray.rllib.agents.ppo.ppo import update_kl
|
||||
from ray.rllib.agents.ppo.ppo import UpdateKL
|
||||
from ray.rllib.agents.trainer import with_base_config
|
||||
from ray.rllib.agents import impala
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES, _get_shared_metrics
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -54,35 +56,48 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
|
|||
"vf_loss_coeff": 0.5,
|
||||
"entropy_coeff": 0.01,
|
||||
"entropy_coeff_schedule": None,
|
||||
|
||||
# TODO: impl update target.
|
||||
"use_exec_api": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def update_target_and_kl(trainer, fetches):
|
||||
# Update the KL coeff depending on how many steps LearnerThread has stepped
|
||||
# through
|
||||
learner_steps = trainer.optimizer.learner.num_steps
|
||||
if learner_steps >= trainer.target_update_frequency:
|
||||
|
||||
# Update Target Network
|
||||
trainer.optimizer.learner.num_steps = 0
|
||||
trainer.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
|
||||
# Also update KL Coeff
|
||||
if trainer.config["use_kl_loss"]:
|
||||
update_kl(trainer, trainer.optimizer.learner.stats)
|
||||
|
||||
|
||||
def initialize_target(trainer):
|
||||
trainer.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
trainer.target_update_frequency = trainer.config["num_sgd_iter"] \
|
||||
* trainer.config["minibatch_buffer_size"]
|
||||
|
||||
|
||||
class UpdateTargetAndKL:
|
||||
def __init__(self, workers, config):
|
||||
self.workers = workers
|
||||
self.config = config
|
||||
self.update_kl = UpdateKL(workers)
|
||||
self.target_update_freq = config["num_sgd_iter"] \
|
||||
* config["minibatch_buffer_size"]
|
||||
|
||||
def __call__(self, fetches):
|
||||
metrics = _get_shared_metrics()
|
||||
cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER]
|
||||
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update > self.target_update_freq:
|
||||
metrics.counters[NUM_TARGET_UPDATES] += 1
|
||||
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
# Update Target Network
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
# Also update KL Coeff
|
||||
if self.config["use_kl_loss"]:
|
||||
self.update_kl(fetches)
|
||||
|
||||
|
||||
def add_target_callback(config):
|
||||
"""Add the update target and kl hook.
|
||||
|
||||
This hook is called explicitly after each learner step in the execution
|
||||
setup for IMPALA.
|
||||
"""
|
||||
|
||||
config["after_train_step"] = UpdateTargetAndKL
|
||||
return validate_config(config)
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
|
@ -96,8 +111,7 @@ def get_policy_class(config):
|
|||
APPOTrainer = impala.ImpalaTrainer.with_updates(
|
||||
name="APPO",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
validate_config=validate_config,
|
||||
validate_config=add_target_callback,
|
||||
default_policy=AsyncPPOTFPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
after_init=initialize_target,
|
||||
after_optimizer_step=update_target_and_kl)
|
||||
after_init=initialize_target)
|
||||
|
|
|
@ -18,14 +18,13 @@ import logging
|
|||
import time
|
||||
|
||||
import ray
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.agents.ppo import ppo
|
||||
from ray.rllib.agents.trainer import with_base_config
|
||||
from ray.rllib.optimizers import TorchDistributedDataParallelOptimizer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_TRAINED_COUNTER, LEARNER_INFO, LEARN_ON_BATCH_TIMER
|
||||
STEPS_TRAINED_COUNTER, LEARNER_INFO, LEARN_ON_BATCH_TIMER, \
|
||||
_get_shared_metrics
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
|
||||
|
@ -87,17 +86,6 @@ def validate_config(config):
|
|||
ppo.validate_config(config)
|
||||
|
||||
|
||||
def make_distributed_allreduce_optimizer(workers, config):
|
||||
return TorchDistributedDataParallelOptimizer(
|
||||
workers,
|
||||
expected_batch_size=config["rollout_fragment_length"] *
|
||||
config["num_envs_per_worker"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
sgd_minibatch_size=config["sgd_minibatch_size"],
|
||||
standardize_fields=["advantages"])
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
rollouts = ParallelRollouts(workers, mode="raw")
|
||||
|
||||
|
@ -141,7 +129,7 @@ def execution_plan(workers, config):
|
|||
def __call__(self, items):
|
||||
for item in items:
|
||||
info, count = item
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters[STEPS_SAMPLED_COUNTER] += count
|
||||
metrics.counters[STEPS_TRAINED_COUNTER] += count
|
||||
metrics.info[LEARNER_INFO] = info
|
||||
|
@ -190,6 +178,5 @@ def execution_plan(workers, config):
|
|||
DDPPOTrainer = ppo.PPOTrainer.with_updates(
|
||||
name="DDPPO",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
make_policy_optimizer=make_distributed_allreduce_optimizer,
|
||||
execution_plan=execution_plan,
|
||||
validate_config=validate_config)
|
||||
|
|
|
@ -7,7 +7,6 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches, \
|
|||
StandardizeFields, SelectExperiences
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, TrainTFMultiGPU
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
@ -76,58 +75,12 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"_fake_gpus": False,
|
||||
# Use PyTorch as framework?
|
||||
"use_pytorch": False,
|
||||
# Use the execution plan API instead of policy optimizers.
|
||||
"use_exec_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def choose_policy_optimizer(workers, config):
|
||||
if config["simple_optimizer"]:
|
||||
return SyncSamplesOptimizer(
|
||||
workers,
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
sgd_minibatch_size=config["sgd_minibatch_size"],
|
||||
standardize_fields=["advantages"])
|
||||
|
||||
return LocalMultiGPUOptimizer(
|
||||
workers,
|
||||
sgd_batch_size=config["sgd_minibatch_size"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
num_gpus=config["num_gpus"],
|
||||
rollout_fragment_length=config["rollout_fragment_length"],
|
||||
num_envs_per_worker=config["num_envs_per_worker"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
standardize_fields=["advantages"],
|
||||
shuffle_sequences=config["shuffle_sequences"],
|
||||
_fake_gpus=config["_fake_gpus"])
|
||||
|
||||
|
||||
def update_kl(trainer, fetches):
|
||||
# Single-agent.
|
||||
if "kl" in fetches:
|
||||
trainer.workers.local_worker().for_policy(
|
||||
lambda pi: pi.update_kl(fetches["kl"]))
|
||||
|
||||
# Multi-agent.
|
||||
else:
|
||||
|
||||
def update(pi, pi_id):
|
||||
if pi_id in fetches:
|
||||
pi.update_kl(fetches[pi_id]["kl"])
|
||||
else:
|
||||
logger.info("No data for {}, not updating kl".format(pi_id))
|
||||
|
||||
trainer.workers.local_worker().foreach_trainable_policy(update)
|
||||
|
||||
|
||||
def warn_about_bad_reward_scales(trainer, result):
|
||||
return _warn_about_bad_reward_scales(trainer.config, result)
|
||||
|
||||
|
||||
def _warn_about_bad_reward_scales(config, result):
|
||||
def warn_about_bad_reward_scales(config, result):
|
||||
if result["policy_reward_mean"]:
|
||||
return result # Punt on handling multiagent case.
|
||||
|
||||
|
@ -197,7 +150,25 @@ def get_policy_class(config):
|
|||
return PPOTFPolicy
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
class UpdateKL:
|
||||
"""Callback to update the KL based on optimization info."""
|
||||
|
||||
def __init__(self, workers):
|
||||
self.workers = workers
|
||||
|
||||
def __call__(self, fetches):
|
||||
def update(pi, pi_id):
|
||||
assert "kl" not in fetches, (
|
||||
"kl should be nested under policy id key", fetches)
|
||||
if pi_id in fetches:
|
||||
assert "kl" in fetches[pi_id], (fetches, pi_id)
|
||||
pi.update_kl(fetches[pi_id]["kl"])
|
||||
else:
|
||||
logger.warning("No data for {}, not updating kl".format(pi_id))
|
||||
|
||||
self.workers.local_worker().foreach_trainable_policy(update)
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
|
@ -227,23 +198,11 @@ def execution_plan(workers, config):
|
|||
shuffle_sequences=config["shuffle_sequences"],
|
||||
_fake_gpus=config["_fake_gpus"]))
|
||||
|
||||
# Callback to update the KL based on optimization info.
|
||||
def update_kl(item):
|
||||
_, fetches = item
|
||||
|
||||
def update(pi, pi_id):
|
||||
if pi_id in fetches:
|
||||
pi.update_kl(fetches[pi_id]["kl"])
|
||||
else:
|
||||
logger.warning("No data for {}, not updating kl".format(pi_id))
|
||||
|
||||
workers.local_worker().foreach_trainable_policy(update)
|
||||
|
||||
# Update KL after each round of training.
|
||||
train_op = train_op.for_each(update_kl)
|
||||
train_op = train_op.for_each(lambda t: t[1]).for_each(UpdateKL(workers))
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config) \
|
||||
.for_each(lambda result: _warn_about_bad_reward_scales(config, result))
|
||||
.for_each(lambda result: warn_about_bad_reward_scales(config, result))
|
||||
|
||||
|
||||
PPOTrainer = build_trainer(
|
||||
|
@ -251,8 +210,5 @@ PPOTrainer = build_trainer(
|
|||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=PPOTFPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
make_policy_optimizer=choose_policy_optimizer,
|
||||
execution_plan=execution_plan,
|
||||
validate_config=validate_config,
|
||||
after_optimizer_step=update_kl,
|
||||
after_train_result=warn_about_bad_reward_scales)
|
||||
validate_config=validate_config)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from ray.rllib.agents.qmix.qmix import QMixTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.qmix.apex import ApexQMixTrainer
|
||||
|
||||
__all__ = ["QMixTrainer", "ApexQMixTrainer", "DEFAULT_CONFIG"]
|
||||
__all__ = ["QMixTrainer", "DEFAULT_CONFIG"]
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
"""Experimental: scalable Ape-X variant of QMIX"""
|
||||
|
||||
from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES
|
||||
from ray.rllib.agents.qmix.qmix import QMixTrainer, \
|
||||
DEFAULT_CONFIG as QMIX_CONFIG
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
||||
APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
|
||||
QMIX_CONFIG, # see also the options in qmix.py, which are also supported
|
||||
{
|
||||
"optimizer": merge_dicts(
|
||||
QMIX_CONFIG["optimizer"],
|
||||
{
|
||||
"max_weight_sync_delay": 400,
|
||||
"num_replay_buffer_shards": 4,
|
||||
"batch_replay": True, # required for RNN. Disables prio.
|
||||
"debug": False
|
||||
}),
|
||||
"num_gpus": 0,
|
||||
"num_workers": 32,
|
||||
"buffer_size": 2000000,
|
||||
"learning_starts": 50000,
|
||||
"train_batch_size": 512,
|
||||
"rollout_fragment_length": 50,
|
||||
"target_network_update_freq": 500000,
|
||||
"timesteps_per_iteration": 25000,
|
||||
"per_worker_exploration": True,
|
||||
"min_iter_time_s": 30,
|
||||
},
|
||||
)
|
||||
|
||||
ApexQMixTrainer = QMixTrainer.with_updates(
|
||||
name="APEX_QMIX",
|
||||
default_config=APEX_QMIX_DEFAULT_CONFIG,
|
||||
**APEX_TRAINER_PROPERTIES)
|
|
@ -7,7 +7,6 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
|||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.optimizers import SyncBatchReplayOptimizer
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -93,15 +92,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def make_sync_batch_optimizer(workers, config):
|
||||
return SyncBatchReplayOptimizer(
|
||||
workers,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
train_batch_size=config["train_batch_size"])
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
|
||||
|
@ -127,5 +117,4 @@ QMixTrainer = GenericOffPolicyTrainer.with_updates(
|
|||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=QMixTorchPolicy,
|
||||
get_policy_class=None,
|
||||
make_policy_optimizer=make_sync_batch_optimizer,
|
||||
execution_plan=execution_plan)
|
||||
|
|
|
@ -25,11 +25,6 @@ def _import_qmix():
|
|||
return qmix.QMixTrainer
|
||||
|
||||
|
||||
def _import_apex_qmix():
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.ApexQMixTrainer
|
||||
|
||||
|
||||
def _import_ddpg():
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.DDPGTrainer
|
||||
|
@ -116,7 +111,6 @@ ALGORITHMS = {
|
|||
"PG": _import_pg,
|
||||
"IMPALA": _import_impala,
|
||||
"QMIX": _import_qmix,
|
||||
"APEX_QMIX": _import_apex_qmix,
|
||||
"APPO": _import_appo,
|
||||
"DDPPO": _import_ddppo,
|
||||
"MARWIL": _import_marwil,
|
||||
|
|
|
@ -9,7 +9,7 @@ from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as \
|
|||
loss_torch
|
||||
from ray.rllib.models.tf.tf_action_dist import SquashedGaussian
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchSquashedGaussian
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.numpy import fc, relu
|
||||
|
|
|
@ -18,7 +18,6 @@ from ray.rllib.evaluation.worker_set import WorkerSet
|
|||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts, \
|
||||
try_import_tf
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import ExportFormat
|
||||
|
@ -206,9 +205,6 @@ COMMON_CONFIG = {
|
|||
# trainer guarantees all eval workers have the latest policy state before
|
||||
# this function is called.
|
||||
"custom_eval_function": None,
|
||||
# EXPERIMENTAL: use the execution plan based API impl of the algo. Can also
|
||||
# be enabled by setting RLLIB_EXEC_API=1.
|
||||
"use_exec_api": True,
|
||||
|
||||
# === Advanced Rollout Settings ===
|
||||
# Use a background thread for sampling (slightly off-policy, usually not
|
||||
|
@ -981,7 +977,7 @@ class Trainer(Trainable):
|
|||
for i, obj_id in enumerate(checks):
|
||||
w = workers.remote_workers()[i]
|
||||
try:
|
||||
ray_get_and_free(obj_id)
|
||||
ray.get(obj_id)
|
||||
healthy_workers.append(w)
|
||||
logger.info("Worker {} looks healthy".format(i + 1))
|
||||
except RayError:
|
||||
|
|
|
@ -1,38 +1,57 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.train_ops import TrainOneStep
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default_execution_plan(workers, config):
|
||||
# Collects experiences in parallel from multiple RolloutWorker actors.
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# Combine experiences batches until we hit `train_batch_size` in size.
|
||||
# Then, train the policy on those experiences and update the workers.
|
||||
train_op = rollouts \
|
||||
.combine(ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(workers))
|
||||
|
||||
# Add on the standard episode reward, etc. metrics reporting. This returns
|
||||
# a LocalIterator[metrics_dict] representing metrics for each train step.
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def build_trainer(name,
|
||||
default_policy,
|
||||
default_config=None,
|
||||
validate_config=None,
|
||||
get_initial_state=None,
|
||||
get_policy_class=None,
|
||||
before_init=None,
|
||||
make_workers=None,
|
||||
make_policy_optimizer=None,
|
||||
after_init=None,
|
||||
before_train_step=None,
|
||||
after_optimizer_step=None,
|
||||
after_train_result=None,
|
||||
collect_metrics_fn=None,
|
||||
before_evaluate_fn=None,
|
||||
mixins=None,
|
||||
execution_plan=None):
|
||||
def build_trainer(
|
||||
name,
|
||||
default_policy,
|
||||
default_config=None,
|
||||
validate_config=None,
|
||||
get_initial_state=None, # DEPRECATED
|
||||
get_policy_class=None,
|
||||
before_init=None,
|
||||
make_workers=None, # DEPRECATED
|
||||
make_policy_optimizer=None, # DEPRECATED
|
||||
after_init=None,
|
||||
before_train_step=None, # DEPRECATED
|
||||
after_optimizer_step=None, # DEPRECATED
|
||||
after_train_result=None, # DEPRECATED
|
||||
collect_metrics_fn=None, # DEPRECATED
|
||||
before_evaluate_fn=None,
|
||||
mixins=None,
|
||||
execution_plan=default_execution_plan):
|
||||
"""Helper function for defining a custom trainer.
|
||||
|
||||
Functions will be run in this order to initialize the trainer:
|
||||
1. Config setup: validate_config, get_initial_state, get_policy
|
||||
2. Worker setup: before_init, make_workers, make_policy_optimizer
|
||||
1. Config setup: validate_config, get_policy
|
||||
2. Worker setup: before_init, execution_plan
|
||||
3. Post setup: after_init
|
||||
|
||||
Arguments:
|
||||
|
@ -42,37 +61,18 @@ def build_trainer(name,
|
|||
otherwise uses the Trainer default config.
|
||||
validate_config (func): optional callback that checks a given config
|
||||
for correctness. It may mutate the config as needed.
|
||||
get_initial_state (func): optional function that returns the initial
|
||||
state dict given the trainer instance as an argument. The state
|
||||
dict must be serializable so that it can be checkpointed, and will
|
||||
be available as the `trainer.state` variable.
|
||||
get_policy_class (func): optional callback that takes a config and
|
||||
returns the policy class to override the default with
|
||||
before_init (func): optional function to run at the start of trainer
|
||||
init that takes the trainer instance as argument
|
||||
make_workers (func): override the method that creates rollout workers.
|
||||
This takes in (trainer, env_creator, policy, config) as args.
|
||||
make_policy_optimizer (func): optional function that returns a
|
||||
PolicyOptimizer instance given (WorkerSet, config)
|
||||
after_init (func): optional function to run at the end of trainer init
|
||||
that takes the trainer instance as argument
|
||||
before_train_step (func): optional callback to run before each train()
|
||||
call. It takes the trainer instance as an argument.
|
||||
after_optimizer_step (func): optional callback to run after each
|
||||
step() call to the policy optimizer. It takes the trainer instance
|
||||
and the policy gradient fetches as arguments.
|
||||
after_train_result (func): optional callback to run at the end of each
|
||||
train() call. It takes the trainer instance and result dict as
|
||||
arguments, and may mutate the result dict as needed.
|
||||
collect_metrics_fn (func): override the method used to collect metrics.
|
||||
It takes the trainer instance as argumnt.
|
||||
before_evaluate_fn (func): callback to run before evaluation. This
|
||||
takes the trainer instance as argument.
|
||||
mixins (list): list of any class mixins for the returned trainer class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the Trainer class
|
||||
execution_plan (func): Experimental distributed execution
|
||||
API. This overrides `make_policy_optimizer`.
|
||||
execution_plan (func): Setup the distributed execution workflow.
|
||||
|
||||
Returns:
|
||||
a Trainer instance that uses the specified args.
|
||||
|
@ -94,6 +94,7 @@ def build_trainer(name,
|
|||
validate_config(config)
|
||||
|
||||
if get_initial_state:
|
||||
deprecation_warning("get_initial_state", "execution_plan")
|
||||
self.state = get_initial_state(self)
|
||||
else:
|
||||
self.state = {}
|
||||
|
@ -103,12 +104,9 @@ def build_trainer(name,
|
|||
self._policy = get_policy_class(config)
|
||||
if before_init:
|
||||
before_init(self)
|
||||
use_exec_api = (execution_plan
|
||||
and (self.config["use_exec_api"]
|
||||
or "RLLIB_EXEC_API" in os.environ))
|
||||
|
||||
# Creating all workers (excluding evaluation workers).
|
||||
if make_workers and not use_exec_api:
|
||||
if make_workers and not execution_plan:
|
||||
deprecation_warning("make_workers", "execution_plan")
|
||||
self.workers = make_workers(self, env_creator, self._policy,
|
||||
config)
|
||||
else:
|
||||
|
@ -119,16 +117,12 @@ def build_trainer(name,
|
|||
self.optimizer = None
|
||||
self.execution_plan = execution_plan
|
||||
|
||||
if use_exec_api:
|
||||
self.train_exec_impl = execution_plan(self.workers, config)
|
||||
elif make_policy_optimizer:
|
||||
if make_policy_optimizer:
|
||||
deprecation_warning("make_policy_optimizer", "execution_plan")
|
||||
self.optimizer = make_policy_optimizer(self.workers, config)
|
||||
else:
|
||||
optimizer_config = dict(
|
||||
config["optimizer"],
|
||||
**{"train_batch_size": config["train_batch_size"]})
|
||||
self.optimizer = SyncSamplesOptimizer(self.workers,
|
||||
**optimizer_config)
|
||||
assert execution_plan is not None
|
||||
self.train_exec_impl = execution_plan(self.workers, config)
|
||||
if after_init:
|
||||
after_init(self)
|
||||
|
||||
|
@ -138,6 +132,7 @@ def build_trainer(name,
|
|||
return self._train_exec_impl()
|
||||
|
||||
if before_train_step:
|
||||
deprecation_warning("before_train_step", "execution_plan")
|
||||
before_train_step(self)
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
|
||||
|
@ -147,6 +142,8 @@ def build_trainer(name,
|
|||
fetches = self.optimizer.step()
|
||||
optimizer_steps_this_iter += 1
|
||||
if after_optimizer_step:
|
||||
deprecation_warning("after_optimizer_step",
|
||||
"execution_plan")
|
||||
after_optimizer_step(self, fetches)
|
||||
if (time.time() - start >= self.config["min_iter_time_s"]
|
||||
and self.optimizer.num_steps_sampled - prev_steps >=
|
||||
|
@ -154,6 +151,7 @@ def build_trainer(name,
|
|||
break
|
||||
|
||||
if collect_metrics_fn:
|
||||
deprecation_warning("collect_metrics_fn", "execution_plan")
|
||||
res = collect_metrics_fn(self)
|
||||
else:
|
||||
res = self.collect_metrics()
|
||||
|
@ -164,15 +162,12 @@ def build_trainer(name,
|
|||
info=res.get("info", {}))
|
||||
|
||||
if after_train_result:
|
||||
deprecation_warning("after_train_result", "execution_plan")
|
||||
after_train_result(self, res)
|
||||
return res
|
||||
|
||||
def _train_exec_impl(self):
|
||||
if before_train_step:
|
||||
logger.debug("Ignoring before_train_step callback")
|
||||
res = next(self.train_exec_impl)
|
||||
if after_train_result:
|
||||
logger.debug("Ignoring after_train_result callback")
|
||||
return res
|
||||
|
||||
@override(Trainer)
|
||||
|
|
|
@ -3,18 +3,21 @@ import logging
|
|||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \
|
||||
StoreToReplayBuffer, WaitUntilTimestepsElapsed
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.train_ops import TrainOneStep
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.model import restore_original_dimensions
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
from ray.tune.registry import ENV_CREATOR, _global_registry
|
||||
|
||||
from ray.rllib.contrib.alpha_zero.core.alpha_zero_policy import AlphaZeroPolicy
|
||||
from ray.rllib.contrib.alpha_zero.core.mcts import MCTS
|
||||
from ray.rllib.contrib.alpha_zero.core.ranked_rewards import get_r2_env_wrapper
|
||||
from ray.rllib.contrib.alpha_zero.optimizer.sync_batches_replay_optimizer \
|
||||
import SyncBatchesReplayOptimizer
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
@ -111,21 +114,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def choose_policy_optimizer(workers, config):
|
||||
if config["simple_optimizer"]:
|
||||
return SyncSamplesOptimizer(
|
||||
workers,
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
train_batch_size=config["train_batch_size"])
|
||||
else:
|
||||
return SyncBatchesReplayOptimizer(
|
||||
workers,
|
||||
num_gradient_descents=config["num_sgd_iter"],
|
||||
learning_starts=config["learning_starts"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
buffer_size=config["buffer_size"])
|
||||
|
||||
|
||||
def alpha_zero_loss(policy, model, dist_class, train_batch):
|
||||
# get inputs unflattened inputs
|
||||
input_dict = restore_original_dimensions(train_batch["obs"],
|
||||
|
@ -172,8 +160,36 @@ class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
|
|||
_env_creator)
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
if config["simple_optimizer"]:
|
||||
train_op = rollouts \
|
||||
.combine(ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(
|
||||
workers, num_sgd_iter=config["num_sgd_iter"]))
|
||||
else:
|
||||
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
|
||||
|
||||
store_op = rollouts \
|
||||
.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
|
||||
|
||||
replay_op = Replay(local_buffer=replay_buffer) \
|
||||
.filter(WaitUntilTimestepsElapsed(config["learning_starts"])) \
|
||||
.combine(
|
||||
ConcatBatches(min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(
|
||||
workers, num_sgd_iter=config["num_sgd_iter"]))
|
||||
|
||||
train_op = Concurrently(
|
||||
[store_op, replay_op], mode="round_robin", output_indexes=[1])
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
AlphaZeroTrainer = build_trainer(
|
||||
name="AlphaZero",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=AlphaZeroPolicyWrapperClass,
|
||||
make_policy_optimizer=choose_policy_optimizer)
|
||||
execution_plan=execution_plan)
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
import random
|
||||
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.optimizers.sync_batch_replay_optimizer import \
|
||||
SyncBatchReplayOptimizer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class SyncBatchesReplayOptimizer(SyncBatchReplayOptimizer):
|
||||
def __init__(self,
|
||||
workers,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
train_batch_size=32,
|
||||
num_gradient_descents=10):
|
||||
super(SyncBatchesReplayOptimizer, self).__init__(
|
||||
workers, learning_starts, buffer_size, train_batch_size)
|
||||
self.num_sgds = num_gradient_descents
|
||||
|
||||
@override(SyncBatchReplayOptimizer)
|
||||
def _optimize(self):
|
||||
for _ in range(self.num_sgds):
|
||||
samples = [random.choice(self.replay_buffer)]
|
||||
while sum(s.count for s in samples) < self.train_batch_size:
|
||||
samples.append(random.choice(self.replay_buffer))
|
||||
samples = SampleBatch.concat_samples(samples)
|
||||
with self.grad_timer:
|
||||
info_dict = self.workers.local_worker().learn_on_batch(samples)
|
||||
for policy_id, info in info_dict.items():
|
||||
self.learner_stats[policy_id] = get_learner_stats(info)
|
||||
self.grad_timer.push_units_processed(samples.count)
|
||||
self.num_steps_trained += samples.count
|
||||
return info_dict
|
|
@ -29,17 +29,5 @@ TS_CONFIG = with_common_config({
|
|||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def get_stats(trainer):
|
||||
env_metrics = trainer.collect_metrics()
|
||||
stats = trainer.optimizer.stats()
|
||||
# Uncomment if regret at each time step is needed
|
||||
# stats.update({"all_regrets": trainer.get_policy().regrets})
|
||||
return dict(env_metrics, **stats)
|
||||
|
||||
|
||||
LinTSTrainer = build_trainer(
|
||||
name="LinTS",
|
||||
default_config=TS_CONFIG,
|
||||
default_policy=BanditPolicy,
|
||||
collect_metrics_fn=get_stats)
|
||||
name="LinTS", default_config=TS_CONFIG, default_policy=BanditPolicy)
|
||||
|
|
|
@ -29,17 +29,5 @@ UCB_CONFIG = with_common_config({
|
|||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def get_stats(trainer):
|
||||
env_metrics = trainer.collect_metrics()
|
||||
stats = trainer.optimizer.stats()
|
||||
# Uncomment if regret at each time step is needed
|
||||
# stats.update({"all_regrets": trainer.get_policy().regrets})
|
||||
return dict(env_metrics, **stats)
|
||||
|
||||
|
||||
LinUCBTrainer = build_trainer(
|
||||
name="LinUCB",
|
||||
default_config=UCB_CONFIG,
|
||||
default_policy=BanditPolicy,
|
||||
collect_metrics_fn=get_stats)
|
||||
name="LinUCB", default_config=UCB_CONFIG, default_policy=BanditPolicy)
|
||||
|
|
|
@ -14,7 +14,6 @@ import logging
|
|||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
|
||||
from ray.rllib.contrib.maddpg.maddpg_policy import MADDPGTFPolicy
|
||||
from ray.rllib.optimizers import SyncReplayOptimizer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -112,11 +111,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def set_global_timestep(trainer):
|
||||
global_timestep = trainer.optimizer.num_steps_sampled
|
||||
trainer.train_start_timestep = global_timestep
|
||||
|
||||
|
||||
def before_learn_on_batch(multi_agent_batch, policies, train_batch_size):
|
||||
samples = {}
|
||||
|
||||
|
@ -150,31 +144,6 @@ def before_learn_on_batch(multi_agent_batch, policies, train_batch_size):
|
|||
return MultiAgentBatch(policy_batches, train_batch_size)
|
||||
|
||||
|
||||
def make_optimizer(workers, config):
|
||||
return SyncReplayOptimizer(
|
||||
workers,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
before_learn_on_batch=before_learn_on_batch,
|
||||
synchronize_sampling=True,
|
||||
prioritized_replay=False)
|
||||
|
||||
|
||||
def add_trainer_metrics(trainer, result):
|
||||
global_timestep = trainer.optimizer.num_steps_sampled
|
||||
result.update(
|
||||
timesteps_this_iter=global_timestep - trainer.train_start_timestep,
|
||||
info=dict({
|
||||
"num_target_updates": trainer.state["num_target_updates"],
|
||||
}, **trainer.optimizer.stats()))
|
||||
|
||||
|
||||
def collect_metrics(trainer):
|
||||
result = trainer.collect_metrics()
|
||||
return result
|
||||
|
||||
|
||||
def add_maddpg_postprocessing(config):
|
||||
"""Add the before learn on batch hook.
|
||||
|
||||
|
@ -196,11 +165,5 @@ MADDPGTrainer = GenericOffPolicyTrainer.with_updates(
|
|||
name="MADDPG",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=MADDPGTFPolicy,
|
||||
validate_config=add_maddpg_postprocessing,
|
||||
get_policy_class=None,
|
||||
before_init=None,
|
||||
before_train_step=set_global_timestep,
|
||||
make_policy_optimizer=make_optimizer,
|
||||
after_train_result=add_trainer_metrics,
|
||||
collect_metrics_fn=collect_metrics,
|
||||
before_evaluate_fn=None)
|
||||
validate_config=add_maddpg_postprocessing)
|
||||
|
|
3
rllib/env/remote_vector_env.py
vendored
3
rllib/env/remote_vector_env.py
vendored
|
@ -2,7 +2,6 @@ import logging
|
|||
|
||||
import ray
|
||||
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -57,7 +56,7 @@ class RemoteVectorEnv(BaseEnv):
|
|||
actor = self.pending.pop(obj_id)
|
||||
env_id = self.actors.index(actor)
|
||||
env_ids.add(env_id)
|
||||
ob, rew, done, info = ray_get_and_free(obj_id)
|
||||
ob, rew, done, info = ray.get(obj_id)
|
||||
obs[env_id] = ob
|
||||
rewards[env_id] = rew
|
||||
dones[env_id] = done
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.interface import EvaluatorInterface
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
|
||||
|
@ -14,7 +13,6 @@ from ray.rllib.evaluation.metrics import collect_metrics
|
|||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
|
||||
__all__ = [
|
||||
"EvaluatorInterface",
|
||||
"RolloutWorker",
|
||||
"PolicyGraph",
|
||||
"TFPolicyGraph",
|
||||
|
|
|
@ -1,124 +0,0 @@
|
|||
import os
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class EvaluatorInterface:
|
||||
"""This is the interface between policy optimizers and policy evaluation.
|
||||
|
||||
See also: RolloutWorker
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def sample(self):
|
||||
"""Returns a batch of experience sampled from this evaluator.
|
||||
|
||||
This method must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
SampleBatch|MultiAgentBatch: A columnar batch of experiences
|
||||
(e.g., tensors), or a multi-agent batch.
|
||||
|
||||
Examples:
|
||||
>>> print(ev.sample())
|
||||
SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...})
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(self, samples):
|
||||
"""Update policies based on the given batch.
|
||||
|
||||
This is the equivalent to apply_gradients(compute_gradients(samples)),
|
||||
but can be optimized to avoid pulling gradients into CPU memory.
|
||||
|
||||
Either this or the combination of compute/apply grads must be
|
||||
implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
info: dictionary of extra metadata from compute_gradients().
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
>>> ev.learn_on_batch(samples)
|
||||
"""
|
||||
|
||||
grads, info = self.compute_gradients(samples)
|
||||
self.apply_gradients(grads)
|
||||
return info
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, samples):
|
||||
"""Returns a gradient computed w.r.t the specified samples.
|
||||
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
(grads, info): A list of gradients that can be applied on a
|
||||
compatible evaluator. In the multi-agent case, returns a dict
|
||||
of gradients keyed by policy ids. An info dictionary of
|
||||
extra metadata is also returned.
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
>>> grads, info = ev2.compute_gradients(samples)
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def apply_gradients(self, grads):
|
||||
"""Applies the given gradients to this evaluator's weights.
|
||||
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
|
||||
Examples:
|
||||
>>> samples = ev1.sample()
|
||||
>>> grads, info = ev2.compute_gradients(samples)
|
||||
>>> ev1.apply_gradients(grads)
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def get_weights(self):
|
||||
"""Returns the model weights of this Evaluator.
|
||||
|
||||
This method must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
object: weights that can be set on a compatible evaluator.
|
||||
info: dictionary of extra metadata.
|
||||
|
||||
Examples:
|
||||
>>> weights = ev1.get_weights()
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def set_weights(self, weights):
|
||||
"""Sets the model weights of this Evaluator.
|
||||
|
||||
This method must be implemented by subclasses.
|
||||
|
||||
Examples:
|
||||
>>> weights = ev1.get_weights()
|
||||
>>> ev2.set_weights(weights)
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def get_host(self):
|
||||
"""Returns the hostname of the process running this evaluator."""
|
||||
|
||||
return os.uname()[1]
|
||||
|
||||
@DeveloperAPI
|
||||
def apply(self, func, *args):
|
||||
"""Apply the given function to this evaluator instance."""
|
||||
|
||||
return func(self, *args)
|
|
@ -8,7 +8,6 @@ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
|||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -18,7 +17,8 @@ def get_learner_stats(grad_info):
|
|||
"""Return optimization stats reported from the policy.
|
||||
|
||||
Example:
|
||||
>>> grad_info = evaluator.learn_on_batch(samples)
|
||||
>>> grad_info = worker.learn_on_batch(samples)
|
||||
{"td_error": [...], "learner_stats": {"vf_loss": ..., ...}}
|
||||
>>> print(get_stats(grad_info))
|
||||
{"vf_loss": ..., "policy_loss": ...}
|
||||
"""
|
||||
|
@ -68,7 +68,7 @@ def collect_episodes(local_worker=None,
|
|||
logger.warning(
|
||||
"WARNING: collected no metrics in {} seconds".format(
|
||||
timeout_seconds))
|
||||
metric_lists = ray_get_and_free(collected)
|
||||
metric_lists = ray.get(collected)
|
||||
else:
|
||||
metric_lists = []
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ from ray.rllib.env.external_env import ExternalEnv
|
|||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.evaluation.interface import EvaluatorInterface
|
||||
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.policy import Policy
|
||||
|
@ -28,7 +27,7 @@ from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
|||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import NoPreprocessor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
|
@ -55,7 +54,7 @@ def get_global_worker():
|
|||
|
||||
|
||||
@DeveloperAPI
|
||||
class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
class RolloutWorker(ParallelIteratorWorker):
|
||||
"""Common experience collection class.
|
||||
|
||||
This class wraps a policy instance and an environment class to
|
||||
|
@ -497,12 +496,19 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
"Created rollout worker with env {} ({}), policies {}".format(
|
||||
self.async_env, self.env, self.policy_map))
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
@DeveloperAPI
|
||||
def sample(self):
|
||||
"""Evaluate the current policies and return a batch of experiences.
|
||||
"""Returns a batch of experience sampled from this worker.
|
||||
|
||||
Return:
|
||||
SampleBatch|MultiAgentBatch from evaluating the current policies.
|
||||
This method must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
SampleBatch|MultiAgentBatch: A columnar batch of experiences
|
||||
(e.g., tensors), or a multi-agent batch.
|
||||
|
||||
Examples:
|
||||
>>> print(worker.sample())
|
||||
SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...})
|
||||
"""
|
||||
|
||||
if self.fake_sampler and self.last_batch is not None:
|
||||
|
@ -561,8 +567,17 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
batch = self.sample()
|
||||
return batch, batch.count
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
@DeveloperAPI
|
||||
def get_weights(self, policies=None):
|
||||
"""Returns the model weights of this worker.
|
||||
|
||||
Returns:
|
||||
object: weights that can be set on another worker.
|
||||
info: dictionary of extra metadata.
|
||||
|
||||
Examples:
|
||||
>>> weights = worker.get_weights()
|
||||
"""
|
||||
if policies is None:
|
||||
policies = self.policy_map.keys()
|
||||
return {
|
||||
|
@ -570,15 +585,33 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
for pid, policy in self.policy_map.items() if pid in policies
|
||||
}
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
@DeveloperAPI
|
||||
def set_weights(self, weights, global_vars=None):
|
||||
"""Sets the model weights of this worker.
|
||||
|
||||
Examples:
|
||||
>>> weights = worker.get_weights()
|
||||
>>> worker.set_weights(weights)
|
||||
"""
|
||||
for pid, w in weights.items():
|
||||
self.policy_map[pid].set_weights(w)
|
||||
if global_vars:
|
||||
self.set_global_vars(global_vars)
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, samples):
|
||||
"""Returns a gradient computed w.r.t the specified samples.
|
||||
|
||||
Returns:
|
||||
(grads, info): A list of gradients that can be applied on a
|
||||
compatible worker. In the multi-agent case, returns a dict
|
||||
of gradients keyed by policy ids. An info dictionary of
|
||||
extra metadata is also returned.
|
||||
|
||||
Examples:
|
||||
>>> batch = worker.sample()
|
||||
>>> grads, info = worker.compute_gradients(samples)
|
||||
"""
|
||||
if log_once("compute_gradients"):
|
||||
logger.info("Compute gradients on:\n\n{}\n".format(
|
||||
summarize(samples)))
|
||||
|
@ -609,8 +642,15 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
summarize(info_out)))
|
||||
return grad_out, info_out
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
@DeveloperAPI
|
||||
def apply_gradients(self, grads):
|
||||
"""Applies the given gradients to this worker's weights.
|
||||
|
||||
Examples:
|
||||
>>> samples = worker.sample()
|
||||
>>> grads, info = worker.compute_gradients(samples)
|
||||
>>> worker.apply_gradients(grads)
|
||||
"""
|
||||
if log_once("apply_gradients"):
|
||||
logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
|
||||
if isinstance(grads, dict):
|
||||
|
@ -630,8 +670,20 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
else:
|
||||
return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(self, samples):
|
||||
"""Update policies based on the given batch.
|
||||
|
||||
This is the equivalent to apply_gradients(compute_gradients(samples)),
|
||||
but can be optimized to avoid pulling gradients into CPU memory.
|
||||
|
||||
Returns:
|
||||
info: dictionary of extra metadata from compute_gradients().
|
||||
|
||||
Examples:
|
||||
>>> batch = worker.sample()
|
||||
>>> worker.learn_on_batch(samples)
|
||||
"""
|
||||
if log_once("learn_on_batch"):
|
||||
logger.info(
|
||||
"Training on concatenated sample batches:\n\n{}\n".format(
|
||||
|
@ -654,8 +706,10 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
info_out[pid] = policy.learn_on_batch(batch)
|
||||
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
|
||||
else:
|
||||
info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
|
||||
samples)
|
||||
info_out = {
|
||||
DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID]
|
||||
.learn_on_batch(samples)
|
||||
}
|
||||
if log_once("learn_out"):
|
||||
logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
|
||||
return info_out
|
||||
|
@ -827,6 +881,18 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
"""Returns the args used to create this worker."""
|
||||
return self._original_kwargs
|
||||
|
||||
@DeveloperAPI
|
||||
def get_host(self):
|
||||
"""Returns the hostname of the process running this evaluator."""
|
||||
|
||||
return os.uname()[1]
|
||||
|
||||
@DeveloperAPI
|
||||
def apply(self, func, *args):
|
||||
"""Apply the given function to this evaluator instance."""
|
||||
|
||||
return func(self, *args)
|
||||
|
||||
def _build_policy_map(self, policy_dict, policy_config):
|
||||
policy_map = {}
|
||||
preprocessors = {}
|
||||
|
|
|
@ -8,7 +8,6 @@ from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
|
|||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
from ray.rllib.utils import merge_dicts, try_import_tf
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -115,7 +114,7 @@ class WorkerSet:
|
|||
"""Apply the given function to each worker instance."""
|
||||
|
||||
local_result = [func(self.local_worker())]
|
||||
remote_results = ray_get_and_free(
|
||||
remote_results = ray.get(
|
||||
[w.apply.remote(func) for w in self.remote_workers()])
|
||||
return local_result + remote_results
|
||||
|
||||
|
@ -126,7 +125,7 @@ class WorkerSet:
|
|||
The index will be passed as the second arg to the given function.
|
||||
"""
|
||||
local_result = [func(self.local_worker(), 0)]
|
||||
remote_results = ray_get_and_free([
|
||||
remote_results = ray.get([
|
||||
w.apply.remote(func, i + 1)
|
||||
for i, w in enumerate(self.remote_workers())
|
||||
])
|
||||
|
@ -147,7 +146,7 @@ class WorkerSet:
|
|||
local_results = self.local_worker().foreach_policy(func)
|
||||
remote_results = []
|
||||
for worker in self.remote_workers():
|
||||
res = ray_get_and_free(
|
||||
res = ray.get(
|
||||
worker.apply.remote(lambda w: w.foreach_policy(func)))
|
||||
remote_results.extend(res)
|
||||
return local_results + remote_results
|
||||
|
@ -172,7 +171,7 @@ class WorkerSet:
|
|||
local_results = self.local_worker().foreach_trainable_policy(func)
|
||||
remote_results = []
|
||||
for worker in self.remote_workers():
|
||||
res = ray_get_and_free(
|
||||
res = ray.get(
|
||||
worker.apply.remote(
|
||||
lambda w: w.foreach_trainable_policy(func)))
|
||||
remote_results.extend(res)
|
||||
|
|
|
@ -25,8 +25,8 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches, \
|
|||
StandardizeFields, SelectExperiences
|
||||
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ Configurations you can try:
|
|||
- normal policy gradients (PG)
|
||||
- contrib/MADDPG
|
||||
- QMIX
|
||||
- APEX_QMIX
|
||||
|
||||
See also: centralized_critic.py for centralized critic PPO on this game.
|
||||
"""
|
||||
|
@ -95,27 +94,6 @@ if __name__ == "__main__":
|
|||
"use_pytorch": args.torch,
|
||||
}
|
||||
group = True
|
||||
elif args.run == "APEX_QMIX":
|
||||
config = {
|
||||
"num_gpus": 0,
|
||||
"num_workers": 2,
|
||||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
"min_iter_time_s": 3,
|
||||
"buffer_size": 1000,
|
||||
"learning_starts": 1000,
|
||||
"train_batch_size": 128,
|
||||
"rollout_fragment_length": 32,
|
||||
"target_network_update_freq": 500,
|
||||
"timesteps_per_iteration": 1000,
|
||||
"env_config": {
|
||||
"separate_state_space": True,
|
||||
"one_hot_state_encoding": True
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
}
|
||||
group = True
|
||||
else:
|
||||
config = {}
|
||||
group = False
|
||||
|
|
90
rllib/execution/learner_thread.py
Normal file
90
rllib/execution/learner_thread.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
import threading
|
||||
import copy
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
|
||||
|
||||
class LearnerThread(threading.Thread):
|
||||
"""Background thread that updates the local model from sample trajectories.
|
||||
|
||||
The learner thread communicates with the main thread through Queues. This
|
||||
is needed since Ray operations can only be run on the main thread. In
|
||||
addition, moving heavyweight gradient ops session runs off the main thread
|
||||
improves overall throughput.
|
||||
"""
|
||||
|
||||
def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter,
|
||||
learner_queue_size, learner_queue_timeout):
|
||||
"""Initialize the learner thread.
|
||||
|
||||
Arguments:
|
||||
local_worker (RolloutWorker): process local rollout worker holding
|
||||
policies this thread will call learn_on_batch() on
|
||||
minibatch_buffer_size (int): max number of train batches to store
|
||||
in the minibatching buffer
|
||||
num_sgd_iter (int): number of passes to learn on per train batch
|
||||
learner_queue_size (int): max size of queue of inbound
|
||||
train batches to this thread
|
||||
learner_queue_timeout (int): raise an exception if the queue has
|
||||
been empty for this long in seconds
|
||||
"""
|
||||
threading.Thread.__init__(self)
|
||||
self.learner_queue_size = WindowStat("size", 50)
|
||||
self.local_worker = local_worker
|
||||
self.inqueue = queue.Queue(maxsize=learner_queue_size)
|
||||
self.outqueue = queue.Queue()
|
||||
self.minibatch_buffer = MinibatchBuffer(
|
||||
inqueue=self.inqueue,
|
||||
size=minibatch_buffer_size,
|
||||
timeout=learner_queue_timeout,
|
||||
num_passes=num_sgd_iter,
|
||||
init_num_passes=num_sgd_iter)
|
||||
self.queue_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.load_timer = TimerStat()
|
||||
self.load_wait_timer = TimerStat()
|
||||
self.daemon = True
|
||||
self.weights_updated = False
|
||||
self.stats = {}
|
||||
self.stopped = False
|
||||
self.num_steps = 0
|
||||
|
||||
def run(self):
|
||||
while not self.stopped:
|
||||
self.step()
|
||||
|
||||
def step(self):
|
||||
with self.queue_timer:
|
||||
batch, _ = self.minibatch_buffer.get()
|
||||
|
||||
with self.grad_timer:
|
||||
fetches = self.local_worker.learn_on_batch(batch)
|
||||
self.weights_updated = True
|
||||
self.stats = get_learner_stats(fetches)
|
||||
|
||||
self.num_steps += 1
|
||||
self.outqueue.put((batch.count, self.stats))
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
|
||||
def add_learner_metrics(self, result):
|
||||
"""Add internal metrics to a trainer result dict."""
|
||||
|
||||
def timer_to_ms(timer):
|
||||
return round(1000 * timer.mean, 3)
|
||||
|
||||
result["info"].update({
|
||||
"learner_queue": self.learner_queue_size.stats(),
|
||||
"learner": copy.deepcopy(self.stats),
|
||||
"timing_breakdown": {
|
||||
"learner_grad_time_ms": timer_to_ms(self.grad_timer),
|
||||
"learner_load_time_ms": timer_to_ms(self.load_timer),
|
||||
"learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer),
|
||||
"learner_dequeue_time_ms": timer_to_ms(self.queue_timer),
|
||||
}
|
||||
})
|
||||
return result
|
|
@ -3,7 +3,8 @@ import time
|
|||
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
_get_shared_metrics
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
|
||||
|
||||
|
@ -86,7 +87,7 @@ class CollectMetrics:
|
|||
res = summarize_episodes(episodes, orig_episodes)
|
||||
|
||||
# Add in iterator metrics.
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
timers = {}
|
||||
counters = {}
|
||||
info = {}
|
||||
|
@ -157,7 +158,7 @@ class OncePerTimestepsElapsed:
|
|||
def __call__(self, item):
|
||||
if self.delay_steps <= 0:
|
||||
return True
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
now = metrics.counters[STEPS_SAMPLED_COUNTER]
|
||||
if now - self.last_called >= self.delay_steps:
|
||||
self.last_called = now
|
||||
|
|
44
rllib/execution/minibatch_buffer.py
Normal file
44
rllib/execution/minibatch_buffer.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
class MinibatchBuffer:
|
||||
"""Ring buffer of recent data batches for minibatch SGD.
|
||||
|
||||
This is for use with AsyncSamplesOptimizer.
|
||||
"""
|
||||
|
||||
def __init__(self, inqueue, size, timeout, num_passes, init_num_passes=1):
|
||||
"""Initialize a minibatch buffer.
|
||||
|
||||
Arguments:
|
||||
inqueue: Queue to populate the internal ring buffer from.
|
||||
size: Max number of data items to buffer.
|
||||
timeout: Queue timeout
|
||||
num_passes: Max num times each data item should be emitted.
|
||||
init_num_passes: Initial max passes for each data item
|
||||
"""
|
||||
self.inqueue = inqueue
|
||||
self.size = size
|
||||
self.timeout = timeout
|
||||
self.max_ttl = num_passes
|
||||
self.cur_max_ttl = init_num_passes
|
||||
self.buffers = [None] * size
|
||||
self.ttl = [0] * size
|
||||
self.idx = 0
|
||||
|
||||
def get(self):
|
||||
"""Get a new batch from the internal ring buffer.
|
||||
|
||||
Returns:
|
||||
buf: Data item saved from inqueue.
|
||||
released: True if the item is now removed from the ring buffer.
|
||||
"""
|
||||
if self.ttl[self.idx] <= 0:
|
||||
self.buffers[self.idx] = self.inqueue.get(timeout=self.timeout)
|
||||
self.ttl[self.idx] = self.cur_max_ttl
|
||||
if self.cur_max_ttl < self.max_ttl:
|
||||
self.cur_max_ttl += 1
|
||||
buf = self.buffers[self.idx]
|
||||
self.ttl[self.idx] -= 1
|
||||
released = self.ttl[self.idx] <= 0
|
||||
if released:
|
||||
self.buffers[self.idx] = None
|
||||
self.idx = (self.idx + 1) % len(self.buffers)
|
||||
return buf, released
|
358
rllib/execution/multi_gpu_impl.py
Normal file
358
rllib/execution/multi_gpu_impl.py
Normal file
|
@ -0,0 +1,358 @@
|
|||
from collections import namedtuple
|
||||
import logging
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
# Variable scope in which created variables will be placed under
|
||||
TOWER_SCOPE_NAME = "tower"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalSyncParallelOptimizer:
|
||||
"""Optimizer that runs in parallel across multiple local devices.
|
||||
|
||||
LocalSyncParallelOptimizer automatically splits up and loads training data
|
||||
onto specified local devices (e.g. GPUs) with `load_data()`. During a call
|
||||
to `optimize()`, the devices compute gradients over slices of the data in
|
||||
parallel. The gradients are then averaged and applied to the shared
|
||||
weights.
|
||||
|
||||
The data loaded is pinned in device memory until the next call to
|
||||
`load_data`, so you can make multiple passes (possibly in randomized order)
|
||||
over the same data once loaded.
|
||||
|
||||
This is similar to tf.train.SyncReplicasOptimizer, but works within a
|
||||
single TensorFlow graph, i.e. implements in-graph replicated training:
|
||||
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
|
||||
|
||||
Args:
|
||||
optimizer: Delegate TensorFlow optimizer object.
|
||||
devices: List of the names of TensorFlow devices to parallelize over.
|
||||
input_placeholders: List of input_placeholders for the loss function.
|
||||
Tensors of these shapes will be passed to build_graph() in order
|
||||
to define the per-device loss ops.
|
||||
rnn_inputs: Extra input placeholders for RNN inputs. These will have
|
||||
shape [BATCH_SIZE // MAX_SEQ_LEN, ...].
|
||||
max_per_device_batch_size: Number of tuples to optimize over at a time
|
||||
per device. In each call to `optimize()`,
|
||||
`len(devices) * per_device_batch_size` tuples of data will be
|
||||
processed. If this is larger than the total data size, it will be
|
||||
clipped.
|
||||
build_graph: Function that takes the specified inputs and returns a
|
||||
TF Policy instance.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
devices,
|
||||
input_placeholders,
|
||||
rnn_inputs,
|
||||
max_per_device_batch_size,
|
||||
build_graph,
|
||||
grad_norm_clipping=None):
|
||||
self.optimizer = optimizer
|
||||
self.devices = devices
|
||||
self.max_per_device_batch_size = max_per_device_batch_size
|
||||
self.loss_inputs = input_placeholders + rnn_inputs
|
||||
self.build_graph = build_graph
|
||||
|
||||
# First initialize the shared loss network
|
||||
with tf.name_scope(TOWER_SCOPE_NAME):
|
||||
self._shared_loss = build_graph(self.loss_inputs)
|
||||
shared_ops = tf.get_collection(
|
||||
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
|
||||
|
||||
# Then setup the per-device loss graphs that use the shared weights
|
||||
self._batch_index = tf.placeholder(tf.int32, name="batch_index")
|
||||
|
||||
# Dynamic batch size, which may be shrunk if there isn't enough data
|
||||
self._per_device_batch_size = tf.placeholder(
|
||||
tf.int32, name="per_device_batch_size")
|
||||
self._loaded_per_device_batch_size = max_per_device_batch_size
|
||||
|
||||
# When loading RNN input, we dynamically determine the max seq len
|
||||
self._max_seq_len = tf.placeholder(tf.int32, name="max_seq_len")
|
||||
self._loaded_max_seq_len = 1
|
||||
|
||||
# Split on the CPU in case the data doesn't fit in GPU memory.
|
||||
with tf.device("/cpu:0"):
|
||||
data_splits = zip(
|
||||
*[tf.split(ph, len(devices)) for ph in self.loss_inputs])
|
||||
|
||||
self._towers = []
|
||||
for device, device_placeholders in zip(self.devices, data_splits):
|
||||
self._towers.append(
|
||||
self._setup_device(device, device_placeholders,
|
||||
len(input_placeholders)))
|
||||
|
||||
avg = average_gradients([t.grads for t in self._towers])
|
||||
if grad_norm_clipping:
|
||||
clipped = []
|
||||
for grad, _ in avg:
|
||||
clipped.append(grad)
|
||||
clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping)
|
||||
for i, (grad, var) in enumerate(avg):
|
||||
avg[i] = (clipped[i], var)
|
||||
|
||||
# gather update ops for any batch norm layers. TODO(ekl) here we will
|
||||
# use all the ops found which won't work for DQN / DDPG, but those
|
||||
# aren't supported with multi-gpu right now anyways.
|
||||
self._update_ops = tf.get_collection(
|
||||
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
|
||||
for op in shared_ops:
|
||||
self._update_ops.remove(op) # only care about tower update ops
|
||||
if self._update_ops:
|
||||
logger.debug("Update ops to run on apply gradient: {}".format(
|
||||
self._update_ops))
|
||||
|
||||
with tf.control_dependencies(self._update_ops):
|
||||
self._train_op = self.optimizer.apply_gradients(avg)
|
||||
|
||||
def load_data(self, sess, inputs, state_inputs):
|
||||
"""Bulk loads the specified inputs into device memory.
|
||||
|
||||
The shape of the inputs must conform to the shapes of the input
|
||||
placeholders this optimizer was constructed with.
|
||||
|
||||
The data is split equally across all the devices. If the data is not
|
||||
evenly divisible by the batch size, excess data will be discarded.
|
||||
|
||||
Args:
|
||||
sess: TensorFlow session.
|
||||
inputs: List of arrays matching the input placeholders, of shape
|
||||
[BATCH_SIZE, ...].
|
||||
state_inputs: List of RNN input arrays. These arrays have size
|
||||
[BATCH_SIZE / MAX_SEQ_LEN, ...].
|
||||
|
||||
Returns:
|
||||
The number of tuples loaded per device.
|
||||
"""
|
||||
|
||||
if log_once("load_data"):
|
||||
logger.info(
|
||||
"Training on concatenated sample batches:\n\n{}\n".format(
|
||||
summarize({
|
||||
"placeholders": self.loss_inputs,
|
||||
"inputs": inputs,
|
||||
"state_inputs": state_inputs
|
||||
})))
|
||||
|
||||
feed_dict = {}
|
||||
assert len(self.loss_inputs) == len(inputs + state_inputs), \
|
||||
(self.loss_inputs, inputs, state_inputs)
|
||||
|
||||
# Let's suppose we have the following input data, and 2 devices:
|
||||
# 1 2 3 4 5 6 7 <- state inputs shape
|
||||
# A A A B B B C C C D D D E E E F F F G G G <- inputs shape
|
||||
# The data is truncated and split across devices as follows:
|
||||
# |---| seq len = 3
|
||||
# |---------------------------------| seq batch size = 6 seqs
|
||||
# |----------------| per device batch size = 9 tuples
|
||||
|
||||
if len(state_inputs) > 0:
|
||||
smallest_array = state_inputs[0]
|
||||
seq_len = len(inputs[0]) // len(state_inputs[0])
|
||||
self._loaded_max_seq_len = seq_len
|
||||
else:
|
||||
smallest_array = inputs[0]
|
||||
self._loaded_max_seq_len = 1
|
||||
|
||||
sequences_per_minibatch = (
|
||||
self.max_per_device_batch_size // self._loaded_max_seq_len * len(
|
||||
self.devices))
|
||||
if sequences_per_minibatch < 1:
|
||||
logger.warning(
|
||||
("Target minibatch size is {}, however the rollout sequence "
|
||||
"length is {}, hence the minibatch size will be raised to "
|
||||
"{}.").format(self.max_per_device_batch_size,
|
||||
self._loaded_max_seq_len,
|
||||
self._loaded_max_seq_len * len(self.devices)))
|
||||
sequences_per_minibatch = 1
|
||||
|
||||
if len(smallest_array) < sequences_per_minibatch:
|
||||
# Dynamically shrink the batch size if insufficient data
|
||||
sequences_per_minibatch = make_divisible_by(
|
||||
len(smallest_array), len(self.devices))
|
||||
|
||||
if log_once("data_slicing"):
|
||||
logger.info(
|
||||
("Divided {} rollout sequences, each of length {}, among "
|
||||
"{} devices.").format(
|
||||
len(smallest_array), self._loaded_max_seq_len,
|
||||
len(self.devices)))
|
||||
|
||||
if sequences_per_minibatch < len(self.devices):
|
||||
raise ValueError(
|
||||
"Must load at least 1 tuple sequence per device. Try "
|
||||
"increasing `sgd_minibatch_size` or reducing `max_seq_len` "
|
||||
"to ensure that at least one sequence fits per device.")
|
||||
self._loaded_per_device_batch_size = (sequences_per_minibatch // len(
|
||||
self.devices) * self._loaded_max_seq_len)
|
||||
|
||||
if len(state_inputs) > 0:
|
||||
# First truncate the RNN state arrays to the sequences_per_minib.
|
||||
state_inputs = [
|
||||
make_divisible_by(arr, sequences_per_minibatch)
|
||||
for arr in state_inputs
|
||||
]
|
||||
# Then truncate the data inputs to match
|
||||
inputs = [arr[:len(state_inputs[0]) * seq_len] for arr in inputs]
|
||||
assert len(state_inputs[0]) * seq_len == len(inputs[0]), \
|
||||
(len(state_inputs[0]), sequences_per_minibatch, seq_len,
|
||||
len(inputs[0]))
|
||||
for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
|
||||
feed_dict[ph] = arr
|
||||
truncated_len = len(inputs[0])
|
||||
else:
|
||||
for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
|
||||
truncated_arr = make_divisible_by(arr, sequences_per_minibatch)
|
||||
feed_dict[ph] = truncated_arr
|
||||
truncated_len = len(truncated_arr)
|
||||
|
||||
sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)
|
||||
|
||||
self.num_tuples_loaded = truncated_len
|
||||
tuples_per_device = truncated_len // len(self.devices)
|
||||
assert tuples_per_device > 0, "No data loaded?"
|
||||
assert tuples_per_device % self._loaded_per_device_batch_size == 0
|
||||
return tuples_per_device
|
||||
|
||||
def optimize(self, sess, batch_index):
|
||||
"""Run a single step of SGD.
|
||||
|
||||
Runs a SGD step over a slice of the preloaded batch with size given by
|
||||
self._loaded_per_device_batch_size and offset given by the batch_index
|
||||
argument.
|
||||
|
||||
Updates shared model weights based on the averaged per-device
|
||||
gradients.
|
||||
|
||||
Args:
|
||||
sess: TensorFlow session.
|
||||
batch_index: Offset into the preloaded data. This value must be
|
||||
between `0` and `tuples_per_device`. The amount of data to
|
||||
process is at most `max_per_device_batch_size`.
|
||||
|
||||
Returns:
|
||||
The outputs of extra_ops evaluated over the batch.
|
||||
"""
|
||||
feed_dict = {
|
||||
self._batch_index: batch_index,
|
||||
self._per_device_batch_size: self._loaded_per_device_batch_size,
|
||||
self._max_seq_len: self._loaded_max_seq_len,
|
||||
}
|
||||
for tower in self._towers:
|
||||
feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())
|
||||
|
||||
fetches = {"train": self._train_op}
|
||||
for tower in self._towers:
|
||||
fetches.update(tower.loss_graph._get_grad_and_stats_fetches())
|
||||
|
||||
return sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
def get_common_loss(self):
|
||||
return self._shared_loss
|
||||
|
||||
def get_device_losses(self):
|
||||
return [t.loss_graph for t in self._towers]
|
||||
|
||||
def _setup_device(self, device, device_input_placeholders, num_data_in):
|
||||
assert num_data_in <= len(device_input_placeholders)
|
||||
with tf.device(device):
|
||||
with tf.name_scope(TOWER_SCOPE_NAME):
|
||||
device_input_batches = []
|
||||
device_input_slices = []
|
||||
for i, ph in enumerate(device_input_placeholders):
|
||||
current_batch = tf.Variable(
|
||||
ph,
|
||||
trainable=False,
|
||||
validate_shape=False,
|
||||
collections=[])
|
||||
device_input_batches.append(current_batch)
|
||||
if i < num_data_in:
|
||||
scale = self._max_seq_len
|
||||
granularity = self._max_seq_len
|
||||
else:
|
||||
scale = self._max_seq_len
|
||||
granularity = 1
|
||||
current_slice = tf.slice(
|
||||
current_batch,
|
||||
([self._batch_index // scale * granularity] +
|
||||
[0] * len(ph.shape[1:])),
|
||||
([self._per_device_batch_size // scale * granularity] +
|
||||
[-1] * len(ph.shape[1:])))
|
||||
current_slice.set_shape(ph.shape)
|
||||
device_input_slices.append(current_slice)
|
||||
graph_obj = self.build_graph(device_input_slices)
|
||||
device_grads = graph_obj.gradients(self.optimizer,
|
||||
graph_obj._loss)
|
||||
return Tower(
|
||||
tf.group(
|
||||
*[batch.initializer for batch in device_input_batches]),
|
||||
device_grads, graph_obj)
|
||||
|
||||
|
||||
# Each tower is a copy of the loss graph pinned to a specific device.
|
||||
Tower = namedtuple("Tower", ["init_op", "grads", "loss_graph"])
|
||||
|
||||
|
||||
def make_divisible_by(a, n):
|
||||
if type(a) is int:
|
||||
return a - a % n
|
||||
return a[0:a.shape[0] - a.shape[0] % n]
|
||||
|
||||
|
||||
def average_gradients(tower_grads):
|
||||
"""Averages gradients across towers.
|
||||
|
||||
Calculate the average gradient for each shared variable across all towers.
|
||||
Note that this function provides a synchronization point across all towers.
|
||||
|
||||
Args:
|
||||
tower_grads: List of lists of (gradient, variable) tuples. The outer
|
||||
list is over individual gradients. The inner list is over the
|
||||
gradient calculation for each tower.
|
||||
|
||||
Returns:
|
||||
List of pairs of (gradient, variable) where the gradient has been
|
||||
averaged across all towers.
|
||||
|
||||
TODO(ekl): We could use NCCL if this becomes a bottleneck.
|
||||
"""
|
||||
|
||||
average_grads = []
|
||||
for grad_and_vars in zip(*tower_grads):
|
||||
|
||||
# Note that each grad_and_vars looks like the following:
|
||||
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
|
||||
grads = []
|
||||
for g, _ in grad_and_vars:
|
||||
if g is not None:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
|
||||
# Append on a 'tower' dimension which we will average over
|
||||
# below.
|
||||
grads.append(expanded_g)
|
||||
|
||||
if not grads:
|
||||
continue
|
||||
|
||||
# Average over the 'tower' dimension.
|
||||
grad = tf.concat(axis=0, values=grads)
|
||||
grad = tf.reduce_mean(grad, 0)
|
||||
|
||||
# Keep in mind that the Variables are redundant because they are shared
|
||||
# across towers. So .. we will just return the first tower's pointer to
|
||||
# the Variable.
|
||||
v = grad_and_vars[0][1]
|
||||
grad_and_var = (grad, v)
|
||||
average_grads.append(grad_and_var)
|
||||
|
||||
return average_grads
|
171
rllib/execution/multi_gpu_learner.py
Normal file
171
rllib/execution/multi_gpu_learner.py
Normal file
|
@ -0,0 +1,171 @@
|
|||
import logging
|
||||
import threading
|
||||
import math
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.execution.learner_thread import LearnerThread
|
||||
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TFMultiGPULearner(LearnerThread):
|
||||
"""Learner that can use multiple GPUs and parallel loading.
|
||||
|
||||
This is for use with AsyncSamplesOptimizer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
local_worker,
|
||||
num_gpus=1,
|
||||
lr=0.0005,
|
||||
train_batch_size=500,
|
||||
num_data_loader_buffers=1,
|
||||
minibatch_buffer_size=1,
|
||||
num_sgd_iter=1,
|
||||
learner_queue_size=16,
|
||||
learner_queue_timeout=300,
|
||||
num_data_load_threads=16,
|
||||
_fake_gpus=False):
|
||||
"""Initialize a multi-gpu learner thread.
|
||||
|
||||
Arguments:
|
||||
local_worker (RolloutWorker): process local rollout worker holding
|
||||
policies this thread will call learn_on_batch() on
|
||||
num_gpus (int): number of GPUs to use for data-parallel SGD
|
||||
lr (float): learning rate
|
||||
train_batch_size (int): size of batches to learn on
|
||||
num_data_loader_buffers (int): number of buffers to load data into
|
||||
in parallel. Each buffer is of size of train_batch_size and
|
||||
increases GPU memory usage proportionally.
|
||||
minibatch_buffer_size (int): max number of train batches to store
|
||||
in the minibatching buffer
|
||||
num_sgd_iter (int): number of passes to learn on per train batch
|
||||
learner_queue_size (int): max size of queue of inbound
|
||||
train batches to this thread
|
||||
num_data_loader_threads (int): number of threads to use to load
|
||||
data into GPU memory in parallel
|
||||
"""
|
||||
LearnerThread.__init__(self, local_worker, minibatch_buffer_size,
|
||||
num_sgd_iter, learner_queue_size,
|
||||
learner_queue_timeout)
|
||||
self.lr = lr
|
||||
self.train_batch_size = train_batch_size
|
||||
if not num_gpus:
|
||||
self.devices = ["/cpu:0"]
|
||||
elif _fake_gpus:
|
||||
self.devices = [
|
||||
"/cpu:{}".format(i) for i in range(int(math.ceil(num_gpus)))
|
||||
]
|
||||
else:
|
||||
self.devices = [
|
||||
"/gpu:{}".format(i) for i in range(int(math.ceil(num_gpus)))
|
||||
]
|
||||
logger.info("TFMultiGPULearner devices {}".format(self.devices))
|
||||
assert self.train_batch_size % len(self.devices) == 0
|
||||
assert self.train_batch_size >= len(self.devices), "batch too small"
|
||||
|
||||
if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}:
|
||||
raise NotImplementedError("Multi-gpu mode for multi-agent")
|
||||
self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID]
|
||||
|
||||
# per-GPU graph copies created below must share vars with the policy
|
||||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||
# all of the device copies are created.
|
||||
self.par_opt = []
|
||||
with self.local_worker.tf_sess.graph.as_default():
|
||||
with self.local_worker.tf_sess.as_default():
|
||||
with tf.variable_scope(DEFAULT_POLICY_ID, reuse=tf.AUTO_REUSE):
|
||||
if self.policy._state_inputs:
|
||||
rnn_inputs = self.policy._state_inputs + [
|
||||
self.policy._seq_lens
|
||||
]
|
||||
else:
|
||||
rnn_inputs = []
|
||||
adam = tf.train.AdamOptimizer(self.lr)
|
||||
for _ in range(num_data_loader_buffers):
|
||||
self.par_opt.append(
|
||||
LocalSyncParallelOptimizer(
|
||||
adam,
|
||||
self.devices,
|
||||
[v for _, v in self.policy._loss_inputs],
|
||||
rnn_inputs,
|
||||
999999, # it will get rounded down
|
||||
self.policy.copy))
|
||||
|
||||
self.sess = self.local_worker.tf_sess
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.idle_optimizers = queue.Queue()
|
||||
self.ready_optimizers = queue.Queue()
|
||||
for opt in self.par_opt:
|
||||
self.idle_optimizers.put(opt)
|
||||
for i in range(num_data_load_threads):
|
||||
self.loader_thread = _LoaderThread(self, share_stats=(i == 0))
|
||||
self.loader_thread.start()
|
||||
|
||||
self.minibatch_buffer = MinibatchBuffer(
|
||||
self.ready_optimizers, minibatch_buffer_size,
|
||||
learner_queue_timeout, num_sgd_iter)
|
||||
|
||||
@override(LearnerThread)
|
||||
def step(self):
|
||||
assert self.loader_thread.is_alive()
|
||||
with self.load_wait_timer:
|
||||
opt, released = self.minibatch_buffer.get()
|
||||
|
||||
with self.grad_timer:
|
||||
fetches = opt.optimize(self.sess, 0)
|
||||
self.weights_updated = True
|
||||
self.stats = get_learner_stats(fetches)
|
||||
|
||||
if released:
|
||||
self.idle_optimizers.put(opt)
|
||||
|
||||
self.outqueue.put((opt.num_tuples_loaded, self.stats))
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
|
||||
|
||||
class _LoaderThread(threading.Thread):
|
||||
def __init__(self, learner, share_stats):
|
||||
threading.Thread.__init__(self)
|
||||
self.learner = learner
|
||||
self.daemon = True
|
||||
if share_stats:
|
||||
self.queue_timer = learner.queue_timer
|
||||
self.load_timer = learner.load_timer
|
||||
else:
|
||||
self.queue_timer = TimerStat()
|
||||
self.load_timer = TimerStat()
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
self._step()
|
||||
|
||||
def _step(self):
|
||||
s = self.learner
|
||||
with self.queue_timer:
|
||||
batch = s.inqueue.get()
|
||||
|
||||
opt = s.idle_optimizers.get()
|
||||
|
||||
with self.load_timer:
|
||||
tuples = s.policy._get_loss_inputs_dict(batch, shuffle=False)
|
||||
data_keys = [ph for _, ph in s.policy._loss_inputs]
|
||||
if s.policy._state_inputs:
|
||||
state_keys = s.policy._state_inputs + [s.policy._seq_lens]
|
||||
else:
|
||||
state_keys = []
|
||||
opt.load_data(s.sess, [tuples[k] for k in data_keys],
|
||||
[tuples[k] for k in state_keys])
|
||||
|
||||
s.ready_optimizers.put(opt)
|
143
rllib/execution/old_segment_tree.py
Normal file
143
rllib/execution/old_segment_tree.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
import operator
|
||||
|
||||
|
||||
class OldSegmentTree(object):
|
||||
def __init__(self, capacity, operation, neutral_element):
|
||||
"""Build a Segment Tree data structure.
|
||||
|
||||
https://en.wikipedia.org/wiki/Segment_tree
|
||||
|
||||
Can be used as regular array, but with two
|
||||
important differences:
|
||||
|
||||
a) setting item's value is slightly slower.
|
||||
It is O(lg capacity) instead of O(1).
|
||||
b) user has access to an efficient `reduce`
|
||||
operation which reduces `operation` over
|
||||
a contiguous subsequence of items in the
|
||||
array.
|
||||
|
||||
Paramters
|
||||
---------
|
||||
capacity: int
|
||||
Total size of the array - must be a power of two.
|
||||
operation: lambda obj, obj -> obj
|
||||
and operation for combining elements (eg. sum, max)
|
||||
must for a mathematical group together with the set of
|
||||
possible values for array elements.
|
||||
neutral_element: obj
|
||||
neutral element for the operation above. eg. float('-inf')
|
||||
for max and 0 for sum.
|
||||
"""
|
||||
|
||||
assert capacity > 0 and capacity & (capacity - 1) == 0, \
|
||||
"capacity must be positive and a power of 2."
|
||||
self._capacity = capacity
|
||||
self._value = [neutral_element for _ in range(2 * capacity)]
|
||||
self._operation = operation
|
||||
|
||||
def _reduce_helper(self, start, end, node, node_start, node_end):
|
||||
if start == node_start and end == node_end:
|
||||
return self._value[node]
|
||||
mid = (node_start + node_end) // 2
|
||||
if end <= mid:
|
||||
return self._reduce_helper(start, end, 2 * node, node_start, mid)
|
||||
else:
|
||||
if mid + 1 <= start:
|
||||
return self._reduce_helper(start, end, 2 * node + 1, mid + 1,
|
||||
node_end)
|
||||
else:
|
||||
return self._operation(
|
||||
self._reduce_helper(start, mid, 2 * node, node_start, mid),
|
||||
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1,
|
||||
node_end))
|
||||
|
||||
def reduce(self, start=0, end=None):
|
||||
"""Returns result of applying `self.operation`
|
||||
to a contiguous subsequence of the array.
|
||||
|
||||
self.operation(
|
||||
arr[start], operation(arr[start+1], operation(... arr[end])))
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start: int
|
||||
beginning of the subsequence
|
||||
end: int
|
||||
end of the subsequences
|
||||
|
||||
Returns
|
||||
-------
|
||||
reduced: obj
|
||||
result of reducing self.operation over the specified range of array
|
||||
elements.
|
||||
"""
|
||||
if end is None:
|
||||
end = self._capacity
|
||||
if end < 0:
|
||||
end += self._capacity
|
||||
end -= 1
|
||||
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
||||
|
||||
def __setitem__(self, idx, val):
|
||||
# index of the leaf
|
||||
idx += self._capacity
|
||||
self._value[idx] = val
|
||||
idx //= 2
|
||||
while idx >= 1:
|
||||
self._value[idx] = self._operation(self._value[2 * idx],
|
||||
self._value[2 * idx + 1])
|
||||
idx //= 2
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert 0 <= idx < self._capacity
|
||||
return self._value[self._capacity + idx]
|
||||
|
||||
|
||||
class OldSumSegmentTree(OldSegmentTree):
|
||||
def __init__(self, capacity):
|
||||
super(OldSumSegmentTree, self).__init__(
|
||||
capacity=capacity, operation=operator.add, neutral_element=0.0)
|
||||
|
||||
def sum(self, start=0, end=None):
|
||||
"""Returns arr[start] + ... + arr[end]"""
|
||||
return super(OldSumSegmentTree, self).reduce(start, end)
|
||||
|
||||
def find_prefixsum_idx(self, prefixsum):
|
||||
"""Find the highest index `i` in the array such that
|
||||
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
||||
|
||||
if array values are probabilities, this function
|
||||
allows to sample indexes according to the discrete
|
||||
probability efficiently.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
perfixsum: float
|
||||
upperbound on the sum of array prefix
|
||||
|
||||
Returns
|
||||
-------
|
||||
idx: int
|
||||
highest index satisfying the prefixsum constraint
|
||||
"""
|
||||
assert 0 <= prefixsum <= self.sum() + 1e-5
|
||||
idx = 1
|
||||
while idx < self._capacity: # while non-leaf
|
||||
if self._value[2 * idx] > prefixsum:
|
||||
idx = 2 * idx
|
||||
else:
|
||||
prefixsum -= self._value[2 * idx]
|
||||
idx = 2 * idx + 1
|
||||
return idx - self._capacity
|
||||
|
||||
|
||||
class OldMinSegmentTree(OldSegmentTree):
|
||||
def __init__(self, capacity):
|
||||
super(OldMinSegmentTree, self).__init__(
|
||||
capacity=capacity, operation=min, neutral_element=float("inf"))
|
||||
|
||||
def min(self, start=0, end=None):
|
||||
"""Returns min(arr[start], ..., arr[end])"""
|
||||
|
||||
return super(OldMinSegmentTree, self).reduce(start, end)
|
419
rllib/execution/replay_buffer.py
Normal file
419
rllib/execution/replay_buffer.py
Normal file
|
@ -0,0 +1,419 @@
|
|||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
import collections
|
||||
import sys
|
||||
|
||||
import ray
|
||||
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.compression import unpack_if_needed
|
||||
from ray.util.iter import ParallelIteratorWorker
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ReplayBuffer:
|
||||
@DeveloperAPI
|
||||
def __init__(self, size):
|
||||
"""Create Prioritized Replay buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size: int
|
||||
Max number of transitions to store in the buffer. When the buffer
|
||||
overflows the old memories are dropped.
|
||||
"""
|
||||
self._storage = []
|
||||
self._maxsize = size
|
||||
self._next_idx = 0
|
||||
self._hit_count = np.zeros(size)
|
||||
self._eviction_started = False
|
||||
self._num_added = 0
|
||||
self._num_sampled = 0
|
||||
self._evicted_hit_stats = WindowStat("evicted_hit", 1000)
|
||||
self._est_size_bytes = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self._storage)
|
||||
|
||||
@DeveloperAPI
|
||||
def add(self, obs_t, action, reward, obs_tp1, done, weight):
|
||||
data = (obs_t, action, reward, obs_tp1, done)
|
||||
self._num_added += 1
|
||||
|
||||
if self._next_idx >= len(self._storage):
|
||||
self._storage.append(data)
|
||||
self._est_size_bytes += sum(sys.getsizeof(d) for d in data)
|
||||
else:
|
||||
self._storage[self._next_idx] = data
|
||||
if self._next_idx + 1 >= self._maxsize:
|
||||
self._eviction_started = True
|
||||
self._next_idx = (self._next_idx + 1) % self._maxsize
|
||||
if self._eviction_started:
|
||||
self._evicted_hit_stats.push(self._hit_count[self._next_idx])
|
||||
self._hit_count[self._next_idx] = 0
|
||||
|
||||
def _encode_sample(self, idxes):
|
||||
obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
|
||||
for i in idxes:
|
||||
data = self._storage[i]
|
||||
obs_t, action, reward, obs_tp1, done = data
|
||||
obses_t.append(np.array(unpack_if_needed(obs_t), copy=False))
|
||||
actions.append(np.array(action, copy=False))
|
||||
rewards.append(reward)
|
||||
obses_tp1.append(np.array(unpack_if_needed(obs_tp1), copy=False))
|
||||
dones.append(done)
|
||||
self._hit_count[i] += 1
|
||||
return (np.array(obses_t), np.array(actions), np.array(rewards),
|
||||
np.array(obses_tp1), np.array(dones))
|
||||
|
||||
@DeveloperAPI
|
||||
def sample_idxes(self, batch_size):
|
||||
return np.random.randint(0, len(self._storage), batch_size)
|
||||
|
||||
@DeveloperAPI
|
||||
def sample_with_idxes(self, idxes):
|
||||
self._num_sampled += len(idxes)
|
||||
return self._encode_sample(idxes)
|
||||
|
||||
@DeveloperAPI
|
||||
def sample(self, batch_size):
|
||||
"""Sample a batch of experiences.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
How many transitions to sample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
obs_batch: np.array
|
||||
batch of observations
|
||||
act_batch: np.array
|
||||
batch of actions executed given obs_batch
|
||||
rew_batch: np.array
|
||||
rewards received as results of executing act_batch
|
||||
next_obs_batch: np.array
|
||||
next set of observations seen after executing act_batch
|
||||
done_mask: np.array
|
||||
done_mask[i] = 1 if executing act_batch[i] resulted in
|
||||
the end of an episode and 0 otherwise.
|
||||
"""
|
||||
idxes = [
|
||||
random.randint(0,
|
||||
len(self._storage) - 1) for _ in range(batch_size)
|
||||
]
|
||||
self._num_sampled += batch_size
|
||||
return self._encode_sample(idxes)
|
||||
|
||||
@DeveloperAPI
|
||||
def stats(self, debug=False):
|
||||
data = {
|
||||
"added_count": self._num_added,
|
||||
"sampled_count": self._num_sampled,
|
||||
"est_size_bytes": self._est_size_bytes,
|
||||
"num_entries": len(self._storage),
|
||||
}
|
||||
if debug:
|
||||
data.update(self._evicted_hit_stats.stats())
|
||||
return data
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
@DeveloperAPI
|
||||
def __init__(self, size, alpha):
|
||||
"""Create Prioritized Replay buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size: int
|
||||
Max number of transitions to store in the buffer. When the buffer
|
||||
overflows the old memories are dropped.
|
||||
alpha: float
|
||||
how much prioritization is used
|
||||
(0 - no prioritization, 1 - full prioritization)
|
||||
|
||||
See Also
|
||||
--------
|
||||
ReplayBuffer.__init__
|
||||
"""
|
||||
super(PrioritizedReplayBuffer, self).__init__(size)
|
||||
assert alpha > 0
|
||||
self._alpha = alpha
|
||||
|
||||
it_capacity = 1
|
||||
while it_capacity < size:
|
||||
it_capacity *= 2
|
||||
|
||||
self._it_sum = SumSegmentTree(it_capacity)
|
||||
self._it_min = MinSegmentTree(it_capacity)
|
||||
self._max_priority = 1.0
|
||||
self._prio_change_stats = WindowStat("reprio", 1000)
|
||||
|
||||
@DeveloperAPI
|
||||
def add(self, obs_t, action, reward, obs_tp1, done, weight):
|
||||
"""See ReplayBuffer.store_effect"""
|
||||
|
||||
idx = self._next_idx
|
||||
super(PrioritizedReplayBuffer, self).add(obs_t, action, reward,
|
||||
obs_tp1, done, weight)
|
||||
if weight is None:
|
||||
weight = self._max_priority
|
||||
self._it_sum[idx] = weight**self._alpha
|
||||
self._it_min[idx] = weight**self._alpha
|
||||
|
||||
def _sample_proportional(self, batch_size):
|
||||
res = []
|
||||
for _ in range(batch_size):
|
||||
# TODO(szymon): should we ensure no repeats?
|
||||
mass = random.random() * self._it_sum.sum(0, len(self._storage))
|
||||
idx = self._it_sum.find_prefixsum_idx(mass)
|
||||
res.append(idx)
|
||||
return res
|
||||
|
||||
@DeveloperAPI
|
||||
def sample_idxes(self, batch_size):
|
||||
return self._sample_proportional(batch_size)
|
||||
|
||||
@DeveloperAPI
|
||||
def sample_with_idxes(self, idxes, beta):
|
||||
assert beta > 0
|
||||
self._num_sampled += len(idxes)
|
||||
|
||||
weights = []
|
||||
p_min = self._it_min.min() / self._it_sum.sum()
|
||||
max_weight = (p_min * len(self._storage))**(-beta)
|
||||
|
||||
for idx in idxes:
|
||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||
weight = (p_sample * len(self._storage))**(-beta)
|
||||
weights.append(weight / max_weight)
|
||||
weights = np.array(weights)
|
||||
encoded_sample = self._encode_sample(idxes)
|
||||
return tuple(list(encoded_sample) + [weights, idxes])
|
||||
|
||||
@DeveloperAPI
|
||||
def sample(self, batch_size, beta):
|
||||
"""Sample a batch of experiences.
|
||||
|
||||
compared to ReplayBuffer.sample
|
||||
it also returns importance weights and idxes
|
||||
of sampled experiences.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
How many transitions to sample.
|
||||
beta: float
|
||||
To what degree to use importance weights
|
||||
(0 - no corrections, 1 - full correction)
|
||||
|
||||
Returns
|
||||
-------
|
||||
obs_batch: np.array
|
||||
batch of observations
|
||||
act_batch: np.array
|
||||
batch of actions executed given obs_batch
|
||||
rew_batch: np.array
|
||||
rewards received as results of executing act_batch
|
||||
next_obs_batch: np.array
|
||||
next set of observations seen after executing act_batch
|
||||
done_mask: np.array
|
||||
done_mask[i] = 1 if executing act_batch[i] resulted in
|
||||
the end of an episode and 0 otherwise.
|
||||
weights: np.array
|
||||
Array of shape (batch_size,) and dtype np.float32
|
||||
denoting importance weight of each sampled transition
|
||||
idxes: np.array
|
||||
Array of shape (batch_size,) and dtype np.int32
|
||||
idexes in buffer of sampled experiences
|
||||
"""
|
||||
assert beta >= 0.0
|
||||
self._num_sampled += batch_size
|
||||
|
||||
idxes = self._sample_proportional(batch_size)
|
||||
|
||||
weights = []
|
||||
p_min = self._it_min.min() / self._it_sum.sum()
|
||||
max_weight = (p_min * len(self._storage))**(-beta)
|
||||
|
||||
for idx in idxes:
|
||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||
weight = (p_sample * len(self._storage))**(-beta)
|
||||
weights.append(weight / max_weight)
|
||||
weights = np.array(weights)
|
||||
encoded_sample = self._encode_sample(idxes)
|
||||
return tuple(list(encoded_sample) + [weights, idxes])
|
||||
|
||||
@DeveloperAPI
|
||||
def update_priorities(self, idxes, priorities):
|
||||
"""Update priorities of sampled transitions.
|
||||
|
||||
sets priority of transition at index idxes[i] in buffer
|
||||
to priorities[i].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idxes: [int]
|
||||
List of idxes of sampled transitions
|
||||
priorities: [float]
|
||||
List of updated priorities corresponding to
|
||||
transitions at the sampled idxes denoted by
|
||||
variable `idxes`.
|
||||
"""
|
||||
assert len(idxes) == len(priorities)
|
||||
for idx, priority in zip(idxes, priorities):
|
||||
assert priority > 0
|
||||
assert 0 <= idx < len(self._storage)
|
||||
delta = priority**self._alpha - self._it_sum[idx]
|
||||
self._prio_change_stats.push(delta)
|
||||
self._it_sum[idx] = priority**self._alpha
|
||||
self._it_min[idx] = priority**self._alpha
|
||||
|
||||
self._max_priority = max(self._max_priority, priority)
|
||||
|
||||
@DeveloperAPI
|
||||
def stats(self, debug=False):
|
||||
parent = ReplayBuffer.stats(self, debug)
|
||||
if debug:
|
||||
parent.update(self._prio_change_stats.stats())
|
||||
return parent
|
||||
|
||||
|
||||
# Visible for testing.
|
||||
_local_replay_buffer = None
|
||||
|
||||
|
||||
# TODO(ekl) move this class to common
|
||||
class LocalReplayBuffer(ParallelIteratorWorker):
|
||||
"""A replay buffer shard.
|
||||
|
||||
Ray actors are single-threaded, so for scalability multiple replay actors
|
||||
may be created to increase parallelism."""
|
||||
|
||||
def __init__(self,
|
||||
num_shards,
|
||||
learning_starts,
|
||||
buffer_size,
|
||||
replay_batch_size,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6,
|
||||
multiagent_sync_replay=False):
|
||||
self.replay_starts = learning_starts // num_shards
|
||||
self.buffer_size = buffer_size // num_shards
|
||||
self.replay_batch_size = replay_batch_size
|
||||
self.prioritized_replay_beta = prioritized_replay_beta
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
self.multiagent_sync_replay = multiagent_sync_replay
|
||||
|
||||
def gen_replay():
|
||||
while True:
|
||||
yield self.replay()
|
||||
|
||||
ParallelIteratorWorker.__init__(self, gen_replay, False)
|
||||
|
||||
def new_buffer():
|
||||
return PrioritizedReplayBuffer(
|
||||
self.buffer_size, alpha=prioritized_replay_alpha)
|
||||
|
||||
self.replay_buffers = collections.defaultdict(new_buffer)
|
||||
|
||||
# Metrics
|
||||
self.add_batch_timer = TimerStat()
|
||||
self.replay_timer = TimerStat()
|
||||
self.update_priorities_timer = TimerStat()
|
||||
self.num_added = 0
|
||||
|
||||
# Make externally accessible for testing.
|
||||
global _local_replay_buffer
|
||||
_local_replay_buffer = self
|
||||
# If set, return this instead of the usual data for testing.
|
||||
self._fake_batch = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance_for_testing():
|
||||
global _local_replay_buffer
|
||||
return _local_replay_buffer
|
||||
|
||||
def get_host(self):
|
||||
return os.uname()[1]
|
||||
|
||||
def add_batch(self, batch):
|
||||
# Make a copy so the replay buffer doesn't pin plasma memory.
|
||||
batch = batch.copy()
|
||||
# Handle everything as if multiagent
|
||||
if isinstance(batch, SampleBatch):
|
||||
batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
|
||||
with self.add_batch_timer:
|
||||
for policy_id, s in batch.policy_batches.items():
|
||||
for row in s.rows():
|
||||
self.replay_buffers[policy_id].add(
|
||||
row["obs"], row["actions"], row["rewards"],
|
||||
row["new_obs"], row["dones"], row["weights"]
|
||||
if "weights" in row else None)
|
||||
self.num_added += batch.count
|
||||
|
||||
def replay(self):
|
||||
if self._fake_batch:
|
||||
fake_batch = SampleBatch(self._fake_batch)
|
||||
return MultiAgentBatch({
|
||||
DEFAULT_POLICY_ID: fake_batch
|
||||
}, fake_batch.count)
|
||||
|
||||
if self.num_added < self.replay_starts:
|
||||
return None
|
||||
|
||||
with self.replay_timer:
|
||||
samples = {}
|
||||
idxes = None
|
||||
for policy_id, replay_buffer in self.replay_buffers.items():
|
||||
if self.multiagent_sync_replay:
|
||||
if idxes is None:
|
||||
idxes = replay_buffer.sample_idxes(
|
||||
self.replay_batch_size)
|
||||
else:
|
||||
idxes = replay_buffer.sample_idxes(self.replay_batch_size)
|
||||
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
||||
batch_indexes) = replay_buffer.sample_with_idxes(
|
||||
idxes, beta=self.prioritized_replay_beta)
|
||||
samples[policy_id] = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": obses_tp1,
|
||||
"dones": dones,
|
||||
"weights": weights,
|
||||
"batch_indexes": batch_indexes
|
||||
})
|
||||
return MultiAgentBatch(samples, self.replay_batch_size)
|
||||
|
||||
def update_priorities(self, prio_dict):
|
||||
with self.update_priorities_timer:
|
||||
for policy_id, (batch_indexes, td_errors) in prio_dict.items():
|
||||
new_priorities = (
|
||||
np.abs(td_errors) + self.prioritized_replay_eps)
|
||||
self.replay_buffers[policy_id].update_priorities(
|
||||
batch_indexes, new_priorities)
|
||||
|
||||
def stats(self, debug=False):
|
||||
stat = {
|
||||
"add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
|
||||
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
|
||||
"update_priorities_time_ms": round(
|
||||
1000 * self.update_priorities_timer.mean, 3),
|
||||
}
|
||||
for policy_id, replay_buffer in self.replay_buffers.items():
|
||||
stat.update({
|
||||
"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)
|
||||
})
|
||||
return stat
|
||||
|
||||
|
||||
ReplayActor = ray.remote(num_cpus=0)(LocalReplayBuffer)
|
|
@ -3,8 +3,9 @@ import random
|
|||
|
||||
from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
|
||||
from ray.util.iter_metrics import SharedMetrics
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
|
||||
from ray.rllib.execution.common import SampleBatchType
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.common import SampleBatchType, \
|
||||
STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||
|
||||
|
||||
class StoreToReplayBuffer:
|
||||
|
@ -93,6 +94,18 @@ def Replay(*,
|
|||
return LocalIterator(gen_replay, SharedMetrics())
|
||||
|
||||
|
||||
class WaitUntilTimestepsElapsed:
|
||||
"""Callable that returns True once a given number of timesteps are hit."""
|
||||
|
||||
def __init__(self, target_num_timesteps):
|
||||
self.target_num_timesteps = target_num_timesteps
|
||||
|
||||
def __call__(self, item):
|
||||
metrics = _get_shared_metrics()
|
||||
ts = metrics.counters[STEPS_SAMPLED_COUNTER]
|
||||
return ts > self.target_num_timesteps
|
||||
|
||||
|
||||
class SimpleReplayBuffer:
|
||||
"""Simple replay buffer that operates over batches."""
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from ray.rllib.evaluation.rollout_worker import get_global_worker
|
|||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import GradientType, SampleBatchType, \
|
||||
STEPS_SAMPLED_COUNTER, LEARNER_INFO, SAMPLE_TIMER, \
|
||||
GRAD_WAIT_TIMER, _check_sample_batch_type
|
||||
GRAD_WAIT_TIMER, _check_sample_batch_type, _get_shared_metrics
|
||||
from ray.rllib.policy.policy import PolicyID
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
|
@ -59,7 +59,7 @@ def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync",
|
|||
workers.sync_weights()
|
||||
|
||||
def report_timesteps(batch):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
|
||||
return batch
|
||||
|
||||
|
@ -123,7 +123,7 @@ def AsyncGradients(
|
|||
|
||||
def __call__(self, item):
|
||||
(grads, info), count = item
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters[STEPS_SAMPLED_COUNTER] += count
|
||||
metrics.info[LEARNER_INFO] = get_learner_stats(info)
|
||||
metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() -
|
||||
|
@ -169,7 +169,7 @@ class ConcatBatches:
|
|||
"This may be because you have many workers or "
|
||||
"long episodes in 'complete_episodes' batch mode.")
|
||||
out = SampleBatch.concat_samples(self.buffer)
|
||||
timer = LocalIterator.get_metrics().timers[SAMPLE_TIMER]
|
||||
timer = _get_shared_metrics().timers[SAMPLE_TIMER]
|
||||
timer.push(time.perf_counter() - self.batch_start_time)
|
||||
timer.push_units_processed(self.count)
|
||||
self.batch_start_time = None
|
||||
|
|
196
rllib/execution/segment_tree.py
Normal file
196
rllib/execution/segment_tree.py
Normal file
|
@ -0,0 +1,196 @@
|
|||
import operator
|
||||
|
||||
|
||||
class SegmentTree:
|
||||
"""A Segment Tree data structure.
|
||||
|
||||
https://en.wikipedia.org/wiki/Segment_tree
|
||||
|
||||
Can be used as regular array, but with two important differences:
|
||||
|
||||
a) Setting an item's value is slightly slower. It is O(lg capacity),
|
||||
instead of O(1).
|
||||
b) Offers efficient `reduce` operation which reduces the tree's values
|
||||
over some specified contiguous subsequence of items in the array.
|
||||
Operation could be e.g. min/max/sum.
|
||||
|
||||
The data is stored in a list, where the length is 2 * capacity.
|
||||
The second half of the list stores the actual values for each index, so if
|
||||
capacity=8, values are stored at indices 8 to 15. The first half of the
|
||||
array contains the reduced-values of the different (binary divided)
|
||||
segments, e.g. (capacity=4):
|
||||
0=not used
|
||||
1=reduced-value over all elements (array indices 4 to 7).
|
||||
2=reduced-value over array indices (4 and 5).
|
||||
3=reduced-value over array indices (6 and 7).
|
||||
4-7: values of the tree.
|
||||
NOTE that the values of the tree are accessed by indices starting at 0, so
|
||||
`tree[0]` accesses `internal_array[4]` in the above example.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity, operation, neutral_element=None):
|
||||
"""Initializes a Segment Tree object.
|
||||
|
||||
Args:
|
||||
capacity (int): Total size of the array - must be a power of two.
|
||||
operation (operation): Lambda obj, obj -> obj
|
||||
The operation for combining elements (eg. sum, max).
|
||||
Must be a mathematical group together with the set of
|
||||
possible values for array elements.
|
||||
neutral_element (Optional[obj]): The neutral element for
|
||||
`operation`. Use None for automatically finding a value:
|
||||
max: float("-inf"), min: float("inf"), sum: 0.0.
|
||||
"""
|
||||
|
||||
assert capacity > 0 and capacity & (capacity - 1) == 0, \
|
||||
"Capacity must be positive and a power of 2!"
|
||||
self.capacity = capacity
|
||||
if neutral_element is None:
|
||||
neutral_element = 0.0 if operation is operator.add else \
|
||||
float("-inf") if operation is max else float("inf")
|
||||
self.neutral_element = neutral_element
|
||||
self.value = [self.neutral_element for _ in range(2 * capacity)]
|
||||
self.operation = operation
|
||||
|
||||
def reduce(self, start=0, end=None):
|
||||
"""Applies `self.operation` to subsequence of our values.
|
||||
|
||||
Subsequence is contiguous, includes `start` and excludes `end`.
|
||||
|
||||
self.operation(
|
||||
arr[start], operation(arr[start+1], operation(... arr[end])))
|
||||
|
||||
Args:
|
||||
start (int): Start index to apply reduction to.
|
||||
end (Optional[int]): End index to apply reduction to (excluded).
|
||||
|
||||
Returns:
|
||||
any: The result of reducing self.operation over the specified
|
||||
range of `self._value` elements.
|
||||
"""
|
||||
if end is None:
|
||||
end = self.capacity
|
||||
elif end < 0:
|
||||
end += self.capacity
|
||||
|
||||
# Init result with neutral element.
|
||||
result = self.neutral_element
|
||||
# Map start/end to our actual index space (second half of array).
|
||||
start += self.capacity
|
||||
end += self.capacity
|
||||
|
||||
# Example:
|
||||
# internal-array (first half=sums, second half=actual values):
|
||||
# 0 1 2 3 | 4 5 6 7
|
||||
# - 6 1 5 | 1 0 2 3
|
||||
|
||||
# tree.sum(0, 3) = 3
|
||||
# internally: start=4, end=7 -> sum values 1 0 2 = 3.
|
||||
|
||||
# Iterate over tree starting in the actual-values (second half)
|
||||
# section.
|
||||
# 1) start=4 is even -> do nothing.
|
||||
# 2) end=7 is odd -> end-- -> end=6 -> add value to result: result=2
|
||||
# 3) int-divide start and end by 2: start=2, end=3
|
||||
# 4) start still smaller end -> iterate once more.
|
||||
# 5) start=2 is even -> do nothing.
|
||||
# 6) end=3 is odd -> end-- -> end=2 -> add value to result: result=1
|
||||
# NOTE: This adds the sum of indices 4 and 5 to the result.
|
||||
|
||||
# Iterate as long as start != end.
|
||||
while start < end:
|
||||
|
||||
# If start is odd: Add its value to result and move start to
|
||||
# next even value.
|
||||
if start & 1:
|
||||
result = self.operation(result, self.value[start])
|
||||
start += 1
|
||||
|
||||
# If end is odd: Move end to previous even value, then add its
|
||||
# value to result. NOTE: This takes care of excluding `end` in any
|
||||
# situation.
|
||||
if end & 1:
|
||||
end -= 1
|
||||
result = self.operation(result, self.value[end])
|
||||
|
||||
# Divide both start and end by 2 to make them "jump" into the
|
||||
# next upper level reduce-index space.
|
||||
start //= 2
|
||||
end //= 2
|
||||
|
||||
# Then repeat till start == end.
|
||||
|
||||
return result
|
||||
|
||||
def __setitem__(self, idx, val):
|
||||
"""
|
||||
Inserts/overwrites a value in/into the tree.
|
||||
|
||||
Args:
|
||||
idx (int): The index to insert to. Must be in [0, `self.capacity`[
|
||||
val (float): The value to insert.
|
||||
"""
|
||||
assert 0 <= idx < self.capacity
|
||||
|
||||
# Index of the leaf to insert into (always insert in "second half"
|
||||
# of the tree, the first half is reserved for already calculated
|
||||
# reduction-values).
|
||||
idx += self.capacity
|
||||
self.value[idx] = val
|
||||
|
||||
# Recalculate all affected reduction values (in "first half" of tree).
|
||||
idx = idx >> 1 # Divide by 2 (faster than division).
|
||||
while idx >= 1:
|
||||
update_idx = 2 * idx # calculate only once
|
||||
# Update the reduction value at the correct "first half" idx.
|
||||
self.value[idx] = self.operation(self.value[update_idx],
|
||||
self.value[update_idx + 1])
|
||||
idx = idx >> 1 # Divide by 2 (faster than division).
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert 0 <= idx < self.capacity
|
||||
return self.value[idx + self.capacity]
|
||||
|
||||
|
||||
class SumSegmentTree(SegmentTree):
|
||||
"""A SegmentTree with the reduction `operation`=operator.add."""
|
||||
|
||||
def __init__(self, capacity):
|
||||
super(SumSegmentTree, self).__init__(
|
||||
capacity=capacity, operation=operator.add)
|
||||
|
||||
def sum(self, start=0, end=None):
|
||||
"""Returns the sum over a sub-segment of the tree."""
|
||||
return self.reduce(start, end)
|
||||
|
||||
def find_prefixsum_idx(self, prefixsum):
|
||||
"""Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum.
|
||||
|
||||
Args:
|
||||
prefixsum (float): `prefixsum` upper bound in above constraint.
|
||||
|
||||
Returns:
|
||||
int: Largest possible index (i) satisfying above constraint.
|
||||
"""
|
||||
assert 0 <= prefixsum <= self.sum() + 1e-5
|
||||
# Global sum node.
|
||||
idx = 1
|
||||
|
||||
# While non-leaf (first half of tree).
|
||||
while idx < self.capacity:
|
||||
update_idx = 2 * idx
|
||||
if self.value[update_idx] > prefixsum:
|
||||
idx = update_idx
|
||||
else:
|
||||
prefixsum -= self.value[update_idx]
|
||||
idx = update_idx + 1
|
||||
return idx - self.capacity
|
||||
|
||||
|
||||
class MinSegmentTree(SegmentTree):
|
||||
def __init__(self, capacity):
|
||||
super(MinSegmentTree, self).__init__(capacity=capacity, operation=min)
|
||||
|
||||
def min(self, start=0, end=None):
|
||||
"""Returns min(arr[start], ..., arr[end])"""
|
||||
return self.reduce(start, end)
|
180
rllib/execution/test_prioritized_replay_buffer.py
Normal file
180
rllib/execution/test_prioritized_replay_buffer.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
from collections import Counter
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from ray.rllib.execution.replay_buffer import PrioritizedReplayBuffer
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
||||
|
||||
class TestPrioritizedReplayBuffer(unittest.TestCase):
|
||||
"""
|
||||
Tests insertion and (weighted) sampling of the PrioritizedReplayBuffer.
|
||||
"""
|
||||
|
||||
capacity = 10
|
||||
alpha = 1.0
|
||||
beta = 1.0
|
||||
max_priority = 1.0
|
||||
|
||||
def _generate_data(self):
|
||||
return (
|
||||
np.random.random((4, )), # obs_t
|
||||
np.random.choice([0, 1]), # action
|
||||
np.random.rand(), # reward
|
||||
np.random.random((4, )), # obs_tp1
|
||||
np.random.choice([False, True]), # done
|
||||
)
|
||||
|
||||
def test_add(self):
|
||||
memory = PrioritizedReplayBuffer(
|
||||
size=2,
|
||||
alpha=self.alpha,
|
||||
)
|
||||
|
||||
# Assert indices 0 before insert.
|
||||
self.assertEqual(len(memory), 0)
|
||||
self.assertEqual(memory._next_idx, 0)
|
||||
|
||||
# Insert single record.
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=0.5)
|
||||
self.assertTrue(len(memory) == 1)
|
||||
self.assertTrue(memory._next_idx == 1)
|
||||
|
||||
# Insert single record.
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=0.1)
|
||||
self.assertTrue(len(memory) == 2)
|
||||
self.assertTrue(memory._next_idx == 0)
|
||||
|
||||
# Insert over capacity.
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=1.0)
|
||||
self.assertTrue(len(memory) == 2)
|
||||
self.assertTrue(memory._next_idx == 1)
|
||||
|
||||
def test_update_priorities(self):
|
||||
memory = PrioritizedReplayBuffer(size=self.capacity, alpha=self.alpha)
|
||||
|
||||
# Insert n samples.
|
||||
num_records = 5
|
||||
for i in range(num_records):
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=1.0)
|
||||
self.assertTrue(len(memory) == i + 1)
|
||||
self.assertTrue(memory._next_idx == i + 1)
|
||||
|
||||
# Fetch records, their indices and weights.
|
||||
_, _, _, _, _, weights, indices = \
|
||||
memory.sample(3, beta=self.beta)
|
||||
check(weights, np.ones(shape=(3, )))
|
||||
self.assertEqual(3, len(indices))
|
||||
self.assertTrue(len(memory) == num_records)
|
||||
self.assertTrue(memory._next_idx == num_records)
|
||||
|
||||
# Update weight of indices 0, 2, 3, 4 to very small.
|
||||
memory.update_priorities(
|
||||
np.array([0, 2, 3, 4]), np.array([0.01, 0.01, 0.01, 0.01]))
|
||||
# Expect to sample almost only index 1
|
||||
# (which still has a weight of 1.0).
|
||||
for _ in range(10):
|
||||
_, _, _, _, _, weights, indices = memory.sample(
|
||||
1000, beta=self.beta)
|
||||
self.assertTrue(970 < np.sum(indices) < 1100)
|
||||
|
||||
# Update weight of indices 0 and 1 to >> 0.01.
|
||||
# Expect to sample 0 and 1 equally (and some 2s, 3s, and 4s).
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(np.array([0, 1]), np.array([rand, rand]))
|
||||
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
|
||||
# Expect biased to higher values due to some 2s, 3s, and 4s.
|
||||
# print(np.sum(indices))
|
||||
self.assertTrue(400 < np.sum(indices) < 800)
|
||||
|
||||
# Update weights to be 1:2.
|
||||
# Expect to sample double as often index 1 over index 0
|
||||
# plus very few times indices 2, 3, or 4.
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(
|
||||
np.array([0, 1]), np.array([rand, rand * 2]))
|
||||
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
|
||||
# print(np.sum(indices))
|
||||
self.assertTrue(600 < np.sum(indices) < 850)
|
||||
|
||||
# Update weights to be 1:4.
|
||||
# Expect to sample quadruple as often index 1 over index 0
|
||||
# plus very few times indices 2, 3, or 4.
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(
|
||||
np.array([0, 1]), np.array([rand, rand * 4]))
|
||||
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
|
||||
# print(np.sum(indices))
|
||||
self.assertTrue(750 < np.sum(indices) < 950)
|
||||
|
||||
# Update weights to be 1:9.
|
||||
# Expect to sample 9 times as often index 1 over index 0.
|
||||
# plus very few times indices 2, 3, or 4.
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(
|
||||
np.array([0, 1]), np.array([rand, rand * 9]))
|
||||
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
|
||||
# print(np.sum(indices))
|
||||
self.assertTrue(850 < np.sum(indices) < 1100)
|
||||
|
||||
# Insert n more samples.
|
||||
num_records = 5
|
||||
for i in range(num_records):
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=1.0)
|
||||
self.assertTrue(len(memory) == i + 6)
|
||||
self.assertTrue(memory._next_idx == (i + 6) % self.capacity)
|
||||
|
||||
# Update all weights to be 1.0 to 10.0 and sample a >100 batch.
|
||||
memory.update_priorities(
|
||||
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
|
||||
np.array([0.001, 0.1, 2., 8., 16., 32., 64., 128., 256., 512.]))
|
||||
counts = Counter()
|
||||
for _ in range(10):
|
||||
_, _, _, _, _, _, indices = memory.sample(
|
||||
np.random.randint(100, 600), beta=self.beta)
|
||||
for i in indices:
|
||||
counts[i] += 1
|
||||
print(counts)
|
||||
# Expect an approximately correct distribution of indices.
|
||||
self.assertTrue(
|
||||
counts[9] >= counts[8] >= counts[7] >= counts[6] >= counts[5] >=
|
||||
counts[4] >= counts[3] >= counts[2] >= counts[1] >= counts[0])
|
||||
|
||||
def test_alpha_parameter(self):
|
||||
# Test sampling from a PR with a very small alpha (should behave just
|
||||
# like a regular ReplayBuffer).
|
||||
memory = PrioritizedReplayBuffer(size=self.capacity, alpha=0.01)
|
||||
|
||||
# Insert n samples.
|
||||
num_records = 5
|
||||
for i in range(num_records):
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=np.random.rand())
|
||||
self.assertTrue(len(memory) == i + 1)
|
||||
self.assertTrue(memory._next_idx == i + 1)
|
||||
|
||||
# Fetch records, their indices and weights.
|
||||
_, _, _, _, _, weights, indices = \
|
||||
memory.sample(1000, beta=self.beta)
|
||||
counts = Counter()
|
||||
for i in indices:
|
||||
counts[i] += 1
|
||||
print(counts)
|
||||
# Expect an approximately uniform distribution of indices.
|
||||
for i in counts.values():
|
||||
self.assertTrue(100 < i < 300)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
133
rllib/execution/test_segment_tree.py
Normal file
133
rllib/execution/test_segment_tree.py
Normal file
|
@ -0,0 +1,133 @@
|
|||
import numpy as np
|
||||
import timeit
|
||||
import unittest
|
||||
|
||||
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
|
||||
|
||||
class TestSegmentTree(unittest.TestCase):
|
||||
def test_tree_set(self):
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[2] = 1.0
|
||||
tree[3] = 3.0
|
||||
|
||||
assert np.isclose(tree.sum(), 4.0)
|
||||
assert np.isclose(tree.sum(0, 2), 0.0)
|
||||
assert np.isclose(tree.sum(0, 3), 1.0)
|
||||
assert np.isclose(tree.sum(2, 3), 1.0)
|
||||
assert np.isclose(tree.sum(2, -1), 1.0)
|
||||
assert np.isclose(tree.sum(2, 4), 4.0)
|
||||
assert np.isclose(tree.sum(2), 4.0)
|
||||
|
||||
def test_tree_set_overlap(self):
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[2] = 1.0
|
||||
tree[2] = 3.0
|
||||
|
||||
assert np.isclose(tree.sum(), 3.0)
|
||||
assert np.isclose(tree.sum(2, 3), 3.0)
|
||||
assert np.isclose(tree.sum(2, -1), 3.0)
|
||||
assert np.isclose(tree.sum(2, 4), 3.0)
|
||||
assert np.isclose(tree.sum(2), 3.0)
|
||||
assert np.isclose(tree.sum(1, 2), 0.0)
|
||||
|
||||
def test_prefixsum_idx(self):
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[2] = 1.0
|
||||
tree[3] = 3.0
|
||||
|
||||
assert tree.find_prefixsum_idx(0.0) == 2
|
||||
assert tree.find_prefixsum_idx(0.5) == 2
|
||||
assert tree.find_prefixsum_idx(0.99) == 2
|
||||
assert tree.find_prefixsum_idx(1.01) == 3
|
||||
assert tree.find_prefixsum_idx(3.00) == 3
|
||||
assert tree.find_prefixsum_idx(4.00) == 3
|
||||
|
||||
def test_prefixsum_idx2(self):
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[0] = 0.5
|
||||
tree[1] = 1.0
|
||||
tree[2] = 1.0
|
||||
tree[3] = 3.0
|
||||
|
||||
assert tree.find_prefixsum_idx(0.00) == 0
|
||||
assert tree.find_prefixsum_idx(0.55) == 1
|
||||
assert tree.find_prefixsum_idx(0.99) == 1
|
||||
assert tree.find_prefixsum_idx(1.51) == 2
|
||||
assert tree.find_prefixsum_idx(3.00) == 3
|
||||
assert tree.find_prefixsum_idx(5.50) == 3
|
||||
|
||||
def test_max_interval_tree(self):
|
||||
tree = MinSegmentTree(4)
|
||||
|
||||
tree[0] = 1.0
|
||||
tree[2] = 0.5
|
||||
tree[3] = 3.0
|
||||
|
||||
assert np.isclose(tree.min(), 0.5)
|
||||
assert np.isclose(tree.min(0, 2), 1.0)
|
||||
assert np.isclose(tree.min(0, 3), 0.5)
|
||||
assert np.isclose(tree.min(0, -1), 0.5)
|
||||
assert np.isclose(tree.min(2, 4), 0.5)
|
||||
assert np.isclose(tree.min(3, 4), 3.0)
|
||||
|
||||
tree[2] = 0.7
|
||||
|
||||
assert np.isclose(tree.min(), 0.7)
|
||||
assert np.isclose(tree.min(0, 2), 1.0)
|
||||
assert np.isclose(tree.min(0, 3), 0.7)
|
||||
assert np.isclose(tree.min(0, -1), 0.7)
|
||||
assert np.isclose(tree.min(2, 4), 0.7)
|
||||
assert np.isclose(tree.min(3, 4), 3.0)
|
||||
|
||||
tree[2] = 4.0
|
||||
|
||||
assert np.isclose(tree.min(), 1.0)
|
||||
assert np.isclose(tree.min(0, 2), 1.0)
|
||||
assert np.isclose(tree.min(0, 3), 1.0)
|
||||
assert np.isclose(tree.min(0, -1), 1.0)
|
||||
assert np.isclose(tree.min(2, 4), 3.0)
|
||||
assert np.isclose(tree.min(2, 3), 4.0)
|
||||
assert np.isclose(tree.min(2, -1), 4.0)
|
||||
assert np.isclose(tree.min(3, 4), 3.0)
|
||||
|
||||
def test_microbenchmark_vs_old_version(self):
|
||||
"""
|
||||
Results from March 2020 (capacity=1048576):
|
||||
|
||||
New tree:
|
||||
0.049599366000000256s
|
||||
results = timeit.timeit("tree.sum(5, 60000)",
|
||||
setup="from ray.rllib.execution.segment_tree import
|
||||
SumSegmentTree; tree = SumSegmentTree({})".format(capacity),
|
||||
number=10000)
|
||||
|
||||
Old tree:
|
||||
0.13390400999999974s
|
||||
results = timeit.timeit("tree.sum(5, 60000)",
|
||||
setup="from ray.rllib.execution.tests.old_segment_tree import
|
||||
OldSumSegmentTree; tree = OldSumSegmentTree({})".format(capacity),
|
||||
number=10000)
|
||||
"""
|
||||
capacity = 2**20
|
||||
new = timeit.timeit(
|
||||
"tree.sum(5, 60000)",
|
||||
setup="from ray.rllib.execution.segment_tree import "
|
||||
"SumSegmentTree; tree = SumSegmentTree({})".format(capacity),
|
||||
number=10000)
|
||||
old = timeit.timeit(
|
||||
"tree.sum(5, 60000)",
|
||||
setup="from ray.rllib.execution.tests.old_segment_tree import "
|
||||
"OldSumSegmentTree; tree = OldSumSegmentTree({})".format(capacity),
|
||||
number=10000)
|
||||
self.assertGreater(old, new)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -5,15 +5,15 @@ import math
|
|||
from typing import List
|
||||
|
||||
import ray
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats, LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import SampleBatchType, \
|
||||
STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER, LEARNER_INFO, \
|
||||
APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, WORKER_UPDATE_TIMER, \
|
||||
LEARN_ON_BATCH_TIMER, LOAD_BATCH_TIMER, LAST_TARGET_UPDATE_TS, \
|
||||
NUM_TARGET_UPDATES, _get_global_vars, _check_sample_batch_type
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
NUM_TARGET_UPDATES, _get_global_vars, _check_sample_batch_type, \
|
||||
_get_shared_metrics
|
||||
from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.policy.policy import PolicyID
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
|
@ -54,7 +54,7 @@ class TrainOneStep:
|
|||
def __call__(self,
|
||||
batch: SampleBatchType) -> (SampleBatchType, List[dict]):
|
||||
_check_sample_batch_type(batch)
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
|
||||
with learn_timer:
|
||||
if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
|
||||
|
@ -164,7 +164,7 @@ class TrainTFMultiGPU:
|
|||
DEFAULT_POLICY_ID: samples
|
||||
}, samples.count)
|
||||
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
load_timer = metrics.timers[LOAD_BATCH_TIMER]
|
||||
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
|
||||
with load_timer:
|
||||
|
@ -245,7 +245,7 @@ class ComputeGradients:
|
|||
|
||||
def __call__(self, samples: SampleBatchType):
|
||||
_check_sample_batch_type(samples)
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
with metrics.timers[COMPUTE_GRADS_TIMER]:
|
||||
grad, info = self.workers.local_worker().compute_gradients(samples)
|
||||
metrics.info[LEARNER_INFO] = get_learner_stats(info)
|
||||
|
@ -287,7 +287,7 @@ class ApplyGradients:
|
|||
"Input must be a tuple of (grad_dict, count), got {}".format(
|
||||
item))
|
||||
gradients, count = item
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters[STEPS_TRAINED_COUNTER] += count
|
||||
|
||||
apply_timer = metrics.timers[APPLY_GRADS_TIMER]
|
||||
|
@ -377,7 +377,7 @@ class UpdateTargetNetwork:
|
|||
self.metric = STEPS_SAMPLED_COUNTER
|
||||
|
||||
def __call__(self, _):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
cur_ts = metrics.counters[self.metric]
|
||||
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update > self.target_update_freq:
|
||||
|
|
|
@ -4,12 +4,12 @@ from typing import List
|
|||
|
||||
import ray
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
SampleBatchType
|
||||
SampleBatchType, _get_shared_metrics
|
||||
from ray.rllib.execution.replay_ops import MixInReplay
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.utils.actors import create_colocated
|
||||
from ray.util.iter import LocalIterator, ParallelIterator, \
|
||||
ParallelIteratorWorker, from_actors
|
||||
from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \
|
||||
from_actors
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -93,7 +93,7 @@ def gather_experiences_tree_aggregation(workers, config):
|
|||
|
||||
# TODO(ekl) properly account for replay.
|
||||
def record_steps_sampled(batch):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
|
||||
return batch
|
||||
|
1
rllib/optimizers/README
Normal file
1
rllib/optimizers/README
Normal file
|
@ -0,0 +1 @@
|
|||
This directory is deprecated; all files in it will be removed in a future release.
|
|
@ -6,7 +6,6 @@ import random
|
|||
import ray
|
||||
from ray.rllib.utils.actors import TaskPool
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
class Aggregator:
|
||||
|
@ -166,7 +165,7 @@ class AggregationWorkerBase:
|
|||
return len(self.replay_batches) > num_needed
|
||||
|
||||
for ev, sample_batch in sample_futures:
|
||||
sample_batch = ray_get_and_free(sample_batch)
|
||||
sample_batch = ray.get(sample_batch)
|
||||
yield ev, sample_batch
|
||||
|
||||
if can_replay():
|
||||
|
|
|
@ -10,7 +10,6 @@ from ray.rllib.utils.actors import TaskPool, create_colocated
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.optimizers.aso_aggregator import Aggregator, \
|
||||
AggregationWorkerBase
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -100,7 +99,7 @@ class TreeAggregator(Aggregator):
|
|||
def iter_train_batches(self):
|
||||
assert self.initialized, "Must call init() before using this class."
|
||||
for agg, batches in self.agg_tasks.completed_prefetch():
|
||||
for b in ray_get_and_free(batches):
|
||||
for b in ray.get(batches):
|
||||
self.num_sent_since_broadcast += 1
|
||||
yield b
|
||||
agg.set_weights.remote(self.broadcasted_weights)
|
||||
|
|
|
@ -3,7 +3,6 @@ from ray.rllib.evaluation.metrics import get_learner_stats
|
|||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
|
@ -53,7 +52,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
|||
ready_list = wait_results[0]
|
||||
future = ready_list[0]
|
||||
|
||||
gradient, info = ray_get_and_free(future)
|
||||
gradient, info = ray.get(future)
|
||||
e = pending_gradients.pop(future)
|
||||
self.learner_stats = get_learner_stats(info)
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
|||
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.actors import TaskPool, create_colocated
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
|
||||
|
@ -166,8 +165,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
|||
|
||||
@override(PolicyOptimizer)
|
||||
def stats(self):
|
||||
replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote(
|
||||
self.debug))
|
||||
replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug))
|
||||
timing = {
|
||||
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
|
||||
for k in self.timers
|
||||
|
@ -218,7 +216,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
|||
counts = {
|
||||
i: v
|
||||
for i, v in enumerate(
|
||||
ray_get_and_free([c[1][1] for c in completed]))
|
||||
ray.get([c[1][1] for c in completed]))
|
||||
}
|
||||
# If there are failed workers, try to recover the still good ones
|
||||
# (via non-batched ray.get()) and store the first error (to raise
|
||||
|
@ -227,7 +225,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
|||
counts = {}
|
||||
for i, c in enumerate(completed):
|
||||
try:
|
||||
counts[i] = ray_get_and_free(c[1][1])
|
||||
counts[i] = ray.get(c[1][1])
|
||||
except RayError as e:
|
||||
logger.exception(
|
||||
"Error in completed task: {}".format(e))
|
||||
|
@ -272,7 +270,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
|||
self.num_samples_dropped += 1
|
||||
else:
|
||||
with self.timers["get_samples"]:
|
||||
samples = ray_get_and_free(replay)
|
||||
samples = ray.get(replay)
|
||||
# Defensive copy against plasma crashes, see #2610 #3452
|
||||
self.learner.inqueue.put((ra, samples and samples.copy()))
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.filter import RunningStat
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -65,7 +64,7 @@ class MicrobatchOptimizer(PolicyOptimizer):
|
|||
while sum(s.count for s in samples) < self.microbatch_size:
|
||||
if self.workers.remote_workers():
|
||||
samples.extend(
|
||||
ray_get_and_free([
|
||||
ray.get([
|
||||
e.sample.remote()
|
||||
for e in self.workers.remote_workers()
|
||||
]))
|
||||
|
|
|
@ -2,7 +2,6 @@ import logging
|
|||
|
||||
import ray
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -22,7 +21,7 @@ def collect_samples(agents, rollout_fragment_length, num_envs_per_worker,
|
|||
while agent_dict:
|
||||
[fut_sample], _ = ray.wait(list(agent_dict))
|
||||
agent = agent_dict.pop(fut_sample)
|
||||
next_sample = ray_get_and_free(fut_sample)
|
||||
next_sample = ray.get(fut_sample)
|
||||
num_timesteps_so_far += next_sample.count
|
||||
trajectories.append(next_sample)
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
|||
MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
class SyncBatchReplayOptimizer(PolicyOptimizer):
|
||||
|
@ -56,7 +55,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
|
|||
|
||||
with self.sample_timer:
|
||||
if self.workers.remote_workers():
|
||||
batches = ray_get_and_free(
|
||||
batches = ray.get(
|
||||
[e.sample.remote() for e in self.workers.remote_workers()])
|
||||
else:
|
||||
batches = [self.workers.local_worker().sample()]
|
||||
|
|
|
@ -13,7 +13,6 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.compression import pack_if_needed
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.schedules import PiecewiseSchedule
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -119,7 +118,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
|||
with self.sample_timer:
|
||||
if self.workers.remote_workers():
|
||||
batch = SampleBatch.concat_samples(
|
||||
ray_get_and_free([
|
||||
ray.get([
|
||||
e.sample.remote()
|
||||
for e in self.workers.remote_workers()
|
||||
]))
|
||||
|
|
|
@ -7,7 +7,6 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.filter import RunningStat
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -54,7 +53,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
|||
while sum(s.count for s in samples) < self.train_batch_size:
|
||||
if self.workers.remote_workers():
|
||||
samples.extend(
|
||||
ray_get_and_free([
|
||||
ray.get([
|
||||
e.sample.remote()
|
||||
for e in self.workers.remote_workers()
|
||||
]))
|
||||
|
|
|
@ -17,10 +17,8 @@ class TestDistributedExecution(unittest.TestCase):
|
|||
|
||||
def test_exec_plan_stats(ray_start_regular):
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
env="CartPole-v0", config={
|
||||
"min_iter_time_s": 0,
|
||||
"use_exec_api": True
|
||||
})
|
||||
result = trainer.train()
|
||||
assert isinstance(result, dict)
|
||||
|
@ -37,10 +35,8 @@ class TestDistributedExecution(unittest.TestCase):
|
|||
|
||||
def test_exec_plan_save_restore(ray_start_regular):
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
env="CartPole-v0", config={
|
||||
"min_iter_time_s": 0,
|
||||
"use_exec_api": True
|
||||
})
|
||||
res1 = trainer.train()
|
||||
checkpoint = trainer.save()
|
||||
|
|
|
@ -15,7 +15,7 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, AsyncGradients, \
|
|||
ConcatBatches, StandardizeFields
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, ComputeGradients, \
|
||||
AverageGradients
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer, \
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer, \
|
||||
ReplayActor
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.util.iter import LocalIterator, from_range
|
||||
|
@ -174,7 +174,8 @@ def test_train_one_step(ray_start_regular_shared):
|
|||
b = a.for_each(TrainOneStep(workers))
|
||||
batch, stats = next(b)
|
||||
assert isinstance(batch, SampleBatch)
|
||||
assert "learner_stats" in stats
|
||||
assert "default_policy" in stats
|
||||
assert "learner_stats" in stats["default_policy"]
|
||||
counters = a.shared_metrics.get().counters
|
||||
assert counters["num_steps_sampled"] == 100, counters
|
||||
assert counters["num_steps_trained"] == 100, counters
|
||||
|
|
|
@ -1,19 +1,13 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.examples.env.multi_agent import BasicMultiAgent, \
|
||||
MultiAgentCartPole
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.examples.env.multi_agent import BasicMultiAgent
|
||||
from ray.rllib.tests.test_rollout_worker import MockPolicy
|
||||
from ray.rllib.tests.test_external_env import make_simple_serving
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
|
||||
SimpleMultiServing = make_simple_serving(True, ExternalMultiAgentEnv)
|
||||
|
||||
|
@ -64,32 +58,6 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
|||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
||||
def test_train_external_multi_agent_cartpole_many_policies(self):
|
||||
n = 20
|
||||
single_env = gym.make("CartPole-v0")
|
||||
act_space = single_env.action_space
|
||||
obs_space = single_env.observation_space
|
||||
policies = {}
|
||||
for i in range(20):
|
||||
policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space,
|
||||
{})
|
||||
policy_ids = list(policies.keys())
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MultiAgentCartPole({"num_agents": n}),
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
rollout_fragment_length=100)
|
||||
optimizer = SyncSamplesOptimizer(WorkerSet._from_existing(ev))
|
||||
for i in range(100):
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev)
|
||||
print("Iteration {}, rew {}".format(i,
|
||||
result["policy_reward_mean"]))
|
||||
print("Total reward", result["episode_reward_mean"])
|
||||
if result["episode_reward_mean"] >= 25 * n:
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
|
|
@ -6,17 +6,12 @@ import ray
|
|||
from ray.tune.registry import register_env
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
|
||||
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
|
||||
BasicMultiAgent, EarlyDoneMultiAgent, RoundRobinMultiAgent
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer,
|
||||
AsyncGradientsOptimizer)
|
||||
from ray.rllib.tests.test_rollout_worker import MockPolicy
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
|
||||
|
||||
|
||||
def one_hot(i, n):
|
||||
|
@ -424,103 +419,6 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
KeyError,
|
||||
lambda: pg.compute_action([0, 0, 0, 0], policy_id="policy_3"))
|
||||
|
||||
def _test_with_optimizer(self, optimizer_cls):
|
||||
n = 3
|
||||
env = gym.make("CartPole-v0")
|
||||
act_space = env.action_space
|
||||
obs_space = env.observation_space
|
||||
dqn_config = {"gamma": 0.95, "n_step": 3}
|
||||
if optimizer_cls == SyncReplayOptimizer:
|
||||
# TODO: support replay with non-DQN graphs. Currently this can't
|
||||
# happen since the replay buffer doesn't encode extra fields like
|
||||
# "advantages" that PG uses.
|
||||
policies = {
|
||||
"p1": (DQNTFPolicy, obs_space, act_space, dqn_config),
|
||||
"p2": (DQNTFPolicy, obs_space, act_space, dqn_config),
|
||||
}
|
||||
else:
|
||||
policies = {
|
||||
"p1": (PGTFPolicy, obs_space, act_space, {}),
|
||||
"p2": (DQNTFPolicy, obs_space, act_space, dqn_config),
|
||||
}
|
||||
worker = RolloutWorker(
|
||||
env_creator=lambda _: MultiAgentCartPole({"num_agents": n}),
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
|
||||
rollout_fragment_length=50)
|
||||
if optimizer_cls == AsyncGradientsOptimizer:
|
||||
|
||||
def policy_mapper(agent_id):
|
||||
return ["p1", "p2"][agent_id % 2]
|
||||
|
||||
remote_workers = [
|
||||
RolloutWorker.as_remote().remote(
|
||||
env_creator=lambda _: MultiAgentCartPole(
|
||||
{"num_agents": n}),
|
||||
policy=policies,
|
||||
policy_mapping_fn=policy_mapper,
|
||||
rollout_fragment_length=50)
|
||||
]
|
||||
else:
|
||||
remote_workers = []
|
||||
workers = WorkerSet._from_existing(worker, remote_workers)
|
||||
optimizer = optimizer_cls(workers)
|
||||
for i in range(200):
|
||||
optimizer.step()
|
||||
result = collect_metrics(worker, remote_workers)
|
||||
if i % 20 == 0:
|
||||
|
||||
def do_update(p):
|
||||
if isinstance(p, DQNTFPolicy):
|
||||
p.update_target()
|
||||
|
||||
worker.foreach_policy(lambda p, _: do_update(p))
|
||||
print("Iter {}, rew {}".format(i,
|
||||
result["policy_reward_mean"]))
|
||||
print("Total reward", result["episode_reward_mean"])
|
||||
if result["episode_reward_mean"] >= 25 * n:
|
||||
return
|
||||
print(result)
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def test_multi_agent_sync_optimizer(self):
|
||||
self._test_with_optimizer(SyncSamplesOptimizer)
|
||||
|
||||
def test_multi_agent_async_gradients_optimizer(self):
|
||||
# Allow to be run via Unittest.
|
||||
ray.init(num_cpus=4, ignore_reinit_error=True)
|
||||
self._test_with_optimizer(AsyncGradientsOptimizer)
|
||||
|
||||
def test_multi_agent_replay_optimizer(self):
|
||||
self._test_with_optimizer(SyncReplayOptimizer)
|
||||
|
||||
def test_train_multi_agent_cartpole_many_policies(self):
|
||||
n = 20
|
||||
env = gym.make("CartPole-v0")
|
||||
act_space = env.action_space
|
||||
obs_space = env.observation_space
|
||||
policies = {}
|
||||
for i in range(20):
|
||||
policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space,
|
||||
{})
|
||||
policy_ids = list(policies.keys())
|
||||
worker = RolloutWorker(
|
||||
env_creator=lambda _: MultiAgentCartPole({"num_agents": n}),
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
rollout_fragment_length=100)
|
||||
workers = WorkerSet._from_existing(worker, [])
|
||||
optimizer = SyncSamplesOptimizer(workers)
|
||||
for i in range(100):
|
||||
optimizer.step()
|
||||
result = collect_metrics(worker)
|
||||
print("Iteration {}, rew {}".format(i,
|
||||
result["policy_reward_mean"]))
|
||||
print("Total reward", result["episode_reward_mean"])
|
||||
if result["episode_reward_mean"] >= 25 * n:
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
|
|
@ -189,10 +189,13 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
print("num_steps_trained={}".format(
|
||||
result["info"]["num_steps_trained"]))
|
||||
if i == 0:
|
||||
self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01)
|
||||
if result["info"]["learner"]["cur_lr"] < 0.07:
|
||||
self.assertGreater(
|
||||
result["info"]["learner"]["default_policy"]["cur_lr"],
|
||||
0.01)
|
||||
if result["info"]["learner"]["default_policy"]["cur_lr"] < 0.07:
|
||||
break
|
||||
self.assertLess(result["info"]["learner"]["cur_lr"], 0.07)
|
||||
self.assertLess(result["info"]["learner"]["default_policy"]["cur_lr"],
|
||||
0.07)
|
||||
|
||||
def test_no_step_on_init(self):
|
||||
# Allow for Unittest run.
|
||||
|
|
|
@ -41,8 +41,7 @@ def DeveloperAPI(obj):
|
|||
to be stable sans minor changes (but less stable than public APIs).
|
||||
|
||||
Subclasses that inherit from a ``@DeveloperAPI`` base class can be
|
||||
assumed part of the RLlib developer API as well (e.g., all policy
|
||||
optimizers are developer API because PolicyOptimizer is ``@DeveloperAPI``).
|
||||
assumed part of the RLlib developer API as well.
|
||||
"""
|
||||
|
||||
return obj
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -21,7 +20,7 @@ class FilterManager:
|
|||
remotes (list): Remote evaluators with filters.
|
||||
update_remote (bool): Whether to push updates to remote filters.
|
||||
"""
|
||||
remote_filters = ray_get_and_free(
|
||||
remote_filters = ray.get(
|
||||
[r.get_filters.remote(flush_after=True) for r in remotes])
|
||||
for rf in remote_filters:
|
||||
for k in local_filters:
|
||||
|
|
|
@ -1,53 +1,4 @@
|
|||
import numpy as np
|
||||
import time
|
||||
import os
|
||||
|
||||
import ray
|
||||
|
||||
FREE_DELAY_S = 10.0
|
||||
MAX_FREE_QUEUE_SIZE = 100
|
||||
_last_free_time = 0.0
|
||||
_to_free = []
|
||||
|
||||
# TODO(ekl) remove this feature entirely. It's here for now just in case we
|
||||
# need to turn it on for debugging.
|
||||
RLLIB_DEBUG_EXPLICIT_FREE = bool(os.environ.get("RLLIB_DEBUG_EXPLICIT_FREE"))
|
||||
|
||||
|
||||
def ray_get_and_free(object_ids):
|
||||
"""Call ray.get and then queue the object ids for deletion.
|
||||
|
||||
This function should be used whenever possible in RLlib, to optimize
|
||||
memory usage. The only exception is when an object_id is shared among
|
||||
multiple readers.
|
||||
|
||||
Args:
|
||||
object_ids (ObjectID|List[ObjectID]): Object ids to fetch and free.
|
||||
|
||||
Returns:
|
||||
The result of ray.get(object_ids).
|
||||
"""
|
||||
|
||||
if not RLLIB_DEBUG_EXPLICIT_FREE:
|
||||
return ray.get(object_ids)
|
||||
|
||||
global _last_free_time
|
||||
global _to_free
|
||||
|
||||
result = ray.get(object_ids)
|
||||
if type(object_ids) is not list:
|
||||
object_ids = [object_ids]
|
||||
_to_free.extend(object_ids)
|
||||
|
||||
# batch calls to free to reduce overheads
|
||||
now = time.time()
|
||||
if (len(_to_free) > MAX_FREE_QUEUE_SIZE
|
||||
or now - _last_free_time > FREE_DELAY_S):
|
||||
ray.internal.free(_to_free)
|
||||
_to_free = []
|
||||
_last_free_time = now
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def aligned_array(size, dtype, align=64):
|
||||
|
|
Loading…
Add table
Reference in a new issue