mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Port QMIX, MADDPG to new execution API (#8344)
This commit is contained in:
parent
9f04a65922
commit
2c599dbf05
6 changed files with 135 additions and 10 deletions
|
@ -157,7 +157,9 @@ def execution_plan(workers: WorkerSet, config: dict):
|
|||
|
||||
# (2) Read experiences from the replay buffer actors and send to the
|
||||
# learner thread via its in-queue.
|
||||
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
replay_op = Replay(actors=replay_actors, num_async=4) \
|
||||
.for_each(lambda x: post_fn(x, workers, config)) \
|
||||
.zip_with_source_actor() \
|
||||
.for_each(Enqueue(learner_thread.inqueue))
|
||||
|
||||
|
|
|
@ -84,6 +84,11 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"prioritized_replay_eps": 1e-6,
|
||||
# Whether to LZ4 compress observations
|
||||
"compress_observations": False,
|
||||
# In multi-agent mode, whether to replay experiences from the same time
|
||||
# step for all policies. This is required for MADDPG.
|
||||
"multiagent_sync_replay": False,
|
||||
# Callback to run before learning on a multi-agent batch of experiences.
|
||||
"before_learn_on_batch": None,
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for adam optimizer
|
||||
|
@ -312,6 +317,7 @@ def execution_plan(workers, config):
|
|||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
replay_batch_size=config["train_batch_size"],
|
||||
multiagent_sync_replay=config.get("multiagent_sync_replay"),
|
||||
**prio_args)
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
@ -341,7 +347,9 @@ def execution_plan(workers, config):
|
|||
# (2) Read and train on experiences from the replay buffer. Every batch
|
||||
# returned from the LocalReplay() iterator is passed to TrainOneStep to
|
||||
# take a SGD step, and then we decide whether to update the target network.
|
||||
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
replay_op = Replay(local_buffer=local_replay_buffer) \
|
||||
.for_each(lambda x: post_fn(x, workers, config)) \
|
||||
.for_each(TrainOneStep(workers)) \
|
||||
.for_each(update_prio) \
|
||||
.for_each(UpdateTargetNetwork(
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
|
||||
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
|
||||
from ray.rllib.execution.replay_ops import MixInReplay
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.optimizers import SyncBatchReplayOptimizer
|
||||
|
||||
# yapf: disable
|
||||
|
@ -82,9 +86,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"lstm_cell_size": 64,
|
||||
"max_seq_len": 999999,
|
||||
},
|
||||
|
||||
# TODO(ekl) support sync batch replay.
|
||||
"use_exec_api": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -98,9 +99,25 @@ def make_sync_batch_optimizer(workers, config):
|
|||
train_batch_size=config["train_batch_size"])
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
train_op = rollouts \
|
||||
.for_each(MixInReplay(config["buffer_size"])) \
|
||||
.combine(
|
||||
ConcatBatches(min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(workers)) \
|
||||
.for_each(UpdateTargetNetwork(
|
||||
workers, config["target_network_update_freq"]))
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
QMixTrainer = GenericOffPolicyTrainer.with_updates(
|
||||
name="QMIX",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=QMixTorchPolicy,
|
||||
get_policy_class=None,
|
||||
make_policy_optimizer=make_sync_batch_optimizer)
|
||||
make_policy_optimizer=make_sync_batch_optimizer,
|
||||
execution_plan=execution_plan)
|
||||
|
|
|
@ -67,6 +67,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Observation compression. Note that compression makes simulation slow in
|
||||
# MPE.
|
||||
"compress_observations": False,
|
||||
# In multi-agent mode, whether to replay experiences from the same time
|
||||
# step for all policies. This is required for MADDPG.
|
||||
"multiagent_sync_replay": True,
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for the critic (Q-function) optimizer.
|
||||
|
@ -100,9 +103,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_workers": 1,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_iter_time_s": 0,
|
||||
|
||||
# TODO(ekl) support synchronized sampling.
|
||||
"use_exec_api": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -171,10 +171,28 @@ def collect_metrics(trainer):
|
|||
return result
|
||||
|
||||
|
||||
def add_maddpg_postprocessing(config):
|
||||
"""Add the before learn on batch hook.
|
||||
|
||||
This hook is called explicitly prior to TrainOneStep() in the execution
|
||||
setups for DQN and APEX.
|
||||
"""
|
||||
|
||||
def f(batch, workers, config):
|
||||
policies = dict(workers.local_worker()
|
||||
.foreach_trainable_policy(lambda p, i: (i, p)))
|
||||
return before_learn_on_batch(batch, policies,
|
||||
config["train_batch_size"])
|
||||
|
||||
config["before_learn_on_batch"] = f
|
||||
return config
|
||||
|
||||
|
||||
MADDPGTrainer = GenericOffPolicyTrainer.with_updates(
|
||||
name="MADDPG",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=MADDPGTFPolicy,
|
||||
validate_config=add_maddpg_postprocessing,
|
||||
get_policy_class=None,
|
||||
before_init=None,
|
||||
before_train_step=set_global_timestep,
|
||||
|
|
|
@ -91,3 +91,74 @@ def Replay(*,
|
|||
yield item
|
||||
|
||||
return LocalIterator(gen_replay, SharedMetrics())
|
||||
|
||||
|
||||
class MixInReplay:
|
||||
"""This operator adds replay to a stream of experiences.
|
||||
|
||||
It takes input batches, and returns a list of batches that include replayed
|
||||
data as well. The number of replayed batches is determined by the
|
||||
configured replay proportion. The max age of a batch is determined by the
|
||||
number of replay slots.
|
||||
"""
|
||||
|
||||
def __init__(self, num_slots, replay_proportion: float = None):
|
||||
"""Initialize MixInReplay.
|
||||
|
||||
Args:
|
||||
num_slots (int): Number of batches to store in total.
|
||||
replay_proportion (float): If None, one batch will be replayed per
|
||||
each input batch. Otherwise, the input batch will be returned
|
||||
and an additional number of batches proportional to this value
|
||||
will be added as well.
|
||||
|
||||
Examples:
|
||||
# 1:1 mode (default)
|
||||
>>> replay_op = MixInReplay(rollouts, 100)
|
||||
>>> print(next(replay_op))
|
||||
SampleBatch(<replay>)
|
||||
|
||||
# proportional mode
|
||||
>>> replay_op = MixInReplay(rollouts, 100, replay_proportion=2)
|
||||
>>> print(next(replay_op))
|
||||
[SampleBatch(<input>), SampleBatch(<replay>), SampleBatch(<rep.>)]
|
||||
|
||||
# proportional mode, replay disabled
|
||||
>>> replay_op = MixInReplay(rollouts, 100, replay_proportion=0)
|
||||
>>> print(next(replay_op))
|
||||
[SampleBatch(<input>)]
|
||||
"""
|
||||
if replay_proportion is not None:
|
||||
if replay_proportion > 0 and num_slots == 0:
|
||||
raise ValueError(
|
||||
"You must set num_slots > 0 if replay_proportion > 0.")
|
||||
elif num_slots == 0:
|
||||
raise ValueError(
|
||||
"You must set num_slots > 0 if replay_proportion = None.")
|
||||
self.num_slots = num_slots
|
||||
self.replay_proportion = replay_proportion
|
||||
self.replay_batches = []
|
||||
self.replay_index = 0
|
||||
|
||||
def __call__(self, sample_batch):
|
||||
# Put in replay buffer if enabled.
|
||||
if self.num_slots > 0:
|
||||
if len(self.replay_batches) < self.num_slots:
|
||||
self.replay_batches.append(sample_batch)
|
||||
else:
|
||||
self.replay_batches[self.replay_index] = sample_batch
|
||||
self.replay_index += 1
|
||||
self.replay_index %= self.num_slots
|
||||
|
||||
# 1:1 replay mode.
|
||||
if self.replay_proportion is None:
|
||||
return random.choice(self.replay_batches)
|
||||
|
||||
# Proportional replay mode.
|
||||
output_batches = [sample_batch]
|
||||
f = self.replay_proportion
|
||||
while random.random() < f:
|
||||
f -= 1
|
||||
replay_batch = random.choice(self.replay_batches)
|
||||
output_batches.append(replay_batch)
|
||||
return output_batches
|
||||
|
|
|
@ -303,12 +303,14 @@ class LocalReplayBuffer(ParallelIteratorWorker):
|
|||
replay_batch_size,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6):
|
||||
prioritized_replay_eps=1e-6,
|
||||
multiagent_sync_replay=False):
|
||||
self.replay_starts = learning_starts // num_shards
|
||||
self.buffer_size = buffer_size // num_shards
|
||||
self.replay_batch_size = replay_batch_size
|
||||
self.prioritized_replay_beta = prioritized_replay_beta
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
self.multiagent_sync_replay = multiagent_sync_replay
|
||||
|
||||
def gen_replay():
|
||||
while True:
|
||||
|
@ -369,10 +371,17 @@ class LocalReplayBuffer(ParallelIteratorWorker):
|
|||
|
||||
with self.replay_timer:
|
||||
samples = {}
|
||||
idxes = None
|
||||
for policy_id, replay_buffer in self.replay_buffers.items():
|
||||
if self.multiagent_sync_replay:
|
||||
if idxes is None:
|
||||
idxes = replay_buffer.sample_idxes(
|
||||
self.replay_batch_size)
|
||||
else:
|
||||
idxes = replay_buffer.sample_idxes(self.replay_batch_size)
|
||||
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
||||
batch_indexes) = replay_buffer.sample(
|
||||
self.replay_batch_size, beta=self.prioritized_replay_beta)
|
||||
batch_indexes) = replay_buffer.sample_with_idxes(
|
||||
idxes, beta=self.prioritized_replay_beta)
|
||||
samples[policy_id] = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
|
|
Loading…
Add table
Reference in a new issue