From 4f46d3e9bfc13ab9c60259b281f895f6e197f0e3 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 9 Apr 2019 00:36:49 -0700 Subject: [PATCH] [rllib] Add multi-agent examples for hand-coded policy, centralized VF (#4554) --- ci/jenkins_tests/run_rllib_tests.sh | 3 + doc/source/rllib-examples.rst | 6 + doc/source/rllib-training.rst | 2 + python/ray/rllib/evaluation/interface.py | 12 +- python/ray/rllib/evaluation/metrics.py | 8 +- .../ray/rllib/evaluation/policy_evaluator.py | 24 ++-- python/ray/rllib/evaluation/policy_graph.py | 12 +- .../examples/multiagent_custom_policy.py | 76 ++++++++++++ .../policy_evaluator_custom_workflow.py | 117 ++++++++++++++++++ 9 files changed, 227 insertions(+), 33 deletions(-) create mode 100644 python/ray/rllib/examples/multiagent_custom_policy.py create mode 100644 python/ray/rllib/examples/policy_evaluator_custom_workflow.py diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index e3ba467b1..89939bb2a 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -380,6 +380,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2 +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/examples/policy_evaluator_custom_workflow.py + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_metrics_and_callbacks.py --num-iters=2 diff --git a/doc/source/rllib-examples.rst b/doc/source/rllib-examples.rst index 4edc5076f..f26e078ea 100644 --- a/doc/source/rllib-examples.rst +++ b/doc/source/rllib-examples.rst @@ -22,6 +22,8 @@ Training Workflows Example of how to adjust the configuration of an environment over time. - `Custom metrics `__: Example of how to output custom training metrics to TensorBoard. +- `Using policy evaluators directly for control over the whole training workflow `__: + Example of how to use RLlib's lower-level building blocks to implement a fully customized training workflow. Custom Envs and Models ---------------------- @@ -49,12 +51,16 @@ Multi-Agent and Hierarchical - `Two-step game `__: Example of the two-step game from the `QMIX paper `__. +- `Hand-coded policy `__: + Example of running a custom hand-coded policy alongside trainable policies. - `Weight sharing between policies `__: Example of how to define weight-sharing layers between two different policies. - `Multiple trainers `__: Example of alternating training between two DQN and PPO trainers. - `Hierarchical training `__: Example of hierarchical training using the multi-agent API. +- `PPO with centralized value function `__: + Example of customizing PPO to include a centralized value function, including a runnable script that demonstrates cooperative CartPole. Community Examples ------------------ diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index cb55ac86c..ef4f29295 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -178,6 +178,8 @@ Custom Training Workflows In the `basic training example `__, Tune will call ``train()`` on your trainer once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions `__ that can be used to implement `custom training workflows (example) `__. +For even finer-grained control over training, you can use RLlib's lower-level `building blocks `__ directly to implement `fully customized training workflows `__. + Accessing Policy State ~~~~~~~~~~~~~~~~~~~~~~ It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.optimizer.foreach_evaluator()`` or ``trainer.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list. diff --git a/python/ray/rllib/evaluation/interface.py b/python/ray/rllib/evaluation/interface.py index e1c0b9108..eb705a99b 100644 --- a/python/ray/rllib/evaluation/interface.py +++ b/python/ray/rllib/evaluation/interface.py @@ -49,7 +49,9 @@ class EvaluatorInterface(object): >>> ev.learn_on_batch(samples) """ - return self.compute_apply(samples) + grads, info = self.compute_gradients(samples) + self.apply_gradients(grads) + return info @DeveloperAPI def compute_gradients(self, samples): @@ -113,14 +115,6 @@ class EvaluatorInterface(object): raise NotImplementedError - @DeveloperAPI - def compute_apply(self, samples): - """Deprecated: override learn_on_batch instead.""" - - grads, info = self.compute_gradients(samples) - self.apply_gradients(grads) - return info - @DeveloperAPI def get_host(self): """Returns the hostname of the process running this evaluator.""" diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index a8fa64b1c..6b1c766c4 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -41,7 +41,8 @@ def get_learner_stats(grad_info): @DeveloperAPI -def collect_metrics(local_evaluator, remote_evaluators=[], +def collect_metrics(local_evaluator=None, + remote_evaluators=[], timeout_seconds=180): """Gathers episode metrics from PolicyEvaluator instances.""" @@ -52,7 +53,7 @@ def collect_metrics(local_evaluator, remote_evaluators=[], @DeveloperAPI -def collect_episodes(local_evaluator, +def collect_episodes(local_evaluator=None, remote_evaluators=[], timeout_seconds=180): """Gathers new episodes metrics tuples from the given evaluators.""" @@ -69,7 +70,8 @@ def collect_episodes(local_evaluator, "this timeout with `collect_metrics_timeout`.") metric_lists = ray.get(collected) - metric_lists.append(local_evaluator.get_metrics()) + if local_evaluator: + metric_lists.append(local_evaluator.get_metrics()) episodes = [] for metrics in metric_lists: episodes.extend(metrics) diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index 4c5a80c86..0993eed3c 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -564,21 +564,21 @@ class PolicyEvaluator(EvaluatorInterface): summarize(samples))) if isinstance(samples, MultiAgentBatch): info_out = {} + to_fetch = {} if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "learn_on_batch") - for pid, batch in samples.policy_batches.items(): - if pid not in self.policies_to_train: - continue - info_out[pid], _ = ( - self.policy_map[pid]._build_learn_on_batch( - builder, batch)) - info_out = {k: builder.get(v) for k, v in info_out.items()} else: - for pid, batch in samples.policy_batches.items(): - if pid not in self.policies_to_train: - continue - info_out[pid], _ = ( - self.policy_map[pid].learn_on_batch(batch)) + builder = None + for pid, batch in samples.policy_batches.items(): + if pid not in self.policies_to_train: + continue + policy = self.policy_map[pid] + if builder and hasattr(policy, "_build_learn_on_batch"): + to_fetch[pid], _ = policy._build_learn_on_batch( + builder, batch) + else: + info_out[pid], _ = policy.learn_on_batch(batch) + info_out.update({k: builder.get(v) for k, v in to_fetch.items()}) else: info_out, _ = ( self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples)) diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index 20290fb1c..c3b101467 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -170,7 +170,9 @@ class PolicyGraph(object): >>> ev.learn_on_batch(samples) """ - return self.compute_apply(samples) + grads, grad_info = self.compute_gradients(samples) + apply_info = self.apply_gradients(grads) + return grad_info, apply_info @DeveloperAPI def compute_gradients(self, postprocessed_batch): @@ -195,14 +197,6 @@ class PolicyGraph(object): """ raise NotImplementedError - @DeveloperAPI - def compute_apply(self, samples): - """Deprecated: override learn_on_batch instead.""" - - grads, grad_info = self.compute_gradients(samples) - apply_info = self.apply_gradients(grads) - return grad_info, apply_info - @DeveloperAPI def get_weights(self): """Returns model weights. diff --git a/python/ray/rllib/examples/multiagent_custom_policy.py b/python/ray/rllib/examples/multiagent_custom_policy.py new file mode 100644 index 000000000..14c309c9a --- /dev/null +++ b/python/ray/rllib/examples/multiagent_custom_policy.py @@ -0,0 +1,76 @@ +"""Example of running a custom hand-coded policy alongside trainable policies. + +This example has two policies: + (1) a simple PG policy + (2) a hand-coded policy that acts at random in the env (doesn't learn) + +In the console output, you can see the PG policy does much better than random: +Result for PG_multi_cartpole_0: + ... + policy_reward_mean: + pg_policy: 185.23 + random: 21.255 + ... +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import gym + +import ray +from ray import tune +from ray.rllib.evaluation import PolicyGraph +from ray.rllib.tests.test_multi_agent_env import MultiCartpole +from ray.tune.registry import register_env + +parser = argparse.ArgumentParser() +parser.add_argument("--num-iters", type=int, default=20) + + +class RandomPolicy(PolicyGraph): + """Hand-coded policy that returns random actions.""" + + def compute_actions(self, + obs_batch, + state_batches, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + """Compute actions on a batch of observations.""" + return [self.action_space.sample() for _ in obs_batch], [], {} + + def learn_on_batch(self, samples): + """No learning.""" + return {}, {} + + +if __name__ == "__main__": + args = parser.parse_args() + ray.init() + + # Simple environment with 4 independent cartpole entities + register_env("multi_cartpole", lambda _: MultiCartpole(4)) + single_env = gym.make("CartPole-v0") + obs_space = single_env.observation_space + act_space = single_env.action_space + + tune.run( + "PG", + stop={"training_iteration": args.num_iters}, + config={ + "env": "multi_cartpole", + "multiagent": { + "policy_graphs": { + "pg_policy": (None, obs_space, act_space, {}), + "random": (RandomPolicy, obs_space, act_space, {}), + }, + "policy_mapping_fn": tune.function( + lambda agent_id: ["pg_policy", "random"][agent_id % 2]), + }, + }, + ) diff --git a/python/ray/rllib/examples/policy_evaluator_custom_workflow.py b/python/ray/rllib/examples/policy_evaluator_custom_workflow.py new file mode 100644 index 000000000..0fa01b303 --- /dev/null +++ b/python/ray/rllib/examples/policy_evaluator_custom_workflow.py @@ -0,0 +1,117 @@ +"""Example of using policy evaluator classes directly to implement training. + +Instead of using the built-in Trainer classes provided by RLlib, here we define +a custom PolicyGraph class and manually coordinate distributed sample +collection and policy optimization. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import gym + +import ray +from ray import tune +from ray.rllib.evaluation import PolicyGraph, PolicyEvaluator, SampleBatch +from ray.rllib.evaluation.metrics import collect_metrics + +parser = argparse.ArgumentParser() +parser.add_argument("--gpu", action="store_true") +parser.add_argument("--num-iters", type=int, default=20) +parser.add_argument("--num-workers", type=int, default=2) + + +class CustomPolicy(PolicyGraph): + """Example of a custom policy graph written from scratch. + + You might find it more convenient to extend TF/TorchPolicyGraph instead + for a real policy. + """ + + def __init__(self, observation_space, action_space, config): + PolicyGraph.__init__(self, observation_space, action_space, config) + # example parameter + self.w = 1.0 + + def compute_actions(self, + obs_batch, + state_batches, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + # return random actions + return [self.action_space.sample() for _ in obs_batch], [], {} + + def learn_on_batch(self, samples): + # implement your learning code here + return {}, {} + + def update_some_value(self, w): + # can also call other methods on policies + self.w = w + + def get_weights(self): + return {"w": self.w} + + def set_weights(self, weights): + self.w = weights["w"] + + +def training_workflow(config, reporter): + # Setup policy and policy evaluation actors + env = gym.make("CartPole-v0") + policy = CustomPolicy(env.observation_space, env.action_space, {}) + workers = [ + PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"), + CustomPolicy) + for _ in range(config["num_workers"]) + ] + + for _ in range(config["num_iters"]): + # Broadcast weights to the policy evaluation workers + weights = ray.put({"default_policy": policy.get_weights()}) + for w in workers: + w.set_weights.remote(weights) + + # Gather a batch of samples + T1 = SampleBatch.concat_samples( + ray.get([w.sample.remote() for w in workers])) + + # Update the remote policy replicas and gather another batch of samples + new_value = policy.w * 2.0 + for w in workers: + w.for_policy.remote(lambda p: p.update_some_value(new_value)) + + # Gather another batch of samples + T2 = SampleBatch.concat_samples( + ray.get([w.sample.remote() for w in workers])) + + # Improve the policy using the T1 batch + policy.learn_on_batch(T1) + + # Do some arbitrary updates based on the T2 batch + policy.update_some_value(sum(T2["rewards"])) + + reporter(**collect_metrics(remote_evaluators=workers)) + + +if __name__ == "__main__": + args = parser.parse_args() + ray.init() + + tune.run( + training_workflow, + resources_per_trial={ + "gpu": 1 if args.gpu else 0, + "cpu": 1, + "extra_cpu": args.num_workers, + }, + config={ + "num_workers": args.num_workers, + "num_iters": args.num_iters, + }, + )