"""An example of customizing PPO to leverage a centralized critic. Here the model and policy are hard-coded to implement a centralized critic for TwoStepGame, but you can adapt this for your own use cases. Compared to simply running `rllib/examples/two_step_game.py --run=PPO`, this centralized critic version reaches vf_explained_variance=1.0 more stably since it takes into account the opponent actions as well as the policy's. Note that this is also using two independent policies instead of weight-sharing with one. See also: centralized_critic_2.py for a simpler approach that instead modifies the environment. """ import argparse import numpy as np from gym.spaces import Discrete import os import ray from ray import tune from ray.rllib.agents.maml.maml_torch_policy import KLCoeffMixin as TorchKLCoeffMixin from ray.rllib.agents.ppo.ppo import PPOTrainer from ray.rllib.agents.ppo.ppo_tf_policy import ( PPOTFPolicy, KLCoeffMixin, ppo_surrogate_loss as tf_loss, ) from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing from ray.rllib.examples.env.two_step_game import TwoStepGame from ray.rllib.examples.models.centralized_critic_models import ( CentralizedCriticModel, TorchCentralizedCriticModel, ) from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import LearningRateSchedule, EntropyCoeffSchedule from ray.rllib.policy.torch_policy import ( LearningRateSchedule as TorchLR, EntropyCoeffSchedule as TorchEntropyCoeffSchedule, ) from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_learning_achieved from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable from ray.rllib.utils.torch_utils import convert_to_torch_tensor tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() OPPONENT_OBS = "opponent_obs" OPPONENT_ACTION = "opponent_action" parser = argparse.ArgumentParser() parser.add_argument( "--framework", choices=["tf", "tf2", "tfe", "torch"], default="tf", help="The DL framework specifier.", ) parser.add_argument( "--as-test", action="store_true", help="Whether this script should be run as a test: --stop-reward must " "be achieved within --stop-timesteps AND --stop-iters.", ) parser.add_argument( "--stop-iters", type=int, default=100, help="Number of iterations to train." ) parser.add_argument( "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train." ) parser.add_argument( "--stop-reward", type=float, default=7.99, help="Reward at which we stop training." ) class CentralizedValueMixin: """Add method to evaluate the central value function from the model.""" def __init__(self): if self.config["framework"] != "torch": self.compute_central_vf = make_tf_callable(self.get_session())( self.model.central_value_function ) else: self.compute_central_vf = self.model.central_value_function # Grabs the opponent obs/act and includes it in the experience train_batch, # and computes GAE using the central vf predictions. def centralized_critic_postprocessing( policy, sample_batch, other_agent_batches=None, episode=None ): pytorch = policy.config["framework"] == "torch" if (pytorch and hasattr(policy, "compute_central_vf")) or ( not pytorch and policy.loss_initialized() ): assert other_agent_batches is not None [(_, opponent_batch)] = list(other_agent_batches.values()) # also record the opponent obs and actions in the trajectory sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS] sample_batch[OPPONENT_ACTION] = opponent_batch[SampleBatch.ACTIONS] # overwrite default VF prediction with the central VF if args.framework == "torch": sample_batch[SampleBatch.VF_PREDS] = ( policy.compute_central_vf( convert_to_torch_tensor( sample_batch[SampleBatch.CUR_OBS], policy.device ), convert_to_torch_tensor(sample_batch[OPPONENT_OBS], policy.device), convert_to_torch_tensor( sample_batch[OPPONENT_ACTION], policy.device ), ) .cpu() .detach() .numpy() ) else: sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf( sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS], sample_batch[OPPONENT_ACTION], ) else: # Policy hasn't been initialized yet, use zeros. sample_batch[OPPONENT_OBS] = np.zeros_like(sample_batch[SampleBatch.CUR_OBS]) sample_batch[OPPONENT_ACTION] = np.zeros_like(sample_batch[SampleBatch.ACTIONS]) sample_batch[SampleBatch.VF_PREDS] = np.zeros_like( sample_batch[SampleBatch.REWARDS], dtype=np.float32 ) completed = sample_batch["dones"][-1] if completed: last_r = 0.0 else: last_r = sample_batch[SampleBatch.VF_PREDS][-1] train_batch = compute_advantages( sample_batch, last_r, policy.config["gamma"], policy.config["lambda"], use_gae=policy.config["use_gae"], ) return train_batch # Copied from PPO but optimizing the central value function. def loss_with_central_critic(policy, model, dist_class, train_batch): CentralizedValueMixin.__init__(policy) func = tf_loss if not policy.config["framework"] == "torch" else PPOTorchPolicy.loss vf_saved = model.value_function model.value_function = lambda: policy.model.central_value_function( train_batch[SampleBatch.CUR_OBS], train_batch[OPPONENT_OBS], train_batch[OPPONENT_ACTION], ) policy._central_value_out = model.value_function() loss = func(policy, model, dist_class, train_batch) model.value_function = vf_saved return loss def setup_tf_mixins(policy, obs_space, action_space, config): # Copied from PPOTFPolicy (w/o ValueNetworkMixin). KLCoeffMixin.__init__(policy, config) EntropyCoeffSchedule.__init__( policy, config["entropy_coeff"], config["entropy_coeff_schedule"] ) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) def setup_torch_mixins(policy, obs_space, action_space, config): # Copied from PPOTorchPolicy (w/o ValueNetworkMixin). TorchKLCoeffMixin.__init__(policy, config) TorchEntropyCoeffSchedule.__init__( policy, config["entropy_coeff"], config["entropy_coeff_schedule"] ) TorchLR.__init__(policy, config["lr"], config["lr_schedule"]) def central_vf_stats(policy, train_batch, grads): # Report the explained variance of the central value function. return { "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], policy._central_value_out ) } CCPPOTFPolicy = PPOTFPolicy.with_updates( name="CCPPOTFPolicy", postprocess_fn=centralized_critic_postprocessing, loss_fn=loss_with_central_critic, before_loss_init=setup_tf_mixins, grad_stats_fn=central_vf_stats, mixins=[ LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, CentralizedValueMixin, ], ) class CCPPOTorchPolicy(PPOTorchPolicy): def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) self.compute_central_vf = self.model.central_value_function @override(PPOTorchPolicy) def loss(self, model, dist_class, train_batch): return loss_with_central_critic(self, model, dist_class, train_batch) @override(PPOTorchPolicy) def postprocess_trajectory( self, sample_batch, other_agent_batches=None, episode=None ): return centralized_critic_postprocessing( self, sample_batch, other_agent_batches, episode ) class CCTrainer(PPOTrainer): @override(PPOTrainer) def get_default_policy_class(self, config): if config["framework"] == "torch": return CCPPOTorchPolicy else: return CCPPOTFPolicy if __name__ == "__main__": ray.init() args = parser.parse_args() ModelCatalog.register_custom_model( "cc_model", TorchCentralizedCriticModel if args.framework == "torch" else CentralizedCriticModel, ) config = { "env": TwoStepGame, "batch_mode": "complete_episodes", # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "num_workers": 0, "multiagent": { "policies": { "pol1": ( None, Discrete(6), TwoStepGame.action_space, { "framework": args.framework, }, ), "pol2": ( None, Discrete(6), TwoStepGame.action_space, { "framework": args.framework, }, ), }, "policy_mapping_fn": (lambda aid, **kwargs: "pol1" if aid == 0 else "pol2"), }, "model": { "custom_model": "cc_model", }, "framework": args.framework, } stop = { "training_iteration": args.stop_iters, "timesteps_total": args.stop_timesteps, "episode_reward_mean": args.stop_reward, } results = tune.run(CCTrainer, config=config, stop=stop, verbose=1) if args.as_test: check_learning_achieved(results, args.stop_reward)