Revert "[RLlib] Simplify multiagent config (automatically infer class/spaces/config). (#16565)" (#17036)

This reverts commit e4123fff27.
This commit is contained in:
Amog Kamsetty 2021-07-13 09:57:15 -07:00 committed by GitHub
parent 27d80c4c88
commit 38b5b6d24c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 227 additions and 228 deletions

View file

@ -23,8 +23,6 @@ mlagents==0.26.0
mlagents_envs==0.26.0
# For tests on PettingZoo's multi-agent envs.
pettingzoo==1.8.2
pymunk
supersuit
# For testing in MuJoCo-like envs (in PyBullet).
pybullet==3.1.7
# For tests on RecSim and Kaggle envs.

View file

@ -2427,14 +2427,6 @@ py_test(
args = ["--as-test", "--stop-reward=40.0", "--num-gpus=0", "--num-workers=0"]
)
py_test(
name = "examples/pettingzoo_env",
main = "examples/pettingzoo_env.py",
tags = ["examples", "examples_P"],
size = "medium",
srcs = ["examples/pettingzoo_env.py"],
)
py_test(
name = "examples/restore_1_of_n_agents_from_checkpoint",
tags = ["examples", "examples_R"],

View file

@ -2,8 +2,8 @@ import unittest
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.policy.policy import LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.test_utils import check_compute_single_action, \
framework_iterator

View file

@ -23,9 +23,9 @@ from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils import deep_update, FilterManager, merge_dicts
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_tf, TensorStructType
@ -1366,44 +1366,6 @@ class Trainer(Trainable):
if simple_optim_setting != DEPRECATED_VALUE:
deprecation_warning(old="simple_optimizer", error=False)
# Loop through all policy definitions in multi-agent policies.
multiagent_config = config["multiagent"]
policies = multiagent_config.get("policies")
if not policies:
policies = {DEFAULT_POLICY_ID}
if isinstance(policies, set):
policies = multiagent_config["policies"] = {
pid: PolicySpec()
for pid in policies
}
is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies
for pid, policy_spec in policies.copy().items():
# Policy IDs must be strings.
if not isinstance(pid, str):
raise ValueError("Policy keys must be strs, got {}".format(
type(pid)))
# Convert to PolicySpec if plain list/tuple.
if not isinstance(policy_spec, PolicySpec):
# Values must be lists/tuples of len 4.
if not isinstance(policy_spec, (list, tuple)) or \
len(policy_spec) != 4:
raise ValueError(
"Policy specs must be tuples/lists of "
"(cls or None, obs_space, action_space, config), "
f"got {policy_spec}")
policies[pid] = PolicySpec(*policy_spec)
# Config is None -> Set to {}.
if policies[pid].config is None:
policies[pid] = policies[pid]._replace(config={})
# Config not a dict.
elif not isinstance(policies[pid].config, dict):
raise ValueError(
f"Multiagent policy config for {pid} must be a dict, "
f"but got {type(policies[pid].config)}!")
framework = config.get("framework")
# Multi-GPU setting: Must use TFMultiGPU if tf.
if config.get("num_gpus", 0) > 1:
@ -1423,7 +1385,7 @@ class Trainer(Trainable):
config["simple_optimizer"] = True
# TF + Multi-agent case: Try using MultiGPU optimizer (only
# if all policies used are DynamicTFPolicies).
elif is_multiagent:
elif len(config["multiagent"]["policies"]) > 0:
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
default_policy_cls = None if trainer_obj_or_none is None else \
getattr(trainer_obj_or_none, "_policy_class", None)

View file

@ -27,7 +27,7 @@ from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import merge_dicts
@ -139,7 +139,9 @@ class RolloutWorker(ParallelIteratorWorker):
env_creator: Callable[[EnvContext], EnvType],
validate_env: Optional[Callable[[EnvType, EnvContext],
None]] = None,
policy_spec: Union[type, Dict[PolicyID, PolicySpec]] = None,
policy_spec: Union[type, Dict[
str, Tuple[Optional[type], gym.Space, gym.Space,
PartialTrainerConfigDict]]] = None,
policy_mapping_fn: Optional[Callable[
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
@ -181,7 +183,9 @@ class RolloutWorker(ParallelIteratorWorker):
fake_sampler: bool = False,
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]] = None,
policy=None,
policy: Union[type, Dict[
str, Tuple[Optional[type], gym.Space, gym.Space,
PartialTrainerConfigDict]]] = None,
monitor_path=None,
):
"""Initialize a rollout worker.
@ -192,11 +196,11 @@ class RolloutWorker(ParallelIteratorWorker):
validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
Optional callable to validate the generated environment (only
on worker=0).
policy_spec (Optional[Union[Type[Policy],
MultiAgentPolicyConfigDict]]): The MultiAgentPolicyConfigDict
mapping policy IDs (str) to PolicySpec's or a single policy
class to use.
If a dict is specified, then we are in multi-agent mode and a
policy_spec (Union[type, Dict[str, Tuple[Type[Policy], gym.Space,
gym.Space, PartialTrainerConfigDict]]]): Either a Policy class
or a dict of policy id strings to
(Policy class, obs_space, action_space, config)-tuples. If a
dict is specified, then we are in multi-agent mode and a
policy_mapping_fn can also be set (if not, will map all agents
to DEFAULT_POLICY_ID).
policy_mapping_fn (Optional[Callable[[AgentID, MultiAgentEpisode],
@ -305,26 +309,15 @@ class RolloutWorker(ParallelIteratorWorker):
gym.spaces.Space]]]): An optional space dict mapping policy IDs
to (obs_space, action_space)-tuples. This is used in case no
Env is created on this RolloutWorker.
policy: Obsoleted arg. Use `policy_spec` instead.
monitor_path: Obsoleted arg. Use `record_env` instead.
"""
# Deprecated args.
if policy is not None:
deprecation_warning("policy", "policy_spec", error=False)
policy_spec = policy
assert policy_spec is not None, \
"Must provide `policy_spec` when creating RolloutWorker!"
# Do quick translation into MultiAgentPolicyConfigDict.
if not isinstance(policy_spec, dict):
policy_spec = {
DEFAULT_POLICY_ID: PolicySpec(policy_class=policy_spec)
}
policy_spec = {
pid: spec if isinstance(spec, PolicySpec) else PolicySpec(*spec)
for pid, spec in policy_spec.copy().items()
}
assert policy_spec is not None, "Must provide `policy_spec` when " \
"creating RolloutWorker!"
if monitor_path is not None:
deprecation_warning("monitor_path", "record_env", error=False)
record_env = monitor_path
@ -490,7 +483,7 @@ class RolloutWorker(ParallelIteratorWorker):
self.make_env_fn = make_env
self.tf_sess = None
policy_dict = _determine_spaces_for_multi_agent_dict(
policy_dict = _validate_and_canonicalize(
policy_spec, self.env, spaces=spaces, policy_config=policy_config)
# List of IDs of those policies, which should be trained.
# By default, these are all policies found in the policy_dict.
@ -1309,7 +1302,7 @@ class RolloutWorker(ParallelIteratorWorker):
for name, (cls, obs_space, act_space,
conf) in sorted(policy_dict.items()):
logger.debug("Creating policy for {}".format(name))
merged_conf = merge_dicts(policy_config, conf or {})
merged_conf = merge_dicts(policy_config, conf)
merged_conf["num_workers"] = self.num_workers
merged_conf["worker_index"] = self.worker_index
if self.preprocessing_enabled:
@ -1386,57 +1379,74 @@ class RolloutWorker(ParallelIteratorWorker):
self.sampler.shutdown = True
def _determine_spaces_for_multi_agent_dict(
multi_agent_dict: MultiAgentPolicyConfigDict,
env: Optional[EnvType] = None,
def _validate_and_canonicalize(
policy: Union[Type[Policy], MultiAgentPolicyConfigDict],
env: Optional[EnvType],
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]] = None,
policy_config: Optional[PartialTrainerConfigDict] = None,
gym.spaces.Space]]],
policy_config: Optional[PartialTrainerConfigDict],
) -> MultiAgentPolicyConfigDict:
# Try extracting spaces from env.
env_obs_space = None
env_act_space = None
if env is not None and hasattr(env, "observation_space") and isinstance(
env.observation_space, gym.Space):
env_obs_space = env.observation_space
if env is not None and hasattr(env, "action_space") and isinstance(
env.action_space, gym.Space):
env_act_space = env.action_space
if isinstance(policy, dict):
_validate_multiagent_config(policy)
return policy
elif not issubclass(policy, Policy):
raise ValueError(f"`policy` ({policy}) must be a rllib.Policy class!")
else:
if (isinstance(env, MultiAgentEnv)
and not hasattr(env, "observation_space")):
raise ValueError(
"MultiAgentEnv must have observation_space defined if run "
"in a single-agent configuration.")
if env is not None:
return {
DEFAULT_POLICY_ID: (policy, env.observation_space,
env.action_space, {})
}
for pid, policy_spec in multi_agent_dict.copy().items():
if policy_spec.observation_space is None:
if spaces is not None:
obs_space = spaces[pid][0]
elif env_obs_space is not None:
obs_space = env_obs_space
elif "observation_space" in policy_config:
obs_space = policy_config["observation_space"]
else:
if spaces is None:
if "action_space" not in policy_config or \
"observation_space" not in policy_config:
raise ValueError(
"`observation_space` not provided in PolicySpec for "
f"{pid} and env does not have an observation space OR "
"no spaces received from other workers' env(s) OR no "
"`observation_space` specified in config!")
multi_agent_dict[pid] = multi_agent_dict[pid]._replace(
observation_space=obs_space)
"If no env given, must provide obs/action spaces either "
"in the `multiagent.policies` dict or under "
"`config.[observation|action]_space`!")
spaces = {
DEFAULT_POLICY_ID: (policy_config["observation_space"],
policy_config["action_space"])
}
return {
DEFAULT_POLICY_ID: (policy, spaces[DEFAULT_POLICY_ID][0],
spaces[DEFAULT_POLICY_ID][1], {})
}
if policy_spec.action_space is None:
if spaces is not None:
act_space = spaces[pid][1]
elif env_act_space is not None:
act_space = env_act_space
elif "action_space" in policy_config:
act_space = policy_config["action_space"]
else:
raise ValueError(
"`action_space` not provided in PolicySpec for "
f"{pid} and env does not have an action space OR "
"no spaces received from other workers' env(s) OR no "
"`action_space` specified in config!")
multi_agent_dict[pid] = multi_agent_dict[pid]._replace(
action_space=act_space)
return multi_agent_dict
def _validate_multiagent_config(policy: MultiAgentPolicyConfigDict,
allow_none_graph: bool = False) -> None:
# Loop through all policy definitions in multi-agent policie
for k, v in policy.items():
if not isinstance(k, str):
raise ValueError("Policy key must be str, got {}!".format(k))
if not isinstance(v, (tuple, list)) or len(v) != 4:
raise ValueError(
"policy values must be tuples/lists of "
"(cls or None, obs_space, action_space, config), got {}".
format(v))
if allow_none_graph and v[0] is None:
pass
elif not issubclass(v[0], Policy):
raise ValueError("policy tuple value 0 must be a rllib.Policy "
"class or None, got {}".format(v[0]))
if not isinstance(v[1], gym.Space):
raise ValueError(
"policy tuple value 1 (observation_space) must be a "
"gym.Space, got {}".format(type(v[1])))
if not isinstance(v[2], gym.Space):
raise ValueError("policy tuple value 2 (action_space) must be a "
"gym.Space, got {}".format(type(v[2])))
if not isinstance(v[3], dict):
raise ValueError("policy tuple value 3 (config) must be a dict, "
"got {}".format(type(v[3])))
def _validate_env(env: Any) -> EnvType:

View file

@ -6,12 +6,13 @@ from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
import ray
from ray.actor import ActorHandle
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
_validate_multiagent_config
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput, D4RLReader
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy import Policy
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.framework import try_import_tf
@ -364,21 +365,20 @@ class WorkerSet:
else:
input_evaluation = config["input_evaluation"]
# Assert everything is correct in "multiagent" config dict (if given).
ma_policies = config["multiagent"]["policies"]
if ma_policies:
for pid, policy_spec in ma_policies.copy().items():
assert isinstance(policy_spec, (PolicySpec, list, tuple))
# Class is None -> Use `policy_cls`.
if policy_spec.policy_class is None:
ma_policies[pid] = ma_policies[pid]._replace(
policy_class=policy_cls)
policies = ma_policies
# Create a policy_spec (MultiAgentPolicyConfigDict),
# even if no "multiagent" setup given by user.
# Fill in the default policy_cls if 'None' is specified in multiagent.
if config["multiagent"]["policies"]:
tmp = config["multiagent"]["policies"]
_validate_multiagent_config(tmp, allow_none_graph=True)
# TODO: (sven) Allow for setting observation and action spaces to
# None as well, in which case, spaces are taken from env.
# It's tedious to have to provide these in a multi-agent config.
for k, v in tmp.items():
if v[0] is None:
tmp[k] = (policy_cls, v[1], v[2], v[3])
policy_spec = tmp
# Otherwise, policy spec is simply the policy class itself.
else:
policies = policy_cls
policy_spec = policy_cls
if worker_index == 0:
extra_python_environs = config.get(
@ -390,7 +390,7 @@ class WorkerSet:
worker = cls(
env_creator=env_creator,
validate_env=validate_env,
policy_spec=policies,
policy_spec=policy_spec,
policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
policies_to_train=config["multiagent"]["policies_to_train"],
tf_session_creator=(session_creator

View file

@ -10,6 +10,7 @@ execution, set the TF_TIMELINE_DIR environment variable.
"""
import argparse
import gym
import os
import random
@ -20,7 +21,6 @@ from ray.rllib.examples.models.shared_weights_model import \
SharedWeightsModel1, SharedWeightsModel2, TF2SharedWeightsModel, \
TorchSharedWeightsModel
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_learning_achieved
@ -73,6 +73,11 @@ if __name__ == "__main__":
ModelCatalog.register_custom_model("model1", mod1)
ModelCatalog.register_custom_model("model2", mod2)
# Get obs- and action Spaces.
single_env = gym.make("CartPole-v0")
obs_space = single_env.observation_space
act_space = single_env.action_space
# Each policy can have a different configuration (including custom model).
def gen_policy(i):
config = {
@ -81,7 +86,7 @@ if __name__ == "__main__":
},
"gamma": random.choice([0.95, 0.99]),
}
return PolicySpec(config=config)
return (None, obs_space, act_space, config)
# Setup PPO with an ensemble of `num_policies` different policies.
policies = {

View file

@ -14,6 +14,7 @@ Result for PG_multi_cartpole_0:
"""
import argparse
import gym
import os
import ray
@ -21,7 +22,6 @@ from ray import tune
from ray.tune.registry import register_env
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.test_utils import check_learning_achieved
parser = argparse.ArgumentParser()
@ -58,6 +58,9 @@ if __name__ == "__main__":
# Simple environment with 4 independent cartpole entities
register_env("multi_agent_cartpole",
lambda _: MultiAgentCartPole({"num_agents": 4}))
single_env = gym.make("CartPole-v0")
obs_space = single_env.observation_space
act_space = single_env.action_space
stop = {
"training_iteration": args.stop_iters,
@ -71,9 +74,9 @@ if __name__ == "__main__":
# The multiagent Policy map.
"policies": {
# The Policy we are actually learning.
"pg_policy": PolicySpec(config={"framework": args.framework}),
"pg_policy": (None, obs_space, act_space, {}),
# Random policy we are playing against.
"random": PolicySpec(policy_class=RandomPolicy),
"random": (RandomPolicy, obs_space, act_space, {}),
},
# Map to either random behavior or PR learning behavior based on
# the agent's ID.

View file

@ -14,6 +14,11 @@ if __name__ == "__main__":
env = env_creator({})
register_env("waterworld", env_creator)
obs_space = env.observation_space
act_spc = env.action_space
policies = {agent: (None, obs_space, act_spc, {}) for agent in env.agents}
tune.run(
"APEX_DDPG",
stop={"episodes_total": 60000},
@ -26,7 +31,7 @@ if __name__ == "__main__":
"num_workers": 2,
# Method specific
"multiagent": {
"policies": set(env.agents),
"policies": policies,
"policy_mapping_fn": (
lambda agent_id, episode, **kwargs: agent_id),
},

View file

@ -8,15 +8,27 @@ from pettingzoo.sisl import waterworld_v0
if __name__ == "__main__":
# RDQN - Rainbow DQN
# ADQN - Apex DQN
def env_creator(args):
return PettingZooEnv(waterworld_v0.env())
register_env("waterworld", lambda _: PettingZooEnv(waterworld_v0.env()))
env = env_creator({})
register_env("waterworld", env_creator)
obs_space = env.observation_space
act_space = env.action_space
policies = {"shared_policy": (None, obs_space, act_space, {})}
# for all methods
policy_ids = list(policies.keys())
tune.run(
"APEX_DDPG",
stop={"episodes_total": 60000},
checkpoint_freq=10,
config={
# Enviroment specific.
# Enviroment specific
"env": "waterworld",
# General
@ -36,13 +48,9 @@ if __name__ == "__main__":
"target_network_update_freq": 50000,
"timesteps_per_iteration": 25000,
# Method specific.
# Method specific
"multiagent": {
# We only have one policy (calling it "shared").
# Class, obs/act-spaces, and config will be derived
# automatically.
"policies": {"shared_policy"},
# Always use "shared" policy.
"policies": policies,
"policy_mapping_fn": (
lambda agent_id, episode, **kwargs: "shared_policy"),
},

View file

@ -6,7 +6,7 @@ from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0
import ray
from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.env import PettingZooEnv
from pettingzoo.butterfly import pistonball_v4
from pettingzoo.butterfly import pistonball_v1
from ray.tune.registry import register_env
@ -21,9 +21,9 @@ if __name__ == "__main__":
"""
alg_name = "PPO"
# Function that outputs the environment you wish to register.
# function that outputs the environment you wish to register.
def env_creator(config):
env = pistonball_v4.env(local_ratio=config.get("local_ratio", 0.2))
env = pistonball_v1.env(local_ratio=config.get("local_ratio", 0.2))
env = dtype_v0(env, dtype=float32)
env = color_reduction_v0(env, mode="R")
env = normalize_obs_v0(env)
@ -32,23 +32,29 @@ if __name__ == "__main__":
num_cpus = 1
num_rollouts = 2
# Gets default training configuration and specifies the POMgame to load.
# 1. Gets default training configuration and specifies the POMgame to load.
config = deepcopy(get_trainer_class(alg_name)._default_config)
# Set environment config. This will be passed to
# 2. Set environment config. This will be passed to
# the env_creator function via the register env lambda below.
config["env_config"] = {"local_ratio": 0.5}
# Register env
# 3. Register env
register_env("pistonball",
lambda config: PettingZooEnv(env_creator(config)))
# Configuration for multiagent setup with policy sharing:
# 4. Extract space dimensions
test_env = PettingZooEnv(env_creator({}))
obs_space = test_env.observation_space
act_space = test_env.action_space
# 5. Configuration for multiagent setup with policy sharing:
config["multiagent"] = {
# Setup a single, shared policy for all agents.
"policies": {"av"},
# Map all agents to that policy.
"policy_mapping_fn": lambda agent_id, episode, **kwargs: "av",
"policies": {
# the first tuple value is None -> uses default policy
"av": (None, obs_space, act_space, {}),
},
"policy_mapping_fn": lambda agent_id, episode, **kwargs: "av"
}
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
@ -68,9 +74,11 @@ if __name__ == "__main__":
# If no_done_at_end = True, environment is not resetted
# when dones[__all__]= True.
# Initialize ray and trainer object
# 6. Initialize ray and trainer object
ray.init(num_cpus=num_cpus + 1)
trainer = get_trainer_class(alg_name)(env="pistonball", config=config)
# Train once
# 7. Train once
trainer.train()
test_env.reset()

View file

@ -8,6 +8,7 @@ This demonstrates running the following policies in competition:
"""
import argparse
from gym.spaces import Discrete
import os
import random
@ -17,7 +18,6 @@ from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.examples.env.rock_paper_scissors import RockPaperScissors
from ray.rllib.examples.policy.rock_paper_scissors_dummies import \
BeatLastHeuristic, AlwaysSameHeuristic
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check_learning_achieved
@ -93,9 +93,10 @@ def run_heuristic_vs_learned(args, use_lstm=False, trainer="PG"):
"multiagent": {
"policies_to_train": ["learned"],
"policies": {
"always_same": PolicySpec(policy_class=AlwaysSameHeuristic),
"beat_last": PolicySpec(policy_class=BeatLastHeuristic),
"learned": PolicySpec(config={
"always_same": (AlwaysSameHeuristic, Discrete(3), Discrete(3),
{}),
"beat_last": (BeatLastHeuristic, Discrete(3), Discrete(3), {}),
"learned": (None, Discrete(3), Discrete(3), {
"model": {
"use_lstm": use_lstm
},

View file

@ -9,7 +9,7 @@ See also: centralized_critic.py for centralized critic PPO on this game.
"""
import argparse
from gym.spaces import Dict, Discrete, Tuple, MultiDiscrete
from gym.spaces import Tuple, MultiDiscrete, Dict, Discrete
import os
import ray
@ -17,7 +17,6 @@ from ray import tune
from ray.tune import register_env, grid_search
from ray.rllib.env.multi_agent_env import ENV_STATE
from ray.rllib.examples.env.two_step_game import TwoStepGame
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.test_utils import check_learning_achieved
parser = argparse.ArgumentParser()
@ -81,8 +80,14 @@ if __name__ == "__main__":
grouping, obs_space=obs_space, act_space=act_space))
if args.run == "contrib/MADDPG":
obs_space = Discrete(6)
act_space = TwoStepGame.action_space
obs_space_dict = {
"agent_1": Discrete(6),
"agent_2": Discrete(6),
}
act_space_dict = {
"agent_1": TwoStepGame.action_space,
"agent_2": TwoStepGame.action_space,
}
config = {
"learning_starts": 100,
"env_config": {
@ -90,14 +95,12 @@ if __name__ == "__main__":
},
"multiagent": {
"policies": {
"pol1": PolicySpec(
observation_space=obs_space,
action_space=act_space,
config={"agent_id": 0}),
"pol2": PolicySpec(
observation_space=obs_space,
action_space=act_space,
config={"agent_id": 1}),
"pol1": (None, Discrete(6), TwoStepGame.action_space, {
"agent_id": 0,
}),
"pol2": (None, Discrete(6), TwoStepGame.action_space, {
"agent_id": 1,
}),
},
"policy_mapping_fn": (
lambda aid, **kwargs: "pol2" if aid else "pol1"),

View file

@ -6,6 +6,7 @@ via a custom training workflow.
"""
import argparse
import gym
import os
import ray
@ -123,14 +124,17 @@ if __name__ == "__main__":
# Simple environment with 4 independent cartpole entities
register_env("multi_agent_cartpole",
lambda _: MultiAgentCartPole({"num_agents": 4}))
single_env = gym.make("CartPole-v0")
obs_space = single_env.observation_space
act_space = single_env.action_space
# Note that since the trainer below does not include a default policy or
# policy configs, we have to explicitly set it in the multiagent config:
policies = {
"ppo_policy": (PPOTorchPolicy if args.torch or args.mixed_torch_tf else
PPOTFPolicy, None, None, PPO_CONFIG),
"dqn_policy": (DQNTorchPolicy
if args.torch else DQNTFPolicy, None, None, DQN_CONFIG),
PPOTFPolicy, obs_space, act_space, PPO_CONFIG),
"dqn_policy": (DQNTorchPolicy if args.torch else DQNTFPolicy,
obs_space, act_space, DQN_CONFIG),
}
def policy_mapping_fn(agent_id, episode, **kwargs):

View file

@ -1,5 +1,4 @@
from abc import ABCMeta, abstractmethod
from collections import namedtuple
import gym
from gym.spaces import Box
import logging
@ -24,7 +23,6 @@ torch, _ = try_import_torch()
if TYPE_CHECKING:
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.utils.typing import PartialTrainerConfigDict
logger = logging.getLogger(__name__)
@ -32,33 +30,6 @@ logger = logging.getLogger(__name__)
# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
LEARNER_STATS_KEY = "learner_stats"
# A policy spec used in the "config.multiagent.policies" specification dict
# as values (keys are the policy IDs (str)). E.g.:
# config:
# multiagent:
# policies: {
# "pol1": PolicySpec(None, Box, Discrete(2), {"lr": 0.0001}),
# "pol2": PolicySpec(config={"lr": 0.001}),
# }
PolicySpec = namedtuple(
"PolicySpec",
[
# If None, use the Trainer's default policy class stored under
# `Trainer._policy_class`.
"policy_class", # type: Union[type, None]
# If None, use the env's observation space. If None and there is no Env
# (e.g. offline RL), an error is thrown.
"observation_space", # type: Union[gym.Space, None]
# If None, use the env's action space. If None and there is no Env
# (e.g. offline RL), an error is thrown.
"action_space", # type: Union[gym.Space, None]
# Overrides defined keys in the main Trainer config.
# If None, use {}.
"config", # type: Union[PartialTrainerConfigDict, None]
])
# From 3.7 on, we could pass `defaults` into the above constructor.
PolicySpec.__new__.__defaults__ = (None, None, None, None)
@DeveloperAPI
class Policy(metaclass=ABCMeta):

View file

@ -1,4 +1,5 @@
import glob
import gym
import json
import numpy as np
import os
@ -12,6 +13,7 @@ import ray
from ray.tune.registry import register_env, register_input, \
registry_get_input, registry_contains_input
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.offline import IOContext, JsonWriter, JsonReader, InputReader, \
ShuffledInput
@ -177,6 +179,12 @@ class AgentIOTest(unittest.TestCase):
def testMultiAgent(self):
register_env("multi_agent_cartpole",
lambda _: MultiAgentCartPole({"num_agents": 10}))
single_env = gym.make("CartPole-v0")
def gen_policy():
obs_space = single_env.observation_space
act_space = single_env.action_space
return (PGTFPolicy, obs_space, act_space, {})
for fw in framework_iterator():
pg = PGTrainer(
@ -185,7 +193,10 @@ class AgentIOTest(unittest.TestCase):
"num_workers": 0,
"output": self.test_dir,
"multiagent": {
"policies": {"policy_1", "policy_2"},
"policies": {
"policy_1": gen_policy(),
"policy_2": gen_policy(),
},
"policy_mapping_fn": (
lambda aid, **kwargs: random.choice(
["policy_1", "policy_2"])),
@ -204,7 +215,10 @@ class AgentIOTest(unittest.TestCase):
"input_evaluation": ["simulation"],
"train_batch_size": 2000,
"multiagent": {
"policies": {"policy_1", "policy_2"},
"policies": {
"policy_1": gen_policy(),
"policy_2": gen_policy(),
},
"policy_mapping_fn": (
lambda aid, **kwargs: random.choice(
["policy_1", "policy_2"])),

View file

@ -16,7 +16,6 @@ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.numpy import one_hot
from ray.rllib.utils.test_utils import check
@ -424,13 +423,16 @@ class TestMultiAgentEnv(unittest.TestCase):
n = 10
register_env("multi_agent_cartpole",
lambda _: MultiAgentCartPole({"num_agents": n}))
single_env = gym.make("CartPole-v0")
def gen_policy():
config = {
"gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
"n_step": random.choice([1, 2, 3, 4, 5]),
}
return PolicySpec(config=config)
obs_space = single_env.observation_space
act_space = single_env.action_space
return (None, obs_space, act_space, config)
pg = PGTrainer(
env="multi_agent_cartpole",

View file

@ -8,6 +8,7 @@ import unittest
import ray
from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.vector_env import VectorEnv
@ -450,10 +451,10 @@ class NestedSpacesTest(unittest.TestCase):
"multiagent": {
"policies": {
"tuple_policy": (
None, TUPLE_SPACE, act_space,
PGTFPolicy, TUPLE_SPACE, act_space,
{"model": {"custom_model": "tuple_spy"}}),
"dict_policy": (
None, DICT_SPACE, act_space,
PGTFPolicy, DICT_SPACE, act_space,
{"model": {"custom_model": "dict_spy"}}),
},
"policy_mapping_fn": lambda aid, **kwargs: {

View file

@ -24,12 +24,16 @@ class TestPettingZooEnv(unittest.TestCase):
config = deepcopy(agent_class._default_config)
test_env = PettingZooEnv(simple_spread_v2.env())
obs_space = test_env.observation_space
act_space = test_env.action_space
test_env.close()
config["multiagent"] = {
# Set of policy IDs (by default, will use Trainer's
# default policy class, the env's obs/act spaces and config={}).
"policies": {"av"},
# Mapping function that always returns "av" as policy ID to use
# (for any agent).
"policies": {
# the first tuple value is None -> uses default policy
"av": (None, obs_space, act_space, {}),
},
"policy_mapping_fn": lambda agent_id, episode, **kwargs: "av"
}

View file

@ -1,3 +1,4 @@
from gym.spaces import Box, Discrete
import os
from pathlib import Path
import re
@ -149,6 +150,9 @@ def learn_test_multi_agent_plus_rollout(algo):
def policy_fn(agent_id, episode, **kwargs):
return "pol{}".format(agent_id)
observation_space = Box(float("-inf"), float("inf"), (4, ))
action_space = Discrete(2)
config = {
"num_gpus": 0,
"num_workers": 1,
@ -158,7 +162,10 @@ def learn_test_multi_agent_plus_rollout(algo):
"framework": fw,
"env": MultiAgentCartPole,
"multiagent": {
"policies": {"pol0", "pol1"},
"policies": {
"pol0": (None, observation_space, action_space, {}),
"pol1": (None, observation_space, action_space, {}),
},
"policy_mapping_fn": policy_fn,
},
}

View file

@ -6,7 +6,7 @@ from collections import defaultdict
import random
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, \
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
logger = logging.getLogger(__name__)

View file

@ -1,10 +1,10 @@
from typing import Any, Dict, List, Tuple, Union, TYPE_CHECKING
import gym
if TYPE_CHECKING:
from ray.rllib.utils import try_import_tf, try_import_torch
_, tf, _ = try_import_tf()
torch, _ = try_import_torch()
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.policy.view_requirement import ViewRequirement
@ -42,7 +42,8 @@ AgentID = Any
PolicyID = str
# Type of the config["multiagent"]["policies"] dict for multi-agent training.
MultiAgentPolicyConfigDict = Dict[PolicyID, "PolicySpec"]
MultiAgentPolicyConfigDict = Dict[PolicyID, Tuple[Union[
type, None], gym.Space, gym.Space, PartialTrainerConfigDict]]
# Represents an environment id. These could be:
# - An int index for a sub-env within a vectorized env.