[rllib] First pass at pipeline implementation of DQN (#7433)

* wip iters

* add test

* speed up

* update docs

* document it

* support serial sampling

* add test

* spacing

* annotate it

* update

* rename to pipeline

* comment

* iter2 wip

* update

* update

* context test

* update

* fix

* fix

* a3c pipeline

* doc

* update

* move timer

* comment

* add piepline test

* fix

* clean up

* document

* iter s

* wip dqn

* wip

* wip

* metrics

* metrics rename

* metrics ctx

* wip

* constants

* add todo

* suppport .union

* wip

* support union

* remove prints

* add todo

* remove auto timer

* fix up

* fix pipeline test

* typing

* fix breakage

* remove bad assert

* wip

* fix multiagent example

* fixapply

* update a3c

* remove a2c pl

* 0 workers

* wip

* wip

* share metrics

* wip

* wip

* doc

* fix weight sync and global var updates

* mode

* fix

* fix

* doc

* fix
This commit is contained in:
Eric Liang 2020-03-07 14:47:58 -08:00 committed by GitHub
parent beb9b02dbd
commit a644060daa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 258 additions and 74 deletions

View file

@ -47,14 +47,14 @@ class TuneReporterBase(ProgressReporter):
"""Abstract base class for the default Tune reporters."""
# Truncated representations of column names (to accommodate small screens).
DEFAULT_COLUMNS = {
EPISODE_REWARD_MEAN: "reward",
DEFAULT_COLUMNS = collections.OrderedDict({
MEAN_ACCURACY: "acc",
MEAN_LOSS: "loss",
TRAINING_ITERATION: "iter",
TIME_TOTAL_S: "total time (s)",
TIMESTEPS_TOTAL: "ts",
TRAINING_ITERATION: "iter",
}
EPISODE_REWARD_MEAN: "reward",
})
def __init__(self,
metric_columns=None,
@ -301,7 +301,6 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
k for k in keys if any(
t.last_result.get(k) is not None for t in trials)
]
keys = sorted(keys)
# Build trial rows.
params = sorted(set().union(*[t.evaluated_params for t in trials]))
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]

View file

@ -776,36 +776,35 @@ class LocalIterator(Generic[T]):
if i >= n:
break
def union(self, other: "LocalIterator[T]",
def union(self, *others: "LocalIterator[T]",
deterministic: bool = False) -> "LocalIterator[T]":
"""Return an iterator that is the union of this and the other.
"""Return an iterator that is the union of this and the others.
If deterministic=True, we alternate between reading from one iterator
and the other. Otherwise we return items from iterators as they
and the others. Otherwise we return items from iterators as they
become ready.
"""
if not isinstance(other, LocalIterator):
for it in others:
if not isinstance(it, LocalIterator):
raise ValueError(
"other must be of type LocalIterator, got {}".format(
type(other)))
type(it)))
if deterministic:
timeout = None
else:
timeout = 0
it1 = LocalIterator(
self.base_iterator,
self.metrics,
self.local_transforms,
timeout=timeout)
it2 = LocalIterator(
other.base_iterator,
other.metrics,
other.local_transforms,
timeout=timeout)
active = [it1, it2]
active = []
shared_metrics = MetricsContext()
for it in [self] + list(others):
active.append(
LocalIterator(
it.base_iterator,
shared_metrics,
it.local_transforms,
timeout=timeout))
def build_union(timeout=None):
while True:
@ -826,15 +825,11 @@ class LocalIterator(Generic[T]):
if not active:
break
# TODO(ekl) is this the best way to represent union() of metrics?
new_ctx = MetricsContext()
new_ctx.parent_metrics.append(self.metrics)
new_ctx.parent_metrics.append(other.metrics)
return LocalIterator(
build_union,
new_ctx, [],
name="LocalUnion[{}, {}]".format(self, other))
shared_metrics, [],
name="LocalUnion[{}, {}]".format(self, ", ".join(map(str,
others))))
class ParallelIteratorWorker(object):

View file

@ -5,9 +5,13 @@ from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
from ray.rllib.agents.dqn.simple_q_policy import SimpleQPolicy
from ray.rllib.optimizers import SyncReplayOptimizer
from ray.rllib.optimizers.replay_buffer import ReplayBuffer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.exploration import PerWorkerEpsilonGreedy
from ray.rllib.utils.experimental_dsl import (
ParallelRollouts, Concurrently, StoreToReplayBuffer, LocalReplay,
TrainOneStep, StandardMetricsReporting, UpdateTargetNetwork)
logger = logging.getLogger(__name__)
@ -308,6 +312,30 @@ def update_target_if_needed(trainer, fetches):
trainer.state["num_target_updates"] += 1
# Experimental pipeline-based impl; enable with "use_pipeline_impl": True.
def training_pipeline(workers, config):
local_replay_buffer = ReplayBuffer(config["buffer_size"])
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer. Calling
# next() on store_op drives this.
store_op = rollouts.for_each(StoreToReplayBuffer(local_replay_buffer))
# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target network.
replay_op = LocalReplay(local_replay_buffer, config["train_batch_size"]) \
.for_each(TrainOneStep(workers)) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))
# Alternate deterministically between (1) and (2).
train_op = Concurrently([store_op, replay_op], mode="round_robin")
return StandardMetricsReporting(train_op, workers, config)
GenericOffPolicyTrainer = build_trainer(
name="GenericOffPolicyAlgorithm",
default_policy=None,
@ -317,7 +345,8 @@ GenericOffPolicyTrainer = build_trainer(
make_policy_optimizer=make_policy_optimizer,
before_train_step=update_worker_exploration,
after_optimizer_step=update_target_if_needed,
after_train_result=after_train_result)
after_train_result=after_train_result,
training_pipeline=training_pipeline)
DQNTrainer = GenericOffPolicyTrainer.with_updates(
name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG)

View file

@ -168,10 +168,10 @@ def build_trainer(name,
def _train_pipeline(self):
if before_train_step:
before_train_step(self)
logger.warning("Ignoring before_train_step callback")
res = next(self.train_pipeline)
if after_train_result:
after_train_result(self, res)
logger.warning("Ignoring after_train_result callback")
return res
@override(Trainer)

View file

@ -546,9 +546,11 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
}
@override(EvaluatorInterface)
def set_weights(self, weights):
def set_weights(self, weights, global_vars=None):
for pid, w in weights.items():
self.policy_map[pid].set_weights(w)
if global_vars:
self.set_global_vars(global_vars)
@override(EvaluatorInterface)
def compute_gradients(self, samples):

View file

@ -1,6 +1,7 @@
import logging
from types import FunctionType
import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
_validate_multiagent_config
@ -71,6 +72,12 @@ class WorkerSet:
"""Return a list of remote rollout workers."""
return self._remote_workers
def sync_weights(self):
"""Syncs weights of remote workers with the local worker."""
weights = ray.put(self.local_worker().get_weights())
for e in self.remote_workers():
e.set_weights.remote(weights)
def add_workers(self, num_workers):
"""Creates and add a number of remote workers to this worker set.

View file

@ -97,18 +97,18 @@ def check_support(alg, config, stats, check_bounds=False, name=None):
if alg not in ["DDPG", "ES", "ARS", "SAC"]:
if o_name in ["atari", "image"]:
if torch:
assert isinstance(
a.get_policy().model, TorchVisionNetV2)
assert isinstance(a.get_policy().model,
TorchVisionNetV2)
else:
assert isinstance(
a.get_policy().model, VisionNetV2)
assert isinstance(a.get_policy().model,
VisionNetV2)
elif o_name in ["vector", "vector2"]:
if torch:
assert isinstance(
a.get_policy().model, TorchFCNetV2)
assert isinstance(a.get_policy().model,
TorchFCNetV2)
else:
assert isinstance(
a.get_policy().model, FCNetV2)
assert isinstance(a.get_policy().model,
FCNetV2)
a.train()
covered_a.add(a_name)
covered_o.add(o_name)
@ -159,12 +159,7 @@ class ModelSupportedSpaces(unittest.TestCase):
ray.shutdown()
def test_a3c(self):
config = {
"num_workers": 1,
"optimizer": {
"grads_per_step": 1
}
}
config = {"num_workers": 1, "optimizer": {"grads_per_step": 1}}
check_support("A3C", config, self.stats, check_bounds=True)
config["use_pytorch"] = True
check_support("A3C", config, self.stats, check_bounds=True)
@ -228,10 +223,7 @@ class ModelSupportedSpaces(unittest.TestCase):
check_support("PPO", config, self.stats, check_bounds=True)
def test_pg(self):
config = {
"num_workers": 1,
"optimizer": {}
}
config = {"num_workers": 1, "optimizer": {}}
check_support("PG", config, self.stats, check_bounds=True)
config["use_pytorch"] = True
check_support("PG", config, self.stats, check_bounds=True)

View file

@ -4,28 +4,40 @@ TODO(ekl): describe the concepts."""
import logging
from typing import List, Any, Tuple, Union
import numpy as np
import time
import ray
from ray.util.iter import from_actors, LocalIterator
from ray.util.iter_metrics import MetricsContext
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
from ray.rllib.evaluation.metrics import collect_episodes, \
summarize_episodes, get_learner_stats
from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \
DEFAULT_POLICY_ID
from ray.rllib.utils.compression import pack_if_needed
logger = logging.getLogger(__name__)
# Metrics context key definitions.
# Counters for training progress (keys for metrics.counters).
STEPS_SAMPLED_COUNTER = "num_steps_sampled"
STEPS_TRAINED_COUNTER = "num_steps_trained"
# Counters to track target network updates.
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
NUM_TARGET_UPDATES = "num_target_updates"
# Performance timers (keys for metrics.timers).
APPLY_GRADS_TIMER = "apply_grad"
COMPUTE_GRADS_TIMER = "compute_grads"
WORKER_UPDATE_TIMER = "update"
GRAD_WAIT_TIMER = "grad_wait"
SAMPLE_TIMER = "sample"
LEARN_ON_BATCH_TIMER = "learn"
# Instant metrics (keys for metrics.info).
LEARNER_INFO = "learner"
# Type aliases.
@ -33,12 +45,19 @@ GradientType = dict
SampleBatchType = Union[SampleBatch, MultiAgentBatch]
# Asserts that an object is a type of SampleBatch.
def _check_sample_batch_type(batch):
if not isinstance(batch, SampleBatchType.__args__):
raise ValueError("Expected either SampleBatch or MultiAgentBatch, "
"got {}: {}".format(type(batch), batch))
# Returns pipeline global vars that should be periodically sent to each worker.
def _get_global_vars():
metrics = LocalIterator.get_metrics()
return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]}
def ParallelRollouts(workers: WorkerSet,
mode="bulk_sync") -> LocalIterator[SampleBatch]:
"""Operator to collect experiences in parallel from rollout workers.
@ -71,6 +90,9 @@ def ParallelRollouts(workers: WorkerSet,
Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context.
"""
# Ensure workers are initially in sync.
workers.sync_weights()
def report_timesteps(batch):
metrics = LocalIterator.get_metrics()
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
@ -119,6 +141,9 @@ def AsyncGradients(
local iterator context.
"""
# Ensure workers are initially in sync.
workers.sync_weights()
# This function will be applied remotely on the workers.
def samples_to_grads(samples):
return get_global_worker().compute_gradients(samples), samples.count
@ -240,7 +265,9 @@ class TrainOneStep:
with metrics.timers[WORKER_UPDATE_TIMER]:
weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
e.set_weights.remote(weights, _get_global_vars())
# Also update global vars of the local worker.
self.workers.local_worker().set_global_vars(_get_global_vars())
return info
@ -266,9 +293,7 @@ class CollectMetrics:
self.timeout_seconds = timeout_seconds
def __call__(self, _):
metrics = LocalIterator.get_metrics()
if metrics.parent_metrics:
raise ValueError("TODO: support nested metrics")
# Collect worker metrics.
episodes, self.to_be_collected = collect_episodes(
self.workers.local_worker(),
self.workers.remote_workers(),
@ -282,22 +307,31 @@ class CollectMetrics:
self.episode_history.extend(orig_episodes)
self.episode_history = self.episode_history[-self.min_history:]
res = summarize_episodes(episodes, orig_episodes)
res.update(info=metrics.info)
res["info"].update({
STEPS_SAMPLED_COUNTER: metrics.counters[STEPS_SAMPLED_COUNTER],
STEPS_TRAINED_COUNTER: metrics.counters[STEPS_TRAINED_COUNTER],
})
# Add in iterator metrics.
metrics = LocalIterator.get_metrics()
if metrics.parent_metrics:
print("TODO: support nested metrics better")
all_metrics = [metrics] + metrics.parent_metrics
timers = {}
counters = {}
info = {}
for metrics in all_metrics:
info.update(metrics.info)
for k, counter in metrics.counters.items():
counters[k] = counter
for k, timer in metrics.timers.items():
timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
if timer.has_units_processed():
timers["{}_throughput".format(k)] = round(
timer.mean_throughput, 3)
res["timers"] = timers
res.update({
"num_healthy_workers": len(self.workers.remote_workers()),
"timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
})
res["timers"] = timers
res["info"] = info
res["info"].update(counters)
return res
@ -392,13 +426,16 @@ class ApplyGradients:
self.workers.local_worker().apply_gradients(gradients)
apply_timer.push_units_processed(count)
# Also update global vars of the local worker.
self.workers.local_worker().set_global_vars(_get_global_vars())
if self.update_all:
if self.workers.remote_workers():
with metrics.timers[WORKER_UPDATE_TIMER]:
weights = ray.put(
self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
e.set_weights.remote(weights, _get_global_vars())
else:
if metrics.cur_actor is None:
raise ValueError("Could not find actor to update. When "
@ -406,7 +443,8 @@ class ApplyGradients:
"in the iterator context.")
with metrics.timers[WORKER_UPDATE_TIMER]:
weights = self.workers.local_worker().get_weights()
metrics.cur_actor.set_weights.remote(weights)
metrics.cur_actor.set_weights.remote(weights,
_get_global_vars())
class AverageGradients:
@ -434,3 +472,125 @@ class AverageGradients:
logger.info("Computing average of {} microbatch gradients "
"({} samples total)".format(len(gradients), sum_count))
return acc, sum_count
class StoreToReplayBuffer:
def __init__(self, replay_buffer):
self.replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
def __call__(self, batch: SampleBatchType):
# Handle everything as if multiagent
if isinstance(batch, SampleBatch):
batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
for policy_id, s in batch.policy_batches.items():
for row in s.rows():
self.replay_buffers[policy_id].add(
pack_if_needed(row["obs"]),
row["actions"],
row["rewards"],
pack_if_needed(row["new_obs"]),
row["dones"],
weight=None)
def LocalReplay(replay_buffer, train_batch_size):
replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
# TODO(ekl) support more options
synchronize_sampling = False
prioritized_replay_beta = None
def gen_replay(timeout):
while True:
samples = {}
idxes = None
for policy_id, replay_buffer in replay_buffers.items():
if synchronize_sampling:
if idxes is None:
idxes = replay_buffer.sample_idxes(train_batch_size)
else:
idxes = replay_buffer.sample_idxes(train_batch_size)
if isinstance(replay_buffer, PrioritizedReplayBuffer):
metrics = LocalIterator.get_metrics()
num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER]
(obses_t, actions, rewards, obses_tp1, dones, weights,
batch_indexes) = replay_buffer.sample_with_idxes(
idxes,
beta=prioritized_replay_beta.value(num_steps_trained))
else:
(obses_t, actions, rewards, obses_tp1,
dones) = replay_buffer.sample_with_idxes(idxes)
weights = np.ones_like(rewards)
batch_indexes = -np.ones_like(rewards)
samples[policy_id] = SampleBatch({
"obs": obses_t,
"actions": actions,
"rewards": rewards,
"new_obs": obses_tp1,
"dones": dones,
"weights": weights,
"batch_indexes": batch_indexes
})
yield MultiAgentBatch(samples, train_batch_size)
return LocalIterator(gen_replay, MetricsContext())
def Concurrently(ops: List[LocalIterator], mode="round_robin"):
"""Operator that runs the given parent iterators concurrently.
Arguments:
mode (str): One of {'round_robin', 'async'}.
- In 'round_robin' mode, we alternate between pulling items from
each parent iterator in order deterministically.
- In 'async' mode, we pull from each parent iterator as fast as
they are produced. This is non-deterministic.
>>> sim_op = ParallelRollouts(...).for_each(...)
>>> replay_op = LocalReplay(...).for_each(...)
>>> combined_op = Concurrently([sim_op, replay_op])
"""
if len(ops) < 2:
raise ValueError("Should specify at least 2 ops.")
if mode == "round_robin":
deterministic = True
elif mode == "async":
deterministic = False
else:
raise ValueError("Unknown mode {}".format(mode))
return ops[0].union(*ops[1:], deterministic=deterministic)
class UpdateTargetNetwork:
"""Periodically call policy.update_target() on all trainable policies.
This should be used with the .for_each() operator after training step
has been taken.
Examples:
>>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
>>> update_op = train_op.for_each(
... UpdateTargetIfNeeded(workers, target_update_freq=500))
>>> print(next(update_op))
None
Updates the LAST_TARGET_UPDATE_TS and NUM_TARGET_UPDATES counters in the
local iterator context. The value of the last update counter is used to
track when we should update the target next.
"""
def __init__(self, workers, target_update_freq):
self.workers = workers
self.target_update_freq = target_update_freq
def __call__(self, _):
metrics = LocalIterator.get_metrics()
cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER]
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update > self.target_update_freq:
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
metrics.counters[NUM_TARGET_UPDATES] += 1
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts