ray/rllib/examples/parametric_actions_cartpole.py
Sven Mika 42991d723f
[RLlib] rllib/examples folder restructuring (#8250)
Cleans up of the rllib/examples folder by moving all example Envs into rllibexamples/env (so they can be used by other scripts and tests as well).
2020-05-01 22:59:34 +02:00

118 lines
4.3 KiB
Python

"""Example of handling variable length and/or parametric action spaces.
This is a toy example of the action-embedding based approach for handling large
discrete action spaces (potentially infinite in size), similar to this:
https://neuro.cs.ut.ee/the-use-of-embeddings-in-openai-five/
This currently works with RLlib's policy gradient style algorithms
(e.g., PG, PPO, IMPALA, A2C) and also DQN.
Note that since the model outputs now include "-inf" tf.float32.min
values, not all algorithm options are supported at the moment. For example,
algorithms might crash if they don't properly ignore the -inf action scores.
Working configurations are given below.
"""
import argparse
from gym.spaces import Box
import ray
from ray import tune
from ray.rllib.agents.dqn.distributional_q_tf_model import \
DistributionalQTFModel
from ray.rllib.examples.env.parametric_actions_cartpole import \
ParametricActionsCartPole
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.tune.registry import register_env
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
parser = argparse.ArgumentParser()
parser.add_argument("--stop", type=int, default=200)
parser.add_argument("--run", type=str, default="PPO")
class ParametricActionsModel(DistributionalQTFModel, TFModelV2):
"""Parametric action model that handles the dot product and masking.
This assumes the outputs are logits for a single Categorical action dist.
Getting this to work with a more complex output (e.g., if the action space
is a tuple of several distributions) is also possible but left as an
exercise to the reader.
"""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
true_obs_shape=(4, ),
action_embed_size=2,
**kw):
super(ParametricActionsModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name, **kw)
self.action_embed_model = FullyConnectedNetwork(
Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size,
model_config, name + "_action_embed")
self.register_variables(self.action_embed_model.variables())
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
avail_actions = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
# Compute the predicted action embedding
action_embed, _ = self.action_embed_model({
"obs": input_dict["obs"]["cart"]
})
# Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
# avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
intent_vector = tf.expand_dims(action_embed, 1)
# Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2)
# Mask out invalid actions (use tf.float32.min for stability)
inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
return action_logits + inf_mask, state
def value_function(self):
return self.action_embed_model.value_function()
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
ModelCatalog.register_custom_model("pa_model", ParametricActionsModel)
register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))
if args.run == "DQN":
cfg = {
# TODO(ekl) we need to set these to prevent the masked values
# from being further processed in DistributionalQModel, which
# would mess up the masking. It is possible to support these if we
# defined a a custom DistributionalQModel that is aware of masking.
"hiddens": [],
"dueling": False,
}
else:
cfg = {}
tune.run(
args.run,
stop={
"episode_reward_mean": args.stop,
},
config=dict({
"env": "pa_cartpole",
"model": {
"custom_model": "pa_model",
},
"num_workers": 0,
}, **cfg),
)