[rllib] Support training intensity for dqn / apex (#8396)

This commit is contained in:
Eric Liang 2020-05-20 11:22:30 -07:00 committed by GitHub
parent f56b3be916
commit aa7a58e92f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 134 additions and 25 deletions

View file

@ -941,13 +941,22 @@ class LocalIterator(Generic[T]):
return iterators
def union(self, *others: "LocalIterator[T]",
deterministic: bool = False) -> "LocalIterator[T]":
def union(self,
*others: "LocalIterator[T]",
deterministic: bool = False,
round_robin_weights: List[float] = None) -> "LocalIterator[T]":
"""Return an iterator that is the union of this and the others.
If deterministic=True, we alternate between reading from one iterator
and the others. Otherwise we return items from iterators as they
become ready.
Args:
deterministic (bool): If deterministic=True, we alternate between
reading from one iterator and the others. Otherwise we return
items from iterators as they become ready.
round_robin_weights (list): List of weights to use for round robin
mode. For example, [2, 1] will cause the iterator to pull twice
as many items from the first iterator as the second.
[2, 1, "*"] will cause as many items to be pulled as possible
from the third iterator without blocking. This overrides the
deterministic flag.
"""
for it in others:
@ -956,32 +965,49 @@ class LocalIterator(Generic[T]):
"other must be of type LocalIterator, got {}".format(
type(it)))
timeout = None if deterministic else 0
active = []
parent_iters = [self] + list(others)
shared_metrics = SharedMetrics(
parents=[p.shared_metrics for p in parent_iters])
for it in parent_iters:
timeout = None if deterministic else 0
if round_robin_weights:
if len(round_robin_weights) != len(parent_iters):
raise ValueError(
"Length of round robin weights must equal number of "
"iterators total.")
timeouts = [0 if w == "*" else None for w in round_robin_weights]
else:
timeouts = [timeout] * len(parent_iters)
round_robin_weights = [1] * len(parent_iters)
for i, it in enumerate(parent_iters):
active.append(
LocalIterator(
it.base_iterator,
shared_metrics,
it.local_transforms,
timeout=timeout))
timeout=timeouts[i]))
active = list(zip(round_robin_weights, active))
def build_union(timeout=None):
while True:
for it in list(active):
for weight, it in list(active):
if weight == "*":
max_pull = 100 # TOOD(ekl) how to best bound this?
else:
max_pull = _randomized_int_cast(weight)
try:
item = next(it)
if isinstance(item, _NextValueNotReady):
if timeout is not None:
for _ in range(max_pull):
item = next(it)
if isinstance(item, _NextValueNotReady):
if timeout is not None:
yield item
break
else:
yield item
else:
yield item
except StopIteration:
active.remove(it)
active.remove((weight, it))
if not active:
break
@ -1071,6 +1097,14 @@ class ParallelIteratorWorker(object):
return self.next_ith_buffer[start].pop(0)
def _randomized_int_cast(float_value):
base = int(float_value)
remainder = float_value - base
if random.random() < remainder:
base += 1
return base
class _NextValueNotReady(Exception):
"""Indicates that a local iterator has no value currently available.

View file

@ -106,6 +106,10 @@ DEFAULT_CONFIG = with_common_config({
"prioritized_replay_eps": 1e-6,
# Whether to LZ4 compress observations
"compress_observations": False,
# If set, this will fix the ratio of sampled to replayed timesteps.
# Otherwise, replay will proceed at the native ratio determined by
# (train_batch_size / rollout_fragment_length).
"training_intensity": None,
# === Optimization ===
# Learning rate for the critic (Q-function) optimizer.

View file

@ -2,7 +2,8 @@ import collections
import copy
import ray
from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG
from ray.rllib.agents.dqn.dqn import DQNTrainer, \
DEFAULT_CONFIG as DQN_CONFIG, calculate_rr_weights
from ray.rllib.execution.common import STEPS_TRAINED_COUNTER, \
SampleBatchType, _get_shared_metrics, _get_global_vars
from ray.rllib.evaluation.worker_set import WorkerSet
@ -41,6 +42,9 @@ APEX_DEFAULT_CONFIG = merge_dicts(
"exploration_config": {"type": "PerWorkerEpsilonGreedy"},
"worker_side_prioritization": True,
"min_iter_time_s": 30,
# If set, this will fix the ratio of sampled to replayed timesteps.
# Otherwise, replay will proceed as fast as possible.
"training_intensity": None,
},
)
# __sphinx_doc_end__
@ -175,10 +179,19 @@ def execution_plan(workers: WorkerSet, config: dict):
workers, config["target_network_update_freq"],
by_steps_trained=True))
# Execute (1), (2), (3) asynchronously as fast as possible. Only output
# items from (3) since metrics aren't available before then.
merged_op = Concurrently(
[store_op, replay_op, update_op], mode="async", output_indexes=[2])
if config["training_intensity"]:
# Execute (1), (2) with a fixed intensity ratio.
rr_weights = calculate_rr_weights(config) + ["*"]
merged_op = Concurrently(
[store_op, replay_op, update_op],
mode="round_robin",
output_indexes=[2],
round_robin_weights=rr_weights)
else:
# Execute (1), (2), (3) asynchronously as fast as possible. Only output
# items from (3) since metrics aren't available before then.
merged_op = Concurrently(
[store_op, replay_op, update_op], mode="async", output_indexes=[2])
# Add in extra replay and learner metrics to the training result.
def add_apex_metrics(result):

View file

@ -89,6 +89,10 @@ DEFAULT_CONFIG = with_common_config({
"multiagent_sync_replay": False,
# Callback to run before learning on a multi-agent batch of experiences.
"before_learn_on_batch": None,
# If set, this will fix the ratio of sampled to replayed timesteps.
# Otherwise, replay will proceed at the native ratio determined by
# (train_batch_size / rollout_fragment_length).
"training_intensity": None,
# === Optimization ===
# Learning rate for adam optimizer
@ -358,11 +362,26 @@ def execution_plan(workers, config):
# Alternate deterministically between (1) and (2). Only return the output
# of (2) since training metrics are not available until (2) runs.
train_op = Concurrently(
[store_op, replay_op], mode="round_robin", output_indexes=[1])
[store_op, replay_op],
mode="round_robin",
output_indexes=[1],
round_robin_weights=calculate_rr_weights(config))
return StandardMetricsReporting(train_op, workers, config)
def calculate_rr_weights(config):
if not config["training_intensity"]:
return [1, 1]
# e.g., 32 / 4 -> native ratio of 8.0
native_ratio = (
config["train_batch_size"] / config["rollout_fragment_length"])
# Training intensity is specified in terms of
# (steps_replayed / steps_sampled), so adjust for the native ratio.
weights = [1, config["training_intensity"] / native_ratio]
return weights
def get_policy_class(config):
if config["use_pytorch"]:
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy

View file

@ -62,8 +62,12 @@ DEFAULT_CONFIG = with_common_config({
"prioritized_replay_eps": 1e-6,
"prioritized_replay_beta_annealing_timesteps": 20000,
"final_prioritized_replay_beta": 0.4,
# Whether to LZ4 compress observations
"compress_observations": False,
# If set, this will fix the ratio of sampled to replayed timesteps.
# Otherwise, replay will proceed at the native ratio determined by
# (train_batch_size / rollout_fragment_length).
"training_intensity": None,
# === Optimization ===
"optimization": {

View file

@ -70,6 +70,10 @@ DEFAULT_CONFIG = with_common_config({
# In multi-agent mode, whether to replay experiences from the same time
# step for all policies. This is required for MADDPG.
"multiagent_sync_replay": True,
# If set, this will fix the ratio of sampled to replayed timesteps.
# Otherwise, replay will proceed at the native ratio determined by
# (train_batch_size / rollout_fragment_length).
"training_intensity": None,
# === Optimization ===
# Learning rate for the critic (Q-function) optimizer.

View file

@ -8,7 +8,8 @@ from ray.util.iter_metrics import SharedMetrics
def Concurrently(ops: List[LocalIterator],
*,
mode="round_robin",
output_indexes=None):
output_indexes=None,
round_robin_weights=None):
"""Operator that runs the given parent iterators concurrently.
Arguments:
@ -20,6 +21,12 @@ def Concurrently(ops: List[LocalIterator],
output_indexes (list): If specified, only output results from the
given ops. For example, if output_indexes=[0], only results from
the first op in ops will be returned.
round_robin_weights (list): List of weights to use for round robin
mode. For example, [2, 1] will cause the iterator to pull twice
as many items from the first iterator as the second. [2, 1, *] will
cause as many items to be pulled as possible from the third
iterator without blocking. This is only allowed in round robin
mode.
>>> sim_op = ParallelRollouts(...).for_each(...)
>>> replay_op = LocalReplay(...).for_each(...)
@ -32,8 +39,13 @@ def Concurrently(ops: List[LocalIterator],
deterministic = True
elif mode == "async":
deterministic = False
if round_robin_weights:
raise ValueError(
"round_robin_weights cannot be specified in async mode")
else:
raise ValueError("Unknown mode {}".format(mode))
if round_robin_weights and all(r == "*" for r in round_robin_weights):
raise ValueError("Cannot specify all round robin weights = *")
if output_indexes:
for i in output_indexes:
@ -44,7 +56,10 @@ def Concurrently(ops: List[LocalIterator],
ops = [tag(op, i) for i, op in enumerate(ops)]
output = ops[0].union(*ops[1:], deterministic=deterministic)
output = ops[0].union(
*ops[1:],
deterministic=deterministic,
round_robin_weights=round_robin_weights)
if output_indexes:
output = (output.filter(lambda tup: tup[0] in output_indexes)

View file

@ -53,6 +53,22 @@ def test_concurrently(ray_start_regular_shared):
assert c.take(6) == [1, 4, 2, 5, 3, 6]
def test_concurrently_weighted(ray_start_regular_shared):
a = iter_list([1, 1, 1])
b = iter_list([2, 2, 2])
c = iter_list([3, 3, 3])
c = Concurrently(
[a, b, c], mode="round_robin", round_robin_weights=[3, 1, 2])
assert c.take(9) == [1, 1, 1, 2, 3, 3, 2, 3, 2]
a = iter_list([1, 1, 1])
b = iter_list([2, 2, 2])
c = iter_list([3, 3, 3])
c = Concurrently(
[a, b, c], mode="round_robin", round_robin_weights=[1, 1, "*"])
assert c.take(9) == [1, 2, 3, 3, 3, 1, 2, 1, 2]
def test_concurrently_output(ray_start_regular_shared):
a = iter_list([1, 2, 3])
b = iter_list([4, 5, 6])