[RLlib] QMIX training iteration function and new replay buffer API. (#24164)

This commit is contained in:
Sven Mika 2022-04-27 14:24:20 +02:00 committed by GitHub
parent 29388fb25b
commit 627b9f2e88
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 612 additions and 378 deletions

View file

@ -19,7 +19,7 @@ from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.utils.replay_buffers.replay_buffer import validate_buffer_config
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
synchronous_parallel_sample,
@ -34,12 +34,11 @@ from ray.rllib.execution.train_ops import (
UpdateTargetNetwork,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_SAMPLED,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
TARGET_NET_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
@ -200,6 +199,73 @@ class SimpleQTrainer(Trainer):
else:
return SimpleQTFPolicy
@ExperimentalAPI
@override(Trainer)
def training_iteration(self) -> ResultDict:
"""Simple Q training iteration function.
Simple Q consists of the following steps:
- Sample n MultiAgentBatches from n workers synchronously.
- Store new samples in the replay buffer.
- Sample one training MultiAgentBatch from the replay buffer.
- Learn on the training batch.
- Update the target network every `target_network_update_freq` steps.
- Return all collected training metrics for the iteration.
Returns:
The results dict from executing the training iteration.
"""
batch_size = self.config["train_batch_size"]
local_worker = self.workers.local_worker()
# Sample n MultiAgentBatches from n workers.
new_sample_batches = synchronous_parallel_sample(
worker_set=self.workers, concat=False
)
for batch in new_sample_batches:
# Update sampling step counters.
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
# Store new samples in the replay buffer
self.local_replay_buffer.add(batch)
# Sample one training MultiAgentBatch from replay buffer.
train_batch = self.local_replay_buffer.sample(batch_size)
# Learn on the training batch.
# Use simple optimizer (only for multi-agent or tf-eager; all other
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
if self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
# TODO: Move training steps counter update outside of `train_one_step()` method.
# # Update train step counters.
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
# Update target network every `target_network_update_freq` steps.
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
last_update = self._counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update >= self.config["target_network_update_freq"]:
with self._timers[TARGET_NET_UPDATE_TIMER]:
to_update = local_worker.get_policies_to_train()
local_worker.foreach_policy_to_train(
lambda p, pid: pid in to_update and p.update_target()
)
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
# Update remote workers' weights after learning on local worker.
if self.workers.remote_workers():
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights()
# Return all collected metrics for the iteration.
return train_results
@staticmethod
@override(Trainer)
def execution_plan(workers, config, **kwargs):
@ -242,66 +308,3 @@ class SimpleQTrainer(Trainer):
)
return StandardMetricsReporting(train_op, workers, config)
@ExperimentalAPI
def training_iteration(self) -> ResultDict:
"""Simple Q training iteration function.
Simple Q consists of the following steps:
- (1) Sample (MultiAgentBatch) from workers...
- (2) Store new samples in replay buffer.
- (3) Sample training batch (MultiAgentBatch) from replay buffer.
- (4) Learn on training batch.
- (5) Update target network every target_network_update_freq steps.
- (6) Return all collected metrics for the iteration.
Returns:
The results dict from executing the training iteration.
"""
batch_size = self.config["train_batch_size"]
local_worker = self.workers.local_worker()
# (1) Sample (MultiAgentBatches) from workers
new_sample_batches = synchronous_parallel_sample(
worker_set=self.workers, concat=False
)
for s in new_sample_batches:
# Update counters
self._counters[NUM_ENV_STEPS_SAMPLED] += len(s)
self._counters[NUM_AGENT_STEPS_SAMPLED] += (
len(s) if isinstance(s, SampleBatch) else s.agent_steps()
)
# (2) Store new samples in replay buffer
self.local_replay_buffer.add(s)
# (3) Sample training batch (MultiAgentBatch) from replay buffer.
train_batch = self.local_replay_buffer.sample(batch_size)
# (4) Learn on training batch.
# Use simple optimizer (only for multi-agent or tf-eager; all other
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
if self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
# (5) Update target network every target_network_update_freq steps
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
last_update = self._counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update >= self.config["target_network_update_freq"]:
with self._timers[TARGET_NET_UPDATE_TIMER]:
to_update = local_worker.get_policies_to_train()
local_worker.foreach_policy_to_train(
lambda p, pid: pid in to_update and p.update_target()
)
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
# Update remote workers' weights after learning on local worker
if self.workers.remote_workers():
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights()
# (6) Return all collected metrics for the iteration.
return train_results

View file

@ -11,11 +11,29 @@ from ray.rllib.execution.replay_ops import (
Replay,
StoreToReplayBuffer,
)
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
from ray.rllib.execution.rollout_ops import (
ConcatBatches,
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
train_one_step,
TrainOneStep,
UpdateTargetNetwork,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
NUM_TARGET_UPDATES,
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
from ray.util.iter import LocalIterator
# fmt: off
@ -61,17 +79,21 @@ DEFAULT_CONFIG = with_common_config({
"explore": False,
},
# Number of env steps to optimize for before returning
# Number of env steps to optimize for before returning.
"timesteps_per_iteration": 1000,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
# === Replay buffer ===
# Size of the replay buffer in batches (not timesteps!).
"buffer_size": 1000,
"replay_buffer_config": {
"no_local_replay_buffer": True,
# Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
"type": "SimpleReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
"learning_starts": 1000,
},
# === Optimization ===
# Learning rate for RMSProp optimizer
"lr": 0.0005,
@ -81,14 +103,12 @@ DEFAULT_CONFIG = with_common_config({
"optim_eps": 0.00001,
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": 10,
# How many steps of the model to sample before learning starts.
"learning_starts": 1000,
# Update the replay buffer with this many samples at once. Note that
# this setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 4,
# Size of a batched sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
# Minimum batch size used for training (in timesteps). With the default buffer
# (ReplayBuffer) this means, sampling from the buffer (entire-episode SampleBatches)
# as many times as is required to reach at least this number of timesteps.
"train_batch_size": 32,
# === Parallelism ===
@ -108,11 +128,18 @@ DEFAULT_CONFIG = with_common_config({
},
# Only torch supported so far.
"framework": "torch",
# Experimental flag.
# === Experimental Flags ===
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": False,
"_disable_execution_plan_api": True,
# Deprecated keys:
# Use `replay_buffer_config.learning_starts` instead.
"learning_starts": DEPRECATED_VALUE,
# Use `replay_buffer_config.capacity` instead.
"buffer_size": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# fmt: on
@ -136,6 +163,78 @@ class QMixTrainer(SimpleQTrainer):
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
return QMixTorchPolicy
@override(SimpleQTrainer)
def training_iteration(self) -> ResultDict:
"""QMIX training iteration function.
- Sample n MultiAgentBatches from n workers synchronously.
- Store new samples in the replay buffer.
- Sample one training MultiAgentBatch from the replay buffer.
- Learn on the training batch.
- Update the target network every `target_network_update_freq` steps.
- Return all collected training metrics for the iteration.
Returns:
The results dict from executing the training iteration.
"""
# Sample n batches from n workers.
new_sample_batches = synchronous_parallel_sample(
worker_set=self.workers, concat=False
)
for batch in new_sample_batches:
# Update counters.
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
# Store new samples in the replay buffer.
self.local_replay_buffer.add(batch)
# Sample n batches from replay buffer until the total number of timesteps
# reaches `train_batch_size`.
train_batch = sample_min_n_steps_from_buffer(
replay_buffer=self.local_replay_buffer,
min_steps=self.config["train_batch_size"],
count_by_agent_steps=self._by_agent_steps,
)
if train_batch is None:
return {}
# Learn on the training batch.
# Use simple optimizer (only for multi-agent or tf-eager; all other
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
if self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
# TODO: Move training steps counter update outside of `train_one_step()` method.
# # Update train step counters.
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
# Update target network every `target_network_update_freq` steps.
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
last_update = self._counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update >= self.config["target_network_update_freq"]:
to_update = self.workers.local_worker().get_policies_to_train()
self.workers.local_worker().foreach_policy_to_train(
lambda p, pid: pid in to_update and p.update_target()
)
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
# Update weights and global_vars - after learning on the local worker - on all
# remote workers.
global_vars = {
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
}
# Update remote workers' weights and global vars after learning on local worker.
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
# Return all collected metrics for the iteration.
return train_results
@staticmethod
@override(SimpleQTrainer)
def execution_plan(

View file

@ -1,22 +1,24 @@
from gym.spaces import Tuple, Discrete, Dict
import gym
import logging
import numpy as np
import tree # pip install dm_tree
from typing import Dict, List, Optional, Tuple
import ray
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
from ray.rllib.agents.qmix.model import RNNModel, _get_size
from ray.rllib.env.multi_agent_env import ENV_STATE
from ray.rllib.env.wrappers.group_agents_wrapper import GROUP_REWARDS
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import _unpack_obs
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TensorType
# Torch must be installed.
torch, nn = try_import_torch(error=True)
@ -152,7 +154,7 @@ class QMixLoss(nn.Module):
# TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
class QMixTorchPolicy(Policy):
class QMixTorchPolicy(TorchPolicy):
"""QMix impl. Assumes homogeneous agents for now.
You must use MultiAgentEnv.with_agent_groups() to group agents
@ -168,7 +170,7 @@ class QMixTorchPolicy(Policy):
_validate(obs_space, action_space)
config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
self.framework = "torch"
super().__init__(obs_space, action_space, config)
self.n_agents = len(obs_space.original_space.spaces)
config["model"]["n_agents"] = self.n_agents
self.n_actions = action_space.spaces[0].n
@ -180,7 +182,7 @@ class QMixTorchPolicy(Policy):
)
agent_obs_space = obs_space.original_space.spaces[0]
if isinstance(agent_obs_space, Dict):
if isinstance(agent_obs_space, gym.spaces.Dict):
space_keys = set(agent_obs_space.spaces.keys())
if "obs" not in space_keys:
raise ValueError("Dict obs space must have subspace labeled `obs`")
@ -228,6 +230,8 @@ class QMixTorchPolicy(Policy):
default_model=RNNModel,
).to(self.device)
super().__init__(obs_space, action_space, config, model=self.model)
self.exploration = self._create_exploration()
# Setup the mixer network.
@ -273,19 +277,22 @@ class QMixTorchPolicy(Policy):
eps=config["optim_eps"],
)
@override(Policy)
def compute_actions(
@override(TorchPolicy)
def compute_actions_from_input_dict(
self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
explore=None,
timestep=None,
**kwargs
):
input_dict: Dict[str, TensorType],
explore: bool = None,
timestep: Optional[int] = None,
**kwargs,
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
obs_batch = input_dict[SampleBatch.OBS]
state_batches = []
i = 0
while f"state_in_{i}" in input_dict:
state_batches.append(input_dict[f"state_in_{i}"])
i += 1
explore = explore if explore is not None else self.config["explore"]
obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
# We need to ensure we do not use the env global state
@ -319,7 +326,11 @@ class QMixTorchPolicy(Policy):
return tuple(actions.transpose([1, 0])), hiddens, {}
@override(Policy)
@override(TorchPolicy)
def compute_actions(self, *args, **kwargs):
return self.compute_actions_from_input_dict(*args, **kwargs)
@override(TorchPolicy)
def compute_log_likelihoods(
self,
actions,
@ -331,7 +342,7 @@ class QMixTorchPolicy(Policy):
obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
return np.zeros(obs_batch.size()[0])
@override(Policy)
@override(TorchPolicy)
def learn_on_batch(self, samples):
obs_batch, action_mask, env_global_state = self._unpack_observation(
samples[SampleBatch.CUR_OBS]
@ -456,14 +467,14 @@ class QMixTorchPolicy(Policy):
}
return {LEARNER_STATS_KEY: stats}
@override(Policy)
@override(TorchPolicy)
def get_initial_state(self): # initial RNN state
return [
s.expand([self.n_agents, -1]).cpu().numpy()
for s in self.model.get_initial_state()
]
@override(Policy)
@override(TorchPolicy)
def get_weights(self):
return {
"model": self._cpu_dict(self.model.state_dict()),
@ -474,7 +485,7 @@ class QMixTorchPolicy(Policy):
else None,
}
@override(Policy)
@override(TorchPolicy)
def set_weights(self, weights):
self.model.load_state_dict(self._device_dict(weights["model"]))
self.target_model.load_state_dict(self._device_dict(weights["target_model"]))
@ -484,13 +495,13 @@ class QMixTorchPolicy(Policy):
self._device_dict(weights["target_mixer"])
)
@override(Policy)
@override(TorchPolicy)
def get_state(self):
state = self.get_weights()
state["cur_epsilon"] = self.cur_epsilon
return state
@override(Policy)
@override(TorchPolicy)
def set_state(self, state):
self.set_weights(state)
self.set_epsilon(state["cur_epsilon"])
@ -564,20 +575,20 @@ class QMixTorchPolicy(Policy):
def _validate(obs_space, action_space):
if not hasattr(obs_space, "original_space") or not isinstance(
obs_space.original_space, Tuple
obs_space.original_space, gym.spaces.Tuple
):
raise ValueError(
"Obs space must be a Tuple, got {}. Use ".format(obs_space)
+ "MultiAgentEnv.with_agent_groups() to group related "
"agents for QMix."
)
if not isinstance(action_space, Tuple):
if not isinstance(action_space, gym.spaces.Tuple):
raise ValueError(
"Action space must be a Tuple, got {}. ".format(action_space)
+ "Use MultiAgentEnv.with_agent_groups() to group related "
"agents for QMix."
)
if not isinstance(action_space.spaces[0], Discrete):
if not isinstance(action_space.spaces[0], gym.spaces.Discrete):
raise ValueError(
"QMix requires a discrete action space, got {}".format(
action_space.spaces[0]

View file

@ -33,7 +33,7 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
from ray.rllib.utils.replay_buffers.replay_buffer import validate_buffer_config
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
logger = logging.getLogger(__name__)

View file

@ -1469,14 +1469,13 @@ class Trainer(Trainable):
else:
train_results = multi_gpu_train_one_step(self, train_batch)
# Update weights - after learning on the local worker - on all remote
# workers.
# Update weights and global_vars - after learning on the local worker - on all
# remote workers.
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
}
if self.workers.remote_workers():
with self._timers[WORKER_UPDATE_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
with self._timers[WORKER_UPDATE_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
return train_results

View file

@ -17,6 +17,7 @@ from ray.rllib.utils.typing import (
MultiAgentDict,
MultiEnvDict,
)
from ray.util import log_once
# If the obs space is Dict type, look for the global state under this key.
ENV_STATE = "state"
@ -156,7 +157,8 @@ class MultiAgentEnv(gym.Env):
if self._spaces_in_preferred_format:
return self.action_space.contains(x)
logger.warning("action_space_contains() has not been implemented")
if log_once("action_space_contains"):
logger.warning("action_space_contains() has not been implemented")
return True
@ExperimentalAPI
@ -219,7 +221,8 @@ class MultiAgentEnv(gym.Env):
samples = self.observation_space.sample()
samples = {agent_id: samples[agent_id] for agent_id in agent_ids}
return samples
logger.warning("observation_space_sample() has not been implemented")
if log_once("observation_space_sample"):
logger.warning("observation_space_sample() has not been implemented")
return {}
@PublicAPI

View file

@ -95,9 +95,13 @@ class NoopResetEnv(gym.Wrapper):
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
# this environment now uses the pcg64 random number generator which
# does not have randint as an attribute only has integers
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
# This environment now uses the pcg64 random number generator which
# does not have randint as an attribute only has integers.
try:
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
# Also still support older versions.
except AttributeError:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
assert noops > 0
obs = None
for _ in range(noops):

View file

@ -207,6 +207,7 @@ class WorkerSet:
)
# Only sync if we have remote workers or `from_worker` is provided.
weights = None
if self.remote_workers() or from_worker is not None:
weights = (from_worker or self.local_worker()).get_weights(policies)
# Put weights only once into object store and use same object
@ -216,14 +217,14 @@ class WorkerSet:
for to_worker in self.remote_workers():
to_worker.set_weights.remote(weights_ref, global_vars=global_vars)
# If `from_worker` is provided, also sync to this WorkerSet's
# local worker.
if from_worker is not None and self.local_worker() is not None:
self.local_worker().set_weights(weights, global_vars=global_vars)
# If `global_vars` is provided and local worker exists -> Update its
# global_vars.
elif self.local_worker() is not None and global_vars is not None:
self.local_worker().set_global_vars(global_vars)
# If `from_worker` is provided, also sync to this WorkerSet's
# local worker.
if from_worker is not None and self.local_worker() is not None:
self.local_worker().set_weights(weights, global_vars=global_vars)
# If `global_vars` is provided and local worker exists -> Update its
# global_vars.
elif self.local_worker() is not None and global_vars is not None:
self.local_worker().set_global_vars(global_vars)
def add_workers(self, num_workers: int) -> None:
"""Creates and adds a number of remote workers to this worker set.

View file

@ -134,6 +134,9 @@ if __name__ == "__main__":
# Evaluate every other training iteration (together
# with every other call to Trainer.train()).
"evaluation_interval": args.evaluation_interval,
"evaluation_config": {
"input_evaluation": ["is"],
},
# Run for n episodes/timesteps (properly distribute load amongst
# all eval workers). The longer it takes to evaluate, the more sense
# it makes to use `evaluation_parallel_to_training=True`.

View file

@ -1,26 +1,28 @@
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer,
ReplayMode,
)
from ray.rllib.utils.replay_buffers.reservoir_buffer import ReservoirBuffer
from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import (
PrioritizedReplayBuffer,
)
from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import (
MultiAgentMixInReplayBuffer,
)
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
MultiAgentPrioritizedReplayBuffer,
)
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer,
ReplayMode,
)
from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import (
PrioritizedReplayBuffer,
)
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit
from ray.rllib.utils.replay_buffers.reservoir_buffer import ReservoirBuffer
from ray.rllib.utils.replay_buffers.simple_replay_buffer import SimpleReplayBuffer
__all__ = [
"ReplayBuffer",
"StorageUnit",
"MultiAgentReplayBuffer",
"ReplayMode",
"ReservoirBuffer",
"PrioritizedReplayBuffer",
"MultiAgentMixInReplayBuffer",
"MultiAgentPrioritizedReplayBuffer",
"MultiAgentReplayBuffer",
"PrioritizedReplayBuffer",
"ReplayMode",
"ReplayBuffer",
"ReservoirBuffer",
"SimpleReplayBuffer",
"StorageUnit",
]

View file

@ -17,12 +17,6 @@ from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.window_stat import WindowStat
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer,
)
from ray.rllib.utils.from_config import from_config
# Constant that represents all policies in lockstep replay mode.
_ALL_POLICIES = "__all__"
@ -37,220 +31,6 @@ class StorageUnit(Enum):
EPISODES = "episodes"
@ExperimentalAPI
def validate_buffer_config(config: dict):
if config.get("replay_buffer_config", None) is None:
config["replay_buffer_config"] = {}
prioritized_replay = config.get("prioritized_replay")
if prioritized_replay != DEPRECATED_VALUE:
deprecation_warning(
old="config['prioritized_replay']",
help="Replay prioritization specified at new location config["
"'replay_buffer_config']["
"'prioritized_replay'] will be overwritten.",
error=False,
)
config["replay_buffer_config"]["prioritized_replay"] = prioritized_replay
capacity = config.get("buffer_size", DEPRECATED_VALUE)
if capacity != DEPRECATED_VALUE:
deprecation_warning(
old="config['buffer_size']",
help="Buffer size specified at new location config["
"'replay_buffer_config']["
"'capacity'] will be overwritten.",
error=False,
)
config["replay_buffer_config"]["capacity"] = capacity
# Deprecation of old-style replay buffer args
# Warnings before checking of we need local buffer so that algorithms
# Without local buffer also get warned
deprecated_replay_buffer_keys = [
"prioritized_replay_alpha",
"prioritized_replay_beta",
"prioritized_replay_eps",
"learning_starts",
]
for k in deprecated_replay_buffer_keys:
if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning(
old="config[{}]".format(k),
help="config['replay_buffer_config'][{}] should be used "
"for Q-Learning algorithms. Ignore this warning if "
"you are not using a Q-Learning algorithm and still "
"provide {}."
"".format(k, k),
error=False,
)
# Copy values over to new location in config to support new
# and old configuration style
if config.get("replay_buffer_config") is not None:
config["replay_buffer_config"][k] = config[k]
# Old Ape-X configs may contain no_local_replay_buffer
no_local_replay_buffer = config.get("no_local_replay_buffer", False)
if no_local_replay_buffer:
deprecation_warning(
old="config['no_local_replay_buffer']",
help="no_local_replay_buffer specified at new location config["
"'replay_buffer_config']["
"'capacity'] will be overwritten.",
error=False,
)
config["replay_buffer_config"][
"no_local_replay_buffer"
] = no_local_replay_buffer
# TODO (Artur):
if config["replay_buffer_config"].get("no_local_replay_buffer", False):
return
replay_buffer_config = config["replay_buffer_config"]
assert (
"type" in replay_buffer_config
), "Can not instantiate ReplayBuffer from config without 'type' key."
# Check if old replay buffer should be instantiated
buffer_type = config["replay_buffer_config"]["type"]
if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Prepend old-style buffers' path
assert buffer_type == "MultiAgentReplayBuffer", (
"Without "
"ReplayBuffer "
"API, only "
"MultiAgentReplayBuffer "
"is supported!"
)
# Create valid full [module].[class] string for from_config
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
else:
assert buffer_type in [
"ray.rllib.execution.MultiAgentReplayBuffer",
Legacy_MultiAgentReplayBuffer,
], (
"Without ReplayBuffer API, only " "MultiAgentReplayBuffer is supported!"
)
config["replay_buffer_config"]["type"] = buffer_type
# Remove from config, so it's not passed into the buffer c'tor
config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
# We need to deprecate the old-style location of the following
# buffer arguments and make users put them into the
# "replay_buffer_config" field of their config.
replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
if replay_batch_size != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_batch_size"] = replay_batch_size
deprecation_warning(
old="config['replay_batch_size']",
help="Replay batch size specified at new "
"location config['replay_buffer_config']["
"'replay_batch_size'] will be overwritten.",
error=False,
)
replay_mode = config.get("replay_mode", DEPRECATED_VALUE)
if replay_mode != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_mode"] = replay_mode
deprecation_warning(
old="config['multiagent']['replay_mode']",
help="Replay sequence length specified at new "
"location config['replay_buffer_config']["
"'replay_mode'] will be overwritten.",
error=False,
)
# Can't use DEPRECATED_VALUE here because this is also a deliberate
# value set for some algorithms
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
replay_sequence_length = config.get("replay_sequence_length", None)
if replay_sequence_length is not None:
config["replay_buffer_config"][
"replay_sequence_length"
] = replay_sequence_length
deprecation_warning(
old="config['replay_sequence_length']",
help="Replay sequence length specified at new "
"location config['replay_buffer_config']["
"'replay_sequence_length'] will be overwritten.",
error=False,
)
replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
if replay_burn_in != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
deprecation_warning(
old="config['burn_in']",
help="Burn in specified at new location config["
"'replay_buffer_config']["
"'replay_burn_in'] will be overwritten.",
)
replay_zero_init_states = config.get(
"replay_zero_init_states", DEPRECATED_VALUE
)
if replay_zero_init_states != DEPRECATED_VALUE:
config["replay_buffer_config"][
"replay_zero_init_states"
] = replay_zero_init_states
deprecation_warning(
old="config['replay_zero_init_states']",
help="Replay zero init states specified at new location "
"config["
"'replay_buffer_config']["
"'replay_zero_init_states'] will be overwritten.",
error=False,
)
# TODO (Artur): Move this logic into config objects
if config["replay_buffer_config"].get("prioritized_replay", False):
is_prioritized_buffer = True
else:
is_prioritized_buffer = False
# This triggers non-prioritization in old-style replay buffer
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
else:
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Create valid full [module].[class] string for from_config
config["replay_buffer_config"]["type"] = (
"ray.rllib.utils.replay_buffers." + buffer_type
)
test_buffer = from_config(buffer_type, config["replay_buffer_config"])
if hasattr(test_buffer, "update_priorities"):
is_prioritized_buffer = True
else:
is_prioritized_buffer = False
if is_prioritized_buffer:
if config["multiagent"]["replay_mode"] == "lockstep":
raise ValueError(
"Prioritized replay is not supported when replay_mode=lockstep."
)
elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1:
raise ValueError(
"Prioritized replay is not supported when "
"replay_sequence_length > 1."
)
else:
if config.get("worker_side_prioritization"):
raise ValueError(
"Worker side prioritization is not supported when "
"prioritized_replay=False."
)
if config["replay_buffer_config"].get("replay_batch_size", None) is None:
# Fall back to train batch size if no replay batch size was provided
config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"]
# Pop prioritized replay because it's not a valid parameter for older
# replay buffers
config["replay_buffer_config"].pop("prioritized_replay", None)
@ExperimentalAPI
class ReplayBuffer:
def __init__(
@ -316,18 +96,6 @@ class ReplayBuffer:
"""Returns the number of items currently stored in this buffer."""
return len(self._storage)
@ExperimentalAPI
@Deprecated(old="add_batch", new="add", error=False)
def add_batch(self, batch: SampleBatchType, **kwargs) -> None:
"""Deprecated in favor of new ReplayBuffer API."""
return self.add(batch, **kwargs)
@ExperimentalAPI
@Deprecated(old="replay", new="sample", error=False)
def replay(self, num_items: int = 1, **kwargs) -> Optional[SampleBatchType]:
"""Deprecated in favor of new ReplayBuffer API."""
return self.sample(num_items, **kwargs)
@ExperimentalAPI
def add(self, batch: SampleBatchType, **kwargs) -> None:
"""Adds a batch of experiences to this buffer.
@ -522,3 +290,11 @@ class ReplayBuffer:
name could not be determined.
"""
return platform.node()
@Deprecated(old="ReplayBuffer.add_batch()", new="RepayBuffer.add()", error=False)
def add_batch(self, *args, **kwargs):
return self.add(*args, **kwargs)
@Deprecated(old="RepayBuffer.replay()", new="RepayBuffer.sample()", error=False)
def replay(self, *args, **kwargs):
return self.sample(*args, **kwargs)

View file

@ -0,0 +1,35 @@
import random
from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
from ray.rllib.utils.replay_buffers.utils import warn_replay_buffer_capacity
from ray.rllib.utils.typing import SampleBatchType
class SimpleReplayBuffer(ReplayBuffer):
"""Simple replay buffer that operates over entire batches."""
def __init__(self, capacity: int, storage_unit: str = "timesteps", **kwargs):
"""Initialize a SimpleReplayBuffer instance."""
super().__init__(capacity=capacity, storage_unit="timesteps", **kwargs)
self.replay_batches = []
self.replay_index = 0
@override(ReplayBuffer)
def add(self, batch: SampleBatchType, **kwargs) -> None:
warn_replay_buffer_capacity(item=batch, capacity=self.capacity)
if self.capacity > 0:
if len(self.replay_batches) < self.capacity:
self.replay_batches.append(batch)
else:
self.replay_batches[self.replay_index] = batch
self.replay_index += 1
self.replay_index %= self.capacity
@override(ReplayBuffer)
def sample(self, num_items: int, **kwargs) -> SampleBatchType:
return random.choice(self.replay_batches)
@override(ReplayBuffer)
def __len__(self):
return len(self.replay_batches)

View file

@ -1,12 +1,25 @@
import logging
import psutil
from typing import Optional
from ray.rllib.execution import MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer as LegacyMultiAgentReplayBuffer,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import deprecation_warning
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.replay_buffers import (
MultiAgentPrioritizedReplayBuffer,
ReplayBuffer,
)
from ray.rllib.utils.typing import ResultDict, SampleBatchType, TrainerConfigDict
from ray.util import log_once
logger = logging.getLogger(__name__)
def update_priorities_in_replay_buffer(
@ -63,3 +76,288 @@ def update_priorities_in_replay_buffer(
# Make the actual buffer API call to update the priority weights on all
# policies.
replay_buffer.update_priorities(prio_dict)
def sample_min_n_steps_from_buffer(
replay_buffer: ReplayBuffer, min_steps: int, count_by_agent_steps: bool
) -> Optional[SampleBatchType]:
"""Samples a minimum of n timesteps from a given replay buffer.
This utility method is primarily used by the QMIX algorithm and helps with
sampling a given number of time steps which has stored samples in units
of sequences or complete episodes. Samples n batches from replay buffer
until the total number of timesteps reaches `train_batch_size`.
Args:
replay_buffer: The replay buffer to sample from
num_timesteps: The number of timesteps to sample
count_by_agent_steps: Whether to count agent steps or env steps
Returns:
A concatenated SampleBatch or MultiAgentBatch with samples from the
buffer.
"""
train_batch_size = 0
train_batches = []
while train_batch_size < min_steps:
batch = replay_buffer.sample(num_items=1)
if batch is None:
return None
train_batches.append(batch)
train_batch_size += (
train_batches[-1].agent_steps()
if count_by_agent_steps
else train_batches[-1].env_steps()
)
# All batch types are the same type, hence we can use any concat_samples()
train_batch = SampleBatch.concat_samples(train_batches)
return train_batch
@ExperimentalAPI
def validate_buffer_config(config: dict):
if config.get("replay_buffer_config", None) is None:
config["replay_buffer_config"] = {}
prioritized_replay = config.get("prioritized_replay")
if prioritized_replay != DEPRECATED_VALUE:
deprecation_warning(
old="config['prioritized_replay']",
help="Replay prioritization specified at new location config["
"'replay_buffer_config']["
"'prioritized_replay'] will be overwritten.",
error=False,
)
config["replay_buffer_config"]["prioritized_replay"] = prioritized_replay
capacity = config.get("buffer_size", DEPRECATED_VALUE)
if capacity != DEPRECATED_VALUE:
deprecation_warning(
old="config['buffer_size']",
help="Buffer size specified at new location config["
"'replay_buffer_config']["
"'capacity'] will be overwritten.",
error=False,
)
config["replay_buffer_config"]["capacity"] = capacity
# Deprecation of old-style replay buffer args
# Warnings before checking of we need local buffer so that algorithms
# Without local buffer also get warned
deprecated_replay_buffer_keys = [
"prioritized_replay_alpha",
"prioritized_replay_beta",
"prioritized_replay_eps",
"learning_starts",
]
for k in deprecated_replay_buffer_keys:
if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning(
old="config[{}]".format(k),
help="config['replay_buffer_config'][{}] should be used "
"for Q-Learning algorithms. Ignore this warning if "
"you are not using a Q-Learning algorithm and still "
"provide {}."
"".format(k, k),
error=False,
)
# Copy values over to new location in config to support new
# and old configuration style
if config.get("replay_buffer_config") is not None:
config["replay_buffer_config"][k] = config[k]
# Old Ape-X configs may contain no_local_replay_buffer
no_local_replay_buffer = config.get("no_local_replay_buffer", False)
if no_local_replay_buffer:
deprecation_warning(
old="config['no_local_replay_buffer']",
help="no_local_replay_buffer specified at new location config["
"'replay_buffer_config']["
"'capacity'] will be overwritten.",
error=False,
)
config["replay_buffer_config"][
"no_local_replay_buffer"
] = no_local_replay_buffer
# TODO (Artur):
if config["replay_buffer_config"].get("no_local_replay_buffer", False):
return
replay_buffer_config = config["replay_buffer_config"]
assert (
"type" in replay_buffer_config
), "Can not instantiate ReplayBuffer from config without 'type' key."
# Check if old replay buffer should be instantiated
buffer_type = config["replay_buffer_config"]["type"]
if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Prepend old-style buffers' path
assert buffer_type == "MultiAgentReplayBuffer", (
"Without "
"ReplayBuffer "
"API, only "
"MultiAgentReplayBuffer "
"is supported!"
)
# Create valid full [module].[class] string for from_config
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
else:
assert buffer_type in [
"ray.rllib.execution.MultiAgentReplayBuffer",
Legacy_MultiAgentReplayBuffer,
], (
"Without ReplayBuffer API, only " "MultiAgentReplayBuffer is supported!"
)
config["replay_buffer_config"]["type"] = buffer_type
# Remove from config, so it's not passed into the buffer c'tor
config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
# We need to deprecate the old-style location of the following
# buffer arguments and make users put them into the
# "replay_buffer_config" field of their config.
replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
if replay_batch_size != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_batch_size"] = replay_batch_size
deprecation_warning(
old="config['replay_batch_size']",
help="Replay batch size specified at new "
"location config['replay_buffer_config']["
"'replay_batch_size'] will be overwritten.",
error=False,
)
replay_mode = config.get("replay_mode", DEPRECATED_VALUE)
if replay_mode != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_mode"] = replay_mode
deprecation_warning(
old="config['multiagent']['replay_mode']",
help="Replay sequence length specified at new "
"location config['replay_buffer_config']["
"'replay_mode'] will be overwritten.",
error=False,
)
# Can't use DEPRECATED_VALUE here because this is also a deliberate
# value set for some algorithms
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
replay_sequence_length = config.get("replay_sequence_length", None)
if replay_sequence_length is not None:
config["replay_buffer_config"][
"replay_sequence_length"
] = replay_sequence_length
deprecation_warning(
old="config['replay_sequence_length']",
help="Replay sequence length specified at new "
"location config['replay_buffer_config']["
"'replay_sequence_length'] will be overwritten.",
error=False,
)
replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
if replay_burn_in != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
deprecation_warning(
old="config['burn_in']",
help="Burn in specified at new location config["
"'replay_buffer_config']["
"'replay_burn_in'] will be overwritten.",
)
replay_zero_init_states = config.get(
"replay_zero_init_states", DEPRECATED_VALUE
)
if replay_zero_init_states != DEPRECATED_VALUE:
config["replay_buffer_config"][
"replay_zero_init_states"
] = replay_zero_init_states
deprecation_warning(
old="config['replay_zero_init_states']",
help="Replay zero init states specified at new location "
"config["
"'replay_buffer_config']["
"'replay_zero_init_states'] will be overwritten.",
error=False,
)
# TODO (Artur): Move this logic into config objects
if config["replay_buffer_config"].get("prioritized_replay", False):
is_prioritized_buffer = True
else:
is_prioritized_buffer = False
# This triggers non-prioritization in old-style replay buffer
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
else:
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Create valid full [module].[class] string for from_config
config["replay_buffer_config"]["type"] = (
"ray.rllib.utils.replay_buffers." + buffer_type
)
test_buffer = from_config(buffer_type, config["replay_buffer_config"])
if hasattr(test_buffer, "update_priorities"):
is_prioritized_buffer = True
else:
is_prioritized_buffer = False
if is_prioritized_buffer:
if config["multiagent"]["replay_mode"] == "lockstep":
raise ValueError(
"Prioritized replay is not supported when replay_mode=lockstep."
)
elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1:
raise ValueError(
"Prioritized replay is not supported when "
"replay_sequence_length > 1."
)
else:
if config.get("worker_side_prioritization"):
raise ValueError(
"Worker side prioritization is not supported when "
"prioritized_replay=False."
)
if config["replay_buffer_config"].get("replay_batch_size", None) is None:
# Fall back to train batch size if no replay batch size was provided
config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"]
# Pop prioritized replay because it's not a valid parameter for older
# replay buffers
config["replay_buffer_config"].pop("prioritized_replay", None)
def warn_replay_buffer_capacity(*, item: SampleBatchType, capacity: int) -> None:
"""Warn if the configured replay buffer capacity is too large for machine's memory.
Args:
item: A (example) item that's supposed to be added to the buffer.
This is used to compute the overall memory footprint estimate for the
buffer.
capacity: The capacity value of the buffer. This is interpreted as the
number of items (such as given `item`) that will eventually be stored in
the buffer.
Raises:
ValueError: If computed memory footprint for the buffer exceeds the machine's
RAM.
"""
if log_once("warn_replay_buffer_capacity"):
item_size = item.size_bytes()
psutil_mem = psutil.virtual_memory()
total_gb = psutil_mem.total / 1e9
mem_size = capacity * item_size / 1e9
msg = (
"Estimated max memory usage for replay buffer is {} GB "
"({} batches of size {}, {} bytes each), "
"available system memory is {} GB".format(
mem_size, capacity, item.count, item_size, total_gb
)
)
if mem_size > total_gb:
raise ValueError(msg)
elif mem_size > 0.2 * total_gb:
logger.warning(msg)
else:
logger.info(msg)