mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] IMPALA/APPO multi-agent mix-in-buffer fixes (plus MA learning tests). (#25848)
This commit is contained in:
parent
1c27469b6d
commit
a322cc5765
12 changed files with 258 additions and 131 deletions
20
rllib/BUILD
20
rllib/BUILD
|
@ -195,6 +195,16 @@ py_test(
|
|||
args = ["--yaml-dir=tuned_examples/appo"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "learning_tests_multi_agent_cartpole_appo",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/appo/multi-agent-cartpole-appo.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/appo"]
|
||||
)
|
||||
|
||||
# py_test(
|
||||
# name = "learning_tests_frozenlake_appo",
|
||||
# main = "tests/run_regression_tests.py",
|
||||
|
@ -402,6 +412,16 @@ py_test(
|
|||
# args = ["--yaml-dir=tuned_examples/impala"]
|
||||
# )
|
||||
|
||||
py_test(
|
||||
name = "learning_tests_multi_agent_cartpole_impala",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/impala/multi-agent-cartpole-impala.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/impala"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "learning_tests_cartpole_impala_fake_gpus",
|
||||
main = "tests/run_regression_tests.py",
|
||||
|
|
|
@ -6,10 +6,12 @@ from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQN
|
|||
from ray.rllib.algorithms.ddpg.ddpg import DDPG, DDPGConfig
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict
|
||||
from ray.rllib.utils.typing import PartialAlgorithmConfigDict
|
||||
from ray.rllib.utils.typing import ResultDict
|
||||
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated
|
||||
from ray.rllib.utils.typing import (
|
||||
AlgorithmConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
ResultDict,
|
||||
)
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
|
||||
|
@ -196,8 +198,9 @@ class ApexDDPG(DDPG, ApexDQN):
|
|||
removed_workers: removed worker ids.
|
||||
new_workers: ids of newly created workers.
|
||||
"""
|
||||
self._sampling_actor_manager.remove_workers(removed_workers)
|
||||
self._sampling_actor_manager.add_workers(new_workers)
|
||||
if self.config["_disable_execution_plan_api"]:
|
||||
self._sampling_actor_manager.remove_workers(removed_workers)
|
||||
self._sampling_actor_manager.add_workers(new_workers)
|
||||
|
||||
@staticmethod
|
||||
@override(DDPG)
|
||||
|
|
|
@ -652,8 +652,9 @@ class ApexDQN(DQN):
|
|||
removed_workers: removed worker ids.
|
||||
new_workers: ids of newly created workers.
|
||||
"""
|
||||
self._sampling_actor_manager.remove_workers(removed_workers)
|
||||
self._sampling_actor_manager.add_workers(new_workers)
|
||||
if self.config["_disable_execution_plan_api"]:
|
||||
self._sampling_actor_manager.remove_workers(removed_workers)
|
||||
self._sampling_actor_manager.add_workers(new_workers)
|
||||
|
||||
@override(Algorithm)
|
||||
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
|
||||
|
|
|
@ -11,6 +11,7 @@ import gym
|
|||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.algorithms.appo.utils import make_appo_models
|
||||
from ray.rllib.algorithms.impala import vtrace_tf as vtrace
|
||||
from ray.rllib.algorithms.impala.impala_tf_policy import (
|
||||
_make_time_major,
|
||||
|
@ -32,11 +33,9 @@ from ray.rllib.policy.tf_mixins import (
|
|||
KLCoeffMixin,
|
||||
ValueNetworkMixin,
|
||||
)
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
from ray.rllib.utils.annotations import (
|
||||
DeveloperAPI,
|
||||
override,
|
||||
)
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
@ -45,52 +44,9 @@ from ray.rllib.utils.typing import TensorType
|
|||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
POLICY_SCOPE = "func"
|
||||
TARGET_POLICY_SCOPE = "target_func"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def make_appo_model(policy) -> ModelV2:
|
||||
"""Builds model and target model for APPO.
|
||||
|
||||
Returns:
|
||||
ModelV2: The Model for the Policy to use.
|
||||
Note: The target model will not be returned, just assigned to
|
||||
`policy.target_model`.
|
||||
"""
|
||||
# Get the num_outputs for the following model construction calls.
|
||||
_, logit_dim = ModelCatalog.get_action_dist(
|
||||
policy.action_space, policy.config["model"]
|
||||
)
|
||||
|
||||
# Construct the (main) model.
|
||||
policy.model = ModelCatalog.get_model_v2(
|
||||
policy.observation_space,
|
||||
policy.action_space,
|
||||
logit_dim,
|
||||
policy.config["model"],
|
||||
name=POLICY_SCOPE,
|
||||
framework=policy.framework,
|
||||
)
|
||||
policy.model_variables = policy.model.variables()
|
||||
|
||||
# Construct the target model.
|
||||
policy.target_model = ModelCatalog.get_model_v2(
|
||||
policy.observation_space,
|
||||
policy.action_space,
|
||||
logit_dim,
|
||||
policy.config["model"],
|
||||
name=TARGET_POLICY_SCOPE,
|
||||
framework=policy.framework,
|
||||
)
|
||||
policy.target_model_variables = policy.target_model.variables()
|
||||
|
||||
# Return only the model (not the target model).
|
||||
return policy.model
|
||||
|
||||
|
||||
class TargetNetworkMixin:
|
||||
"""Target NN is updated by master learner via the `update_target` method.
|
||||
|
||||
|
@ -182,7 +138,7 @@ def get_appo_tf_policy(base: type) -> type:
|
|||
|
||||
@override(base)
|
||||
def make_model(self) -> ModelV2:
|
||||
return make_appo_model(self)
|
||||
return make_appo_models(self)
|
||||
|
||||
@override(base)
|
||||
def loss(
|
||||
|
|
|
@ -11,7 +11,7 @@ import logging
|
|||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.algorithms.appo.appo_tf_policy import make_appo_model
|
||||
from ray.rllib.algorithms.appo.utils import make_appo_models
|
||||
import ray.rllib.algorithms.impala.vtrace_torch as vtrace
|
||||
from ray.rllib.algorithms.impala.impala_torch_policy import (
|
||||
make_time_major,
|
||||
|
@ -101,7 +101,7 @@ class APPOTorchPolicy(
|
|||
|
||||
@override(TorchPolicyV2)
|
||||
def make_model(self) -> ModelV2:
|
||||
return make_appo_model(self)
|
||||
return make_appo_models(self)
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def loss(
|
||||
|
|
45
rllib/algorithms/appo/utils.py
Normal file
45
rllib/algorithms/appo/utils.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
|
||||
|
||||
POLICY_SCOPE = "func"
|
||||
TARGET_POLICY_SCOPE = "target_func"
|
||||
|
||||
|
||||
def make_appo_models(policy) -> ModelV2:
|
||||
"""Builds model and target model for APPO.
|
||||
|
||||
Returns:
|
||||
ModelV2: The Model for the Policy to use.
|
||||
Note: The target model will not be returned, just assigned to
|
||||
`policy.target_model`.
|
||||
"""
|
||||
# Get the num_outputs for the following model construction calls.
|
||||
_, logit_dim = ModelCatalog.get_action_dist(
|
||||
policy.action_space, policy.config["model"]
|
||||
)
|
||||
|
||||
# Construct the (main) model.
|
||||
policy.model = ModelCatalog.get_model_v2(
|
||||
policy.observation_space,
|
||||
policy.action_space,
|
||||
logit_dim,
|
||||
policy.config["model"],
|
||||
name=POLICY_SCOPE,
|
||||
framework=policy.framework,
|
||||
)
|
||||
policy.model_variables = policy.model.variables()
|
||||
|
||||
# Construct the target model.
|
||||
policy.target_model = ModelCatalog.get_model_v2(
|
||||
policy.observation_space,
|
||||
policy.action_space,
|
||||
logit_dim,
|
||||
policy.config["model"],
|
||||
name=TARGET_POLICY_SCOPE,
|
||||
framework=policy.framework,
|
||||
)
|
||||
policy.target_model_variables = policy.target_model.variables()
|
||||
|
||||
# Return only the model (not the target model).
|
||||
return policy.model
|
|
@ -1,9 +1,8 @@
|
|||
import copy
|
||||
import logging
|
||||
import platform
|
||||
|
||||
import queue
|
||||
from typing import Optional, Type, List, Dict, Union, Callable, Any
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
|
@ -11,45 +10,45 @@ from ray.rllib import SampleBatch
|
|||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
|
||||
from ray.rllib.execution.learner_thread import LearnerThread
|
||||
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
|
||||
from ray.rllib.execution.parallel_requests import (
|
||||
AsyncRequestsManager,
|
||||
)
|
||||
from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
|
||||
from ray.rllib.execution.common import (
|
||||
STEPS_TRAINED_COUNTER,
|
||||
STEPS_TRAINED_THIS_ITER_COUNTER,
|
||||
_get_global_vars,
|
||||
_get_shared_metrics,
|
||||
)
|
||||
from ray.rllib.execution.replay_ops import MixInReplay
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently, Dequeue, Enqueue
|
||||
from ray.rllib.execution.learner_thread import LearnerThread
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
|
||||
from ray.rllib.execution.parallel_requests import AsyncRequestsManager
|
||||
from ray.rllib.execution.replay_ops import MixInReplay
|
||||
from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts
|
||||
from ray.rllib.execution.tree_agg import gather_experiences_tree_aggregation
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.actors import create_colocated_actors
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import (
|
||||
DEPRECATED_VALUE,
|
||||
Deprecated,
|
||||
deprecation_warning,
|
||||
)
|
||||
from ray.rllib.utils.metrics import (
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
NUM_AGENT_STEPS_TRAINED,
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
NUM_ENV_STEPS_TRAINED,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
|
||||
|
||||
# from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
|
||||
from ray.rllib.utils.typing import (
|
||||
AlgorithmConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
ResultDict,
|
||||
AlgorithmConfigDict,
|
||||
SampleBatchType,
|
||||
T,
|
||||
)
|
||||
from ray.rllib.utils.deprecation import (
|
||||
Deprecated,
|
||||
DEPRECATED_VALUE,
|
||||
deprecation_warning,
|
||||
)
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.types import ObjectRef
|
||||
|
||||
|
@ -470,9 +469,7 @@ class Impala(Algorithm):
|
|||
return A3CTorchPolicy
|
||||
elif config["framework"] == "tf":
|
||||
if config["vtrace"]:
|
||||
from ray.rllib.algorithms.impala.impala_tf_policy import (
|
||||
ImpalaTF1Policy,
|
||||
)
|
||||
from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF1Policy
|
||||
|
||||
return ImpalaTF1Policy
|
||||
else:
|
||||
|
@ -590,6 +587,7 @@ class Impala(Algorithm):
|
|||
else 1
|
||||
),
|
||||
replay_ratio=self.config["replay_ratio"],
|
||||
replay_mode=ReplayMode.LOCKSTEP,
|
||||
)
|
||||
|
||||
self._sampling_actor_manager = AsyncRequestsManager(
|
||||
|
@ -658,7 +656,7 @@ class Impala(Algorithm):
|
|||
)
|
||||
|
||||
def record_steps_trained(item):
|
||||
count, fetches = item
|
||||
count, fetches, _ = item
|
||||
metrics = _get_shared_metrics()
|
||||
# Manually update the steps trained counter since the learner
|
||||
# thread is executing outside the pipeline.
|
||||
|
@ -797,7 +795,7 @@ class Impala(Algorithm):
|
|||
|
||||
def process_trained_results(self) -> ResultDict:
|
||||
# Get learner outputs/stats from output queue.
|
||||
learner_infos = []
|
||||
learner_info = copy.deepcopy(self._learner_thread.learner_info)
|
||||
num_env_steps_trained = 0
|
||||
num_agent_steps_trained = 0
|
||||
|
||||
|
@ -811,10 +809,9 @@ class Impala(Algorithm):
|
|||
num_env_steps_trained += env_steps
|
||||
num_agent_steps_trained += agent_steps
|
||||
if learner_results:
|
||||
learner_infos.append(learner_results)
|
||||
learner_info.update(learner_results)
|
||||
else:
|
||||
raise RuntimeError("The learner thread died in while training")
|
||||
learner_info = copy.deepcopy(self._learner_thread.learner_info)
|
||||
|
||||
# Update the steps trained counters.
|
||||
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = num_agent_steps_trained
|
||||
|
@ -839,7 +836,7 @@ class Impala(Algorithm):
|
|||
for batch in batches:
|
||||
batch = batch.decompress_if_needed()
|
||||
self.local_mixin_buffer.add_batch(batch)
|
||||
batch = self.local_mixin_buffer.replay()
|
||||
batch = self.local_mixin_buffer.replay(_ALL_POLICIES)
|
||||
if batch:
|
||||
processed_batches.append(batch)
|
||||
return processed_batches
|
||||
|
@ -898,8 +895,9 @@ class Impala(Algorithm):
|
|||
removed_workers: removed worker ids.
|
||||
new_workers: ids of newly created workers.
|
||||
"""
|
||||
self._sampling_actor_manager.remove_workers(removed_workers)
|
||||
self._sampling_actor_manager.add_workers(new_workers)
|
||||
if self.config["_disable_execution_plan_api"]:
|
||||
self._sampling_actor_manager.remove_workers(removed_workers)
|
||||
self._sampling_actor_manager.add_workers(new_workers)
|
||||
|
||||
@override(Algorithm)
|
||||
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
|
||||
|
@ -925,12 +923,13 @@ class AggregatorWorker:
|
|||
else 1
|
||||
),
|
||||
replay_ratio=self.config["replay_ratio"],
|
||||
replay_mode=ReplayMode.LOCKSTEP,
|
||||
)
|
||||
|
||||
def process_episodes(self, batch: SampleBatchType) -> SampleBatchType:
|
||||
batch = batch.decompress_if_needed()
|
||||
self._mixin_buffer.add_batch(batch)
|
||||
processed_batches = self._mixin_buffer.replay()
|
||||
processed_batches = self._mixin_buffer.replay(_ALL_POLICIES)
|
||||
return processed_batches
|
||||
|
||||
def apply(
|
||||
|
|
|
@ -5,6 +5,8 @@ from typing import Optional
|
|||
|
||||
from ray.rllib.execution.replay_ops import SimpleReplayBuffer
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.typing import PolicyID, SampleBatchType
|
||||
|
||||
|
@ -50,7 +52,12 @@ class MixInMultiAgentReplayBuffer:
|
|||
[B]
|
||||
"""
|
||||
|
||||
def __init__(self, capacity: int, replay_ratio: float):
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
replay_ratio: float,
|
||||
replay_mode: ReplayMode = ReplayMode.INDEPENDENT,
|
||||
):
|
||||
"""Initializes MixInReplay instance.
|
||||
|
||||
Args:
|
||||
|
@ -67,6 +74,13 @@ class MixInMultiAgentReplayBuffer:
|
|||
if self.replay_ratio != 1.0:
|
||||
self.replay_proportion = self.replay_ratio / (1.0 - self.replay_ratio)
|
||||
|
||||
if replay_mode in ["lockstep", ReplayMode.LOCKSTEP]:
|
||||
self.replay_mode = ReplayMode.LOCKSTEP
|
||||
elif replay_mode in ["independent", ReplayMode.INDEPENDENT]:
|
||||
self.replay_mode = ReplayMode.INDEPENDENT
|
||||
else:
|
||||
raise ValueError("Unsupported replay mode: {}".format(replay_mode))
|
||||
|
||||
def new_buffer():
|
||||
return SimpleReplayBuffer(num_slots=capacity)
|
||||
|
||||
|
@ -98,14 +112,31 @@ class MixInMultiAgentReplayBuffer:
|
|||
batch = batch.as_multi_agent()
|
||||
|
||||
with self.add_batch_timer:
|
||||
for policy_id, sample_batch in batch.policy_batches.items():
|
||||
self.replay_buffers[policy_id].add_batch(sample_batch)
|
||||
self.last_added_batches[policy_id].append(sample_batch)
|
||||
if self.replay_mode == ReplayMode.LOCKSTEP:
|
||||
# Lockstep mode: Store under _ALL_POLICIES key (we will always
|
||||
# only sample from all policies at the same time).
|
||||
# This means storing a MultiAgentBatch to the underlying buffer
|
||||
self.replay_buffers[_ALL_POLICIES].add_batch(batch)
|
||||
self.last_added_batches[_ALL_POLICIES].append(batch)
|
||||
else:
|
||||
# Store independent SampleBatches
|
||||
for policy_id, sample_batch in batch.policy_batches.items():
|
||||
self.replay_buffers[policy_id].add_batch(sample_batch)
|
||||
self.last_added_batches[policy_id].append(sample_batch)
|
||||
|
||||
self.num_added += batch.count
|
||||
|
||||
def replay(
|
||||
self, policy_id: PolicyID = DEFAULT_POLICY_ID
|
||||
) -> Optional[SampleBatchType]:
|
||||
if self.replay_mode == ReplayMode.LOCKSTEP and policy_id != _ALL_POLICIES:
|
||||
raise ValueError(
|
||||
"Trying to sample from single policy's buffer in lockstep "
|
||||
"mode. In lockstep mode, all policies' experiences are "
|
||||
"sampled from a single replay buffer which is accessed "
|
||||
"with the policy id `{}`".format(_ALL_POLICIES)
|
||||
)
|
||||
|
||||
buffer = self.replay_buffers[policy_id]
|
||||
# Return None, if:
|
||||
# - Buffer empty or
|
||||
|
|
35
rllib/tuned_examples/appo/multi-agent-cartpole-appo.yaml
Normal file
35
rllib/tuned_examples/appo/multi-agent-cartpole-appo.yaml
Normal file
|
@ -0,0 +1,35 @@
|
|||
multi-agent-cartpole-appo:
|
||||
env: ray.rllib.examples.env.multi_agent.MultiAgentCartPole
|
||||
run: APPO
|
||||
stop:
|
||||
episode_reward_mean: 600 # 600 / 4 (==num_agents) = 150
|
||||
timesteps_total: 200000
|
||||
config:
|
||||
# Works for both torch and tf.
|
||||
framework: tf
|
||||
|
||||
# 4-agent MA cartpole.
|
||||
env_config:
|
||||
config:
|
||||
num_agents: 4
|
||||
|
||||
num_envs_per_worker: 5
|
||||
num_workers: 4
|
||||
num_gpus: 1
|
||||
_fake_gpus: true
|
||||
|
||||
observation_filter: MeanStdFilter
|
||||
num_sgd_iter: 1
|
||||
vf_loss_coeff: 0.005
|
||||
vtrace: true
|
||||
vtrace_drop_last_ts: false
|
||||
model:
|
||||
fcnet_hiddens: [32]
|
||||
fcnet_activation: linear
|
||||
vf_share_layers: true
|
||||
|
||||
multiagent:
|
||||
policies: ["p0", "p1", "p2", "p3"]
|
||||
# YAML-capable policy_mapping_fn definition via providing a callable class here.
|
||||
policy_mapping_fn:
|
||||
type: ray.rllib.examples.multi_agent_and_self_play.policy_mapping_fn.PolicyMappingFn
|
37
rllib/tuned_examples/impala/multi-agent-cartpole-impala.yaml
Normal file
37
rllib/tuned_examples/impala/multi-agent-cartpole-impala.yaml
Normal file
|
@ -0,0 +1,37 @@
|
|||
multi-agent-cartpole-impala:
|
||||
env: ray.rllib.examples.env.multi_agent.MultiAgentCartPole
|
||||
run: IMPALA
|
||||
stop:
|
||||
episode_reward_mean: 600 # 600 / 4 (==num_agents) = 150
|
||||
timesteps_total: 200000
|
||||
config:
|
||||
# Works for both torch and tf.
|
||||
framework: tf
|
||||
|
||||
# 4-agent MA cartpole.
|
||||
env_config:
|
||||
config:
|
||||
num_agents: 4
|
||||
|
||||
num_envs_per_worker: 5
|
||||
num_workers: 4
|
||||
num_gpus: 1
|
||||
_fake_gpus: true
|
||||
|
||||
observation_filter: MeanStdFilter
|
||||
num_sgd_iter: 1
|
||||
vf_loss_coeff: 0.005
|
||||
vtrace: true
|
||||
vtrace_drop_last_ts: false
|
||||
model:
|
||||
fcnet_hiddens: [32]
|
||||
fcnet_activation: linear
|
||||
vf_share_layers: true
|
||||
|
||||
replay_ratio: 0.0
|
||||
|
||||
multiagent:
|
||||
policies: ["p0", "p1", "p2", "p3"]
|
||||
# YAML-capable policy_mapping_fn definition via providing a callable class here.
|
||||
policy_mapping_fn:
|
||||
type: ray.rllib.examples.multi_agent_and_self_play.policy_mapping_fn.PolicyMappingFn
|
|
@ -1,31 +1,29 @@
|
|||
import collections
|
||||
import random
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
import random
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
|
||||
from ray.rllib.policy.sample_batch import (
|
||||
DEFAULT_POLICY_ID,
|
||||
SampleBatch,
|
||||
MultiAgentBatch,
|
||||
SampleBatch,
|
||||
)
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
|
||||
MultiAgentPrioritizedReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import (
|
||||
StorageUnit,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
|
||||
merge_dicts_with_warning,
|
||||
MultiAgentReplayBuffer,
|
||||
ReplayMode,
|
||||
merge_dicts_with_warning,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES, StorageUnit
|
||||
from ray.rllib.utils.typing import PolicyID, SampleBatchType
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
|
||||
from ray.util.debug import log_once
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
|
||||
from ray.util.debug import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -133,15 +131,6 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
if not 0 <= replay_ratio <= 1:
|
||||
raise ValueError("Replay ratio must be within [0, 1]")
|
||||
|
||||
if "replay_mode" in kwargs and kwargs["replay_mode"] == "lockstep":
|
||||
if log_once("lockstep_mode_not_supported"):
|
||||
logger.error(
|
||||
"Replay mode `lockstep` is not supported for "
|
||||
"MultiAgentMixInReplayBuffer."
|
||||
"This buffer will run in `independent` mode."
|
||||
)
|
||||
del kwargs["replay_mode"]
|
||||
|
||||
MultiAgentPrioritizedReplayBuffer.__init__(
|
||||
self,
|
||||
capacity=capacity,
|
||||
|
@ -183,19 +172,21 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
|
||||
kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)
|
||||
|
||||
pids_and_batches = self._maybe_split_into_policy_batches(batch)
|
||||
|
||||
# We need to split batches into timesteps, sequences or episodes
|
||||
# here already to properly keep track of self.last_added_batches
|
||||
# underlying buffers should not split up the batch any further
|
||||
with self.add_batch_timer:
|
||||
if self.storage_unit == StorageUnit.TIMESTEPS:
|
||||
for policy_id, sample_batch in batch.policy_batches.items():
|
||||
for policy_id, sample_batch in pids_and_batches.items():
|
||||
timeslices = sample_batch.timeslices(1)
|
||||
for time_slice in timeslices:
|
||||
self.replay_buffers[policy_id].add(time_slice, **kwargs)
|
||||
self.last_added_batches[policy_id].append(time_slice)
|
||||
|
||||
elif self.storage_unit == StorageUnit.SEQUENCES:
|
||||
for policy_id, sample_batch in batch.policy_batches.items():
|
||||
for policy_id, sample_batch in pids_and_batches.items():
|
||||
timeslices = timeslice_along_seq_lens_with_overlap(
|
||||
sample_batch=sample_batch,
|
||||
seq_lens=sample_batch.get(SampleBatch.SEQ_LENS)
|
||||
|
@ -210,7 +201,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
self.last_added_batches[policy_id].append(slice)
|
||||
|
||||
elif self.storage_unit == StorageUnit.EPISODES:
|
||||
for policy_id, sample_batch in batch.policy_batches.items():
|
||||
for policy_id, sample_batch in pids_and_batches.items():
|
||||
for eps in sample_batch.split_by_episode():
|
||||
# Only add full episodes to the buffer
|
||||
if (
|
||||
|
@ -228,7 +219,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
"dropped."
|
||||
)
|
||||
elif self.storage_unit == StorageUnit.FRAGMENTS:
|
||||
for policy_id, sample_batch in batch.policy_batches.items():
|
||||
for policy_id, sample_batch in pids_and_batches.items():
|
||||
self.replay_buffers[policy_id].add(sample_batch, **kwargs)
|
||||
self.last_added_batches[policy_id].append(sample_batch)
|
||||
|
||||
|
|
|
@ -1,22 +1,22 @@
|
|||
import logging
|
||||
import collections
|
||||
from typing import Any, Dict, Optional
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import (
|
||||
_ALL_POLICIES,
|
||||
)
|
||||
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import (
|
||||
_ALL_POLICIES,
|
||||
ReplayBuffer,
|
||||
StorageUnit,
|
||||
)
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.typing import PolicyID, SampleBatchType
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.debug import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -229,15 +229,10 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
batch = batch.as_multi_agent()
|
||||
|
||||
with self.add_batch_timer:
|
||||
if self.replay_mode == ReplayMode.LOCKSTEP:
|
||||
# Lockstep mode: Store under _ALL_POLICIES key (we will always
|
||||
# only sample from all policies at the same time).
|
||||
# This means storing a MultiAgentBatch to the underlying buffer
|
||||
self._add_to_underlying_buffer(_ALL_POLICIES, batch, **kwargs)
|
||||
else:
|
||||
# Store independent SampleBatches
|
||||
for policy_id, sample_batch in batch.policy_batches.items():
|
||||
self._add_to_underlying_buffer(policy_id, sample_batch, **kwargs)
|
||||
pids_and_batches = self._maybe_split_into_policy_batches(batch)
|
||||
for policy_id, sample_batch in pids_and_batches.items():
|
||||
self._add_to_underlying_buffer(policy_id, sample_batch, **kwargs)
|
||||
|
||||
self._num_added += batch.count
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -391,3 +386,17 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
buffer_states = state["replay_buffers"]
|
||||
for policy_id in buffer_states.keys():
|
||||
self.replay_buffers[policy_id].set_state(buffer_states[policy_id])
|
||||
|
||||
def _maybe_split_into_policy_batches(self, batch: SampleBatchType):
|
||||
"""Returns a dict of policy IDs and batches, depending on our replay mode.
|
||||
|
||||
This method helps with splitting up MultiAgentBatches only if the
|
||||
self.replay_mode requires it.
|
||||
"""
|
||||
if self.replay_mode == ReplayMode.LOCKSTEP:
|
||||
return {_ALL_POLICIES: batch}
|
||||
else:
|
||||
return {
|
||||
policy_id: sample_batch
|
||||
for policy_id, sample_batch in batch.policy_batches.items()
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue