mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Unify the way we create local replay buffer for all agents (#19627)
* [RLlib] Unify the way we create and use LocalReplayBuffer for all the agents. This change 1. Get rid of the try...except clause when we call execution_plan(), and get rid of the Deprecation warning as a result. 2. Fix the execution_plan() call in Trainer._try_recover() too. 3. Most importantly, makes it much easier to create and use different types of local replay buffers for all our agents. E.g., allow us to easily create a reservoir sampling replay buffer for APPO agent for Riot in the near future. * Introduce explicit configuration for replay buffer types. * Fix is_training key error. * actually deprecate buffer_size field.
This commit is contained in:
parent
ab15dfd478
commit
99a0088233
22 changed files with 180 additions and 117 deletions
|
@ -29,8 +29,8 @@ A2C_DEFAULT_CONFIG = merge_dicts(
|
|||
)
|
||||
|
||||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
||||
**kwargs) -> LocalIterator[dict]:
|
||||
"""Execution plan of the A2C algorithm. Defines the distributed
|
||||
dataflow.
|
||||
|
||||
|
@ -42,6 +42,9 @@ def execution_plan(workers: WorkerSet,
|
|||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
assert len(kwargs) == 0, (
|
||||
"A2C execution_plan does NOT take any additional parameters")
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
if config["microbatch_size"]:
|
||||
|
|
|
@ -78,8 +78,8 @@ def validate_config(config: TrainerConfigDict) -> None:
|
|||
raise ValueError("`num_workers` for A3C must be >= 1!")
|
||||
|
||||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
||||
**kwargs) -> LocalIterator[dict]:
|
||||
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
|
||||
dataflow.
|
||||
|
||||
|
@ -91,6 +91,9 @@ def execution_plan(workers: WorkerSet,
|
|||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
assert len(kwargs) == 0, (
|
||||
"A3C execution_plan does NOT take any additional parameters")
|
||||
|
||||
# For A3C, compute policy gradients remotely on the rollout workers.
|
||||
grads = AsyncGradients(workers)
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
|
|||
from ray.rllib.agents.sac.sac import SACTrainer, \
|
||||
DEFAULT_CONFIG as SAC_CONFIG
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_ops import Replay
|
||||
from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \
|
||||
UpdateTargetNetwork
|
||||
|
@ -17,6 +16,7 @@ from ray.rllib.offline.shuffled_input import ShuffledInput
|
|||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
@ -24,7 +24,6 @@ from ray.rllib.utils.typing import TrainerConfigDict
|
|||
tf1, tf, tfv = try_import_tf()
|
||||
tfp = try_import_tfp()
|
||||
logger = logging.getLogger(__name__)
|
||||
replay_buffer = None
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -48,7 +47,11 @@ CQL_DEFAULT_CONFIG = merge_dicts(
|
|||
"min_q_weight": 5.0,
|
||||
# Replay buffer should be larger or equal the size of the offline
|
||||
# dataset.
|
||||
"buffer_size": int(1e6),
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
"replay_buffer_config": {
|
||||
"type": "LocalReplayBuffer",
|
||||
"capacity": int(1e6),
|
||||
},
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -74,29 +77,11 @@ def validate_config(config: TrainerConfigDict):
|
|||
try_import_tfp(error=True)
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
if config.get("prioritized_replay"):
|
||||
prio_args = {
|
||||
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
|
||||
"prioritized_replay_beta": config["prioritized_replay_beta"],
|
||||
"prioritized_replay_eps": config["prioritized_replay_eps"],
|
||||
}
|
||||
else:
|
||||
prio_args = {}
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert "local_replay_buffer" in kwargs, (
|
||||
"CQL execution plan requires a local replay buffer.")
|
||||
|
||||
local_replay_buffer = LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=config["learning_starts"],
|
||||
capacity=config["buffer_size"],
|
||||
replay_batch_size=config["train_batch_size"],
|
||||
replay_mode=config["multiagent"]["replay_mode"],
|
||||
replay_sequence_length=config.get("replay_sequence_length", 1),
|
||||
replay_burn_in=config.get("burn_in", 0),
|
||||
replay_zero_init_states=config.get("zero_init_states", True),
|
||||
**prio_args)
|
||||
|
||||
global replay_buffer
|
||||
replay_buffer = local_replay_buffer
|
||||
local_replay_buffer = kwargs["local_replay_buffer"]
|
||||
|
||||
def update_prio(item):
|
||||
samples, info_dict = item
|
||||
|
@ -150,8 +135,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
|||
|
||||
def after_init(trainer):
|
||||
# Add the entire dataset to Replay Buffer (global variable)
|
||||
global replay_buffer
|
||||
reader = trainer.workers.local_worker().input_reader
|
||||
replay_buffer = trainer.local_replay_buffer
|
||||
|
||||
# For d4rl, add the D4RLReaders' dataset to the buffer.
|
||||
if isinstance(trainer.config["input"], str) and \
|
||||
|
|
|
@ -88,8 +88,8 @@ class TestCQL(unittest.TestCase):
|
|||
# Example on how to do evaluation on the trained Trainer
|
||||
# using the data from CQL's global replay buffer.
|
||||
# Get a sample (MultiAgentBatch -> SampleBatch).
|
||||
from ray.rllib.agents.cql.cql import replay_buffer
|
||||
batch = replay_buffer.replay().policy_batches["default_policy"]
|
||||
batch = trainer.local_replay_buffer.replay().policy_batches[
|
||||
"default_policy"]
|
||||
|
||||
if fw == "torch":
|
||||
obs = torch.from_numpy(batch["obs"])
|
||||
|
|
|
@ -55,6 +55,9 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
|||
"num_gpus": 1,
|
||||
"num_workers": 32,
|
||||
"buffer_size": 2000000,
|
||||
# TODO(jungong) : add proper replay_buffer_config after
|
||||
# DistributedReplayBuffer type is supported.
|
||||
"replay_buffer_config": None,
|
||||
"learning_starts": 50000,
|
||||
"train_batch_size": 512,
|
||||
"rollout_fragment_length": 50,
|
||||
|
@ -141,8 +144,11 @@ class UpdateWorkerWeights:
|
|||
metrics.counters["num_weight_syncs"] += 1
|
||||
|
||||
|
||||
def apex_execution_plan(workers: WorkerSet,
|
||||
config: dict) -> LocalIterator[dict]:
|
||||
def apex_execution_plan(workers: WorkerSet, config: dict,
|
||||
**kwargs) -> LocalIterator[dict]:
|
||||
assert len(kwargs) == 0, (
|
||||
"Apex execution_plan does NOT take any additional parameters")
|
||||
|
||||
# Create a number of replay buffer actors.
|
||||
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
|
||||
replay_actors = create_colocated(ReplayActor, [
|
||||
|
|
|
@ -20,7 +20,6 @@ from ray.rllib.agents.trainer import Trainer
|
|||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork, \
|
||||
|
@ -145,8 +144,8 @@ def validate_config(config: TrainerConfigDict) -> None:
|
|||
"simple_optimizer=True if this doesn't work for you.")
|
||||
|
||||
|
||||
def execution_plan(trainer: Trainer, workers: WorkerSet,
|
||||
config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]:
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
||||
**kwargs) -> LocalIterator[dict]:
|
||||
"""Execution plan of the DQN algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
|
@ -158,28 +157,12 @@ def execution_plan(trainer: Trainer, workers: WorkerSet,
|
|||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
if config.get("prioritized_replay"):
|
||||
prio_args = {
|
||||
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
|
||||
"prioritized_replay_beta": config["prioritized_replay_beta"],
|
||||
"prioritized_replay_eps": config["prioritized_replay_eps"],
|
||||
}
|
||||
else:
|
||||
prio_args = {}
|
||||
assert "local_replay_buffer" in kwargs, (
|
||||
"DQN execution plan requires a local replay buffer.")
|
||||
|
||||
local_replay_buffer = LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=config["learning_starts"],
|
||||
capacity=config["buffer_size"],
|
||||
replay_batch_size=config["train_batch_size"],
|
||||
replay_mode=config["multiagent"]["replay_mode"],
|
||||
replay_sequence_length=config.get("replay_sequence_length", 1),
|
||||
replay_burn_in=config.get("burn_in", 0),
|
||||
replay_zero_init_states=config.get("zero_init_states", True),
|
||||
**prio_args)
|
||||
# Assign to Trainer, so we can store the LocalReplayBuffer's
|
||||
# data when we save checkpoints.
|
||||
trainer.local_replay_buffer = local_replay_buffer
|
||||
local_replay_buffer = kwargs["local_replay_buffer"]
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
|
|
|
@ -14,17 +14,17 @@ from typing import Optional, Type
|
|||
|
||||
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
||||
from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import MultiGPUTrainOneStep, TrainOneStep, \
|
||||
UpdateTargetNetwork
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
|
@ -62,7 +62,11 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
# each worker will have a replay buffer of this size.
|
||||
"buffer_size": 50000,
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
"replay_buffer_config": {
|
||||
"type": "LocalReplayBuffer",
|
||||
"capacity": 50000,
|
||||
},
|
||||
# Set this to True, if you want the contents of your buffer(s) to be
|
||||
# stored in any saved checkpoints as well.
|
||||
# Warnings will be created if:
|
||||
|
@ -122,8 +126,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
|||
return SimpleQTorchPolicy
|
||||
|
||||
|
||||
def execution_plan(trainer: Trainer, workers: WorkerSet,
|
||||
config: TrainerConfigDict, **kwargs) -> LocalIterator[dict]:
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
||||
**kwargs) -> LocalIterator[dict]:
|
||||
"""Execution plan of the Simple Q algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
|
@ -135,16 +139,10 @@ def execution_plan(trainer: Trainer, workers: WorkerSet,
|
|||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
local_replay_buffer = LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=config["learning_starts"],
|
||||
capacity=config["buffer_size"],
|
||||
replay_batch_size=config["train_batch_size"],
|
||||
replay_mode=config["multiagent"]["replay_mode"],
|
||||
replay_sequence_length=config["replay_sequence_length"])
|
||||
# Assign to Trainer, so we can store the LocalReplayBuffer's
|
||||
# data when we save checkpoints.
|
||||
trainer.local_replay_buffer = local_replay_buffer
|
||||
assert "local_replay_buffer" in kwargs, (
|
||||
"SimpleQ execution plan requires a local replay buffer.")
|
||||
|
||||
local_replay_buffer = kwargs["local_replay_buffer"]
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
|
|
|
@ -185,7 +185,10 @@ class DreamerIteration:
|
|||
return fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
"Dreamer execution_plan does NOT take any additional parameters")
|
||||
|
||||
# Special replay buffer for Dreamer agent.
|
||||
episode_buffer = EpisodicBuffer(length=config["batch_length"])
|
||||
|
||||
|
|
|
@ -312,7 +312,10 @@ def gather_experiences_directly(workers, config):
|
|||
return train_batches
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
"IMPALA execution_plan does NOT take any additional parameters")
|
||||
|
||||
if config["num_aggregation_workers"] > 0:
|
||||
train_batches = gather_experiences_tree_aggregation(workers, config)
|
||||
else:
|
||||
|
|
|
@ -156,7 +156,10 @@ def inner_adaptation(workers, samples):
|
|||
e.learn_on_batch.remote(samples[i])
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
"MAML execution_plan does NOT take any additional parameters")
|
||||
|
||||
# Sync workers with meta policy
|
||||
workers.sync_weights()
|
||||
|
||||
|
|
|
@ -90,8 +90,8 @@ def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
|||
return MARWILTorchPolicy
|
||||
|
||||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
||||
**kwargs) -> LocalIterator[dict]:
|
||||
"""Execution plan of the MARWIL/BC algorithm. Defines the distributed
|
||||
dataflow.
|
||||
|
||||
|
@ -103,6 +103,9 @@ def execution_plan(workers: WorkerSet,
|
|||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
assert len(kwargs) == 0, (
|
||||
"Marwill execution_plan does NOT take any additional parameters")
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
replay_buffer = LocalReplayBuffer(
|
||||
learning_starts=config["learning_starts"],
|
||||
|
|
|
@ -336,8 +336,8 @@ def post_process_samples(samples, config: TrainerConfigDict):
|
|||
return samples, split_lst
|
||||
|
||||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
||||
**kwargs) -> LocalIterator[dict]:
|
||||
"""Execution plan of the PPO algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
|
@ -349,6 +349,9 @@ def execution_plan(workers: WorkerSet,
|
|||
LocalIterator[dict]: The Policy class to use with PPOTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
assert len(kwargs) == 0, (
|
||||
"MBMPO execution_plan does NOT take any additional parameters")
|
||||
|
||||
# Train TD Models on the driver.
|
||||
workers.local_worker().foreach_policy(fit_dynamics)
|
||||
|
||||
|
|
|
@ -146,8 +146,8 @@ def validate_config(config):
|
|||
raise ValueError("DDPPO doesn't support KL penalties like PPO-1")
|
||||
|
||||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
||||
**kwargs) -> LocalIterator[dict]:
|
||||
"""Execution plan of the DD-PPO algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
|
@ -159,6 +159,9 @@ def execution_plan(workers: WorkerSet,
|
|||
LocalIterator[dict]: The Policy class to use with PGTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
assert len(kwargs) == 0, (
|
||||
"DDPPO execution_plan does NOT take any additional parameters")
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="raw")
|
||||
|
||||
# Setup the distributed processes.
|
||||
|
|
|
@ -253,8 +253,8 @@ def warn_about_bad_reward_scales(config, result):
|
|||
return result
|
||||
|
||||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
||||
**kwargs) -> LocalIterator[dict]:
|
||||
"""Execution plan of the PPO algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
|
@ -266,6 +266,9 @@ def execution_plan(workers: WorkerSet,
|
|||
LocalIterator[dict]: The Policy class to use with PPOTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
assert len(kwargs) == 0, (
|
||||
"PPO execution_plan does NOT take any additional parameters")
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# Collect batches for the trainable policies.
|
||||
|
|
|
@ -100,7 +100,10 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
"QMIX execution_plan does NOT take any additional parameters")
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
|
||||
|
||||
|
|
|
@ -248,7 +248,7 @@ class SACTFModel(TFModelV2):
|
|||
input_dict = {"obs": model_out}
|
||||
# Switch on training mode (when getting Q-values, we are usually in
|
||||
# training).
|
||||
input_dict.is_training = True
|
||||
input_dict["is_training"] = True
|
||||
|
||||
out, _ = net(input_dict, [], None)
|
||||
return out
|
||||
|
|
|
@ -256,7 +256,7 @@ class SACTorchModel(TorchModelV2, nn.Module):
|
|||
input_dict = {"obs": model_out}
|
||||
# Switch on training mode (when getting Q-values, we are usually in
|
||||
# training).
|
||||
input_dict.is_training = True
|
||||
input_dict["is_training"] = True
|
||||
|
||||
out, _ = net(input_dict, [], None)
|
||||
return out
|
||||
|
|
|
@ -22,11 +22,11 @@ from ray.rllib.evaluation.worker_set import WorkerSet
|
|||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import TrainOneStep
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
|
@ -82,7 +82,11 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
# each worker will have a replay buffer of this size.
|
||||
"buffer_size": 50000,
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
"replay_buffer_config": {
|
||||
"type": "LocalReplayBuffer",
|
||||
"capacity": 50000,
|
||||
},
|
||||
# The number of contiguous environment steps to replay at once. This may
|
||||
# be set to greater than 1 to support recurrent models.
|
||||
"replay_sequence_length": 1,
|
||||
|
@ -152,8 +156,8 @@ def validate_config(config: TrainerConfigDict) -> None:
|
|||
"For SARSA strategy, batch_mode must be 'complete_episodes'")
|
||||
|
||||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
||||
**kwargs) -> LocalIterator[dict]:
|
||||
"""Execution plan of the SlateQ algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
|
@ -164,14 +168,8 @@ def execution_plan(workers: WorkerSet,
|
|||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
local_replay_buffer = LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=config["learning_starts"],
|
||||
capacity=config["buffer_size"],
|
||||
replay_batch_size=config["train_batch_size"],
|
||||
replay_mode=config["multiagent"]["replay_mode"],
|
||||
replay_sequence_length=config["replay_sequence_length"],
|
||||
)
|
||||
assert "local_replay_buffer" in kwargs, (
|
||||
"SlateQ execution plan requires a local replay buffer.")
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
|
@ -179,12 +177,12 @@ def execution_plan(workers: WorkerSet,
|
|||
# (1) Generate rollouts and store them in our local replay buffer. Calling
|
||||
# next() on store_op drives this.
|
||||
store_op = rollouts.for_each(
|
||||
StoreToReplayBuffer(local_buffer=local_replay_buffer))
|
||||
StoreToReplayBuffer(local_buffer=kwargs["local_replay_buffer"]))
|
||||
|
||||
# (2) Read and train on experiences from the replay buffer. Every batch
|
||||
# returned from the LocalReplay() iterator is passed to TrainOneStep to
|
||||
# take a SGD step.
|
||||
replay_op = Replay(local_buffer=local_replay_buffer) \
|
||||
replay_op = Replay(local_buffer=kwargs["local_replay_buffer"]) \
|
||||
.for_each(TrainOneStep(workers))
|
||||
|
||||
if config["slateq_strategy"] != "RANDOM":
|
||||
|
|
|
@ -23,6 +23,7 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode
|
|||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
|
@ -706,6 +707,67 @@ class Trainer(Trainable):
|
|||
# to mutate the result
|
||||
Trainable.log_result(self, result)
|
||||
|
||||
@DeveloperAPI
|
||||
def _create_local_replay_buffer_if_necessary(self, config):
|
||||
"""Create a LocalReplayBuffer instance if necessary.
|
||||
|
||||
Args:
|
||||
config (dict): Algorithm-specific configuration data.
|
||||
|
||||
Returns:
|
||||
LocalReplayBuffer instance based on trainer config.
|
||||
None, if local replay buffer is not needed.
|
||||
"""
|
||||
# These are the agents that utilizes a local replay buffer.
|
||||
if ("replay_buffer_config" not in config
|
||||
or not config["replay_buffer_config"]):
|
||||
# Does not need a replay buffer.
|
||||
return None
|
||||
|
||||
replay_buffer_config = config["replay_buffer_config"]
|
||||
if ("type" not in replay_buffer_config
|
||||
or replay_buffer_config["type"] != "LocalReplayBuffer"):
|
||||
# DistributedReplayBuffer coming soon.
|
||||
return None
|
||||
|
||||
capacity = config.get("buffer_size", DEPRECATED_VALUE)
|
||||
if capacity != DEPRECATED_VALUE:
|
||||
# Print a deprecation warning.
|
||||
deprecation_warning(
|
||||
old="config['buffer_size']",
|
||||
new="config['replay_buffer_config']['capacity']",
|
||||
error=False)
|
||||
else:
|
||||
# Get capacity out of replay_buffer_config.
|
||||
capacity = replay_buffer_config["capacity"]
|
||||
|
||||
if config.get("prioritized_replay"):
|
||||
prio_args = {
|
||||
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
|
||||
"prioritized_replay_beta": config["prioritized_replay_beta"],
|
||||
"prioritized_replay_eps": config["prioritized_replay_eps"],
|
||||
}
|
||||
else:
|
||||
prio_args = {}
|
||||
|
||||
return LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=config["learning_starts"],
|
||||
capacity=capacity,
|
||||
replay_batch_size=config["train_batch_size"],
|
||||
replay_mode=config["multiagent"]["replay_mode"],
|
||||
replay_sequence_length=config.get("replay_sequence_length", 1),
|
||||
replay_burn_in=config.get("burn_in", 0),
|
||||
replay_zero_init_states=config.get("zero_init_states", True),
|
||||
**prio_args)
|
||||
|
||||
@DeveloperAPI
|
||||
def _kwargs_for_execution_plan(self):
|
||||
kwargs = {}
|
||||
if self.local_replay_buffer:
|
||||
kwargs["local_replay_buffer"] = self.local_replay_buffer
|
||||
return kwargs
|
||||
|
||||
@override(Trainable)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
env = self._env_id
|
||||
|
@ -773,6 +835,10 @@ class Trainer(Trainable):
|
|||
if self.config.get("log_level"):
|
||||
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
|
||||
|
||||
# Create local replay buffer if necessary.
|
||||
self.local_replay_buffer = (
|
||||
self._create_local_replay_buffer_if_necessary(self.config))
|
||||
|
||||
self._init(self.config, self.env_creator)
|
||||
|
||||
# Evaluation setup.
|
||||
|
@ -1747,7 +1813,8 @@ class Trainer(Trainable):
|
|||
|
||||
logger.warning("Recreating execution plan after failure")
|
||||
workers.reset(healthy_workers)
|
||||
self.train_exec_impl = self.execution_plan(workers, self.config)
|
||||
self.train_exec_impl = self.execution_plan(
|
||||
workers, self.config, **self._kwargs_for_execution_plan())
|
||||
|
||||
@override(Trainable)
|
||||
def _export_model(self, export_formats: List[str],
|
||||
|
|
|
@ -18,7 +18,11 @@ from ray.rllib.utils.typing import EnvConfigDict, EnvType, \
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
||||
def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
||||
**kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
"Default execution_plan does NOT take any additional parameters")
|
||||
|
||||
# Collects experiences in parallel from multiple RolloutWorker actors.
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
|
@ -175,19 +179,8 @@ def build_trainer(
|
|||
config=config,
|
||||
num_workers=self.config["num_workers"])
|
||||
self.execution_plan = execution_plan
|
||||
try:
|
||||
self.train_exec_impl = execution_plan(self, self.workers,
|
||||
config)
|
||||
except TypeError as e:
|
||||
# Keyword error: Try old way w/o kwargs.
|
||||
if "() takes 2 positional arguments but 3" in e.args[0]:
|
||||
self.train_exec_impl = execution_plan(self.workers, config)
|
||||
logger.warning(
|
||||
"`execution_plan` functions should accept "
|
||||
"`trainer`, `workers`, and `config` as args!")
|
||||
# Other error -> re-raise.
|
||||
else:
|
||||
raise e
|
||||
self.train_exec_impl = execution_plan(
|
||||
self.workers, config, **self._kwargs_for_execution_plan())
|
||||
|
||||
if after_init:
|
||||
after_init(self)
|
||||
|
|
|
@ -160,7 +160,10 @@ class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
|
|||
_env_creator)
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
"Alpha zero execution_plan does NOT take any additional parameters")
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
if config["simple_optimizer"]:
|
||||
|
|
|
@ -63,8 +63,8 @@ class RandomParametriclPolicy(Policy, ABC):
|
|||
pass
|
||||
|
||||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
|
||||
**kwargs) -> LocalIterator[dict]:
|
||||
rollouts = ParallelRollouts(workers, mode="async")
|
||||
|
||||
# Collect batches for the trainable policies.
|
||||
|
|
Loading…
Add table
Reference in a new issue