ray/rllib/agents/dqn/apex.py

695 lines
30 KiB
Python

"""
Distributed Prioritized Experience Replay (Ape-X)
=================================================
This file defines a DQN trainer using the Ape-X architecture.
Ape-X uses a single GPU learner and many CPU workers for experience collection.
Experience collection can scale to hundreds of CPU workers due to the
distributed prioritization of experience prior to storage in replay buffers.
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#distributed-prioritized-experience-replay-ape-x
""" # noqa: E501
import queue
from collections import defaultdict
import copy
import platform
import random
from typing import Tuple, Dict, List, DefaultDict, Set
import ray
from ray.actor import ActorHandle
from ray.rllib import RolloutWorker
from ray.rllib.agents import Trainer
from ray.rllib.agents.dqn.dqn import (
calculate_rr_weights,
DEFAULT_CONFIG as DQN_DEFAULT_CONFIG,
DQNTrainer,
)
from ray.rllib.agents.dqn.learner_thread import LearnerThread
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
_get_global_vars,
_get_shared_metrics,
)
from ray.rllib.execution.concurrency_ops import Concurrently, Dequeue, Enqueue
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.buffers.multi_agent_replay_buffer import ReplayActor
from ray.rllib.execution.parallel_requests import (
asynchronous_parallel_requests,
wait_asynchronous_requests,
)
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import UpdateTargetNetwork
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
NUM_TARGET_UPDATES,
SAMPLE_TIMER,
SYNCH_WORKER_WEIGHTS_TIMER,
TARGET_NET_UPDATE_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.typing import (
SampleBatchType,
TrainerConfigDict,
ResultDict,
PartialTrainerConfigDict,
T,
)
from ray.tune.trainable import Trainable
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.util.iter import LocalIterator
# fmt: off
# __sphinx_doc_begin__
APEX_DEFAULT_CONFIG = merge_dicts(
# See also the options in dqn.py, which are also supported.
DQN_DEFAULT_CONFIG,
{
"optimizer": merge_dicts(
DQN_DEFAULT_CONFIG["optimizer"], {
"max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4,
"debug": False
}),
"n_step": 3,
"num_gpus": 1,
"num_workers": 32,
# TODO(jungong) : add proper replay_buffer_config after
# DistributedReplayBuffer type is supported.
"replay_buffer_config": {
# For now we don't use the new ReplayBuffer API here
"_enable_replay_buffer_api": False,
"no_local_replay_buffer": True,
"type": "MultiAgentReplayBuffer",
"capacity": 2000000,
"replay_batch_size": 32,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
},
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"learning_starts": 50000,
"train_batch_size": 512,
"rollout_fragment_length": 50,
"target_network_update_freq": 500000,
"timesteps_per_iteration": 25000,
"exploration_config": {"type": "PerWorkerEpsilonGreedy"},
"worker_side_prioritization": True,
"min_time_s_per_reporting": 30,
# This will set the ratio of replayed from a buffer and learned
# on timesteps to sampled from an environment and stored in the replay
# buffer timesteps. Must be greater than 0.
# TODO: Find a way to support None again as a means to replay
# proceeding as fast as possible.
"training_intensity": 1,
# Use `training_iteration` instead of `execution_plan` by default.
"_disable_execution_plan_api": True,
},
)
# __sphinx_doc_end__
# fmt: on
# Update worker weights as they finish generating experiences.
class UpdateWorkerWeights:
def __init__(
self,
learner_thread: LearnerThread,
workers: WorkerSet,
max_weight_sync_delay: int,
):
self.learner_thread = learner_thread
self.workers = workers
self.steps_since_update = defaultdict(int)
self.max_weight_sync_delay = max_weight_sync_delay
self.weights = None
def __call__(self, item: Tuple[ActorHandle, SampleBatchType]):
actor, batch = item
self.steps_since_update[actor] += batch.count
if self.steps_since_update[actor] >= self.max_weight_sync_delay:
# Note that it's important to pull new weights once
# updated to avoid excessive correlation between actors.
if self.weights is None or self.learner_thread.weights_updated:
self.learner_thread.weights_updated = False
self.weights = ray.put(self.workers.local_worker().get_weights())
actor.set_weights.remote(self.weights, _get_global_vars())
# Also update global vars of the local worker.
self.workers.local_worker().set_global_vars(_get_global_vars())
self.steps_since_update[actor] = 0
# Update metrics.
metrics = _get_shared_metrics()
metrics.counters["num_weight_syncs"] += 1
class ApexTrainer(DQNTrainer):
@override(Trainable)
def setup(self, config: PartialTrainerConfigDict):
super().setup(config)
# Shortcut: If execution_plan, thread and buffer will be created in there.
if self.config["_disable_execution_plan_api"] is False:
return
# Tag those workers (top 1/3rd indices) that we should collect episodes from
# for metrics due to `PerWorkerEpsilonGreedy` exploration strategy.
if self.workers.remote_workers():
self._remote_workers_for_metrics = self.workers.remote_workers()[
-len(self.workers.remote_workers()) // 3 :
]
num_replay_buffer_shards = self.config["optimizer"]["num_replay_buffer_shards"]
buffer_size = (
self.config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards
)
replay_actor_args = [
num_replay_buffer_shards,
self.config["learning_starts"],
buffer_size,
self.config["train_batch_size"],
self.config["replay_buffer_config"]["prioritized_replay_alpha"],
self.config["replay_buffer_config"]["prioritized_replay_beta"],
self.config["replay_buffer_config"]["prioritized_replay_eps"],
self.config["multiagent"]["replay_mode"],
self.config["replay_buffer_config"].get("replay_sequence_length", 1),
]
# Place all replay buffer shards on the same node as the learner
# (driver process that runs this execution plan).
if self.config["replay_buffer_shards_colocated_with_driver"]:
self.replay_actors = create_colocated_actors(
actor_specs=[ # (class, args, kwargs={}, count)
(ReplayActor, replay_actor_args, {}, num_replay_buffer_shards)
],
node=platform.node(), # localhost
)[
0
] # [0]=only one item in `actor_specs`.
# Place replay buffer shards on any node(s).
else:
self.replay_actors = [
ReplayActor.remote(*replay_actor_args)
for _ in range(num_replay_buffer_shards)
]
self.learner_thread = LearnerThread(self.workers.local_worker())
self.learner_thread.start()
self.steps_since_update = defaultdict(int)
weights = self.workers.local_worker().get_weights()
self.curr_learner_weights = ray.put(weights)
self.remote_sampling_requests_in_flight: DefaultDict[
ActorHandle, Set[ray.ObjectRef]
] = defaultdict(set)
self.remote_replay_requests_in_flight: DefaultDict[
ActorHandle, Set[ray.ObjectRef]
] = defaultdict(set)
self.curr_num_samples_collected = 0
self.replay_sample_batches = []
self._num_ts_trained_since_last_target_update = 0
@classmethod
@override(DQNTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return APEX_DEFAULT_CONFIG
@override(DQNTrainer)
def validate_config(self, config):
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for APEX-DQN!")
# Call DQN's validation method.
super().validate_config(config)
# if config["_disable_execution_plan_api"]:
# if not config.get("training_intensity", 1.0) > 0:
# raise ValueError("training_intensity must be > 0")
@override(Trainable)
def training_iteration(self) -> ResultDict:
num_samples_ready_dict = self.get_samples_and_store_to_replay_buffers()
worker_samples_collected = defaultdict(int)
for worker, samples_infos in num_samples_ready_dict.items():
for samples_info in samples_infos:
self._counters[NUM_AGENT_STEPS_SAMPLED] += samples_info["agent_steps"]
self._counters[NUM_ENV_STEPS_SAMPLED] += samples_info["env_steps"]
worker_samples_collected[worker] += samples_info["agent_steps"]
# update the weights of the workers that returned samples
# only do this if there are remote workers (config["num_workers"] > 1)
if self.workers.remote_workers():
self.update_workers(worker_samples_collected)
# trigger a sample from the replay actors and enqueue operation to the
# learner thread.
self.sample_from_replay_buffer_place_on_learner_queue_non_blocking(
worker_samples_collected
)
self.update_replay_sample_priority()
return copy.deepcopy(self.learner_thread.learner_info)
@staticmethod
@override(DQNTrainer)
def execution_plan(
workers: WorkerSet, config: dict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "Apex execution_plan does NOT take any additional parameters"
# Create a number of replay buffer actors.
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
buffer_size = (
config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards
)
replay_actor_args = [
num_replay_buffer_shards,
config["learning_starts"],
buffer_size,
config["train_batch_size"],
config["replay_buffer_config"]["prioritized_replay_alpha"],
config["replay_buffer_config"]["prioritized_replay_beta"],
config["replay_buffer_config"]["prioritized_replay_eps"],
config["multiagent"]["replay_mode"],
config["replay_buffer_config"].get("replay_sequence_length", 1),
]
# Place all replay buffer shards on the same node as the learner
# (driver process that runs this execution plan).
if config["replay_buffer_shards_colocated_with_driver"]:
replay_actors = create_colocated_actors(
actor_specs=[
# (class, args, kwargs={}, count)
(ReplayActor, replay_actor_args, {}, num_replay_buffer_shards)
],
node=platform.node(), # localhost
)[
0
] # [0]=only one item in `actor_specs`.
# Place replay buffer shards on any node(s).
else:
replay_actors = [
ReplayActor(*replay_actor_args) for _ in range(num_replay_buffer_shards)
]
# Start the learner thread.
learner_thread = LearnerThread(workers.local_worker())
learner_thread.start()
# Update experience priorities post learning.
def update_prio_and_stats(item: Tuple[ActorHandle, dict, int, int]) -> None:
actor, prio_dict, env_count, agent_count = item
if config.get("prioritized_replay"):
actor.update_priorities.remote(prio_dict)
metrics = _get_shared_metrics()
# Manually update the steps trained counter since the learner
# thread is executing outside the pipeline.
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = env_count
metrics.counters[STEPS_TRAINED_COUNTER] += env_count
metrics.timers["learner_dequeue"] = learner_thread.queue_timer
metrics.timers["learner_grad"] = learner_thread.grad_timer
metrics.timers["learner_overall"] = learner_thread.overall_timer
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in one of our replay buffer
# actors. Update the weights of the worker that generated the batch.
rollouts = ParallelRollouts(workers, mode="async", num_async=2)
store_op = rollouts.for_each(StoreToReplayBuffer(actors=replay_actors))
# Only need to update workers if there are remote workers.
if workers.remote_workers():
store_op = store_op.zip_with_source_actor().for_each(
UpdateWorkerWeights(
learner_thread,
workers,
max_weight_sync_delay=(
config["optimizer"]["max_weight_sync_delay"]
),
)
)
# (2) Read experiences from one of the replay buffer actors and send
# to the learner thread via its in-queue.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
replay_op = (
Replay(actors=replay_actors, num_async=4)
.for_each(lambda x: post_fn(x, workers, config))
.zip_with_source_actor()
.for_each(Enqueue(learner_thread.inqueue))
)
# (3) Get priorities back from learner thread and apply them to the
# replay buffer actors.
update_op = (
Dequeue(learner_thread.outqueue, check=learner_thread.is_alive)
.for_each(update_prio_and_stats)
.for_each(
UpdateTargetNetwork(
workers, config["target_network_update_freq"], by_steps_trained=True
)
)
)
if config["training_intensity"]:
# Execute (1), (2) with a fixed intensity ratio.
rr_weights = calculate_rr_weights(config) + ["*"]
merged_op = Concurrently(
[store_op, replay_op, update_op],
mode="round_robin",
output_indexes=[2],
round_robin_weights=rr_weights,
)
else:
# Execute (1), (2), (3) asynchronously as fast as possible. Only
# output items from (3) since metrics aren't available before
# then.
merged_op = Concurrently(
[store_op, replay_op, update_op], mode="async", output_indexes=[2]
)
# Add in extra replay and learner metrics to the training result.
def add_apex_metrics(result: dict) -> dict:
replay_stats = ray.get(
replay_actors[0].stats.remote(config["optimizer"].get("debug"))
)
exploration_infos = workers.foreach_policy_to_train(
lambda p, _: p.get_exploration_state()
)
result["info"].update(
{
"exploration_infos": exploration_infos,
"learner_queue": learner_thread.learner_queue_size.stats(),
LEARNER_INFO: copy.deepcopy(learner_thread.learner_info),
"replay_shard_0": replay_stats,
}
)
return result
# Only report metrics from the workers with the lowest 1/3 of
# epsilons.
selected_workers = workers.remote_workers()[
-len(workers.remote_workers()) // 3 :
]
return StandardMetricsReporting(
merged_op, workers, config, selected_workers=selected_workers
).for_each(add_apex_metrics)
def get_samples_and_store_to_replay_buffers(self):
# in the case the num_workers = 0
if not self.workers.remote_workers():
with self._timers[SAMPLE_TIMER]:
local_sampling_worker = self.workers.local_worker()
batch = local_sampling_worker.sample()
actor = random.choice(self.replay_actors)
ray.get(actor.add_batch.remote(batch))
batch_statistics = {
local_sampling_worker: [
{
"agent_steps": batch.agent_steps(),
"env_steps": batch.env_steps(),
}
]
}
return batch_statistics
def remote_worker_sample_and_store(
worker: RolloutWorker, replay_actors: List[ReplayActor]
):
# This function is run as a remote function on sampling workers,
# and should only be used with the RolloutWorker's apply function ever.
# It is used to gather samples, and trigger the operation to store them to
# replay actors from the rollout worker instead of returning the obj
# refs for the samples to the driver process and doing the sampling
# operation on there.
_batch = worker.sample()
_actor = random.choice(replay_actors)
_actor.add_batch.remote(_batch)
_batch_statistics = {
"agent_steps": _batch.agent_steps(),
"env_steps": _batch.env_steps(),
}
return _batch_statistics
# Sample and Store in the Replay Actors on the sampling workers.
with self._timers[SAMPLE_TIMER]:
# Results are a mapping from ActorHandle (RolloutWorker) to their
# returned gradient calculation results.
num_samples_ready_dict: Dict[
ActorHandle, T
] = asynchronous_parallel_requests(
remote_requests_in_flight=self.remote_sampling_requests_in_flight,
actors=self.workers.remote_workers(),
ray_wait_timeout_s=0.1,
max_remote_requests_in_flight_per_actor=4,
remote_fn=remote_worker_sample_and_store,
remote_kwargs=[{"replay_actors": self.replay_actors}]
* len(self.workers.remote_workers()),
)
return num_samples_ready_dict
def update_workers(self, _num_samples_ready: Dict[ActorHandle, int]) -> int:
"""Update the remote workers that have samples ready.
Args:
_num_samples_ready: A mapping from ActorHandle (RolloutWorker) to
the number of samples returned by the remote worker.
Returns:
The number of remote workers whose weights were updated.
"""
max_steps_weight_sync_delay = self.config["optimizer"]["max_weight_sync_delay"]
# Update our local copy of the weights if the learner thread has updated
# the learner worker's weights
if self.learner_thread.weights_updated:
self.learner_thread.weights_updated = False
weights = self.workers.local_worker().get_weights()
self.curr_learner_weights = ray.put(weights)
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
for (
remote_sampler_worker,
num_samples_collected,
) in _num_samples_ready.items():
self.steps_since_update[remote_sampler_worker] += num_samples_collected
if (
self.steps_since_update[remote_sampler_worker]
>= max_steps_weight_sync_delay
):
remote_sampler_worker.set_weights.remote(
self.curr_learner_weights,
{"timestep": self._counters[STEPS_TRAINED_COUNTER]},
)
self.steps_since_update[remote_sampler_worker] = 0
self._counters["num_weight_syncs"] += 1
def sample_from_replay_buffer_place_on_learner_queue_non_blocking(
self, num_samples_collected: Dict[ActorHandle, int]
) -> None:
"""Get samples from the replay buffer and place them on the learner queue.
Args:
num_samples_collected: A mapping from ActorHandle (RolloutWorker) to
number of samples returned by the remote worker. This is used to
implement training intensity which is the concept of triggering a
certain amount of training based on the number of samples that have
been collected since the last time that training was triggered.
"""
def wait_on_replay_actors(timeout: float) -> None:
"""Wait for the replay actors to finish sampling for timeout seconds.
If the timeout is None, then block on the actors indefinitely.
"""
replay_samples_ready: Dict[ActorHandle, T] = wait_asynchronous_requests(
remote_requests_in_flight=self.remote_replay_requests_in_flight,
ray_wait_timeout_s=timeout,
)
for replay_actor, sample_batches in replay_samples_ready.items():
for sample_batch in sample_batches:
self.replay_sample_batches.append((replay_actor, sample_batch))
num_samples_collected = sum(num_samples_collected.values())
self.curr_num_samples_collected += num_samples_collected
if self.curr_num_samples_collected >= self.config["train_batch_size"]:
wait_on_replay_actors(None)
training_intensity = int(self.config["training_intensity"] or 1)
num_requests_to_launch = (
self.curr_num_samples_collected / self.config["train_batch_size"]
) * training_intensity
num_requests_to_launch = max(1, round(num_requests_to_launch))
self.curr_num_samples_collected = 0
for _ in range(num_requests_to_launch):
rand_actor = random.choice(self.replay_actors)
replay_samples_ready: Dict[
ActorHandle, T
] = asynchronous_parallel_requests(
remote_requests_in_flight=self.remote_replay_requests_in_flight,
actors=[rand_actor],
ray_wait_timeout_s=0.1,
max_remote_requests_in_flight_per_actor=num_requests_to_launch,
remote_fn=lambda actor: actor.replay(),
)
for replay_actor, sample_batches in replay_samples_ready.items():
for sample_batch in sample_batches:
self.replay_sample_batches.append((replay_actor, sample_batch))
wait_on_replay_actors(0.1)
# add the sample batches to the learner queue
while self.replay_sample_batches:
try:
item = self.replay_sample_batches[0]
# the replay buffer returns none if it has not been filled to
# the minimum threshold yet.
if item:
self.learner_thread.inqueue.put(
self.replay_sample_batches[0], timeout=0.001
)
self.replay_sample_batches.pop(0)
except queue.Full:
break
def update_replay_sample_priority(self) -> int:
"""Update the priorities of the sample batches with new priorities that are
computed by the learner thread.
Returns:
The number of samples trained by the learner thread since the last
training iteration.
"""
num_samples_trained_this_itr = 0
for _ in range(self.learner_thread.outqueue.qsize()):
if self.learner_thread.is_alive():
(
replay_actor,
priority_dict,
env_steps,
agent_steps,
) = self.learner_thread.outqueue.get(timeout=0.001)
if self.config["prioritized_replay"]:
replay_actor.update_priorities.remote(priority_dict)
num_samples_trained_this_itr += env_steps
self.update_target_networks(env_steps)
self._counters[NUM_ENV_STEPS_TRAINED] += env_steps
self._counters[NUM_AGENT_STEPS_TRAINED] += agent_steps
self.workers.local_worker().set_global_vars(
{"timestep": self._counters[NUM_ENV_STEPS_TRAINED]}
)
else:
raise RuntimeError("The learner thread died in while training")
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = num_samples_trained_this_itr
self._timers["learner_dequeue"] = self.learner_thread.queue_timer
self._timers["learner_grad"] = self.learner_thread.grad_timer
self._timers["learner_overall"] = self.learner_thread.overall_timer
def update_target_networks(self, num_new_trained_samples) -> None:
"""Update the target networks."""
self._num_ts_trained_since_last_target_update += num_new_trained_samples
if (
self._num_ts_trained_since_last_target_update
>= self.config["target_network_update_freq"]
):
self._num_ts_trained_since_last_target_update = 0
with self._timers[TARGET_NET_UPDATE_TIMER]:
to_update = self.workers.local_worker().get_policies_to_train()
self.workers.local_worker().foreach_policy_to_train(
lambda p, pid: pid in to_update and p.update_target()
)
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = self._counters[
STEPS_TRAINED_COUNTER
]
@override(Trainer)
def _compile_step_results(self, *, step_ctx, step_attempt_results=None):
result = super()._compile_step_results(
step_ctx=step_ctx, step_attempt_results=step_attempt_results
)
replay_stats = ray.get(
self.replay_actors[0].stats.remote(self.config["optimizer"].get("debug"))
)
exploration_infos_list = self.workers.foreach_policy_to_train(
lambda p, pid: {pid: p.get_exploration_state()}
)
exploration_infos = {}
for info in exploration_infos_list:
# we're guaranteed that each info has policy ids that are unique
exploration_infos.update(info)
other_results = {
"exploration_infos": exploration_infos,
"learner_queue": self.learner_thread.learner_queue_size.stats(),
"replay_shard_0": replay_stats,
}
result["info"].update(other_results)
return result
@classmethod
@override(Trainable)
def default_resource_request(cls, config):
cf = dict(cls.get_default_config(), **config)
eval_config = cf["evaluation_config"]
# Return PlacementGroupFactory containing all needed resources
# (already properly defined as device bundles).
return PlacementGroupFactory(
bundles=[
{
# Local worker + replay buffer actors.
# Force replay buffers to be on same node to maximize
# data bandwidth between buffers and the learner (driver).
# Replay buffer actors each contain one shard of the total
# replay buffer and use 1 CPU each.
"CPU": cf["num_cpus_for_driver"]
+ cf["optimizer"]["num_replay_buffer_shards"],
"GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"],
}
]
+ [
{
# RolloutWorkers.
"CPU": cf["num_cpus_per_worker"],
"GPU": cf["num_gpus_per_worker"],
}
for _ in range(cf["num_workers"])
]
+ (
[
{
# Evaluation workers.
# Note: The local eval worker is located on the driver
# CPU.
"CPU": eval_config.get(
"num_cpus_per_worker", cf["num_cpus_per_worker"]
),
"GPU": eval_config.get(
"num_gpus_per_worker", cf["num_gpus_per_worker"]
),
}
for _ in range(cf["evaluation_num_workers"])
]
if cf["evaluation_interval"]
else []
),
strategy=config.get("placement_strategy", "PACK"),
)