ray/rllib/optimizers/sync_replay_optimizer.py
2019-08-06 16:22:06 -07:00

217 lines
9 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import collections
import numpy as np
import ray
from ray.rllib.optimizers.replay_buffer import ReplayBuffer, \
PrioritizedReplayBuffer
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.compression import pack_if_needed
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.schedules import LinearSchedule
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
class SyncReplayOptimizer(PolicyOptimizer):
"""Variant of the local sync optimizer that supports replay (for DQN).
This optimizer requires that rollout workers return an additional
"td_error" array in the info return of compute_gradients(). This error
term will be used for sample prioritization."""
def __init__(self,
workers,
learning_starts=1000,
buffer_size=10000,
prioritized_replay=True,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6,
schedule_max_timesteps=100000,
beta_annealing_fraction=0.2,
final_prioritized_replay_beta=0.4,
train_batch_size=32,
sample_batch_size=4,
before_learn_on_batch=None,
synchronize_sampling=False):
"""Initialize an sync replay optimizer.
Arguments:
workers (WorkerSet): all workers
learning_starts (int): wait until this many steps have been sampled
before starting optimization.
buffer_size (int): max size of the replay buffer
prioritized_replay (bool): whether to enable prioritized replay
prioritized_replay_alpha (float): replay alpha hyperparameter
prioritized_replay_beta (float): replay beta hyperparameter
prioritized_replay_eps (float): replay eps hyperparameter
schedule_max_timesteps (int): number of timesteps in the schedule
beta_annealing_fraction (float): fraction of schedule to anneal
beta over
final_prioritized_replay_beta (float): final value of beta
train_batch_size (int): size of batches to learn on
sample_batch_size (int): size of batches to sample from workers
before_learn_on_batch (function): callback to run before passing
the sampled batch to learn on
synchronize_sampling (bool): whether to sample the experiences for
all policies with the same indices (used in MADDPG).
"""
PolicyOptimizer.__init__(self, workers)
self.replay_starts = learning_starts
# linearly annealing beta used in Rainbow paper
self.prioritized_replay_beta = LinearSchedule(
schedule_timesteps=int(
schedule_max_timesteps * beta_annealing_fraction),
initial_p=prioritized_replay_beta,
final_p=final_prioritized_replay_beta)
self.prioritized_replay_eps = prioritized_replay_eps
self.train_batch_size = train_batch_size
self.before_learn_on_batch = before_learn_on_batch
self.synchronize_sampling = synchronize_sampling
# Stats
self.update_weights_timer = TimerStat()
self.sample_timer = TimerStat()
self.replay_timer = TimerStat()
self.grad_timer = TimerStat()
self.learner_stats = {}
# Set up replay buffer
if prioritized_replay:
def new_buffer():
return PrioritizedReplayBuffer(
buffer_size, alpha=prioritized_replay_alpha)
else:
def new_buffer():
return ReplayBuffer(buffer_size)
self.replay_buffers = collections.defaultdict(new_buffer)
if buffer_size < self.replay_starts:
logger.warning("buffer_size={} < replay_starts={}".format(
buffer_size, self.replay_starts))
@override(PolicyOptimizer)
def step(self):
with self.update_weights_timer:
if self.workers.remote_workers():
weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
with self.sample_timer:
if self.workers.remote_workers():
batch = SampleBatch.concat_samples(
ray_get_and_free([
e.sample.remote()
for e in self.workers.remote_workers()
]))
else:
batch = self.workers.local_worker().sample()
# Handle everything as if multiagent
if isinstance(batch, SampleBatch):
batch = MultiAgentBatch({
DEFAULT_POLICY_ID: batch
}, batch.count)
for policy_id, s in batch.policy_batches.items():
for row in s.rows():
self.replay_buffers[policy_id].add(
pack_if_needed(row["obs"]),
row["actions"],
row["rewards"],
pack_if_needed(row["new_obs"]),
row["dones"],
weight=None)
if self.num_steps_sampled >= self.replay_starts:
self._optimize()
self.num_steps_sampled += batch.count
@override(PolicyOptimizer)
def stats(self):
return dict(
PolicyOptimizer.stats(self), **{
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
"update_time_ms": round(1000 * self.update_weights_timer.mean,
3),
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
3),
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
"learner": self.learner_stats,
})
def _optimize(self):
samples = self._replay()
with self.grad_timer:
if self.before_learn_on_batch:
samples = self.before_learn_on_batch(
samples,
self.workers.local_worker().policy_map,
self.train_batch_size)
info_dict = self.workers.local_worker().learn_on_batch(samples)
for policy_id, info in info_dict.items():
self.learner_stats[policy_id] = get_learner_stats(info)
replay_buffer = self.replay_buffers[policy_id]
if isinstance(replay_buffer, PrioritizedReplayBuffer):
td_error = info["td_error"]
new_priorities = (
np.abs(td_error) + self.prioritized_replay_eps)
replay_buffer.update_priorities(
samples.policy_batches[policy_id]["batch_indexes"],
new_priorities)
self.grad_timer.push_units_processed(samples.count)
self.num_steps_trained += samples.count
def _replay(self):
samples = {}
idxes = None
with self.replay_timer:
for policy_id, replay_buffer in self.replay_buffers.items():
if self.synchronize_sampling:
if idxes is None:
idxes = replay_buffer.sample_idxes(
self.train_batch_size)
else:
idxes = replay_buffer.sample_idxes(self.train_batch_size)
if isinstance(replay_buffer, PrioritizedReplayBuffer):
(obses_t, actions, rewards, obses_tp1, dones, weights,
batch_indexes) = replay_buffer.sample_with_idxes(
idxes,
beta=self.prioritized_replay_beta.value(
self.num_steps_trained))
else:
(obses_t, actions, rewards, obses_tp1,
dones) = replay_buffer.sample_with_idxes(idxes)
weights = np.ones_like(rewards)
batch_indexes = -np.ones_like(rewards)
samples[policy_id] = SampleBatch({
"obs": obses_t,
"actions": actions,
"rewards": rewards,
"new_obs": obses_tp1,
"dones": dones,
"weights": weights,
"batch_indexes": batch_indexes
})
return MultiAgentBatch(samples, self.train_batch_size)