[rllib] Add multi-agent examples for hand-coded policy, centralized VF (#4554)

This commit is contained in:
Eric Liang 2019-04-09 00:36:49 -07:00 committed by GitHub
parent 7f23e8431b
commit 4f46d3e9bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 227 additions and 33 deletions

View file

@ -380,6 +380,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2 /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/policy_evaluator_custom_workflow.py
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_metrics_and_callbacks.py --num-iters=2 /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_metrics_and_callbacks.py --num-iters=2

View file

@ -22,6 +22,8 @@ Training Workflows
Example of how to adjust the configuration of an environment over time. Example of how to adjust the configuration of an environment over time.
- `Custom metrics <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_metrics_and_callbacks.py>`__: - `Custom metrics <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_metrics_and_callbacks.py>`__:
Example of how to output custom training metrics to TensorBoard. Example of how to output custom training metrics to TensorBoard.
- `Using policy evaluators directly for control over the whole training workflow <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/policy_evaluator_custom_workflow.py>`__:
Example of how to use RLlib's lower-level building blocks to implement a fully customized training workflow.
Custom Envs and Models Custom Envs and Models
---------------------- ----------------------
@ -49,12 +51,16 @@ Multi-Agent and Hierarchical
- `Two-step game <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/twostep_game.py>`__: - `Two-step game <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/twostep_game.py>`__:
Example of the two-step game from the `QMIX paper <https://arxiv.org/pdf/1803.11485.pdf>`__. Example of the two-step game from the `QMIX paper <https://arxiv.org/pdf/1803.11485.pdf>`__.
- `Hand-coded policy <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_custom_policy.py>`__:
Example of running a custom hand-coded policy alongside trainable policies.
- `Weight sharing between policies <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_cartpole.py>`__: - `Weight sharing between policies <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_cartpole.py>`__:
Example of how to define weight-sharing layers between two different policies. Example of how to define weight-sharing layers between two different policies.
- `Multiple trainers <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_two_trainers.py>`__: - `Multiple trainers <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_two_trainers.py>`__:
Example of alternating training between two DQN and PPO trainers. Example of alternating training between two DQN and PPO trainers.
- `Hierarchical training <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/hierarchical_training.py>`__: - `Hierarchical training <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/hierarchical_training.py>`__:
Example of hierarchical training using the multi-agent API. Example of hierarchical training using the multi-agent API.
- `PPO with centralized value function <https://github.com/ray-project/ray/pull/3642/files>`__:
Example of customizing PPO to include a centralized value function, including a runnable script that demonstrates cooperative CartPole.
Community Examples Community Examples
------------------ ------------------

View file

@ -178,6 +178,8 @@ Custom Training Workflows
In the `basic training example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your trainer once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions <tune-usage.html#training-api>`__ that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_train_fn.py>`__. In the `basic training example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your trainer once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions <tune-usage.html#training-api>`__ that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_train_fn.py>`__.
For even finer-grained control over training, you can use RLlib's lower-level `building blocks <rllib-concepts.html>`__ directly to implement `fully customized training workflows <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/policy_evaluator_custom_workflow.py>`__.
Accessing Policy State Accessing Policy State
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~
It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.optimizer.foreach_evaluator()`` or ``trainer.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list. It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.optimizer.foreach_evaluator()`` or ``trainer.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list.

View file

@ -49,7 +49,9 @@ class EvaluatorInterface(object):
>>> ev.learn_on_batch(samples) >>> ev.learn_on_batch(samples)
""" """
return self.compute_apply(samples) grads, info = self.compute_gradients(samples)
self.apply_gradients(grads)
return info
@DeveloperAPI @DeveloperAPI
def compute_gradients(self, samples): def compute_gradients(self, samples):
@ -113,14 +115,6 @@ class EvaluatorInterface(object):
raise NotImplementedError raise NotImplementedError
@DeveloperAPI
def compute_apply(self, samples):
"""Deprecated: override learn_on_batch instead."""
grads, info = self.compute_gradients(samples)
self.apply_gradients(grads)
return info
@DeveloperAPI @DeveloperAPI
def get_host(self): def get_host(self):
"""Returns the hostname of the process running this evaluator.""" """Returns the hostname of the process running this evaluator."""

View file

@ -41,7 +41,8 @@ def get_learner_stats(grad_info):
@DeveloperAPI @DeveloperAPI
def collect_metrics(local_evaluator, remote_evaluators=[], def collect_metrics(local_evaluator=None,
remote_evaluators=[],
timeout_seconds=180): timeout_seconds=180):
"""Gathers episode metrics from PolicyEvaluator instances.""" """Gathers episode metrics from PolicyEvaluator instances."""
@ -52,7 +53,7 @@ def collect_metrics(local_evaluator, remote_evaluators=[],
@DeveloperAPI @DeveloperAPI
def collect_episodes(local_evaluator, def collect_episodes(local_evaluator=None,
remote_evaluators=[], remote_evaluators=[],
timeout_seconds=180): timeout_seconds=180):
"""Gathers new episodes metrics tuples from the given evaluators.""" """Gathers new episodes metrics tuples from the given evaluators."""
@ -69,6 +70,7 @@ def collect_episodes(local_evaluator,
"this timeout with `collect_metrics_timeout`.") "this timeout with `collect_metrics_timeout`.")
metric_lists = ray.get(collected) metric_lists = ray.get(collected)
if local_evaluator:
metric_lists.append(local_evaluator.get_metrics()) metric_lists.append(local_evaluator.get_metrics())
episodes = [] episodes = []
for metrics in metric_lists: for metrics in metric_lists:

View file

@ -564,21 +564,21 @@ class PolicyEvaluator(EvaluatorInterface):
summarize(samples))) summarize(samples)))
if isinstance(samples, MultiAgentBatch): if isinstance(samples, MultiAgentBatch):
info_out = {} info_out = {}
to_fetch = {}
if self.tf_sess is not None: if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "learn_on_batch") builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
info_out[pid], _ = (
self.policy_map[pid]._build_learn_on_batch(
builder, batch))
info_out = {k: builder.get(v) for k, v in info_out.items()}
else: else:
builder = None
for pid, batch in samples.policy_batches.items(): for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train: if pid not in self.policies_to_train:
continue continue
info_out[pid], _ = ( policy = self.policy_map[pid]
self.policy_map[pid].learn_on_batch(batch)) if builder and hasattr(policy, "_build_learn_on_batch"):
to_fetch[pid], _ = policy._build_learn_on_batch(
builder, batch)
else:
info_out[pid], _ = policy.learn_on_batch(batch)
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
else: else:
info_out, _ = ( info_out, _ = (
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples)) self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))

View file

@ -170,7 +170,9 @@ class PolicyGraph(object):
>>> ev.learn_on_batch(samples) >>> ev.learn_on_batch(samples)
""" """
return self.compute_apply(samples) grads, grad_info = self.compute_gradients(samples)
apply_info = self.apply_gradients(grads)
return grad_info, apply_info
@DeveloperAPI @DeveloperAPI
def compute_gradients(self, postprocessed_batch): def compute_gradients(self, postprocessed_batch):
@ -195,14 +197,6 @@ class PolicyGraph(object):
""" """
raise NotImplementedError raise NotImplementedError
@DeveloperAPI
def compute_apply(self, samples):
"""Deprecated: override learn_on_batch instead."""
grads, grad_info = self.compute_gradients(samples)
apply_info = self.apply_gradients(grads)
return grad_info, apply_info
@DeveloperAPI @DeveloperAPI
def get_weights(self): def get_weights(self):
"""Returns model weights. """Returns model weights.

View file

@ -0,0 +1,76 @@
"""Example of running a custom hand-coded policy alongside trainable policies.
This example has two policies:
(1) a simple PG policy
(2) a hand-coded policy that acts at random in the env (doesn't learn)
In the console output, you can see the PG policy does much better than random:
Result for PG_multi_cartpole_0:
...
policy_reward_mean:
pg_policy: 185.23
random: 21.255
...
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import gym
import ray
from ray import tune
from ray.rllib.evaluation import PolicyGraph
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
from ray.tune.registry import register_env
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=20)
class RandomPolicy(PolicyGraph):
"""Hand-coded policy that returns random actions."""
def compute_actions(self,
obs_batch,
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
"""Compute actions on a batch of observations."""
return [self.action_space.sample() for _ in obs_batch], [], {}
def learn_on_batch(self, samples):
"""No learning."""
return {}, {}
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
# Simple environment with 4 independent cartpole entities
register_env("multi_cartpole", lambda _: MultiCartpole(4))
single_env = gym.make("CartPole-v0")
obs_space = single_env.observation_space
act_space = single_env.action_space
tune.run(
"PG",
stop={"training_iteration": args.num_iters},
config={
"env": "multi_cartpole",
"multiagent": {
"policy_graphs": {
"pg_policy": (None, obs_space, act_space, {}),
"random": (RandomPolicy, obs_space, act_space, {}),
},
"policy_mapping_fn": tune.function(
lambda agent_id: ["pg_policy", "random"][agent_id % 2]),
},
},
)

View file

@ -0,0 +1,117 @@
"""Example of using policy evaluator classes directly to implement training.
Instead of using the built-in Trainer classes provided by RLlib, here we define
a custom PolicyGraph class and manually coordinate distributed sample
collection and policy optimization.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import gym
import ray
from ray import tune
from ray.rllib.evaluation import PolicyGraph, PolicyEvaluator, SampleBatch
from ray.rllib.evaluation.metrics import collect_metrics
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", action="store_true")
parser.add_argument("--num-iters", type=int, default=20)
parser.add_argument("--num-workers", type=int, default=2)
class CustomPolicy(PolicyGraph):
"""Example of a custom policy graph written from scratch.
You might find it more convenient to extend TF/TorchPolicyGraph instead
for a real policy.
"""
def __init__(self, observation_space, action_space, config):
PolicyGraph.__init__(self, observation_space, action_space, config)
# example parameter
self.w = 1.0
def compute_actions(self,
obs_batch,
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
# return random actions
return [self.action_space.sample() for _ in obs_batch], [], {}
def learn_on_batch(self, samples):
# implement your learning code here
return {}, {}
def update_some_value(self, w):
# can also call other methods on policies
self.w = w
def get_weights(self):
return {"w": self.w}
def set_weights(self, weights):
self.w = weights["w"]
def training_workflow(config, reporter):
# Setup policy and policy evaluation actors
env = gym.make("CartPole-v0")
policy = CustomPolicy(env.observation_space, env.action_space, {})
workers = [
PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"),
CustomPolicy)
for _ in range(config["num_workers"])
]
for _ in range(config["num_iters"]):
# Broadcast weights to the policy evaluation workers
weights = ray.put({"default_policy": policy.get_weights()})
for w in workers:
w.set_weights.remote(weights)
# Gather a batch of samples
T1 = SampleBatch.concat_samples(
ray.get([w.sample.remote() for w in workers]))
# Update the remote policy replicas and gather another batch of samples
new_value = policy.w * 2.0
for w in workers:
w.for_policy.remote(lambda p: p.update_some_value(new_value))
# Gather another batch of samples
T2 = SampleBatch.concat_samples(
ray.get([w.sample.remote() for w in workers]))
# Improve the policy using the T1 batch
policy.learn_on_batch(T1)
# Do some arbitrary updates based on the T2 batch
policy.update_some_value(sum(T2["rewards"]))
reporter(**collect_metrics(remote_evaluators=workers))
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
tune.run(
training_workflow,
resources_per_trial={
"gpu": 1 if args.gpu else 0,
"cpu": 1,
"extra_cpu": args.num_workers,
},
config={
"num_workers": args.num_workers,
"num_iters": args.num_iters,
},
)