diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index 22b7ea18b..de171c0ca 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -368,15 +368,16 @@ class TFMultiGPULearner(LearnerThread): assert self.loader_thread.is_alive() with self.load_wait_timer: opt, released = self.minibatch_buffer.get() - if released: - self.idle_optimizers.put(opt) with self.grad_timer: fetches = opt.optimize(self.sess, 0) self.weights_updated = True self.stats = fetches.get("stats", {}) - self.outqueue.put(self.train_batch_size) + if released: + self.idle_optimizers.put(opt) + + self.outqueue.put(opt.num_tuples_loaded) self.learner_queue_size.push(self.inqueue.qsize()) diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 7c00fda99..337ca11aa 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -188,6 +188,7 @@ class LocalSyncParallelOptimizer(object): sess.run([t.init_op for t in self._towers], feed_dict=feed_dict) + self.num_tuples_loaded = truncated_len tuples_per_device = truncated_len // len(self.devices) assert tuples_per_device > 0, "No data loaded?" assert tuples_per_device % self._loaded_per_device_batch_size == 0