ray/rllib/examples/models/autoregressive_action_model.py

156 lines
5.6 KiB
Python

from gym.spaces import Discrete, Tuple
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.misc import normc_initializer as normc_init_torch
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
class AutoregressiveActionModel(TFModelV2):
"""Implements the `.action_model` branch required above."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(AutoregressiveActionModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
if action_space != Tuple([Discrete(2), Discrete(2)]):
raise ValueError(
"This model only supports the [2, 2] action space")
# Inputs
obs_input = tf.keras.layers.Input(
shape=obs_space.shape, name="obs_input")
a1_input = tf.keras.layers.Input(shape=(1, ), name="a1_input")
ctx_input = tf.keras.layers.Input(
shape=(num_outputs, ), name="ctx_input")
# Output of the model (normally 'logits', but for an autoregressive
# dist this is more like a context/feature layer encoding the obs)
context = tf.keras.layers.Dense(
num_outputs,
name="hidden",
activation=tf.nn.tanh,
kernel_initializer=normc_initializer(1.0))(obs_input)
# V(s)
value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(context)
# P(a1 | obs)
a1_logits = tf.keras.layers.Dense(
2,
name="a1_logits",
activation=None,
kernel_initializer=normc_initializer(0.01))(ctx_input)
# P(a2 | a1)
# --note: typically you'd want to implement P(a2 | a1, obs) as follows:
# a2_context = tf.keras.layers.Concatenate(axis=1)(
# [ctx_input, a1_input])
a2_context = a1_input
a2_hidden = tf.keras.layers.Dense(
16,
name="a2_hidden",
activation=tf.nn.tanh,
kernel_initializer=normc_initializer(1.0))(a2_context)
a2_logits = tf.keras.layers.Dense(
2,
name="a2_logits",
activation=None,
kernel_initializer=normc_initializer(0.01))(a2_hidden)
# Base layers
self.base_model = tf.keras.Model(obs_input, [context, value_out])
self.base_model.summary()
# Autoregressive action sampler
self.action_model = tf.keras.Model([ctx_input, a1_input],
[a1_logits, a2_logits])
self.action_model.summary()
def forward(self, input_dict, state, seq_lens):
context, self._value_out = self.base_model(input_dict["obs"])
return context, state
def value_function(self):
return tf.reshape(self._value_out, [-1])
class TorchAutoregressiveActionModel(TorchModelV2, nn.Module):
"""PyTorch version of the AutoregressiveActionModel above."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
if action_space != Tuple([Discrete(2), Discrete(2)]):
raise ValueError(
"This model only supports the [2, 2] action space")
# Output of the model (normally 'logits', but for an autoregressive
# dist this is more like a context/feature layer encoding the obs)
self.context_layer = SlimFC(
in_size=obs_space.shape[0],
out_size=num_outputs,
initializer=normc_init_torch(1.0),
activation_fn=nn.Tanh,
)
# V(s)
self.value_branch = SlimFC(
in_size=num_outputs,
out_size=1,
initializer=normc_init_torch(0.01),
activation_fn=None,
)
# P(a1 | obs)
self.a1_logits = SlimFC(
in_size=num_outputs,
out_size=2,
activation_fn=None,
initializer=normc_init_torch(0.01))
class _ActionModel(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.a2_hidden = SlimFC(
in_size=1,
out_size=16,
activation_fn=nn.Tanh,
initializer=normc_init_torch(1.0))
self.a2_logits = SlimFC(
in_size=16,
out_size=2,
activation_fn=None,
initializer=normc_init_torch(0.01))
def forward(self_, ctx_input, a1_input):
a1_logits = self.a1_logits(ctx_input)
a2_logits = self_.a2_logits(self_.a2_hidden(a1_input))
return a1_logits, a2_logits
# P(a2 | a1)
# --note: typically you'd want to implement P(a2 | a1, obs) as follows:
# a2_context = tf.keras.layers.Concatenate(axis=1)(
# [ctx_input, a1_input])
self.action_module = _ActionModel()
self._context = None
def forward(self, input_dict, state, seq_lens):
self._context = self.context_layer(input_dict["obs"])
return self._context, state
def value_function(self):
return torch.reshape(self.value_branch(self._context), [-1])