mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Add multi-agent examples for hand-coded policy, centralized VF (#4554)
This commit is contained in:
parent
7f23e8431b
commit
4f46d3e9bf
9 changed files with 227 additions and 33 deletions
|
@ -380,6 +380,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||||
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2
|
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2
|
||||||
|
|
||||||
|
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||||
|
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/policy_evaluator_custom_workflow.py
|
||||||
|
|
||||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||||
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_metrics_and_callbacks.py --num-iters=2
|
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_metrics_and_callbacks.py --num-iters=2
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,8 @@ Training Workflows
|
||||||
Example of how to adjust the configuration of an environment over time.
|
Example of how to adjust the configuration of an environment over time.
|
||||||
- `Custom metrics <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_metrics_and_callbacks.py>`__:
|
- `Custom metrics <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_metrics_and_callbacks.py>`__:
|
||||||
Example of how to output custom training metrics to TensorBoard.
|
Example of how to output custom training metrics to TensorBoard.
|
||||||
|
- `Using policy evaluators directly for control over the whole training workflow <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/policy_evaluator_custom_workflow.py>`__:
|
||||||
|
Example of how to use RLlib's lower-level building blocks to implement a fully customized training workflow.
|
||||||
|
|
||||||
Custom Envs and Models
|
Custom Envs and Models
|
||||||
----------------------
|
----------------------
|
||||||
|
@ -49,12 +51,16 @@ Multi-Agent and Hierarchical
|
||||||
|
|
||||||
- `Two-step game <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/twostep_game.py>`__:
|
- `Two-step game <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/twostep_game.py>`__:
|
||||||
Example of the two-step game from the `QMIX paper <https://arxiv.org/pdf/1803.11485.pdf>`__.
|
Example of the two-step game from the `QMIX paper <https://arxiv.org/pdf/1803.11485.pdf>`__.
|
||||||
|
- `Hand-coded policy <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_custom_policy.py>`__:
|
||||||
|
Example of running a custom hand-coded policy alongside trainable policies.
|
||||||
- `Weight sharing between policies <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_cartpole.py>`__:
|
- `Weight sharing between policies <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_cartpole.py>`__:
|
||||||
Example of how to define weight-sharing layers between two different policies.
|
Example of how to define weight-sharing layers between two different policies.
|
||||||
- `Multiple trainers <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_two_trainers.py>`__:
|
- `Multiple trainers <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/multiagent_two_trainers.py>`__:
|
||||||
Example of alternating training between two DQN and PPO trainers.
|
Example of alternating training between two DQN and PPO trainers.
|
||||||
- `Hierarchical training <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/hierarchical_training.py>`__:
|
- `Hierarchical training <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/hierarchical_training.py>`__:
|
||||||
Example of hierarchical training using the multi-agent API.
|
Example of hierarchical training using the multi-agent API.
|
||||||
|
- `PPO with centralized value function <https://github.com/ray-project/ray/pull/3642/files>`__:
|
||||||
|
Example of customizing PPO to include a centralized value function, including a runnable script that demonstrates cooperative CartPole.
|
||||||
|
|
||||||
Community Examples
|
Community Examples
|
||||||
------------------
|
------------------
|
||||||
|
|
|
@ -178,6 +178,8 @@ Custom Training Workflows
|
||||||
|
|
||||||
In the `basic training example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your trainer once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions <tune-usage.html#training-api>`__ that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_train_fn.py>`__.
|
In the `basic training example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your trainer once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions <tune-usage.html#training-api>`__ that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_train_fn.py>`__.
|
||||||
|
|
||||||
|
For even finer-grained control over training, you can use RLlib's lower-level `building blocks <rllib-concepts.html>`__ directly to implement `fully customized training workflows <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/policy_evaluator_custom_workflow.py>`__.
|
||||||
|
|
||||||
Accessing Policy State
|
Accessing Policy State
|
||||||
~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.optimizer.foreach_evaluator()`` or ``trainer.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list.
|
It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.optimizer.foreach_evaluator()`` or ``trainer.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list.
|
||||||
|
|
|
@ -49,7 +49,9 @@ class EvaluatorInterface(object):
|
||||||
>>> ev.learn_on_batch(samples)
|
>>> ev.learn_on_batch(samples)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.compute_apply(samples)
|
grads, info = self.compute_gradients(samples)
|
||||||
|
self.apply_gradients(grads)
|
||||||
|
return info
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def compute_gradients(self, samples):
|
def compute_gradients(self, samples):
|
||||||
|
@ -113,14 +115,6 @@ class EvaluatorInterface(object):
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@DeveloperAPI
|
|
||||||
def compute_apply(self, samples):
|
|
||||||
"""Deprecated: override learn_on_batch instead."""
|
|
||||||
|
|
||||||
grads, info = self.compute_gradients(samples)
|
|
||||||
self.apply_gradients(grads)
|
|
||||||
return info
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_host(self):
|
def get_host(self):
|
||||||
"""Returns the hostname of the process running this evaluator."""
|
"""Returns the hostname of the process running this evaluator."""
|
||||||
|
|
|
@ -41,7 +41,8 @@ def get_learner_stats(grad_info):
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def collect_metrics(local_evaluator, remote_evaluators=[],
|
def collect_metrics(local_evaluator=None,
|
||||||
|
remote_evaluators=[],
|
||||||
timeout_seconds=180):
|
timeout_seconds=180):
|
||||||
"""Gathers episode metrics from PolicyEvaluator instances."""
|
"""Gathers episode metrics from PolicyEvaluator instances."""
|
||||||
|
|
||||||
|
@ -52,7 +53,7 @@ def collect_metrics(local_evaluator, remote_evaluators=[],
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def collect_episodes(local_evaluator,
|
def collect_episodes(local_evaluator=None,
|
||||||
remote_evaluators=[],
|
remote_evaluators=[],
|
||||||
timeout_seconds=180):
|
timeout_seconds=180):
|
||||||
"""Gathers new episodes metrics tuples from the given evaluators."""
|
"""Gathers new episodes metrics tuples from the given evaluators."""
|
||||||
|
@ -69,7 +70,8 @@ def collect_episodes(local_evaluator,
|
||||||
"this timeout with `collect_metrics_timeout`.")
|
"this timeout with `collect_metrics_timeout`.")
|
||||||
|
|
||||||
metric_lists = ray.get(collected)
|
metric_lists = ray.get(collected)
|
||||||
metric_lists.append(local_evaluator.get_metrics())
|
if local_evaluator:
|
||||||
|
metric_lists.append(local_evaluator.get_metrics())
|
||||||
episodes = []
|
episodes = []
|
||||||
for metrics in metric_lists:
|
for metrics in metric_lists:
|
||||||
episodes.extend(metrics)
|
episodes.extend(metrics)
|
||||||
|
|
|
@ -564,21 +564,21 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||||
summarize(samples)))
|
summarize(samples)))
|
||||||
if isinstance(samples, MultiAgentBatch):
|
if isinstance(samples, MultiAgentBatch):
|
||||||
info_out = {}
|
info_out = {}
|
||||||
|
to_fetch = {}
|
||||||
if self.tf_sess is not None:
|
if self.tf_sess is not None:
|
||||||
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
|
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
|
||||||
for pid, batch in samples.policy_batches.items():
|
|
||||||
if pid not in self.policies_to_train:
|
|
||||||
continue
|
|
||||||
info_out[pid], _ = (
|
|
||||||
self.policy_map[pid]._build_learn_on_batch(
|
|
||||||
builder, batch))
|
|
||||||
info_out = {k: builder.get(v) for k, v in info_out.items()}
|
|
||||||
else:
|
else:
|
||||||
for pid, batch in samples.policy_batches.items():
|
builder = None
|
||||||
if pid not in self.policies_to_train:
|
for pid, batch in samples.policy_batches.items():
|
||||||
continue
|
if pid not in self.policies_to_train:
|
||||||
info_out[pid], _ = (
|
continue
|
||||||
self.policy_map[pid].learn_on_batch(batch))
|
policy = self.policy_map[pid]
|
||||||
|
if builder and hasattr(policy, "_build_learn_on_batch"):
|
||||||
|
to_fetch[pid], _ = policy._build_learn_on_batch(
|
||||||
|
builder, batch)
|
||||||
|
else:
|
||||||
|
info_out[pid], _ = policy.learn_on_batch(batch)
|
||||||
|
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
|
||||||
else:
|
else:
|
||||||
info_out, _ = (
|
info_out, _ = (
|
||||||
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
|
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
|
||||||
|
|
|
@ -170,7 +170,9 @@ class PolicyGraph(object):
|
||||||
>>> ev.learn_on_batch(samples)
|
>>> ev.learn_on_batch(samples)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.compute_apply(samples)
|
grads, grad_info = self.compute_gradients(samples)
|
||||||
|
apply_info = self.apply_gradients(grads)
|
||||||
|
return grad_info, apply_info
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def compute_gradients(self, postprocessed_batch):
|
def compute_gradients(self, postprocessed_batch):
|
||||||
|
@ -195,14 +197,6 @@ class PolicyGraph(object):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@DeveloperAPI
|
|
||||||
def compute_apply(self, samples):
|
|
||||||
"""Deprecated: override learn_on_batch instead."""
|
|
||||||
|
|
||||||
grads, grad_info = self.compute_gradients(samples)
|
|
||||||
apply_info = self.apply_gradients(grads)
|
|
||||||
return grad_info, apply_info
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_weights(self):
|
def get_weights(self):
|
||||||
"""Returns model weights.
|
"""Returns model weights.
|
||||||
|
|
76
python/ray/rllib/examples/multiagent_custom_policy.py
Normal file
76
python/ray/rllib/examples/multiagent_custom_policy.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
"""Example of running a custom hand-coded policy alongside trainable policies.
|
||||||
|
|
||||||
|
This example has two policies:
|
||||||
|
(1) a simple PG policy
|
||||||
|
(2) a hand-coded policy that acts at random in the env (doesn't learn)
|
||||||
|
|
||||||
|
In the console output, you can see the PG policy does much better than random:
|
||||||
|
Result for PG_multi_cartpole_0:
|
||||||
|
...
|
||||||
|
policy_reward_mean:
|
||||||
|
pg_policy: 185.23
|
||||||
|
random: 21.255
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gym
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray import tune
|
||||||
|
from ray.rllib.evaluation import PolicyGraph
|
||||||
|
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
|
||||||
|
from ray.tune.registry import register_env
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-iters", type=int, default=20)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomPolicy(PolicyGraph):
|
||||||
|
"""Hand-coded policy that returns random actions."""
|
||||||
|
|
||||||
|
def compute_actions(self,
|
||||||
|
obs_batch,
|
||||||
|
state_batches,
|
||||||
|
prev_action_batch=None,
|
||||||
|
prev_reward_batch=None,
|
||||||
|
info_batch=None,
|
||||||
|
episodes=None,
|
||||||
|
**kwargs):
|
||||||
|
"""Compute actions on a batch of observations."""
|
||||||
|
return [self.action_space.sample() for _ in obs_batch], [], {}
|
||||||
|
|
||||||
|
def learn_on_batch(self, samples):
|
||||||
|
"""No learning."""
|
||||||
|
return {}, {}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
# Simple environment with 4 independent cartpole entities
|
||||||
|
register_env("multi_cartpole", lambda _: MultiCartpole(4))
|
||||||
|
single_env = gym.make("CartPole-v0")
|
||||||
|
obs_space = single_env.observation_space
|
||||||
|
act_space = single_env.action_space
|
||||||
|
|
||||||
|
tune.run(
|
||||||
|
"PG",
|
||||||
|
stop={"training_iteration": args.num_iters},
|
||||||
|
config={
|
||||||
|
"env": "multi_cartpole",
|
||||||
|
"multiagent": {
|
||||||
|
"policy_graphs": {
|
||||||
|
"pg_policy": (None, obs_space, act_space, {}),
|
||||||
|
"random": (RandomPolicy, obs_space, act_space, {}),
|
||||||
|
},
|
||||||
|
"policy_mapping_fn": tune.function(
|
||||||
|
lambda agent_id: ["pg_policy", "random"][agent_id % 2]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
117
python/ray/rllib/examples/policy_evaluator_custom_workflow.py
Normal file
117
python/ray/rllib/examples/policy_evaluator_custom_workflow.py
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
"""Example of using policy evaluator classes directly to implement training.
|
||||||
|
|
||||||
|
Instead of using the built-in Trainer classes provided by RLlib, here we define
|
||||||
|
a custom PolicyGraph class and manually coordinate distributed sample
|
||||||
|
collection and policy optimization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gym
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray import tune
|
||||||
|
from ray.rllib.evaluation import PolicyGraph, PolicyEvaluator, SampleBatch
|
||||||
|
from ray.rllib.evaluation.metrics import collect_metrics
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--gpu", action="store_true")
|
||||||
|
parser.add_argument("--num-iters", type=int, default=20)
|
||||||
|
parser.add_argument("--num-workers", type=int, default=2)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomPolicy(PolicyGraph):
|
||||||
|
"""Example of a custom policy graph written from scratch.
|
||||||
|
|
||||||
|
You might find it more convenient to extend TF/TorchPolicyGraph instead
|
||||||
|
for a real policy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, observation_space, action_space, config):
|
||||||
|
PolicyGraph.__init__(self, observation_space, action_space, config)
|
||||||
|
# example parameter
|
||||||
|
self.w = 1.0
|
||||||
|
|
||||||
|
def compute_actions(self,
|
||||||
|
obs_batch,
|
||||||
|
state_batches,
|
||||||
|
prev_action_batch=None,
|
||||||
|
prev_reward_batch=None,
|
||||||
|
info_batch=None,
|
||||||
|
episodes=None,
|
||||||
|
**kwargs):
|
||||||
|
# return random actions
|
||||||
|
return [self.action_space.sample() for _ in obs_batch], [], {}
|
||||||
|
|
||||||
|
def learn_on_batch(self, samples):
|
||||||
|
# implement your learning code here
|
||||||
|
return {}, {}
|
||||||
|
|
||||||
|
def update_some_value(self, w):
|
||||||
|
# can also call other methods on policies
|
||||||
|
self.w = w
|
||||||
|
|
||||||
|
def get_weights(self):
|
||||||
|
return {"w": self.w}
|
||||||
|
|
||||||
|
def set_weights(self, weights):
|
||||||
|
self.w = weights["w"]
|
||||||
|
|
||||||
|
|
||||||
|
def training_workflow(config, reporter):
|
||||||
|
# Setup policy and policy evaluation actors
|
||||||
|
env = gym.make("CartPole-v0")
|
||||||
|
policy = CustomPolicy(env.observation_space, env.action_space, {})
|
||||||
|
workers = [
|
||||||
|
PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"),
|
||||||
|
CustomPolicy)
|
||||||
|
for _ in range(config["num_workers"])
|
||||||
|
]
|
||||||
|
|
||||||
|
for _ in range(config["num_iters"]):
|
||||||
|
# Broadcast weights to the policy evaluation workers
|
||||||
|
weights = ray.put({"default_policy": policy.get_weights()})
|
||||||
|
for w in workers:
|
||||||
|
w.set_weights.remote(weights)
|
||||||
|
|
||||||
|
# Gather a batch of samples
|
||||||
|
T1 = SampleBatch.concat_samples(
|
||||||
|
ray.get([w.sample.remote() for w in workers]))
|
||||||
|
|
||||||
|
# Update the remote policy replicas and gather another batch of samples
|
||||||
|
new_value = policy.w * 2.0
|
||||||
|
for w in workers:
|
||||||
|
w.for_policy.remote(lambda p: p.update_some_value(new_value))
|
||||||
|
|
||||||
|
# Gather another batch of samples
|
||||||
|
T2 = SampleBatch.concat_samples(
|
||||||
|
ray.get([w.sample.remote() for w in workers]))
|
||||||
|
|
||||||
|
# Improve the policy using the T1 batch
|
||||||
|
policy.learn_on_batch(T1)
|
||||||
|
|
||||||
|
# Do some arbitrary updates based on the T2 batch
|
||||||
|
policy.update_some_value(sum(T2["rewards"]))
|
||||||
|
|
||||||
|
reporter(**collect_metrics(remote_evaluators=workers))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
tune.run(
|
||||||
|
training_workflow,
|
||||||
|
resources_per_trial={
|
||||||
|
"gpu": 1 if args.gpu else 0,
|
||||||
|
"cpu": 1,
|
||||||
|
"extra_cpu": args.num_workers,
|
||||||
|
},
|
||||||
|
config={
|
||||||
|
"num_workers": args.num_workers,
|
||||||
|
"num_iters": args.num_iters,
|
||||||
|
},
|
||||||
|
)
|
Loading…
Add table
Reference in a new issue