mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00
93 lines
3.7 KiB
Python
93 lines
3.7 KiB
Python
![]() |
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import numpy as np
|
||
|
import tensorflow as tf
|
||
|
import tensorflow.contrib.rnn as rnn
|
||
|
import distutils.version
|
||
|
import ray
|
||
|
from policy import *
|
||
|
use_tf100_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('1.0.0')
|
||
|
|
||
|
|
||
|
class LSTMPolicy(Policy):
|
||
|
def setup_graph(self, ob_space, ac_space):
|
||
|
"""Setup model used for Policy (in this A3C, both the Critic and the Actor share the model)"""
|
||
|
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||
|
|
||
|
for i in range(4):
|
||
|
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
|
||
|
# introduce a "fake" batch dimension of 1 after flatten so that we can do LSTM over time dim
|
||
|
x = tf.expand_dims(flatten(x), [0])
|
||
|
|
||
|
size = 256
|
||
|
if use_tf100_api:
|
||
|
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
|
||
|
else:
|
||
|
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
|
||
|
self.state_size = lstm.state_size
|
||
|
step_size = tf.shape(self.x)[:1]
|
||
|
|
||
|
c_init = np.zeros((1, lstm.state_size.c), np.float32)
|
||
|
h_init = np.zeros((1, lstm.state_size.h), np.float32)
|
||
|
self.state_init = [c_init, h_init]
|
||
|
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
|
||
|
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
|
||
|
self.state_in = [c_in, h_in]
|
||
|
|
||
|
if use_tf100_api:
|
||
|
state_in = rnn.LSTMStateTuple(c_in, h_in)
|
||
|
else:
|
||
|
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
|
||
|
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
|
||
|
lstm, x, initial_state=state_in, sequence_length=step_size,
|
||
|
time_major=False)
|
||
|
lstm_c, lstm_h = lstm_state
|
||
|
x = tf.reshape(lstm_outputs, [-1, size])
|
||
|
self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
|
||
|
self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
|
||
|
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
|
||
|
self.sample = categorical_sample(self.logits, ac_space)[0, :]
|
||
|
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
|
||
|
self.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32),
|
||
|
trainable=False)
|
||
|
|
||
|
def get_gradients(self, batch):
|
||
|
"""Computing the gradient is actually model-dependent.
|
||
|
The LSTM needs its hidden states in order to compute the gradient accurately."""
|
||
|
feed_dict = {
|
||
|
self.x: batch.si,
|
||
|
self.ac: batch.a,
|
||
|
self.adv: batch.adv,
|
||
|
self.r: batch.r,
|
||
|
self.state_in[0]: batch.features[0],
|
||
|
self.state_in[1]: batch.features[1],
|
||
|
}
|
||
|
self.local_steps += 1
|
||
|
return self.sess.run(self.grads, feed_dict=feed_dict)
|
||
|
|
||
|
def act(self, ob, c, h):
|
||
|
return self.sess.run([self.sample, self.vf] + self.state_out,
|
||
|
{self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})
|
||
|
|
||
|
def value(self, ob, c, h):
|
||
|
return self.sess.run(self.vf, {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})[0]
|
||
|
|
||
|
def get_initial_features(self):
|
||
|
return self.state_init
|
||
|
|
||
|
|
||
|
class RawLSTMPolicy(LSTMPolicy):
|
||
|
def get_weights(self):
|
||
|
if not hasattr(self, "_weights"):
|
||
|
self._weights = self.variables.get_weights()
|
||
|
return self._weights
|
||
|
|
||
|
def set_weights(self, weights):
|
||
|
self._weights = weights
|
||
|
|
||
|
def model_update(self, grads):
|
||
|
for var, grad in zip(self.var_list, grads):
|
||
|
self._weights[var.name[:-2]] -= 1e-4 * grad
|