mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Fix race condition with multiple data loaders, fix stats
This commit is contained in:
parent
7a38f9be1c
commit
01699ce4ea
2 changed files with 5 additions and 3 deletions
|
@ -368,15 +368,16 @@ class TFMultiGPULearner(LearnerThread):
|
|||
assert self.loader_thread.is_alive()
|
||||
with self.load_wait_timer:
|
||||
opt, released = self.minibatch_buffer.get()
|
||||
if released:
|
||||
self.idle_optimizers.put(opt)
|
||||
|
||||
with self.grad_timer:
|
||||
fetches = opt.optimize(self.sess, 0)
|
||||
self.weights_updated = True
|
||||
self.stats = fetches.get("stats", {})
|
||||
|
||||
self.outqueue.put(self.train_batch_size)
|
||||
if released:
|
||||
self.idle_optimizers.put(opt)
|
||||
|
||||
self.outqueue.put(opt.num_tuples_loaded)
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
|
||||
|
||||
|
|
|
@ -188,6 +188,7 @@ class LocalSyncParallelOptimizer(object):
|
|||
|
||||
sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)
|
||||
|
||||
self.num_tuples_loaded = truncated_len
|
||||
tuples_per_device = truncated_len // len(self.devices)
|
||||
assert tuples_per_device > 0, "No data loaded?"
|
||||
assert tuples_per_device % self._loaded_per_device_batch_size == 0
|
||||
|
|
Loading…
Add table
Reference in a new issue