ray/rllib/models/tf/lstm_v1.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

75 lines
2.6 KiB
Python

import numpy as np
from ray.rllib.models.model import Model
from ray.rllib.models.tf.misc import linear, normc_initializer
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
# Deprecated: see as an alternative models/tf/recurrent_tf_modelv2.py
class LSTM(Model):
"""Adds a LSTM cell on top of some other model output.
Uses a linear layer at the end for output.
Important: we assume inputs is a padded batch of sequences denoted by
self.seq_lens. See add_time_dimension() for more information.
"""
@override(Model)
def _build_layers_v2(self, input_dict, num_outputs, options):
cell_size = options.get("lstm_cell_size")
if options.get("lstm_use_prev_action_reward"):
action_dim = int(
np.product(
input_dict["prev_actions"].get_shape().as_list()[1:]))
features = tf.concat(
[
input_dict["obs"],
tf.reshape(
tf.cast(input_dict["prev_actions"], tf.float32),
[-1, action_dim]),
tf.reshape(input_dict["prev_rewards"], [-1, 1]),
],
axis=1)
else:
features = input_dict["obs"]
last_layer = add_time_dimension(features, self.seq_lens)
# Setup the LSTM cell
lstm = tf.nn.rnn_cell.LSTMCell(cell_size, state_is_tuple=True)
self.state_init = [
np.zeros(lstm.state_size.c, np.float32),
np.zeros(lstm.state_size.h, np.float32)
]
# Setup LSTM inputs
if self.state_in:
c_in, h_in = self.state_in
else:
c_in = tf.placeholder(
tf.float32, [None, lstm.state_size.c], name="c")
h_in = tf.placeholder(
tf.float32, [None, lstm.state_size.h], name="h")
self.state_in = [c_in, h_in]
# Setup LSTM outputs
state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_out, lstm_state = tf.nn.dynamic_rnn(
lstm,
last_layer,
initial_state=state_in,
sequence_length=self.seq_lens,
time_major=False,
dtype=tf.float32)
self.state_out = list(lstm_state)
# Compute outputs
last_layer = tf.reshape(lstm_out, [-1, cell_size])
logits = linear(last_layer, num_outputs, "action",
normc_initializer(0.01))
return logits, last_layer