2017-12-06 17:51:57 -08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import ray
|
2018-03-15 15:57:31 -07:00
|
|
|
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
2018-07-01 00:05:08 -07:00
|
|
|
from ray.rllib.evaluation.sample_batch import SampleBatch
|
2018-03-04 12:25:25 -08:00
|
|
|
from ray.rllib.utils.filter import RunningStat
|
2017-12-06 17:51:57 -08:00
|
|
|
from ray.rllib.utils.timer import TimerStat
|
|
|
|
|
|
|
|
|
2018-06-27 02:30:15 -07:00
|
|
|
class SyncSamplesOptimizer(PolicyOptimizer):
|
2017-12-06 17:51:57 -08:00
|
|
|
"""A simple synchronous RL optimizer.
|
|
|
|
|
|
|
|
In each step, this optimizer pulls samples from a number of remote
|
|
|
|
evaluators, concatenates them, and then updates a local model. The updated
|
|
|
|
model weights are then broadcast to all remote evaluators.
|
|
|
|
"""
|
|
|
|
|
2018-08-06 12:10:59 -07:00
|
|
|
def _init(self, num_sgd_iter=1, timesteps_per_batch=1):
|
2017-12-06 17:51:57 -08:00
|
|
|
self.update_weights_timer = TimerStat()
|
|
|
|
self.sample_timer = TimerStat()
|
|
|
|
self.grad_timer = TimerStat()
|
2018-03-04 12:25:25 -08:00
|
|
|
self.throughput = RunningStat()
|
2018-07-12 19:22:46 +02:00
|
|
|
self.num_sgd_iter = num_sgd_iter
|
2018-08-06 12:10:59 -07:00
|
|
|
self.timesteps_per_batch = timesteps_per_batch
|
2018-08-23 17:49:10 -07:00
|
|
|
self.learner_stats = {}
|
2017-12-06 17:51:57 -08:00
|
|
|
|
|
|
|
def step(self):
|
|
|
|
with self.update_weights_timer:
|
|
|
|
if self.remote_evaluators:
|
|
|
|
weights = ray.put(self.local_evaluator.get_weights())
|
|
|
|
for e in self.remote_evaluators:
|
|
|
|
e.set_weights.remote(weights)
|
|
|
|
|
|
|
|
with self.sample_timer:
|
2018-08-06 12:10:59 -07:00
|
|
|
samples = []
|
|
|
|
while sum(s.count for s in samples) < self.timesteps_per_batch:
|
|
|
|
if self.remote_evaluators:
|
|
|
|
samples.extend(
|
|
|
|
ray.get([
|
|
|
|
e.sample.remote() for e in self.remote_evaluators
|
|
|
|
]))
|
|
|
|
else:
|
|
|
|
samples.append(self.local_evaluator.sample())
|
|
|
|
samples = SampleBatch.concat_samples(samples)
|
2018-08-20 15:28:03 -07:00
|
|
|
self.sample_timer.push_units_processed(samples.count)
|
2017-12-06 17:51:57 -08:00
|
|
|
|
|
|
|
with self.grad_timer:
|
2018-07-12 19:22:46 +02:00
|
|
|
for i in range(self.num_sgd_iter):
|
|
|
|
fetches = self.local_evaluator.compute_apply(samples)
|
2018-08-23 17:49:10 -07:00
|
|
|
if "stats" in fetches:
|
|
|
|
self.learner_stats = fetches["stats"]
|
2018-07-12 19:22:46 +02:00
|
|
|
if self.num_sgd_iter > 1:
|
|
|
|
print(i, fetches)
|
2018-03-04 12:25:25 -08:00
|
|
|
self.grad_timer.push_units_processed(samples.count)
|
|
|
|
|
|
|
|
self.num_steps_sampled += samples.count
|
|
|
|
self.num_steps_trained += samples.count
|
2018-07-12 19:22:46 +02:00
|
|
|
return fetches
|
2017-12-06 17:51:57 -08:00
|
|
|
|
|
|
|
def stats(self):
|
2018-07-19 15:30:36 -07:00
|
|
|
return dict(
|
|
|
|
PolicyOptimizer.stats(self), **{
|
|
|
|
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
|
|
|
|
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
|
|
|
|
"update_time_ms": round(1000 * self.update_weights_timer.mean,
|
|
|
|
3),
|
|
|
|
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
|
|
|
|
3),
|
2018-08-20 15:28:03 -07:00
|
|
|
"sample_peak_throughput": round(
|
|
|
|
self.sample_timer.mean_throughput, 3),
|
2018-07-19 15:30:36 -07:00
|
|
|
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
|
2018-08-23 17:49:10 -07:00
|
|
|
"learner": self.learner_stats,
|
2018-07-19 15:30:36 -07:00
|
|
|
})
|