ray/rllib/examples/two_trainer_workflow.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

201 lines
6.5 KiB
Python

"""Example of using a custom training workflow.
Here we create a number of CartPole agents, some of which are trained with
DQN, and some of which are trained with PPO. Both are executed concurrently
via a custom training workflow.
"""
import argparse
import os
import ray
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_CONFIG
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as PPO_CONFIG
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import _get_shared_metrics
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
ConcatBatches,
StandardizeFields,
SelectExperiences,
)
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.tune.registry import register_env
parser = argparse.ArgumentParser()
parser.add_argument("--torch", action="store_true")
parser.add_argument("--mixed-torch-tf", action="store_true")
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a test: --stop-reward must "
"be achieved within --stop-timesteps AND --stop-iters.",
)
parser.add_argument(
"--stop-iters", type=int, default=20, help="Number of iterations to train."
)
parser.add_argument(
"--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
)
parser.add_argument(
"--stop-reward", type=float, default=150.0, help="Reward at which we stop training."
)
def custom_training_workflow(workers: WorkerSet, config: dict):
local_replay_buffer = MultiAgentReplayBuffer(
num_shards=1, learning_starts=1000, capacity=50000, replay_batch_size=64
)
def add_ppo_metrics(batch):
print(
"PPO policy learning on samples from",
batch.policy_batches.keys(),
"env steps",
batch.env_steps(),
"agent steps",
batch.env_steps(),
)
metrics = _get_shared_metrics()
metrics.counters["agent_steps_trained_PPO"] += batch.env_steps()
return batch
def add_dqn_metrics(batch):
print(
"DQN policy learning on samples from",
batch.policy_batches.keys(),
"env steps",
batch.env_steps(),
"agent steps",
batch.env_steps(),
)
metrics = _get_shared_metrics()
metrics.counters["agent_steps_trained_DQN"] += batch.env_steps()
return batch
# Generate common experiences.
rollouts = ParallelRollouts(workers, mode="bulk_sync")
r1, r2 = rollouts.duplicate(n=2)
# DQN sub-flow.
dqn_store_op = r1.for_each(SelectExperiences(["dqn_policy"])).for_each(
StoreToReplayBuffer(local_buffer=local_replay_buffer)
)
dqn_replay_op = (
Replay(local_buffer=local_replay_buffer)
.for_each(add_dqn_metrics)
.for_each(TrainOneStep(workers, policies=["dqn_policy"]))
.for_each(
UpdateTargetNetwork(
workers, target_update_freq=500, policies=["dqn_policy"]
)
)
)
dqn_train_op = Concurrently(
[dqn_store_op, dqn_replay_op], mode="round_robin", output_indexes=[1]
)
# PPO sub-flow.
ppo_train_op = (
r2.for_each(SelectExperiences(["ppo_policy"]))
.combine(ConcatBatches(min_batch_size=200, count_steps_by="env_steps"))
.for_each(add_ppo_metrics)
.for_each(StandardizeFields(["advantages"]))
.for_each(
TrainOneStep(
workers,
policies=["ppo_policy"],
num_sgd_iter=10,
sgd_minibatch_size=128,
)
)
)
# Combined training flow
train_op = Concurrently(
[ppo_train_op, dqn_train_op], mode="async", output_indexes=[1]
)
return StandardMetricsReporting(train_op, workers, config)
if __name__ == "__main__":
args = parser.parse_args()
assert not (
args.torch and args.mixed_torch_tf
), "Use either --torch or --mixed-torch-tf, not both!"
ray.init()
# Simple environment with 4 independent cartpole entities
register_env(
"multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 4})
)
# Note that since the trainer below does not include a default policy or
# policy configs, we have to explicitly set it in the multiagent config:
policies = {
"ppo_policy": (
PPOTorchPolicy if args.torch or args.mixed_torch_tf else PPOTFPolicy,
None,
None,
PPO_CONFIG,
),
"dqn_policy": (
DQNTorchPolicy if args.torch else DQNTFPolicy,
None,
None,
DQN_CONFIG,
),
}
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
if agent_id % 2 == 0:
return "ppo_policy"
else:
return "dqn_policy"
MyTrainer = build_trainer(
name="PPO_DQN_MultiAgent",
default_policy=None,
execution_plan=custom_training_workflow,
)
config = {
"rollout_fragment_length": 50,
"num_workers": 0,
"env": "multi_agent_cartpole",
"multiagent": {
"policies": policies,
"policy_mapping_fn": policy_mapping_fn,
"policies_to_train": ["dqn_policy", "ppo_policy"],
},
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": "torch" if args.torch else "tf",
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results = tune.run(MyTrainer, config=config, stop=stop)
if args.as_test:
check_learning_achieved(results, args.stop_reward)
ray.shutdown()