2017-12-06 17:51:57 -08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import ray
|
|
|
|
from ray.rllib.optimizers.optimizer import Optimizer
|
|
|
|
from ray.rllib.utils.timer import TimerStat
|
|
|
|
|
|
|
|
|
|
|
|
class AsyncOptimizer(Optimizer):
|
|
|
|
"""An asynchronous RL optimizer, e.g. for implementing A3C.
|
|
|
|
|
|
|
|
This optimizer asynchronously pulls and applies gradients from remote
|
|
|
|
evaluators, sending updated weights back as needed. This pipelines the
|
|
|
|
gradient computations on the remote workers.
|
|
|
|
"""
|
2018-03-04 12:25:25 -08:00
|
|
|
def _init(self, grads_per_step=100, batch_size=10):
|
2017-12-06 17:51:57 -08:00
|
|
|
self.apply_timer = TimerStat()
|
|
|
|
self.wait_timer = TimerStat()
|
|
|
|
self.dispatch_timer = TimerStat()
|
2018-03-04 12:25:25 -08:00
|
|
|
self.grads_per_step = grads_per_step
|
|
|
|
self.batch_size = batch_size
|
2017-12-06 17:51:57 -08:00
|
|
|
|
|
|
|
def step(self):
|
|
|
|
weights = ray.put(self.local_evaluator.get_weights())
|
|
|
|
gradient_queue = []
|
|
|
|
num_gradients = 0
|
|
|
|
|
|
|
|
# Kick off the first wave of async tasks
|
|
|
|
for e in self.remote_evaluators:
|
|
|
|
e.set_weights.remote(weights)
|
|
|
|
fut = e.compute_gradients.remote(e.sample.remote())
|
|
|
|
gradient_queue.append((fut, e))
|
|
|
|
num_gradients += 1
|
|
|
|
|
|
|
|
# Note: can't use wait: https://github.com/ray-project/ray/issues/1128
|
|
|
|
while gradient_queue:
|
|
|
|
with self.wait_timer:
|
2017-12-24 12:25:13 -08:00
|
|
|
fut, e = gradient_queue.pop(0)
|
2017-12-06 17:51:57 -08:00
|
|
|
gradient = ray.get(fut)
|
|
|
|
|
|
|
|
if gradient is not None:
|
|
|
|
with self.apply_timer:
|
|
|
|
self.local_evaluator.apply_gradients(gradient)
|
|
|
|
|
|
|
|
if num_gradients < self.grads_per_step:
|
|
|
|
with self.dispatch_timer:
|
|
|
|
e.set_weights.remote(self.local_evaluator.get_weights())
|
|
|
|
fut = e.compute_gradients.remote(e.sample.remote())
|
|
|
|
gradient_queue.append((fut, e))
|
|
|
|
num_gradients += 1
|
|
|
|
|
2018-03-04 12:25:25 -08:00
|
|
|
self.num_steps_sampled += self.grads_per_step * self.batch_size
|
|
|
|
self.num_steps_trained += self.grads_per_step * self.batch_size
|
|
|
|
|
2017-12-06 17:51:57 -08:00
|
|
|
def stats(self):
|
2018-03-04 12:25:25 -08:00
|
|
|
return dict(Optimizer.stats(), **{
|
2017-12-06 17:51:57 -08:00
|
|
|
"wait_time_ms": round(1000 * self.wait_timer.mean, 3),
|
|
|
|
"apply_time_ms": round(1000 * self.apply_timer.mean, 3),
|
|
|
|
"dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3),
|
2018-03-04 12:25:25 -08:00
|
|
|
})
|