mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
87 lines
3.3 KiB
Python
87 lines
3.3 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import ray
|
|
from ray.rllib.evaluation.metrics import get_learner_stats
|
|
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.timer import TimerStat
|
|
from ray.rllib.utils.memory import ray_get_and_free
|
|
|
|
|
|
class AsyncGradientsOptimizer(PolicyOptimizer):
|
|
"""An asynchronous RL optimizer, e.g. for implementing A3C.
|
|
|
|
This optimizer asynchronously pulls and applies gradients from remote
|
|
workers, sending updated weights back as needed. This pipelines the
|
|
gradient computations on the remote workers.
|
|
"""
|
|
|
|
def __init__(self, workers, grads_per_step=100):
|
|
"""Initialize an async gradients optimizer.
|
|
|
|
Arguments:
|
|
grads_per_step (int): The number of gradients to collect and apply
|
|
per each call to step(). This number should be sufficiently
|
|
high to amortize the overhead of calling step().
|
|
"""
|
|
PolicyOptimizer.__init__(self, workers)
|
|
|
|
self.apply_timer = TimerStat()
|
|
self.wait_timer = TimerStat()
|
|
self.dispatch_timer = TimerStat()
|
|
self.grads_per_step = grads_per_step
|
|
self.learner_stats = {}
|
|
if not self.workers.remote_workers():
|
|
raise ValueError(
|
|
"Async optimizer requires at least 1 remote workers")
|
|
|
|
@override(PolicyOptimizer)
|
|
def step(self):
|
|
weights = ray.put(self.workers.local_worker().get_weights())
|
|
pending_gradients = {}
|
|
num_gradients = 0
|
|
|
|
# Kick off the first wave of async tasks
|
|
for e in self.workers.remote_workers():
|
|
e.set_weights.remote(weights)
|
|
future = e.compute_gradients.remote(e.sample.remote())
|
|
pending_gradients[future] = e
|
|
num_gradients += 1
|
|
|
|
while pending_gradients:
|
|
with self.wait_timer:
|
|
wait_results = ray.wait(
|
|
list(pending_gradients.keys()), num_returns=1)
|
|
ready_list = wait_results[0]
|
|
future = ready_list[0]
|
|
|
|
gradient, info = ray_get_and_free(future)
|
|
e = pending_gradients.pop(future)
|
|
self.learner_stats = get_learner_stats(info)
|
|
|
|
if gradient is not None:
|
|
with self.apply_timer:
|
|
self.workers.local_worker().apply_gradients(gradient)
|
|
self.num_steps_sampled += info["batch_count"]
|
|
self.num_steps_trained += info["batch_count"]
|
|
|
|
if num_gradients < self.grads_per_step:
|
|
with self.dispatch_timer:
|
|
e.set_weights.remote(
|
|
self.workers.local_worker().get_weights())
|
|
future = e.compute_gradients.remote(e.sample.remote())
|
|
|
|
pending_gradients[future] = e
|
|
num_gradients += 1
|
|
|
|
@override(PolicyOptimizer)
|
|
def stats(self):
|
|
return dict(
|
|
PolicyOptimizer.stats(self), **{
|
|
"wait_time_ms": round(1000 * self.wait_timer.mean, 3),
|
|
"apply_time_ms": round(1000 * self.apply_timer.mean, 3),
|
|
"dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3),
|
|
"learner": self.learner_stats,
|
|
})
|