ray/rllib/examples/models/trajectory_view_utilizing_models.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

129 lines
5.2 KiB
Python
Raw Normal View History

from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.tf_utils import one_hot
from ray.rllib.utils.torch_utils import one_hot as torch_one_hot
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
# __sphinx_doc_begin__
class FrameStackingCartPoleModel(TFModelV2):
"""A simple FC model that takes the last n observations as input."""
def __init__(
self, obs_space, action_space, num_outputs, model_config, name, num_frames=3
):
super(FrameStackingCartPoleModel, self).__init__(
obs_space, action_space, None, model_config, name
)
self.num_frames = num_frames
self.num_outputs = num_outputs
# Construct actual (very simple) FC model.
assert len(obs_space.shape) == 1
obs = tf.keras.layers.Input(shape=(self.num_frames, obs_space.shape[0]))
obs_reshaped = tf.keras.layers.Reshape([obs_space.shape[0] * self.num_frames])(
obs
)
rewards = tf.keras.layers.Input(shape=(self.num_frames))
rewards_reshaped = tf.keras.layers.Reshape([self.num_frames])(rewards)
actions = tf.keras.layers.Input(shape=(self.num_frames, self.action_space.n))
actions_reshaped = tf.keras.layers.Reshape([action_space.n * self.num_frames])(
actions
)
input_ = tf.keras.layers.Concatenate(axis=-1)(
[obs_reshaped, actions_reshaped, rewards_reshaped]
)
layer1 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(input_)
layer2 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(layer1)
out = tf.keras.layers.Dense(self.num_outputs)(layer2)
values = tf.keras.layers.Dense(1)(layer1)
self.base_model = tf.keras.models.Model([obs, actions, rewards], [out, values])
self._last_value = None
self.view_requirements["prev_n_obs"] = ViewRequirement(
data_col="obs", shift="-{}:0".format(num_frames - 1), space=obs_space
)
self.view_requirements["prev_n_rewards"] = ViewRequirement(
data_col="rewards", shift="-{}:-1".format(self.num_frames)
)
self.view_requirements["prev_n_actions"] = ViewRequirement(
data_col="actions",
shift="-{}:-1".format(self.num_frames),
space=self.action_space,
)
def forward(self, input_dict, states, seq_lens):
obs = tf.cast(input_dict["prev_n_obs"], tf.float32)
rewards = tf.cast(input_dict["prev_n_rewards"], tf.float32)
actions = one_hot(input_dict["prev_n_actions"], self.action_space)
out, self._last_value = self.base_model([obs, actions, rewards])
return out, []
def value_function(self):
return tf.squeeze(self._last_value, -1)
# __sphinx_doc_end__
class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module):
"""A simple FC model that takes the last n observations as input."""
def __init__(
self, obs_space, action_space, num_outputs, model_config, name, num_frames=3
):
nn.Module.__init__(self)
super(TorchFrameStackingCartPoleModel, self).__init__(
obs_space, action_space, None, model_config, name
)
self.num_frames = num_frames
self.num_outputs = num_outputs
# Construct actual (very simple) FC model.
assert len(obs_space.shape) == 1
in_size = self.num_frames * (obs_space.shape[0] + action_space.n + 1)
self.layer1 = SlimFC(in_size=in_size, out_size=256, activation_fn="relu")
self.layer2 = SlimFC(in_size=256, out_size=256, activation_fn="relu")
self.out = SlimFC(
in_size=256, out_size=self.num_outputs, activation_fn="linear"
)
self.values = SlimFC(in_size=256, out_size=1, activation_fn="linear")
self._last_value = None
self.view_requirements["prev_n_obs"] = ViewRequirement(
data_col="obs", shift="-{}:0".format(num_frames - 1), space=obs_space
)
self.view_requirements["prev_n_rewards"] = ViewRequirement(
data_col="rewards", shift="-{}:-1".format(self.num_frames)
)
self.view_requirements["prev_n_actions"] = ViewRequirement(
data_col="actions",
shift="-{}:-1".format(self.num_frames),
space=self.action_space,
)
def forward(self, input_dict, states, seq_lens):
obs = input_dict["prev_n_obs"]
obs = torch.reshape(obs, [-1, self.obs_space.shape[0] * self.num_frames])
rewards = torch.reshape(input_dict["prev_n_rewards"], [-1, self.num_frames])
actions = torch_one_hot(input_dict["prev_n_actions"], self.action_space)
actions = torch.reshape(actions, [-1, self.num_frames * actions.shape[-1]])
input_ = torch.cat([obs, actions, rewards], dim=-1)
features = self.layer1(input_)
features = self.layer2(features)
out = self.out(features)
self._last_value = self.values(features)
return out, []
def value_function(self):
return torch.squeeze(self._last_value, -1)