ray/examples/policy_gradient/reinforce/tfutils.py

31 lines
1.1 KiB
Python
Raw Normal View History

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import threading
class DataQueue(object):
def __init__(self, placeholder_dict):
"""Here, placeholder_dict is an OrderedDict."""
placeholders = placeholder_dict.values()
shapes = [placeholder.get_shape()[1:].as_list() for placeholder in placeholders]
types = [placeholder.dtype for placeholder in placeholders]
self.queue = tf.RandomShuffleQueue(shapes=shapes, dtypes=dtypes, capacity=2000, min_after_dequeue=1000)
self.enqueue_op = self.queue.enqueue_many(placeholders)
def thread_main(self, sess, data_iterator):
for data in data_iterator:
feed_dict = {placeholder: data[name] for (name, placeholder) in placeholder_dict}
sess.run(self.enqueue_op, feed_dict=feed_dict)
def start_thread(self, sess, data_iterator, num_threads=1):
threads = []
for n in range(num_thread):
t = threading.Thread(target=self.train_main, args=(sess, data_iterator))
t.daemon = True # Thread will close when parent quits
t.start()
threads.append(t)
return threads