[RLlib] Simple-Q uses training iteration fn (instead of execution_plan); ReplayBuffer API for Simple-Q (#22842)

This commit is contained in:
Artur Niederfahrenhorst 2022-03-29 15:44:40 +03:00 committed by GitHub
parent a7e5aa8c6a
commit 9a64bd4e9b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 377 additions and 98 deletions

View file

@ -80,8 +80,6 @@ apex-breakoutnoframeskip-v4:
epsilon_timesteps: 200000 epsilon_timesteps: 200000
final_epsilon: 0.01 final_epsilon: 0.01
prioritized_replay_alpha: 0.5 prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 1 num_gpus: 1
num_workers: 8 num_workers: 8
num_envs_per_worker: 8 num_envs_per_worker: 8
@ -327,8 +325,6 @@ dqn-breakoutnoframeskip-v4:
epsilon_timesteps: 200000 epsilon_timesteps: 200000
final_epsilon: 0.01 final_epsilon: 0.01
prioritized_replay_alpha: 0.5 prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.5 num_gpus: 0.5
timesteps_per_iteration: 10000 timesteps_per_iteration: 10000

View file

@ -53,8 +53,6 @@ apex-breakoutnoframeskip-v4:
epsilon_timesteps: 200000 epsilon_timesteps: 200000
final_epsilon: 0.01 final_epsilon: 0.01
prioritized_replay_alpha: 0.5 prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 1 num_gpus: 1
num_workers: 8 num_workers: 8
num_envs_per_worker: 8 num_envs_per_worker: 8

View file

@ -19,7 +19,7 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
"num_workers": 32, "num_workers": 32,
"buffer_size": 2000000, "buffer_size": 2000000,
# TODO(jungong) : update once Apex supports replay_buffer_config. # TODO(jungong) : update once Apex supports replay_buffer_config.
"replay_buffer_config": None, "no_local_replay_buffer": True,
# Whether all shards of the replay buffer must be co-located # Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan). # with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick # This is preferred b/c the learner process should have quick

View file

@ -111,10 +111,6 @@ DEFAULT_CONFIG = with_common_config({
"prioritized_replay_alpha": 0.6, "prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer. # Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4, "prioritized_replay_beta": 0.4,
# Time steps over which the beta parameter is annealed.
"prioritized_replay_beta_annealing_timesteps": 20000,
# Final value of beta
"final_prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities. # Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6, "prioritized_replay_eps": 1e-6,
# Whether to LZ4 compress observations # Whether to LZ4 compress observations

View file

@ -66,7 +66,7 @@ APEX_DEFAULT_CONFIG = merge_dicts(
"buffer_size": 2000000, "buffer_size": 2000000,
# TODO(jungong) : add proper replay_buffer_config after # TODO(jungong) : add proper replay_buffer_config after
# DistributedReplayBuffer type is supported. # DistributedReplayBuffer type is supported.
"replay_buffer_config": None, "no_local_replay_buffer": True,
# Whether all shards of the replay buffer must be co-located # Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan). # with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick # This is preferred b/c the learner process should have quick
@ -157,9 +157,9 @@ class ApexTrainer(DQNTrainer):
config["learning_starts"], config["learning_starts"],
config["buffer_size"], config["buffer_size"],
config["train_batch_size"], config["train_batch_size"],
config["prioritized_replay_alpha"], config["replay_buffer_config"]["prioritized_replay_alpha"],
config["prioritized_replay_beta"], config["replay_buffer_config"]["prioritized_replay_beta"],
config["prioritized_replay_eps"], config["replay_buffer_config"]["prioritized_replay_eps"],
config["multiagent"]["replay_mode"], config["multiagent"]["replay_mode"],
config.get("replay_sequence_length", 1), config.get("replay_sequence_length", 1),
] ]

View file

@ -35,6 +35,7 @@ from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import TrainerConfigDict from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator from ray.util.iter import LocalIterator
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,19 +65,37 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# N-step Q learning # N-step Q learning
"n_step": 1, "n_step": 1,
# === Prioritized replay buffer === # === Replay buffer ===
# If True prioritized replay buffer will be used. # Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": DEPRECATED_VALUE,
# Prioritized replay is here since this algo uses the old replay
# buffer api
"prioritized_replay": True, "prioritized_replay": True,
# Alpha parameter for prioritized replay buffer. "replay_buffer_config": {
"prioritized_replay_alpha": 0.6, # For now we don't use the new ReplayBuffer API here
# Beta parameter for sampling from prioritized replay buffer. "_enable_replay_buffer_api": False,
"prioritized_replay_beta": 0.4, "type": "MultiAgentReplayBuffer",
# Final value of beta (by default, we use constant beta=0.4). "capacity": 50000,
"final_prioritized_replay_beta": 0.4, "replay_batch_size": 32,
# Time steps over which the beta parameter is annealed. "prioritized_replay_alpha": 0.6,
"prioritized_replay_beta_annealing_timesteps": 20000, # Beta parameter for sampling from prioritized replay buffer.
# Epsilon to add to the TD errors when updating priorities. "prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6, # Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
},
# Set this to True, if you want the contents of your buffer(s) to be
# stored in any saved checkpoints as well.
# Warnings will be created if:
# - This is True AND restoring from a checkpoint that contains no buffer
# data.
# - This is False AND restoring from a checkpoint that does contain
# buffer data.
"store_buffer_in_checkpoints": False,
# The number of contiguous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
# Callback to run before learning on a multi-agent batch of # Callback to run before learning on a multi-agent batch of
# experiences. # experiences.
@ -102,6 +121,12 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# === Parallelism === # === Parallelism ===
# Whether to compute priorities on workers. # Whether to compute priorities on workers.
"worker_side_prioritization": False, "worker_side_prioritization": False,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": False,
}, },
_allow_unknown_configs=True, _allow_unknown_configs=True,
) )

View file

@ -451,10 +451,14 @@ def postprocess_nstep_and_prio(
batch[SampleBatch.DONES], batch[SampleBatch.DONES],
batch[PRIO_WEIGHTS], batch[PRIO_WEIGHTS],
) )
new_priorities = ( # Retain compatibility with old-style Replay args
np.abs(convert_to_numpy(td_errors)) epsilon = policy.config.get("replay_buffer_config", {}).get(
+ policy.config["prioritized_replay_eps"] "prioritized_replay_eps"
) ) or policy.config.get("prioritized_replay_eps")
if epsilon is None:
raise ValueError("prioritized_replay_eps not defined in config.")
new_priorities = np.abs(convert_to_numpy(td_errors)) + epsilon
batch[PRIO_WEIGHTS] = new_priorities batch[PRIO_WEIGHTS] = new_priorities
return batch return batch

View file

@ -29,6 +29,19 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# Batch mode must be complete_episodes. # Batch mode must be complete_episodes.
"batch_mode": "complete_episodes", "batch_mode": "complete_episodes",
# === Replay buffer ===
"replay_buffer_config": {
# For now we don't use the new ReplayBuffer API here
"_enable_replay_buffer_api": False,
"type": "MultiAgentReplayBuffer",
"capacity": 50000,
"replay_batch_size": 32,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
},
# If True, assume a zero-initialized state input (no matter where in # If True, assume a zero-initialized state input (no matter where in
# the episode the sequence is located). # the episode the sequence is located).
# If False, store the initial states along with each SampleBatch, use # If False, store the initial states along with each SampleBatch, use
@ -66,6 +79,12 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# Update the target network every `target_network_update_freq` steps. # Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 2500, "target_network_update_freq": 2500,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": False,
}, },
_allow_unknown_configs=True, _allow_unknown_configs=True,
) )

View file

@ -15,19 +15,40 @@ from typing import Optional, Type
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
from ray.rllib.execution.concurrency_ops import Concurrently from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import ( from ray.rllib.execution.train_ops import (
MultiGPUTrainOneStep,
TrainOneStep, TrainOneStep,
MultiGPUTrainOneStep,
train_one_step,
multi_gpu_train_one_step,
)
from ray.rllib.execution.train_ops import (
UpdateTargetNetwork, UpdateTargetNetwork,
) )
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_SAMPLED,
NUM_AGENT_STEPS_SAMPLED,
)
from ray.rllib.utils.typing import (
ResultDict,
TrainerConfigDict,
)
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_TARGET_UPDATES,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,9 +85,18 @@ DEFAULT_CONFIG = with_common_config({
# Size of the replay buffer. Note that if async_updates is set, then # Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size. # each worker will have a replay buffer of this size.
"buffer_size": DEPRECATED_VALUE, "buffer_size": DEPRECATED_VALUE,
# Deprecated for Simple Q because of new ReplayBuffer API
# Use MultiAgentPrioritizedReplayBuffer for prioritization.
"prioritized_replay": DEPRECATED_VALUE,
"replay_buffer_config": { "replay_buffer_config": {
# Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer", "type": "MultiAgentReplayBuffer",
"capacity": 50000, "capacity": 50000,
"replay_batch_size": 32,
# The number of contiguous environment steps to replay at once. This
# may be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
}, },
# Set this to True, if you want the contents of your buffer(s) to be # Set this to True, if you want the contents of your buffer(s) to be
# stored in any saved checkpoints as well. # stored in any saved checkpoints as well.
@ -76,9 +106,6 @@ DEFAULT_CONFIG = with_common_config({
# - This is False AND restoring from a checkpoint that does contain # - This is False AND restoring from a checkpoint that does contain
# buffer data. # buffer data.
"store_buffer_in_checkpoints": False, "store_buffer_in_checkpoints": False,
# The number of contiguous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
# === Optimization === # === Optimization ===
# Learning rate for adam optimizer # Learning rate for adam optimizer
@ -108,6 +135,12 @@ DEFAULT_CONFIG = with_common_config({
"num_workers": 0, "num_workers": 0,
# Prevent reporting frequency from going lower than this time span. # Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1, "min_time_s_per_reporting": 1,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": True,
}) })
# __sphinx_doc_end__ # __sphinx_doc_end__
# fmt: on # fmt: on
@ -139,7 +172,9 @@ class SimpleQTrainer(Trainer):
" used at the same time!" " used at the same time!"
) )
if config.get("prioritized_replay"): if config.get("prioritized_replay") or config.get(
"replay_buffer_config", {}
).get("prioritized_replay"):
if config["multiagent"]["replay_mode"] == "lockstep": if config["multiagent"]["replay_mode"] == "lockstep":
raise ValueError( raise ValueError(
"Prioritized replay is not supported when replay_mode=lockstep." "Prioritized replay is not supported when replay_mode=lockstep."
@ -215,3 +250,63 @@ class SimpleQTrainer(Trainer):
) )
return StandardMetricsReporting(train_op, workers, config) return StandardMetricsReporting(train_op, workers, config)
@ExperimentalAPI
def training_iteration(self) -> ResultDict:
"""Simple Q training iteration function.
Simple Q consists of the following steps:
- (1) Sample (MultiAgentBatch) from workers...
- (2) Store new samples in replay buffer.
- (3) Sample training batch (MultiAgentBatch) from replay buffer.
- (4) Learn on training batch.
- (5) Update target network every target_network_update_freq steps.
- (6) Return all collected metrics for the iteration.
Returns:
The results dict from executing the training iteration.
"""
batch_size = self.config["train_batch_size"]
local_worker = self.workers.local_worker()
# (1) Sample (MultiAgentBatch) from workers
new_sample_batches = synchronous_parallel_sample(self.workers)
for s in new_sample_batches:
# Update counters
self._counters[NUM_ENV_STEPS_SAMPLED] += len(s)
self._counters[NUM_AGENT_STEPS_SAMPLED] += (
len(s) if isinstance(s, SampleBatch) else s.agent_steps()
)
# (2) Store new samples in replay buffer
self.local_replay_buffer.add(s)
# (3) Sample training batch (MultiAgentBatch) from replay buffer.
train_batch = self.local_replay_buffer.sample(batch_size)
# (4) Learn on training batch.
# Use simple optimizer (only for multi-agent or tf-eager; all other
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
if self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
# (5) Update target network every target_network_update_freq steps
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
last_update = self._counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update >= self.config["target_network_update_freq"]:
to_update = local_worker.get_policies_to_train()
local_worker.foreach_policy_to_train(
lambda p, pid: pid in to_update and p.update_target()
)
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
# Update remote workers' weights after learning on local worker
if self.workers.remote_workers():
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights()
# (6) Return all collected metrics for the iteration.
return train_results

View file

@ -109,8 +109,6 @@ DEFAULT_CONFIG = with_common_config({
"prioritized_replay_alpha": 0.6, "prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4, "prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6, "prioritized_replay_eps": 1e-6,
"prioritized_replay_beta_annealing_timesteps": 20000,
"final_prioritized_replay_beta": 0.4,
# Whether to LZ4 compress observations # Whether to LZ4 compress observations
"compress_observations": False, "compress_observations": False,

View file

@ -85,7 +85,7 @@ class TestSAC(unittest.TestCase):
# If we use default buffer size (1e6), the buffer will take up # If we use default buffer size (1e6), the buffer will take up
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021) # 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
# available system memory (8.34816 GB). # available system memory (8.34816 GB).
config["buffer_size"] = 40000 config["replay_buffer_config"]["capacity"] = 40000
# Test with saved replay buffer. # Test with saved replay buffer.
config["store_buffer_in_checkpoints"] = True config["store_buffer_in_checkpoints"] = True
num_iterations = 1 num_iterations = 1

View file

@ -361,6 +361,26 @@ COMMON_CONFIG: TrainerConfigDict = {
# "env_config": {...}, # "env_config": {...},
# "explore": False # "explore": False
}, },
# === Replay Buffer Settings ===
# Provide a dict specifying the ReplayBuffer's config.
# "replay_buffer_config": {
# The ReplayBuffer class to use. Any class that obeys the
# ReplayBuffer API can be used here. In the simplest case, this is the
# name (str) of any class present in the `rllib.utils.replay_buffers`
# package. You can also provide the python class directly or the
# full location of your class (e.g.
# "ray.rllib.utils.replay_buffers.replay_buffer.ReplayBuffer").
# "type": "ReplayBuffer",
# The capacity of units that can be stored in one ReplayBuffer
# instance before eviction.
# "capacity": 10000,
# Specifies how experiences are stored. Either 'sequences' or
# 'timesteps'.
# "storage_unit": "timesteps",
# Add constructor kwargs here (if any).
# },
# Number of parallel workers to use for evaluation. Note that this is set # Number of parallel workers to use for evaluation. Note that this is set
# to zero by default, which means evaluation will be run in the trainer # to zero by default, which means evaluation will be run in the trainer
# process (only if evaluation_interval is not None). If you increase this, # process (only if evaluation_interval is not None). If you increase this,
@ -652,6 +672,8 @@ COMMON_CONFIG: TrainerConfigDict = {
# Use `metrics_episode_collection_timeout_s` instead. # Use `metrics_episode_collection_timeout_s` instead.
"collect_metrics_timeout": DEPRECATED_VALUE, "collect_metrics_timeout": DEPRECATED_VALUE,
} }
# __sphinx_doc_end__ # __sphinx_doc_end__
# fmt: on # fmt: on
@ -719,7 +741,7 @@ class Trainer(Trainable):
"custom_resources_per_worker", "custom_resources_per_worker",
"evaluation_config", "evaluation_config",
"exploration_config", "exploration_config",
"extra_python_environs_for_driver", "replay_buffer_config",
"extra_python_environs_for_worker", "extra_python_environs_for_worker",
"input_config", "input_config",
"output_config", "output_config",
@ -727,7 +749,10 @@ class Trainer(Trainable):
# List of top level keys with value=dict, for which we always override the # List of top level keys with value=dict, for which we always override the
# entire value (dict), iff the "type" key in that value dict changes. # entire value (dict), iff the "type" key in that value dict changes.
_override_all_subkeys_if_type_changes = ["exploration_config"] _override_all_subkeys_if_type_changes = [
"exploration_config",
"replay_buffer_config",
]
# TODO: Deprecate. Instead, override `Trainer.get_default_config()`. # TODO: Deprecate. Instead, override `Trainer.get_default_config()`.
_default_config = COMMON_CONFIG _default_config = COMMON_CONFIG
@ -2724,58 +2749,147 @@ class Trainer(Trainable):
MultiAgentReplayBuffer instance based on trainer config. MultiAgentReplayBuffer instance based on trainer config.
None, if local replay buffer is not needed. None, if local replay buffer is not needed.
""" """
# These are the agents that utilizes a local replay buffer. # Deprecation of old-style replay buffer args
if "replay_buffer_config" not in config or not config["replay_buffer_config"]: # Warnings before checking of we need local buffer so that algorithms
# Does not need a replay buffer. # Without local buffer also get warned
return None deprecated_replay_buffer_keys = [
"prioritized_replay_alpha",
"prioritized_replay_beta",
"prioritized_replay_eps",
"learning_starts",
]
for k in deprecated_replay_buffer_keys:
if config.get(k) is not None:
deprecation_warning(
old="config[{}]".format(k),
help="config['replay_buffer_config'][{}] should be used "
"for Q-Learning algorithms. Ignore this warning if "
"you are not using a Q-Learning algorithm and still "
"provide {}."
"".format(k, k),
error=False,
)
# Copy values over to new location in config to support new
# and old configuration style
if config.get("replay_buffer_config") is not None:
config["replay_buffer_config"][k] = config[k]
# Some agents do not need a replay buffer
if not config.get("replay_buffer_config") or config.get(
"no_local_replay_buffer", False
):
return
replay_buffer_config = config["replay_buffer_config"] replay_buffer_config = config["replay_buffer_config"]
if ( assert (
"type" not in replay_buffer_config "type" in replay_buffer_config
or replay_buffer_config["type"] != "MultiAgentReplayBuffer" ), "Can not instantiate ReplayBuffer from config without 'type' key."
):
# DistributedReplayBuffer coming soon.
return None
capacity = config.get("buffer_size", DEPRECATED_VALUE) capacity = config.get("buffer_size", DEPRECATED_VALUE)
if capacity != DEPRECATED_VALUE: if capacity != DEPRECATED_VALUE:
# Print a deprecation warning.
deprecation_warning( deprecation_warning(
old="config['buffer_size']", old="config['buffer_size']",
new="config['replay_buffer_config']['capacity']", help="Buffer size specified at new location config["
"'replay_buffer_config']["
"'capacity'] will be overwritten.",
error=False, error=False,
) )
else: config["replay_buffer_config"]["capacity"] = capacity
# Get capacity out of replay_buffer_config.
capacity = replay_buffer_config["capacity"]
# Configure prio. replay parameters. # Check if old replay buffer should be instantiated
if config.get("prioritized_replay"): buffer_type = config["replay_buffer_config"]["type"]
prio_args = { if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
"prioritized_replay_alpha": config["prioritized_replay_alpha"], if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
"prioritized_replay_beta": config["prioritized_replay_beta"], # Prepend old-style buffers' path
"prioritized_replay_eps": config["prioritized_replay_eps"], assert buffer_type == "MultiAgentReplayBuffer", (
} "Without "
# Switch off prioritization (alpha=0.0). "ReplayBuffer "
else: "API, only "
prio_args = {"prioritized_replay_alpha": 0.0} "MultiAgentReplayBuffer "
"is supported!"
)
# Create valid full [module].[class] string for from_config
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
else:
assert buffer_type in [
"ray.rllib.execution.MultiAgentReplayBuffer",
MultiAgentReplayBuffer,
], (
"Without ReplayBuffer API, only "
"MultiAgentReplayBuffer is supported!"
)
return MultiAgentReplayBuffer( config["replay_buffer_config"]["type"] = buffer_type
num_shards=1,
learning_starts=config["learning_starts"], # Remove from config so it's not passed into the buffer c'tor
capacity=capacity, config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"], # We need to deprecate the old-style location of the following
replay_sequence_length=config.get("replay_sequence_length", 1), # buffer arguments and make users put them into the
replay_burn_in=config.get("burn_in", 0), # "replay_buffer_config" field of their config.
replay_zero_init_states=config.get("zero_init_states", True), config["replay_buffer_config"]["replay_batch_size"] = config[
**prio_args, "train_batch_size"
) ]
config["replay_buffer_config"]["replay_mode"] = config["multiagent"][
"replay_mode"
]
deprecation_warning(
old="config['multiagent']['replay_mode']",
new="config['replay_buffer_config']['replay_mode']",
error=False,
)
config["replay_buffer_config"]["replay_sequence_length"] = config.get(
"replay_sequence_length", 1
)
if config.get("replay_sequence_length"):
deprecation_warning(
old="config['replay_sequence_length']",
new="config['replay_buffer_config']['replay_sequence_length']",
error=False,
)
config["replay_buffer_config"]["replay_burn_in"] = config.get(
"replay_burn_in", 0
)
if config.get("burn_in"):
deprecation_warning(
old="config['burn_in']",
help="Burn in specified at new location config["
"'replay_buffer_config']["
"'replay_burn_in'] will be overwritten.",
)
config["replay_buffer_config"]["replay_burn_in"] = config["burn_in"]
config["replay_buffer_config"]["replay_zero_init_states"] = config.get(
"replay_zero_init_states", True
)
if config.get("replay_zero_init_states"):
deprecation_warning(
old="config['replay_zero_init_states']",
new="config['replay_buffer_config']['replay_zero_init_states']",
error=False,
)
# If no prioritized replay, old-style replay buffer should
# not be handed the following parameters:
if config.get("prioritized_replay", False) is False:
# This triggers non-prioritization in old-style replay buffer
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
else:
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Create valid full [module].[class] string for from_config
buffer_type = "ray.rllib.utils.replay_buffers." + buffer_type
config["replay_buffer_config"]["type"] = buffer_type
return from_config(buffer_type, config["replay_buffer_config"])
@DeveloperAPI @DeveloperAPI
def _kwargs_for_execution_plan(self): def _kwargs_for_execution_plan(self):
kwargs = {} kwargs = {}
if self.local_replay_buffer: if self.local_replay_buffer is not None:
kwargs["local_replay_buffer"] = self.local_replay_buffer kwargs["local_replay_buffer"] = self.local_replay_buffer
return kwargs return kwargs

View file

@ -31,7 +31,6 @@ if __name__ == "__main__":
"n_step": 3, "n_step": 3,
"lr": 0.0001, "lr": 0.0001,
"prioritized_replay_alpha": 0.5, "prioritized_replay_alpha": 0.5,
"final_prioritized_replay_beta": 1.0,
"target_network_update_freq": 50000, "target_network_update_freq": 50000,
"timesteps_per_iteration": 25000, "timesteps_per_iteration": 25000,
# Method specific. # Method specific.

View file

@ -91,8 +91,6 @@ apex:
epsilon_timesteps: 200000 epsilon_timesteps: 200000
final_epsilon: 0.01 final_epsilon: 0.01
prioritized_replay_alpha: 0.5 prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 1 num_gpus: 1
num_workers: 8 num_workers: 8
num_envs_per_worker: 8 num_envs_per_worker: 8
@ -141,7 +139,5 @@ atari-basic-dqn:
epsilon_timesteps: 200000 epsilon_timesteps: 200000
final_epsilon: 0.01 final_epsilon: 0.01
prioritized_replay_alpha: 0.5 prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.2 num_gpus: 0.2
timesteps_per_iteration: 10000 timesteps_per_iteration: 10000

View file

@ -20,8 +20,6 @@ apex:
hiddens: [512] hiddens: [512]
buffer_size: 1000000 buffer_size: 1000000
prioritized_replay_alpha: 0.5 prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 1 num_gpus: 1

View file

@ -25,7 +25,5 @@ atari-dist-dqn:
epsilon_timesteps: 200000 epsilon_timesteps: 200000
final_epsilon: 0.01 final_epsilon: 0.01
prioritized_replay_alpha: 0.5 prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.2 num_gpus: 0.2
timesteps_per_iteration: 10000 timesteps_per_iteration: 10000

View file

@ -29,7 +29,5 @@ atari-basic-dqn:
epsilon_timesteps: 200000 epsilon_timesteps: 200000
final_epsilon: 0.01 final_epsilon: 0.01
prioritized_replay_alpha: 0.5 prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.2 num_gpus: 0.2
timesteps_per_iteration: 10000 timesteps_per_iteration: 10000

View file

@ -29,7 +29,5 @@ dueling-ddqn:
epsilon_timesteps: 200000 epsilon_timesteps: 200000
final_epsilon: 0.01 final_epsilon: 0.01
prioritized_replay_alpha: 0.5 prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.2 num_gpus: 0.2
timesteps_per_iteration: 10000 timesteps_per_iteration: 10000

View file

@ -19,8 +19,6 @@ pong-deterministic-rainbow:
target_network_update_freq: 500 target_network_update_freq: 500
prioritized_replay: True prioritized_replay: True
prioritized_replay_alpha: 0.5 prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 400000
n_step: 3 n_step: 3
gpu: True gpu: True
model: model:

View file

@ -0,0 +1,26 @@
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer,
ReplayMode,
)
from ray.rllib.utils.replay_buffers.reservoir_buffer import ReservoirBuffer
from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import (
PrioritizedReplayBuffer,
)
from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import (
MultiAgentMixInReplayBuffer,
)
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
MultiAgentPrioritizedReplayBuffer,
)
__all__ = [
"ReplayBuffer",
"StorageUnit",
"MultiAgentReplayBuffer",
"ReplayMode",
"ReservoirBuffer",
"PrioritizedReplayBuffer",
"MultiAgentMixInReplayBuffer",
"MultiAgentPrioritizedReplayBuffer",
]

View file

@ -14,6 +14,7 @@ from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit
from ray.rllib.utils.from_config import from_config from ray.rllib.utils.from_config import from_config
from ray.util.debug import log_once from ray.util.debug import log_once
from ray.rllib.utils.deprecation import Deprecated
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -81,7 +82,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
'episodes'. Specifies how experiences are stored. If they 'episodes'. Specifies how experiences are stored. If they
are stored in episodes, replay_sequence_length is ignored. are stored in episodes, replay_sequence_length is ignored.
learning_starts: Number of timesteps after which a call to learning_starts: Number of timesteps after which a call to
`replay()` will yield samples (before that, `replay()` will `sample()` will yield samples (before that, `sample()` will
return None). return None).
capacity: Max number of total timesteps in all policy buffers. capacity: Max number of total timesteps in all policy buffers.
After reaching this number, older samples will be After reaching this number, older samples will be
@ -170,6 +171,14 @@ class MultiAgentReplayBuffer(ReplayBuffer):
"""Returns the number of items currently stored in this buffer.""" """Returns the number of items currently stored in this buffer."""
return sum(len(buffer._storage) for buffer in self.replay_buffers.values()) return sum(len(buffer._storage) for buffer in self.replay_buffers.values())
@ExperimentalAPI
@Deprecated(old="replay", new="sample", error=False)
def replay(self, num_items: int = None, **kwargs) -> Optional[SampleBatchType]:
"""Deprecated in favor of new ReplayBuffer API."""
if num_items is None:
num_items = self.replay_batch_size
return self.sample(num_items, **kwargs)
@ExperimentalAPI @ExperimentalAPI
@override(ReplayBuffer) @override(ReplayBuffer)
def add(self, batch: SampleBatchType, **kwargs) -> None: def add(self, batch: SampleBatchType, **kwargs) -> None:
@ -262,7 +271,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs) kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)
if self._num_added < self.replay_starts: if self._num_added < self.replay_starts:
return None return MultiAgentBatch({}, 0)
with self.replay_timer: with self.replay_timer:
# Lockstep mode: Sample from all policies at the same time an # Lockstep mode: Sample from all policies at the same time an
# equal amount of steps. # equal amount of steps.

View file

@ -1,17 +1,19 @@
import logging import logging
import platform import platform
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy as np import numpy as np
import random import random
from enum import Enum from enum import Enum
from ray.util.debug import log_once
# Import ray before psutil will make sure we use psutil's bundled version # Import ray before psutil will make sure we use psutil's bundled version
import ray # noqa F401 import ray # noqa F401
import psutil # noqa E402 import psutil # noqa E402
from ray.util.debug import log_once
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import ExperimentalAPI from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.window_stat import WindowStat from ray.rllib.utils.metrics.window_stat import WindowStat
from ray.rllib.utils.typing import SampleBatchType from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
@ -94,6 +96,18 @@ class ReplayBuffer:
"""Returns the number of items currently stored in this buffer.""" """Returns the number of items currently stored in this buffer."""
return len(self._storage) return len(self._storage)
@ExperimentalAPI
@Deprecated(old="add_batch", new="add", error=False)
def add_batch(self, batch: SampleBatchType, **kwargs) -> None:
"""Deprecated in favor of new ReplayBuffer API."""
return self.add(batch, **kwargs)
@ExperimentalAPI
@Deprecated(old="replay", new="sample", error=False)
def replay(self, num_items: int = 1, **kwargs) -> Optional[SampleBatchType]:
"""Deprecated in favor of new ReplayBuffer API."""
return self.sample(num_items, **kwargs)
@ExperimentalAPI @ExperimentalAPI
def add(self, batch: SampleBatchType, **kwargs) -> None: def add(self, batch: SampleBatchType, **kwargs) -> None:
"""Adds a batch of experiences to this buffer. """Adds a batch of experiences to this buffer.