mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[rllib] Enable functional execution workflow API by default (#8221)
This commit is contained in:
parent
4bdef78e2e
commit
b14cc16616
13 changed files with 102 additions and 24 deletions
|
@ -38,8 +38,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Workers sample async. Note that this increases the effective
|
||||
# rollout_fragment_length by up to 5x due to async buffering of batches.
|
||||
"sample_async": True,
|
||||
# Use the execution plan API instead of policy optimizers.
|
||||
"use_exec_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
|
|
@ -417,7 +417,7 @@ def setup_late_mixins(policy, obs_space, action_space, config):
|
|||
|
||||
|
||||
DDPGTFPolicy = build_tf_policy(
|
||||
name="DQNTFPolicy",
|
||||
name="DDPGTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG,
|
||||
make_model=build_ddpg_models,
|
||||
action_distribution_fn=get_distribution_inputs_and_class,
|
||||
|
|
|
@ -6,6 +6,7 @@ import ray.rllib.agents.ddpg as ddpg
|
|||
from ray.rllib.agents.ddpg.ddpg_torch_policy import ddpg_actor_critic_loss as \
|
||||
loss_torch
|
||||
from ray.rllib.agents.sac.tests.test_sac import SimpleEnv
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.numpy import fc, huber_loss, l2_loss, relu, sigmoid
|
||||
|
@ -325,7 +326,8 @@ class TestDDPG(unittest.TestCase):
|
|||
tf_inputs.append(in_)
|
||||
# Set a fake-batch to use
|
||||
# (instead of sampling from replay buffer).
|
||||
trainer.optimizer._fake_batch = in_
|
||||
buf = LocalReplayBuffer.get_instance_for_testing()
|
||||
buf._fake_batch = in_
|
||||
trainer.train()
|
||||
updated_weights = policy.get_weights()
|
||||
# Net must have changed.
|
||||
|
@ -344,7 +346,8 @@ class TestDDPG(unittest.TestCase):
|
|||
in_ = tf_inputs[update_iteration]
|
||||
# Set a fake-batch to use
|
||||
# (instead of sampling from replay buffer).
|
||||
trainer.optimizer._fake_batch = in_
|
||||
buf = LocalReplayBuffer.get_instance_for_testing()
|
||||
buf._fake_batch = in_
|
||||
trainer.train()
|
||||
# Compare updated model and target weights.
|
||||
for tf_key in tf_weights.keys():
|
||||
|
|
|
@ -126,9 +126,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"soft_q": DEPRECATED_VALUE,
|
||||
"parameter_noise": DEPRECATED_VALUE,
|
||||
"grad_norm_clipping": DEPRECATED_VALUE,
|
||||
|
||||
# Use the execution plan API instead of policy optimizers.
|
||||
"use_exec_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -301,14 +298,21 @@ def update_target_if_needed(trainer, fetches):
|
|||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
if config.get("prioritized_replay"):
|
||||
prio_args = {
|
||||
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
|
||||
"prioritized_replay_beta": config["prioritized_replay_beta"],
|
||||
"prioritized_replay_eps": config["prioritized_replay_eps"],
|
||||
}
|
||||
else:
|
||||
prio_args = {}
|
||||
|
||||
local_replay_buffer = LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
replay_batch_size=config["train_batch_size"],
|
||||
prioritized_replay_alpha=config["prioritized_replay_alpha"],
|
||||
prioritized_replay_beta=config["prioritized_replay_beta"],
|
||||
prioritized_replay_eps=config["prioritized_replay_eps"])
|
||||
**prio_args)
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
|
@ -320,7 +324,7 @@ def execution_plan(workers, config):
|
|||
|
||||
def update_prio(item):
|
||||
samples, info_dict = item
|
||||
if config["prioritized_replay"]:
|
||||
if config.get("prioritized_replay"):
|
||||
prio_dict = {}
|
||||
for policy_id, info in info_dict.items():
|
||||
# TODO(sven): This is currently structured differently for
|
||||
|
|
|
@ -3,6 +3,12 @@ import logging
|
|||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -81,7 +87,35 @@ def get_policy_class(config):
|
|||
return SimpleQTFPolicy
|
||||
|
||||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
local_replay_buffer = LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
replay_batch_size=config["train_batch_size"])
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# (1) Generate rollouts and store them in our local replay buffer.
|
||||
store_op = rollouts.for_each(
|
||||
StoreToReplayBuffer(local_buffer=local_replay_buffer))
|
||||
|
||||
# (2) Read and train on experiences from the replay buffer.
|
||||
replay_op = Replay(local_buffer=local_replay_buffer) \
|
||||
.for_each(TrainOneStep(workers)) \
|
||||
.for_each(UpdateTargetNetwork(
|
||||
workers, config["target_network_update_freq"]))
|
||||
|
||||
# Alternate deterministically between (1) and (2).
|
||||
train_op = Concurrently(
|
||||
[store_op, replay_op], mode="round_robin", output_indexes=[1])
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
SimpleQTrainer = DQNTrainer.with_updates(
|
||||
default_policy=SimpleQTFPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
execution_plan=execution_plan,
|
||||
default_config=DEFAULT_CONFIG)
|
||||
|
|
|
@ -12,8 +12,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_workers": 0,
|
||||
# Learning rate.
|
||||
"lr": 0.0004,
|
||||
# Use the execution plan API instead of policy optimizers.
|
||||
"use_exec_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
|
|
@ -82,6 +82,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"lstm_cell_size": 64,
|
||||
"max_seq_len": 999999,
|
||||
},
|
||||
|
||||
# TODO(ekl) support sync batch replay.
|
||||
"use_exec_api": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
|
|
@ -9,6 +9,7 @@ from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as \
|
|||
loss_torch
|
||||
from ray.rllib.models.tf.tf_action_dist import SquashedGaussian
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchSquashedGaussian
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.numpy import fc, relu
|
||||
|
@ -306,7 +307,8 @@ class TestSAC(unittest.TestCase):
|
|||
tf_inputs.append(in_)
|
||||
# Set a fake-batch to use
|
||||
# (instead of sampling from replay buffer).
|
||||
trainer.optimizer._fake_batch = in_
|
||||
buf = LocalReplayBuffer.get_instance_for_testing()
|
||||
buf._fake_batch = in_
|
||||
trainer.train()
|
||||
updated_weights = policy.get_weights()
|
||||
# Net must have changed.
|
||||
|
@ -325,7 +327,8 @@ class TestSAC(unittest.TestCase):
|
|||
in_ = tf_inputs[update_iteration]
|
||||
# Set a fake-batch to use
|
||||
# (instead of sampling from replay buffer).
|
||||
trainer.optimizer._fake_batch = in_
|
||||
buf = LocalReplayBuffer.get_instance_for_testing()
|
||||
buf._fake_batch = in_
|
||||
trainer.train()
|
||||
# Compare updated model.
|
||||
for tf_key in sorted(tf_weights.keys())[2:10]:
|
||||
|
|
|
@ -206,7 +206,7 @@ COMMON_CONFIG = {
|
|||
"custom_eval_function": None,
|
||||
# EXPERIMENTAL: use the execution plan based API impl of the algo. Can also
|
||||
# be enabled by setting RLLIB_EXEC_API=1.
|
||||
"use_exec_api": False,
|
||||
"use_exec_api": True,
|
||||
|
||||
# === Advanced Rollout Settings ===
|
||||
# Use a background thread for sampling (slightly off-policy, usually not
|
||||
|
|
|
@ -100,6 +100,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_workers": 1,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_iter_time_s": 0,
|
||||
|
||||
# TODO(ekl) support synchronized sampling.
|
||||
"use_exec_api": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
|
|
@ -159,7 +159,7 @@ class OncePerTimestepsElapsed:
|
|||
return True
|
||||
metrics = LocalIterator.get_metrics()
|
||||
now = metrics.counters[STEPS_SAMPLED_COUNTER]
|
||||
if now - self.last_called > self.delay_steps:
|
||||
if now - self.last_called >= self.delay_steps:
|
||||
self.last_called = now
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -285,6 +285,10 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
|||
return sample_timesteps, train_timesteps
|
||||
|
||||
|
||||
# Visible for testing.
|
||||
_local_replay_buffer = None
|
||||
|
||||
|
||||
# TODO(ekl) move this class to common
|
||||
class LocalReplayBuffer(ParallelIteratorWorker):
|
||||
"""A replay buffer shard.
|
||||
|
@ -292,9 +296,14 @@ class LocalReplayBuffer(ParallelIteratorWorker):
|
|||
Ray actors are single-threaded, so for scalability multiple replay actors
|
||||
may be created to increase parallelism."""
|
||||
|
||||
def __init__(self, num_shards, learning_starts, buffer_size,
|
||||
replay_batch_size, prioritized_replay_alpha,
|
||||
prioritized_replay_beta, prioritized_replay_eps):
|
||||
def __init__(self,
|
||||
num_shards,
|
||||
learning_starts,
|
||||
buffer_size,
|
||||
replay_batch_size,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6):
|
||||
self.replay_starts = learning_starts // num_shards
|
||||
self.buffer_size = buffer_size // num_shards
|
||||
self.replay_batch_size = replay_batch_size
|
||||
|
@ -319,6 +328,17 @@ class LocalReplayBuffer(ParallelIteratorWorker):
|
|||
self.update_priorities_timer = TimerStat()
|
||||
self.num_added = 0
|
||||
|
||||
# Make externally accessible for testing.
|
||||
global _local_replay_buffer
|
||||
_local_replay_buffer = self
|
||||
# If set, return this instead of the usual data for testing.
|
||||
self._fake_batch = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance_for_testing():
|
||||
global _local_replay_buffer
|
||||
return _local_replay_buffer
|
||||
|
||||
def get_host(self):
|
||||
return os.uname()[1]
|
||||
|
||||
|
@ -338,6 +358,12 @@ class LocalReplayBuffer(ParallelIteratorWorker):
|
|||
self.num_added += batch.count
|
||||
|
||||
def replay(self):
|
||||
if self._fake_batch:
|
||||
fake_batch = SampleBatch(self._fake_batch)
|
||||
return MultiAgentBatch({
|
||||
DEFAULT_POLICY_ID: fake_batch
|
||||
}, fake_batch.count)
|
||||
|
||||
if self.num_added < self.replay_starts:
|
||||
return None
|
||||
|
||||
|
@ -392,9 +418,14 @@ class LocalBatchReplayBuffer(LocalReplayBuffer):
|
|||
This allows for RNN models, but ignores prioritization params.
|
||||
"""
|
||||
|
||||
def __init__(self, num_shards, learning_starts, buffer_size,
|
||||
train_batch_size, prioritized_replay_alpha,
|
||||
prioritized_replay_beta, prioritized_replay_eps):
|
||||
def __init__(self,
|
||||
num_shards,
|
||||
learning_starts,
|
||||
buffer_size,
|
||||
train_batch_size,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6):
|
||||
self.replay_starts = learning_starts // num_shards
|
||||
self.buffer_size = buffer_size // num_shards
|
||||
self.train_batch_size = train_batch_size
|
||||
|
|
|
@ -256,6 +256,7 @@ class ModelSupportedSpaces(unittest.TestCase):
|
|||
check_support_multiagent("DDPG", {
|
||||
"timesteps_per_iteration": 1,
|
||||
"use_state_preprocessor": True,
|
||||
"learning_starts": 500,
|
||||
})
|
||||
|
||||
def test_dqn_multiagent(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue