ray/rllib/optimizers/microbatch_optimizer.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

139 lines
5.8 KiB
Python

import logging
import ray
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.filter import RunningStat
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
class MicrobatchOptimizer(PolicyOptimizer):
"""A microbatching synchronous RL optimizer.
This optimizer pulls sample batches from workers until the target
microbatch size is reached. Then, it computes and accumulates the policy
gradient in a local buffer. This process is repeated until the number of
samples collected equals the train batch size. Then, an accumulated
gradient update is made.
This allows for training with effective batch sizes much larger than can
fit in GPU or host memory.
"""
def __init__(self, workers, train_batch_size=10000, microbatch_size=1000):
PolicyOptimizer.__init__(self, workers)
if train_batch_size <= microbatch_size:
raise ValueError(
"The microbatch size must be smaller than the train batch "
"size, got {} vs {}".format(microbatch_size, train_batch_size))
self.update_weights_timer = TimerStat()
self.sample_timer = TimerStat()
self.grad_timer = TimerStat()
self.throughput = RunningStat()
self.train_batch_size = train_batch_size
self.microbatch_size = microbatch_size
self.learner_stats = {}
self.policies = dict(self.workers.local_worker()
.foreach_trainable_policy(lambda p, i: (i, p)))
logger.debug("Policies to train: {}".format(self.policies))
@override(PolicyOptimizer)
def step(self):
with self.update_weights_timer:
if self.workers.remote_workers():
weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
fetches = {}
accumulated_gradients = {}
samples_so_far = 0
# Accumulate minibatches.
i = 0
while samples_so_far < self.train_batch_size:
i += 1
with self.sample_timer:
samples = []
while sum(s.count for s in samples) < self.microbatch_size:
if self.workers.remote_workers():
samples.extend(
ray_get_and_free([
e.sample.remote()
for e in self.workers.remote_workers()
]))
else:
samples.append(self.workers.local_worker().sample())
samples = SampleBatch.concat_samples(samples)
self.sample_timer.push_units_processed(samples.count)
samples_so_far += samples.count
logger.info(
"Computing gradients for microbatch {} ({}/{} samples)".format(
i, samples_so_far, self.train_batch_size))
# Handle everything as if multiagent
if isinstance(samples, SampleBatch):
samples = MultiAgentBatch({
DEFAULT_POLICY_ID: samples
}, samples.count)
with self.grad_timer:
for policy_id, policy in self.policies.items():
if policy_id not in samples.policy_batches:
continue
batch = samples.policy_batches[policy_id]
grad_out, info_out = (
self.workers.local_worker().compute_gradients(
MultiAgentBatch({
policy_id: batch
}, batch.count)))
grad = grad_out[policy_id]
fetches.update(info_out)
if policy_id not in accumulated_gradients:
accumulated_gradients[policy_id] = grad
else:
grad_size = len(accumulated_gradients[policy_id])
assert grad_size == len(grad), (grad_size, len(grad))
c = []
for a, b in zip(accumulated_gradients[policy_id],
grad):
c.append(a + b)
accumulated_gradients[policy_id] = c
self.grad_timer.push_units_processed(samples.count)
# Apply the accumulated gradient
logger.info("Applying accumulated gradients ({} samples)".format(
samples_so_far))
self.workers.local_worker().apply_gradients(accumulated_gradients)
if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
self.learner_stats = fetches[DEFAULT_POLICY_ID]
else:
self.learner_stats = fetches
self.num_steps_sampled += samples_so_far
self.num_steps_trained += samples_so_far
return self.learner_stats
@override(PolicyOptimizer)
def stats(self):
return dict(
PolicyOptimizer.stats(self), **{
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
"update_time_ms": round(1000 * self.update_weights_timer.mean,
3),
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
3),
"sample_peak_throughput": round(
self.sample_timer.mean_throughput, 3),
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
"learner": self.learner_stats,
})