[rllib] Enable functional execution workflow API by default (#8221)

This commit is contained in:
Eric Liang 2020-05-05 12:36:42 -07:00 committed by GitHub
parent 4bdef78e2e
commit b14cc16616
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 102 additions and 24 deletions

View file

@ -38,8 +38,6 @@ DEFAULT_CONFIG = with_common_config({
# Workers sample async. Note that this increases the effective
# rollout_fragment_length by up to 5x due to async buffering of batches.
"sample_async": True,
# Use the execution plan API instead of policy optimizers.
"use_exec_api": True,
})
# __sphinx_doc_end__
# yapf: enable

View file

@ -417,7 +417,7 @@ def setup_late_mixins(policy, obs_space, action_space, config):
DDPGTFPolicy = build_tf_policy(
name="DQNTFPolicy",
name="DDPGTFPolicy",
get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG,
make_model=build_ddpg_models,
action_distribution_fn=get_distribution_inputs_and_class,

View file

@ -6,6 +6,7 @@ import ray.rllib.agents.ddpg as ddpg
from ray.rllib.agents.ddpg.ddpg_torch_policy import ddpg_actor_critic_loss as \
loss_torch
from ray.rllib.agents.sac.tests.test_sac import SimpleEnv
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import fc, huber_loss, l2_loss, relu, sigmoid
@ -325,7 +326,8 @@ class TestDDPG(unittest.TestCase):
tf_inputs.append(in_)
# Set a fake-batch to use
# (instead of sampling from replay buffer).
trainer.optimizer._fake_batch = in_
buf = LocalReplayBuffer.get_instance_for_testing()
buf._fake_batch = in_
trainer.train()
updated_weights = policy.get_weights()
# Net must have changed.
@ -344,7 +346,8 @@ class TestDDPG(unittest.TestCase):
in_ = tf_inputs[update_iteration]
# Set a fake-batch to use
# (instead of sampling from replay buffer).
trainer.optimizer._fake_batch = in_
buf = LocalReplayBuffer.get_instance_for_testing()
buf._fake_batch = in_
trainer.train()
# Compare updated model and target weights.
for tf_key in tf_weights.keys():

View file

@ -126,9 +126,6 @@ DEFAULT_CONFIG = with_common_config({
"soft_q": DEPRECATED_VALUE,
"parameter_noise": DEPRECATED_VALUE,
"grad_norm_clipping": DEPRECATED_VALUE,
# Use the execution plan API instead of policy optimizers.
"use_exec_api": True,
})
# __sphinx_doc_end__
# yapf: enable
@ -301,14 +298,21 @@ def update_target_if_needed(trainer, fetches):
# Experimental distributed execution impl; enable with "use_exec_api": True.
def execution_plan(workers, config):
if config.get("prioritized_replay"):
prio_args = {
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
"prioritized_replay_beta": config["prioritized_replay_beta"],
"prioritized_replay_eps": config["prioritized_replay_eps"],
}
else:
prio_args = {}
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
prioritized_replay_alpha=config["prioritized_replay_alpha"],
prioritized_replay_beta=config["prioritized_replay_beta"],
prioritized_replay_eps=config["prioritized_replay_eps"])
**prio_args)
rollouts = ParallelRollouts(workers, mode="bulk_sync")
@ -320,7 +324,7 @@ def execution_plan(workers, config):
def update_prio(item):
samples, info_dict = item
if config["prioritized_replay"]:
if config.get("prioritized_replay"):
prio_dict = {}
for policy_id, info in info_dict.items():
# TODO(sven): This is currently structured differently for

View file

@ -3,6 +3,12 @@ import logging
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
logger = logging.getLogger(__name__)
@ -81,7 +87,35 @@ def get_policy_class(config):
return SimpleQTFPolicy
# Experimental distributed execution impl; enable with "use_exec_api": True.
def execution_plan(workers, config):
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
replay_batch_size=config["train_batch_size"])
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# (1) Generate rollouts and store them in our local replay buffer.
store_op = rollouts.for_each(
StoreToReplayBuffer(local_buffer=local_replay_buffer))
# (2) Read and train on experiences from the replay buffer.
replay_op = Replay(local_buffer=local_replay_buffer) \
.for_each(TrainOneStep(workers)) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))
# Alternate deterministically between (1) and (2).
train_op = Concurrently(
[store_op, replay_op], mode="round_robin", output_indexes=[1])
return StandardMetricsReporting(train_op, workers, config)
SimpleQTrainer = DQNTrainer.with_updates(
default_policy=SimpleQTFPolicy,
get_policy_class=get_policy_class,
execution_plan=execution_plan,
default_config=DEFAULT_CONFIG)

View file

@ -12,8 +12,6 @@ DEFAULT_CONFIG = with_common_config({
"num_workers": 0,
# Learning rate.
"lr": 0.0004,
# Use the execution plan API instead of policy optimizers.
"use_exec_api": True,
})
# __sphinx_doc_end__
# yapf: enable

View file

@ -82,6 +82,9 @@ DEFAULT_CONFIG = with_common_config({
"lstm_cell_size": 64,
"max_seq_len": 999999,
},
# TODO(ekl) support sync batch replay.
"use_exec_api": False,
})
# __sphinx_doc_end__
# yapf: enable

View file

@ -9,6 +9,7 @@ from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as \
loss_torch
from ray.rllib.models.tf.tf_action_dist import SquashedGaussian
from ray.rllib.models.torch.torch_action_dist import TorchSquashedGaussian
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import fc, relu
@ -306,7 +307,8 @@ class TestSAC(unittest.TestCase):
tf_inputs.append(in_)
# Set a fake-batch to use
# (instead of sampling from replay buffer).
trainer.optimizer._fake_batch = in_
buf = LocalReplayBuffer.get_instance_for_testing()
buf._fake_batch = in_
trainer.train()
updated_weights = policy.get_weights()
# Net must have changed.
@ -325,7 +327,8 @@ class TestSAC(unittest.TestCase):
in_ = tf_inputs[update_iteration]
# Set a fake-batch to use
# (instead of sampling from replay buffer).
trainer.optimizer._fake_batch = in_
buf = LocalReplayBuffer.get_instance_for_testing()
buf._fake_batch = in_
trainer.train()
# Compare updated model.
for tf_key in sorted(tf_weights.keys())[2:10]:

View file

@ -206,7 +206,7 @@ COMMON_CONFIG = {
"custom_eval_function": None,
# EXPERIMENTAL: use the execution plan based API impl of the algo. Can also
# be enabled by setting RLLIB_EXEC_API=1.
"use_exec_api": False,
"use_exec_api": True,
# === Advanced Rollout Settings ===
# Use a background thread for sampling (slightly off-policy, usually not

View file

@ -100,6 +100,9 @@ DEFAULT_CONFIG = with_common_config({
"num_workers": 1,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 0,
# TODO(ekl) support synchronized sampling.
"use_exec_api": False,
})
# __sphinx_doc_end__
# yapf: enable

View file

@ -159,7 +159,7 @@ class OncePerTimestepsElapsed:
return True
metrics = LocalIterator.get_metrics()
now = metrics.counters[STEPS_SAMPLED_COUNTER]
if now - self.last_called > self.delay_steps:
if now - self.last_called >= self.delay_steps:
self.last_called = now
return True
return False

View file

@ -285,6 +285,10 @@ class AsyncReplayOptimizer(PolicyOptimizer):
return sample_timesteps, train_timesteps
# Visible for testing.
_local_replay_buffer = None
# TODO(ekl) move this class to common
class LocalReplayBuffer(ParallelIteratorWorker):
"""A replay buffer shard.
@ -292,9 +296,14 @@ class LocalReplayBuffer(ParallelIteratorWorker):
Ray actors are single-threaded, so for scalability multiple replay actors
may be created to increase parallelism."""
def __init__(self, num_shards, learning_starts, buffer_size,
replay_batch_size, prioritized_replay_alpha,
prioritized_replay_beta, prioritized_replay_eps):
def __init__(self,
num_shards,
learning_starts,
buffer_size,
replay_batch_size,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6):
self.replay_starts = learning_starts // num_shards
self.buffer_size = buffer_size // num_shards
self.replay_batch_size = replay_batch_size
@ -319,6 +328,17 @@ class LocalReplayBuffer(ParallelIteratorWorker):
self.update_priorities_timer = TimerStat()
self.num_added = 0
# Make externally accessible for testing.
global _local_replay_buffer
_local_replay_buffer = self
# If set, return this instead of the usual data for testing.
self._fake_batch = None
@staticmethod
def get_instance_for_testing():
global _local_replay_buffer
return _local_replay_buffer
def get_host(self):
return os.uname()[1]
@ -338,6 +358,12 @@ class LocalReplayBuffer(ParallelIteratorWorker):
self.num_added += batch.count
def replay(self):
if self._fake_batch:
fake_batch = SampleBatch(self._fake_batch)
return MultiAgentBatch({
DEFAULT_POLICY_ID: fake_batch
}, fake_batch.count)
if self.num_added < self.replay_starts:
return None
@ -392,9 +418,14 @@ class LocalBatchReplayBuffer(LocalReplayBuffer):
This allows for RNN models, but ignores prioritization params.
"""
def __init__(self, num_shards, learning_starts, buffer_size,
train_batch_size, prioritized_replay_alpha,
prioritized_replay_beta, prioritized_replay_eps):
def __init__(self,
num_shards,
learning_starts,
buffer_size,
train_batch_size,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6):
self.replay_starts = learning_starts // num_shards
self.buffer_size = buffer_size // num_shards
self.train_batch_size = train_batch_size

View file

@ -256,6 +256,7 @@ class ModelSupportedSpaces(unittest.TestCase):
check_support_multiagent("DDPG", {
"timesteps_per_iteration": 1,
"use_state_preprocessor": True,
"learning_starts": 500,
})
def test_dqn_multiagent(self):