mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Simple-Q uses training iteration fn (instead of execution_plan); ReplayBuffer API for Simple-Q (#22842)
This commit is contained in:
parent
a7e5aa8c6a
commit
9a64bd4e9b
22 changed files with 377 additions and 98 deletions
|
@ -80,8 +80,6 @@ apex-breakoutnoframeskip-v4:
|
||||||
epsilon_timesteps: 200000
|
epsilon_timesteps: 200000
|
||||||
final_epsilon: 0.01
|
final_epsilon: 0.01
|
||||||
prioritized_replay_alpha: 0.5
|
prioritized_replay_alpha: 0.5
|
||||||
final_prioritized_replay_beta: 1.0
|
|
||||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
num_envs_per_worker: 8
|
num_envs_per_worker: 8
|
||||||
|
@ -327,8 +325,6 @@ dqn-breakoutnoframeskip-v4:
|
||||||
epsilon_timesteps: 200000
|
epsilon_timesteps: 200000
|
||||||
final_epsilon: 0.01
|
final_epsilon: 0.01
|
||||||
prioritized_replay_alpha: 0.5
|
prioritized_replay_alpha: 0.5
|
||||||
final_prioritized_replay_beta: 1.0
|
|
||||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
|
||||||
num_gpus: 0.5
|
num_gpus: 0.5
|
||||||
timesteps_per_iteration: 10000
|
timesteps_per_iteration: 10000
|
||||||
|
|
||||||
|
|
|
@ -53,8 +53,6 @@ apex-breakoutnoframeskip-v4:
|
||||||
epsilon_timesteps: 200000
|
epsilon_timesteps: 200000
|
||||||
final_epsilon: 0.01
|
final_epsilon: 0.01
|
||||||
prioritized_replay_alpha: 0.5
|
prioritized_replay_alpha: 0.5
|
||||||
final_prioritized_replay_beta: 1.0
|
|
||||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
num_envs_per_worker: 8
|
num_envs_per_worker: 8
|
||||||
|
|
|
@ -19,7 +19,7 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
||||||
"num_workers": 32,
|
"num_workers": 32,
|
||||||
"buffer_size": 2000000,
|
"buffer_size": 2000000,
|
||||||
# TODO(jungong) : update once Apex supports replay_buffer_config.
|
# TODO(jungong) : update once Apex supports replay_buffer_config.
|
||||||
"replay_buffer_config": None,
|
"no_local_replay_buffer": True,
|
||||||
# Whether all shards of the replay buffer must be co-located
|
# Whether all shards of the replay buffer must be co-located
|
||||||
# with the learner process (running the execution plan).
|
# with the learner process (running the execution plan).
|
||||||
# This is preferred b/c the learner process should have quick
|
# This is preferred b/c the learner process should have quick
|
||||||
|
|
|
@ -111,10 +111,6 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
"prioritized_replay_alpha": 0.6,
|
"prioritized_replay_alpha": 0.6,
|
||||||
# Beta parameter for sampling from prioritized replay buffer.
|
# Beta parameter for sampling from prioritized replay buffer.
|
||||||
"prioritized_replay_beta": 0.4,
|
"prioritized_replay_beta": 0.4,
|
||||||
# Time steps over which the beta parameter is annealed.
|
|
||||||
"prioritized_replay_beta_annealing_timesteps": 20000,
|
|
||||||
# Final value of beta
|
|
||||||
"final_prioritized_replay_beta": 0.4,
|
|
||||||
# Epsilon to add to the TD errors when updating priorities.
|
# Epsilon to add to the TD errors when updating priorities.
|
||||||
"prioritized_replay_eps": 1e-6,
|
"prioritized_replay_eps": 1e-6,
|
||||||
# Whether to LZ4 compress observations
|
# Whether to LZ4 compress observations
|
||||||
|
|
|
@ -66,7 +66,7 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
||||||
"buffer_size": 2000000,
|
"buffer_size": 2000000,
|
||||||
# TODO(jungong) : add proper replay_buffer_config after
|
# TODO(jungong) : add proper replay_buffer_config after
|
||||||
# DistributedReplayBuffer type is supported.
|
# DistributedReplayBuffer type is supported.
|
||||||
"replay_buffer_config": None,
|
"no_local_replay_buffer": True,
|
||||||
# Whether all shards of the replay buffer must be co-located
|
# Whether all shards of the replay buffer must be co-located
|
||||||
# with the learner process (running the execution plan).
|
# with the learner process (running the execution plan).
|
||||||
# This is preferred b/c the learner process should have quick
|
# This is preferred b/c the learner process should have quick
|
||||||
|
@ -157,9 +157,9 @@ class ApexTrainer(DQNTrainer):
|
||||||
config["learning_starts"],
|
config["learning_starts"],
|
||||||
config["buffer_size"],
|
config["buffer_size"],
|
||||||
config["train_batch_size"],
|
config["train_batch_size"],
|
||||||
config["prioritized_replay_alpha"],
|
config["replay_buffer_config"]["prioritized_replay_alpha"],
|
||||||
config["prioritized_replay_beta"],
|
config["replay_buffer_config"]["prioritized_replay_beta"],
|
||||||
config["prioritized_replay_eps"],
|
config["replay_buffer_config"]["prioritized_replay_eps"],
|
||||||
config["multiagent"]["replay_mode"],
|
config["multiagent"]["replay_mode"],
|
||||||
config.get("replay_sequence_length", 1),
|
config.get("replay_sequence_length", 1),
|
||||||
]
|
]
|
||||||
|
|
|
@ -35,6 +35,7 @@ from ray.rllib.utils.deprecation import Deprecated
|
||||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||||
from ray.rllib.utils.typing import TrainerConfigDict
|
from ray.rllib.utils.typing import TrainerConfigDict
|
||||||
from ray.util.iter import LocalIterator
|
from ray.util.iter import LocalIterator
|
||||||
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -64,19 +65,37 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||||
# N-step Q learning
|
# N-step Q learning
|
||||||
"n_step": 1,
|
"n_step": 1,
|
||||||
|
|
||||||
# === Prioritized replay buffer ===
|
# === Replay buffer ===
|
||||||
# If True prioritized replay buffer will be used.
|
# Size of the replay buffer. Note that if async_updates is set, then
|
||||||
|
# each worker will have a replay buffer of this size.
|
||||||
|
"buffer_size": DEPRECATED_VALUE,
|
||||||
|
# Prioritized replay is here since this algo uses the old replay
|
||||||
|
# buffer api
|
||||||
"prioritized_replay": True,
|
"prioritized_replay": True,
|
||||||
# Alpha parameter for prioritized replay buffer.
|
"replay_buffer_config": {
|
||||||
"prioritized_replay_alpha": 0.6,
|
# For now we don't use the new ReplayBuffer API here
|
||||||
# Beta parameter for sampling from prioritized replay buffer.
|
"_enable_replay_buffer_api": False,
|
||||||
"prioritized_replay_beta": 0.4,
|
"type": "MultiAgentReplayBuffer",
|
||||||
# Final value of beta (by default, we use constant beta=0.4).
|
"capacity": 50000,
|
||||||
"final_prioritized_replay_beta": 0.4,
|
"replay_batch_size": 32,
|
||||||
# Time steps over which the beta parameter is annealed.
|
"prioritized_replay_alpha": 0.6,
|
||||||
"prioritized_replay_beta_annealing_timesteps": 20000,
|
# Beta parameter for sampling from prioritized replay buffer.
|
||||||
# Epsilon to add to the TD errors when updating priorities.
|
"prioritized_replay_beta": 0.4,
|
||||||
"prioritized_replay_eps": 1e-6,
|
# Epsilon to add to the TD errors when updating priorities.
|
||||||
|
"prioritized_replay_eps": 1e-6,
|
||||||
|
},
|
||||||
|
# Set this to True, if you want the contents of your buffer(s) to be
|
||||||
|
# stored in any saved checkpoints as well.
|
||||||
|
# Warnings will be created if:
|
||||||
|
# - This is True AND restoring from a checkpoint that contains no buffer
|
||||||
|
# data.
|
||||||
|
# - This is False AND restoring from a checkpoint that does contain
|
||||||
|
# buffer data.
|
||||||
|
"store_buffer_in_checkpoints": False,
|
||||||
|
# The number of contiguous environment steps to replay at once. This may
|
||||||
|
# be set to greater than 1 to support recurrent models.
|
||||||
|
"replay_sequence_length": 1,
|
||||||
|
|
||||||
|
|
||||||
# Callback to run before learning on a multi-agent batch of
|
# Callback to run before learning on a multi-agent batch of
|
||||||
# experiences.
|
# experiences.
|
||||||
|
@ -102,6 +121,12 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||||
# === Parallelism ===
|
# === Parallelism ===
|
||||||
# Whether to compute priorities on workers.
|
# Whether to compute priorities on workers.
|
||||||
"worker_side_prioritization": False,
|
"worker_side_prioritization": False,
|
||||||
|
|
||||||
|
# Experimental flag.
|
||||||
|
# If True, the execution plan API will not be used. Instead,
|
||||||
|
# a Trainer's `training_iteration` method will be called as-is each
|
||||||
|
# training iteration.
|
||||||
|
"_disable_execution_plan_api": False,
|
||||||
},
|
},
|
||||||
_allow_unknown_configs=True,
|
_allow_unknown_configs=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -451,10 +451,14 @@ def postprocess_nstep_and_prio(
|
||||||
batch[SampleBatch.DONES],
|
batch[SampleBatch.DONES],
|
||||||
batch[PRIO_WEIGHTS],
|
batch[PRIO_WEIGHTS],
|
||||||
)
|
)
|
||||||
new_priorities = (
|
# Retain compatibility with old-style Replay args
|
||||||
np.abs(convert_to_numpy(td_errors))
|
epsilon = policy.config.get("replay_buffer_config", {}).get(
|
||||||
+ policy.config["prioritized_replay_eps"]
|
"prioritized_replay_eps"
|
||||||
)
|
) or policy.config.get("prioritized_replay_eps")
|
||||||
|
if epsilon is None:
|
||||||
|
raise ValueError("prioritized_replay_eps not defined in config.")
|
||||||
|
|
||||||
|
new_priorities = np.abs(convert_to_numpy(td_errors)) + epsilon
|
||||||
batch[PRIO_WEIGHTS] = new_priorities
|
batch[PRIO_WEIGHTS] = new_priorities
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
|
@ -29,6 +29,19 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||||
# Batch mode must be complete_episodes.
|
# Batch mode must be complete_episodes.
|
||||||
"batch_mode": "complete_episodes",
|
"batch_mode": "complete_episodes",
|
||||||
|
|
||||||
|
# === Replay buffer ===
|
||||||
|
"replay_buffer_config": {
|
||||||
|
# For now we don't use the new ReplayBuffer API here
|
||||||
|
"_enable_replay_buffer_api": False,
|
||||||
|
"type": "MultiAgentReplayBuffer",
|
||||||
|
"capacity": 50000,
|
||||||
|
"replay_batch_size": 32,
|
||||||
|
"prioritized_replay_alpha": 0.6,
|
||||||
|
# Beta parameter for sampling from prioritized replay buffer.
|
||||||
|
"prioritized_replay_beta": 0.4,
|
||||||
|
# Epsilon to add to the TD errors when updating priorities.
|
||||||
|
"prioritized_replay_eps": 1e-6,
|
||||||
|
},
|
||||||
# If True, assume a zero-initialized state input (no matter where in
|
# If True, assume a zero-initialized state input (no matter where in
|
||||||
# the episode the sequence is located).
|
# the episode the sequence is located).
|
||||||
# If False, store the initial states along with each SampleBatch, use
|
# If False, store the initial states along with each SampleBatch, use
|
||||||
|
@ -66,6 +79,12 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||||
|
|
||||||
# Update the target network every `target_network_update_freq` steps.
|
# Update the target network every `target_network_update_freq` steps.
|
||||||
"target_network_update_freq": 2500,
|
"target_network_update_freq": 2500,
|
||||||
|
|
||||||
|
# Experimental flag.
|
||||||
|
# If True, the execution plan API will not be used. Instead,
|
||||||
|
# a Trainer's `training_iteration` method will be called as-is each
|
||||||
|
# training iteration.
|
||||||
|
"_disable_execution_plan_api": False,
|
||||||
},
|
},
|
||||||
_allow_unknown_configs=True,
|
_allow_unknown_configs=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,19 +15,40 @@ from typing import Optional, Type
|
||||||
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
||||||
from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
|
from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
|
||||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||||
|
from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
|
||||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
from ray.rllib.execution.rollout_ops import (
|
||||||
|
ParallelRollouts,
|
||||||
|
synchronous_parallel_sample,
|
||||||
|
)
|
||||||
from ray.rllib.execution.train_ops import (
|
from ray.rllib.execution.train_ops import (
|
||||||
MultiGPUTrainOneStep,
|
|
||||||
TrainOneStep,
|
TrainOneStep,
|
||||||
|
MultiGPUTrainOneStep,
|
||||||
|
train_one_step,
|
||||||
|
multi_gpu_train_one_step,
|
||||||
|
)
|
||||||
|
from ray.rllib.execution.train_ops import (
|
||||||
UpdateTargetNetwork,
|
UpdateTargetNetwork,
|
||||||
)
|
)
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||||
from ray.rllib.utils.typing import TrainerConfigDict
|
from ray.rllib.utils.metrics import (
|
||||||
|
NUM_ENV_STEPS_SAMPLED,
|
||||||
|
NUM_AGENT_STEPS_SAMPLED,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.typing import (
|
||||||
|
ResultDict,
|
||||||
|
TrainerConfigDict,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.metrics import (
|
||||||
|
LAST_TARGET_UPDATE_TS,
|
||||||
|
NUM_TARGET_UPDATES,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -64,9 +85,18 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
# Size of the replay buffer. Note that if async_updates is set, then
|
# Size of the replay buffer. Note that if async_updates is set, then
|
||||||
# each worker will have a replay buffer of this size.
|
# each worker will have a replay buffer of this size.
|
||||||
"buffer_size": DEPRECATED_VALUE,
|
"buffer_size": DEPRECATED_VALUE,
|
||||||
|
# Deprecated for Simple Q because of new ReplayBuffer API
|
||||||
|
# Use MultiAgentPrioritizedReplayBuffer for prioritization.
|
||||||
|
"prioritized_replay": DEPRECATED_VALUE,
|
||||||
"replay_buffer_config": {
|
"replay_buffer_config": {
|
||||||
|
# Use the new ReplayBuffer API here
|
||||||
|
"_enable_replay_buffer_api": True,
|
||||||
"type": "MultiAgentReplayBuffer",
|
"type": "MultiAgentReplayBuffer",
|
||||||
"capacity": 50000,
|
"capacity": 50000,
|
||||||
|
"replay_batch_size": 32,
|
||||||
|
# The number of contiguous environment steps to replay at once. This
|
||||||
|
# may be set to greater than 1 to support recurrent models.
|
||||||
|
"replay_sequence_length": 1,
|
||||||
},
|
},
|
||||||
# Set this to True, if you want the contents of your buffer(s) to be
|
# Set this to True, if you want the contents of your buffer(s) to be
|
||||||
# stored in any saved checkpoints as well.
|
# stored in any saved checkpoints as well.
|
||||||
|
@ -76,9 +106,6 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
# - This is False AND restoring from a checkpoint that does contain
|
# - This is False AND restoring from a checkpoint that does contain
|
||||||
# buffer data.
|
# buffer data.
|
||||||
"store_buffer_in_checkpoints": False,
|
"store_buffer_in_checkpoints": False,
|
||||||
# The number of contiguous environment steps to replay at once. This may
|
|
||||||
# be set to greater than 1 to support recurrent models.
|
|
||||||
"replay_sequence_length": 1,
|
|
||||||
|
|
||||||
# === Optimization ===
|
# === Optimization ===
|
||||||
# Learning rate for adam optimizer
|
# Learning rate for adam optimizer
|
||||||
|
@ -108,6 +135,12 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
# Prevent reporting frequency from going lower than this time span.
|
# Prevent reporting frequency from going lower than this time span.
|
||||||
"min_time_s_per_reporting": 1,
|
"min_time_s_per_reporting": 1,
|
||||||
|
|
||||||
|
# Experimental flag.
|
||||||
|
# If True, the execution plan API will not be used. Instead,
|
||||||
|
# a Trainer's `training_iteration` method will be called as-is each
|
||||||
|
# training iteration.
|
||||||
|
"_disable_execution_plan_api": True,
|
||||||
})
|
})
|
||||||
# __sphinx_doc_end__
|
# __sphinx_doc_end__
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
@ -139,7 +172,9 @@ class SimpleQTrainer(Trainer):
|
||||||
" used at the same time!"
|
" used at the same time!"
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.get("prioritized_replay"):
|
if config.get("prioritized_replay") or config.get(
|
||||||
|
"replay_buffer_config", {}
|
||||||
|
).get("prioritized_replay"):
|
||||||
if config["multiagent"]["replay_mode"] == "lockstep":
|
if config["multiagent"]["replay_mode"] == "lockstep":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Prioritized replay is not supported when replay_mode=lockstep."
|
"Prioritized replay is not supported when replay_mode=lockstep."
|
||||||
|
@ -215,3 +250,63 @@ class SimpleQTrainer(Trainer):
|
||||||
)
|
)
|
||||||
|
|
||||||
return StandardMetricsReporting(train_op, workers, config)
|
return StandardMetricsReporting(train_op, workers, config)
|
||||||
|
|
||||||
|
@ExperimentalAPI
|
||||||
|
def training_iteration(self) -> ResultDict:
|
||||||
|
"""Simple Q training iteration function.
|
||||||
|
|
||||||
|
Simple Q consists of the following steps:
|
||||||
|
- (1) Sample (MultiAgentBatch) from workers...
|
||||||
|
- (2) Store new samples in replay buffer.
|
||||||
|
- (3) Sample training batch (MultiAgentBatch) from replay buffer.
|
||||||
|
- (4) Learn on training batch.
|
||||||
|
- (5) Update target network every target_network_update_freq steps.
|
||||||
|
- (6) Return all collected metrics for the iteration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The results dict from executing the training iteration.
|
||||||
|
"""
|
||||||
|
batch_size = self.config["train_batch_size"]
|
||||||
|
local_worker = self.workers.local_worker()
|
||||||
|
|
||||||
|
# (1) Sample (MultiAgentBatch) from workers
|
||||||
|
new_sample_batches = synchronous_parallel_sample(self.workers)
|
||||||
|
|
||||||
|
for s in new_sample_batches:
|
||||||
|
# Update counters
|
||||||
|
self._counters[NUM_ENV_STEPS_SAMPLED] += len(s)
|
||||||
|
self._counters[NUM_AGENT_STEPS_SAMPLED] += (
|
||||||
|
len(s) if isinstance(s, SampleBatch) else s.agent_steps()
|
||||||
|
)
|
||||||
|
# (2) Store new samples in replay buffer
|
||||||
|
self.local_replay_buffer.add(s)
|
||||||
|
|
||||||
|
# (3) Sample training batch (MultiAgentBatch) from replay buffer.
|
||||||
|
train_batch = self.local_replay_buffer.sample(batch_size)
|
||||||
|
|
||||||
|
# (4) Learn on training batch.
|
||||||
|
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||||
|
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||||
|
if self.config.get("simple_optimizer") is True:
|
||||||
|
train_results = train_one_step(self, train_batch)
|
||||||
|
else:
|
||||||
|
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||||
|
|
||||||
|
# (5) Update target network every target_network_update_freq steps
|
||||||
|
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
|
||||||
|
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||||
|
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||||
|
to_update = local_worker.get_policies_to_train()
|
||||||
|
local_worker.foreach_policy_to_train(
|
||||||
|
lambda p, pid: pid in to_update and p.update_target()
|
||||||
|
)
|
||||||
|
self._counters[NUM_TARGET_UPDATES] += 1
|
||||||
|
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||||
|
|
||||||
|
# Update remote workers' weights after learning on local worker
|
||||||
|
if self.workers.remote_workers():
|
||||||
|
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
||||||
|
self.workers.sync_weights()
|
||||||
|
|
||||||
|
# (6) Return all collected metrics for the iteration.
|
||||||
|
return train_results
|
||||||
|
|
|
@ -109,8 +109,6 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
"prioritized_replay_alpha": 0.6,
|
"prioritized_replay_alpha": 0.6,
|
||||||
"prioritized_replay_beta": 0.4,
|
"prioritized_replay_beta": 0.4,
|
||||||
"prioritized_replay_eps": 1e-6,
|
"prioritized_replay_eps": 1e-6,
|
||||||
"prioritized_replay_beta_annealing_timesteps": 20000,
|
|
||||||
"final_prioritized_replay_beta": 0.4,
|
|
||||||
# Whether to LZ4 compress observations
|
# Whether to LZ4 compress observations
|
||||||
"compress_observations": False,
|
"compress_observations": False,
|
||||||
|
|
||||||
|
|
|
@ -85,7 +85,7 @@ class TestSAC(unittest.TestCase):
|
||||||
# If we use default buffer size (1e6), the buffer will take up
|
# If we use default buffer size (1e6), the buffer will take up
|
||||||
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
|
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
|
||||||
# available system memory (8.34816 GB).
|
# available system memory (8.34816 GB).
|
||||||
config["buffer_size"] = 40000
|
config["replay_buffer_config"]["capacity"] = 40000
|
||||||
# Test with saved replay buffer.
|
# Test with saved replay buffer.
|
||||||
config["store_buffer_in_checkpoints"] = True
|
config["store_buffer_in_checkpoints"] = True
|
||||||
num_iterations = 1
|
num_iterations = 1
|
||||||
|
|
|
@ -361,6 +361,26 @@ COMMON_CONFIG: TrainerConfigDict = {
|
||||||
# "env_config": {...},
|
# "env_config": {...},
|
||||||
# "explore": False
|
# "explore": False
|
||||||
},
|
},
|
||||||
|
|
||||||
|
# === Replay Buffer Settings ===
|
||||||
|
# Provide a dict specifying the ReplayBuffer's config.
|
||||||
|
# "replay_buffer_config": {
|
||||||
|
# The ReplayBuffer class to use. Any class that obeys the
|
||||||
|
# ReplayBuffer API can be used here. In the simplest case, this is the
|
||||||
|
# name (str) of any class present in the `rllib.utils.replay_buffers`
|
||||||
|
# package. You can also provide the python class directly or the
|
||||||
|
# full location of your class (e.g.
|
||||||
|
# "ray.rllib.utils.replay_buffers.replay_buffer.ReplayBuffer").
|
||||||
|
# "type": "ReplayBuffer",
|
||||||
|
# The capacity of units that can be stored in one ReplayBuffer
|
||||||
|
# instance before eviction.
|
||||||
|
# "capacity": 10000,
|
||||||
|
# Specifies how experiences are stored. Either 'sequences' or
|
||||||
|
# 'timesteps'.
|
||||||
|
# "storage_unit": "timesteps",
|
||||||
|
# Add constructor kwargs here (if any).
|
||||||
|
# },
|
||||||
|
|
||||||
# Number of parallel workers to use for evaluation. Note that this is set
|
# Number of parallel workers to use for evaluation. Note that this is set
|
||||||
# to zero by default, which means evaluation will be run in the trainer
|
# to zero by default, which means evaluation will be run in the trainer
|
||||||
# process (only if evaluation_interval is not None). If you increase this,
|
# process (only if evaluation_interval is not None). If you increase this,
|
||||||
|
@ -652,6 +672,8 @@ COMMON_CONFIG: TrainerConfigDict = {
|
||||||
# Use `metrics_episode_collection_timeout_s` instead.
|
# Use `metrics_episode_collection_timeout_s` instead.
|
||||||
"collect_metrics_timeout": DEPRECATED_VALUE,
|
"collect_metrics_timeout": DEPRECATED_VALUE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# __sphinx_doc_end__
|
# __sphinx_doc_end__
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
@ -719,7 +741,7 @@ class Trainer(Trainable):
|
||||||
"custom_resources_per_worker",
|
"custom_resources_per_worker",
|
||||||
"evaluation_config",
|
"evaluation_config",
|
||||||
"exploration_config",
|
"exploration_config",
|
||||||
"extra_python_environs_for_driver",
|
"replay_buffer_config",
|
||||||
"extra_python_environs_for_worker",
|
"extra_python_environs_for_worker",
|
||||||
"input_config",
|
"input_config",
|
||||||
"output_config",
|
"output_config",
|
||||||
|
@ -727,7 +749,10 @@ class Trainer(Trainable):
|
||||||
|
|
||||||
# List of top level keys with value=dict, for which we always override the
|
# List of top level keys with value=dict, for which we always override the
|
||||||
# entire value (dict), iff the "type" key in that value dict changes.
|
# entire value (dict), iff the "type" key in that value dict changes.
|
||||||
_override_all_subkeys_if_type_changes = ["exploration_config"]
|
_override_all_subkeys_if_type_changes = [
|
||||||
|
"exploration_config",
|
||||||
|
"replay_buffer_config",
|
||||||
|
]
|
||||||
|
|
||||||
# TODO: Deprecate. Instead, override `Trainer.get_default_config()`.
|
# TODO: Deprecate. Instead, override `Trainer.get_default_config()`.
|
||||||
_default_config = COMMON_CONFIG
|
_default_config = COMMON_CONFIG
|
||||||
|
@ -2724,58 +2749,147 @@ class Trainer(Trainable):
|
||||||
MultiAgentReplayBuffer instance based on trainer config.
|
MultiAgentReplayBuffer instance based on trainer config.
|
||||||
None, if local replay buffer is not needed.
|
None, if local replay buffer is not needed.
|
||||||
"""
|
"""
|
||||||
# These are the agents that utilizes a local replay buffer.
|
# Deprecation of old-style replay buffer args
|
||||||
if "replay_buffer_config" not in config or not config["replay_buffer_config"]:
|
# Warnings before checking of we need local buffer so that algorithms
|
||||||
# Does not need a replay buffer.
|
# Without local buffer also get warned
|
||||||
return None
|
deprecated_replay_buffer_keys = [
|
||||||
|
"prioritized_replay_alpha",
|
||||||
|
"prioritized_replay_beta",
|
||||||
|
"prioritized_replay_eps",
|
||||||
|
"learning_starts",
|
||||||
|
]
|
||||||
|
for k in deprecated_replay_buffer_keys:
|
||||||
|
if config.get(k) is not None:
|
||||||
|
deprecation_warning(
|
||||||
|
old="config[{}]".format(k),
|
||||||
|
help="config['replay_buffer_config'][{}] should be used "
|
||||||
|
"for Q-Learning algorithms. Ignore this warning if "
|
||||||
|
"you are not using a Q-Learning algorithm and still "
|
||||||
|
"provide {}."
|
||||||
|
"".format(k, k),
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
# Copy values over to new location in config to support new
|
||||||
|
# and old configuration style
|
||||||
|
if config.get("replay_buffer_config") is not None:
|
||||||
|
config["replay_buffer_config"][k] = config[k]
|
||||||
|
|
||||||
|
# Some agents do not need a replay buffer
|
||||||
|
if not config.get("replay_buffer_config") or config.get(
|
||||||
|
"no_local_replay_buffer", False
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
replay_buffer_config = config["replay_buffer_config"]
|
replay_buffer_config = config["replay_buffer_config"]
|
||||||
if (
|
assert (
|
||||||
"type" not in replay_buffer_config
|
"type" in replay_buffer_config
|
||||||
or replay_buffer_config["type"] != "MultiAgentReplayBuffer"
|
), "Can not instantiate ReplayBuffer from config without 'type' key."
|
||||||
):
|
|
||||||
# DistributedReplayBuffer coming soon.
|
|
||||||
return None
|
|
||||||
|
|
||||||
capacity = config.get("buffer_size", DEPRECATED_VALUE)
|
capacity = config.get("buffer_size", DEPRECATED_VALUE)
|
||||||
if capacity != DEPRECATED_VALUE:
|
if capacity != DEPRECATED_VALUE:
|
||||||
# Print a deprecation warning.
|
|
||||||
deprecation_warning(
|
deprecation_warning(
|
||||||
old="config['buffer_size']",
|
old="config['buffer_size']",
|
||||||
new="config['replay_buffer_config']['capacity']",
|
help="Buffer size specified at new location config["
|
||||||
|
"'replay_buffer_config']["
|
||||||
|
"'capacity'] will be overwritten.",
|
||||||
error=False,
|
error=False,
|
||||||
)
|
)
|
||||||
else:
|
config["replay_buffer_config"]["capacity"] = capacity
|
||||||
# Get capacity out of replay_buffer_config.
|
|
||||||
capacity = replay_buffer_config["capacity"]
|
|
||||||
|
|
||||||
# Configure prio. replay parameters.
|
# Check if old replay buffer should be instantiated
|
||||||
if config.get("prioritized_replay"):
|
buffer_type = config["replay_buffer_config"]["type"]
|
||||||
prio_args = {
|
if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
|
||||||
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
|
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
||||||
"prioritized_replay_beta": config["prioritized_replay_beta"],
|
# Prepend old-style buffers' path
|
||||||
"prioritized_replay_eps": config["prioritized_replay_eps"],
|
assert buffer_type == "MultiAgentReplayBuffer", (
|
||||||
}
|
"Without "
|
||||||
# Switch off prioritization (alpha=0.0).
|
"ReplayBuffer "
|
||||||
else:
|
"API, only "
|
||||||
prio_args = {"prioritized_replay_alpha": 0.0}
|
"MultiAgentReplayBuffer "
|
||||||
|
"is supported!"
|
||||||
|
)
|
||||||
|
# Create valid full [module].[class] string for from_config
|
||||||
|
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
|
||||||
|
else:
|
||||||
|
assert buffer_type in [
|
||||||
|
"ray.rllib.execution.MultiAgentReplayBuffer",
|
||||||
|
MultiAgentReplayBuffer,
|
||||||
|
], (
|
||||||
|
"Without ReplayBuffer API, only "
|
||||||
|
"MultiAgentReplayBuffer is supported!"
|
||||||
|
)
|
||||||
|
|
||||||
return MultiAgentReplayBuffer(
|
config["replay_buffer_config"]["type"] = buffer_type
|
||||||
num_shards=1,
|
|
||||||
learning_starts=config["learning_starts"],
|
# Remove from config so it's not passed into the buffer c'tor
|
||||||
capacity=capacity,
|
config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
|
||||||
replay_batch_size=config["train_batch_size"],
|
|
||||||
replay_mode=config["multiagent"]["replay_mode"],
|
# We need to deprecate the old-style location of the following
|
||||||
replay_sequence_length=config.get("replay_sequence_length", 1),
|
# buffer arguments and make users put them into the
|
||||||
replay_burn_in=config.get("burn_in", 0),
|
# "replay_buffer_config" field of their config.
|
||||||
replay_zero_init_states=config.get("zero_init_states", True),
|
config["replay_buffer_config"]["replay_batch_size"] = config[
|
||||||
**prio_args,
|
"train_batch_size"
|
||||||
)
|
]
|
||||||
|
config["replay_buffer_config"]["replay_mode"] = config["multiagent"][
|
||||||
|
"replay_mode"
|
||||||
|
]
|
||||||
|
deprecation_warning(
|
||||||
|
old="config['multiagent']['replay_mode']",
|
||||||
|
new="config['replay_buffer_config']['replay_mode']",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
config["replay_buffer_config"]["replay_sequence_length"] = config.get(
|
||||||
|
"replay_sequence_length", 1
|
||||||
|
)
|
||||||
|
if config.get("replay_sequence_length"):
|
||||||
|
deprecation_warning(
|
||||||
|
old="config['replay_sequence_length']",
|
||||||
|
new="config['replay_buffer_config']['replay_sequence_length']",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
config["replay_buffer_config"]["replay_burn_in"] = config.get(
|
||||||
|
"replay_burn_in", 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.get("burn_in"):
|
||||||
|
deprecation_warning(
|
||||||
|
old="config['burn_in']",
|
||||||
|
help="Burn in specified at new location config["
|
||||||
|
"'replay_buffer_config']["
|
||||||
|
"'replay_burn_in'] will be overwritten.",
|
||||||
|
)
|
||||||
|
config["replay_buffer_config"]["replay_burn_in"] = config["burn_in"]
|
||||||
|
|
||||||
|
config["replay_buffer_config"]["replay_zero_init_states"] = config.get(
|
||||||
|
"replay_zero_init_states", True
|
||||||
|
)
|
||||||
|
if config.get("replay_zero_init_states"):
|
||||||
|
deprecation_warning(
|
||||||
|
old="config['replay_zero_init_states']",
|
||||||
|
new="config['replay_buffer_config']['replay_zero_init_states']",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no prioritized replay, old-style replay buffer should
|
||||||
|
# not be handed the following parameters:
|
||||||
|
if config.get("prioritized_replay", False) is False:
|
||||||
|
# This triggers non-prioritization in old-style replay buffer
|
||||||
|
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
|
||||||
|
|
||||||
|
else:
|
||||||
|
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
||||||
|
# Create valid full [module].[class] string for from_config
|
||||||
|
buffer_type = "ray.rllib.utils.replay_buffers." + buffer_type
|
||||||
|
config["replay_buffer_config"]["type"] = buffer_type
|
||||||
|
|
||||||
|
return from_config(buffer_type, config["replay_buffer_config"])
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def _kwargs_for_execution_plan(self):
|
def _kwargs_for_execution_plan(self):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if self.local_replay_buffer:
|
if self.local_replay_buffer is not None:
|
||||||
kwargs["local_replay_buffer"] = self.local_replay_buffer
|
kwargs["local_replay_buffer"] = self.local_replay_buffer
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,6 @@ if __name__ == "__main__":
|
||||||
"n_step": 3,
|
"n_step": 3,
|
||||||
"lr": 0.0001,
|
"lr": 0.0001,
|
||||||
"prioritized_replay_alpha": 0.5,
|
"prioritized_replay_alpha": 0.5,
|
||||||
"final_prioritized_replay_beta": 1.0,
|
|
||||||
"target_network_update_freq": 50000,
|
"target_network_update_freq": 50000,
|
||||||
"timesteps_per_iteration": 25000,
|
"timesteps_per_iteration": 25000,
|
||||||
# Method specific.
|
# Method specific.
|
||||||
|
|
|
@ -91,8 +91,6 @@ apex:
|
||||||
epsilon_timesteps: 200000
|
epsilon_timesteps: 200000
|
||||||
final_epsilon: 0.01
|
final_epsilon: 0.01
|
||||||
prioritized_replay_alpha: 0.5
|
prioritized_replay_alpha: 0.5
|
||||||
final_prioritized_replay_beta: 1.0
|
|
||||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
num_envs_per_worker: 8
|
num_envs_per_worker: 8
|
||||||
|
@ -141,7 +139,5 @@ atari-basic-dqn:
|
||||||
epsilon_timesteps: 200000
|
epsilon_timesteps: 200000
|
||||||
final_epsilon: 0.01
|
final_epsilon: 0.01
|
||||||
prioritized_replay_alpha: 0.5
|
prioritized_replay_alpha: 0.5
|
||||||
final_prioritized_replay_beta: 1.0
|
|
||||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
|
||||||
num_gpus: 0.2
|
num_gpus: 0.2
|
||||||
timesteps_per_iteration: 10000
|
timesteps_per_iteration: 10000
|
||||||
|
|
|
@ -20,8 +20,6 @@ apex:
|
||||||
hiddens: [512]
|
hiddens: [512]
|
||||||
buffer_size: 1000000
|
buffer_size: 1000000
|
||||||
prioritized_replay_alpha: 0.5
|
prioritized_replay_alpha: 0.5
|
||||||
final_prioritized_replay_beta: 1.0
|
|
||||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
|
||||||
|
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,5 @@ atari-dist-dqn:
|
||||||
epsilon_timesteps: 200000
|
epsilon_timesteps: 200000
|
||||||
final_epsilon: 0.01
|
final_epsilon: 0.01
|
||||||
prioritized_replay_alpha: 0.5
|
prioritized_replay_alpha: 0.5
|
||||||
final_prioritized_replay_beta: 1.0
|
|
||||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
|
||||||
num_gpus: 0.2
|
num_gpus: 0.2
|
||||||
timesteps_per_iteration: 10000
|
timesteps_per_iteration: 10000
|
||||||
|
|
|
@ -29,7 +29,5 @@ atari-basic-dqn:
|
||||||
epsilon_timesteps: 200000
|
epsilon_timesteps: 200000
|
||||||
final_epsilon: 0.01
|
final_epsilon: 0.01
|
||||||
prioritized_replay_alpha: 0.5
|
prioritized_replay_alpha: 0.5
|
||||||
final_prioritized_replay_beta: 1.0
|
|
||||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
|
||||||
num_gpus: 0.2
|
num_gpus: 0.2
|
||||||
timesteps_per_iteration: 10000
|
timesteps_per_iteration: 10000
|
||||||
|
|
|
@ -29,7 +29,5 @@ dueling-ddqn:
|
||||||
epsilon_timesteps: 200000
|
epsilon_timesteps: 200000
|
||||||
final_epsilon: 0.01
|
final_epsilon: 0.01
|
||||||
prioritized_replay_alpha: 0.5
|
prioritized_replay_alpha: 0.5
|
||||||
final_prioritized_replay_beta: 1.0
|
|
||||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
|
||||||
num_gpus: 0.2
|
num_gpus: 0.2
|
||||||
timesteps_per_iteration: 10000
|
timesteps_per_iteration: 10000
|
||||||
|
|
|
@ -19,8 +19,6 @@ pong-deterministic-rainbow:
|
||||||
target_network_update_freq: 500
|
target_network_update_freq: 500
|
||||||
prioritized_replay: True
|
prioritized_replay: True
|
||||||
prioritized_replay_alpha: 0.5
|
prioritized_replay_alpha: 0.5
|
||||||
final_prioritized_replay_beta: 1.0
|
|
||||||
prioritized_replay_beta_annealing_timesteps: 400000
|
|
||||||
n_step: 3
|
n_step: 3
|
||||||
gpu: True
|
gpu: True
|
||||||
model:
|
model:
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit
|
||||||
|
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
|
||||||
|
MultiAgentReplayBuffer,
|
||||||
|
ReplayMode,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.replay_buffers.reservoir_buffer import ReservoirBuffer
|
||||||
|
from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import (
|
||||||
|
PrioritizedReplayBuffer,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import (
|
||||||
|
MultiAgentMixInReplayBuffer,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
|
||||||
|
MultiAgentPrioritizedReplayBuffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ReplayBuffer",
|
||||||
|
"StorageUnit",
|
||||||
|
"MultiAgentReplayBuffer",
|
||||||
|
"ReplayMode",
|
||||||
|
"ReservoirBuffer",
|
||||||
|
"PrioritizedReplayBuffer",
|
||||||
|
"MultiAgentMixInReplayBuffer",
|
||||||
|
"MultiAgentPrioritizedReplayBuffer",
|
||||||
|
]
|
|
@ -14,6 +14,7 @@ from ray.rllib.utils.typing import PolicyID, SampleBatchType
|
||||||
from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit
|
from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit
|
||||||
from ray.rllib.utils.from_config import from_config
|
from ray.rllib.utils.from_config import from_config
|
||||||
from ray.util.debug import log_once
|
from ray.util.debug import log_once
|
||||||
|
from ray.rllib.utils.deprecation import Deprecated
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -81,7 +82,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
||||||
'episodes'. Specifies how experiences are stored. If they
|
'episodes'. Specifies how experiences are stored. If they
|
||||||
are stored in episodes, replay_sequence_length is ignored.
|
are stored in episodes, replay_sequence_length is ignored.
|
||||||
learning_starts: Number of timesteps after which a call to
|
learning_starts: Number of timesteps after which a call to
|
||||||
`replay()` will yield samples (before that, `replay()` will
|
`sample()` will yield samples (before that, `sample()` will
|
||||||
return None).
|
return None).
|
||||||
capacity: Max number of total timesteps in all policy buffers.
|
capacity: Max number of total timesteps in all policy buffers.
|
||||||
After reaching this number, older samples will be
|
After reaching this number, older samples will be
|
||||||
|
@ -170,6 +171,14 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
||||||
"""Returns the number of items currently stored in this buffer."""
|
"""Returns the number of items currently stored in this buffer."""
|
||||||
return sum(len(buffer._storage) for buffer in self.replay_buffers.values())
|
return sum(len(buffer._storage) for buffer in self.replay_buffers.values())
|
||||||
|
|
||||||
|
@ExperimentalAPI
|
||||||
|
@Deprecated(old="replay", new="sample", error=False)
|
||||||
|
def replay(self, num_items: int = None, **kwargs) -> Optional[SampleBatchType]:
|
||||||
|
"""Deprecated in favor of new ReplayBuffer API."""
|
||||||
|
if num_items is None:
|
||||||
|
num_items = self.replay_batch_size
|
||||||
|
return self.sample(num_items, **kwargs)
|
||||||
|
|
||||||
@ExperimentalAPI
|
@ExperimentalAPI
|
||||||
@override(ReplayBuffer)
|
@override(ReplayBuffer)
|
||||||
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
||||||
|
@ -262,7 +271,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
||||||
kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)
|
kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)
|
||||||
|
|
||||||
if self._num_added < self.replay_starts:
|
if self._num_added < self.replay_starts:
|
||||||
return None
|
return MultiAgentBatch({}, 0)
|
||||||
with self.replay_timer:
|
with self.replay_timer:
|
||||||
# Lockstep mode: Sample from all policies at the same time an
|
# Lockstep mode: Sample from all policies at the same time an
|
||||||
# equal amount of steps.
|
# equal amount of steps.
|
||||||
|
|
|
@ -1,17 +1,19 @@
|
||||||
import logging
|
import logging
|
||||||
import platform
|
import platform
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from ray.util.debug import log_once
|
|
||||||
|
|
||||||
# Import ray before psutil will make sure we use psutil's bundled version
|
# Import ray before psutil will make sure we use psutil's bundled version
|
||||||
import ray # noqa F401
|
import ray # noqa F401
|
||||||
import psutil # noqa E402
|
import psutil # noqa E402
|
||||||
|
|
||||||
|
from ray.util.debug import log_once
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||||
|
from ray.rllib.utils.deprecation import Deprecated
|
||||||
from ray.rllib.utils.metrics.window_stat import WindowStat
|
from ray.rllib.utils.metrics.window_stat import WindowStat
|
||||||
from ray.rllib.utils.typing import SampleBatchType
|
from ray.rllib.utils.typing import SampleBatchType
|
||||||
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
|
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
|
||||||
|
@ -94,6 +96,18 @@ class ReplayBuffer:
|
||||||
"""Returns the number of items currently stored in this buffer."""
|
"""Returns the number of items currently stored in this buffer."""
|
||||||
return len(self._storage)
|
return len(self._storage)
|
||||||
|
|
||||||
|
@ExperimentalAPI
|
||||||
|
@Deprecated(old="add_batch", new="add", error=False)
|
||||||
|
def add_batch(self, batch: SampleBatchType, **kwargs) -> None:
|
||||||
|
"""Deprecated in favor of new ReplayBuffer API."""
|
||||||
|
return self.add(batch, **kwargs)
|
||||||
|
|
||||||
|
@ExperimentalAPI
|
||||||
|
@Deprecated(old="replay", new="sample", error=False)
|
||||||
|
def replay(self, num_items: int = 1, **kwargs) -> Optional[SampleBatchType]:
|
||||||
|
"""Deprecated in favor of new ReplayBuffer API."""
|
||||||
|
return self.sample(num_items, **kwargs)
|
||||||
|
|
||||||
@ExperimentalAPI
|
@ExperimentalAPI
|
||||||
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
||||||
"""Adds a batch of experiences to this buffer.
|
"""Adds a batch of experiences to this buffer.
|
||||||
|
|
Loading…
Add table
Reference in a new issue