mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
156 lines
5.6 KiB
Python
156 lines
5.6 KiB
Python
from gym.spaces import Discrete, Tuple
|
|
|
|
from ray.rllib.models.tf.misc import normc_initializer
|
|
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
|
from ray.rllib.models.torch.misc import normc_initializer as normc_init_torch
|
|
from ray.rllib.models.torch.misc import SlimFC
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
class AutoregressiveActionModel(TFModelV2):
|
|
"""Implements the `.action_model` branch required above."""
|
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
|
name):
|
|
super(AutoregressiveActionModel, self).__init__(
|
|
obs_space, action_space, num_outputs, model_config, name)
|
|
if action_space != Tuple([Discrete(2), Discrete(2)]):
|
|
raise ValueError(
|
|
"This model only supports the [2, 2] action space")
|
|
|
|
# Inputs
|
|
obs_input = tf.keras.layers.Input(
|
|
shape=obs_space.shape, name="obs_input")
|
|
a1_input = tf.keras.layers.Input(shape=(1, ), name="a1_input")
|
|
ctx_input = tf.keras.layers.Input(
|
|
shape=(num_outputs, ), name="ctx_input")
|
|
|
|
# Output of the model (normally 'logits', but for an autoregressive
|
|
# dist this is more like a context/feature layer encoding the obs)
|
|
context = tf.keras.layers.Dense(
|
|
num_outputs,
|
|
name="hidden",
|
|
activation=tf.nn.tanh,
|
|
kernel_initializer=normc_initializer(1.0))(obs_input)
|
|
|
|
# V(s)
|
|
value_out = tf.keras.layers.Dense(
|
|
1,
|
|
name="value_out",
|
|
activation=None,
|
|
kernel_initializer=normc_initializer(0.01))(context)
|
|
|
|
# P(a1 | obs)
|
|
a1_logits = tf.keras.layers.Dense(
|
|
2,
|
|
name="a1_logits",
|
|
activation=None,
|
|
kernel_initializer=normc_initializer(0.01))(ctx_input)
|
|
|
|
# P(a2 | a1)
|
|
# --note: typically you'd want to implement P(a2 | a1, obs) as follows:
|
|
# a2_context = tf.keras.layers.Concatenate(axis=1)(
|
|
# [ctx_input, a1_input])
|
|
a2_context = a1_input
|
|
a2_hidden = tf.keras.layers.Dense(
|
|
16,
|
|
name="a2_hidden",
|
|
activation=tf.nn.tanh,
|
|
kernel_initializer=normc_initializer(1.0))(a2_context)
|
|
a2_logits = tf.keras.layers.Dense(
|
|
2,
|
|
name="a2_logits",
|
|
activation=None,
|
|
kernel_initializer=normc_initializer(0.01))(a2_hidden)
|
|
|
|
# Base layers
|
|
self.base_model = tf.keras.Model(obs_input, [context, value_out])
|
|
self.base_model.summary()
|
|
|
|
# Autoregressive action sampler
|
|
self.action_model = tf.keras.Model([ctx_input, a1_input],
|
|
[a1_logits, a2_logits])
|
|
self.action_model.summary()
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
|
context, self._value_out = self.base_model(input_dict["obs"])
|
|
return context, state
|
|
|
|
def value_function(self):
|
|
return tf.reshape(self._value_out, [-1])
|
|
|
|
|
|
class TorchAutoregressiveActionModel(TorchModelV2, nn.Module):
|
|
"""PyTorch version of the AutoregressiveActionModel above."""
|
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
|
name):
|
|
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
|
|
model_config, name)
|
|
nn.Module.__init__(self)
|
|
|
|
if action_space != Tuple([Discrete(2), Discrete(2)]):
|
|
raise ValueError(
|
|
"This model only supports the [2, 2] action space")
|
|
|
|
# Output of the model (normally 'logits', but for an autoregressive
|
|
# dist this is more like a context/feature layer encoding the obs)
|
|
self.context_layer = SlimFC(
|
|
in_size=obs_space.shape[0],
|
|
out_size=num_outputs,
|
|
initializer=normc_init_torch(1.0),
|
|
activation_fn=nn.Tanh,
|
|
)
|
|
|
|
# V(s)
|
|
self.value_branch = SlimFC(
|
|
in_size=num_outputs,
|
|
out_size=1,
|
|
initializer=normc_init_torch(0.01),
|
|
activation_fn=None,
|
|
)
|
|
|
|
# P(a1 | obs)
|
|
self.a1_logits = SlimFC(
|
|
in_size=num_outputs,
|
|
out_size=2,
|
|
activation_fn=None,
|
|
initializer=normc_init_torch(0.01))
|
|
|
|
class _ActionModel(nn.Module):
|
|
def __init__(self):
|
|
nn.Module.__init__(self)
|
|
self.a2_hidden = SlimFC(
|
|
in_size=1,
|
|
out_size=16,
|
|
activation_fn=nn.Tanh,
|
|
initializer=normc_init_torch(1.0))
|
|
self.a2_logits = SlimFC(
|
|
in_size=16,
|
|
out_size=2,
|
|
activation_fn=None,
|
|
initializer=normc_init_torch(0.01))
|
|
|
|
def forward(self_, ctx_input, a1_input):
|
|
a1_logits = self.a1_logits(ctx_input)
|
|
a2_logits = self_.a2_logits(self_.a2_hidden(a1_input))
|
|
return a1_logits, a2_logits
|
|
|
|
# P(a2 | a1)
|
|
# --note: typically you'd want to implement P(a2 | a1, obs) as follows:
|
|
# a2_context = tf.keras.layers.Concatenate(axis=1)(
|
|
# [ctx_input, a1_input])
|
|
self.action_module = _ActionModel()
|
|
|
|
self._context = None
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
|
self._context = self.context_layer(input_dict["obs"])
|
|
return self._context, state
|
|
|
|
def value_function(self):
|
|
return torch.reshape(self.value_branch(self._context), [-1])
|