mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Simple-Q uses training iteration fn (instead of execution_plan); ReplayBuffer API for Simple-Q (#22842)
This commit is contained in:
parent
a7e5aa8c6a
commit
9a64bd4e9b
22 changed files with 377 additions and 98 deletions
|
@ -80,8 +80,6 @@ apex-breakoutnoframeskip-v4:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
||||
num_gpus: 1
|
||||
num_workers: 8
|
||||
num_envs_per_worker: 8
|
||||
|
@ -327,8 +325,6 @@ dqn-breakoutnoframeskip-v4:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
||||
num_gpus: 0.5
|
||||
timesteps_per_iteration: 10000
|
||||
|
||||
|
|
|
@ -53,8 +53,6 @@ apex-breakoutnoframeskip-v4:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
||||
num_gpus: 1
|
||||
num_workers: 8
|
||||
num_envs_per_worker: 8
|
||||
|
|
|
@ -19,7 +19,7 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
|||
"num_workers": 32,
|
||||
"buffer_size": 2000000,
|
||||
# TODO(jungong) : update once Apex supports replay_buffer_config.
|
||||
"replay_buffer_config": None,
|
||||
"no_local_replay_buffer": True,
|
||||
# Whether all shards of the replay buffer must be co-located
|
||||
# with the learner process (running the execution plan).
|
||||
# This is preferred b/c the learner process should have quick
|
||||
|
|
|
@ -111,10 +111,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
"prioritized_replay_beta": 0.4,
|
||||
# Time steps over which the beta parameter is annealed.
|
||||
"prioritized_replay_beta_annealing_timesteps": 20000,
|
||||
# Final value of beta
|
||||
"final_prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
# Whether to LZ4 compress observations
|
||||
|
|
|
@ -66,7 +66,7 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
|||
"buffer_size": 2000000,
|
||||
# TODO(jungong) : add proper replay_buffer_config after
|
||||
# DistributedReplayBuffer type is supported.
|
||||
"replay_buffer_config": None,
|
||||
"no_local_replay_buffer": True,
|
||||
# Whether all shards of the replay buffer must be co-located
|
||||
# with the learner process (running the execution plan).
|
||||
# This is preferred b/c the learner process should have quick
|
||||
|
@ -157,9 +157,9 @@ class ApexTrainer(DQNTrainer):
|
|||
config["learning_starts"],
|
||||
config["buffer_size"],
|
||||
config["train_batch_size"],
|
||||
config["prioritized_replay_alpha"],
|
||||
config["prioritized_replay_beta"],
|
||||
config["prioritized_replay_eps"],
|
||||
config["replay_buffer_config"]["prioritized_replay_alpha"],
|
||||
config["replay_buffer_config"]["prioritized_replay_beta"],
|
||||
config["replay_buffer_config"]["prioritized_replay_eps"],
|
||||
config["multiagent"]["replay_mode"],
|
||||
config.get("replay_sequence_length", 1),
|
||||
]
|
||||
|
|
|
@ -35,6 +35,7 @@ from ray.rllib.utils.deprecation import Deprecated
|
|||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -64,19 +65,37 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
# N-step Q learning
|
||||
"n_step": 1,
|
||||
|
||||
# === Prioritized replay buffer ===
|
||||
# If True prioritized replay buffer will be used.
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
# each worker will have a replay buffer of this size.
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
# Prioritized replay is here since this algo uses the old replay
|
||||
# buffer api
|
||||
"prioritized_replay": True,
|
||||
# Alpha parameter for prioritized replay buffer.
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
"prioritized_replay_beta": 0.4,
|
||||
# Final value of beta (by default, we use constant beta=0.4).
|
||||
"final_prioritized_replay_beta": 0.4,
|
||||
# Time steps over which the beta parameter is annealed.
|
||||
"prioritized_replay_beta_annealing_timesteps": 20000,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
"replay_buffer_config": {
|
||||
# For now we don't use the new ReplayBuffer API here
|
||||
"_enable_replay_buffer_api": False,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"capacity": 50000,
|
||||
"replay_batch_size": 32,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
"prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
},
|
||||
# Set this to True, if you want the contents of your buffer(s) to be
|
||||
# stored in any saved checkpoints as well.
|
||||
# Warnings will be created if:
|
||||
# - This is True AND restoring from a checkpoint that contains no buffer
|
||||
# data.
|
||||
# - This is False AND restoring from a checkpoint that does contain
|
||||
# buffer data.
|
||||
"store_buffer_in_checkpoints": False,
|
||||
# The number of contiguous environment steps to replay at once. This may
|
||||
# be set to greater than 1 to support recurrent models.
|
||||
"replay_sequence_length": 1,
|
||||
|
||||
|
||||
# Callback to run before learning on a multi-agent batch of
|
||||
# experiences.
|
||||
|
@ -102,6 +121,12 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
# === Parallelism ===
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
|
||||
# Experimental flag.
|
||||
# If True, the execution plan API will not be used. Instead,
|
||||
# a Trainer's `training_iteration` method will be called as-is each
|
||||
# training iteration.
|
||||
"_disable_execution_plan_api": False,
|
||||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
|
|
|
@ -451,10 +451,14 @@ def postprocess_nstep_and_prio(
|
|||
batch[SampleBatch.DONES],
|
||||
batch[PRIO_WEIGHTS],
|
||||
)
|
||||
new_priorities = (
|
||||
np.abs(convert_to_numpy(td_errors))
|
||||
+ policy.config["prioritized_replay_eps"]
|
||||
)
|
||||
# Retain compatibility with old-style Replay args
|
||||
epsilon = policy.config.get("replay_buffer_config", {}).get(
|
||||
"prioritized_replay_eps"
|
||||
) or policy.config.get("prioritized_replay_eps")
|
||||
if epsilon is None:
|
||||
raise ValueError("prioritized_replay_eps not defined in config.")
|
||||
|
||||
new_priorities = np.abs(convert_to_numpy(td_errors)) + epsilon
|
||||
batch[PRIO_WEIGHTS] = new_priorities
|
||||
|
||||
return batch
|
||||
|
|
|
@ -29,6 +29,19 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
# Batch mode must be complete_episodes.
|
||||
"batch_mode": "complete_episodes",
|
||||
|
||||
# === Replay buffer ===
|
||||
"replay_buffer_config": {
|
||||
# For now we don't use the new ReplayBuffer API here
|
||||
"_enable_replay_buffer_api": False,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"capacity": 50000,
|
||||
"replay_batch_size": 32,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
"prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
},
|
||||
# If True, assume a zero-initialized state input (no matter where in
|
||||
# the episode the sequence is located).
|
||||
# If False, store the initial states along with each SampleBatch, use
|
||||
|
@ -66,6 +79,12 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
"target_network_update_freq": 2500,
|
||||
|
||||
# Experimental flag.
|
||||
# If True, the execution plan API will not be used. Instead,
|
||||
# a Trainer's `training_iteration` method will be called as-is each
|
||||
# training iteration.
|
||||
"_disable_execution_plan_api": False,
|
||||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
|
|
|
@ -15,19 +15,40 @@ from typing import Optional, Type
|
|||
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
||||
from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
MultiGPUTrainOneStep,
|
||||
TrainOneStep,
|
||||
MultiGPUTrainOneStep,
|
||||
train_one_step,
|
||||
multi_gpu_train_one_step,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
UpdateTargetNetwork,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.metrics import (
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
)
|
||||
from ray.rllib.utils.metrics import (
|
||||
LAST_TARGET_UPDATE_TS,
|
||||
NUM_TARGET_UPDATES,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -64,9 +85,18 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
# each worker will have a replay buffer of this size.
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
# Deprecated for Simple Q because of new ReplayBuffer API
|
||||
# Use MultiAgentPrioritizedReplayBuffer for prioritization.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"replay_buffer_config": {
|
||||
# Use the new ReplayBuffer API here
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"capacity": 50000,
|
||||
"replay_batch_size": 32,
|
||||
# The number of contiguous environment steps to replay at once. This
|
||||
# may be set to greater than 1 to support recurrent models.
|
||||
"replay_sequence_length": 1,
|
||||
},
|
||||
# Set this to True, if you want the contents of your buffer(s) to be
|
||||
# stored in any saved checkpoints as well.
|
||||
|
@ -76,9 +106,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# - This is False AND restoring from a checkpoint that does contain
|
||||
# buffer data.
|
||||
"store_buffer_in_checkpoints": False,
|
||||
# The number of contiguous environment steps to replay at once. This may
|
||||
# be set to greater than 1 to support recurrent models.
|
||||
"replay_sequence_length": 1,
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for adam optimizer
|
||||
|
@ -108,6 +135,12 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_workers": 0,
|
||||
# Prevent reporting frequency from going lower than this time span.
|
||||
"min_time_s_per_reporting": 1,
|
||||
|
||||
# Experimental flag.
|
||||
# If True, the execution plan API will not be used. Instead,
|
||||
# a Trainer's `training_iteration` method will be called as-is each
|
||||
# training iteration.
|
||||
"_disable_execution_plan_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
@ -139,7 +172,9 @@ class SimpleQTrainer(Trainer):
|
|||
" used at the same time!"
|
||||
)
|
||||
|
||||
if config.get("prioritized_replay"):
|
||||
if config.get("prioritized_replay") or config.get(
|
||||
"replay_buffer_config", {}
|
||||
).get("prioritized_replay"):
|
||||
if config["multiagent"]["replay_mode"] == "lockstep":
|
||||
raise ValueError(
|
||||
"Prioritized replay is not supported when replay_mode=lockstep."
|
||||
|
@ -215,3 +250,63 @@ class SimpleQTrainer(Trainer):
|
|||
)
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
@ExperimentalAPI
|
||||
def training_iteration(self) -> ResultDict:
|
||||
"""Simple Q training iteration function.
|
||||
|
||||
Simple Q consists of the following steps:
|
||||
- (1) Sample (MultiAgentBatch) from workers...
|
||||
- (2) Store new samples in replay buffer.
|
||||
- (3) Sample training batch (MultiAgentBatch) from replay buffer.
|
||||
- (4) Learn on training batch.
|
||||
- (5) Update target network every target_network_update_freq steps.
|
||||
- (6) Return all collected metrics for the iteration.
|
||||
|
||||
Returns:
|
||||
The results dict from executing the training iteration.
|
||||
"""
|
||||
batch_size = self.config["train_batch_size"]
|
||||
local_worker = self.workers.local_worker()
|
||||
|
||||
# (1) Sample (MultiAgentBatch) from workers
|
||||
new_sample_batches = synchronous_parallel_sample(self.workers)
|
||||
|
||||
for s in new_sample_batches:
|
||||
# Update counters
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += len(s)
|
||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += (
|
||||
len(s) if isinstance(s, SampleBatch) else s.agent_steps()
|
||||
)
|
||||
# (2) Store new samples in replay buffer
|
||||
self.local_replay_buffer.add(s)
|
||||
|
||||
# (3) Sample training batch (MultiAgentBatch) from replay buffer.
|
||||
train_batch = self.local_replay_buffer.sample(batch_size)
|
||||
|
||||
# (4) Learn on training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer") is True:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# (5) Update target network every target_network_update_freq steps
|
||||
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||
to_update = local_worker.get_policies_to_train()
|
||||
local_worker.foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
# Update remote workers' weights after learning on local worker
|
||||
if self.workers.remote_workers():
|
||||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
||||
self.workers.sync_weights()
|
||||
|
||||
# (6) Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
|
|
@ -109,8 +109,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"prioritized_replay_alpha": 0.6,
|
||||
"prioritized_replay_beta": 0.4,
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
"prioritized_replay_beta_annealing_timesteps": 20000,
|
||||
"final_prioritized_replay_beta": 0.4,
|
||||
# Whether to LZ4 compress observations
|
||||
"compress_observations": False,
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ class TestSAC(unittest.TestCase):
|
|||
# If we use default buffer size (1e6), the buffer will take up
|
||||
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
|
||||
# available system memory (8.34816 GB).
|
||||
config["buffer_size"] = 40000
|
||||
config["replay_buffer_config"]["capacity"] = 40000
|
||||
# Test with saved replay buffer.
|
||||
config["store_buffer_in_checkpoints"] = True
|
||||
num_iterations = 1
|
||||
|
|
|
@ -361,6 +361,26 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# "env_config": {...},
|
||||
# "explore": False
|
||||
},
|
||||
|
||||
# === Replay Buffer Settings ===
|
||||
# Provide a dict specifying the ReplayBuffer's config.
|
||||
# "replay_buffer_config": {
|
||||
# The ReplayBuffer class to use. Any class that obeys the
|
||||
# ReplayBuffer API can be used here. In the simplest case, this is the
|
||||
# name (str) of any class present in the `rllib.utils.replay_buffers`
|
||||
# package. You can also provide the python class directly or the
|
||||
# full location of your class (e.g.
|
||||
# "ray.rllib.utils.replay_buffers.replay_buffer.ReplayBuffer").
|
||||
# "type": "ReplayBuffer",
|
||||
# The capacity of units that can be stored in one ReplayBuffer
|
||||
# instance before eviction.
|
||||
# "capacity": 10000,
|
||||
# Specifies how experiences are stored. Either 'sequences' or
|
||||
# 'timesteps'.
|
||||
# "storage_unit": "timesteps",
|
||||
# Add constructor kwargs here (if any).
|
||||
# },
|
||||
|
||||
# Number of parallel workers to use for evaluation. Note that this is set
|
||||
# to zero by default, which means evaluation will be run in the trainer
|
||||
# process (only if evaluation_interval is not None). If you increase this,
|
||||
|
@ -652,6 +672,8 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# Use `metrics_episode_collection_timeout_s` instead.
|
||||
"collect_metrics_timeout": DEPRECATED_VALUE,
|
||||
}
|
||||
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
@ -719,7 +741,7 @@ class Trainer(Trainable):
|
|||
"custom_resources_per_worker",
|
||||
"evaluation_config",
|
||||
"exploration_config",
|
||||
"extra_python_environs_for_driver",
|
||||
"replay_buffer_config",
|
||||
"extra_python_environs_for_worker",
|
||||
"input_config",
|
||||
"output_config",
|
||||
|
@ -727,7 +749,10 @@ class Trainer(Trainable):
|
|||
|
||||
# List of top level keys with value=dict, for which we always override the
|
||||
# entire value (dict), iff the "type" key in that value dict changes.
|
||||
_override_all_subkeys_if_type_changes = ["exploration_config"]
|
||||
_override_all_subkeys_if_type_changes = [
|
||||
"exploration_config",
|
||||
"replay_buffer_config",
|
||||
]
|
||||
|
||||
# TODO: Deprecate. Instead, override `Trainer.get_default_config()`.
|
||||
_default_config = COMMON_CONFIG
|
||||
|
@ -2724,58 +2749,147 @@ class Trainer(Trainable):
|
|||
MultiAgentReplayBuffer instance based on trainer config.
|
||||
None, if local replay buffer is not needed.
|
||||
"""
|
||||
# These are the agents that utilizes a local replay buffer.
|
||||
if "replay_buffer_config" not in config or not config["replay_buffer_config"]:
|
||||
# Does not need a replay buffer.
|
||||
return None
|
||||
# Deprecation of old-style replay buffer args
|
||||
# Warnings before checking of we need local buffer so that algorithms
|
||||
# Without local buffer also get warned
|
||||
deprecated_replay_buffer_keys = [
|
||||
"prioritized_replay_alpha",
|
||||
"prioritized_replay_beta",
|
||||
"prioritized_replay_eps",
|
||||
"learning_starts",
|
||||
]
|
||||
for k in deprecated_replay_buffer_keys:
|
||||
if config.get(k) is not None:
|
||||
deprecation_warning(
|
||||
old="config[{}]".format(k),
|
||||
help="config['replay_buffer_config'][{}] should be used "
|
||||
"for Q-Learning algorithms. Ignore this warning if "
|
||||
"you are not using a Q-Learning algorithm and still "
|
||||
"provide {}."
|
||||
"".format(k, k),
|
||||
error=False,
|
||||
)
|
||||
# Copy values over to new location in config to support new
|
||||
# and old configuration style
|
||||
if config.get("replay_buffer_config") is not None:
|
||||
config["replay_buffer_config"][k] = config[k]
|
||||
|
||||
# Some agents do not need a replay buffer
|
||||
if not config.get("replay_buffer_config") or config.get(
|
||||
"no_local_replay_buffer", False
|
||||
):
|
||||
return
|
||||
|
||||
replay_buffer_config = config["replay_buffer_config"]
|
||||
if (
|
||||
"type" not in replay_buffer_config
|
||||
or replay_buffer_config["type"] != "MultiAgentReplayBuffer"
|
||||
):
|
||||
# DistributedReplayBuffer coming soon.
|
||||
return None
|
||||
assert (
|
||||
"type" in replay_buffer_config
|
||||
), "Can not instantiate ReplayBuffer from config without 'type' key."
|
||||
|
||||
capacity = config.get("buffer_size", DEPRECATED_VALUE)
|
||||
if capacity != DEPRECATED_VALUE:
|
||||
# Print a deprecation warning.
|
||||
deprecation_warning(
|
||||
old="config['buffer_size']",
|
||||
new="config['replay_buffer_config']['capacity']",
|
||||
help="Buffer size specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'capacity'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
else:
|
||||
# Get capacity out of replay_buffer_config.
|
||||
capacity = replay_buffer_config["capacity"]
|
||||
config["replay_buffer_config"]["capacity"] = capacity
|
||||
|
||||
# Configure prio. replay parameters.
|
||||
if config.get("prioritized_replay"):
|
||||
prio_args = {
|
||||
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
|
||||
"prioritized_replay_beta": config["prioritized_replay_beta"],
|
||||
"prioritized_replay_eps": config["prioritized_replay_eps"],
|
||||
}
|
||||
# Switch off prioritization (alpha=0.0).
|
||||
else:
|
||||
prio_args = {"prioritized_replay_alpha": 0.0}
|
||||
# Check if old replay buffer should be instantiated
|
||||
buffer_type = config["replay_buffer_config"]["type"]
|
||||
if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
|
||||
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
||||
# Prepend old-style buffers' path
|
||||
assert buffer_type == "MultiAgentReplayBuffer", (
|
||||
"Without "
|
||||
"ReplayBuffer "
|
||||
"API, only "
|
||||
"MultiAgentReplayBuffer "
|
||||
"is supported!"
|
||||
)
|
||||
# Create valid full [module].[class] string for from_config
|
||||
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
|
||||
else:
|
||||
assert buffer_type in [
|
||||
"ray.rllib.execution.MultiAgentReplayBuffer",
|
||||
MultiAgentReplayBuffer,
|
||||
], (
|
||||
"Without ReplayBuffer API, only "
|
||||
"MultiAgentReplayBuffer is supported!"
|
||||
)
|
||||
|
||||
return MultiAgentReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=config["learning_starts"],
|
||||
capacity=capacity,
|
||||
replay_batch_size=config["train_batch_size"],
|
||||
replay_mode=config["multiagent"]["replay_mode"],
|
||||
replay_sequence_length=config.get("replay_sequence_length", 1),
|
||||
replay_burn_in=config.get("burn_in", 0),
|
||||
replay_zero_init_states=config.get("zero_init_states", True),
|
||||
**prio_args,
|
||||
)
|
||||
config["replay_buffer_config"]["type"] = buffer_type
|
||||
|
||||
# Remove from config so it's not passed into the buffer c'tor
|
||||
config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
|
||||
|
||||
# We need to deprecate the old-style location of the following
|
||||
# buffer arguments and make users put them into the
|
||||
# "replay_buffer_config" field of their config.
|
||||
config["replay_buffer_config"]["replay_batch_size"] = config[
|
||||
"train_batch_size"
|
||||
]
|
||||
config["replay_buffer_config"]["replay_mode"] = config["multiagent"][
|
||||
"replay_mode"
|
||||
]
|
||||
deprecation_warning(
|
||||
old="config['multiagent']['replay_mode']",
|
||||
new="config['replay_buffer_config']['replay_mode']",
|
||||
error=False,
|
||||
)
|
||||
|
||||
config["replay_buffer_config"]["replay_sequence_length"] = config.get(
|
||||
"replay_sequence_length", 1
|
||||
)
|
||||
if config.get("replay_sequence_length"):
|
||||
deprecation_warning(
|
||||
old="config['replay_sequence_length']",
|
||||
new="config['replay_buffer_config']['replay_sequence_length']",
|
||||
error=False,
|
||||
)
|
||||
|
||||
config["replay_buffer_config"]["replay_burn_in"] = config.get(
|
||||
"replay_burn_in", 0
|
||||
)
|
||||
|
||||
if config.get("burn_in"):
|
||||
deprecation_warning(
|
||||
old="config['burn_in']",
|
||||
help="Burn in specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'replay_burn_in'] will be overwritten.",
|
||||
)
|
||||
config["replay_buffer_config"]["replay_burn_in"] = config["burn_in"]
|
||||
|
||||
config["replay_buffer_config"]["replay_zero_init_states"] = config.get(
|
||||
"replay_zero_init_states", True
|
||||
)
|
||||
if config.get("replay_zero_init_states"):
|
||||
deprecation_warning(
|
||||
old="config['replay_zero_init_states']",
|
||||
new="config['replay_buffer_config']['replay_zero_init_states']",
|
||||
error=False,
|
||||
)
|
||||
|
||||
# If no prioritized replay, old-style replay buffer should
|
||||
# not be handed the following parameters:
|
||||
if config.get("prioritized_replay", False) is False:
|
||||
# This triggers non-prioritization in old-style replay buffer
|
||||
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
|
||||
|
||||
else:
|
||||
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
||||
# Create valid full [module].[class] string for from_config
|
||||
buffer_type = "ray.rllib.utils.replay_buffers." + buffer_type
|
||||
config["replay_buffer_config"]["type"] = buffer_type
|
||||
|
||||
return from_config(buffer_type, config["replay_buffer_config"])
|
||||
|
||||
@DeveloperAPI
|
||||
def _kwargs_for_execution_plan(self):
|
||||
kwargs = {}
|
||||
if self.local_replay_buffer:
|
||||
if self.local_replay_buffer is not None:
|
||||
kwargs["local_replay_buffer"] = self.local_replay_buffer
|
||||
return kwargs
|
||||
|
||||
|
|
|
@ -31,7 +31,6 @@ if __name__ == "__main__":
|
|||
"n_step": 3,
|
||||
"lr": 0.0001,
|
||||
"prioritized_replay_alpha": 0.5,
|
||||
"final_prioritized_replay_beta": 1.0,
|
||||
"target_network_update_freq": 50000,
|
||||
"timesteps_per_iteration": 25000,
|
||||
# Method specific.
|
||||
|
|
|
@ -91,8 +91,6 @@ apex:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
||||
num_gpus: 1
|
||||
num_workers: 8
|
||||
num_envs_per_worker: 8
|
||||
|
@ -141,7 +139,5 @@ atari-basic-dqn:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
||||
num_gpus: 0.2
|
||||
timesteps_per_iteration: 10000
|
||||
|
|
|
@ -20,8 +20,6 @@ apex:
|
|||
hiddens: [512]
|
||||
buffer_size: 1000000
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
||||
|
||||
num_gpus: 1
|
||||
|
||||
|
|
|
@ -25,7 +25,5 @@ atari-dist-dqn:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
||||
num_gpus: 0.2
|
||||
timesteps_per_iteration: 10000
|
||||
|
|
|
@ -29,7 +29,5 @@ atari-basic-dqn:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
||||
num_gpus: 0.2
|
||||
timesteps_per_iteration: 10000
|
||||
|
|
|
@ -29,7 +29,5 @@ dueling-ddqn:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
||||
num_gpus: 0.2
|
||||
timesteps_per_iteration: 10000
|
||||
|
|
|
@ -19,8 +19,6 @@ pong-deterministic-rainbow:
|
|||
target_network_update_freq: 500
|
||||
prioritized_replay: True
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 400000
|
||||
n_step: 3
|
||||
gpu: True
|
||||
model:
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer,
|
||||
ReplayMode,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.reservoir_buffer import ReservoirBuffer
|
||||
from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import (
|
||||
PrioritizedReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import (
|
||||
MultiAgentMixInReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
|
||||
MultiAgentPrioritizedReplayBuffer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ReplayBuffer",
|
||||
"StorageUnit",
|
||||
"MultiAgentReplayBuffer",
|
||||
"ReplayMode",
|
||||
"ReservoirBuffer",
|
||||
"PrioritizedReplayBuffer",
|
||||
"MultiAgentMixInReplayBuffer",
|
||||
"MultiAgentPrioritizedReplayBuffer",
|
||||
]
|
|
@ -14,6 +14,7 @@ from ray.rllib.utils.typing import PolicyID, SampleBatchType
|
|||
from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -81,7 +82,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
'episodes'. Specifies how experiences are stored. If they
|
||||
are stored in episodes, replay_sequence_length is ignored.
|
||||
learning_starts: Number of timesteps after which a call to
|
||||
`replay()` will yield samples (before that, `replay()` will
|
||||
`sample()` will yield samples (before that, `sample()` will
|
||||
return None).
|
||||
capacity: Max number of total timesteps in all policy buffers.
|
||||
After reaching this number, older samples will be
|
||||
|
@ -170,6 +171,14 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
"""Returns the number of items currently stored in this buffer."""
|
||||
return sum(len(buffer._storage) for buffer in self.replay_buffers.values())
|
||||
|
||||
@ExperimentalAPI
|
||||
@Deprecated(old="replay", new="sample", error=False)
|
||||
def replay(self, num_items: int = None, **kwargs) -> Optional[SampleBatchType]:
|
||||
"""Deprecated in favor of new ReplayBuffer API."""
|
||||
if num_items is None:
|
||||
num_items = self.replay_batch_size
|
||||
return self.sample(num_items, **kwargs)
|
||||
|
||||
@ExperimentalAPI
|
||||
@override(ReplayBuffer)
|
||||
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
||||
|
@ -262,7 +271,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)
|
||||
|
||||
if self._num_added < self.replay_starts:
|
||||
return None
|
||||
return MultiAgentBatch({}, 0)
|
||||
with self.replay_timer:
|
||||
# Lockstep mode: Sample from all policies at the same time an
|
||||
# equal amount of steps.
|
||||
|
|
|
@ -1,17 +1,19 @@
|
|||
import logging
|
||||
import platform
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
from enum import Enum
|
||||
from ray.util.debug import log_once
|
||||
|
||||
# Import ray before psutil will make sure we use psutil's bundled version
|
||||
import ray # noqa F401
|
||||
import psutil # noqa E402
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.metrics.window_stat import WindowStat
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
|
||||
|
@ -94,6 +96,18 @@ class ReplayBuffer:
|
|||
"""Returns the number of items currently stored in this buffer."""
|
||||
return len(self._storage)
|
||||
|
||||
@ExperimentalAPI
|
||||
@Deprecated(old="add_batch", new="add", error=False)
|
||||
def add_batch(self, batch: SampleBatchType, **kwargs) -> None:
|
||||
"""Deprecated in favor of new ReplayBuffer API."""
|
||||
return self.add(batch, **kwargs)
|
||||
|
||||
@ExperimentalAPI
|
||||
@Deprecated(old="replay", new="sample", error=False)
|
||||
def replay(self, num_items: int = 1, **kwargs) -> Optional[SampleBatchType]:
|
||||
"""Deprecated in favor of new ReplayBuffer API."""
|
||||
return self.sample(num_items, **kwargs)
|
||||
|
||||
@ExperimentalAPI
|
||||
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
||||
"""Adds a batch of experiences to this buffer.
|
||||
|
|
Loading…
Add table
Reference in a new issue