mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] QMIX training iteration function and new replay buffer API. (#24164)
This commit is contained in:
parent
29388fb25b
commit
627b9f2e88
13 changed files with 612 additions and 378 deletions
|
@ -19,7 +19,7 @@ from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
|
|||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import validate_buffer_config
|
||||
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
|
@ -34,12 +34,11 @@ from ray.rllib.execution.train_ops import (
|
|||
UpdateTargetNetwork,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.metrics import (
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
TARGET_NET_UPDATE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
|
@ -200,6 +199,73 @@ class SimpleQTrainer(Trainer):
|
|||
else:
|
||||
return SimpleQTFPolicy
|
||||
|
||||
@ExperimentalAPI
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
"""Simple Q training iteration function.
|
||||
|
||||
Simple Q consists of the following steps:
|
||||
- Sample n MultiAgentBatches from n workers synchronously.
|
||||
- Store new samples in the replay buffer.
|
||||
- Sample one training MultiAgentBatch from the replay buffer.
|
||||
- Learn on the training batch.
|
||||
- Update the target network every `target_network_update_freq` steps.
|
||||
- Return all collected training metrics for the iteration.
|
||||
|
||||
Returns:
|
||||
The results dict from executing the training iteration.
|
||||
"""
|
||||
batch_size = self.config["train_batch_size"]
|
||||
local_worker = self.workers.local_worker()
|
||||
|
||||
# Sample n MultiAgentBatches from n workers.
|
||||
new_sample_batches = synchronous_parallel_sample(
|
||||
worker_set=self.workers, concat=False
|
||||
)
|
||||
|
||||
for batch in new_sample_batches:
|
||||
# Update sampling step counters.
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
|
||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
|
||||
# Store new samples in the replay buffer
|
||||
self.local_replay_buffer.add(batch)
|
||||
|
||||
# Sample one training MultiAgentBatch from replay buffer.
|
||||
train_batch = self.local_replay_buffer.sample(batch_size)
|
||||
|
||||
# Learn on the training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer") is True:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# TODO: Move training steps counter update outside of `train_one_step()` method.
|
||||
# # Update train step counters.
|
||||
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
|
||||
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
|
||||
|
||||
# Update target network every `target_network_update_freq` steps.
|
||||
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||
with self._timers[TARGET_NET_UPDATE_TIMER]:
|
||||
to_update = local_worker.get_policies_to_train()
|
||||
local_worker.foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
# Update remote workers' weights after learning on local worker.
|
||||
if self.workers.remote_workers():
|
||||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
||||
self.workers.sync_weights()
|
||||
|
||||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
|
@ -242,66 +308,3 @@ class SimpleQTrainer(Trainer):
|
|||
)
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
@ExperimentalAPI
|
||||
def training_iteration(self) -> ResultDict:
|
||||
"""Simple Q training iteration function.
|
||||
|
||||
Simple Q consists of the following steps:
|
||||
- (1) Sample (MultiAgentBatch) from workers...
|
||||
- (2) Store new samples in replay buffer.
|
||||
- (3) Sample training batch (MultiAgentBatch) from replay buffer.
|
||||
- (4) Learn on training batch.
|
||||
- (5) Update target network every target_network_update_freq steps.
|
||||
- (6) Return all collected metrics for the iteration.
|
||||
|
||||
Returns:
|
||||
The results dict from executing the training iteration.
|
||||
"""
|
||||
batch_size = self.config["train_batch_size"]
|
||||
local_worker = self.workers.local_worker()
|
||||
|
||||
# (1) Sample (MultiAgentBatches) from workers
|
||||
new_sample_batches = synchronous_parallel_sample(
|
||||
worker_set=self.workers, concat=False
|
||||
)
|
||||
|
||||
for s in new_sample_batches:
|
||||
# Update counters
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += len(s)
|
||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += (
|
||||
len(s) if isinstance(s, SampleBatch) else s.agent_steps()
|
||||
)
|
||||
# (2) Store new samples in replay buffer
|
||||
self.local_replay_buffer.add(s)
|
||||
|
||||
# (3) Sample training batch (MultiAgentBatch) from replay buffer.
|
||||
train_batch = self.local_replay_buffer.sample(batch_size)
|
||||
|
||||
# (4) Learn on training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer") is True:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# (5) Update target network every target_network_update_freq steps
|
||||
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||
with self._timers[TARGET_NET_UPDATE_TIMER]:
|
||||
to_update = local_worker.get_policies_to_train()
|
||||
local_worker.foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
# Update remote workers' weights after learning on local worker
|
||||
if self.workers.remote_workers():
|
||||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
||||
self.workers.sync_weights()
|
||||
|
||||
# (6) Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
|
|
@ -11,11 +11,29 @@ from ray.rllib.execution.replay_ops import (
|
|||
Replay,
|
||||
StoreToReplayBuffer,
|
||||
)
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ConcatBatches,
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
multi_gpu_train_one_step,
|
||||
train_one_step,
|
||||
TrainOneStep,
|
||||
UpdateTargetNetwork,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.metrics import (
|
||||
LAST_TARGET_UPDATE_TS,
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
NUM_TARGET_UPDATES,
|
||||
SYNCH_WORKER_WEIGHTS_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
||||
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
# fmt: off
|
||||
|
@ -61,17 +79,21 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"explore": False,
|
||||
},
|
||||
|
||||
# Number of env steps to optimize for before returning
|
||||
# Number of env steps to optimize for before returning.
|
||||
"timesteps_per_iteration": 1000,
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
"target_network_update_freq": 500,
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"buffer_size": 1000,
|
||||
"replay_buffer_config": {
|
||||
"no_local_replay_buffer": True,
|
||||
# Use the new ReplayBuffer API here
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "SimpleReplayBuffer",
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"capacity": 1000,
|
||||
"learning_starts": 1000,
|
||||
},
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for RMSProp optimizer
|
||||
"lr": 0.0005,
|
||||
|
@ -81,14 +103,12 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"optim_eps": 0.00001,
|
||||
# If not None, clip gradients during optimization at this value
|
||||
"grad_norm_clipping": 10,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1000,
|
||||
# Update the replay buffer with this many samples at once. Note that
|
||||
# this setting applies per-worker if num_workers > 1.
|
||||
"rollout_fragment_length": 4,
|
||||
# Size of a batched sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
# Minimum batch size used for training (in timesteps). With the default buffer
|
||||
# (ReplayBuffer) this means, sampling from the buffer (entire-episode SampleBatches)
|
||||
# as many times as is required to reach at least this number of timesteps.
|
||||
"train_batch_size": 32,
|
||||
|
||||
# === Parallelism ===
|
||||
|
@ -108,11 +128,18 @@ DEFAULT_CONFIG = with_common_config({
|
|||
},
|
||||
# Only torch supported so far.
|
||||
"framework": "torch",
|
||||
# Experimental flag.
|
||||
|
||||
# === Experimental Flags ===
|
||||
# If True, the execution plan API will not be used. Instead,
|
||||
# a Trainer's `training_iteration` method will be called as-is each
|
||||
# training iteration.
|
||||
"_disable_execution_plan_api": False,
|
||||
"_disable_execution_plan_api": True,
|
||||
|
||||
# Deprecated keys:
|
||||
# Use `replay_buffer_config.learning_starts` instead.
|
||||
"learning_starts": DEPRECATED_VALUE,
|
||||
# Use `replay_buffer_config.capacity` instead.
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
@ -136,6 +163,78 @@ class QMixTrainer(SimpleQTrainer):
|
|||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
return QMixTorchPolicy
|
||||
|
||||
@override(SimpleQTrainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
"""QMIX training iteration function.
|
||||
|
||||
- Sample n MultiAgentBatches from n workers synchronously.
|
||||
- Store new samples in the replay buffer.
|
||||
- Sample one training MultiAgentBatch from the replay buffer.
|
||||
- Learn on the training batch.
|
||||
- Update the target network every `target_network_update_freq` steps.
|
||||
- Return all collected training metrics for the iteration.
|
||||
|
||||
Returns:
|
||||
The results dict from executing the training iteration.
|
||||
"""
|
||||
# Sample n batches from n workers.
|
||||
new_sample_batches = synchronous_parallel_sample(
|
||||
worker_set=self.workers, concat=False
|
||||
)
|
||||
|
||||
for batch in new_sample_batches:
|
||||
# Update counters.
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
|
||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
|
||||
# Store new samples in the replay buffer.
|
||||
self.local_replay_buffer.add(batch)
|
||||
|
||||
# Sample n batches from replay buffer until the total number of timesteps
|
||||
# reaches `train_batch_size`.
|
||||
train_batch = sample_min_n_steps_from_buffer(
|
||||
replay_buffer=self.local_replay_buffer,
|
||||
min_steps=self.config["train_batch_size"],
|
||||
count_by_agent_steps=self._by_agent_steps,
|
||||
)
|
||||
if train_batch is None:
|
||||
return {}
|
||||
|
||||
# Learn on the training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer") is True:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# TODO: Move training steps counter update outside of `train_one_step()` method.
|
||||
# # Update train step counters.
|
||||
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
|
||||
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
|
||||
|
||||
# Update target network every `target_network_update_freq` steps.
|
||||
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||
to_update = self.workers.local_worker().get_policies_to_train()
|
||||
self.workers.local_worker().foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
# Update weights and global_vars - after learning on the local worker - on all
|
||||
# remote workers.
|
||||
global_vars = {
|
||||
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
|
||||
}
|
||||
# Update remote workers' weights and global vars after learning on local worker.
|
||||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
||||
self.workers.sync_weights(global_vars=global_vars)
|
||||
|
||||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(SimpleQTrainer)
|
||||
def execution_plan(
|
||||
|
|
|
@ -1,22 +1,24 @@
|
|||
from gym.spaces import Tuple, Discrete, Dict
|
||||
import gym
|
||||
import logging
|
||||
import numpy as np
|
||||
import tree # pip install dm_tree
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
|
||||
from ray.rllib.agents.qmix.model import RNNModel, _get_size
|
||||
from ray.rllib.env.multi_agent_env import ENV_STATE
|
||||
from ray.rllib.env.wrappers.group_agents_wrapper import GROUP_REWARDS
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import _unpack_obs
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
# Torch must be installed.
|
||||
torch, nn = try_import_torch(error=True)
|
||||
|
@ -152,7 +154,7 @@ class QMixLoss(nn.Module):
|
|||
|
||||
|
||||
# TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
|
||||
class QMixTorchPolicy(Policy):
|
||||
class QMixTorchPolicy(TorchPolicy):
|
||||
"""QMix impl. Assumes homogeneous agents for now.
|
||||
|
||||
You must use MultiAgentEnv.with_agent_groups() to group agents
|
||||
|
@ -168,7 +170,7 @@ class QMixTorchPolicy(Policy):
|
|||
_validate(obs_space, action_space)
|
||||
config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
|
||||
self.framework = "torch"
|
||||
super().__init__(obs_space, action_space, config)
|
||||
|
||||
self.n_agents = len(obs_space.original_space.spaces)
|
||||
config["model"]["n_agents"] = self.n_agents
|
||||
self.n_actions = action_space.spaces[0].n
|
||||
|
@ -180,7 +182,7 @@ class QMixTorchPolicy(Policy):
|
|||
)
|
||||
|
||||
agent_obs_space = obs_space.original_space.spaces[0]
|
||||
if isinstance(agent_obs_space, Dict):
|
||||
if isinstance(agent_obs_space, gym.spaces.Dict):
|
||||
space_keys = set(agent_obs_space.spaces.keys())
|
||||
if "obs" not in space_keys:
|
||||
raise ValueError("Dict obs space must have subspace labeled `obs`")
|
||||
|
@ -228,6 +230,8 @@ class QMixTorchPolicy(Policy):
|
|||
default_model=RNNModel,
|
||||
).to(self.device)
|
||||
|
||||
super().__init__(obs_space, action_space, config, model=self.model)
|
||||
|
||||
self.exploration = self._create_exploration()
|
||||
|
||||
# Setup the mixer network.
|
||||
|
@ -273,19 +277,22 @@ class QMixTorchPolicy(Policy):
|
|||
eps=config["optim_eps"],
|
||||
)
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(
|
||||
@override(TorchPolicy)
|
||||
def compute_actions_from_input_dict(
|
||||
self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
explore=None,
|
||||
timestep=None,
|
||||
**kwargs
|
||||
):
|
||||
input_dict: Dict[str, TensorType],
|
||||
explore: bool = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
|
||||
obs_batch = input_dict[SampleBatch.OBS]
|
||||
state_batches = []
|
||||
i = 0
|
||||
while f"state_in_{i}" in input_dict:
|
||||
state_batches.append(input_dict[f"state_in_{i}"])
|
||||
i += 1
|
||||
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
|
||||
# We need to ensure we do not use the env global state
|
||||
|
@ -319,7 +326,11 @@ class QMixTorchPolicy(Policy):
|
|||
|
||||
return tuple(actions.transpose([1, 0])), hiddens, {}
|
||||
|
||||
@override(Policy)
|
||||
@override(TorchPolicy)
|
||||
def compute_actions(self, *args, **kwargs):
|
||||
return self.compute_actions_from_input_dict(*args, **kwargs)
|
||||
|
||||
@override(TorchPolicy)
|
||||
def compute_log_likelihoods(
|
||||
self,
|
||||
actions,
|
||||
|
@ -331,7 +342,7 @@ class QMixTorchPolicy(Policy):
|
|||
obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
|
||||
return np.zeros(obs_batch.size()[0])
|
||||
|
||||
@override(Policy)
|
||||
@override(TorchPolicy)
|
||||
def learn_on_batch(self, samples):
|
||||
obs_batch, action_mask, env_global_state = self._unpack_observation(
|
||||
samples[SampleBatch.CUR_OBS]
|
||||
|
@ -456,14 +467,14 @@ class QMixTorchPolicy(Policy):
|
|||
}
|
||||
return {LEARNER_STATS_KEY: stats}
|
||||
|
||||
@override(Policy)
|
||||
@override(TorchPolicy)
|
||||
def get_initial_state(self): # initial RNN state
|
||||
return [
|
||||
s.expand([self.n_agents, -1]).cpu().numpy()
|
||||
for s in self.model.get_initial_state()
|
||||
]
|
||||
|
||||
@override(Policy)
|
||||
@override(TorchPolicy)
|
||||
def get_weights(self):
|
||||
return {
|
||||
"model": self._cpu_dict(self.model.state_dict()),
|
||||
|
@ -474,7 +485,7 @@ class QMixTorchPolicy(Policy):
|
|||
else None,
|
||||
}
|
||||
|
||||
@override(Policy)
|
||||
@override(TorchPolicy)
|
||||
def set_weights(self, weights):
|
||||
self.model.load_state_dict(self._device_dict(weights["model"]))
|
||||
self.target_model.load_state_dict(self._device_dict(weights["target_model"]))
|
||||
|
@ -484,13 +495,13 @@ class QMixTorchPolicy(Policy):
|
|||
self._device_dict(weights["target_mixer"])
|
||||
)
|
||||
|
||||
@override(Policy)
|
||||
@override(TorchPolicy)
|
||||
def get_state(self):
|
||||
state = self.get_weights()
|
||||
state["cur_epsilon"] = self.cur_epsilon
|
||||
return state
|
||||
|
||||
@override(Policy)
|
||||
@override(TorchPolicy)
|
||||
def set_state(self, state):
|
||||
self.set_weights(state)
|
||||
self.set_epsilon(state["cur_epsilon"])
|
||||
|
@ -564,20 +575,20 @@ class QMixTorchPolicy(Policy):
|
|||
|
||||
def _validate(obs_space, action_space):
|
||||
if not hasattr(obs_space, "original_space") or not isinstance(
|
||||
obs_space.original_space, Tuple
|
||||
obs_space.original_space, gym.spaces.Tuple
|
||||
):
|
||||
raise ValueError(
|
||||
"Obs space must be a Tuple, got {}. Use ".format(obs_space)
|
||||
+ "MultiAgentEnv.with_agent_groups() to group related "
|
||||
"agents for QMix."
|
||||
)
|
||||
if not isinstance(action_space, Tuple):
|
||||
if not isinstance(action_space, gym.spaces.Tuple):
|
||||
raise ValueError(
|
||||
"Action space must be a Tuple, got {}. ".format(action_space)
|
||||
+ "Use MultiAgentEnv.with_agent_groups() to group related "
|
||||
"agents for QMix."
|
||||
)
|
||||
if not isinstance(action_space.spaces[0], Discrete):
|
||||
if not isinstance(action_space.spaces[0], gym.spaces.Discrete):
|
||||
raise ValueError(
|
||||
"QMix requires a discrete action space, got {}".format(
|
||||
action_space.spaces[0]
|
||||
|
|
|
@ -33,7 +33,7 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import validate_buffer_config
|
||||
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -1469,14 +1469,13 @@ class Trainer(Trainable):
|
|||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# Update weights - after learning on the local worker - on all remote
|
||||
# workers.
|
||||
# Update weights and global_vars - after learning on the local worker - on all
|
||||
# remote workers.
|
||||
global_vars = {
|
||||
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
|
||||
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
|
||||
}
|
||||
if self.workers.remote_workers():
|
||||
with self._timers[WORKER_UPDATE_TIMER]:
|
||||
self.workers.sync_weights(global_vars=global_vars)
|
||||
with self._timers[WORKER_UPDATE_TIMER]:
|
||||
self.workers.sync_weights(global_vars=global_vars)
|
||||
|
||||
return train_results
|
||||
|
||||
|
|
7
rllib/env/multi_agent_env.py
vendored
7
rllib/env/multi_agent_env.py
vendored
|
@ -17,6 +17,7 @@ from ray.rllib.utils.typing import (
|
|||
MultiAgentDict,
|
||||
MultiEnvDict,
|
||||
)
|
||||
from ray.util import log_once
|
||||
|
||||
# If the obs space is Dict type, look for the global state under this key.
|
||||
ENV_STATE = "state"
|
||||
|
@ -156,7 +157,8 @@ class MultiAgentEnv(gym.Env):
|
|||
if self._spaces_in_preferred_format:
|
||||
return self.action_space.contains(x)
|
||||
|
||||
logger.warning("action_space_contains() has not been implemented")
|
||||
if log_once("action_space_contains"):
|
||||
logger.warning("action_space_contains() has not been implemented")
|
||||
return True
|
||||
|
||||
@ExperimentalAPI
|
||||
|
@ -219,7 +221,8 @@ class MultiAgentEnv(gym.Env):
|
|||
samples = self.observation_space.sample()
|
||||
samples = {agent_id: samples[agent_id] for agent_id in agent_ids}
|
||||
return samples
|
||||
logger.warning("observation_space_sample() has not been implemented")
|
||||
if log_once("observation_space_sample"):
|
||||
logger.warning("observation_space_sample() has not been implemented")
|
||||
return {}
|
||||
|
||||
@PublicAPI
|
||||
|
|
10
rllib/env/wrappers/atari_wrappers.py
vendored
10
rllib/env/wrappers/atari_wrappers.py
vendored
|
@ -95,9 +95,13 @@ class NoopResetEnv(gym.Wrapper):
|
|||
if self.override_num_noops is not None:
|
||||
noops = self.override_num_noops
|
||||
else:
|
||||
# this environment now uses the pcg64 random number generator which
|
||||
# does not have randint as an attribute only has integers
|
||||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||
# This environment now uses the pcg64 random number generator which
|
||||
# does not have randint as an attribute only has integers.
|
||||
try:
|
||||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||
# Also still support older versions.
|
||||
except AttributeError:
|
||||
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
|
||||
assert noops > 0
|
||||
obs = None
|
||||
for _ in range(noops):
|
||||
|
|
|
@ -207,6 +207,7 @@ class WorkerSet:
|
|||
)
|
||||
|
||||
# Only sync if we have remote workers or `from_worker` is provided.
|
||||
weights = None
|
||||
if self.remote_workers() or from_worker is not None:
|
||||
weights = (from_worker or self.local_worker()).get_weights(policies)
|
||||
# Put weights only once into object store and use same object
|
||||
|
@ -216,14 +217,14 @@ class WorkerSet:
|
|||
for to_worker in self.remote_workers():
|
||||
to_worker.set_weights.remote(weights_ref, global_vars=global_vars)
|
||||
|
||||
# If `from_worker` is provided, also sync to this WorkerSet's
|
||||
# local worker.
|
||||
if from_worker is not None and self.local_worker() is not None:
|
||||
self.local_worker().set_weights(weights, global_vars=global_vars)
|
||||
# If `global_vars` is provided and local worker exists -> Update its
|
||||
# global_vars.
|
||||
elif self.local_worker() is not None and global_vars is not None:
|
||||
self.local_worker().set_global_vars(global_vars)
|
||||
# If `from_worker` is provided, also sync to this WorkerSet's
|
||||
# local worker.
|
||||
if from_worker is not None and self.local_worker() is not None:
|
||||
self.local_worker().set_weights(weights, global_vars=global_vars)
|
||||
# If `global_vars` is provided and local worker exists -> Update its
|
||||
# global_vars.
|
||||
elif self.local_worker() is not None and global_vars is not None:
|
||||
self.local_worker().set_global_vars(global_vars)
|
||||
|
||||
def add_workers(self, num_workers: int) -> None:
|
||||
"""Creates and adds a number of remote workers to this worker set.
|
||||
|
|
|
@ -134,6 +134,9 @@ if __name__ == "__main__":
|
|||
# Evaluate every other training iteration (together
|
||||
# with every other call to Trainer.train()).
|
||||
"evaluation_interval": args.evaluation_interval,
|
||||
"evaluation_config": {
|
||||
"input_evaluation": ["is"],
|
||||
},
|
||||
# Run for n episodes/timesteps (properly distribute load amongst
|
||||
# all eval workers). The longer it takes to evaluate, the more sense
|
||||
# it makes to use `evaluation_parallel_to_training=True`.
|
||||
|
|
|
@ -1,26 +1,28 @@
|
|||
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer,
|
||||
ReplayMode,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.reservoir_buffer import ReservoirBuffer
|
||||
from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import (
|
||||
PrioritizedReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import (
|
||||
MultiAgentMixInReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
|
||||
MultiAgentPrioritizedReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer,
|
||||
ReplayMode,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import (
|
||||
PrioritizedReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit
|
||||
from ray.rllib.utils.replay_buffers.reservoir_buffer import ReservoirBuffer
|
||||
from ray.rllib.utils.replay_buffers.simple_replay_buffer import SimpleReplayBuffer
|
||||
|
||||
__all__ = [
|
||||
"ReplayBuffer",
|
||||
"StorageUnit",
|
||||
"MultiAgentReplayBuffer",
|
||||
"ReplayMode",
|
||||
"ReservoirBuffer",
|
||||
"PrioritizedReplayBuffer",
|
||||
"MultiAgentMixInReplayBuffer",
|
||||
"MultiAgentPrioritizedReplayBuffer",
|
||||
"MultiAgentReplayBuffer",
|
||||
"PrioritizedReplayBuffer",
|
||||
"ReplayMode",
|
||||
"ReplayBuffer",
|
||||
"ReservoirBuffer",
|
||||
"SimpleReplayBuffer",
|
||||
"StorageUnit",
|
||||
]
|
||||
|
|
|
@ -17,12 +17,6 @@ from ray.rllib.utils.deprecation import Deprecated
|
|||
from ray.rllib.utils.metrics.window_stat import WindowStat
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
|
||||
# Constant that represents all policies in lockstep replay mode.
|
||||
_ALL_POLICIES = "__all__"
|
||||
|
@ -37,220 +31,6 @@ class StorageUnit(Enum):
|
|||
EPISODES = "episodes"
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
def validate_buffer_config(config: dict):
|
||||
if config.get("replay_buffer_config", None) is None:
|
||||
config["replay_buffer_config"] = {}
|
||||
|
||||
prioritized_replay = config.get("prioritized_replay")
|
||||
if prioritized_replay != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['prioritized_replay']",
|
||||
help="Replay prioritization specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'prioritized_replay'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
config["replay_buffer_config"]["prioritized_replay"] = prioritized_replay
|
||||
|
||||
capacity = config.get("buffer_size", DEPRECATED_VALUE)
|
||||
if capacity != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['buffer_size']",
|
||||
help="Buffer size specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'capacity'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
config["replay_buffer_config"]["capacity"] = capacity
|
||||
|
||||
# Deprecation of old-style replay buffer args
|
||||
# Warnings before checking of we need local buffer so that algorithms
|
||||
# Without local buffer also get warned
|
||||
deprecated_replay_buffer_keys = [
|
||||
"prioritized_replay_alpha",
|
||||
"prioritized_replay_beta",
|
||||
"prioritized_replay_eps",
|
||||
"learning_starts",
|
||||
]
|
||||
for k in deprecated_replay_buffer_keys:
|
||||
if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config[{}]".format(k),
|
||||
help="config['replay_buffer_config'][{}] should be used "
|
||||
"for Q-Learning algorithms. Ignore this warning if "
|
||||
"you are not using a Q-Learning algorithm and still "
|
||||
"provide {}."
|
||||
"".format(k, k),
|
||||
error=False,
|
||||
)
|
||||
# Copy values over to new location in config to support new
|
||||
# and old configuration style
|
||||
if config.get("replay_buffer_config") is not None:
|
||||
config["replay_buffer_config"][k] = config[k]
|
||||
|
||||
# Old Ape-X configs may contain no_local_replay_buffer
|
||||
no_local_replay_buffer = config.get("no_local_replay_buffer", False)
|
||||
if no_local_replay_buffer:
|
||||
deprecation_warning(
|
||||
old="config['no_local_replay_buffer']",
|
||||
help="no_local_replay_buffer specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'capacity'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
config["replay_buffer_config"][
|
||||
"no_local_replay_buffer"
|
||||
] = no_local_replay_buffer
|
||||
|
||||
# TODO (Artur):
|
||||
if config["replay_buffer_config"].get("no_local_replay_buffer", False):
|
||||
return
|
||||
|
||||
replay_buffer_config = config["replay_buffer_config"]
|
||||
assert (
|
||||
"type" in replay_buffer_config
|
||||
), "Can not instantiate ReplayBuffer from config without 'type' key."
|
||||
|
||||
# Check if old replay buffer should be instantiated
|
||||
buffer_type = config["replay_buffer_config"]["type"]
|
||||
if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
|
||||
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
||||
# Prepend old-style buffers' path
|
||||
assert buffer_type == "MultiAgentReplayBuffer", (
|
||||
"Without "
|
||||
"ReplayBuffer "
|
||||
"API, only "
|
||||
"MultiAgentReplayBuffer "
|
||||
"is supported!"
|
||||
)
|
||||
# Create valid full [module].[class] string for from_config
|
||||
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
|
||||
else:
|
||||
assert buffer_type in [
|
||||
"ray.rllib.execution.MultiAgentReplayBuffer",
|
||||
Legacy_MultiAgentReplayBuffer,
|
||||
], (
|
||||
"Without ReplayBuffer API, only " "MultiAgentReplayBuffer is supported!"
|
||||
)
|
||||
|
||||
config["replay_buffer_config"]["type"] = buffer_type
|
||||
|
||||
# Remove from config, so it's not passed into the buffer c'tor
|
||||
config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
|
||||
|
||||
# We need to deprecate the old-style location of the following
|
||||
# buffer arguments and make users put them into the
|
||||
# "replay_buffer_config" field of their config.
|
||||
replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
|
||||
if replay_batch_size != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"]["replay_batch_size"] = replay_batch_size
|
||||
deprecation_warning(
|
||||
old="config['replay_batch_size']",
|
||||
help="Replay batch size specified at new "
|
||||
"location config['replay_buffer_config']["
|
||||
"'replay_batch_size'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
replay_mode = config.get("replay_mode", DEPRECATED_VALUE)
|
||||
if replay_mode != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"]["replay_mode"] = replay_mode
|
||||
deprecation_warning(
|
||||
old="config['multiagent']['replay_mode']",
|
||||
help="Replay sequence length specified at new "
|
||||
"location config['replay_buffer_config']["
|
||||
"'replay_mode'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
# Can't use DEPRECATED_VALUE here because this is also a deliberate
|
||||
# value set for some algorithms
|
||||
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
|
||||
replay_sequence_length = config.get("replay_sequence_length", None)
|
||||
if replay_sequence_length is not None:
|
||||
config["replay_buffer_config"][
|
||||
"replay_sequence_length"
|
||||
] = replay_sequence_length
|
||||
deprecation_warning(
|
||||
old="config['replay_sequence_length']",
|
||||
help="Replay sequence length specified at new "
|
||||
"location config['replay_buffer_config']["
|
||||
"'replay_sequence_length'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
|
||||
if replay_burn_in != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
|
||||
deprecation_warning(
|
||||
old="config['burn_in']",
|
||||
help="Burn in specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'replay_burn_in'] will be overwritten.",
|
||||
)
|
||||
|
||||
replay_zero_init_states = config.get(
|
||||
"replay_zero_init_states", DEPRECATED_VALUE
|
||||
)
|
||||
if replay_zero_init_states != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"][
|
||||
"replay_zero_init_states"
|
||||
] = replay_zero_init_states
|
||||
deprecation_warning(
|
||||
old="config['replay_zero_init_states']",
|
||||
help="Replay zero init states specified at new location "
|
||||
"config["
|
||||
"'replay_buffer_config']["
|
||||
"'replay_zero_init_states'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
# TODO (Artur): Move this logic into config objects
|
||||
if config["replay_buffer_config"].get("prioritized_replay", False):
|
||||
is_prioritized_buffer = True
|
||||
else:
|
||||
is_prioritized_buffer = False
|
||||
# This triggers non-prioritization in old-style replay buffer
|
||||
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
|
||||
else:
|
||||
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
||||
# Create valid full [module].[class] string for from_config
|
||||
config["replay_buffer_config"]["type"] = (
|
||||
"ray.rllib.utils.replay_buffers." + buffer_type
|
||||
)
|
||||
test_buffer = from_config(buffer_type, config["replay_buffer_config"])
|
||||
if hasattr(test_buffer, "update_priorities"):
|
||||
is_prioritized_buffer = True
|
||||
else:
|
||||
is_prioritized_buffer = False
|
||||
|
||||
if is_prioritized_buffer:
|
||||
if config["multiagent"]["replay_mode"] == "lockstep":
|
||||
raise ValueError(
|
||||
"Prioritized replay is not supported when replay_mode=lockstep."
|
||||
)
|
||||
elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1:
|
||||
raise ValueError(
|
||||
"Prioritized replay is not supported when "
|
||||
"replay_sequence_length > 1."
|
||||
)
|
||||
else:
|
||||
if config.get("worker_side_prioritization"):
|
||||
raise ValueError(
|
||||
"Worker side prioritization is not supported when "
|
||||
"prioritized_replay=False."
|
||||
)
|
||||
|
||||
if config["replay_buffer_config"].get("replay_batch_size", None) is None:
|
||||
# Fall back to train batch size if no replay batch size was provided
|
||||
config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"]
|
||||
|
||||
# Pop prioritized replay because it's not a valid parameter for older
|
||||
# replay buffers
|
||||
config["replay_buffer_config"].pop("prioritized_replay", None)
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
class ReplayBuffer:
|
||||
def __init__(
|
||||
|
@ -316,18 +96,6 @@ class ReplayBuffer:
|
|||
"""Returns the number of items currently stored in this buffer."""
|
||||
return len(self._storage)
|
||||
|
||||
@ExperimentalAPI
|
||||
@Deprecated(old="add_batch", new="add", error=False)
|
||||
def add_batch(self, batch: SampleBatchType, **kwargs) -> None:
|
||||
"""Deprecated in favor of new ReplayBuffer API."""
|
||||
return self.add(batch, **kwargs)
|
||||
|
||||
@ExperimentalAPI
|
||||
@Deprecated(old="replay", new="sample", error=False)
|
||||
def replay(self, num_items: int = 1, **kwargs) -> Optional[SampleBatchType]:
|
||||
"""Deprecated in favor of new ReplayBuffer API."""
|
||||
return self.sample(num_items, **kwargs)
|
||||
|
||||
@ExperimentalAPI
|
||||
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
||||
"""Adds a batch of experiences to this buffer.
|
||||
|
@ -522,3 +290,11 @@ class ReplayBuffer:
|
|||
name could not be determined.
|
||||
"""
|
||||
return platform.node()
|
||||
|
||||
@Deprecated(old="ReplayBuffer.add_batch()", new="RepayBuffer.add()", error=False)
|
||||
def add_batch(self, *args, **kwargs):
|
||||
return self.add(*args, **kwargs)
|
||||
|
||||
@Deprecated(old="RepayBuffer.replay()", new="RepayBuffer.sample()", error=False)
|
||||
def replay(self, *args, **kwargs):
|
||||
return self.sample(*args, **kwargs)
|
||||
|
|
35
rllib/utils/replay_buffers/simple_replay_buffer.py
Normal file
35
rllib/utils/replay_buffers/simple_replay_buffer.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import random
|
||||
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
|
||||
from ray.rllib.utils.replay_buffers.utils import warn_replay_buffer_capacity
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
|
||||
class SimpleReplayBuffer(ReplayBuffer):
|
||||
"""Simple replay buffer that operates over entire batches."""
|
||||
|
||||
def __init__(self, capacity: int, storage_unit: str = "timesteps", **kwargs):
|
||||
"""Initialize a SimpleReplayBuffer instance."""
|
||||
super().__init__(capacity=capacity, storage_unit="timesteps", **kwargs)
|
||||
self.replay_batches = []
|
||||
self.replay_index = 0
|
||||
|
||||
@override(ReplayBuffer)
|
||||
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
||||
warn_replay_buffer_capacity(item=batch, capacity=self.capacity)
|
||||
if self.capacity > 0:
|
||||
if len(self.replay_batches) < self.capacity:
|
||||
self.replay_batches.append(batch)
|
||||
else:
|
||||
self.replay_batches[self.replay_index] = batch
|
||||
self.replay_index += 1
|
||||
self.replay_index %= self.capacity
|
||||
|
||||
@override(ReplayBuffer)
|
||||
def sample(self, num_items: int, **kwargs) -> SampleBatchType:
|
||||
return random.choice(self.replay_batches)
|
||||
|
||||
@override(ReplayBuffer)
|
||||
def __len__(self):
|
||||
return len(self.replay_batches)
|
|
@ -1,12 +1,25 @@
|
|||
import logging
|
||||
import psutil
|
||||
from typing import Optional
|
||||
|
||||
from ray.rllib.execution import MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer as LegacyMultiAgentReplayBuffer,
|
||||
)
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import deprecation_warning
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.replay_buffers import (
|
||||
MultiAgentPrioritizedReplayBuffer,
|
||||
ReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.typing import ResultDict, SampleBatchType, TrainerConfigDict
|
||||
from ray.util import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def update_priorities_in_replay_buffer(
|
||||
|
@ -63,3 +76,288 @@ def update_priorities_in_replay_buffer(
|
|||
# Make the actual buffer API call to update the priority weights on all
|
||||
# policies.
|
||||
replay_buffer.update_priorities(prio_dict)
|
||||
|
||||
|
||||
def sample_min_n_steps_from_buffer(
|
||||
replay_buffer: ReplayBuffer, min_steps: int, count_by_agent_steps: bool
|
||||
) -> Optional[SampleBatchType]:
|
||||
"""Samples a minimum of n timesteps from a given replay buffer.
|
||||
|
||||
This utility method is primarily used by the QMIX algorithm and helps with
|
||||
sampling a given number of time steps which has stored samples in units
|
||||
of sequences or complete episodes. Samples n batches from replay buffer
|
||||
until the total number of timesteps reaches `train_batch_size`.
|
||||
|
||||
Args:
|
||||
replay_buffer: The replay buffer to sample from
|
||||
num_timesteps: The number of timesteps to sample
|
||||
count_by_agent_steps: Whether to count agent steps or env steps
|
||||
|
||||
Returns:
|
||||
A concatenated SampleBatch or MultiAgentBatch with samples from the
|
||||
buffer.
|
||||
"""
|
||||
train_batch_size = 0
|
||||
train_batches = []
|
||||
while train_batch_size < min_steps:
|
||||
batch = replay_buffer.sample(num_items=1)
|
||||
if batch is None:
|
||||
return None
|
||||
train_batches.append(batch)
|
||||
train_batch_size += (
|
||||
train_batches[-1].agent_steps()
|
||||
if count_by_agent_steps
|
||||
else train_batches[-1].env_steps()
|
||||
)
|
||||
# All batch types are the same type, hence we can use any concat_samples()
|
||||
train_batch = SampleBatch.concat_samples(train_batches)
|
||||
return train_batch
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
def validate_buffer_config(config: dict):
|
||||
if config.get("replay_buffer_config", None) is None:
|
||||
config["replay_buffer_config"] = {}
|
||||
|
||||
prioritized_replay = config.get("prioritized_replay")
|
||||
if prioritized_replay != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['prioritized_replay']",
|
||||
help="Replay prioritization specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'prioritized_replay'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
config["replay_buffer_config"]["prioritized_replay"] = prioritized_replay
|
||||
|
||||
capacity = config.get("buffer_size", DEPRECATED_VALUE)
|
||||
if capacity != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['buffer_size']",
|
||||
help="Buffer size specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'capacity'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
config["replay_buffer_config"]["capacity"] = capacity
|
||||
|
||||
# Deprecation of old-style replay buffer args
|
||||
# Warnings before checking of we need local buffer so that algorithms
|
||||
# Without local buffer also get warned
|
||||
deprecated_replay_buffer_keys = [
|
||||
"prioritized_replay_alpha",
|
||||
"prioritized_replay_beta",
|
||||
"prioritized_replay_eps",
|
||||
"learning_starts",
|
||||
]
|
||||
for k in deprecated_replay_buffer_keys:
|
||||
if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config[{}]".format(k),
|
||||
help="config['replay_buffer_config'][{}] should be used "
|
||||
"for Q-Learning algorithms. Ignore this warning if "
|
||||
"you are not using a Q-Learning algorithm and still "
|
||||
"provide {}."
|
||||
"".format(k, k),
|
||||
error=False,
|
||||
)
|
||||
# Copy values over to new location in config to support new
|
||||
# and old configuration style
|
||||
if config.get("replay_buffer_config") is not None:
|
||||
config["replay_buffer_config"][k] = config[k]
|
||||
|
||||
# Old Ape-X configs may contain no_local_replay_buffer
|
||||
no_local_replay_buffer = config.get("no_local_replay_buffer", False)
|
||||
if no_local_replay_buffer:
|
||||
deprecation_warning(
|
||||
old="config['no_local_replay_buffer']",
|
||||
help="no_local_replay_buffer specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'capacity'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
config["replay_buffer_config"][
|
||||
"no_local_replay_buffer"
|
||||
] = no_local_replay_buffer
|
||||
|
||||
# TODO (Artur):
|
||||
if config["replay_buffer_config"].get("no_local_replay_buffer", False):
|
||||
return
|
||||
|
||||
replay_buffer_config = config["replay_buffer_config"]
|
||||
assert (
|
||||
"type" in replay_buffer_config
|
||||
), "Can not instantiate ReplayBuffer from config without 'type' key."
|
||||
|
||||
# Check if old replay buffer should be instantiated
|
||||
buffer_type = config["replay_buffer_config"]["type"]
|
||||
if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
|
||||
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
||||
# Prepend old-style buffers' path
|
||||
assert buffer_type == "MultiAgentReplayBuffer", (
|
||||
"Without "
|
||||
"ReplayBuffer "
|
||||
"API, only "
|
||||
"MultiAgentReplayBuffer "
|
||||
"is supported!"
|
||||
)
|
||||
# Create valid full [module].[class] string for from_config
|
||||
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
|
||||
else:
|
||||
assert buffer_type in [
|
||||
"ray.rllib.execution.MultiAgentReplayBuffer",
|
||||
Legacy_MultiAgentReplayBuffer,
|
||||
], (
|
||||
"Without ReplayBuffer API, only " "MultiAgentReplayBuffer is supported!"
|
||||
)
|
||||
|
||||
config["replay_buffer_config"]["type"] = buffer_type
|
||||
|
||||
# Remove from config, so it's not passed into the buffer c'tor
|
||||
config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
|
||||
|
||||
# We need to deprecate the old-style location of the following
|
||||
# buffer arguments and make users put them into the
|
||||
# "replay_buffer_config" field of their config.
|
||||
replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
|
||||
if replay_batch_size != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"]["replay_batch_size"] = replay_batch_size
|
||||
deprecation_warning(
|
||||
old="config['replay_batch_size']",
|
||||
help="Replay batch size specified at new "
|
||||
"location config['replay_buffer_config']["
|
||||
"'replay_batch_size'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
replay_mode = config.get("replay_mode", DEPRECATED_VALUE)
|
||||
if replay_mode != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"]["replay_mode"] = replay_mode
|
||||
deprecation_warning(
|
||||
old="config['multiagent']['replay_mode']",
|
||||
help="Replay sequence length specified at new "
|
||||
"location config['replay_buffer_config']["
|
||||
"'replay_mode'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
# Can't use DEPRECATED_VALUE here because this is also a deliberate
|
||||
# value set for some algorithms
|
||||
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
|
||||
replay_sequence_length = config.get("replay_sequence_length", None)
|
||||
if replay_sequence_length is not None:
|
||||
config["replay_buffer_config"][
|
||||
"replay_sequence_length"
|
||||
] = replay_sequence_length
|
||||
deprecation_warning(
|
||||
old="config['replay_sequence_length']",
|
||||
help="Replay sequence length specified at new "
|
||||
"location config['replay_buffer_config']["
|
||||
"'replay_sequence_length'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
|
||||
if replay_burn_in != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
|
||||
deprecation_warning(
|
||||
old="config['burn_in']",
|
||||
help="Burn in specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'replay_burn_in'] will be overwritten.",
|
||||
)
|
||||
|
||||
replay_zero_init_states = config.get(
|
||||
"replay_zero_init_states", DEPRECATED_VALUE
|
||||
)
|
||||
if replay_zero_init_states != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"][
|
||||
"replay_zero_init_states"
|
||||
] = replay_zero_init_states
|
||||
deprecation_warning(
|
||||
old="config['replay_zero_init_states']",
|
||||
help="Replay zero init states specified at new location "
|
||||
"config["
|
||||
"'replay_buffer_config']["
|
||||
"'replay_zero_init_states'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
# TODO (Artur): Move this logic into config objects
|
||||
if config["replay_buffer_config"].get("prioritized_replay", False):
|
||||
is_prioritized_buffer = True
|
||||
else:
|
||||
is_prioritized_buffer = False
|
||||
# This triggers non-prioritization in old-style replay buffer
|
||||
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
|
||||
else:
|
||||
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
||||
# Create valid full [module].[class] string for from_config
|
||||
config["replay_buffer_config"]["type"] = (
|
||||
"ray.rllib.utils.replay_buffers." + buffer_type
|
||||
)
|
||||
test_buffer = from_config(buffer_type, config["replay_buffer_config"])
|
||||
if hasattr(test_buffer, "update_priorities"):
|
||||
is_prioritized_buffer = True
|
||||
else:
|
||||
is_prioritized_buffer = False
|
||||
|
||||
if is_prioritized_buffer:
|
||||
if config["multiagent"]["replay_mode"] == "lockstep":
|
||||
raise ValueError(
|
||||
"Prioritized replay is not supported when replay_mode=lockstep."
|
||||
)
|
||||
elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1:
|
||||
raise ValueError(
|
||||
"Prioritized replay is not supported when "
|
||||
"replay_sequence_length > 1."
|
||||
)
|
||||
else:
|
||||
if config.get("worker_side_prioritization"):
|
||||
raise ValueError(
|
||||
"Worker side prioritization is not supported when "
|
||||
"prioritized_replay=False."
|
||||
)
|
||||
|
||||
if config["replay_buffer_config"].get("replay_batch_size", None) is None:
|
||||
# Fall back to train batch size if no replay batch size was provided
|
||||
config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"]
|
||||
|
||||
# Pop prioritized replay because it's not a valid parameter for older
|
||||
# replay buffers
|
||||
config["replay_buffer_config"].pop("prioritized_replay", None)
|
||||
|
||||
|
||||
def warn_replay_buffer_capacity(*, item: SampleBatchType, capacity: int) -> None:
|
||||
"""Warn if the configured replay buffer capacity is too large for machine's memory.
|
||||
|
||||
Args:
|
||||
item: A (example) item that's supposed to be added to the buffer.
|
||||
This is used to compute the overall memory footprint estimate for the
|
||||
buffer.
|
||||
capacity: The capacity value of the buffer. This is interpreted as the
|
||||
number of items (such as given `item`) that will eventually be stored in
|
||||
the buffer.
|
||||
|
||||
Raises:
|
||||
ValueError: If computed memory footprint for the buffer exceeds the machine's
|
||||
RAM.
|
||||
"""
|
||||
if log_once("warn_replay_buffer_capacity"):
|
||||
item_size = item.size_bytes()
|
||||
psutil_mem = psutil.virtual_memory()
|
||||
total_gb = psutil_mem.total / 1e9
|
||||
mem_size = capacity * item_size / 1e9
|
||||
msg = (
|
||||
"Estimated max memory usage for replay buffer is {} GB "
|
||||
"({} batches of size {}, {} bytes each), "
|
||||
"available system memory is {} GB".format(
|
||||
mem_size, capacity, item.count, item_size, total_gb
|
||||
)
|
||||
)
|
||||
if mem_size > total_gb:
|
||||
raise ValueError(msg)
|
||||
elif mem_size > 0.2 * total_gb:
|
||||
logger.warning(msg)
|
||||
else:
|
||||
logger.info(msg)
|
||||
|
|
Loading…
Add table
Reference in a new issue