ray/rllib/examples/centralized_critic.py
2020-01-20 23:06:50 -08:00

219 lines
8.3 KiB
Python

"""An example of customizing PPO to leverage a centralized critic.
Here the model and policy are hard-coded to implement a centralized critic
for TwoStepGame, but you can adapt this for your own use cases.
Compared to simply running `twostep_game.py --run=PPO`, this centralized
critic version reaches vf_explained_variance=1.0 more stably since it takes
into account the opponent actions as well as the policy's. Note that this is
also using two independent policies instead of weight-sharing with one.
See also: centralized_critic_2.py for a simpler approach that instead
modifies the environment.
"""
import argparse
import numpy as np
from gym.spaces import Discrete
from ray import tune
from ray.rllib.agents.impala.vtrace_policy import BEHAVIOUR_LOGITS
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy, KLCoeffMixin, \
PPOLoss
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.examples.twostep_game import TwoStepGame
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import LearningRateSchedule, \
EntropyCoeffSchedule, ACTION_LOGP
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.utils.tf_ops import make_tf_callable
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
OPPONENT_OBS = "opponent_obs"
OPPONENT_ACTION = "opponent_action"
parser = argparse.ArgumentParser()
parser.add_argument("--stop", type=int, default=100000)
class CentralizedCriticModel(TFModelV2):
"""Multi-agent model that implements a centralized VF."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(CentralizedCriticModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
# Base of the model
self.model = FullyConnectedNetwork(obs_space, action_space,
num_outputs, model_config, name)
self.register_variables(self.model.variables())
# Central VF maps (obs, opp_obs, opp_act) -> vf_pred
obs = tf.keras.layers.Input(shape=(6, ), name="obs")
opp_obs = tf.keras.layers.Input(shape=(6, ), name="opp_obs")
opp_act = tf.keras.layers.Input(shape=(2, ), name="opp_act")
concat_obs = tf.keras.layers.Concatenate(axis=1)(
[obs, opp_obs, opp_act])
central_vf_dense = tf.keras.layers.Dense(
16, activation=tf.nn.tanh, name="c_vf_dense")(concat_obs)
central_vf_out = tf.keras.layers.Dense(
1, activation=None, name="c_vf_out")(central_vf_dense)
self.central_vf = tf.keras.Model(
inputs=[obs, opp_obs, opp_act], outputs=central_vf_out)
self.register_variables(self.central_vf.variables)
def forward(self, input_dict, state, seq_lens):
return self.model.forward(input_dict, state, seq_lens)
def central_value_function(self, obs, opponent_obs, opponent_actions):
return tf.reshape(
self.central_vf(
[obs, opponent_obs,
tf.one_hot(opponent_actions, 2)]), [-1])
def value_function(self):
return self.model.value_function() # not used
class CentralizedValueMixin:
"""Add method to evaluate the central value function from the model."""
def __init__(self):
self.compute_central_vf = make_tf_callable(self.get_session())(
self.model.central_value_function)
# Grabs the opponent obs/act and includes it in the experience train_batch,
# and computes GAE using the central vf predictions.
def centralized_critic_postprocessing(policy,
sample_batch,
other_agent_batches=None,
episode=None):
if policy.loss_initialized():
assert sample_batch["dones"][-1], \
"Not implemented for train_batch_mode=truncate_episodes"
assert other_agent_batches is not None
[(_, opponent_batch)] = list(other_agent_batches.values())
# also record the opponent obs and actions in the trajectory
sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS]
sample_batch[OPPONENT_ACTION] = opponent_batch[SampleBatch.ACTIONS]
# overwrite default VF prediction with the central VF
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS],
sample_batch[OPPONENT_ACTION])
else:
# policy hasn't initialized yet, use zeros
sample_batch[OPPONENT_OBS] = np.zeros_like(
sample_batch[SampleBatch.CUR_OBS])
sample_batch[OPPONENT_ACTION] = np.zeros_like(
sample_batch[SampleBatch.ACTIONS])
sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
sample_batch[SampleBatch.ACTIONS], dtype=np.float32)
train_batch = compute_advantages(
sample_batch,
0.0,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"])
return train_batch
# Copied from PPO but optimizing the central value function
def loss_with_central_critic(policy, model, dist_class, train_batch):
CentralizedValueMixin.__init__(policy)
logits, state = model.from_batch(train_batch)
action_dist = dist_class(logits, model)
policy.central_value_out = policy.model.central_value_function(
train_batch[SampleBatch.CUR_OBS], train_batch[OPPONENT_OBS],
train_batch[OPPONENT_ACTION])
policy.loss_obj = PPOLoss(
policy.action_space,
dist_class,
model,
train_batch[Postprocessing.VALUE_TARGETS],
train_batch[Postprocessing.ADVANTAGES],
train_batch[SampleBatch.ACTIONS],
train_batch[BEHAVIOUR_LOGITS],
train_batch[ACTION_LOGP],
train_batch[SampleBatch.VF_PREDS],
action_dist,
policy.central_value_out,
policy.kl_coeff,
tf.ones_like(train_batch[Postprocessing.ADVANTAGES], dtype=tf.bool),
entropy_coeff=policy.entropy_coeff,
clip_param=policy.config["clip_param"],
vf_clip_param=policy.config["vf_clip_param"],
vf_loss_coeff=policy.config["vf_loss_coeff"],
use_gae=policy.config["use_gae"],
model_config=policy.config["model"])
return policy.loss_obj.loss
def setup_mixins(policy, obs_space, action_space, config):
# copied from PPO
KLCoeffMixin.__init__(policy, config)
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
config["entropy_coeff_schedule"])
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
def central_vf_stats(policy, train_batch, grads):
# Report the explained variance of the central value function.
return {
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.central_value_out),
}
CCPPO = PPOTFPolicy.with_updates(
name="CCPPO",
postprocess_fn=centralized_critic_postprocessing,
loss_fn=loss_with_central_critic,
before_loss_init=setup_mixins,
grad_stats_fn=central_vf_stats,
mixins=[
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
CentralizedValueMixin
])
CCTrainer = PPOTrainer.with_updates(name="CCPPOTrainer", default_policy=CCPPO)
if __name__ == "__main__":
args = parser.parse_args()
ModelCatalog.register_custom_model("cc_model", CentralizedCriticModel)
tune.run(
CCTrainer,
stop={
"timesteps_total": args.stop,
"episode_reward_mean": 7.99,
},
config={
"env": TwoStepGame,
"batch_mode": "complete_episodes",
"eager": False,
"num_workers": 0,
"multiagent": {
"policies": {
"pol1": (None, Discrete(6), TwoStepGame.action_space, {}),
"pol2": (None, Discrete(6), TwoStepGame.action_space, {}),
},
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
},
"model": {
"custom_model": "cc_model",
},
})