Revert revert #23906 [RLlib] DD-PPO training iteration function implementation. (#24035)

This commit is contained in:
Avnish Narayan 2022-04-21 08:37:49 -07:00 committed by GitHub
parent c5252c5ceb
commit a337fd994e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 187 additions and 24 deletions

View file

@ -259,6 +259,16 @@ py_test(
args = ["--yaml-dir=tuned_examples/ppo"]
)
py_test(
name = "learning_tests_pendulum_ddppo",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = glob(["tuned_examples/ppo/pendulum-ddppo.yaml"]),
args = ["--yaml-dir=tuned_examples/ppo"]
)
# DQN
py_test(
name = "learning_tests_cartpole_dqn",

View file

@ -23,22 +23,37 @@ from typing import Callable, Optional, Union
import ray
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as PPO_DEFAULT_CONFIG, PPOTrainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.common import (
STEPS_SAMPLED_COUNTER,
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
LEARN_ON_BATCH_TIMER,
_get_shared_metrics,
_get_global_vars,
)
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.metrics import (
LEARN_ON_BATCH_TIMER,
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
SAMPLE_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LearnerInfoBuilder
from ray.rllib.utils.sgd import do_minibatch_sgd
from ray.rllib.utils.typing import EnvType, PartialTrainerConfigDict, TrainerConfigDict
from ray.rllib.utils.typing import (
EnvType,
PartialTrainerConfigDict,
ResultDict,
TrainerConfigDict,
)
from ray.tune.logger import Logger
from ray.util.iter import LocalIterator
@ -79,9 +94,6 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
"num_gpus": 0,
# Each rollout worker gets a GPU.
"num_gpus_per_worker": 1,
# Require evenly sized batches. Otherwise,
# collective allreduce could fail.
"truncate_episodes": True,
# This is auto set based on sample batch size.
"train_batch_size": -1,
# Kl divergence penalty should be fixed to 0 in DDPPO because in order
@ -89,9 +101,6 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# DDPPO
"kl_coeff": 0.0,
"kl_target": 0.0,
# Keep using execution_plan API (training_iteration fn not defined yet).
"_disable_execution_plan_api": False,
},
_allow_unknown_configs=True,
)
@ -157,6 +166,13 @@ class DDPPOTrainer(PPOTrainer):
# setting.
super().validate_config(config)
# Must have `num_workers` >= 1.
if config["num_workers"] < 1:
raise ValueError(
"Due to its ditributed, decentralized nature, "
"DD-PPO requires `num_workers` to be >= 1!"
)
# Only supported for PyTorch so far.
if config["framework"] != "torch":
raise ValueError("Distributed data parallel is only supported for PyTorch")
@ -186,6 +202,126 @@ class DDPPOTrainer(PPOTrainer):
if config["kl_coeff"] != 0.0 or config["kl_target"] != 0.0:
raise ValueError("DDPPO doesn't support KL penalties like PPO-1")
@override(PPOTrainer)
def setup(self, config: PartialTrainerConfigDict):
super().setup(config)
# Initialize torch process group for
if self.config["_disable_execution_plan_api"] is True:
self._curr_learner_info = {}
ip = ray.get(self.workers.remote_workers()[0].get_node_ip.remote())
port = ray.get(self.workers.remote_workers()[0].find_free_port.remote())
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
logger.info("Creating torch process group with leader {}".format(address))
# Get setup tasks in order to throw errors on failure.
ray.get(
[
worker.setup_torch_data_parallel.remote(
url=address,
world_rank=i,
world_size=len(self.workers.remote_workers()),
backend=self.config["torch_distributed_backend"],
)
for i, worker in enumerate(self.workers.remote_workers())
]
)
logger.info("Torch process group init completed")
@override(PPOTrainer)
def training_iteration(self) -> ResultDict:
# Shortcut.
first_worker = self.workers.remote_workers()[0]
# Run sampling and update steps on each worker in asynchronous fashion.
sample_and_update_results = asynchronous_parallel_requests(
remote_requests_in_flight=self.remote_requests_in_flight,
actors=self.workers.remote_workers(),
ray_wait_timeout_s=0.0,
max_remote_requests_in_flight_per_actor=1, # 2
remote_fn=self._sample_and_train_torch_distributed,
)
# For all results collected:
# - Update our counters and timers.
# - Update the worker's global_vars.
# - Build info dict using a LearnerInfoBuilder object.
learner_info_builder = LearnerInfoBuilder(num_devices=1)
steps_this_iter = 0
for worker, result in sample_and_update_results.items():
# TODO: Add an inner loop over (>1) results here once APEX has been merged!
steps_this_iter += result["env_steps"]
self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"]
self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"]
self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"]
self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"]
self._timers[LEARN_ON_BATCH_TIMER].push(result["learn_on_batch_time"])
self._timers[SAMPLE_TIMER].push(result["sample_time"])
# Add partial learner info to builder object.
learner_info_builder.add_learn_on_batch_results_multi_agent(result["info"])
# Broadcast the local set of global vars.
global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED]}
for worker in self.workers.remote_workers():
worker.set_global_vars.remote(global_vars)
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = steps_this_iter
# Sync down the weights from 1st remote worker (only if we have received
# some results from it).
# As with the sync up, this is not really needed unless the user is
# reading the local weights.
if (
self.config["keep_local_weights_in_sync"]
and first_worker in sample_and_update_results
):
self.workers.local_worker().set_weights(
ray.get(first_worker.get_weights.remote())
)
# Return merged laarner into results.
new_learner_info = learner_info_builder.finalize()
if new_learner_info:
self._curr_learner_info = new_learner_info
return self._curr_learner_info
@staticmethod
def _sample_and_train_torch_distributed(worker: RolloutWorker):
# This function is applied remotely on each rollout worker.
config = worker.policy_config
# Generate a sample.
start = time.perf_counter()
batch = worker.sample()
sample_time = time.perf_counter() - start
expected_batch_size = (
config["rollout_fragment_length"] * config["num_envs_per_worker"]
)
assert batch.count == expected_batch_size, (
"Batch size possibly out of sync between workers, expected:",
expected_batch_size,
"got:",
batch.count,
)
# Perform n minibatch SGD update(s) on the worker itself.
start = time.perf_counter()
info = do_minibatch_sgd(
batch,
worker.policy_map,
worker,
config["num_sgd_iter"],
config["sgd_minibatch_size"],
[Postprocessing.ADVANTAGES],
)
learn_on_batch_time = time.perf_counter() - start
return {
"info": info,
"env_steps": batch.env_steps(),
"agent_steps": batch.agent_steps(),
"sample_time": sample_time,
"learn_on_batch_time": learn_on_batch_time,
}
@staticmethod
@override(PPOTrainer)
def execution_plan(
@ -209,8 +345,6 @@ class DDPPOTrainer(PPOTrainer):
rollouts = ParallelRollouts(workers, mode="raw")
# Setup the distributed processes.
if not workers.remote_workers():
raise ValueError("This optimizer requires >0 remote workers.")
ip = ray.get(workers.remote_workers()[0].get_node_ip.remote())
port = ray.get(workers.remote_workers()[0].find_free_port.remote())
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
@ -271,6 +405,7 @@ class DDPPOTrainer(PPOTrainer):
self.fetch_start_time = time.perf_counter()
def __call__(self, items):
assert len(items) == config["num_workers"]
for item in items:
info, count = item
metrics = _get_shared_metrics()

View file

@ -34,11 +34,10 @@ class TestDDPPO(unittest.TestCase):
results = trainer.train()
check_train_results(results)
print(results)
# Make sure, weights on all workers are the same (including
# local one).
# Make sure, weights on all workers are the same.
weights = trainer.workers.foreach_worker(lambda w: w.get_weights())
for w in weights[1:]:
check(w, weights[0])
check(w, weights[1])
check_compute_single_action(trainer)
trainer.stop()
@ -48,15 +47,16 @@ class TestDDPPO(unittest.TestCase):
config = ppo.ddppo.DEFAULT_CONFIG.copy()
config["num_gpus_per_worker"] = 0
config["lr_schedule"] = [[0, config["lr"]], [1000, 0.0]]
num_iterations = 3
num_iterations = 10
for _ in framework_iterator(config, "torch"):
trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0")
for _ in range(num_iterations):
result = trainer.train()
lr = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
"cur_lr"
]
if result["info"][LEARNER_INFO]:
lr = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][
LEARNER_STATS_KEY
]["cur_lr"]
trainer.stop()
assert lr == 0.0, "lr should anneal to 0.0"

View file

@ -961,10 +961,8 @@ class Trainer(Trainable):
# Function defining one single training iteration's behavior.
if self.config["_disable_execution_plan_api"]:
# TODO: Ensure remote workers are initially in sync with the
# local worker.
# self.workers.sync_weights()
pass # TODO: Uncommenting line above breaks tf2+eager_tracing for A3C.
# Ensure remote workers are initially in sync with the local worker.
self.workers.sync_weights()
# LocalIterator-creating "execution plan".
# Only call this once here to create `self.train_exec_impl`,
# which is a ray.util.iter.LocalIterator that will be `next`'d

View file

@ -0,0 +1,20 @@
pendulum-ddppo:
env: Pendulum-v1
run: DDPPO
stop:
episode_reward_mean: -300
timesteps_total: 1500000
config:
framework: torch
rollout_fragment_length: 250
num_gpus_per_worker: 0
num_workers: 4
num_envs_per_worker: 10
observation_filter: MeanStdFilter
gamma: 0.95
sgd_minibatch_size: 50
num_sgd_iter: 5
clip_param: 0.3
vf_clip_param: 10.0
lambda: 0.1
lr: 0.00015