[rllib] Fix race condition with multiple data loaders, fix stats

This commit is contained in:
Eric Liang 2019-03-23 20:17:01 -07:00 committed by GitHub
parent 7a38f9be1c
commit 01699ce4ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 3 deletions

View file

@ -368,15 +368,16 @@ class TFMultiGPULearner(LearnerThread):
assert self.loader_thread.is_alive()
with self.load_wait_timer:
opt, released = self.minibatch_buffer.get()
if released:
self.idle_optimizers.put(opt)
with self.grad_timer:
fetches = opt.optimize(self.sess, 0)
self.weights_updated = True
self.stats = fetches.get("stats", {})
self.outqueue.put(self.train_batch_size)
if released:
self.idle_optimizers.put(opt)
self.outqueue.put(opt.num_tuples_loaded)
self.learner_queue_size.push(self.inqueue.qsize())

View file

@ -188,6 +188,7 @@ class LocalSyncParallelOptimizer(object):
sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)
self.num_tuples_loaded = truncated_len
tuples_per_device = truncated_len // len(self.devices)
assert tuples_per_device > 0, "No data loaded?"
assert tuples_per_device % self._loaded_per_device_batch_size == 0