mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
This commit is contained in:
parent
c5252c5ceb
commit
a337fd994e
5 changed files with 187 additions and 24 deletions
10
rllib/BUILD
10
rllib/BUILD
|
@ -259,6 +259,16 @@ py_test(
|
||||||
args = ["--yaml-dir=tuned_examples/ppo"]
|
args = ["--yaml-dir=tuned_examples/ppo"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "learning_tests_pendulum_ddppo",
|
||||||
|
main = "tests/run_regression_tests.py",
|
||||||
|
tags = ["team:ml", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
|
||||||
|
size = "large",
|
||||||
|
srcs = ["tests/run_regression_tests.py"],
|
||||||
|
data = glob(["tuned_examples/ppo/pendulum-ddppo.yaml"]),
|
||||||
|
args = ["--yaml-dir=tuned_examples/ppo"]
|
||||||
|
)
|
||||||
|
|
||||||
# DQN
|
# DQN
|
||||||
py_test(
|
py_test(
|
||||||
name = "learning_tests_cartpole_dqn",
|
name = "learning_tests_cartpole_dqn",
|
||||||
|
|
|
@ -23,22 +23,37 @@ from typing import Callable, Optional, Union
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as PPO_DEFAULT_CONFIG, PPOTrainer
|
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as PPO_DEFAULT_CONFIG, PPOTrainer
|
||||||
from ray.rllib.agents.trainer import Trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
|
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||||
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
|
||||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
|
||||||
from ray.rllib.execution.common import (
|
from ray.rllib.execution.common import (
|
||||||
STEPS_SAMPLED_COUNTER,
|
STEPS_SAMPLED_COUNTER,
|
||||||
STEPS_TRAINED_COUNTER,
|
STEPS_TRAINED_COUNTER,
|
||||||
STEPS_TRAINED_THIS_ITER_COUNTER,
|
STEPS_TRAINED_THIS_ITER_COUNTER,
|
||||||
LEARN_ON_BATCH_TIMER,
|
|
||||||
_get_shared_metrics,
|
_get_shared_metrics,
|
||||||
_get_global_vars,
|
_get_global_vars,
|
||||||
)
|
)
|
||||||
|
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||||
|
from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests
|
||||||
|
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
from ray.rllib.utils.metrics import (
|
||||||
|
LEARN_ON_BATCH_TIMER,
|
||||||
|
NUM_AGENT_STEPS_SAMPLED,
|
||||||
|
NUM_AGENT_STEPS_TRAINED,
|
||||||
|
NUM_ENV_STEPS_SAMPLED,
|
||||||
|
NUM_ENV_STEPS_TRAINED,
|
||||||
|
SAMPLE_TIMER,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LearnerInfoBuilder
|
||||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||||
from ray.rllib.utils.typing import EnvType, PartialTrainerConfigDict, TrainerConfigDict
|
from ray.rllib.utils.typing import (
|
||||||
|
EnvType,
|
||||||
|
PartialTrainerConfigDict,
|
||||||
|
ResultDict,
|
||||||
|
TrainerConfigDict,
|
||||||
|
)
|
||||||
from ray.tune.logger import Logger
|
from ray.tune.logger import Logger
|
||||||
from ray.util.iter import LocalIterator
|
from ray.util.iter import LocalIterator
|
||||||
|
|
||||||
|
@ -79,9 +94,6 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||||
"num_gpus": 0,
|
"num_gpus": 0,
|
||||||
# Each rollout worker gets a GPU.
|
# Each rollout worker gets a GPU.
|
||||||
"num_gpus_per_worker": 1,
|
"num_gpus_per_worker": 1,
|
||||||
# Require evenly sized batches. Otherwise,
|
|
||||||
# collective allreduce could fail.
|
|
||||||
"truncate_episodes": True,
|
|
||||||
# This is auto set based on sample batch size.
|
# This is auto set based on sample batch size.
|
||||||
"train_batch_size": -1,
|
"train_batch_size": -1,
|
||||||
# Kl divergence penalty should be fixed to 0 in DDPPO because in order
|
# Kl divergence penalty should be fixed to 0 in DDPPO because in order
|
||||||
|
@ -89,9 +101,6 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||||
# DDPPO
|
# DDPPO
|
||||||
"kl_coeff": 0.0,
|
"kl_coeff": 0.0,
|
||||||
"kl_target": 0.0,
|
"kl_target": 0.0,
|
||||||
|
|
||||||
# Keep using execution_plan API (training_iteration fn not defined yet).
|
|
||||||
"_disable_execution_plan_api": False,
|
|
||||||
},
|
},
|
||||||
_allow_unknown_configs=True,
|
_allow_unknown_configs=True,
|
||||||
)
|
)
|
||||||
|
@ -157,6 +166,13 @@ class DDPPOTrainer(PPOTrainer):
|
||||||
# setting.
|
# setting.
|
||||||
super().validate_config(config)
|
super().validate_config(config)
|
||||||
|
|
||||||
|
# Must have `num_workers` >= 1.
|
||||||
|
if config["num_workers"] < 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Due to its ditributed, decentralized nature, "
|
||||||
|
"DD-PPO requires `num_workers` to be >= 1!"
|
||||||
|
)
|
||||||
|
|
||||||
# Only supported for PyTorch so far.
|
# Only supported for PyTorch so far.
|
||||||
if config["framework"] != "torch":
|
if config["framework"] != "torch":
|
||||||
raise ValueError("Distributed data parallel is only supported for PyTorch")
|
raise ValueError("Distributed data parallel is only supported for PyTorch")
|
||||||
|
@ -186,6 +202,126 @@ class DDPPOTrainer(PPOTrainer):
|
||||||
if config["kl_coeff"] != 0.0 or config["kl_target"] != 0.0:
|
if config["kl_coeff"] != 0.0 or config["kl_target"] != 0.0:
|
||||||
raise ValueError("DDPPO doesn't support KL penalties like PPO-1")
|
raise ValueError("DDPPO doesn't support KL penalties like PPO-1")
|
||||||
|
|
||||||
|
@override(PPOTrainer)
|
||||||
|
def setup(self, config: PartialTrainerConfigDict):
|
||||||
|
super().setup(config)
|
||||||
|
|
||||||
|
# Initialize torch process group for
|
||||||
|
if self.config["_disable_execution_plan_api"] is True:
|
||||||
|
self._curr_learner_info = {}
|
||||||
|
ip = ray.get(self.workers.remote_workers()[0].get_node_ip.remote())
|
||||||
|
port = ray.get(self.workers.remote_workers()[0].find_free_port.remote())
|
||||||
|
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
||||||
|
logger.info("Creating torch process group with leader {}".format(address))
|
||||||
|
|
||||||
|
# Get setup tasks in order to throw errors on failure.
|
||||||
|
ray.get(
|
||||||
|
[
|
||||||
|
worker.setup_torch_data_parallel.remote(
|
||||||
|
url=address,
|
||||||
|
world_rank=i,
|
||||||
|
world_size=len(self.workers.remote_workers()),
|
||||||
|
backend=self.config["torch_distributed_backend"],
|
||||||
|
)
|
||||||
|
for i, worker in enumerate(self.workers.remote_workers())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.info("Torch process group init completed")
|
||||||
|
|
||||||
|
@override(PPOTrainer)
|
||||||
|
def training_iteration(self) -> ResultDict:
|
||||||
|
# Shortcut.
|
||||||
|
first_worker = self.workers.remote_workers()[0]
|
||||||
|
|
||||||
|
# Run sampling and update steps on each worker in asynchronous fashion.
|
||||||
|
sample_and_update_results = asynchronous_parallel_requests(
|
||||||
|
remote_requests_in_flight=self.remote_requests_in_flight,
|
||||||
|
actors=self.workers.remote_workers(),
|
||||||
|
ray_wait_timeout_s=0.0,
|
||||||
|
max_remote_requests_in_flight_per_actor=1, # 2
|
||||||
|
remote_fn=self._sample_and_train_torch_distributed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For all results collected:
|
||||||
|
# - Update our counters and timers.
|
||||||
|
# - Update the worker's global_vars.
|
||||||
|
# - Build info dict using a LearnerInfoBuilder object.
|
||||||
|
learner_info_builder = LearnerInfoBuilder(num_devices=1)
|
||||||
|
steps_this_iter = 0
|
||||||
|
for worker, result in sample_and_update_results.items():
|
||||||
|
# TODO: Add an inner loop over (>1) results here once APEX has been merged!
|
||||||
|
steps_this_iter += result["env_steps"]
|
||||||
|
self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"]
|
||||||
|
self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"]
|
||||||
|
self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"]
|
||||||
|
self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"]
|
||||||
|
self._timers[LEARN_ON_BATCH_TIMER].push(result["learn_on_batch_time"])
|
||||||
|
self._timers[SAMPLE_TIMER].push(result["sample_time"])
|
||||||
|
# Add partial learner info to builder object.
|
||||||
|
learner_info_builder.add_learn_on_batch_results_multi_agent(result["info"])
|
||||||
|
|
||||||
|
# Broadcast the local set of global vars.
|
||||||
|
global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED]}
|
||||||
|
for worker in self.workers.remote_workers():
|
||||||
|
worker.set_global_vars.remote(global_vars)
|
||||||
|
|
||||||
|
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = steps_this_iter
|
||||||
|
|
||||||
|
# Sync down the weights from 1st remote worker (only if we have received
|
||||||
|
# some results from it).
|
||||||
|
# As with the sync up, this is not really needed unless the user is
|
||||||
|
# reading the local weights.
|
||||||
|
if (
|
||||||
|
self.config["keep_local_weights_in_sync"]
|
||||||
|
and first_worker in sample_and_update_results
|
||||||
|
):
|
||||||
|
self.workers.local_worker().set_weights(
|
||||||
|
ray.get(first_worker.get_weights.remote())
|
||||||
|
)
|
||||||
|
# Return merged laarner into results.
|
||||||
|
new_learner_info = learner_info_builder.finalize()
|
||||||
|
if new_learner_info:
|
||||||
|
self._curr_learner_info = new_learner_info
|
||||||
|
return self._curr_learner_info
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sample_and_train_torch_distributed(worker: RolloutWorker):
|
||||||
|
# This function is applied remotely on each rollout worker.
|
||||||
|
config = worker.policy_config
|
||||||
|
|
||||||
|
# Generate a sample.
|
||||||
|
start = time.perf_counter()
|
||||||
|
batch = worker.sample()
|
||||||
|
sample_time = time.perf_counter() - start
|
||||||
|
expected_batch_size = (
|
||||||
|
config["rollout_fragment_length"] * config["num_envs_per_worker"]
|
||||||
|
)
|
||||||
|
assert batch.count == expected_batch_size, (
|
||||||
|
"Batch size possibly out of sync between workers, expected:",
|
||||||
|
expected_batch_size,
|
||||||
|
"got:",
|
||||||
|
batch.count,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform n minibatch SGD update(s) on the worker itself.
|
||||||
|
start = time.perf_counter()
|
||||||
|
info = do_minibatch_sgd(
|
||||||
|
batch,
|
||||||
|
worker.policy_map,
|
||||||
|
worker,
|
||||||
|
config["num_sgd_iter"],
|
||||||
|
config["sgd_minibatch_size"],
|
||||||
|
[Postprocessing.ADVANTAGES],
|
||||||
|
)
|
||||||
|
learn_on_batch_time = time.perf_counter() - start
|
||||||
|
return {
|
||||||
|
"info": info,
|
||||||
|
"env_steps": batch.env_steps(),
|
||||||
|
"agent_steps": batch.agent_steps(),
|
||||||
|
"sample_time": sample_time,
|
||||||
|
"learn_on_batch_time": learn_on_batch_time,
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@override(PPOTrainer)
|
@override(PPOTrainer)
|
||||||
def execution_plan(
|
def execution_plan(
|
||||||
|
@ -209,8 +345,6 @@ class DDPPOTrainer(PPOTrainer):
|
||||||
rollouts = ParallelRollouts(workers, mode="raw")
|
rollouts = ParallelRollouts(workers, mode="raw")
|
||||||
|
|
||||||
# Setup the distributed processes.
|
# Setup the distributed processes.
|
||||||
if not workers.remote_workers():
|
|
||||||
raise ValueError("This optimizer requires >0 remote workers.")
|
|
||||||
ip = ray.get(workers.remote_workers()[0].get_node_ip.remote())
|
ip = ray.get(workers.remote_workers()[0].get_node_ip.remote())
|
||||||
port = ray.get(workers.remote_workers()[0].find_free_port.remote())
|
port = ray.get(workers.remote_workers()[0].find_free_port.remote())
|
||||||
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
||||||
|
@ -271,6 +405,7 @@ class DDPPOTrainer(PPOTrainer):
|
||||||
self.fetch_start_time = time.perf_counter()
|
self.fetch_start_time = time.perf_counter()
|
||||||
|
|
||||||
def __call__(self, items):
|
def __call__(self, items):
|
||||||
|
assert len(items) == config["num_workers"]
|
||||||
for item in items:
|
for item in items:
|
||||||
info, count = item
|
info, count = item
|
||||||
metrics = _get_shared_metrics()
|
metrics = _get_shared_metrics()
|
||||||
|
|
|
@ -34,11 +34,10 @@ class TestDDPPO(unittest.TestCase):
|
||||||
results = trainer.train()
|
results = trainer.train()
|
||||||
check_train_results(results)
|
check_train_results(results)
|
||||||
print(results)
|
print(results)
|
||||||
# Make sure, weights on all workers are the same (including
|
# Make sure, weights on all workers are the same.
|
||||||
# local one).
|
|
||||||
weights = trainer.workers.foreach_worker(lambda w: w.get_weights())
|
weights = trainer.workers.foreach_worker(lambda w: w.get_weights())
|
||||||
for w in weights[1:]:
|
for w in weights[1:]:
|
||||||
check(w, weights[0])
|
check(w, weights[1])
|
||||||
|
|
||||||
check_compute_single_action(trainer)
|
check_compute_single_action(trainer)
|
||||||
trainer.stop()
|
trainer.stop()
|
||||||
|
@ -48,15 +47,16 @@ class TestDDPPO(unittest.TestCase):
|
||||||
config = ppo.ddppo.DEFAULT_CONFIG.copy()
|
config = ppo.ddppo.DEFAULT_CONFIG.copy()
|
||||||
config["num_gpus_per_worker"] = 0
|
config["num_gpus_per_worker"] = 0
|
||||||
config["lr_schedule"] = [[0, config["lr"]], [1000, 0.0]]
|
config["lr_schedule"] = [[0, config["lr"]], [1000, 0.0]]
|
||||||
num_iterations = 3
|
num_iterations = 10
|
||||||
|
|
||||||
for _ in framework_iterator(config, "torch"):
|
for _ in framework_iterator(config, "torch"):
|
||||||
trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0")
|
trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0")
|
||||||
for _ in range(num_iterations):
|
for _ in range(num_iterations):
|
||||||
result = trainer.train()
|
result = trainer.train()
|
||||||
lr = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
|
if result["info"][LEARNER_INFO]:
|
||||||
"cur_lr"
|
lr = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][
|
||||||
]
|
LEARNER_STATS_KEY
|
||||||
|
]["cur_lr"]
|
||||||
trainer.stop()
|
trainer.stop()
|
||||||
assert lr == 0.0, "lr should anneal to 0.0"
|
assert lr == 0.0, "lr should anneal to 0.0"
|
||||||
|
|
||||||
|
|
|
@ -961,10 +961,8 @@ class Trainer(Trainable):
|
||||||
|
|
||||||
# Function defining one single training iteration's behavior.
|
# Function defining one single training iteration's behavior.
|
||||||
if self.config["_disable_execution_plan_api"]:
|
if self.config["_disable_execution_plan_api"]:
|
||||||
# TODO: Ensure remote workers are initially in sync with the
|
# Ensure remote workers are initially in sync with the local worker.
|
||||||
# local worker.
|
self.workers.sync_weights()
|
||||||
# self.workers.sync_weights()
|
|
||||||
pass # TODO: Uncommenting line above breaks tf2+eager_tracing for A3C.
|
|
||||||
# LocalIterator-creating "execution plan".
|
# LocalIterator-creating "execution plan".
|
||||||
# Only call this once here to create `self.train_exec_impl`,
|
# Only call this once here to create `self.train_exec_impl`,
|
||||||
# which is a ray.util.iter.LocalIterator that will be `next`'d
|
# which is a ray.util.iter.LocalIterator that will be `next`'d
|
||||||
|
|
20
rllib/tuned_examples/ppo/pendulum-ddppo.yaml
Normal file
20
rllib/tuned_examples/ppo/pendulum-ddppo.yaml
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
pendulum-ddppo:
|
||||||
|
env: Pendulum-v1
|
||||||
|
run: DDPPO
|
||||||
|
stop:
|
||||||
|
episode_reward_mean: -300
|
||||||
|
timesteps_total: 1500000
|
||||||
|
config:
|
||||||
|
framework: torch
|
||||||
|
rollout_fragment_length: 250
|
||||||
|
num_gpus_per_worker: 0
|
||||||
|
num_workers: 4
|
||||||
|
num_envs_per_worker: 10
|
||||||
|
observation_filter: MeanStdFilter
|
||||||
|
gamma: 0.95
|
||||||
|
sgd_minibatch_size: 50
|
||||||
|
num_sgd_iter: 5
|
||||||
|
clip_param: 0.3
|
||||||
|
vf_clip_param: 10.0
|
||||||
|
lambda: 0.1
|
||||||
|
lr: 0.00015
|
Loading…
Add table
Reference in a new issue