mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] use ray.wait
to get next worker result in async sample optimizer (#2993)
This commit is contained in:
parent
a41bbc10ef
commit
3c0803e7e9
1 changed files with 15 additions and 9 deletions
|
@ -27,21 +27,26 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
|||
|
||||
def step(self):
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
gradient_queue = []
|
||||
pending_gradients = {}
|
||||
num_gradients = 0
|
||||
|
||||
# Kick off the first wave of async tasks
|
||||
for e in self.remote_evaluators:
|
||||
e.set_weights.remote(weights)
|
||||
fut = e.compute_gradients.remote(e.sample.remote())
|
||||
gradient_queue.append((fut, e))
|
||||
future = e.compute_gradients.remote(e.sample.remote())
|
||||
pending_gradients[future] = e
|
||||
num_gradients += 1
|
||||
|
||||
# Note: can't use wait: https://github.com/ray-project/ray/issues/1128
|
||||
while gradient_queue:
|
||||
while pending_gradients:
|
||||
with self.wait_timer:
|
||||
fut, e = gradient_queue.pop(0)
|
||||
gradient, info = ray.get(fut)
|
||||
wait_results = ray.wait(
|
||||
list(pending_gradients.keys()), num_returns=1)
|
||||
ready_list = wait_results[0]
|
||||
future = ready_list[0]
|
||||
|
||||
gradient, info = ray.get(future)
|
||||
e = pending_gradients.pop(future)
|
||||
|
||||
if "stats" in info:
|
||||
self.learner_stats = info["stats"]
|
||||
|
||||
|
@ -54,8 +59,9 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
|||
if num_gradients < self.grads_per_step:
|
||||
with self.dispatch_timer:
|
||||
e.set_weights.remote(self.local_evaluator.get_weights())
|
||||
fut = e.compute_gradients.remote(e.sample.remote())
|
||||
gradient_queue.append((fut, e))
|
||||
future = e.compute_gradients.remote(e.sample.remote())
|
||||
|
||||
pending_gradients[future] = e
|
||||
num_gradients += 1
|
||||
|
||||
def stats(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue