mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
136 lines
5.5 KiB
Python
136 lines
5.5 KiB
Python
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
|
from ray.rllib.models.torch.misc import SlimFC
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
from ray.rllib.utils.tf_utils import one_hot
|
|
from ray.rllib.utils.torch_utils import one_hot as torch_one_hot
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
torch, nn = try_import_torch()
|
|
|
|
# __sphinx_doc_begin__
|
|
|
|
|
|
class FrameStackingCartPoleModel(TFModelV2):
|
|
"""A simple FC model that takes the last n observations as input."""
|
|
|
|
def __init__(self,
|
|
obs_space,
|
|
action_space,
|
|
num_outputs,
|
|
model_config,
|
|
name,
|
|
num_frames=3):
|
|
super(FrameStackingCartPoleModel, self).__init__(
|
|
obs_space, action_space, None, model_config, name)
|
|
|
|
self.num_frames = num_frames
|
|
self.num_outputs = num_outputs
|
|
|
|
# Construct actual (very simple) FC model.
|
|
assert len(obs_space.shape) == 1
|
|
obs = tf.keras.layers.Input(
|
|
shape=(self.num_frames, obs_space.shape[0]))
|
|
obs_reshaped = tf.keras.layers.Reshape(
|
|
[obs_space.shape[0] * self.num_frames])(obs)
|
|
rewards = tf.keras.layers.Input(shape=(self.num_frames))
|
|
rewards_reshaped = tf.keras.layers.Reshape([self.num_frames])(rewards)
|
|
actions = tf.keras.layers.Input(
|
|
shape=(self.num_frames, self.action_space.n))
|
|
actions_reshaped = tf.keras.layers.Reshape(
|
|
[action_space.n * self.num_frames])(actions)
|
|
input_ = tf.keras.layers.Concatenate(axis=-1)(
|
|
[obs_reshaped, actions_reshaped, rewards_reshaped])
|
|
layer1 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(input_)
|
|
layer2 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(layer1)
|
|
out = tf.keras.layers.Dense(self.num_outputs)(layer2)
|
|
values = tf.keras.layers.Dense(1)(layer1)
|
|
self.base_model = tf.keras.models.Model([obs, actions, rewards],
|
|
[out, values])
|
|
self._last_value = None
|
|
|
|
self.view_requirements["prev_n_obs"] = ViewRequirement(
|
|
data_col="obs",
|
|
shift="-{}:0".format(num_frames - 1),
|
|
space=obs_space)
|
|
self.view_requirements["prev_n_rewards"] = ViewRequirement(
|
|
data_col="rewards", shift="-{}:-1".format(self.num_frames))
|
|
self.view_requirements["prev_n_actions"] = ViewRequirement(
|
|
data_col="actions",
|
|
shift="-{}:-1".format(self.num_frames),
|
|
space=self.action_space)
|
|
|
|
def forward(self, input_dict, states, seq_lens):
|
|
obs = tf.cast(input_dict["prev_n_obs"], tf.float32)
|
|
rewards = tf.cast(input_dict["prev_n_rewards"], tf.float32)
|
|
actions = one_hot(input_dict["prev_n_actions"], self.action_space)
|
|
out, self._last_value = self.base_model([obs, actions, rewards])
|
|
return out, []
|
|
|
|
def value_function(self):
|
|
return tf.squeeze(self._last_value, -1)
|
|
|
|
|
|
# __sphinx_doc_end__
|
|
|
|
|
|
class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module):
|
|
"""A simple FC model that takes the last n observations as input."""
|
|
|
|
def __init__(self,
|
|
obs_space,
|
|
action_space,
|
|
num_outputs,
|
|
model_config,
|
|
name,
|
|
num_frames=3):
|
|
nn.Module.__init__(self)
|
|
super(TorchFrameStackingCartPoleModel, self).__init__(
|
|
obs_space, action_space, None, model_config, name)
|
|
|
|
self.num_frames = num_frames
|
|
self.num_outputs = num_outputs
|
|
|
|
# Construct actual (very simple) FC model.
|
|
assert len(obs_space.shape) == 1
|
|
in_size = self.num_frames * (obs_space.shape[0] + action_space.n + 1)
|
|
self.layer1 = SlimFC(
|
|
in_size=in_size, out_size=256, activation_fn="relu")
|
|
self.layer2 = SlimFC(in_size=256, out_size=256, activation_fn="relu")
|
|
self.out = SlimFC(
|
|
in_size=256, out_size=self.num_outputs, activation_fn="linear")
|
|
self.values = SlimFC(in_size=256, out_size=1, activation_fn="linear")
|
|
|
|
self._last_value = None
|
|
|
|
self.view_requirements["prev_n_obs"] = ViewRequirement(
|
|
data_col="obs",
|
|
shift="-{}:0".format(num_frames - 1),
|
|
space=obs_space)
|
|
self.view_requirements["prev_n_rewards"] = ViewRequirement(
|
|
data_col="rewards", shift="-{}:-1".format(self.num_frames))
|
|
self.view_requirements["prev_n_actions"] = ViewRequirement(
|
|
data_col="actions",
|
|
shift="-{}:-1".format(self.num_frames),
|
|
space=self.action_space)
|
|
|
|
def forward(self, input_dict, states, seq_lens):
|
|
obs = input_dict["prev_n_obs"]
|
|
obs = torch.reshape(obs,
|
|
[-1, self.obs_space.shape[0] * self.num_frames])
|
|
rewards = torch.reshape(input_dict["prev_n_rewards"],
|
|
[-1, self.num_frames])
|
|
actions = torch_one_hot(input_dict["prev_n_actions"],
|
|
self.action_space)
|
|
actions = torch.reshape(actions,
|
|
[-1, self.num_frames * actions.shape[-1]])
|
|
input_ = torch.cat([obs, actions, rewards], dim=-1)
|
|
features = self.layer1(input_)
|
|
features = self.layer2(features)
|
|
out = self.out(features)
|
|
self._last_value = self.values(features)
|
|
return out, []
|
|
|
|
def value_function(self):
|
|
return torch.squeeze(self._last_value, -1)
|