mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
This commit is contained in:
parent
c5252c5ceb
commit
a337fd994e
5 changed files with 187 additions and 24 deletions
10
rllib/BUILD
10
rllib/BUILD
|
@ -259,6 +259,16 @@ py_test(
|
|||
args = ["--yaml-dir=tuned_examples/ppo"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "learning_tests_pendulum_ddppo",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:ml", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = glob(["tuned_examples/ppo/pendulum-ddppo.yaml"]),
|
||||
args = ["--yaml-dir=tuned_examples/ppo"]
|
||||
)
|
||||
|
||||
# DQN
|
||||
py_test(
|
||||
name = "learning_tests_cartpole_dqn",
|
||||
|
|
|
@ -23,22 +23,37 @@ from typing import Callable, Optional, Union
|
|||
import ray
|
||||
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as PPO_DEFAULT_CONFIG, PPOTrainer
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.common import (
|
||||
STEPS_SAMPLED_COUNTER,
|
||||
STEPS_TRAINED_COUNTER,
|
||||
STEPS_TRAINED_THIS_ITER_COUNTER,
|
||||
LEARN_ON_BATCH_TIMER,
|
||||
_get_shared_metrics,
|
||||
_get_global_vars,
|
||||
)
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.metrics import (
|
||||
LEARN_ON_BATCH_TIMER,
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
NUM_AGENT_STEPS_TRAINED,
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
NUM_ENV_STEPS_TRAINED,
|
||||
SAMPLE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LearnerInfoBuilder
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
from ray.rllib.utils.typing import EnvType, PartialTrainerConfigDict, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import (
|
||||
EnvType,
|
||||
PartialTrainerConfigDict,
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
)
|
||||
from ray.tune.logger import Logger
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
|
@ -79,9 +94,6 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
"num_gpus": 0,
|
||||
# Each rollout worker gets a GPU.
|
||||
"num_gpus_per_worker": 1,
|
||||
# Require evenly sized batches. Otherwise,
|
||||
# collective allreduce could fail.
|
||||
"truncate_episodes": True,
|
||||
# This is auto set based on sample batch size.
|
||||
"train_batch_size": -1,
|
||||
# Kl divergence penalty should be fixed to 0 in DDPPO because in order
|
||||
|
@ -89,9 +101,6 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
# DDPPO
|
||||
"kl_coeff": 0.0,
|
||||
"kl_target": 0.0,
|
||||
|
||||
# Keep using execution_plan API (training_iteration fn not defined yet).
|
||||
"_disable_execution_plan_api": False,
|
||||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
|
@ -157,6 +166,13 @@ class DDPPOTrainer(PPOTrainer):
|
|||
# setting.
|
||||
super().validate_config(config)
|
||||
|
||||
# Must have `num_workers` >= 1.
|
||||
if config["num_workers"] < 1:
|
||||
raise ValueError(
|
||||
"Due to its ditributed, decentralized nature, "
|
||||
"DD-PPO requires `num_workers` to be >= 1!"
|
||||
)
|
||||
|
||||
# Only supported for PyTorch so far.
|
||||
if config["framework"] != "torch":
|
||||
raise ValueError("Distributed data parallel is only supported for PyTorch")
|
||||
|
@ -186,6 +202,126 @@ class DDPPOTrainer(PPOTrainer):
|
|||
if config["kl_coeff"] != 0.0 or config["kl_target"] != 0.0:
|
||||
raise ValueError("DDPPO doesn't support KL penalties like PPO-1")
|
||||
|
||||
@override(PPOTrainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
super().setup(config)
|
||||
|
||||
# Initialize torch process group for
|
||||
if self.config["_disable_execution_plan_api"] is True:
|
||||
self._curr_learner_info = {}
|
||||
ip = ray.get(self.workers.remote_workers()[0].get_node_ip.remote())
|
||||
port = ray.get(self.workers.remote_workers()[0].find_free_port.remote())
|
||||
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
||||
logger.info("Creating torch process group with leader {}".format(address))
|
||||
|
||||
# Get setup tasks in order to throw errors on failure.
|
||||
ray.get(
|
||||
[
|
||||
worker.setup_torch_data_parallel.remote(
|
||||
url=address,
|
||||
world_rank=i,
|
||||
world_size=len(self.workers.remote_workers()),
|
||||
backend=self.config["torch_distributed_backend"],
|
||||
)
|
||||
for i, worker in enumerate(self.workers.remote_workers())
|
||||
]
|
||||
)
|
||||
logger.info("Torch process group init completed")
|
||||
|
||||
@override(PPOTrainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
# Shortcut.
|
||||
first_worker = self.workers.remote_workers()[0]
|
||||
|
||||
# Run sampling and update steps on each worker in asynchronous fashion.
|
||||
sample_and_update_results = asynchronous_parallel_requests(
|
||||
remote_requests_in_flight=self.remote_requests_in_flight,
|
||||
actors=self.workers.remote_workers(),
|
||||
ray_wait_timeout_s=0.0,
|
||||
max_remote_requests_in_flight_per_actor=1, # 2
|
||||
remote_fn=self._sample_and_train_torch_distributed,
|
||||
)
|
||||
|
||||
# For all results collected:
|
||||
# - Update our counters and timers.
|
||||
# - Update the worker's global_vars.
|
||||
# - Build info dict using a LearnerInfoBuilder object.
|
||||
learner_info_builder = LearnerInfoBuilder(num_devices=1)
|
||||
steps_this_iter = 0
|
||||
for worker, result in sample_and_update_results.items():
|
||||
# TODO: Add an inner loop over (>1) results here once APEX has been merged!
|
||||
steps_this_iter += result["env_steps"]
|
||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"]
|
||||
self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"]
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"]
|
||||
self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"]
|
||||
self._timers[LEARN_ON_BATCH_TIMER].push(result["learn_on_batch_time"])
|
||||
self._timers[SAMPLE_TIMER].push(result["sample_time"])
|
||||
# Add partial learner info to builder object.
|
||||
learner_info_builder.add_learn_on_batch_results_multi_agent(result["info"])
|
||||
|
||||
# Broadcast the local set of global vars.
|
||||
global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED]}
|
||||
for worker in self.workers.remote_workers():
|
||||
worker.set_global_vars.remote(global_vars)
|
||||
|
||||
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = steps_this_iter
|
||||
|
||||
# Sync down the weights from 1st remote worker (only if we have received
|
||||
# some results from it).
|
||||
# As with the sync up, this is not really needed unless the user is
|
||||
# reading the local weights.
|
||||
if (
|
||||
self.config["keep_local_weights_in_sync"]
|
||||
and first_worker in sample_and_update_results
|
||||
):
|
||||
self.workers.local_worker().set_weights(
|
||||
ray.get(first_worker.get_weights.remote())
|
||||
)
|
||||
# Return merged laarner into results.
|
||||
new_learner_info = learner_info_builder.finalize()
|
||||
if new_learner_info:
|
||||
self._curr_learner_info = new_learner_info
|
||||
return self._curr_learner_info
|
||||
|
||||
@staticmethod
|
||||
def _sample_and_train_torch_distributed(worker: RolloutWorker):
|
||||
# This function is applied remotely on each rollout worker.
|
||||
config = worker.policy_config
|
||||
|
||||
# Generate a sample.
|
||||
start = time.perf_counter()
|
||||
batch = worker.sample()
|
||||
sample_time = time.perf_counter() - start
|
||||
expected_batch_size = (
|
||||
config["rollout_fragment_length"] * config["num_envs_per_worker"]
|
||||
)
|
||||
assert batch.count == expected_batch_size, (
|
||||
"Batch size possibly out of sync between workers, expected:",
|
||||
expected_batch_size,
|
||||
"got:",
|
||||
batch.count,
|
||||
)
|
||||
|
||||
# Perform n minibatch SGD update(s) on the worker itself.
|
||||
start = time.perf_counter()
|
||||
info = do_minibatch_sgd(
|
||||
batch,
|
||||
worker.policy_map,
|
||||
worker,
|
||||
config["num_sgd_iter"],
|
||||
config["sgd_minibatch_size"],
|
||||
[Postprocessing.ADVANTAGES],
|
||||
)
|
||||
learn_on_batch_time = time.perf_counter() - start
|
||||
return {
|
||||
"info": info,
|
||||
"env_steps": batch.env_steps(),
|
||||
"agent_steps": batch.agent_steps(),
|
||||
"sample_time": sample_time,
|
||||
"learn_on_batch_time": learn_on_batch_time,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@override(PPOTrainer)
|
||||
def execution_plan(
|
||||
|
@ -209,8 +345,6 @@ class DDPPOTrainer(PPOTrainer):
|
|||
rollouts = ParallelRollouts(workers, mode="raw")
|
||||
|
||||
# Setup the distributed processes.
|
||||
if not workers.remote_workers():
|
||||
raise ValueError("This optimizer requires >0 remote workers.")
|
||||
ip = ray.get(workers.remote_workers()[0].get_node_ip.remote())
|
||||
port = ray.get(workers.remote_workers()[0].find_free_port.remote())
|
||||
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
||||
|
@ -271,6 +405,7 @@ class DDPPOTrainer(PPOTrainer):
|
|||
self.fetch_start_time = time.perf_counter()
|
||||
|
||||
def __call__(self, items):
|
||||
assert len(items) == config["num_workers"]
|
||||
for item in items:
|
||||
info, count = item
|
||||
metrics = _get_shared_metrics()
|
||||
|
|
|
@ -34,11 +34,10 @@ class TestDDPPO(unittest.TestCase):
|
|||
results = trainer.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
# Make sure, weights on all workers are the same (including
|
||||
# local one).
|
||||
# Make sure, weights on all workers are the same.
|
||||
weights = trainer.workers.foreach_worker(lambda w: w.get_weights())
|
||||
for w in weights[1:]:
|
||||
check(w, weights[0])
|
||||
check(w, weights[1])
|
||||
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
|
@ -48,15 +47,16 @@ class TestDDPPO(unittest.TestCase):
|
|||
config = ppo.ddppo.DEFAULT_CONFIG.copy()
|
||||
config["num_gpus_per_worker"] = 0
|
||||
config["lr_schedule"] = [[0, config["lr"]], [1000, 0.0]]
|
||||
num_iterations = 3
|
||||
num_iterations = 10
|
||||
|
||||
for _ in framework_iterator(config, "torch"):
|
||||
trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0")
|
||||
for _ in range(num_iterations):
|
||||
result = trainer.train()
|
||||
lr = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
|
||||
"cur_lr"
|
||||
]
|
||||
if result["info"][LEARNER_INFO]:
|
||||
lr = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][
|
||||
LEARNER_STATS_KEY
|
||||
]["cur_lr"]
|
||||
trainer.stop()
|
||||
assert lr == 0.0, "lr should anneal to 0.0"
|
||||
|
||||
|
|
|
@ -961,10 +961,8 @@ class Trainer(Trainable):
|
|||
|
||||
# Function defining one single training iteration's behavior.
|
||||
if self.config["_disable_execution_plan_api"]:
|
||||
# TODO: Ensure remote workers are initially in sync with the
|
||||
# local worker.
|
||||
# self.workers.sync_weights()
|
||||
pass # TODO: Uncommenting line above breaks tf2+eager_tracing for A3C.
|
||||
# Ensure remote workers are initially in sync with the local worker.
|
||||
self.workers.sync_weights()
|
||||
# LocalIterator-creating "execution plan".
|
||||
# Only call this once here to create `self.train_exec_impl`,
|
||||
# which is a ray.util.iter.LocalIterator that will be `next`'d
|
||||
|
|
20
rllib/tuned_examples/ppo/pendulum-ddppo.yaml
Normal file
20
rllib/tuned_examples/ppo/pendulum-ddppo.yaml
Normal file
|
@ -0,0 +1,20 @@
|
|||
pendulum-ddppo:
|
||||
env: Pendulum-v1
|
||||
run: DDPPO
|
||||
stop:
|
||||
episode_reward_mean: -300
|
||||
timesteps_total: 1500000
|
||||
config:
|
||||
framework: torch
|
||||
rollout_fragment_length: 250
|
||||
num_gpus_per_worker: 0
|
||||
num_workers: 4
|
||||
num_envs_per_worker: 10
|
||||
observation_filter: MeanStdFilter
|
||||
gamma: 0.95
|
||||
sgd_minibatch_size: 50
|
||||
num_sgd_iter: 5
|
||||
clip_param: 0.3
|
||||
vf_clip_param: 10.0
|
||||
lambda: 0.1
|
||||
lr: 0.00015
|
Loading…
Add table
Reference in a new issue