from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf import tensorflow.contrib.rnn as rnn import distutils.version import ray from policy import * use_tf100_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('1.0.0') class LSTMPolicy(Policy): def setup_graph(self, ob_space, ac_space): """Setup model used for Policy (in this A3C, both the Critic and the Actor share the model)""" self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space)) for i in range(4): x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2])) # introduce a "fake" batch dimension of 1 after flatten so that we can do LSTM over time dim x = tf.expand_dims(flatten(x), [0]) size = 256 if use_tf100_api: lstm = rnn.BasicLSTMCell(size, state_is_tuple=True) else: lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True) self.state_size = lstm.state_size step_size = tf.shape(self.x)[:1] c_init = np.zeros((1, lstm.state_size.c), np.float32) h_init = np.zeros((1, lstm.state_size.h), np.float32) self.state_init = [c_init, h_init] c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c]) h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h]) self.state_in = [c_in, h_in] if use_tf100_api: state_in = rnn.LSTMStateTuple(c_in, h_in) else: state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in) lstm_outputs, lstm_state = tf.nn.dynamic_rnn( lstm, x, initial_state=state_in, sequence_length=step_size, time_major=False) lstm_c, lstm_h = lstm_state x = tf.reshape(lstm_outputs, [-1, size]) self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01)) self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1]) self.state_out = [lstm_c[:1, :], lstm_h[:1, :]] self.sample = categorical_sample(self.logits, ac_space)[0, :] self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) self.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False) def get_gradients(self, batch): """Computing the gradient is actually model-dependent. The LSTM needs its hidden states in order to compute the gradient accurately.""" feed_dict = { self.x: batch.si, self.ac: batch.a, self.adv: batch.adv, self.r: batch.r, self.state_in[0]: batch.features[0], self.state_in[1]: batch.features[1], } self.local_steps += 1 return self.sess.run(self.grads, feed_dict=feed_dict) def act(self, ob, c, h): return self.sess.run([self.sample, self.vf] + self.state_out, {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h}) def value(self, ob, c, h): return self.sess.run(self.vf, {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})[0] def get_initial_features(self): return self.state_init class RawLSTMPolicy(LSTMPolicy): def get_weights(self): if not hasattr(self, "_weights"): self._weights = self.variables.get_weights() return self._weights def set_weights(self, weights): self._weights = weights def model_update(self, grads): for var, grad in zip(self.var_list, grads): self._weights[var.name[:-2]] -= 1e-4 * grad