2017-12-06 17:51:57 -08:00
|
|
|
import ray
|
2019-03-27 15:40:15 -07:00
|
|
|
from ray.rllib.evaluation.metrics import get_learner_stats
|
2018-03-15 15:57:31 -07:00
|
|
|
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
2018-12-08 16:28:58 -08:00
|
|
|
from ray.rllib.utils.annotations import override
|
2017-12-06 17:51:57 -08:00
|
|
|
from ray.rllib.utils.timer import TimerStat
|
2019-04-17 20:30:03 -04:00
|
|
|
from ray.rllib.utils.memory import ray_get_and_free
|
2017-12-06 17:51:57 -08:00
|
|
|
|
|
|
|
|
2018-06-27 02:30:15 -07:00
|
|
|
class AsyncGradientsOptimizer(PolicyOptimizer):
|
2017-12-06 17:51:57 -08:00
|
|
|
"""An asynchronous RL optimizer, e.g. for implementing A3C.
|
|
|
|
|
|
|
|
This optimizer asynchronously pulls and applies gradients from remote
|
2019-06-03 06:49:24 +08:00
|
|
|
workers, sending updated weights back as needed. This pipelines the
|
2017-12-06 17:51:57 -08:00
|
|
|
gradient computations on the remote workers.
|
|
|
|
"""
|
2018-07-19 15:30:36 -07:00
|
|
|
|
2019-06-03 06:49:24 +08:00
|
|
|
def __init__(self, workers, grads_per_step=100):
|
2019-07-27 02:08:16 -07:00
|
|
|
"""Initialize an async gradients optimizer.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
grads_per_step (int): The number of gradients to collect and apply
|
|
|
|
per each call to step(). This number should be sufficiently
|
|
|
|
high to amortize the overhead of calling step().
|
|
|
|
"""
|
2019-06-03 06:49:24 +08:00
|
|
|
PolicyOptimizer.__init__(self, workers)
|
2019-04-12 21:03:26 -07:00
|
|
|
|
2017-12-06 17:51:57 -08:00
|
|
|
self.apply_timer = TimerStat()
|
|
|
|
self.wait_timer = TimerStat()
|
|
|
|
self.dispatch_timer = TimerStat()
|
2018-03-04 12:25:25 -08:00
|
|
|
self.grads_per_step = grads_per_step
|
2018-08-01 20:53:53 -07:00
|
|
|
self.learner_stats = {}
|
2019-06-03 06:49:24 +08:00
|
|
|
if not self.workers.remote_workers():
|
2018-06-25 22:33:57 -07:00
|
|
|
raise ValueError(
|
2019-06-03 06:49:24 +08:00
|
|
|
"Async optimizer requires at least 1 remote workers")
|
2017-12-06 17:51:57 -08:00
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(PolicyOptimizer)
|
2017-12-06 17:51:57 -08:00
|
|
|
def step(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
weights = ray.put(self.workers.local_worker().get_weights())
|
2018-10-17 17:44:51 -07:00
|
|
|
pending_gradients = {}
|
2017-12-06 17:51:57 -08:00
|
|
|
num_gradients = 0
|
|
|
|
|
|
|
|
# Kick off the first wave of async tasks
|
2019-06-03 06:49:24 +08:00
|
|
|
for e in self.workers.remote_workers():
|
2017-12-06 17:51:57 -08:00
|
|
|
e.set_weights.remote(weights)
|
2018-10-17 17:44:51 -07:00
|
|
|
future = e.compute_gradients.remote(e.sample.remote())
|
|
|
|
pending_gradients[future] = e
|
2017-12-06 17:51:57 -08:00
|
|
|
num_gradients += 1
|
|
|
|
|
2018-10-17 17:44:51 -07:00
|
|
|
while pending_gradients:
|
2017-12-06 17:51:57 -08:00
|
|
|
with self.wait_timer:
|
2018-10-17 17:44:51 -07:00
|
|
|
wait_results = ray.wait(
|
|
|
|
list(pending_gradients.keys()), num_returns=1)
|
|
|
|
ready_list = wait_results[0]
|
|
|
|
future = ready_list[0]
|
|
|
|
|
2019-04-17 20:30:03 -04:00
|
|
|
gradient, info = ray_get_and_free(future)
|
2018-10-17 17:44:51 -07:00
|
|
|
e = pending_gradients.pop(future)
|
2019-03-27 15:40:15 -07:00
|
|
|
self.learner_stats = get_learner_stats(info)
|
2017-12-06 17:51:57 -08:00
|
|
|
|
|
|
|
if gradient is not None:
|
|
|
|
with self.apply_timer:
|
2019-06-03 06:49:24 +08:00
|
|
|
self.workers.local_worker().apply_gradients(gradient)
|
2018-07-18 08:59:52 +02:00
|
|
|
self.num_steps_sampled += info["batch_count"]
|
|
|
|
self.num_steps_trained += info["batch_count"]
|
2017-12-06 17:51:57 -08:00
|
|
|
|
|
|
|
if num_gradients < self.grads_per_step:
|
|
|
|
with self.dispatch_timer:
|
2019-06-03 06:49:24 +08:00
|
|
|
e.set_weights.remote(
|
|
|
|
self.workers.local_worker().get_weights())
|
2018-10-17 17:44:51 -07:00
|
|
|
future = e.compute_gradients.remote(e.sample.remote())
|
|
|
|
|
|
|
|
pending_gradients[future] = e
|
2017-12-06 17:51:57 -08:00
|
|
|
num_gradients += 1
|
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(PolicyOptimizer)
|
2017-12-06 17:51:57 -08:00
|
|
|
def stats(self):
|
2018-07-19 15:30:36 -07:00
|
|
|
return dict(
|
|
|
|
PolicyOptimizer.stats(self), **{
|
|
|
|
"wait_time_ms": round(1000 * self.wait_timer.mean, 3),
|
|
|
|
"apply_time_ms": round(1000 * self.apply_timer.mean, 3),
|
|
|
|
"dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3),
|
2018-08-01 20:53:53 -07:00
|
|
|
"learner": self.learner_stats,
|
2018-07-19 15:30:36 -07:00
|
|
|
})
|