ray/rllib/agents/ppo/ddppo.py

390 lines
15 KiB
Python
Raw Normal View History

"""
Decentralized Distributed PPO (DD-PPO)
======================================
Unlike APPO or PPO, learning is no longer done centralized in the trainer
process. Instead, gradients are computed remotely on each rollout worker and
all-reduced to sync them at each mini-batch. This allows each worker's GPU
to be used both for sampling and for training.
DD-PPO should be used if you have envs that require GPUs to function, or have
a very large model that cannot be effectively optimized with the GPUs available
on a single machine (DD-PPO allows scaling to arbitrary numbers of GPUs across
multiple nodes, unlike PPO/APPO which is limited to GPUs on a single node).
Paper reference: https://arxiv.org/abs/1911.00357
Note that unlike the paper, we currently do not implement straggler mitigation.
"""
import logging
import time
from typing import Callable, Optional, Union
import ray
from ray.rllib.agents.ppo.ppo import PPOConfig, PPOTrainer
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.common import (
STEPS_TRAINED_THIS_ITER_COUNTER,
)
from ray.rllib.execution.parallel_requests import AsyncRequestsManager
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics import (
LEARN_ON_BATCH_TIMER,
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
SAMPLE_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.sgd import do_minibatch_sgd
from ray.rllib.utils.typing import (
EnvType,
PartialTrainerConfigDict,
ResultDict,
TrainerConfigDict,
)
from ray.tune.logger import Logger
logger = logging.getLogger(__name__)
class DDPPOConfig(PPOConfig):
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
Example:
>>> from ray.rllib.agents.ppo import DDPPOConfig
>>> config = DDPPOConfig().training(lr=0.003, keep_local_weights_in_sync=True)\
... .resources(num_gpus=1)\
... .rollouts(num_workers=10)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
Example:
>>> from ray.rllib.agents.ppo import DDPPOConfig
>>> from ray import tune
>>> config = DDPPOConfig()
>>> # Print out some default values.
>>> print(config.kl_coeff)
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), num_sgd_iter=15)
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "DDPPO",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self, trainer_class=None):
"""Initializes a DDPPOConfig instance."""
super().__init__(trainer_class=trainer_class or DDPPOTrainer)
# fmt: off
# __sphinx_doc_begin__
# DD-PPO specific settings:
self.keep_local_weights_in_sync = True
self.torch_distributed_backend = "gloo"
# Override some of PPO/Trainer's default values with DDPPO-specific values.
# During the sampling phase, each rollout worker will collect a batch
# `rollout_fragment_length * num_envs_per_worker` steps in size.
self.rollout_fragment_length = 100
# Vectorize the env (should enable by default since each worker has
# a GPU).
self.num_envs_per_worker = 5
# During the SGD phase, workers iterate over minibatches of this size.
# The effective minibatch size will be:
# `sgd_minibatch_size * num_workers`.
self.sgd_minibatch_size = 50
# Number of SGD epochs per optimization round.
self.num_sgd_iter = 10
# *** WARNING: configs below are DDPPO overrides over PPO; you
# shouldn't need to adjust them. ***
# DDPPO requires PyTorch distributed.
self.framework_str = "torch"
# Learning is no longer done on the driver process, so
# giving GPUs to the driver does not make sense!
self.num_gpus = 0
# Each rollout worker gets a GPU.
self.num_gpus_per_worker = 1
# This is auto set based on sample batch size.
self.train_batch_size = -1
# Kl divergence penalty should be fixed to 0 in DDPPO because in order
# for it to be used as a penalty, we would have to un-decentralize
# DDPPO
self.kl_coeff = 0.0
self.kl_target = 0.0
# __sphinx_doc_end__
# fmt: on
@override(PPOConfig)
def training(
self,
*,
keep_local_weights_in_sync: Optional[bool] = None,
torch_distributed_backend: Optional[str] = None,
**kwargs,
) -> "DDPPOConfig":
"""Sets the training related configuration.
Args:
keep_local_weights_in_sync: Download weights between each training step.
This adds a bit of overhead but allows the user to access the weights
from the trainer.
torch_distributed_backend: The communication backend for PyTorch
distributed.
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if keep_local_weights_in_sync is not None:
self.keep_local_weights_in_sync = keep_local_weights_in_sync
if torch_distributed_backend is not None:
self.torch_distributed_backend = torch_distributed_backend
return self
class DDPPOTrainer(PPOTrainer):
def __init__(
self,
config: Optional[PartialTrainerConfigDict] = None,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
remote_checkpoint_dir: Optional[str] = None,
sync_function_tpl: Optional[str] = None,
):
"""Initializes a DDPPOTrainer instance.
Args:
config: Algorithm-specific configuration dict.
env: Name of the environment to use (e.g. a gym-registered str),
a full class path (e.g.
"ray.rllib.examples.env.random_env.RandomEnv"), or an Env
class directly. Note that this arg can also be specified via
the "env" key in `config`.
logger_creator: Callable that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
super().__init__(
config, env, logger_creator, remote_checkpoint_dir, sync_function_tpl
)
if "train_batch_size" in config.keys() and config["train_batch_size"] != -1:
# Users should not define `train_batch_size` directly (always -1).
raise ValueError(
"Set rollout_fragment_length instead of train_batch_size for DDPPO."
)
# Auto-train_batch_size: Calculate from rollout len and
# envs-per-worker.
config["train_batch_size"] = config.get(
"rollout_fragment_length", DEFAULT_CONFIG["rollout_fragment_length"]
) * config.get("num_envs_per_worker", DEFAULT_CONFIG["num_envs_per_worker"])
@classmethod
@override(PPOTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return DDPPOConfig().to_dict()
@override(PPOTrainer)
def validate_config(self, config):
"""Validates the Trainer's config dict.
Args:
config (TrainerConfigDict): The Trainer's config to check.
Raises:
ValueError: In case something is wrong with the config.
"""
# Call (base) PPO's config validation function first.
# Note that this will not touch or check on the train_batch_size=-1
# setting.
super().validate_config(config)
# Must have `num_workers` >= 1.
if config["num_workers"] < 1:
raise ValueError(
"Due to its distributed, decentralized nature, "
"DD-PPO requires `num_workers` to be >= 1!"
)
# Only supported for PyTorch so far.
if config["framework"] != "torch":
raise ValueError("Distributed data parallel is only supported for PyTorch")
if config["torch_distributed_backend"] not in ("gloo", "mpi", "nccl"):
raise ValueError(
"Only gloo, mpi, or nccl is supported for "
"the backend of PyTorch distributed."
)
# `num_gpus` must be 0/None, since all optimization happens on Workers.
if config["num_gpus"]:
raise ValueError(
"When using distributed data parallel, you should set "
"num_gpus=0 since all optimization "
"is happening on workers. Enable GPUs for workers by setting "
"num_gpus_per_worker=1."
)
# `batch_mode` must be "truncate_episodes".
if config["batch_mode"] != "truncate_episodes":
raise ValueError(
"Distributed data parallel requires truncate_episodes batch mode."
)
# DDPPO doesn't support KL penalties like PPO-1.
# In order to support KL penalties, DDPPO would need to become
# undecentralized, which defeats the purpose of the algorithm.
# Users can still tune the entropy coefficient to control the
# policy entropy (similar to controlling the KL penalty).
if config["kl_coeff"] != 0.0 or config["kl_target"] != 0.0:
raise ValueError("DDPPO doesn't support KL penalties like PPO-1")
@override(PPOTrainer)
def setup(self, config: PartialTrainerConfigDict):
super().setup(config)
# Initialize torch process group for
if self.config["_disable_execution_plan_api"] is True:
self._curr_learner_info = {}
ip = ray.get(self.workers.remote_workers()[0].get_node_ip.remote())
port = ray.get(self.workers.remote_workers()[0].find_free_port.remote())
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
logger.info("Creating torch process group with leader {}".format(address))
# Get setup tasks in order to throw errors on failure.
ray.get(
[
worker.setup_torch_data_parallel.remote(
url=address,
world_rank=i,
world_size=len(self.workers.remote_workers()),
backend=self.config["torch_distributed_backend"],
)
for i, worker in enumerate(self.workers.remote_workers())
]
)
logger.info("Torch process group init completed")
self._ddppo_worker_manager = AsyncRequestsManager(
self.workers.remote_workers(),
max_remote_requests_in_flight_per_worker=1,
ray_wait_timeout_s=0.03,
)
@override(PPOTrainer)
def training_iteration(self) -> ResultDict:
# Shortcut.
first_worker = self.workers.remote_workers()[0]
self._ddppo_worker_manager.call_on_all_available(
self._sample_and_train_torch_distributed
)
sample_and_update_results = self._ddppo_worker_manager.get_ready()
# For all results collected:
# - Update our counters and timers.
# - Update the worker's global_vars.
# - Build info dict using a LearnerInfoBuilder object.
learner_info_builder = LearnerInfoBuilder(num_devices=1)
steps_this_iter = 0
for worker, results in sample_and_update_results.items():
for result in results:
steps_this_iter += result["env_steps"]
self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"]
self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"]
self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"]
self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"]
self._timers[LEARN_ON_BATCH_TIMER].push(result["learn_on_batch_time"])
self._timers[SAMPLE_TIMER].push(result["sample_time"])
# Add partial learner info to builder object.
learner_info_builder.add_learn_on_batch_results_multi_agent(result["info"])
# Broadcast the local set of global vars.
global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED]}
for worker in self.workers.remote_workers():
worker.set_global_vars.remote(global_vars)
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = steps_this_iter
# Sync down the weights from 1st remote worker (only if we have received
# some results from it).
# As with the sync up, this is not really needed unless the user is
# reading the local weights.
if (
self.config["keep_local_weights_in_sync"]
and first_worker in sample_and_update_results
):
self.workers.local_worker().set_weights(
ray.get(first_worker.get_weights.remote())
)
# Return merged laarner into results.
new_learner_info = learner_info_builder.finalize()
if new_learner_info:
self._curr_learner_info = new_learner_info
return self._curr_learner_info
@staticmethod
def _sample_and_train_torch_distributed(worker: RolloutWorker):
# This function is applied remotely on each rollout worker.
config = worker.policy_config
# Generate a sample.
start = time.perf_counter()
batch = worker.sample()
sample_time = time.perf_counter() - start
expected_batch_size = (
config["rollout_fragment_length"] * config["num_envs_per_worker"]
)
assert batch.count == expected_batch_size, (
"Batch size possibly out of sync between workers, expected:",
expected_batch_size,
"got:",
batch.count,
)
# Perform n minibatch SGD update(s) on the worker itself.
start = time.perf_counter()
info = do_minibatch_sgd(
batch,
worker.policy_map,
worker,
config["num_sgd_iter"],
config["sgd_minibatch_size"],
[Postprocessing.ADVANTAGES],
)
learn_on_batch_time = time.perf_counter() - start
return {
"info": info,
"env_steps": batch.env_steps(),
"agent_steps": batch.agent_steps(),
"sample_time": sample_time,
"learn_on_batch_time": learn_on_batch_time,
}
# Deprecated: Use ray.rllib.agents.ppo.DDPPOConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(DDPPOConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.ppo.ddppo.DEFAULT_CONFIG",
new="ray.rllib.agents.ppo.ddppo.DDPPOConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()