"""An example of customizing PPO to leverage a centralized critic. Here the model and policy are hard-coded to implement a centralized critic for TwoStepGame, but you can adapt this for your own use cases. Compared to simply running `twostep_game.py --run=PPO`, this centralized critic version reaches vf_explained_variance=1.0 more stably since it takes into account the opponent actions as well as the policy's. Note that this is also using two independent policies instead of weight-sharing with one. See also: centralized_critic_2.py for a simpler approach that instead modifies the environment. """ import argparse import numpy as np from gym.spaces import Discrete import ray from ray import tune from ray.rllib.agents.ppo.ppo import PPOTrainer from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy, KLCoeffMixin, \ PPOLoss as TFLoss from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy, \ KLCoeffMixin as TorchKLCoeffMixin, PPOLoss as TorchLoss from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing from ray.rllib.examples.env.two_step_game import TwoStepGame from ray.rllib.examples.models.centralized_critic_models import \ CentralizedCriticModel, TorchCentralizedCriticModel from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import LearningRateSchedule, \ EntropyCoeffSchedule from ray.rllib.policy.torch_policy import LearningRateSchedule as TorchLR, \ EntropyCoeffSchedule as TorchEntropyCoeffSchedule from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_learning_achieved from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable from ray.rllib.utils.torch_ops import convert_to_torch_tensor tf = try_import_tf() torch, nn = try_import_torch() OPPONENT_OBS = "opponent_obs" OPPONENT_ACTION = "opponent_action" parser = argparse.ArgumentParser() parser.add_argument("--torch", action="store_true") parser.add_argument("--as-test", action="store_true") parser.add_argument("--stop-iters", type=int, default=100) parser.add_argument("--stop-timesteps", type=int, default=100000) parser.add_argument("--stop-reward", type=float, default=7.99) class CentralizedValueMixin: """Add method to evaluate the central value function from the model.""" def __init__(self): if self.config["framework"] != "torch": self.compute_central_vf = make_tf_callable(self.get_session())( self.model.central_value_function) else: self.compute_central_vf = self.model.central_value_function # Grabs the opponent obs/act and includes it in the experience train_batch, # and computes GAE using the central vf predictions. def centralized_critic_postprocessing(policy, sample_batch, other_agent_batches=None, episode=None): pytorch = policy.config["framework"] == "torch" if (pytorch and hasattr(policy, "compute_central_vf")) or \ (not pytorch and policy.loss_initialized()): assert other_agent_batches is not None [(_, opponent_batch)] = list(other_agent_batches.values()) # also record the opponent obs and actions in the trajectory sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS] sample_batch[OPPONENT_ACTION] = opponent_batch[SampleBatch.ACTIONS] # overwrite default VF prediction with the central VF if args.torch: sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf( convert_to_torch_tensor(sample_batch[SampleBatch.CUR_OBS]), convert_to_torch_tensor(sample_batch[OPPONENT_OBS]), convert_to_torch_tensor(sample_batch[OPPONENT_ACTION])). \ detach().numpy() else: sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf( sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS], sample_batch[OPPONENT_ACTION]) else: # Policy hasn't been initialized yet, use zeros. sample_batch[OPPONENT_OBS] = np.zeros_like( sample_batch[SampleBatch.CUR_OBS]) sample_batch[OPPONENT_ACTION] = np.zeros_like( sample_batch[SampleBatch.ACTIONS]) sample_batch[SampleBatch.VF_PREDS] = np.zeros_like( sample_batch[SampleBatch.REWARDS], dtype=np.float32) completed = sample_batch["dones"][-1] if completed: last_r = 0.0 else: last_r = sample_batch[SampleBatch.VF_PREDS][-1] train_batch = compute_advantages( sample_batch, last_r, policy.config["gamma"], policy.config["lambda"], use_gae=policy.config["use_gae"]) return train_batch # Copied from PPO but optimizing the central value function def loss_with_central_critic(policy, model, dist_class, train_batch): CentralizedValueMixin.__init__(policy) logits, state = model.from_batch(train_batch) action_dist = dist_class(logits, model) policy.central_value_out = policy.model.central_value_function( train_batch[SampleBatch.CUR_OBS], train_batch[OPPONENT_OBS], train_batch[OPPONENT_ACTION]) func = TFLoss if not policy.config["framework"] == "torch" else TorchLoss adv = tf.ones_like(train_batch[Postprocessing.ADVANTAGES], dtype=tf.bool) \ if policy.config["framework"] != "torch" else \ torch.ones_like(train_batch[Postprocessing.ADVANTAGES], dtype=torch.bool) policy.loss_obj = func( dist_class, model, train_batch[Postprocessing.VALUE_TARGETS], train_batch[Postprocessing.ADVANTAGES], train_batch[SampleBatch.ACTIONS], train_batch[SampleBatch.ACTION_DIST_INPUTS], train_batch[SampleBatch.ACTION_LOGP], train_batch[SampleBatch.VF_PREDS], action_dist, policy.central_value_out, policy.kl_coeff, adv, entropy_coeff=policy.entropy_coeff, clip_param=policy.config["clip_param"], vf_clip_param=policy.config["vf_clip_param"], vf_loss_coeff=policy.config["vf_loss_coeff"], use_gae=policy.config["use_gae"]) return policy.loss_obj.loss def setup_mixins(policy, obs_space, action_space, config): # copied from PPO KLCoeffMixin.__init__(policy, config) EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], config["entropy_coeff_schedule"]) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) def central_vf_stats(policy, train_batch, grads): # Report the explained variance of the central value function. return { "vf_explained_var": explained_variance( train_batch[Postprocessing.VALUE_TARGETS], policy.central_value_out), } CCPPOTFPolicy = PPOTFPolicy.with_updates( name="CCPPOTFPolicy", postprocess_fn=centralized_critic_postprocessing, loss_fn=loss_with_central_critic, before_loss_init=setup_mixins, grad_stats_fn=central_vf_stats, mixins=[ LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, CentralizedValueMixin ]) CCPPOTorchPolicy = PPOTorchPolicy.with_updates( name="CCPPOTorchPolicy", postprocess_fn=centralized_critic_postprocessing, loss_fn=loss_with_central_critic, before_init=setup_mixins, mixins=[ TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin, CentralizedValueMixin ]) def get_policy_class(config): return CCPPOTorchPolicy if config["framework"] == "torch" \ else CCPPOTFPolicy CCTrainer = PPOTrainer.with_updates( name="CCPPOTrainer", default_policy=CCPPOTFPolicy, get_policy_class=get_policy_class, ) if __name__ == "__main__": ray.init(local_mode=True) args = parser.parse_args() ModelCatalog.register_custom_model( "cc_model", TorchCentralizedCriticModel if args.torch else CentralizedCriticModel) config = { "env": TwoStepGame, "batch_mode": "complete_episodes", "num_workers": 0, "multiagent": { "policies": { "pol1": (None, Discrete(6), TwoStepGame.action_space, { "framework": "torch" if args.torch else "tf", }), "pol2": (None, Discrete(6), TwoStepGame.action_space, { "framework": "torch" if args.torch else "tf", }), }, "policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2", }, "model": { "custom_model": "cc_model", }, "framework": "torch" if args.torch else "tf", } stop = { "training_iteration": args.stop_iters, "timesteps_total": args.stop_timesteps, "episode_reward_mean": args.stop_reward, } results = tune.run(CCTrainer, config=config, stop=stop) if args.as_test: check_learning_achieved(results, args.stop_reward)