From 46af992efd312df73b2d509e0359b98b7a3e5592 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 19 Feb 2020 16:07:37 -0800 Subject: [PATCH] [rllib] [experimental] custom RL training pipelines (PG_pl, A2C_pl) (#7213) --- rllib/BUILD | 8 + rllib/agents/a3c/__init__.py | 4 +- rllib/agents/a3c/a2c_pipeline.py | 38 ++++ rllib/agents/a3c/tests/test_a2c.py | 34 ++++ rllib/agents/pg/__init__.py | 7 +- rllib/agents/pg/pg_pipeline.py | 24 +++ rllib/agents/pg/tests/test_pg.py | 15 +- rllib/agents/registry.py | 14 ++ rllib/agents/trainer_template.py | 23 ++- rllib/evaluation/rollout_worker.py | 9 +- rllib/utils/experimental_dsl.py | 280 +++++++++++++++++++++++++++++ 11 files changed, 447 insertions(+), 9 deletions(-) create mode 100644 rllib/agents/a3c/a2c_pipeline.py create mode 100644 rllib/agents/a3c/tests/test_a2c.py create mode 100644 rllib/agents/pg/pg_pipeline.py create mode 100644 rllib/utils/experimental_dsl.py diff --git a/rllib/BUILD b/rllib/BUILD index 5b6dfa3c6..99d0ae6c7 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -66,6 +66,14 @@ py_test( srcs = ["agents/dqn/tests/test_dqn.py"] ) +# A2CTrainer +py_test( + name = "test_a2c", + tags = ["agents_dir"], + size = "small", + srcs = ["agents/a3c/tests/test_a2c.py"] +) + # PGTrainer py_test( name = "test_pg", diff --git a/rllib/agents/a3c/__init__.py b/rllib/agents/a3c/__init__.py index 4a8480eab..3d63eed0b 100644 --- a/rllib/agents/a3c/__init__.py +++ b/rllib/agents/a3c/__init__.py @@ -1,10 +1,12 @@ from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG from ray.rllib.agents.a3c.a2c import A2CTrainer +from ray.rllib.agents.a3c.a2c_pipeline import A2CPipeline from ray.rllib.utils import renamed_agent A2CAgent = renamed_agent(A2CTrainer) A3CAgent = renamed_agent(A3CTrainer) __all__ = [ - "A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG" + "A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG", + "A2CPipeline" ] diff --git a/rllib/agents/a3c/a2c_pipeline.py b/rllib/agents/a3c/a2c_pipeline.py new file mode 100644 index 000000000..aa12cfe7f --- /dev/null +++ b/rllib/agents/a3c/a2c_pipeline.py @@ -0,0 +1,38 @@ +"""Experimental pipeline-based impl; run this with --run='A2C_pl'""" + +import math + +from ray.rllib.agents.a3c.a2c import A2CTrainer +from ray.rllib.utils.experimental_dsl import ( + ParallelRollouts, ConcatBatches, ComputeGradients, AverageGradients, + ApplyGradients, TrainOneStep, StandardMetricsReporting) + + +def training_pipeline(workers, config): + rollouts = ParallelRollouts(workers, mode="bulk_sync") + + if config["microbatch_size"]: + num_microbatches = math.ceil( + config["train_batch_size"] / config["microbatch_size"]) + # In microbatch mode, we want to compute gradients on experience + # microbatches, average a number of these microbatches, and then apply + # the averaged gradient in one SGD step. This conserves GPU memory, + # allowing for extremely large experience batches to be used. + train_op = ( + rollouts.combine( + ConcatBatches(min_batch_size=config["microbatch_size"])) + .for_each(ComputeGradients(workers)) # (grads, info) + .batch(num_microbatches) # List[(grads, info)] + .for_each(AverageGradients()) # (avg_grads, info) + .for_each(ApplyGradients(workers))) + else: + # In normal mode, we execute one SGD step per each train batch. + train_op = rollouts \ + .combine(ConcatBatches( + min_batch_size=config["train_batch_size"])) \ + .for_each(TrainOneStep(workers)) + + return StandardMetricsReporting(train_op, workers, config) + + +A2CPipeline = A2CTrainer.with_updates(training_pipeline=training_pipeline) diff --git a/rllib/agents/a3c/tests/test_a2c.py b/rllib/agents/a3c/tests/test_a2c.py new file mode 100644 index 000000000..935444b49 --- /dev/null +++ b/rllib/agents/a3c/tests/test_a2c.py @@ -0,0 +1,34 @@ +import unittest + +import ray +from ray.rllib.agents.a3c import a2c_pipeline + + +class TestA2C(unittest.TestCase): + """Sanity tests for A2CPipeline.""" + + def setUp(self): + ray.init() + + def tearDown(self): + ray.shutdown() + + def test_a2c_pipeline(ray_start_regular): + trainer = a2c_pipeline.A2CPipeline( + env="CartPole-v0", config={"min_iter_time_s": 0}) + assert isinstance(trainer.train(), dict) + + def test_a2c_pipeline_microbatch(ray_start_regular): + trainer = a2c_pipeline.A2CPipeline( + env="CartPole-v0", + config={ + "min_iter_time_s": 0, + "microbatch_size": 10 + }) + assert isinstance(trainer.train(), dict) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/pg/__init__.py b/rllib/agents/pg/__init__.py index ae4a55a81..6aa398214 100644 --- a/rllib/agents/pg/__init__.py +++ b/rllib/agents/pg/__init__.py @@ -1,7 +1,10 @@ from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG +from ray.rllib.agents.pg.pg_pipeline import PGPipeline from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, \ post_process_advantages from ray.rllib.agents.pg.pg_torch_policy import pg_torch_loss -__all__ = ["PGTrainer", "pg_tf_loss", "pg_torch_loss", - "post_process_advantages", "DEFAULT_CONFIG"] +__all__ = [ + "PGTrainer", "pg_tf_loss", "pg_torch_loss", "post_process_advantages", + "DEFAULT_CONFIG", "PGPipeline" +] diff --git a/rllib/agents/pg/pg_pipeline.py b/rllib/agents/pg/pg_pipeline.py new file mode 100644 index 000000000..23ca07ae7 --- /dev/null +++ b/rllib/agents/pg/pg_pipeline.py @@ -0,0 +1,24 @@ +"""Experimental pipeline-based impl; run this with --run='PG_pl'""" + +from ray.rllib.agents.pg.pg import PGTrainer +from ray.rllib.utils.experimental_dsl import ( + ParallelRollouts, ConcatBatches, TrainOneStep, StandardMetricsReporting) + + +def training_pipeline(workers, config): + # Collects experiences in parallel from multiple RolloutWorker actors. + rollouts = ParallelRollouts(workers, mode="bulk_sync") + + # Combine experiences batches until we hit `train_batch_size` in size. + # Then, train the policy on those experiences and update the workers. + train_op = rollouts \ + .combine(ConcatBatches( + min_batch_size=config["train_batch_size"])) \ + .for_each(TrainOneStep(workers)) + + # Add on the standard episode reward, etc. metrics reporting. This returns + # a LocalIterator[metrics_dict] representing metrics for each train step. + return StandardMetricsReporting(train_op, workers, config) + + +PGPipeline = PGTrainer.with_updates(training_pipeline=training_pipeline) diff --git a/rllib/agents/pg/tests/test_pg.py b/rllib/agents/pg/tests/test_pg.py index 9f3e1b2fa..62ff3ee0d 100644 --- a/rllib/agents/pg/tests/test_pg.py +++ b/rllib/agents/pg/tests/test_pg.py @@ -3,6 +3,7 @@ import unittest import ray import ray.rllib.agents.pg as pg +from ray.rllib.agents.pg import PGPipeline from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.torch.torch_action_dist import TorchCategorical @@ -11,8 +12,15 @@ from ray.rllib.utils import check, fc class TestPG(unittest.TestCase): + def setUp(self): + ray.init() - ray.init() + def tearDown(self): + ray.shutdown() + + def test_pg_pipeline(ray_start_regular): + trainer = PGPipeline(env="CartPole-v0", config={"min_iter_time_s": 0}) + assert isinstance(trainer.train(), dict) def test_pg_compilation(self): """Test whether a PGTrainer can be built with both frameworks.""" @@ -101,5 +109,6 @@ class TestPG(unittest.TestCase): if __name__ == "__main__": - import unittest - unittest.main(verbosity=1) + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/registry.py b/rllib/agents/registry.py index be6e0920a..1d98fc9ac 100644 --- a/rllib/agents/registry.py +++ b/rllib/agents/registry.py @@ -100,6 +100,16 @@ def _import_marwil(): return marwil.MARWILTrainer +def _import_a2c_pipeline(): + from ray.rllib.agents import a3c + return a3c.A2CPipeline + + +def _import_pg_pipeline(): + from ray.rllib.agents import pg + return pg.PGPipeline + + ALGORITHMS = { "SAC": _import_sac, "DDPG": _import_ddpg, @@ -120,6 +130,10 @@ ALGORITHMS = { "APPO": _import_appo, "DDPPO": _import_ddppo, "MARWIL": _import_marwil, + + # Experimental pipeline-based impls. + "A2C_pl": _import_a2c_pipeline, + "PG_pl": _import_pg_pipeline, } diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index 761ba77ce..f9f33a7f1 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -22,7 +22,8 @@ def build_trainer(name, after_train_result=None, collect_metrics_fn=None, before_evaluate_fn=None, - mixins=None): + mixins=None, + training_pipeline=None): """Helper function for defining a custom trainer. Functions will be run in this order to initialize the trainer: @@ -66,6 +67,8 @@ def build_trainer(name, mixins (list): list of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class + training_pipeline (func): Experimental support for custom + training pipelines. This overrides `make_policy_optimizer`. Returns: a Trainer instance that uses the specified args. @@ -100,7 +103,12 @@ def build_trainer(name, else: self.workers = self._make_workers(env_creator, policy, config, self.config["num_workers"]) - if make_policy_optimizer: + self.train_pipeline = None + self.optimizer = None + + if training_pipeline: + self.train_pipeline = training_pipeline(self.workers, config) + elif make_policy_optimizer: self.optimizer = make_policy_optimizer(self.workers, config) else: optimizer_config = dict( @@ -113,6 +121,9 @@ def build_trainer(name, @override(Trainer) def _train(self): + if self.train_pipeline: + return self._train_pipeline() + if before_train_step: before_train_step(self) prev_steps = self.optimizer.num_steps_sampled @@ -140,6 +151,14 @@ def build_trainer(name, after_train_result(self, res) return res + def _train_pipeline(self): + if before_train_step: + before_train_step(self) + res = next(self.train_pipeline) + if after_train_result: + after_train_result(self, res) + return res + @override(Trainer) def _before_evaluate(self): if before_evaluate_fn: diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 31483a64b..7c65af30f 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -5,6 +5,7 @@ import logging import pickle import ray +from ray.util.iter import ParallelIteratorWorker from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.env_context import EnvContext @@ -52,7 +53,7 @@ def get_global_worker(): @DeveloperAPI -class RolloutWorker(EvaluatorInterface): +class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker): """Common experience collection class. This class wraps a policy instance and an environment class to @@ -241,6 +242,12 @@ class RolloutWorker(EvaluatorInterface): global _global_worker _global_worker = self + def gen_rollouts(): + while True: + yield self.sample() + + ParallelIteratorWorker.__init__(self, gen_rollouts, False) + policy_config = policy_config or {} if (tf and policy_config.get("eager") and not policy_config.get("no_eager_on_workers")): diff --git a/rllib/utils/experimental_dsl.py b/rllib/utils/experimental_dsl.py new file mode 100644 index 000000000..598e700bd --- /dev/null +++ b/rllib/utils/experimental_dsl.py @@ -0,0 +1,280 @@ +"""Experimental operators for defining distributed training pipelines. + +TODO(ekl): describe the concepts.""" + +from typing import List, Any +import time + +import ray +from ray.util.iter import from_actors, LocalIterator +from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes +from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.policy.sample_batch import SampleBatch + + +def ParallelRollouts(workers: WorkerSet, + mode="bulk_sync") -> LocalIterator[SampleBatch]: + """Operator to collect experiences in parallel from rollout workers. + + If there are no remote workers, experiences will be collected serially from + the local worker instance instead. + + Arguments: + workers (WorkerSet): set of rollout workers to use. + mode (str): One of {'async', 'bulk_sync'}. + - In 'async' mode, batches are returned as soon as they are + computed by rollout workers with no order guarantees. + - In 'bulk_sync' mode, we collect one batch from each worker + and concatenate them together into a large batch to return. + + Returns: + A local iterator over experiences collected in parallel. + + Examples: + >>> rollouts = ParallelRollouts(workers, mode="async") + >>> batch = next(rollouts) + >>> print(batch.count) + 50 # config.sample_batch_size + + >>> rollouts = ParallelRollouts(workers, mode="bulk_sync") + >>> batch = next(rollouts) + >>> print(batch.count) + 200 # config.sample_batch_size * config.num_workers + """ + + if not workers.remote_workers(): + # Handle the serial sampling case. + def sampler(_): + while True: + yield workers.local_worker().sample() + + return LocalIterator(sampler) + + # Create a parallel iterator over generated experiences. + rollouts = from_actors(workers.remote_workers()) + + if mode == "bulk_sync": + return rollouts \ + .batch_across_shards() \ + .for_each(lambda batches: SampleBatch.concat_samples(batches)) + elif mode == "async": + return rollouts.gather_async() + else: + raise ValueError( + "mode must be one of 'bulk_sync', 'async', got '{}'".format(mode)) + + +def StandardMetricsReporting(train_op: LocalIterator[Any], workers: WorkerSet, + config: dict): + """Operator to periodically collect and report metrics. + + Arguments: + train_op (LocalIterator): Operator for executing training steps. + We ignore the output values. + workers (WorkerSet): Rollout workers to collect metrics from. + config (dict): Trainer configuration, used to determine the frequency + of stats reporting. + + Returns: + A local iterator over training results. + + Examples: + >>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...)) + >>> metrics_op = StandardMetricsReporting(train_op, workers, config) + >>> next(metrics_op) + {"episode_reward_max": ..., "episode_reward_mean": ..., ...} + """ + + output_op = train_op \ + .filter(OncePerTimeInterval(config["min_iter_time_s"])) \ + .for_each(CollectMetrics( + workers, min_history=config["metrics_smoothing_episodes"], + timeout_seconds=config["collect_metrics_timeout"])) + return output_op + + +class ConcatBatches: + """Callable used to merge batches into larger batches for training. + + This should be used with the .combine() operator. + + Examples: + >>> rollouts = ParallelRollouts(...) + >>> rollouts = rollouts.combine(ConcatBatches(min_batch_size=10000)) + >>> print(next(rollouts).count) + 10000 + """ + + def __init__(self, min_batch_size: int): + self.min_batch_size = min_batch_size + self.buffer = [] + self.count = 0 + + def __call__(self, batch: SampleBatch) -> List[SampleBatch]: + if not isinstance(batch, SampleBatch): + raise ValueError("Expected type SampleBatch, got {}: {}".format( + type(batch), batch)) + self.buffer.append(batch) + self.count += batch.count + if self.count >= self.min_batch_size: + out = SampleBatch.concat_samples(self.buffer) + self.buffer = [] + self.count = 0 + return [out] + return [] + + +class TrainOneStep: + """Callable that improves the policy and updates workers. + + This should be used with the .for_each() operator. + + Examples: + >>> rollouts = ParallelRollouts(...) + >>> train_op = rollouts.for_each(TrainOneStep(workers)) + >>> print(next(train_op)) # This trains the policy on one batch. + {"learner_stats": {"policy_loss": ...}} + """ + + def __init__(self, workers: WorkerSet): + self.workers = workers + + def __call__(self, batch: SampleBatch) -> List[dict]: + info = self.workers.local_worker().learn_on_batch(batch) + if self.workers.remote_workers(): + weights = ray.put(self.workers.local_worker().get_weights()) + for e in self.workers.remote_workers(): + e.set_weights.remote(weights) + return info + + +class CollectMetrics: + """Callable that collects metrics from workers. + + The metrics are smoothed over a given history window. + + This should be used with the .for_each() operator. For a higher level + API, consider using StandardMetricsReporting instead. + + Examples: + >>> output_op = train_op.for_each(CollectMetrics(workers)) + >>> print(next(output_op)) + {"episode_reward_max": ..., "episode_reward_mean": ..., ...} + """ + + def __init__(self, workers, min_history=100, timeout_seconds=180): + self.workers = workers + self.episode_history = [] + self.to_be_collected = [] + self.min_history = min_history + self.timeout_seconds = timeout_seconds + + def __call__(self, info): + episodes, self.to_be_collected = collect_episodes( + self.workers.local_worker(), + self.workers.remote_workers(), + self.to_be_collected, + timeout_seconds=self.timeout_seconds) + orig_episodes = list(episodes) + missing = self.min_history - len(episodes) + if missing > 0: + episodes.extend(self.episode_history[-missing:]) + assert len(episodes) <= self.min_history + self.episode_history.extend(orig_episodes) + self.episode_history = self.episode_history[-self.min_history:] + res = summarize_episodes(episodes, orig_episodes) + res.update(info=info) + return res + + +class OncePerTimeInterval: + """Callable that returns True once per given interval. + + This should be used with the .filter() operator to throttle / rate-limit + metrics reporting. For a higher-level API, consider using + StandardMetricsReporting instead. + + Examples: + >>> throttled_op = train_op.filter(OncePerTimeInterval(5)) + >>> start = time.time() + >>> next(throttled_op) + >>> print(time.time() - start) + 5.00001 # will be greater than 5 seconds + """ + + def __init__(self, delay): + self.delay = delay + self.last_called = 0 + + def __call__(self, item): + now = time.time() + if now - self.last_called > self.delay: + self.last_called = now + return True + return False + + +class ComputeGradients: + """Callable that computes gradients with respect to the policy loss. + + This should be used with the .for_each() operator. + + Examples: + >>> grads_op = rollouts.for_each(ComputeGradients(workers)) + >>> print(next(grads_op)) + {"var_0": ..., ...}, {"learner_stats": ...} # grads, learner info + """ + + def __init__(self, workers): + self.workers = workers + + def __call__(self, samples): + grad, info = self.workers.local_worker().compute_gradients(samples) + return grad, info + + +class ApplyGradients: + """Callable that applies gradients and updates workers. + + This should be used with the .for_each() operator. + + Examples: + >>> apply_op = grads_op.for_each(ApplyGradients(workers)) + >>> print(next(apply_op)) + {"learner_stats": ...} # learner info + """ + + def __init__(self, workers): + self.workers = workers + + def __call__(self, item): + gradients, info = item + self.workers.local_worker().apply_gradients(gradients) + if self.workers.remote_workers(): + weights = ray.put(self.workers.local_worker().get_weights()) + for e in self.workers.remote_workers(): + e.set_weights.remote(weights) + return info + + +class AverageGradients: + """Callable that averages the gradients in a batch. + + This should be used with the .for_each() operator after a set of gradients + have been batched with .batch(). + + Examples: + >>> batched_grads = grads_op.batch(32) + >>> avg_grads = batched_grads.for_each(AverageGradients()) + >>> print(next(avg_grads)) + {"var_0": ..., ...}, {"learner_stats": ...} # avg grads, last info + """ + + def __call__(self, gradients): + acc = None + for grad, info in gradients: + if acc is None: + acc = grad + else: + acc = [a + b for a, b in zip(acc, grad)] + return acc, info