From 3c0803e7e909f2f617aa9bd8f3d266d31cbe3026 Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Wed, 17 Oct 2018 17:44:51 -0700 Subject: [PATCH] [rllib] use `ray.wait` to get next worker result in async sample optimizer (#2993) --- .../optimizers/async_gradients_optimizer.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/python/ray/rllib/optimizers/async_gradients_optimizer.py b/python/ray/rllib/optimizers/async_gradients_optimizer.py index fc7fdb248..499d2a91f 100644 --- a/python/ray/rllib/optimizers/async_gradients_optimizer.py +++ b/python/ray/rllib/optimizers/async_gradients_optimizer.py @@ -27,21 +27,26 @@ class AsyncGradientsOptimizer(PolicyOptimizer): def step(self): weights = ray.put(self.local_evaluator.get_weights()) - gradient_queue = [] + pending_gradients = {} num_gradients = 0 # Kick off the first wave of async tasks for e in self.remote_evaluators: e.set_weights.remote(weights) - fut = e.compute_gradients.remote(e.sample.remote()) - gradient_queue.append((fut, e)) + future = e.compute_gradients.remote(e.sample.remote()) + pending_gradients[future] = e num_gradients += 1 - # Note: can't use wait: https://github.com/ray-project/ray/issues/1128 - while gradient_queue: + while pending_gradients: with self.wait_timer: - fut, e = gradient_queue.pop(0) - gradient, info = ray.get(fut) + wait_results = ray.wait( + list(pending_gradients.keys()), num_returns=1) + ready_list = wait_results[0] + future = ready_list[0] + + gradient, info = ray.get(future) + e = pending_gradients.pop(future) + if "stats" in info: self.learner_stats = info["stats"] @@ -54,8 +59,9 @@ class AsyncGradientsOptimizer(PolicyOptimizer): if num_gradients < self.grads_per_step: with self.dispatch_timer: e.set_weights.remote(self.local_evaluator.get_weights()) - fut = e.compute_gradients.remote(e.sample.remote()) - gradient_queue.append((fut, e)) + future = e.compute_gradients.remote(e.sample.remote()) + + pending_gradients[future] = e num_gradients += 1 def stats(self):