[RLlib] IMPALA/APPO multi-agent mix-in-buffer fixes (plus MA learning tests). (#25848)

This commit is contained in:
Artur Niederfahrenhorst 2022-06-17 14:10:36 +02:00 committed by GitHub
parent 1c27469b6d
commit a322cc5765
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 258 additions and 131 deletions

View file

@ -195,6 +195,16 @@ py_test(
args = ["--yaml-dir=tuned_examples/appo"]
)
py_test(
name = "learning_tests_multi_agent_cartpole_appo",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/multi-agent-cartpole-appo.yaml"],
args = ["--yaml-dir=tuned_examples/appo"]
)
# py_test(
# name = "learning_tests_frozenlake_appo",
# main = "tests/run_regression_tests.py",
@ -402,6 +412,16 @@ py_test(
# args = ["--yaml-dir=tuned_examples/impala"]
# )
py_test(
name = "learning_tests_multi_agent_cartpole_impala",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/impala/multi-agent-cartpole-impala.yaml"],
args = ["--yaml-dir=tuned_examples/impala"]
)
py_test(
name = "learning_tests_cartpole_impala_fake_gpus",
main = "tests/run_regression_tests.py",

View file

@ -6,10 +6,12 @@ from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQN
from ray.rllib.algorithms.ddpg.ddpg import DDPG, DDPGConfig
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AlgorithmConfigDict
from ray.rllib.utils.typing import PartialAlgorithmConfigDict
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
PartialAlgorithmConfigDict,
ResultDict,
)
from ray.util.iter import LocalIterator
@ -196,8 +198,9 @@ class ApexDDPG(DDPG, ApexDQN):
removed_workers: removed worker ids.
new_workers: ids of newly created workers.
"""
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
if self.config["_disable_execution_plan_api"]:
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
@staticmethod
@override(DDPG)

View file

@ -652,8 +652,9 @@ class ApexDQN(DQN):
removed_workers: removed worker ids.
new_workers: ids of newly created workers.
"""
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
if self.config["_disable_execution_plan_api"]:
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
@override(Algorithm)
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):

View file

@ -11,6 +11,7 @@ import gym
from typing import Dict, List, Optional, Type, Union
import ray
from ray.rllib.algorithms.appo.utils import make_appo_models
from ray.rllib.algorithms.impala import vtrace_tf as vtrace
from ray.rllib.algorithms.impala.impala_tf_policy import (
_make_time_major,
@ -32,11 +33,9 @@ from ray.rllib.policy.tf_mixins import (
KLCoeffMixin,
ValueNetworkMixin,
)
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.utils.annotations import (
DeveloperAPI,
override,
)
from ray.rllib.utils.framework import try_import_tf
@ -45,52 +44,9 @@ from ray.rllib.utils.typing import TensorType
tf1, tf, tfv = try_import_tf()
POLICY_SCOPE = "func"
TARGET_POLICY_SCOPE = "target_func"
logger = logging.getLogger(__name__)
@DeveloperAPI
def make_appo_model(policy) -> ModelV2:
"""Builds model and target model for APPO.
Returns:
ModelV2: The Model for the Policy to use.
Note: The target model will not be returned, just assigned to
`policy.target_model`.
"""
# Get the num_outputs for the following model construction calls.
_, logit_dim = ModelCatalog.get_action_dist(
policy.action_space, policy.config["model"]
)
# Construct the (main) model.
policy.model = ModelCatalog.get_model_v2(
policy.observation_space,
policy.action_space,
logit_dim,
policy.config["model"],
name=POLICY_SCOPE,
framework=policy.framework,
)
policy.model_variables = policy.model.variables()
# Construct the target model.
policy.target_model = ModelCatalog.get_model_v2(
policy.observation_space,
policy.action_space,
logit_dim,
policy.config["model"],
name=TARGET_POLICY_SCOPE,
framework=policy.framework,
)
policy.target_model_variables = policy.target_model.variables()
# Return only the model (not the target model).
return policy.model
class TargetNetworkMixin:
"""Target NN is updated by master learner via the `update_target` method.
@ -182,7 +138,7 @@ def get_appo_tf_policy(base: type) -> type:
@override(base)
def make_model(self) -> ModelV2:
return make_appo_model(self)
return make_appo_models(self)
@override(base)
def loss(

View file

@ -11,7 +11,7 @@ import logging
from typing import Any, Dict, List, Optional, Type, Union
import ray
from ray.rllib.algorithms.appo.appo_tf_policy import make_appo_model
from ray.rllib.algorithms.appo.utils import make_appo_models
import ray.rllib.algorithms.impala.vtrace_torch as vtrace
from ray.rllib.algorithms.impala.impala_torch_policy import (
make_time_major,
@ -101,7 +101,7 @@ class APPOTorchPolicy(
@override(TorchPolicyV2)
def make_model(self) -> ModelV2:
return make_appo_model(self)
return make_appo_models(self)
@override(TorchPolicyV2)
def loss(

View file

@ -0,0 +1,45 @@
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
POLICY_SCOPE = "func"
TARGET_POLICY_SCOPE = "target_func"
def make_appo_models(policy) -> ModelV2:
"""Builds model and target model for APPO.
Returns:
ModelV2: The Model for the Policy to use.
Note: The target model will not be returned, just assigned to
`policy.target_model`.
"""
# Get the num_outputs for the following model construction calls.
_, logit_dim = ModelCatalog.get_action_dist(
policy.action_space, policy.config["model"]
)
# Construct the (main) model.
policy.model = ModelCatalog.get_model_v2(
policy.observation_space,
policy.action_space,
logit_dim,
policy.config["model"],
name=POLICY_SCOPE,
framework=policy.framework,
)
policy.model_variables = policy.model.variables()
# Construct the target model.
policy.target_model = ModelCatalog.get_model_v2(
policy.observation_space,
policy.action_space,
logit_dim,
policy.config["model"],
name=TARGET_POLICY_SCOPE,
framework=policy.framework,
)
policy.target_model_variables = policy.target_model.variables()
# Return only the model (not the target model).
return policy.model

View file

@ -1,9 +1,8 @@
import copy
import logging
import platform
import queue
from typing import Optional, Type, List, Dict, Union, Callable, Any
from typing import Any, Callable, Dict, List, Optional, Type, Union
import ray
from ray.actor import ActorHandle
@ -11,45 +10,45 @@ from ray.rllib import SampleBatch
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
from ray.rllib.execution.learner_thread import LearnerThread
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
from ray.rllib.execution.parallel_requests import (
AsyncRequestsManager,
)
from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
_get_global_vars,
_get_shared_metrics,
)
from ray.rllib.execution.replay_ops import MixInReplay
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
from ray.rllib.execution.concurrency_ops import Concurrently, Dequeue, Enqueue
from ray.rllib.execution.learner_thread import LearnerThread
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
from ray.rllib.execution.parallel_requests import AsyncRequestsManager
from ray.rllib.execution.replay_ops import MixInReplay
from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts
from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
Deprecated,
deprecation_warning,
)
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
)
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
# from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
PartialAlgorithmConfigDict,
ResultDict,
AlgorithmConfigDict,
SampleBatchType,
T,
)
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
deprecation_warning,
)
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.types import ObjectRef
@ -470,9 +469,7 @@ class Impala(Algorithm):
return A3CTorchPolicy
elif config["framework"] == "tf":
if config["vtrace"]:
from ray.rllib.algorithms.impala.impala_tf_policy import (
ImpalaTF1Policy,
)
from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF1Policy
return ImpalaTF1Policy
else:
@ -590,6 +587,7 @@ class Impala(Algorithm):
else 1
),
replay_ratio=self.config["replay_ratio"],
replay_mode=ReplayMode.LOCKSTEP,
)
self._sampling_actor_manager = AsyncRequestsManager(
@ -658,7 +656,7 @@ class Impala(Algorithm):
)
def record_steps_trained(item):
count, fetches = item
count, fetches, _ = item
metrics = _get_shared_metrics()
# Manually update the steps trained counter since the learner
# thread is executing outside the pipeline.
@ -797,7 +795,7 @@ class Impala(Algorithm):
def process_trained_results(self) -> ResultDict:
# Get learner outputs/stats from output queue.
learner_infos = []
learner_info = copy.deepcopy(self._learner_thread.learner_info)
num_env_steps_trained = 0
num_agent_steps_trained = 0
@ -811,10 +809,9 @@ class Impala(Algorithm):
num_env_steps_trained += env_steps
num_agent_steps_trained += agent_steps
if learner_results:
learner_infos.append(learner_results)
learner_info.update(learner_results)
else:
raise RuntimeError("The learner thread died in while training")
learner_info = copy.deepcopy(self._learner_thread.learner_info)
# Update the steps trained counters.
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = num_agent_steps_trained
@ -839,7 +836,7 @@ class Impala(Algorithm):
for batch in batches:
batch = batch.decompress_if_needed()
self.local_mixin_buffer.add_batch(batch)
batch = self.local_mixin_buffer.replay()
batch = self.local_mixin_buffer.replay(_ALL_POLICIES)
if batch:
processed_batches.append(batch)
return processed_batches
@ -898,8 +895,9 @@ class Impala(Algorithm):
removed_workers: removed worker ids.
new_workers: ids of newly created workers.
"""
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
if self.config["_disable_execution_plan_api"]:
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
@override(Algorithm)
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
@ -925,12 +923,13 @@ class AggregatorWorker:
else 1
),
replay_ratio=self.config["replay_ratio"],
replay_mode=ReplayMode.LOCKSTEP,
)
def process_episodes(self, batch: SampleBatchType) -> SampleBatchType:
batch = batch.decompress_if_needed()
self._mixin_buffer.add_batch(batch)
processed_batches = self._mixin_buffer.replay()
processed_batches = self._mixin_buffer.replay(_ALL_POLICIES)
return processed_batches
def apply(

View file

@ -5,6 +5,8 @@ from typing import Optional
from ray.rllib.execution.replay_ops import SimpleReplayBuffer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import PolicyID, SampleBatchType
@ -50,7 +52,12 @@ class MixInMultiAgentReplayBuffer:
[B]
"""
def __init__(self, capacity: int, replay_ratio: float):
def __init__(
self,
capacity: int,
replay_ratio: float,
replay_mode: ReplayMode = ReplayMode.INDEPENDENT,
):
"""Initializes MixInReplay instance.
Args:
@ -67,6 +74,13 @@ class MixInMultiAgentReplayBuffer:
if self.replay_ratio != 1.0:
self.replay_proportion = self.replay_ratio / (1.0 - self.replay_ratio)
if replay_mode in ["lockstep", ReplayMode.LOCKSTEP]:
self.replay_mode = ReplayMode.LOCKSTEP
elif replay_mode in ["independent", ReplayMode.INDEPENDENT]:
self.replay_mode = ReplayMode.INDEPENDENT
else:
raise ValueError("Unsupported replay mode: {}".format(replay_mode))
def new_buffer():
return SimpleReplayBuffer(num_slots=capacity)
@ -98,14 +112,31 @@ class MixInMultiAgentReplayBuffer:
batch = batch.as_multi_agent()
with self.add_batch_timer:
for policy_id, sample_batch in batch.policy_batches.items():
self.replay_buffers[policy_id].add_batch(sample_batch)
self.last_added_batches[policy_id].append(sample_batch)
if self.replay_mode == ReplayMode.LOCKSTEP:
# Lockstep mode: Store under _ALL_POLICIES key (we will always
# only sample from all policies at the same time).
# This means storing a MultiAgentBatch to the underlying buffer
self.replay_buffers[_ALL_POLICIES].add_batch(batch)
self.last_added_batches[_ALL_POLICIES].append(batch)
else:
# Store independent SampleBatches
for policy_id, sample_batch in batch.policy_batches.items():
self.replay_buffers[policy_id].add_batch(sample_batch)
self.last_added_batches[policy_id].append(sample_batch)
self.num_added += batch.count
def replay(
self, policy_id: PolicyID = DEFAULT_POLICY_ID
) -> Optional[SampleBatchType]:
if self.replay_mode == ReplayMode.LOCKSTEP and policy_id != _ALL_POLICIES:
raise ValueError(
"Trying to sample from single policy's buffer in lockstep "
"mode. In lockstep mode, all policies' experiences are "
"sampled from a single replay buffer which is accessed "
"with the policy id `{}`".format(_ALL_POLICIES)
)
buffer = self.replay_buffers[policy_id]
# Return None, if:
# - Buffer empty or

View file

@ -0,0 +1,35 @@
multi-agent-cartpole-appo:
env: ray.rllib.examples.env.multi_agent.MultiAgentCartPole
run: APPO
stop:
episode_reward_mean: 600 # 600 / 4 (==num_agents) = 150
timesteps_total: 200000
config:
# Works for both torch and tf.
framework: tf
# 4-agent MA cartpole.
env_config:
config:
num_agents: 4
num_envs_per_worker: 5
num_workers: 4
num_gpus: 1
_fake_gpus: true
observation_filter: MeanStdFilter
num_sgd_iter: 1
vf_loss_coeff: 0.005
vtrace: true
vtrace_drop_last_ts: false
model:
fcnet_hiddens: [32]
fcnet_activation: linear
vf_share_layers: true
multiagent:
policies: ["p0", "p1", "p2", "p3"]
# YAML-capable policy_mapping_fn definition via providing a callable class here.
policy_mapping_fn:
type: ray.rllib.examples.multi_agent_and_self_play.policy_mapping_fn.PolicyMappingFn

View file

@ -0,0 +1,37 @@
multi-agent-cartpole-impala:
env: ray.rllib.examples.env.multi_agent.MultiAgentCartPole
run: IMPALA
stop:
episode_reward_mean: 600 # 600 / 4 (==num_agents) = 150
timesteps_total: 200000
config:
# Works for both torch and tf.
framework: tf
# 4-agent MA cartpole.
env_config:
config:
num_agents: 4
num_envs_per_worker: 5
num_workers: 4
num_gpus: 1
_fake_gpus: true
observation_filter: MeanStdFilter
num_sgd_iter: 1
vf_loss_coeff: 0.005
vtrace: true
vtrace_drop_last_ts: false
model:
fcnet_hiddens: [32]
fcnet_activation: linear
vf_share_layers: true
replay_ratio: 0.0
multiagent:
policies: ["p0", "p1", "p2", "p3"]
# YAML-capable policy_mapping_fn definition via providing a callable class here.
policy_mapping_fn:
type: ray.rllib.examples.multi_agent_and_self_play.policy_mapping_fn.PolicyMappingFn

View file

@ -1,31 +1,29 @@
import collections
import random
import numpy as np
import logging
from typing import Optional, Dict, Any
import random
from typing import Any, Dict, Optional
import numpy as np
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
SampleBatch,
MultiAgentBatch,
SampleBatch,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
MultiAgentPrioritizedReplayBuffer,
)
from ray.rllib.utils.replay_buffers.replay_buffer import (
StorageUnit,
)
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
merge_dicts_with_warning,
MultiAgentReplayBuffer,
ReplayMode,
merge_dicts_with_warning,
)
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES, StorageUnit
from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
from ray.util.debug import log_once
from ray.util.annotations import DeveloperAPI
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
from ray.util.debug import log_once
logger = logging.getLogger(__name__)
@ -133,15 +131,6 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
if not 0 <= replay_ratio <= 1:
raise ValueError("Replay ratio must be within [0, 1]")
if "replay_mode" in kwargs and kwargs["replay_mode"] == "lockstep":
if log_once("lockstep_mode_not_supported"):
logger.error(
"Replay mode `lockstep` is not supported for "
"MultiAgentMixInReplayBuffer."
"This buffer will run in `independent` mode."
)
del kwargs["replay_mode"]
MultiAgentPrioritizedReplayBuffer.__init__(
self,
capacity=capacity,
@ -183,19 +172,21 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)
pids_and_batches = self._maybe_split_into_policy_batches(batch)
# We need to split batches into timesteps, sequences or episodes
# here already to properly keep track of self.last_added_batches
# underlying buffers should not split up the batch any further
with self.add_batch_timer:
if self.storage_unit == StorageUnit.TIMESTEPS:
for policy_id, sample_batch in batch.policy_batches.items():
for policy_id, sample_batch in pids_and_batches.items():
timeslices = sample_batch.timeslices(1)
for time_slice in timeslices:
self.replay_buffers[policy_id].add(time_slice, **kwargs)
self.last_added_batches[policy_id].append(time_slice)
elif self.storage_unit == StorageUnit.SEQUENCES:
for policy_id, sample_batch in batch.policy_batches.items():
for policy_id, sample_batch in pids_and_batches.items():
timeslices = timeslice_along_seq_lens_with_overlap(
sample_batch=sample_batch,
seq_lens=sample_batch.get(SampleBatch.SEQ_LENS)
@ -210,7 +201,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
self.last_added_batches[policy_id].append(slice)
elif self.storage_unit == StorageUnit.EPISODES:
for policy_id, sample_batch in batch.policy_batches.items():
for policy_id, sample_batch in pids_and_batches.items():
for eps in sample_batch.split_by_episode():
# Only add full episodes to the buffer
if (
@ -228,7 +219,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
"dropped."
)
elif self.storage_unit == StorageUnit.FRAGMENTS:
for policy_id, sample_batch in batch.policy_batches.items():
for policy_id, sample_batch in pids_and_batches.items():
self.replay_buffers[policy_id].add(sample_batch, **kwargs)
self.last_added_batches[policy_id].append(sample_batch)

View file

@ -1,22 +1,22 @@
import logging
import collections
from typing import Any, Dict, Optional
import logging
from enum import Enum
from typing import Any, Dict, Optional
from ray.rllib.utils.replay_buffers.replay_buffer import (
_ALL_POLICIES,
)
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.replay_buffers.replay_buffer import (
_ALL_POLICIES,
ReplayBuffer,
StorageUnit,
)
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit
from ray.rllib.utils.from_config import from_config
from ray.util.debug import log_once
from ray.rllib.utils.deprecation import Deprecated
from ray.util.annotations import DeveloperAPI
from ray.util.debug import log_once
logger = logging.getLogger(__name__)
@ -229,15 +229,10 @@ class MultiAgentReplayBuffer(ReplayBuffer):
batch = batch.as_multi_agent()
with self.add_batch_timer:
if self.replay_mode == ReplayMode.LOCKSTEP:
# Lockstep mode: Store under _ALL_POLICIES key (we will always
# only sample from all policies at the same time).
# This means storing a MultiAgentBatch to the underlying buffer
self._add_to_underlying_buffer(_ALL_POLICIES, batch, **kwargs)
else:
# Store independent SampleBatches
for policy_id, sample_batch in batch.policy_batches.items():
self._add_to_underlying_buffer(policy_id, sample_batch, **kwargs)
pids_and_batches = self._maybe_split_into_policy_batches(batch)
for policy_id, sample_batch in pids_and_batches.items():
self._add_to_underlying_buffer(policy_id, sample_batch, **kwargs)
self._num_added += batch.count
@DeveloperAPI
@ -391,3 +386,17 @@ class MultiAgentReplayBuffer(ReplayBuffer):
buffer_states = state["replay_buffers"]
for policy_id in buffer_states.keys():
self.replay_buffers[policy_id].set_state(buffer_states[policy_id])
def _maybe_split_into_policy_batches(self, batch: SampleBatchType):
"""Returns a dict of policy IDs and batches, depending on our replay mode.
This method helps with splitting up MultiAgentBatches only if the
self.replay_mode requires it.
"""
if self.replay_mode == ReplayMode.LOCKSTEP:
return {_ALL_POLICIES: batch}
else:
return {
policy_id: sample_batch
for policy_id, sample_batch in batch.policy_batches.items()
}