[rllib] Port QMIX, MADDPG to new execution API (#8344)

This commit is contained in:
Eric Liang 2020-05-07 23:41:10 -07:00 committed by GitHub
parent 9f04a65922
commit 2c599dbf05
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 135 additions and 10 deletions

View file

@ -157,7 +157,9 @@ def execution_plan(workers: WorkerSet, config: dict):
# (2) Read experiences from the replay buffer actors and send to the
# learner thread via its in-queue.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
replay_op = Replay(actors=replay_actors, num_async=4) \
.for_each(lambda x: post_fn(x, workers, config)) \
.zip_with_source_actor() \
.for_each(Enqueue(learner_thread.inqueue))

View file

@ -84,6 +84,11 @@ DEFAULT_CONFIG = with_common_config({
"prioritized_replay_eps": 1e-6,
# Whether to LZ4 compress observations
"compress_observations": False,
# In multi-agent mode, whether to replay experiences from the same time
# step for all policies. This is required for MADDPG.
"multiagent_sync_replay": False,
# Callback to run before learning on a multi-agent batch of experiences.
"before_learn_on_batch": None,
# === Optimization ===
# Learning rate for adam optimizer
@ -312,6 +317,7 @@ def execution_plan(workers, config):
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
multiagent_sync_replay=config.get("multiagent_sync_replay"),
**prio_args)
rollouts = ParallelRollouts(workers, mode="bulk_sync")
@ -341,7 +347,9 @@ def execution_plan(workers, config):
# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target network.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
replay_op = Replay(local_buffer=local_replay_buffer) \
.for_each(lambda x: post_fn(x, workers, config)) \
.for_each(TrainOneStep(workers)) \
.for_each(update_prio) \
.for_each(UpdateTargetNetwork(

View file

@ -1,6 +1,10 @@
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
from ray.rllib.execution.replay_ops import MixInReplay
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.optimizers import SyncBatchReplayOptimizer
# yapf: disable
@ -82,9 +86,6 @@ DEFAULT_CONFIG = with_common_config({
"lstm_cell_size": 64,
"max_seq_len": 999999,
},
# TODO(ekl) support sync batch replay.
"use_exec_api": False,
})
# __sphinx_doc_end__
# yapf: enable
@ -98,9 +99,25 @@ def make_sync_batch_optimizer(workers, config):
train_batch_size=config["train_batch_size"])
# Experimental distributed execution impl; enable with "use_exec_api": True.
def execution_plan(workers, config):
rollouts = ParallelRollouts(workers, mode="bulk_sync")
train_op = rollouts \
.for_each(MixInReplay(config["buffer_size"])) \
.combine(
ConcatBatches(min_batch_size=config["train_batch_size"])) \
.for_each(TrainOneStep(workers)) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))
return StandardMetricsReporting(train_op, workers, config)
QMixTrainer = GenericOffPolicyTrainer.with_updates(
name="QMIX",
default_config=DEFAULT_CONFIG,
default_policy=QMixTorchPolicy,
get_policy_class=None,
make_policy_optimizer=make_sync_batch_optimizer)
make_policy_optimizer=make_sync_batch_optimizer,
execution_plan=execution_plan)

View file

@ -67,6 +67,9 @@ DEFAULT_CONFIG = with_common_config({
# Observation compression. Note that compression makes simulation slow in
# MPE.
"compress_observations": False,
# In multi-agent mode, whether to replay experiences from the same time
# step for all policies. This is required for MADDPG.
"multiagent_sync_replay": True,
# === Optimization ===
# Learning rate for the critic (Q-function) optimizer.
@ -100,9 +103,6 @@ DEFAULT_CONFIG = with_common_config({
"num_workers": 1,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 0,
# TODO(ekl) support synchronized sampling.
"use_exec_api": False,
})
# __sphinx_doc_end__
# yapf: enable
@ -171,10 +171,28 @@ def collect_metrics(trainer):
return result
def add_maddpg_postprocessing(config):
"""Add the before learn on batch hook.
This hook is called explicitly prior to TrainOneStep() in the execution
setups for DQN and APEX.
"""
def f(batch, workers, config):
policies = dict(workers.local_worker()
.foreach_trainable_policy(lambda p, i: (i, p)))
return before_learn_on_batch(batch, policies,
config["train_batch_size"])
config["before_learn_on_batch"] = f
return config
MADDPGTrainer = GenericOffPolicyTrainer.with_updates(
name="MADDPG",
default_config=DEFAULT_CONFIG,
default_policy=MADDPGTFPolicy,
validate_config=add_maddpg_postprocessing,
get_policy_class=None,
before_init=None,
before_train_step=set_global_timestep,

View file

@ -91,3 +91,74 @@ def Replay(*,
yield item
return LocalIterator(gen_replay, SharedMetrics())
class MixInReplay:
"""This operator adds replay to a stream of experiences.
It takes input batches, and returns a list of batches that include replayed
data as well. The number of replayed batches is determined by the
configured replay proportion. The max age of a batch is determined by the
number of replay slots.
"""
def __init__(self, num_slots, replay_proportion: float = None):
"""Initialize MixInReplay.
Args:
num_slots (int): Number of batches to store in total.
replay_proportion (float): If None, one batch will be replayed per
each input batch. Otherwise, the input batch will be returned
and an additional number of batches proportional to this value
will be added as well.
Examples:
# 1:1 mode (default)
>>> replay_op = MixInReplay(rollouts, 100)
>>> print(next(replay_op))
SampleBatch(<replay>)
# proportional mode
>>> replay_op = MixInReplay(rollouts, 100, replay_proportion=2)
>>> print(next(replay_op))
[SampleBatch(<input>), SampleBatch(<replay>), SampleBatch(<rep.>)]
# proportional mode, replay disabled
>>> replay_op = MixInReplay(rollouts, 100, replay_proportion=0)
>>> print(next(replay_op))
[SampleBatch(<input>)]
"""
if replay_proportion is not None:
if replay_proportion > 0 and num_slots == 0:
raise ValueError(
"You must set num_slots > 0 if replay_proportion > 0.")
elif num_slots == 0:
raise ValueError(
"You must set num_slots > 0 if replay_proportion = None.")
self.num_slots = num_slots
self.replay_proportion = replay_proportion
self.replay_batches = []
self.replay_index = 0
def __call__(self, sample_batch):
# Put in replay buffer if enabled.
if self.num_slots > 0:
if len(self.replay_batches) < self.num_slots:
self.replay_batches.append(sample_batch)
else:
self.replay_batches[self.replay_index] = sample_batch
self.replay_index += 1
self.replay_index %= self.num_slots
# 1:1 replay mode.
if self.replay_proportion is None:
return random.choice(self.replay_batches)
# Proportional replay mode.
output_batches = [sample_batch]
f = self.replay_proportion
while random.random() < f:
f -= 1
replay_batch = random.choice(self.replay_batches)
output_batches.append(replay_batch)
return output_batches

View file

@ -303,12 +303,14 @@ class LocalReplayBuffer(ParallelIteratorWorker):
replay_batch_size,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6):
prioritized_replay_eps=1e-6,
multiagent_sync_replay=False):
self.replay_starts = learning_starts // num_shards
self.buffer_size = buffer_size // num_shards
self.replay_batch_size = replay_batch_size
self.prioritized_replay_beta = prioritized_replay_beta
self.prioritized_replay_eps = prioritized_replay_eps
self.multiagent_sync_replay = multiagent_sync_replay
def gen_replay():
while True:
@ -369,10 +371,17 @@ class LocalReplayBuffer(ParallelIteratorWorker):
with self.replay_timer:
samples = {}
idxes = None
for policy_id, replay_buffer in self.replay_buffers.items():
if self.multiagent_sync_replay:
if idxes is None:
idxes = replay_buffer.sample_idxes(
self.replay_batch_size)
else:
idxes = replay_buffer.sample_idxes(self.replay_batch_size)
(obses_t, actions, rewards, obses_tp1, dones, weights,
batch_indexes) = replay_buffer.sample(
self.replay_batch_size, beta=self.prioritized_replay_beta)
batch_indexes) = replay_buffer.sample_with_idxes(
idxes, beta=self.prioritized_replay_beta)
samples[policy_id] = SampleBatch({
"obs": obses_t,
"actions": actions,