ray/rllib/examples/models/parametric_actions_model.py

111 lines
4.4 KiB
Python

from gym.spaces import Box
from ray.rllib.agents.dqn.distributional_q_tf_model import \
DistributionalQTFModel
from ray.rllib.agents.dqn.dqn_torch_model import \
DQNTorchModel
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import LARGE_INTEGER
tf = try_import_tf()
torch, nn = try_import_torch()
class ParametricActionsModel(DistributionalQTFModel):
"""Parametric action model that handles the dot product and masking.
This assumes the outputs are logits for a single Categorical action dist.
Getting this to work with a more complex output (e.g., if the action space
is a tuple of several distributions) is also possible but left as an
exercise to the reader.
"""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
true_obs_shape=(4, ),
action_embed_size=2,
**kw):
super(ParametricActionsModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name, **kw)
self.action_embed_model = FullyConnectedNetwork(
Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size,
model_config, name + "_action_embed")
self.register_variables(self.action_embed_model.variables())
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
avail_actions = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
# Compute the predicted action embedding
action_embed, _ = self.action_embed_model({
"obs": input_dict["obs"]["cart"]
})
# Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
# avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
intent_vector = tf.expand_dims(action_embed, 1)
# Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2)
# Mask out invalid actions (use tf.float32.min for stability)
inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
return action_logits + inf_mask, state
def value_function(self):
return self.action_embed_model.value_function()
class TorchParametricActionsModel(DQNTorchModel, nn.Module):
"""PyTorch version of above ParametricActionsModel."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
true_obs_shape=(4, ),
action_embed_size=2,
**kw):
nn.Module.__init__(self)
DQNTorchModel.__init__(self, obs_space, action_space, num_outputs,
model_config, name, **kw)
self.action_embed_model = TorchFC(
Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size,
model_config, name + "_action_embed")
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
avail_actions = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
# Compute the predicted action embedding
action_embed, _ = self.action_embed_model({
"obs": input_dict["obs"]["cart"]
})
# Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
# avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
intent_vector = torch.unsqueeze(action_embed, 1)
# Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
action_logits = torch.sum(avail_actions * intent_vector, dim=2)
# Mask out invalid actions (use -LARGE_INTEGER to tag invalid).
# These are then recognized by the EpsilonGreedy exploration component
# as invalid actions that are not to be chosen.
inf_mask = torch.clamp(
torch.log(action_mask), -float(LARGE_INTEGER), float("inf"))
return action_logits + inf_mask, state
def value_function(self):
return self.action_embed_model.value_function()