mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
390 lines
15 KiB
Python
390 lines
15 KiB
Python
"""
|
|
Deep Q-Networks (DQN, Rainbow, Parametric DQN)
|
|
==============================================
|
|
|
|
This file defines the distributed Trainer class for the Deep Q-Networks
|
|
algorithm. See `dqn_[tf|torch]_policy.py` for the definition of the policies.
|
|
|
|
Detailed documentation:
|
|
https://docs.ray.io/en/master/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn
|
|
""" # noqa: E501
|
|
|
|
import logging
|
|
from typing import List, Optional, Type
|
|
|
|
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
|
|
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
|
|
from ray.rllib.agents.dqn.simple_q import (
|
|
SimpleQTrainer,
|
|
DEFAULT_CONFIG as SIMPLEQ_DEFAULT_CONFIG,
|
|
)
|
|
from ray.rllib.agents.trainer import Trainer
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
|
from ray.rllib.execution.concurrency_ops import Concurrently
|
|
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
|
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
|
from ray.rllib.execution.rollout_ops import (
|
|
ParallelRollouts,
|
|
synchronous_parallel_sample,
|
|
)
|
|
from ray.rllib.execution.train_ops import (
|
|
TrainOneStep,
|
|
UpdateTargetNetwork,
|
|
MultiGPUTrainOneStep,
|
|
train_one_step,
|
|
multi_gpu_train_one_step,
|
|
)
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
|
from ray.rllib.utils.typing import (
|
|
ResultDict,
|
|
TrainerConfigDict,
|
|
)
|
|
from ray.rllib.utils.metrics import (
|
|
NUM_ENV_STEPS_SAMPLED,
|
|
NUM_AGENT_STEPS_SAMPLED,
|
|
)
|
|
from ray.util.iter import LocalIterator
|
|
from ray.rllib.utils.replay_buffers import MultiAgentPrioritizedReplayBuffer
|
|
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
|
|
MultiAgentReplayBuffer as LegacyMultiAgentReplayBuffer,
|
|
)
|
|
from ray.rllib.utils.deprecation import (
|
|
Deprecated,
|
|
DEPRECATED_VALUE,
|
|
)
|
|
from ray.rllib.utils.annotations import ExperimentalAPI
|
|
from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
|
|
from ray.rllib.execution.common import (
|
|
LAST_TARGET_UPDATE_TS,
|
|
NUM_TARGET_UPDATES,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|
SIMPLEQ_DEFAULT_CONFIG,
|
|
{
|
|
# === Model ===
|
|
# Number of atoms for representing the distribution of return. When
|
|
# this is greater than 1, distributional Q-learning is used.
|
|
# the discrete supports are bounded by v_min and v_max
|
|
"num_atoms": 1,
|
|
"v_min": -10.0,
|
|
"v_max": 10.0,
|
|
# Whether to use noisy network
|
|
"noisy": False,
|
|
# control the initial value of noisy nets
|
|
"sigma0": 0.5,
|
|
# Whether to use dueling dqn
|
|
"dueling": True,
|
|
# Dense-layer setup for each the advantage branch and the value branch
|
|
# in a dueling architecture.
|
|
"hiddens": [256],
|
|
# Whether to use double dqn
|
|
"double_q": True,
|
|
# N-step Q learning
|
|
"n_step": 1,
|
|
|
|
# === Replay buffer ===
|
|
# Deprecated, use capacity in replay_buffer_config instead.
|
|
"buffer_size": DEPRECATED_VALUE,
|
|
"replay_buffer_config": {
|
|
# For now we don't use the new ReplayBuffer API here
|
|
"_enable_replay_buffer_api": True,
|
|
"type": "MultiAgentPrioritizedReplayBuffer",
|
|
# Size of the replay buffer. Note that if async_updates is set,
|
|
# then each worker will have a replay buffer of this size.
|
|
"capacity": 50000,
|
|
"prioritized_replay_alpha": 0.6,
|
|
# Beta parameter for sampling from prioritized replay buffer.
|
|
"prioritized_replay_beta": 0.4,
|
|
# Epsilon to add to the TD errors when updating priorities.
|
|
"prioritized_replay_eps": 1e-6,
|
|
# The number of continuous environment steps to replay at once. This may
|
|
# be set to greater than 1 to support recurrent models.
|
|
"replay_sequence_length": 1,
|
|
},
|
|
# Set this to True, if you want the contents of your buffer(s) to be
|
|
# stored in any saved checkpoints as well.
|
|
# Warnings will be created if:
|
|
# - This is True AND restoring from a checkpoint that contains no buffer
|
|
# data.
|
|
# - This is False AND restoring from a checkpoint that does contain
|
|
# buffer data.
|
|
"store_buffer_in_checkpoints": False,
|
|
|
|
|
|
# Callback to run before learning on a multi-agent batch of
|
|
# experiences.
|
|
"before_learn_on_batch": None,
|
|
|
|
# The intensity with which to update the model (vs collecting samples
|
|
# from the env). If None, uses the "natural" value of:
|
|
# `train_batch_size` / (`rollout_fragment_length` x `num_workers` x
|
|
# `num_envs_per_worker`).
|
|
# If provided, will make sure that the ratio between ts inserted into
|
|
# and sampled from the buffer matches the given value.
|
|
# Example:
|
|
# training_intensity=1000.0
|
|
# train_batch_size=250 rollout_fragment_length=1
|
|
# num_workers=1 (or 0) num_envs_per_worker=1
|
|
# -> natural value = 250 / 1 = 250.0
|
|
# -> will make sure that replay+train op will be executed 4x as
|
|
# often as rollout+insert op (4 * 250 = 1000).
|
|
# See: rllib/agents/dqn/dqn.py::calculate_rr_weights for further
|
|
# details.
|
|
"training_intensity": None,
|
|
|
|
# === Parallelism ===
|
|
# Whether to compute priorities on workers.
|
|
"worker_side_prioritization": False,
|
|
|
|
# Experimental flag.
|
|
# If True, the execution plan API will not be used. Instead,
|
|
# a Trainer's `training_iteration` method will be called as-is each
|
|
# training iteration.
|
|
"_disable_execution_plan_api": True,
|
|
},
|
|
_allow_unknown_configs=True,
|
|
)
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
|
|
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
|
|
"""Calculate the round robin weights for the rollout and train steps"""
|
|
if not config["training_intensity"]:
|
|
return [1, 1]
|
|
|
|
# Calculate the "native ratio" as:
|
|
# [train-batch-size] / [size of env-rolled-out sampled data]
|
|
# This is to set freshly rollout-collected data in relation to
|
|
# the data we pull from the replay buffer (which also contains old
|
|
# samples).
|
|
native_ratio = config["train_batch_size"] / (
|
|
config["rollout_fragment_length"]
|
|
* config["num_envs_per_worker"]
|
|
* config["num_workers"]
|
|
)
|
|
|
|
# Training intensity is specified in terms of
|
|
# (steps_replayed / steps_sampled), so adjust for the native ratio.
|
|
weights = [1, config["training_intensity"] / native_ratio]
|
|
return weights
|
|
|
|
|
|
class DQNTrainer(SimpleQTrainer):
|
|
@classmethod
|
|
@override(SimpleQTrainer)
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
return DEFAULT_CONFIG
|
|
|
|
@override(SimpleQTrainer)
|
|
def validate_config(self, config: TrainerConfigDict) -> None:
|
|
# Call super's validation method.
|
|
super().validate_config(config)
|
|
|
|
# Update effective batch size to include n-step
|
|
adjusted_rollout_len = max(config["rollout_fragment_length"], config["n_step"])
|
|
config["rollout_fragment_length"] = adjusted_rollout_len
|
|
|
|
@override(SimpleQTrainer)
|
|
def get_default_policy_class(
|
|
self, config: TrainerConfigDict
|
|
) -> Optional[Type[Policy]]:
|
|
if config["framework"] == "torch":
|
|
return DQNTorchPolicy
|
|
else:
|
|
return DQNTFPolicy
|
|
|
|
@staticmethod
|
|
@override(SimpleQTrainer)
|
|
def execution_plan(
|
|
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
|
) -> LocalIterator[dict]:
|
|
assert (
|
|
"local_replay_buffer" in kwargs
|
|
), "DQN's execution plan requires a local replay buffer."
|
|
|
|
# Assign to Trainer, so we can store the MultiAgentReplayBuffer's
|
|
# data when we save checkpoints.
|
|
local_replay_buffer = kwargs["local_replay_buffer"]
|
|
|
|
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
|
|
|
# We execute the following steps concurrently:
|
|
# (1) Generate rollouts and store them in our local replay buffer.
|
|
# Calling next() on store_op drives this.
|
|
store_op = rollouts.for_each(
|
|
StoreToReplayBuffer(local_buffer=local_replay_buffer)
|
|
)
|
|
|
|
def update_prio(item):
|
|
samples, info_dict = item
|
|
prio_dict = {}
|
|
for policy_id, info in info_dict.items():
|
|
# TODO(sven): This is currently structured differently for
|
|
# torch/tf. Clean up these results/info dicts across
|
|
# policies (note: fixing this in torch_policy.py will
|
|
# break e.g. DDPPO!).
|
|
td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error"))
|
|
samples.policy_batches[policy_id].set_get_interceptor(None)
|
|
batch_indices = samples.policy_batches[policy_id].get("batch_indexes")
|
|
# In case the buffer stores sequences, TD-error could
|
|
# already be calculated per sequence chunk.
|
|
if len(batch_indices) != len(td_error):
|
|
T = local_replay_buffer.replay_sequence_length
|
|
assert (
|
|
len(batch_indices) > len(td_error)
|
|
and len(batch_indices) % T == 0
|
|
)
|
|
batch_indices = batch_indices.reshape([-1, T])[:, 0]
|
|
assert len(batch_indices) == len(td_error)
|
|
prio_dict[policy_id] = (batch_indices, td_error)
|
|
local_replay_buffer.update_priorities(prio_dict)
|
|
return info_dict
|
|
|
|
# (2) Read and train on experiences from the replay buffer. Every batch
|
|
# returned from the LocalReplay() iterator is passed to TrainOneStep to
|
|
# take a SGD step, and then we decide whether to update the target
|
|
# network.
|
|
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
|
|
|
|
if config["simple_optimizer"]:
|
|
train_step_op = TrainOneStep(workers)
|
|
else:
|
|
train_step_op = MultiGPUTrainOneStep(
|
|
workers=workers,
|
|
sgd_minibatch_size=config["train_batch_size"],
|
|
num_sgd_iter=1,
|
|
num_gpus=config["num_gpus"],
|
|
_fake_gpus=config["_fake_gpus"],
|
|
)
|
|
|
|
if (
|
|
type(local_replay_buffer) is LegacyMultiAgentReplayBuffer
|
|
and config["replay_buffer_config"].get("prioritized_replay_alpha", 0.0)
|
|
> 0.0
|
|
) or isinstance(local_replay_buffer, MultiAgentPrioritizedReplayBuffer):
|
|
update_prio_fn = update_prio
|
|
else:
|
|
|
|
def update_prio_fn(x):
|
|
return x
|
|
|
|
replay_op = (
|
|
Replay(local_buffer=local_replay_buffer)
|
|
.for_each(lambda x: post_fn(x, workers, config))
|
|
.for_each(train_step_op)
|
|
.for_each(update_prio_fn)
|
|
.for_each(
|
|
UpdateTargetNetwork(workers, config["target_network_update_freq"])
|
|
)
|
|
)
|
|
|
|
# Alternate deterministically between (1) and (2).
|
|
# Only return the output of (2) since training metrics are not
|
|
# available until (2) runs.
|
|
train_op = Concurrently(
|
|
[store_op, replay_op],
|
|
mode="round_robin",
|
|
output_indexes=[1],
|
|
round_robin_weights=calculate_rr_weights(config),
|
|
)
|
|
|
|
return StandardMetricsReporting(train_op, workers, config)
|
|
|
|
@ExperimentalAPI
|
|
def training_iteration(self) -> ResultDict:
|
|
"""DQN training iteration function.
|
|
|
|
Each training iteration, we:
|
|
- Sample (MultiAgentBatch) from workers.
|
|
- Store new samples in replay buffer.
|
|
- Sample training batch (MultiAgentBatch) from replay buffer.
|
|
- Learn on training batch.
|
|
- Update remote workers' new policy weights.
|
|
- Update target network every target_network_update_freq steps.
|
|
- Return all collected metrics for the iteration.
|
|
|
|
Returns:
|
|
The results dict from executing the training iteration.
|
|
"""
|
|
local_worker = self.workers.local_worker()
|
|
|
|
train_results = {}
|
|
|
|
# We alternate between storing new samples and sampling and training
|
|
store_weight, sample_and_train_weight = calculate_rr_weights(self.config)
|
|
|
|
for _ in range(store_weight):
|
|
# (1) Sample (MultiAgentBatch) from workers
|
|
new_sample_batch = synchronous_parallel_sample(
|
|
worker_set=self.workers, concat=True
|
|
)
|
|
|
|
# Update counters
|
|
self._counters[NUM_AGENT_STEPS_SAMPLED] += new_sample_batch.agent_steps()
|
|
self._counters[NUM_ENV_STEPS_SAMPLED] += new_sample_batch.env_steps()
|
|
|
|
# (2) Store new samples in replay buffer
|
|
self.local_replay_buffer.add_batch(new_sample_batch)
|
|
|
|
for _ in range(sample_and_train_weight):
|
|
# (3) Sample training batch (MultiAgentBatch) from replay buffer.
|
|
train_batch = self.local_replay_buffer.replay()
|
|
|
|
# Old-style replay buffers return None if learning has not started
|
|
if not train_batch:
|
|
continue
|
|
|
|
# Postprocess batch before we learn on it
|
|
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
|
train_batch = post_fn(train_batch, self.workers, self.config)
|
|
|
|
# (4) Learn on training batch.
|
|
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
|
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
|
if self.config.get("simple_optimizer") is True:
|
|
train_results = train_one_step(self, train_batch)
|
|
else:
|
|
train_results = multi_gpu_train_one_step(self, train_batch)
|
|
|
|
# Update replay buffer priorities.
|
|
update_priorities_in_replay_buffer(
|
|
self.local_replay_buffer,
|
|
self.config,
|
|
train_batch,
|
|
train_results,
|
|
)
|
|
|
|
# (6) Update target network every target_network_update_freq steps
|
|
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
|
|
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
|
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
|
to_update = local_worker.get_policies_to_train()
|
|
local_worker.foreach_policy_to_train(
|
|
lambda p, pid: pid in to_update and p.update_target()
|
|
)
|
|
self._counters[NUM_TARGET_UPDATES] += 1
|
|
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
|
|
|
# Update remote workers's weights after learning on local worker
|
|
if self.workers.remote_workers():
|
|
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
|
self.workers.sync_weights()
|
|
|
|
# (7) Return all collected metrics for the iteration.
|
|
return train_results
|
|
|
|
|
|
@Deprecated(
|
|
new="Sub-class directly from `DQNTrainer` and override its methods", error=False
|
|
)
|
|
class GenericOffPolicyTrainer(DQNTrainer):
|
|
pass
|