ray/rllib/optimizers/async_gradients_optimizer.py

87 lines
3.3 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free
class AsyncGradientsOptimizer(PolicyOptimizer):
"""An asynchronous RL optimizer, e.g. for implementing A3C.
This optimizer asynchronously pulls and applies gradients from remote
workers, sending updated weights back as needed. This pipelines the
gradient computations on the remote workers.
"""
def __init__(self, workers, grads_per_step=100):
"""Initialize an async gradients optimizer.
Arguments:
grads_per_step (int): The number of gradients to collect and apply
per each call to step(). This number should be sufficiently
high to amortize the overhead of calling step().
"""
PolicyOptimizer.__init__(self, workers)
self.apply_timer = TimerStat()
self.wait_timer = TimerStat()
self.dispatch_timer = TimerStat()
self.grads_per_step = grads_per_step
self.learner_stats = {}
if not self.workers.remote_workers():
raise ValueError(
"Async optimizer requires at least 1 remote workers")
@override(PolicyOptimizer)
def step(self):
weights = ray.put(self.workers.local_worker().get_weights())
pending_gradients = {}
num_gradients = 0
# Kick off the first wave of async tasks
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
future = e.compute_gradients.remote(e.sample.remote())
pending_gradients[future] = e
num_gradients += 1
while pending_gradients:
with self.wait_timer:
wait_results = ray.wait(
list(pending_gradients.keys()), num_returns=1)
ready_list = wait_results[0]
future = ready_list[0]
gradient, info = ray_get_and_free(future)
e = pending_gradients.pop(future)
self.learner_stats = get_learner_stats(info)
if gradient is not None:
with self.apply_timer:
self.workers.local_worker().apply_gradients(gradient)
self.num_steps_sampled += info["batch_count"]
self.num_steps_trained += info["batch_count"]
if num_gradients < self.grads_per_step:
with self.dispatch_timer:
e.set_weights.remote(
self.workers.local_worker().get_weights())
future = e.compute_gradients.remote(e.sample.remote())
pending_gradients[future] = e
num_gradients += 1
@override(PolicyOptimizer)
def stats(self):
return dict(
PolicyOptimizer.stats(self), **{
"wait_time_ms": round(1000 * self.wait_timer.mean, 3),
"apply_time_ms": round(1000 * self.apply_timer.mean, 3),
"dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3),
"learner": self.learner_stats,
})