"""An example of customizing PPO to leverage a centralized critic. Here the model and policy are hard-coded to implement a centralized critic for TwoStepGame, but you can adapt this for your own use cases. Compared to simply running `rllib/examples/two_step_game.py --run=PPO`, this centralized critic version reaches vf_explained_variance=1.0 more stably since it takes into account the opponent actions as well as the policy's. Note that this is also using two independent policies instead of weight-sharing with one. See also: centralized_critic_2.py for a simpler approach that instead modifies the environment. """ import argparse import numpy as np from gym.spaces import Discrete import os import ray from ray import tune from ray.rllib.agents.ppo.ppo import PPOTrainer from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy, KLCoeffMixin, \ ppo_surrogate_loss as tf_loss from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy, \ KLCoeffMixin as TorchKLCoeffMixin, ppo_surrogate_loss as torch_loss from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing from ray.rllib.examples.env.two_step_game import TwoStepGame from ray.rllib.examples.models.centralized_critic_models import \ CentralizedCriticModel, TorchCentralizedCriticModel from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import LearningRateSchedule, \ EntropyCoeffSchedule from ray.rllib.policy.torch_policy import LearningRateSchedule as TorchLR, \ EntropyCoeffSchedule as TorchEntropyCoeffSchedule from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_learning_achieved from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable from ray.rllib.utils.torch_ops import convert_to_torch_tensor tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() OPPONENT_OBS = "opponent_obs" OPPONENT_ACTION = "opponent_action" parser = argparse.ArgumentParser() parser.add_argument( "--framework", choices=["tf", "tf2", "tfe", "torch"], default="tf", help="The DL framework specifier.") parser.add_argument( "--as-test", action="store_true", help="Whether this script should be run as a test: --stop-reward must " "be achieved within --stop-timesteps AND --stop-iters.") parser.add_argument( "--stop-iters", type=int, default=100, help="Number of iterations to train.") parser.add_argument( "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train.") parser.add_argument( "--stop-reward", type=float, default=7.99, help="Reward at which we stop training.") class CentralizedValueMixin: """Add method to evaluate the central value function from the model.""" def __init__(self): if self.config["framework"] != "torch": self.compute_central_vf = make_tf_callable(self.get_session())( self.model.central_value_function) else: self.compute_central_vf = self.model.central_value_function # Grabs the opponent obs/act and includes it in the experience train_batch, # and computes GAE using the central vf predictions. def centralized_critic_postprocessing(policy, sample_batch, other_agent_batches=None, episode=None): pytorch = policy.config["framework"] == "torch" if (pytorch and hasattr(policy, "compute_central_vf")) or \ (not pytorch and policy.loss_initialized()): assert other_agent_batches is not None [(_, opponent_batch)] = list(other_agent_batches.values()) # also record the opponent obs and actions in the trajectory sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS] sample_batch[OPPONENT_ACTION] = opponent_batch[SampleBatch.ACTIONS] # overwrite default VF prediction with the central VF if args.framework == "torch": sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf( convert_to_torch_tensor( sample_batch[SampleBatch.CUR_OBS], policy.device), convert_to_torch_tensor( sample_batch[OPPONENT_OBS], policy.device), convert_to_torch_tensor( sample_batch[OPPONENT_ACTION], policy.device)) \ .cpu().detach().numpy() else: sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf( sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS], sample_batch[OPPONENT_ACTION]) else: # Policy hasn't been initialized yet, use zeros. sample_batch[OPPONENT_OBS] = np.zeros_like( sample_batch[SampleBatch.CUR_OBS]) sample_batch[OPPONENT_ACTION] = np.zeros_like( sample_batch[SampleBatch.ACTIONS]) sample_batch[SampleBatch.VF_PREDS] = np.zeros_like( sample_batch[SampleBatch.REWARDS], dtype=np.float32) completed = sample_batch["dones"][-1] if completed: last_r = 0.0 else: last_r = sample_batch[SampleBatch.VF_PREDS][-1] train_batch = compute_advantages( sample_batch, last_r, policy.config["gamma"], policy.config["lambda"], use_gae=policy.config["use_gae"]) return train_batch # Copied from PPO but optimizing the central value function. def loss_with_central_critic(policy, model, dist_class, train_batch): CentralizedValueMixin.__init__(policy) func = tf_loss if not policy.config["framework"] == "torch" else torch_loss vf_saved = model.value_function model.value_function = lambda: policy.model.central_value_function( train_batch[SampleBatch.CUR_OBS], train_batch[OPPONENT_OBS], train_batch[OPPONENT_ACTION]) policy._central_value_out = model.value_function() loss = func(policy, model, dist_class, train_batch) model.value_function = vf_saved return loss def setup_tf_mixins(policy, obs_space, action_space, config): # Copied from PPOTFPolicy (w/o ValueNetworkMixin). KLCoeffMixin.__init__(policy, config) EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], config["entropy_coeff_schedule"]) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) def setup_torch_mixins(policy, obs_space, action_space, config): # Copied from PPOTorchPolicy (w/o ValueNetworkMixin). TorchKLCoeffMixin.__init__(policy, config) TorchEntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], config["entropy_coeff_schedule"]) TorchLR.__init__(policy, config["lr"], config["lr_schedule"]) def central_vf_stats(policy, train_batch, grads): # Report the explained variance of the central value function. return { "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], policy._central_value_out), } CCPPOTFPolicy = PPOTFPolicy.with_updates( name="CCPPOTFPolicy", postprocess_fn=centralized_critic_postprocessing, loss_fn=loss_with_central_critic, before_loss_init=setup_tf_mixins, grad_stats_fn=central_vf_stats, mixins=[ LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, CentralizedValueMixin ]) CCPPOTorchPolicy = PPOTorchPolicy.with_updates( name="CCPPOTorchPolicy", postprocess_fn=centralized_critic_postprocessing, loss_fn=loss_with_central_critic, before_init=setup_torch_mixins, mixins=[ TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin, CentralizedValueMixin ]) def get_policy_class(config): if config["framework"] == "torch": return CCPPOTorchPolicy CCTrainer = PPOTrainer.with_updates( name="CCPPOTrainer", default_policy=CCPPOTFPolicy, get_policy_class=get_policy_class, ) if __name__ == "__main__": ray.init() args = parser.parse_args() ModelCatalog.register_custom_model( "cc_model", TorchCentralizedCriticModel if args.framework == "torch" else CentralizedCriticModel) config = { "env": TwoStepGame, "batch_mode": "complete_episodes", # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "num_workers": 0, "multiagent": { "policies": { "pol1": (None, Discrete(6), TwoStepGame.action_space, { "framework": args.framework, }), "pol2": (None, Discrete(6), TwoStepGame.action_space, { "framework": args.framework, }), }, "policy_mapping_fn": ( lambda aid, **kwargs: "pol1" if aid == 0 else "pol2"), }, "model": { "custom_model": "cc_model", }, "framework": args.framework, } stop = { "training_iteration": args.stop_iters, "timesteps_total": args.stop_timesteps, "episode_reward_mean": args.stop_reward, } results = tune.run(CCTrainer, config=config, stop=stop, verbose=1) if args.as_test: check_learning_achieved(results, args.stop_reward)