ray/rllib/optimizers/microbatch_optimizer.py

143 lines
5.9 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import ray
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.filter import RunningStat
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
class MicrobatchOptimizer(PolicyOptimizer):
"""A microbatching synchronous RL optimizer.
This optimizer pulls sample batches from workers until the target
microbatch size is reached. Then, it computes and accumulates the policy
gradient in a local buffer. This process is repeated until the number of
samples collected equals the train batch size. Then, an accumulated
gradient update is made.
This allows for training with effective batch sizes much larger than can
fit in GPU or host memory.
"""
def __init__(self, workers, train_batch_size=10000, microbatch_size=1000):
PolicyOptimizer.__init__(self, workers)
if train_batch_size <= microbatch_size:
raise ValueError(
"The microbatch size must be smaller than the train batch "
"size, got {} vs {}".format(microbatch_size, train_batch_size))
self.update_weights_timer = TimerStat()
self.sample_timer = TimerStat()
self.grad_timer = TimerStat()
self.throughput = RunningStat()
self.train_batch_size = train_batch_size
self.microbatch_size = microbatch_size
self.learner_stats = {}
self.policies = dict(self.workers.local_worker()
.foreach_trainable_policy(lambda p, i: (i, p)))
logger.debug("Policies to train: {}".format(self.policies))
@override(PolicyOptimizer)
def step(self):
with self.update_weights_timer:
if self.workers.remote_workers():
weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
fetches = {}
accumulated_gradients = {}
samples_so_far = 0
# Accumulate minibatches.
i = 0
while samples_so_far < self.train_batch_size:
i += 1
with self.sample_timer:
samples = []
while sum(s.count for s in samples) < self.microbatch_size:
if self.workers.remote_workers():
samples.extend(
ray_get_and_free([
e.sample.remote()
for e in self.workers.remote_workers()
]))
else:
samples.append(self.workers.local_worker().sample())
samples = SampleBatch.concat_samples(samples)
self.sample_timer.push_units_processed(samples.count)
samples_so_far += samples.count
logger.info(
"Computing gradients for microbatch {} ({}/{} samples)".format(
i, samples_so_far, self.train_batch_size))
# Handle everything as if multiagent
if isinstance(samples, SampleBatch):
samples = MultiAgentBatch({
DEFAULT_POLICY_ID: samples
}, samples.count)
with self.grad_timer:
for policy_id, policy in self.policies.items():
if policy_id not in samples.policy_batches:
continue
batch = samples.policy_batches[policy_id]
grad_out, info_out = (
self.workers.local_worker().compute_gradients(
MultiAgentBatch({
policy_id: batch
}, batch.count)))
grad = grad_out[policy_id]
fetches.update(info_out)
if policy_id not in accumulated_gradients:
accumulated_gradients[policy_id] = grad
else:
grad_size = len(accumulated_gradients[policy_id])
assert grad_size == len(grad), (grad_size, len(grad))
c = []
for a, b in zip(accumulated_gradients[policy_id],
grad):
c.append(a + b)
accumulated_gradients[policy_id] = c
self.grad_timer.push_units_processed(samples.count)
# Apply the accumulated gradient
logger.info("Applying accumulated gradients ({} samples)".format(
samples_so_far))
self.workers.local_worker().apply_gradients(accumulated_gradients)
if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
self.learner_stats = fetches[DEFAULT_POLICY_ID]
else:
self.learner_stats = fetches
self.num_steps_sampled += samples_so_far
self.num_steps_trained += samples_so_far
return self.learner_stats
@override(PolicyOptimizer)
def stats(self):
return dict(
PolicyOptimizer.stats(self), **{
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
"update_time_ms": round(1000 * self.update_weights_timer.mean,
3),
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
3),
"sample_peak_throughput": round(
self.sample_timer.mean_throughput, 3),
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
"learner": self.learner_stats,
})