diff --git a/python/ray/rllib/optimizers/async_gradients_optimizer.py b/python/ray/rllib/optimizers/async_gradients_optimizer.py index fc7fdb248..499d2a91f 100644 --- a/python/ray/rllib/optimizers/async_gradients_optimizer.py +++ b/python/ray/rllib/optimizers/async_gradients_optimizer.py @@ -27,21 +27,26 @@ class AsyncGradientsOptimizer(PolicyOptimizer): def step(self): weights = ray.put(self.local_evaluator.get_weights()) - gradient_queue = [] + pending_gradients = {} num_gradients = 0 # Kick off the first wave of async tasks for e in self.remote_evaluators: e.set_weights.remote(weights) - fut = e.compute_gradients.remote(e.sample.remote()) - gradient_queue.append((fut, e)) + future = e.compute_gradients.remote(e.sample.remote()) + pending_gradients[future] = e num_gradients += 1 - # Note: can't use wait: https://github.com/ray-project/ray/issues/1128 - while gradient_queue: + while pending_gradients: with self.wait_timer: - fut, e = gradient_queue.pop(0) - gradient, info = ray.get(fut) + wait_results = ray.wait( + list(pending_gradients.keys()), num_returns=1) + ready_list = wait_results[0] + future = ready_list[0] + + gradient, info = ray.get(future) + e = pending_gradients.pop(future) + if "stats" in info: self.learner_stats = info["stats"] @@ -54,8 +59,9 @@ class AsyncGradientsOptimizer(PolicyOptimizer): if num_gradients < self.grads_per_step: with self.dispatch_timer: e.set_weights.remote(self.local_evaluator.get_weights()) - fut = e.compute_gradients.remote(e.sample.remote()) - gradient_queue.append((fut, e)) + future = e.compute_gradients.remote(e.sample.remote()) + + pending_gradients[future] = e num_gradients += 1 def stats(self):