[rllib] [experimental] custom RL training pipelines (PG_pl, A2C_pl) (#7213)

This commit is contained in:
Eric Liang 2020-02-19 16:07:37 -08:00 committed by GitHub
parent 7bef7031c2
commit 46af992efd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 447 additions and 9 deletions

View file

@ -66,6 +66,14 @@ py_test(
srcs = ["agents/dqn/tests/test_dqn.py"]
)
# A2CTrainer
py_test(
name = "test_a2c",
tags = ["agents_dir"],
size = "small",
srcs = ["agents/a3c/tests/test_a2c.py"]
)
# PGTrainer
py_test(
name = "test_pg",

View file

@ -1,10 +1,12 @@
from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG
from ray.rllib.agents.a3c.a2c import A2CTrainer
from ray.rllib.agents.a3c.a2c_pipeline import A2CPipeline
from ray.rllib.utils import renamed_agent
A2CAgent = renamed_agent(A2CTrainer)
A3CAgent = renamed_agent(A3CTrainer)
__all__ = [
"A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG"
"A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG",
"A2CPipeline"
]

View file

@ -0,0 +1,38 @@
"""Experimental pipeline-based impl; run this with --run='A2C_pl'"""
import math
from ray.rllib.agents.a3c.a2c import A2CTrainer
from ray.rllib.utils.experimental_dsl import (
ParallelRollouts, ConcatBatches, ComputeGradients, AverageGradients,
ApplyGradients, TrainOneStep, StandardMetricsReporting)
def training_pipeline(workers, config):
rollouts = ParallelRollouts(workers, mode="bulk_sync")
if config["microbatch_size"]:
num_microbatches = math.ceil(
config["train_batch_size"] / config["microbatch_size"])
# In microbatch mode, we want to compute gradients on experience
# microbatches, average a number of these microbatches, and then apply
# the averaged gradient in one SGD step. This conserves GPU memory,
# allowing for extremely large experience batches to be used.
train_op = (
rollouts.combine(
ConcatBatches(min_batch_size=config["microbatch_size"]))
.for_each(ComputeGradients(workers)) # (grads, info)
.batch(num_microbatches) # List[(grads, info)]
.for_each(AverageGradients()) # (avg_grads, info)
.for_each(ApplyGradients(workers)))
else:
# In normal mode, we execute one SGD step per each train batch.
train_op = rollouts \
.combine(ConcatBatches(
min_batch_size=config["train_batch_size"])) \
.for_each(TrainOneStep(workers))
return StandardMetricsReporting(train_op, workers, config)
A2CPipeline = A2CTrainer.with_updates(training_pipeline=training_pipeline)

View file

@ -0,0 +1,34 @@
import unittest
import ray
from ray.rllib.agents.a3c import a2c_pipeline
class TestA2C(unittest.TestCase):
"""Sanity tests for A2CPipeline."""
def setUp(self):
ray.init()
def tearDown(self):
ray.shutdown()
def test_a2c_pipeline(ray_start_regular):
trainer = a2c_pipeline.A2CPipeline(
env="CartPole-v0", config={"min_iter_time_s": 0})
assert isinstance(trainer.train(), dict)
def test_a2c_pipeline_microbatch(ray_start_regular):
trainer = a2c_pipeline.A2CPipeline(
env="CartPole-v0",
config={
"min_iter_time_s": 0,
"microbatch_size": 10
})
assert isinstance(trainer.train(), dict)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,7 +1,10 @@
from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG
from ray.rllib.agents.pg.pg_pipeline import PGPipeline
from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, \
post_process_advantages
from ray.rllib.agents.pg.pg_torch_policy import pg_torch_loss
__all__ = ["PGTrainer", "pg_tf_loss", "pg_torch_loss",
"post_process_advantages", "DEFAULT_CONFIG"]
__all__ = [
"PGTrainer", "pg_tf_loss", "pg_torch_loss", "post_process_advantages",
"DEFAULT_CONFIG", "PGPipeline"
]

View file

@ -0,0 +1,24 @@
"""Experimental pipeline-based impl; run this with --run='PG_pl'"""
from ray.rllib.agents.pg.pg import PGTrainer
from ray.rllib.utils.experimental_dsl import (
ParallelRollouts, ConcatBatches, TrainOneStep, StandardMetricsReporting)
def training_pipeline(workers, config):
# Collects experiences in parallel from multiple RolloutWorker actors.
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# Combine experiences batches until we hit `train_batch_size` in size.
# Then, train the policy on those experiences and update the workers.
train_op = rollouts \
.combine(ConcatBatches(
min_batch_size=config["train_batch_size"])) \
.for_each(TrainOneStep(workers))
# Add on the standard episode reward, etc. metrics reporting. This returns
# a LocalIterator[metrics_dict] representing metrics for each train step.
return StandardMetricsReporting(train_op, workers, config)
PGPipeline = PGTrainer.with_updates(training_pipeline=training_pipeline)

View file

@ -3,6 +3,7 @@ import unittest
import ray
import ray.rllib.agents.pg as pg
from ray.rllib.agents.pg import PGPipeline
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
@ -11,8 +12,15 @@ from ray.rllib.utils import check, fc
class TestPG(unittest.TestCase):
def setUp(self):
ray.init()
ray.init()
def tearDown(self):
ray.shutdown()
def test_pg_pipeline(ray_start_regular):
trainer = PGPipeline(env="CartPole-v0", config={"min_iter_time_s": 0})
assert isinstance(trainer.train(), dict)
def test_pg_compilation(self):
"""Test whether a PGTrainer can be built with both frameworks."""
@ -101,5 +109,6 @@ class TestPG(unittest.TestCase):
if __name__ == "__main__":
import unittest
unittest.main(verbosity=1)
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -100,6 +100,16 @@ def _import_marwil():
return marwil.MARWILTrainer
def _import_a2c_pipeline():
from ray.rllib.agents import a3c
return a3c.A2CPipeline
def _import_pg_pipeline():
from ray.rllib.agents import pg
return pg.PGPipeline
ALGORITHMS = {
"SAC": _import_sac,
"DDPG": _import_ddpg,
@ -120,6 +130,10 @@ ALGORITHMS = {
"APPO": _import_appo,
"DDPPO": _import_ddppo,
"MARWIL": _import_marwil,
# Experimental pipeline-based impls.
"A2C_pl": _import_a2c_pipeline,
"PG_pl": _import_pg_pipeline,
}

View file

@ -22,7 +22,8 @@ def build_trainer(name,
after_train_result=None,
collect_metrics_fn=None,
before_evaluate_fn=None,
mixins=None):
mixins=None,
training_pipeline=None):
"""Helper function for defining a custom trainer.
Functions will be run in this order to initialize the trainer:
@ -66,6 +67,8 @@ def build_trainer(name,
mixins (list): list of any class mixins for the returned trainer class.
These mixins will be applied in order and will have higher
precedence than the Trainer class
training_pipeline (func): Experimental support for custom
training pipelines. This overrides `make_policy_optimizer`.
Returns:
a Trainer instance that uses the specified args.
@ -100,7 +103,12 @@ def build_trainer(name,
else:
self.workers = self._make_workers(env_creator, policy, config,
self.config["num_workers"])
if make_policy_optimizer:
self.train_pipeline = None
self.optimizer = None
if training_pipeline:
self.train_pipeline = training_pipeline(self.workers, config)
elif make_policy_optimizer:
self.optimizer = make_policy_optimizer(self.workers, config)
else:
optimizer_config = dict(
@ -113,6 +121,9 @@ def build_trainer(name,
@override(Trainer)
def _train(self):
if self.train_pipeline:
return self._train_pipeline()
if before_train_step:
before_train_step(self)
prev_steps = self.optimizer.num_steps_sampled
@ -140,6 +151,14 @@ def build_trainer(name,
after_train_result(self, res)
return res
def _train_pipeline(self):
if before_train_step:
before_train_step(self)
res = next(self.train_pipeline)
if after_train_result:
after_train_result(self, res)
return res
@override(Trainer)
def _before_evaluate(self):
if before_evaluate_fn:

View file

@ -5,6 +5,7 @@ import logging
import pickle
import ray
from ray.util.iter import ParallelIteratorWorker
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
@ -52,7 +53,7 @@ def get_global_worker():
@DeveloperAPI
class RolloutWorker(EvaluatorInterface):
class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
"""Common experience collection class.
This class wraps a policy instance and an environment class to
@ -241,6 +242,12 @@ class RolloutWorker(EvaluatorInterface):
global _global_worker
_global_worker = self
def gen_rollouts():
while True:
yield self.sample()
ParallelIteratorWorker.__init__(self, gen_rollouts, False)
policy_config = policy_config or {}
if (tf and policy_config.get("eager")
and not policy_config.get("no_eager_on_workers")):

View file

@ -0,0 +1,280 @@
"""Experimental operators for defining distributed training pipelines.
TODO(ekl): describe the concepts."""
from typing import List, Any
import time
import ray
from ray.util.iter import from_actors, LocalIterator
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.policy.sample_batch import SampleBatch
def ParallelRollouts(workers: WorkerSet,
mode="bulk_sync") -> LocalIterator[SampleBatch]:
"""Operator to collect experiences in parallel from rollout workers.
If there are no remote workers, experiences will be collected serially from
the local worker instance instead.
Arguments:
workers (WorkerSet): set of rollout workers to use.
mode (str): One of {'async', 'bulk_sync'}.
- In 'async' mode, batches are returned as soon as they are
computed by rollout workers with no order guarantees.
- In 'bulk_sync' mode, we collect one batch from each worker
and concatenate them together into a large batch to return.
Returns:
A local iterator over experiences collected in parallel.
Examples:
>>> rollouts = ParallelRollouts(workers, mode="async")
>>> batch = next(rollouts)
>>> print(batch.count)
50 # config.sample_batch_size
>>> rollouts = ParallelRollouts(workers, mode="bulk_sync")
>>> batch = next(rollouts)
>>> print(batch.count)
200 # config.sample_batch_size * config.num_workers
"""
if not workers.remote_workers():
# Handle the serial sampling case.
def sampler(_):
while True:
yield workers.local_worker().sample()
return LocalIterator(sampler)
# Create a parallel iterator over generated experiences.
rollouts = from_actors(workers.remote_workers())
if mode == "bulk_sync":
return rollouts \
.batch_across_shards() \
.for_each(lambda batches: SampleBatch.concat_samples(batches))
elif mode == "async":
return rollouts.gather_async()
else:
raise ValueError(
"mode must be one of 'bulk_sync', 'async', got '{}'".format(mode))
def StandardMetricsReporting(train_op: LocalIterator[Any], workers: WorkerSet,
config: dict):
"""Operator to periodically collect and report metrics.
Arguments:
train_op (LocalIterator): Operator for executing training steps.
We ignore the output values.
workers (WorkerSet): Rollout workers to collect metrics from.
config (dict): Trainer configuration, used to determine the frequency
of stats reporting.
Returns:
A local iterator over training results.
Examples:
>>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
>>> metrics_op = StandardMetricsReporting(train_op, workers, config)
>>> next(metrics_op)
{"episode_reward_max": ..., "episode_reward_mean": ..., ...}
"""
output_op = train_op \
.filter(OncePerTimeInterval(config["min_iter_time_s"])) \
.for_each(CollectMetrics(
workers, min_history=config["metrics_smoothing_episodes"],
timeout_seconds=config["collect_metrics_timeout"]))
return output_op
class ConcatBatches:
"""Callable used to merge batches into larger batches for training.
This should be used with the .combine() operator.
Examples:
>>> rollouts = ParallelRollouts(...)
>>> rollouts = rollouts.combine(ConcatBatches(min_batch_size=10000))
>>> print(next(rollouts).count)
10000
"""
def __init__(self, min_batch_size: int):
self.min_batch_size = min_batch_size
self.buffer = []
self.count = 0
def __call__(self, batch: SampleBatch) -> List[SampleBatch]:
if not isinstance(batch, SampleBatch):
raise ValueError("Expected type SampleBatch, got {}: {}".format(
type(batch), batch))
self.buffer.append(batch)
self.count += batch.count
if self.count >= self.min_batch_size:
out = SampleBatch.concat_samples(self.buffer)
self.buffer = []
self.count = 0
return [out]
return []
class TrainOneStep:
"""Callable that improves the policy and updates workers.
This should be used with the .for_each() operator.
Examples:
>>> rollouts = ParallelRollouts(...)
>>> train_op = rollouts.for_each(TrainOneStep(workers))
>>> print(next(train_op)) # This trains the policy on one batch.
{"learner_stats": {"policy_loss": ...}}
"""
def __init__(self, workers: WorkerSet):
self.workers = workers
def __call__(self, batch: SampleBatch) -> List[dict]:
info = self.workers.local_worker().learn_on_batch(batch)
if self.workers.remote_workers():
weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
return info
class CollectMetrics:
"""Callable that collects metrics from workers.
The metrics are smoothed over a given history window.
This should be used with the .for_each() operator. For a higher level
API, consider using StandardMetricsReporting instead.
Examples:
>>> output_op = train_op.for_each(CollectMetrics(workers))
>>> print(next(output_op))
{"episode_reward_max": ..., "episode_reward_mean": ..., ...}
"""
def __init__(self, workers, min_history=100, timeout_seconds=180):
self.workers = workers
self.episode_history = []
self.to_be_collected = []
self.min_history = min_history
self.timeout_seconds = timeout_seconds
def __call__(self, info):
episodes, self.to_be_collected = collect_episodes(
self.workers.local_worker(),
self.workers.remote_workers(),
self.to_be_collected,
timeout_seconds=self.timeout_seconds)
orig_episodes = list(episodes)
missing = self.min_history - len(episodes)
if missing > 0:
episodes.extend(self.episode_history[-missing:])
assert len(episodes) <= self.min_history
self.episode_history.extend(orig_episodes)
self.episode_history = self.episode_history[-self.min_history:]
res = summarize_episodes(episodes, orig_episodes)
res.update(info=info)
return res
class OncePerTimeInterval:
"""Callable that returns True once per given interval.
This should be used with the .filter() operator to throttle / rate-limit
metrics reporting. For a higher-level API, consider using
StandardMetricsReporting instead.
Examples:
>>> throttled_op = train_op.filter(OncePerTimeInterval(5))
>>> start = time.time()
>>> next(throttled_op)
>>> print(time.time() - start)
5.00001 # will be greater than 5 seconds
"""
def __init__(self, delay):
self.delay = delay
self.last_called = 0
def __call__(self, item):
now = time.time()
if now - self.last_called > self.delay:
self.last_called = now
return True
return False
class ComputeGradients:
"""Callable that computes gradients with respect to the policy loss.
This should be used with the .for_each() operator.
Examples:
>>> grads_op = rollouts.for_each(ComputeGradients(workers))
>>> print(next(grads_op))
{"var_0": ..., ...}, {"learner_stats": ...} # grads, learner info
"""
def __init__(self, workers):
self.workers = workers
def __call__(self, samples):
grad, info = self.workers.local_worker().compute_gradients(samples)
return grad, info
class ApplyGradients:
"""Callable that applies gradients and updates workers.
This should be used with the .for_each() operator.
Examples:
>>> apply_op = grads_op.for_each(ApplyGradients(workers))
>>> print(next(apply_op))
{"learner_stats": ...} # learner info
"""
def __init__(self, workers):
self.workers = workers
def __call__(self, item):
gradients, info = item
self.workers.local_worker().apply_gradients(gradients)
if self.workers.remote_workers():
weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
return info
class AverageGradients:
"""Callable that averages the gradients in a batch.
This should be used with the .for_each() operator after a set of gradients
have been batched with .batch().
Examples:
>>> batched_grads = grads_op.batch(32)
>>> avg_grads = batched_grads.for_each(AverageGradients())
>>> print(next(avg_grads))
{"var_0": ..., ...}, {"learner_stats": ...} # avg grads, last info
"""
def __call__(self, gradients):
acc = None
for grad, info in gradients:
if acc is None:
acc = grad
else:
acc = [a + b for a, b in zip(acc, grad)]
return acc, info