From b14cc16616473d27898515015f264911fdb84b55 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 5 May 2020 12:36:42 -0700 Subject: [PATCH] [rllib] Enable functional execution workflow API by default (#8221) --- rllib/agents/a3c/a3c.py | 2 - rllib/agents/ddpg/ddpg_tf_policy.py | 2 +- rllib/agents/ddpg/tests/test_ddpg.py | 7 +++- rllib/agents/dqn/dqn.py | 18 +++++---- rllib/agents/dqn/simple_q.py | 34 +++++++++++++++++ rllib/agents/pg/pg.py | 2 - rllib/agents/qmix/qmix.py | 3 ++ rllib/agents/sac/tests/test_sac.py | 7 +++- rllib/agents/trainer.py | 2 +- rllib/contrib/maddpg/maddpg.py | 3 ++ rllib/execution/metric_ops.py | 2 +- rllib/optimizers/async_replay_optimizer.py | 43 +++++++++++++++++++--- rllib/tests/test_supported_spaces.py | 1 + 13 files changed, 102 insertions(+), 24 deletions(-) diff --git a/rllib/agents/a3c/a3c.py b/rllib/agents/a3c/a3c.py index 6583f121e..a057c52a3 100644 --- a/rllib/agents/a3c/a3c.py +++ b/rllib/agents/a3c/a3c.py @@ -38,8 +38,6 @@ DEFAULT_CONFIG = with_common_config({ # Workers sample async. Note that this increases the effective # rollout_fragment_length by up to 5x due to async buffering of batches. "sample_async": True, - # Use the execution plan API instead of policy optimizers. - "use_exec_api": True, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index 207cd4c4b..ee8b40b55 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -417,7 +417,7 @@ def setup_late_mixins(policy, obs_space, action_space, config): DDPGTFPolicy = build_tf_policy( - name="DQNTFPolicy", + name="DDPGTFPolicy", get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG, make_model=build_ddpg_models, action_distribution_fn=get_distribution_inputs_and_class, diff --git a/rllib/agents/ddpg/tests/test_ddpg.py b/rllib/agents/ddpg/tests/test_ddpg.py index 6b35eceb6..aad842bcf 100644 --- a/rllib/agents/ddpg/tests/test_ddpg.py +++ b/rllib/agents/ddpg/tests/test_ddpg.py @@ -6,6 +6,7 @@ import ray.rllib.agents.ddpg as ddpg from ray.rllib.agents.ddpg.ddpg_torch_policy import ddpg_actor_critic_loss as \ loss_torch from ray.rllib.agents.sac.tests.test_sac import SimpleEnv +from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.numpy import fc, huber_loss, l2_loss, relu, sigmoid @@ -325,7 +326,8 @@ class TestDDPG(unittest.TestCase): tf_inputs.append(in_) # Set a fake-batch to use # (instead of sampling from replay buffer). - trainer.optimizer._fake_batch = in_ + buf = LocalReplayBuffer.get_instance_for_testing() + buf._fake_batch = in_ trainer.train() updated_weights = policy.get_weights() # Net must have changed. @@ -344,7 +346,8 @@ class TestDDPG(unittest.TestCase): in_ = tf_inputs[update_iteration] # Set a fake-batch to use # (instead of sampling from replay buffer). - trainer.optimizer._fake_batch = in_ + buf = LocalReplayBuffer.get_instance_for_testing() + buf._fake_batch = in_ trainer.train() # Compare updated model and target weights. for tf_key in tf_weights.keys(): diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 7e865767d..694aeec8f 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -126,9 +126,6 @@ DEFAULT_CONFIG = with_common_config({ "soft_q": DEPRECATED_VALUE, "parameter_noise": DEPRECATED_VALUE, "grad_norm_clipping": DEPRECATED_VALUE, - - # Use the execution plan API instead of policy optimizers. - "use_exec_api": True, }) # __sphinx_doc_end__ # yapf: enable @@ -301,14 +298,21 @@ def update_target_if_needed(trainer, fetches): # Experimental distributed execution impl; enable with "use_exec_api": True. def execution_plan(workers, config): + if config.get("prioritized_replay"): + prio_args = { + "prioritized_replay_alpha": config["prioritized_replay_alpha"], + "prioritized_replay_beta": config["prioritized_replay_beta"], + "prioritized_replay_eps": config["prioritized_replay_eps"], + } + else: + prio_args = {} + local_replay_buffer = LocalReplayBuffer( num_shards=1, learning_starts=config["learning_starts"], buffer_size=config["buffer_size"], replay_batch_size=config["train_batch_size"], - prioritized_replay_alpha=config["prioritized_replay_alpha"], - prioritized_replay_beta=config["prioritized_replay_beta"], - prioritized_replay_eps=config["prioritized_replay_eps"]) + **prio_args) rollouts = ParallelRollouts(workers, mode="bulk_sync") @@ -320,7 +324,7 @@ def execution_plan(workers, config): def update_prio(item): samples, info_dict = item - if config["prioritized_replay"]: + if config.get("prioritized_replay"): prio_dict = {} for policy_id, info in info_dict.items(): # TODO(sven): This is currently structured differently for diff --git a/rllib/agents/dqn/simple_q.py b/rllib/agents/dqn/simple_q.py index fd7d3dd09..757564be7 100644 --- a/rllib/agents/dqn/simple_q.py +++ b/rllib/agents/dqn/simple_q.py @@ -3,6 +3,12 @@ import logging from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy from ray.rllib.agents.dqn.dqn import DQNTrainer +from ray.rllib.execution.concurrency_ops import Concurrently +from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay +from ray.rllib.execution.rollout_ops import ParallelRollouts +from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork +from ray.rllib.execution.metric_ops import StandardMetricsReporting +from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer logger = logging.getLogger(__name__) @@ -81,7 +87,35 @@ def get_policy_class(config): return SimpleQTFPolicy +# Experimental distributed execution impl; enable with "use_exec_api": True. +def execution_plan(workers, config): + local_replay_buffer = LocalReplayBuffer( + num_shards=1, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + replay_batch_size=config["train_batch_size"]) + + rollouts = ParallelRollouts(workers, mode="bulk_sync") + + # (1) Generate rollouts and store them in our local replay buffer. + store_op = rollouts.for_each( + StoreToReplayBuffer(local_buffer=local_replay_buffer)) + + # (2) Read and train on experiences from the replay buffer. + replay_op = Replay(local_buffer=local_replay_buffer) \ + .for_each(TrainOneStep(workers)) \ + .for_each(UpdateTargetNetwork( + workers, config["target_network_update_freq"])) + + # Alternate deterministically between (1) and (2). + train_op = Concurrently( + [store_op, replay_op], mode="round_robin", output_indexes=[1]) + + return StandardMetricsReporting(train_op, workers, config) + + SimpleQTrainer = DQNTrainer.with_updates( default_policy=SimpleQTFPolicy, get_policy_class=get_policy_class, + execution_plan=execution_plan, default_config=DEFAULT_CONFIG) diff --git a/rllib/agents/pg/pg.py b/rllib/agents/pg/pg.py index fd31eb52c..2f663877c 100644 --- a/rllib/agents/pg/pg.py +++ b/rllib/agents/pg/pg.py @@ -12,8 +12,6 @@ DEFAULT_CONFIG = with_common_config({ "num_workers": 0, # Learning rate. "lr": 0.0004, - # Use the execution plan API instead of policy optimizers. - "use_exec_api": True, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/qmix/qmix.py b/rllib/agents/qmix/qmix.py index 92f9090b8..07996f50e 100644 --- a/rllib/agents/qmix/qmix.py +++ b/rllib/agents/qmix/qmix.py @@ -82,6 +82,9 @@ DEFAULT_CONFIG = with_common_config({ "lstm_cell_size": 64, "max_seq_len": 999999, }, + + # TODO(ekl) support sync batch replay. + "use_exec_api": False, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/sac/tests/test_sac.py b/rllib/agents/sac/tests/test_sac.py index 041912dd8..d2bb87f96 100644 --- a/rllib/agents/sac/tests/test_sac.py +++ b/rllib/agents/sac/tests/test_sac.py @@ -9,6 +9,7 @@ from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as \ loss_torch from ray.rllib.models.tf.tf_action_dist import SquashedGaussian from ray.rllib.models.torch.torch_action_dist import TorchSquashedGaussian +from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.numpy import fc, relu @@ -306,7 +307,8 @@ class TestSAC(unittest.TestCase): tf_inputs.append(in_) # Set a fake-batch to use # (instead of sampling from replay buffer). - trainer.optimizer._fake_batch = in_ + buf = LocalReplayBuffer.get_instance_for_testing() + buf._fake_batch = in_ trainer.train() updated_weights = policy.get_weights() # Net must have changed. @@ -325,7 +327,8 @@ class TestSAC(unittest.TestCase): in_ = tf_inputs[update_iteration] # Set a fake-batch to use # (instead of sampling from replay buffer). - trainer.optimizer._fake_batch = in_ + buf = LocalReplayBuffer.get_instance_for_testing() + buf._fake_batch = in_ trainer.train() # Compare updated model. for tf_key in sorted(tf_weights.keys())[2:10]: diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 83431956e..cb8e0d454 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -206,7 +206,7 @@ COMMON_CONFIG = { "custom_eval_function": None, # EXPERIMENTAL: use the execution plan based API impl of the algo. Can also # be enabled by setting RLLIB_EXEC_API=1. - "use_exec_api": False, + "use_exec_api": True, # === Advanced Rollout Settings === # Use a background thread for sampling (slightly off-policy, usually not diff --git a/rllib/contrib/maddpg/maddpg.py b/rllib/contrib/maddpg/maddpg.py index ad8604588..a6ca8231c 100644 --- a/rllib/contrib/maddpg/maddpg.py +++ b/rllib/contrib/maddpg/maddpg.py @@ -100,6 +100,9 @@ DEFAULT_CONFIG = with_common_config({ "num_workers": 1, # Prevent iterations from going lower than this time span "min_iter_time_s": 0, + + # TODO(ekl) support synchronized sampling. + "use_exec_api": False, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/execution/metric_ops.py b/rllib/execution/metric_ops.py index 4bb987187..28208c9ba 100644 --- a/rllib/execution/metric_ops.py +++ b/rllib/execution/metric_ops.py @@ -159,7 +159,7 @@ class OncePerTimestepsElapsed: return True metrics = LocalIterator.get_metrics() now = metrics.counters[STEPS_SAMPLED_COUNTER] - if now - self.last_called > self.delay_steps: + if now - self.last_called >= self.delay_steps: self.last_called = now return True return False diff --git a/rllib/optimizers/async_replay_optimizer.py b/rllib/optimizers/async_replay_optimizer.py index 460c475b6..caac76c77 100644 --- a/rllib/optimizers/async_replay_optimizer.py +++ b/rllib/optimizers/async_replay_optimizer.py @@ -285,6 +285,10 @@ class AsyncReplayOptimizer(PolicyOptimizer): return sample_timesteps, train_timesteps +# Visible for testing. +_local_replay_buffer = None + + # TODO(ekl) move this class to common class LocalReplayBuffer(ParallelIteratorWorker): """A replay buffer shard. @@ -292,9 +296,14 @@ class LocalReplayBuffer(ParallelIteratorWorker): Ray actors are single-threaded, so for scalability multiple replay actors may be created to increase parallelism.""" - def __init__(self, num_shards, learning_starts, buffer_size, - replay_batch_size, prioritized_replay_alpha, - prioritized_replay_beta, prioritized_replay_eps): + def __init__(self, + num_shards, + learning_starts, + buffer_size, + replay_batch_size, + prioritized_replay_alpha=0.6, + prioritized_replay_beta=0.4, + prioritized_replay_eps=1e-6): self.replay_starts = learning_starts // num_shards self.buffer_size = buffer_size // num_shards self.replay_batch_size = replay_batch_size @@ -319,6 +328,17 @@ class LocalReplayBuffer(ParallelIteratorWorker): self.update_priorities_timer = TimerStat() self.num_added = 0 + # Make externally accessible for testing. + global _local_replay_buffer + _local_replay_buffer = self + # If set, return this instead of the usual data for testing. + self._fake_batch = None + + @staticmethod + def get_instance_for_testing(): + global _local_replay_buffer + return _local_replay_buffer + def get_host(self): return os.uname()[1] @@ -338,6 +358,12 @@ class LocalReplayBuffer(ParallelIteratorWorker): self.num_added += batch.count def replay(self): + if self._fake_batch: + fake_batch = SampleBatch(self._fake_batch) + return MultiAgentBatch({ + DEFAULT_POLICY_ID: fake_batch + }, fake_batch.count) + if self.num_added < self.replay_starts: return None @@ -392,9 +418,14 @@ class LocalBatchReplayBuffer(LocalReplayBuffer): This allows for RNN models, but ignores prioritization params. """ - def __init__(self, num_shards, learning_starts, buffer_size, - train_batch_size, prioritized_replay_alpha, - prioritized_replay_beta, prioritized_replay_eps): + def __init__(self, + num_shards, + learning_starts, + buffer_size, + train_batch_size, + prioritized_replay_alpha=0.6, + prioritized_replay_beta=0.4, + prioritized_replay_eps=1e-6): self.replay_starts = learning_starts // num_shards self.buffer_size = buffer_size // num_shards self.train_batch_size = train_batch_size diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index 55c0bee58..5c850e736 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -256,6 +256,7 @@ class ModelSupportedSpaces(unittest.TestCase): check_support_multiagent("DDPG", { "timesteps_per_iteration": 1, "use_state_preprocessor": True, + "learning_starts": 500, }) def test_dqn_multiagent(self):