mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
231 lines
9.6 KiB
Python
231 lines
9.6 KiB
Python
import logging
|
|
import collections
|
|
import numpy as np
|
|
|
|
import ray
|
|
from ray.rllib.optimizers.replay_buffer import ReplayBuffer, \
|
|
PrioritizedReplayBuffer
|
|
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
|
from ray.rllib.evaluation.metrics import get_learner_stats
|
|
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
|
MultiAgentBatch
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.compression import pack_if_needed
|
|
from ray.rllib.utils.timer import TimerStat
|
|
from ray.rllib.utils.schedules import PiecewiseSchedule
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SyncReplayOptimizer(PolicyOptimizer):
|
|
"""Variant of the local sync optimizer that supports replay (for DQN).
|
|
|
|
This optimizer requires that rollout workers return an additional
|
|
"td_error" array in the info return of compute_gradients(). This error
|
|
term will be used for sample prioritization."""
|
|
|
|
def __init__(
|
|
self,
|
|
workers,
|
|
learning_starts=1000,
|
|
buffer_size=10000,
|
|
prioritized_replay=True,
|
|
prioritized_replay_alpha=0.6,
|
|
prioritized_replay_beta=0.4,
|
|
prioritized_replay_eps=1e-6,
|
|
final_prioritized_replay_beta=0.4,
|
|
train_batch_size=32,
|
|
before_learn_on_batch=None,
|
|
synchronize_sampling=False,
|
|
prioritized_replay_beta_annealing_timesteps=100000 * 0.2,
|
|
):
|
|
"""Initialize an sync replay optimizer.
|
|
|
|
Args:
|
|
workers (WorkerSet): all workers
|
|
learning_starts (int): wait until this many steps have been sampled
|
|
before starting optimization.
|
|
buffer_size (int): max size of the replay buffer
|
|
prioritized_replay (bool): whether to enable prioritized replay
|
|
prioritized_replay_alpha (float): replay alpha hyperparameter
|
|
prioritized_replay_beta (float): replay beta hyperparameter
|
|
prioritized_replay_eps (float): replay eps hyperparameter
|
|
final_prioritized_replay_beta (float): Final value of beta.
|
|
train_batch_size (int): size of batches to learn on
|
|
before_learn_on_batch (function): callback to run before passing
|
|
the sampled batch to learn on
|
|
synchronize_sampling (bool): whether to sample the experiences for
|
|
all policies with the same indices (used in MADDPG).
|
|
prioritized_replay_beta_annealing_timesteps (int): The timestep at
|
|
which PR-beta annealing should end.
|
|
"""
|
|
PolicyOptimizer.__init__(self, workers)
|
|
|
|
self.replay_starts = learning_starts
|
|
|
|
# Linearly annealing beta used in Rainbow paper, stopping at
|
|
# `final_prioritized_replay_beta`.
|
|
self.prioritized_replay_beta = PiecewiseSchedule(
|
|
endpoints=[(0, prioritized_replay_beta),
|
|
(prioritized_replay_beta_annealing_timesteps,
|
|
final_prioritized_replay_beta)],
|
|
outside_value=final_prioritized_replay_beta,
|
|
framework=None)
|
|
self.prioritized_replay_eps = prioritized_replay_eps
|
|
self.train_batch_size = train_batch_size
|
|
self.before_learn_on_batch = before_learn_on_batch
|
|
self.synchronize_sampling = synchronize_sampling
|
|
|
|
# Stats
|
|
self.update_weights_timer = TimerStat()
|
|
self.sample_timer = TimerStat()
|
|
self.replay_timer = TimerStat()
|
|
self.grad_timer = TimerStat()
|
|
self.learner_stats = {}
|
|
|
|
# Set up replay buffer
|
|
if prioritized_replay:
|
|
|
|
def new_buffer():
|
|
return PrioritizedReplayBuffer(
|
|
buffer_size, alpha=prioritized_replay_alpha)
|
|
else:
|
|
|
|
def new_buffer():
|
|
return ReplayBuffer(buffer_size)
|
|
|
|
self.replay_buffers = collections.defaultdict(new_buffer)
|
|
|
|
if buffer_size < self.replay_starts:
|
|
logger.warning("buffer_size={} < replay_starts={}".format(
|
|
buffer_size, self.replay_starts))
|
|
|
|
# If set, will use this batch for stepping/updating, instead of
|
|
# sampling from the replay buffer. Actual sampling from the env
|
|
# (and adding collected experiences to the replay will still happen
|
|
# normally).
|
|
# After self.step(), self.fake_batch must be set again.
|
|
self._fake_batch = None
|
|
|
|
@override(PolicyOptimizer)
|
|
def step(self):
|
|
with self.update_weights_timer:
|
|
if self.workers.remote_workers():
|
|
weights = ray.put(self.workers.local_worker().get_weights())
|
|
for e in self.workers.remote_workers():
|
|
e.set_weights.remote(weights)
|
|
|
|
with self.sample_timer:
|
|
if self.workers.remote_workers():
|
|
batch = SampleBatch.concat_samples(
|
|
ray.get([
|
|
e.sample.remote()
|
|
for e in self.workers.remote_workers()
|
|
]))
|
|
else:
|
|
batch = self.workers.local_worker().sample()
|
|
|
|
# Handle everything as if multiagent
|
|
if isinstance(batch, SampleBatch):
|
|
batch = MultiAgentBatch({
|
|
DEFAULT_POLICY_ID: batch
|
|
}, batch.count)
|
|
|
|
for policy_id, s in batch.policy_batches.items():
|
|
for row in s.rows():
|
|
self.replay_buffers[policy_id].add(
|
|
pack_if_needed(row["obs"]),
|
|
row["actions"],
|
|
row["rewards"],
|
|
pack_if_needed(row["new_obs"]),
|
|
row["dones"],
|
|
weight=None)
|
|
|
|
if self.num_steps_sampled >= self.replay_starts:
|
|
self._optimize()
|
|
|
|
self.num_steps_sampled += batch.count
|
|
|
|
@override(PolicyOptimizer)
|
|
def stats(self):
|
|
return dict(
|
|
PolicyOptimizer.stats(self), **{
|
|
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
|
|
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
|
|
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
|
|
"update_time_ms": round(1000 * self.update_weights_timer.mean,
|
|
3),
|
|
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
|
|
3),
|
|
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
|
|
"learner": self.learner_stats,
|
|
})
|
|
|
|
def _optimize(self):
|
|
if self._fake_batch:
|
|
fake_batch = SampleBatch(self._fake_batch)
|
|
samples = MultiAgentBatch({
|
|
DEFAULT_POLICY_ID: fake_batch
|
|
}, fake_batch.count)
|
|
else:
|
|
samples = self._replay()
|
|
|
|
with self.grad_timer:
|
|
if self.before_learn_on_batch:
|
|
samples = self.before_learn_on_batch(
|
|
samples,
|
|
self.workers.local_worker().policy_map,
|
|
self.train_batch_size)
|
|
info_dict = self.workers.local_worker().learn_on_batch(samples)
|
|
for policy_id, info in info_dict.items():
|
|
self.learner_stats[policy_id] = get_learner_stats(info)
|
|
replay_buffer = self.replay_buffers[policy_id]
|
|
if isinstance(replay_buffer, PrioritizedReplayBuffer):
|
|
# TODO(sven): This is currently structured differently for
|
|
# torch/tf. Clean up these results/info dicts across
|
|
# policies (note: fixing this in torch_policy.py will
|
|
# break e.g. DDPPO!).
|
|
td_error = info.get("td_error",
|
|
info["learner_stats"].get("td_error"))
|
|
new_priorities = (
|
|
np.abs(td_error) + self.prioritized_replay_eps)
|
|
replay_buffer.update_priorities(
|
|
samples.policy_batches[policy_id]["batch_indexes"],
|
|
new_priorities)
|
|
self.grad_timer.push_units_processed(samples.count)
|
|
|
|
self.num_steps_trained += samples.count
|
|
|
|
def _replay(self):
|
|
samples = {}
|
|
idxes = None
|
|
with self.replay_timer:
|
|
for policy_id, replay_buffer in self.replay_buffers.items():
|
|
if self.synchronize_sampling:
|
|
if idxes is None:
|
|
idxes = replay_buffer.sample_idxes(
|
|
self.train_batch_size)
|
|
else:
|
|
idxes = replay_buffer.sample_idxes(self.train_batch_size)
|
|
|
|
if isinstance(replay_buffer, PrioritizedReplayBuffer):
|
|
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
|
batch_indexes) = replay_buffer.sample_with_idxes(
|
|
idxes,
|
|
beta=self.prioritized_replay_beta.value(
|
|
self.num_steps_trained))
|
|
else:
|
|
(obses_t, actions, rewards, obses_tp1,
|
|
dones) = replay_buffer.sample_with_idxes(idxes)
|
|
weights = np.ones_like(rewards)
|
|
batch_indexes = -np.ones_like(rewards)
|
|
samples[policy_id] = SampleBatch({
|
|
"obs": obses_t,
|
|
"actions": actions,
|
|
"rewards": rewards,
|
|
"new_obs": obses_tp1,
|
|
"dones": dones,
|
|
"weights": weights,
|
|
"batch_indexes": batch_indexes
|
|
})
|
|
return MultiAgentBatch(samples, self.train_batch_size)
|