[RLlib] Simple-Q uses training iteration fn (instead of execution_plan); ReplayBuffer API for Simple-Q (#22842)

This commit is contained in:
Artur Niederfahrenhorst 2022-03-29 15:44:40 +03:00 committed by GitHub
parent a7e5aa8c6a
commit 9a64bd4e9b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 377 additions and 98 deletions

View file

@ -80,8 +80,6 @@ apex-breakoutnoframeskip-v4:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 1
num_workers: 8
num_envs_per_worker: 8
@ -327,8 +325,6 @@ dqn-breakoutnoframeskip-v4:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.5
timesteps_per_iteration: 10000

View file

@ -53,8 +53,6 @@ apex-breakoutnoframeskip-v4:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 1
num_workers: 8
num_envs_per_worker: 8

View file

@ -19,7 +19,7 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
"num_workers": 32,
"buffer_size": 2000000,
# TODO(jungong) : update once Apex supports replay_buffer_config.
"replay_buffer_config": None,
"no_local_replay_buffer": True,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick

View file

@ -111,10 +111,6 @@ DEFAULT_CONFIG = with_common_config({
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Time steps over which the beta parameter is annealed.
"prioritized_replay_beta_annealing_timesteps": 20000,
# Final value of beta
"final_prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# Whether to LZ4 compress observations

View file

@ -66,7 +66,7 @@ APEX_DEFAULT_CONFIG = merge_dicts(
"buffer_size": 2000000,
# TODO(jungong) : add proper replay_buffer_config after
# DistributedReplayBuffer type is supported.
"replay_buffer_config": None,
"no_local_replay_buffer": True,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
@ -157,9 +157,9 @@ class ApexTrainer(DQNTrainer):
config["learning_starts"],
config["buffer_size"],
config["train_batch_size"],
config["prioritized_replay_alpha"],
config["prioritized_replay_beta"],
config["prioritized_replay_eps"],
config["replay_buffer_config"]["prioritized_replay_alpha"],
config["replay_buffer_config"]["prioritized_replay_beta"],
config["replay_buffer_config"]["prioritized_replay_eps"],
config["multiagent"]["replay_mode"],
config.get("replay_sequence_length", 1),
]

View file

@ -35,6 +35,7 @@ from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
logger = logging.getLogger(__name__)
@ -64,19 +65,37 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# N-step Q learning
"n_step": 1,
# === Prioritized replay buffer ===
# If True prioritized replay buffer will be used.
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": DEPRECATED_VALUE,
# Prioritized replay is here since this algo uses the old replay
# buffer api
"prioritized_replay": True,
# Alpha parameter for prioritized replay buffer.
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Final value of beta (by default, we use constant beta=0.4).
"final_prioritized_replay_beta": 0.4,
# Time steps over which the beta parameter is annealed.
"prioritized_replay_beta_annealing_timesteps": 20000,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
"replay_buffer_config": {
# For now we don't use the new ReplayBuffer API here
"_enable_replay_buffer_api": False,
"type": "MultiAgentReplayBuffer",
"capacity": 50000,
"replay_batch_size": 32,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
},
# Set this to True, if you want the contents of your buffer(s) to be
# stored in any saved checkpoints as well.
# Warnings will be created if:
# - This is True AND restoring from a checkpoint that contains no buffer
# data.
# - This is False AND restoring from a checkpoint that does contain
# buffer data.
"store_buffer_in_checkpoints": False,
# The number of contiguous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
# Callback to run before learning on a multi-agent batch of
# experiences.
@ -102,6 +121,12 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# === Parallelism ===
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": False,
},
_allow_unknown_configs=True,
)

View file

@ -451,10 +451,14 @@ def postprocess_nstep_and_prio(
batch[SampleBatch.DONES],
batch[PRIO_WEIGHTS],
)
new_priorities = (
np.abs(convert_to_numpy(td_errors))
+ policy.config["prioritized_replay_eps"]
)
# Retain compatibility with old-style Replay args
epsilon = policy.config.get("replay_buffer_config", {}).get(
"prioritized_replay_eps"
) or policy.config.get("prioritized_replay_eps")
if epsilon is None:
raise ValueError("prioritized_replay_eps not defined in config.")
new_priorities = np.abs(convert_to_numpy(td_errors)) + epsilon
batch[PRIO_WEIGHTS] = new_priorities
return batch

View file

@ -29,6 +29,19 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# Batch mode must be complete_episodes.
"batch_mode": "complete_episodes",
# === Replay buffer ===
"replay_buffer_config": {
# For now we don't use the new ReplayBuffer API here
"_enable_replay_buffer_api": False,
"type": "MultiAgentReplayBuffer",
"capacity": 50000,
"replay_batch_size": 32,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
},
# If True, assume a zero-initialized state input (no matter where in
# the episode the sequence is located).
# If False, store the initial states along with each SampleBatch, use
@ -66,6 +79,12 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 2500,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": False,
},
_allow_unknown_configs=True,
)

View file

@ -15,19 +15,40 @@ from typing import Optional, Type
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
MultiGPUTrainOneStep,
TrainOneStep,
MultiGPUTrainOneStep,
train_one_step,
multi_gpu_train_one_step,
)
from ray.rllib.execution.train_ops import (
UpdateTargetNetwork,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_SAMPLED,
NUM_AGENT_STEPS_SAMPLED,
)
from ray.rllib.utils.typing import (
ResultDict,
TrainerConfigDict,
)
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_TARGET_UPDATES,
)
logger = logging.getLogger(__name__)
@ -64,9 +85,18 @@ DEFAULT_CONFIG = with_common_config({
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": DEPRECATED_VALUE,
# Deprecated for Simple Q because of new ReplayBuffer API
# Use MultiAgentPrioritizedReplayBuffer for prioritization.
"prioritized_replay": DEPRECATED_VALUE,
"replay_buffer_config": {
# Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer",
"capacity": 50000,
"replay_batch_size": 32,
# The number of contiguous environment steps to replay at once. This
# may be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
},
# Set this to True, if you want the contents of your buffer(s) to be
# stored in any saved checkpoints as well.
@ -76,9 +106,6 @@ DEFAULT_CONFIG = with_common_config({
# - This is False AND restoring from a checkpoint that does contain
# buffer data.
"store_buffer_in_checkpoints": False,
# The number of contiguous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
# === Optimization ===
# Learning rate for adam optimizer
@ -108,6 +135,12 @@ DEFAULT_CONFIG = with_common_config({
"num_workers": 0,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": True,
})
# __sphinx_doc_end__
# fmt: on
@ -139,7 +172,9 @@ class SimpleQTrainer(Trainer):
" used at the same time!"
)
if config.get("prioritized_replay"):
if config.get("prioritized_replay") or config.get(
"replay_buffer_config", {}
).get("prioritized_replay"):
if config["multiagent"]["replay_mode"] == "lockstep":
raise ValueError(
"Prioritized replay is not supported when replay_mode=lockstep."
@ -215,3 +250,63 @@ class SimpleQTrainer(Trainer):
)
return StandardMetricsReporting(train_op, workers, config)
@ExperimentalAPI
def training_iteration(self) -> ResultDict:
"""Simple Q training iteration function.
Simple Q consists of the following steps:
- (1) Sample (MultiAgentBatch) from workers...
- (2) Store new samples in replay buffer.
- (3) Sample training batch (MultiAgentBatch) from replay buffer.
- (4) Learn on training batch.
- (5) Update target network every target_network_update_freq steps.
- (6) Return all collected metrics for the iteration.
Returns:
The results dict from executing the training iteration.
"""
batch_size = self.config["train_batch_size"]
local_worker = self.workers.local_worker()
# (1) Sample (MultiAgentBatch) from workers
new_sample_batches = synchronous_parallel_sample(self.workers)
for s in new_sample_batches:
# Update counters
self._counters[NUM_ENV_STEPS_SAMPLED] += len(s)
self._counters[NUM_AGENT_STEPS_SAMPLED] += (
len(s) if isinstance(s, SampleBatch) else s.agent_steps()
)
# (2) Store new samples in replay buffer
self.local_replay_buffer.add(s)
# (3) Sample training batch (MultiAgentBatch) from replay buffer.
train_batch = self.local_replay_buffer.sample(batch_size)
# (4) Learn on training batch.
# Use simple optimizer (only for multi-agent or tf-eager; all other
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
if self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
# (5) Update target network every target_network_update_freq steps
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
last_update = self._counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update >= self.config["target_network_update_freq"]:
to_update = local_worker.get_policies_to_train()
local_worker.foreach_policy_to_train(
lambda p, pid: pid in to_update and p.update_target()
)
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
# Update remote workers' weights after learning on local worker
if self.workers.remote_workers():
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights()
# (6) Return all collected metrics for the iteration.
return train_results

View file

@ -109,8 +109,6 @@ DEFAULT_CONFIG = with_common_config({
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
"prioritized_replay_beta_annealing_timesteps": 20000,
"final_prioritized_replay_beta": 0.4,
# Whether to LZ4 compress observations
"compress_observations": False,

View file

@ -85,7 +85,7 @@ class TestSAC(unittest.TestCase):
# If we use default buffer size (1e6), the buffer will take up
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
# available system memory (8.34816 GB).
config["buffer_size"] = 40000
config["replay_buffer_config"]["capacity"] = 40000
# Test with saved replay buffer.
config["store_buffer_in_checkpoints"] = True
num_iterations = 1

View file

@ -361,6 +361,26 @@ COMMON_CONFIG: TrainerConfigDict = {
# "env_config": {...},
# "explore": False
},
# === Replay Buffer Settings ===
# Provide a dict specifying the ReplayBuffer's config.
# "replay_buffer_config": {
# The ReplayBuffer class to use. Any class that obeys the
# ReplayBuffer API can be used here. In the simplest case, this is the
# name (str) of any class present in the `rllib.utils.replay_buffers`
# package. You can also provide the python class directly or the
# full location of your class (e.g.
# "ray.rllib.utils.replay_buffers.replay_buffer.ReplayBuffer").
# "type": "ReplayBuffer",
# The capacity of units that can be stored in one ReplayBuffer
# instance before eviction.
# "capacity": 10000,
# Specifies how experiences are stored. Either 'sequences' or
# 'timesteps'.
# "storage_unit": "timesteps",
# Add constructor kwargs here (if any).
# },
# Number of parallel workers to use for evaluation. Note that this is set
# to zero by default, which means evaluation will be run in the trainer
# process (only if evaluation_interval is not None). If you increase this,
@ -652,6 +672,8 @@ COMMON_CONFIG: TrainerConfigDict = {
# Use `metrics_episode_collection_timeout_s` instead.
"collect_metrics_timeout": DEPRECATED_VALUE,
}
# __sphinx_doc_end__
# fmt: on
@ -719,7 +741,7 @@ class Trainer(Trainable):
"custom_resources_per_worker",
"evaluation_config",
"exploration_config",
"extra_python_environs_for_driver",
"replay_buffer_config",
"extra_python_environs_for_worker",
"input_config",
"output_config",
@ -727,7 +749,10 @@ class Trainer(Trainable):
# List of top level keys with value=dict, for which we always override the
# entire value (dict), iff the "type" key in that value dict changes.
_override_all_subkeys_if_type_changes = ["exploration_config"]
_override_all_subkeys_if_type_changes = [
"exploration_config",
"replay_buffer_config",
]
# TODO: Deprecate. Instead, override `Trainer.get_default_config()`.
_default_config = COMMON_CONFIG
@ -2724,58 +2749,147 @@ class Trainer(Trainable):
MultiAgentReplayBuffer instance based on trainer config.
None, if local replay buffer is not needed.
"""
# These are the agents that utilizes a local replay buffer.
if "replay_buffer_config" not in config or not config["replay_buffer_config"]:
# Does not need a replay buffer.
return None
# Deprecation of old-style replay buffer args
# Warnings before checking of we need local buffer so that algorithms
# Without local buffer also get warned
deprecated_replay_buffer_keys = [
"prioritized_replay_alpha",
"prioritized_replay_beta",
"prioritized_replay_eps",
"learning_starts",
]
for k in deprecated_replay_buffer_keys:
if config.get(k) is not None:
deprecation_warning(
old="config[{}]".format(k),
help="config['replay_buffer_config'][{}] should be used "
"for Q-Learning algorithms. Ignore this warning if "
"you are not using a Q-Learning algorithm and still "
"provide {}."
"".format(k, k),
error=False,
)
# Copy values over to new location in config to support new
# and old configuration style
if config.get("replay_buffer_config") is not None:
config["replay_buffer_config"][k] = config[k]
# Some agents do not need a replay buffer
if not config.get("replay_buffer_config") or config.get(
"no_local_replay_buffer", False
):
return
replay_buffer_config = config["replay_buffer_config"]
if (
"type" not in replay_buffer_config
or replay_buffer_config["type"] != "MultiAgentReplayBuffer"
):
# DistributedReplayBuffer coming soon.
return None
assert (
"type" in replay_buffer_config
), "Can not instantiate ReplayBuffer from config without 'type' key."
capacity = config.get("buffer_size", DEPRECATED_VALUE)
if capacity != DEPRECATED_VALUE:
# Print a deprecation warning.
deprecation_warning(
old="config['buffer_size']",
new="config['replay_buffer_config']['capacity']",
help="Buffer size specified at new location config["
"'replay_buffer_config']["
"'capacity'] will be overwritten.",
error=False,
)
else:
# Get capacity out of replay_buffer_config.
capacity = replay_buffer_config["capacity"]
config["replay_buffer_config"]["capacity"] = capacity
# Configure prio. replay parameters.
if config.get("prioritized_replay"):
prio_args = {
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
"prioritized_replay_beta": config["prioritized_replay_beta"],
"prioritized_replay_eps": config["prioritized_replay_eps"],
}
# Switch off prioritization (alpha=0.0).
else:
prio_args = {"prioritized_replay_alpha": 0.0}
# Check if old replay buffer should be instantiated
buffer_type = config["replay_buffer_config"]["type"]
if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Prepend old-style buffers' path
assert buffer_type == "MultiAgentReplayBuffer", (
"Without "
"ReplayBuffer "
"API, only "
"MultiAgentReplayBuffer "
"is supported!"
)
# Create valid full [module].[class] string for from_config
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
else:
assert buffer_type in [
"ray.rllib.execution.MultiAgentReplayBuffer",
MultiAgentReplayBuffer,
], (
"Without ReplayBuffer API, only "
"MultiAgentReplayBuffer is supported!"
)
return MultiAgentReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
capacity=capacity,
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config.get("replay_sequence_length", 1),
replay_burn_in=config.get("burn_in", 0),
replay_zero_init_states=config.get("zero_init_states", True),
**prio_args,
)
config["replay_buffer_config"]["type"] = buffer_type
# Remove from config so it's not passed into the buffer c'tor
config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
# We need to deprecate the old-style location of the following
# buffer arguments and make users put them into the
# "replay_buffer_config" field of their config.
config["replay_buffer_config"]["replay_batch_size"] = config[
"train_batch_size"
]
config["replay_buffer_config"]["replay_mode"] = config["multiagent"][
"replay_mode"
]
deprecation_warning(
old="config['multiagent']['replay_mode']",
new="config['replay_buffer_config']['replay_mode']",
error=False,
)
config["replay_buffer_config"]["replay_sequence_length"] = config.get(
"replay_sequence_length", 1
)
if config.get("replay_sequence_length"):
deprecation_warning(
old="config['replay_sequence_length']",
new="config['replay_buffer_config']['replay_sequence_length']",
error=False,
)
config["replay_buffer_config"]["replay_burn_in"] = config.get(
"replay_burn_in", 0
)
if config.get("burn_in"):
deprecation_warning(
old="config['burn_in']",
help="Burn in specified at new location config["
"'replay_buffer_config']["
"'replay_burn_in'] will be overwritten.",
)
config["replay_buffer_config"]["replay_burn_in"] = config["burn_in"]
config["replay_buffer_config"]["replay_zero_init_states"] = config.get(
"replay_zero_init_states", True
)
if config.get("replay_zero_init_states"):
deprecation_warning(
old="config['replay_zero_init_states']",
new="config['replay_buffer_config']['replay_zero_init_states']",
error=False,
)
# If no prioritized replay, old-style replay buffer should
# not be handed the following parameters:
if config.get("prioritized_replay", False) is False:
# This triggers non-prioritization in old-style replay buffer
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
else:
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Create valid full [module].[class] string for from_config
buffer_type = "ray.rllib.utils.replay_buffers." + buffer_type
config["replay_buffer_config"]["type"] = buffer_type
return from_config(buffer_type, config["replay_buffer_config"])
@DeveloperAPI
def _kwargs_for_execution_plan(self):
kwargs = {}
if self.local_replay_buffer:
if self.local_replay_buffer is not None:
kwargs["local_replay_buffer"] = self.local_replay_buffer
return kwargs

View file

@ -31,7 +31,6 @@ if __name__ == "__main__":
"n_step": 3,
"lr": 0.0001,
"prioritized_replay_alpha": 0.5,
"final_prioritized_replay_beta": 1.0,
"target_network_update_freq": 50000,
"timesteps_per_iteration": 25000,
# Method specific.

View file

@ -91,8 +91,6 @@ apex:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 1
num_workers: 8
num_envs_per_worker: 8
@ -141,7 +139,5 @@ atari-basic-dqn:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.2
timesteps_per_iteration: 10000

View file

@ -20,8 +20,6 @@ apex:
hiddens: [512]
buffer_size: 1000000
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 1

View file

@ -25,7 +25,5 @@ atari-dist-dqn:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.2
timesteps_per_iteration: 10000

View file

@ -29,7 +29,5 @@ atari-basic-dqn:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.2
timesteps_per_iteration: 10000

View file

@ -29,7 +29,5 @@ dueling-ddqn:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.2
timesteps_per_iteration: 10000

View file

@ -19,8 +19,6 @@ pong-deterministic-rainbow:
target_network_update_freq: 500
prioritized_replay: True
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 400000
n_step: 3
gpu: True
model:

View file

@ -0,0 +1,26 @@
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer,
ReplayMode,
)
from ray.rllib.utils.replay_buffers.reservoir_buffer import ReservoirBuffer
from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import (
PrioritizedReplayBuffer,
)
from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import (
MultiAgentMixInReplayBuffer,
)
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
MultiAgentPrioritizedReplayBuffer,
)
__all__ = [
"ReplayBuffer",
"StorageUnit",
"MultiAgentReplayBuffer",
"ReplayMode",
"ReservoirBuffer",
"PrioritizedReplayBuffer",
"MultiAgentMixInReplayBuffer",
"MultiAgentPrioritizedReplayBuffer",
]

View file

@ -14,6 +14,7 @@ from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit
from ray.rllib.utils.from_config import from_config
from ray.util.debug import log_once
from ray.rllib.utils.deprecation import Deprecated
logger = logging.getLogger(__name__)
@ -81,7 +82,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
'episodes'. Specifies how experiences are stored. If they
are stored in episodes, replay_sequence_length is ignored.
learning_starts: Number of timesteps after which a call to
`replay()` will yield samples (before that, `replay()` will
`sample()` will yield samples (before that, `sample()` will
return None).
capacity: Max number of total timesteps in all policy buffers.
After reaching this number, older samples will be
@ -170,6 +171,14 @@ class MultiAgentReplayBuffer(ReplayBuffer):
"""Returns the number of items currently stored in this buffer."""
return sum(len(buffer._storage) for buffer in self.replay_buffers.values())
@ExperimentalAPI
@Deprecated(old="replay", new="sample", error=False)
def replay(self, num_items: int = None, **kwargs) -> Optional[SampleBatchType]:
"""Deprecated in favor of new ReplayBuffer API."""
if num_items is None:
num_items = self.replay_batch_size
return self.sample(num_items, **kwargs)
@ExperimentalAPI
@override(ReplayBuffer)
def add(self, batch: SampleBatchType, **kwargs) -> None:
@ -262,7 +271,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)
if self._num_added < self.replay_starts:
return None
return MultiAgentBatch({}, 0)
with self.replay_timer:
# Lockstep mode: Sample from all policies at the same time an
# equal amount of steps.

View file

@ -1,17 +1,19 @@
import logging
import platform
from typing import Any, Dict, List, Optional
import numpy as np
import random
from enum import Enum
from ray.util.debug import log_once
# Import ray before psutil will make sure we use psutil's bundled version
import ray # noqa F401
import psutil # noqa E402
from ray.util.debug import log_once
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.window_stat import WindowStat
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
@ -94,6 +96,18 @@ class ReplayBuffer:
"""Returns the number of items currently stored in this buffer."""
return len(self._storage)
@ExperimentalAPI
@Deprecated(old="add_batch", new="add", error=False)
def add_batch(self, batch: SampleBatchType, **kwargs) -> None:
"""Deprecated in favor of new ReplayBuffer API."""
return self.add(batch, **kwargs)
@ExperimentalAPI
@Deprecated(old="replay", new="sample", error=False)
def replay(self, num_items: int = 1, **kwargs) -> Optional[SampleBatchType]:
"""Deprecated in favor of new ReplayBuffer API."""
return self.sample(num_items, **kwargs)
@ExperimentalAPI
def add(self, batch: SampleBatchType, **kwargs) -> None:
"""Adds a batch of experiences to this buffer.