mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Revert "[RLlib] Simplify multiagent config (automatically infer class/spaces/config). (#16565)" (#17036)
This reverts commit e4123fff27
.
This commit is contained in:
parent
27d80c4c88
commit
38b5b6d24c
22 changed files with 227 additions and 228 deletions
|
@ -23,8 +23,6 @@ mlagents==0.26.0
|
|||
mlagents_envs==0.26.0
|
||||
# For tests on PettingZoo's multi-agent envs.
|
||||
pettingzoo==1.8.2
|
||||
pymunk
|
||||
supersuit
|
||||
# For testing in MuJoCo-like envs (in PyBullet).
|
||||
pybullet==3.1.7
|
||||
# For tests on RecSim and Kaggle envs.
|
||||
|
|
|
@ -2427,14 +2427,6 @@ py_test(
|
|||
args = ["--as-test", "--stop-reward=40.0", "--num-gpus=0", "--num-workers=0"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/pettingzoo_env",
|
||||
main = "examples/pettingzoo_env.py",
|
||||
tags = ["examples", "examples_P"],
|
||||
size = "medium",
|
||||
srcs = ["examples/pettingzoo_env.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/restore_1_of_n_agents_from_checkpoint",
|
||||
tags = ["examples", "examples_R"],
|
||||
|
|
|
@ -2,8 +2,8 @@ import unittest
|
|||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.test_utils import check_compute_single_action, \
|
||||
framework_iterator
|
||||
|
||||
|
|
|
@ -23,9 +23,9 @@ from ray.rllib.evaluation.metrics import collect_metrics
|
|||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils import deep_update, FilterManager, merge_dicts
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.framework import try_import_tf, TensorStructType
|
||||
|
@ -1366,44 +1366,6 @@ class Trainer(Trainable):
|
|||
if simple_optim_setting != DEPRECATED_VALUE:
|
||||
deprecation_warning(old="simple_optimizer", error=False)
|
||||
|
||||
# Loop through all policy definitions in multi-agent policies.
|
||||
multiagent_config = config["multiagent"]
|
||||
policies = multiagent_config.get("policies")
|
||||
if not policies:
|
||||
policies = {DEFAULT_POLICY_ID}
|
||||
if isinstance(policies, set):
|
||||
policies = multiagent_config["policies"] = {
|
||||
pid: PolicySpec()
|
||||
for pid in policies
|
||||
}
|
||||
is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies
|
||||
|
||||
for pid, policy_spec in policies.copy().items():
|
||||
# Policy IDs must be strings.
|
||||
if not isinstance(pid, str):
|
||||
raise ValueError("Policy keys must be strs, got {}".format(
|
||||
type(pid)))
|
||||
|
||||
# Convert to PolicySpec if plain list/tuple.
|
||||
if not isinstance(policy_spec, PolicySpec):
|
||||
# Values must be lists/tuples of len 4.
|
||||
if not isinstance(policy_spec, (list, tuple)) or \
|
||||
len(policy_spec) != 4:
|
||||
raise ValueError(
|
||||
"Policy specs must be tuples/lists of "
|
||||
"(cls or None, obs_space, action_space, config), "
|
||||
f"got {policy_spec}")
|
||||
policies[pid] = PolicySpec(*policy_spec)
|
||||
|
||||
# Config is None -> Set to {}.
|
||||
if policies[pid].config is None:
|
||||
policies[pid] = policies[pid]._replace(config={})
|
||||
# Config not a dict.
|
||||
elif not isinstance(policies[pid].config, dict):
|
||||
raise ValueError(
|
||||
f"Multiagent policy config for {pid} must be a dict, "
|
||||
f"but got {type(policies[pid].config)}!")
|
||||
|
||||
framework = config.get("framework")
|
||||
# Multi-GPU setting: Must use TFMultiGPU if tf.
|
||||
if config.get("num_gpus", 0) > 1:
|
||||
|
@ -1423,7 +1385,7 @@ class Trainer(Trainable):
|
|||
config["simple_optimizer"] = True
|
||||
# TF + Multi-agent case: Try using MultiGPU optimizer (only
|
||||
# if all policies used are DynamicTFPolicies).
|
||||
elif is_multiagent:
|
||||
elif len(config["multiagent"]["policies"]) > 0:
|
||||
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
|
||||
default_policy_cls = None if trainer_obj_or_none is None else \
|
||||
getattr(trainer_obj_or_none, "_policy_class", None)
|
||||
|
|
|
@ -27,7 +27,7 @@ from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
|
|||
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
|
||||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
@ -139,7 +139,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
env_creator: Callable[[EnvContext], EnvType],
|
||||
validate_env: Optional[Callable[[EnvType, EnvContext],
|
||||
None]] = None,
|
||||
policy_spec: Union[type, Dict[PolicyID, PolicySpec]] = None,
|
||||
policy_spec: Union[type, Dict[
|
||||
str, Tuple[Optional[type], gym.Space, gym.Space,
|
||||
PartialTrainerConfigDict]]] = None,
|
||||
policy_mapping_fn: Optional[Callable[
|
||||
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
|
||||
policies_to_train: Optional[List[PolicyID]] = None,
|
||||
|
@ -181,7 +183,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
fake_sampler: bool = False,
|
||||
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
|
||||
gym.spaces.Space]]] = None,
|
||||
policy=None,
|
||||
policy: Union[type, Dict[
|
||||
str, Tuple[Optional[type], gym.Space, gym.Space,
|
||||
PartialTrainerConfigDict]]] = None,
|
||||
monitor_path=None,
|
||||
):
|
||||
"""Initialize a rollout worker.
|
||||
|
@ -192,11 +196,11 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
|
||||
Optional callable to validate the generated environment (only
|
||||
on worker=0).
|
||||
policy_spec (Optional[Union[Type[Policy],
|
||||
MultiAgentPolicyConfigDict]]): The MultiAgentPolicyConfigDict
|
||||
mapping policy IDs (str) to PolicySpec's or a single policy
|
||||
class to use.
|
||||
If a dict is specified, then we are in multi-agent mode and a
|
||||
policy_spec (Union[type, Dict[str, Tuple[Type[Policy], gym.Space,
|
||||
gym.Space, PartialTrainerConfigDict]]]): Either a Policy class
|
||||
or a dict of policy id strings to
|
||||
(Policy class, obs_space, action_space, config)-tuples. If a
|
||||
dict is specified, then we are in multi-agent mode and a
|
||||
policy_mapping_fn can also be set (if not, will map all agents
|
||||
to DEFAULT_POLICY_ID).
|
||||
policy_mapping_fn (Optional[Callable[[AgentID, MultiAgentEpisode],
|
||||
|
@ -305,26 +309,15 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
gym.spaces.Space]]]): An optional space dict mapping policy IDs
|
||||
to (obs_space, action_space)-tuples. This is used in case no
|
||||
Env is created on this RolloutWorker.
|
||||
policy: Obsoleted arg. Use `policy_spec` instead.
|
||||
monitor_path: Obsoleted arg. Use `record_env` instead.
|
||||
"""
|
||||
|
||||
# Deprecated args.
|
||||
if policy is not None:
|
||||
deprecation_warning("policy", "policy_spec", error=False)
|
||||
policy_spec = policy
|
||||
assert policy_spec is not None, \
|
||||
"Must provide `policy_spec` when creating RolloutWorker!"
|
||||
|
||||
# Do quick translation into MultiAgentPolicyConfigDict.
|
||||
if not isinstance(policy_spec, dict):
|
||||
policy_spec = {
|
||||
DEFAULT_POLICY_ID: PolicySpec(policy_class=policy_spec)
|
||||
}
|
||||
policy_spec = {
|
||||
pid: spec if isinstance(spec, PolicySpec) else PolicySpec(*spec)
|
||||
for pid, spec in policy_spec.copy().items()
|
||||
}
|
||||
|
||||
assert policy_spec is not None, "Must provide `policy_spec` when " \
|
||||
"creating RolloutWorker!"
|
||||
if monitor_path is not None:
|
||||
deprecation_warning("monitor_path", "record_env", error=False)
|
||||
record_env = monitor_path
|
||||
|
@ -490,7 +483,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.make_env_fn = make_env
|
||||
|
||||
self.tf_sess = None
|
||||
policy_dict = _determine_spaces_for_multi_agent_dict(
|
||||
policy_dict = _validate_and_canonicalize(
|
||||
policy_spec, self.env, spaces=spaces, policy_config=policy_config)
|
||||
# List of IDs of those policies, which should be trained.
|
||||
# By default, these are all policies found in the policy_dict.
|
||||
|
@ -1309,7 +1302,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
for name, (cls, obs_space, act_space,
|
||||
conf) in sorted(policy_dict.items()):
|
||||
logger.debug("Creating policy for {}".format(name))
|
||||
merged_conf = merge_dicts(policy_config, conf or {})
|
||||
merged_conf = merge_dicts(policy_config, conf)
|
||||
merged_conf["num_workers"] = self.num_workers
|
||||
merged_conf["worker_index"] = self.worker_index
|
||||
if self.preprocessing_enabled:
|
||||
|
@ -1386,57 +1379,74 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.sampler.shutdown = True
|
||||
|
||||
|
||||
def _determine_spaces_for_multi_agent_dict(
|
||||
multi_agent_dict: MultiAgentPolicyConfigDict,
|
||||
env: Optional[EnvType] = None,
|
||||
def _validate_and_canonicalize(
|
||||
policy: Union[Type[Policy], MultiAgentPolicyConfigDict],
|
||||
env: Optional[EnvType],
|
||||
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
|
||||
gym.spaces.Space]]] = None,
|
||||
policy_config: Optional[PartialTrainerConfigDict] = None,
|
||||
gym.spaces.Space]]],
|
||||
policy_config: Optional[PartialTrainerConfigDict],
|
||||
) -> MultiAgentPolicyConfigDict:
|
||||
|
||||
# Try extracting spaces from env.
|
||||
env_obs_space = None
|
||||
env_act_space = None
|
||||
if env is not None and hasattr(env, "observation_space") and isinstance(
|
||||
env.observation_space, gym.Space):
|
||||
env_obs_space = env.observation_space
|
||||
if env is not None and hasattr(env, "action_space") and isinstance(
|
||||
env.action_space, gym.Space):
|
||||
env_act_space = env.action_space
|
||||
|
||||
for pid, policy_spec in multi_agent_dict.copy().items():
|
||||
if policy_spec.observation_space is None:
|
||||
if spaces is not None:
|
||||
obs_space = spaces[pid][0]
|
||||
elif env_obs_space is not None:
|
||||
obs_space = env_obs_space
|
||||
elif "observation_space" in policy_config:
|
||||
obs_space = policy_config["observation_space"]
|
||||
if isinstance(policy, dict):
|
||||
_validate_multiagent_config(policy)
|
||||
return policy
|
||||
elif not issubclass(policy, Policy):
|
||||
raise ValueError(f"`policy` ({policy}) must be a rllib.Policy class!")
|
||||
else:
|
||||
if (isinstance(env, MultiAgentEnv)
|
||||
and not hasattr(env, "observation_space")):
|
||||
raise ValueError(
|
||||
"`observation_space` not provided in PolicySpec for "
|
||||
f"{pid} and env does not have an observation space OR "
|
||||
"no spaces received from other workers' env(s) OR no "
|
||||
"`observation_space` specified in config!")
|
||||
multi_agent_dict[pid] = multi_agent_dict[pid]._replace(
|
||||
observation_space=obs_space)
|
||||
"MultiAgentEnv must have observation_space defined if run "
|
||||
"in a single-agent configuration.")
|
||||
if env is not None:
|
||||
return {
|
||||
DEFAULT_POLICY_ID: (policy, env.observation_space,
|
||||
env.action_space, {})
|
||||
}
|
||||
|
||||
if policy_spec.action_space is None:
|
||||
if spaces is not None:
|
||||
act_space = spaces[pid][1]
|
||||
elif env_act_space is not None:
|
||||
act_space = env_act_space
|
||||
elif "action_space" in policy_config:
|
||||
act_space = policy_config["action_space"]
|
||||
else:
|
||||
if spaces is None:
|
||||
if "action_space" not in policy_config or \
|
||||
"observation_space" not in policy_config:
|
||||
raise ValueError(
|
||||
"`action_space` not provided in PolicySpec for "
|
||||
f"{pid} and env does not have an action space OR "
|
||||
"no spaces received from other workers' env(s) OR no "
|
||||
"`action_space` specified in config!")
|
||||
multi_agent_dict[pid] = multi_agent_dict[pid]._replace(
|
||||
action_space=act_space)
|
||||
return multi_agent_dict
|
||||
"If no env given, must provide obs/action spaces either "
|
||||
"in the `multiagent.policies` dict or under "
|
||||
"`config.[observation|action]_space`!")
|
||||
spaces = {
|
||||
DEFAULT_POLICY_ID: (policy_config["observation_space"],
|
||||
policy_config["action_space"])
|
||||
}
|
||||
return {
|
||||
DEFAULT_POLICY_ID: (policy, spaces[DEFAULT_POLICY_ID][0],
|
||||
spaces[DEFAULT_POLICY_ID][1], {})
|
||||
}
|
||||
|
||||
|
||||
def _validate_multiagent_config(policy: MultiAgentPolicyConfigDict,
|
||||
allow_none_graph: bool = False) -> None:
|
||||
# Loop through all policy definitions in multi-agent policie
|
||||
for k, v in policy.items():
|
||||
if not isinstance(k, str):
|
||||
raise ValueError("Policy key must be str, got {}!".format(k))
|
||||
if not isinstance(v, (tuple, list)) or len(v) != 4:
|
||||
raise ValueError(
|
||||
"policy values must be tuples/lists of "
|
||||
"(cls or None, obs_space, action_space, config), got {}".
|
||||
format(v))
|
||||
if allow_none_graph and v[0] is None:
|
||||
pass
|
||||
elif not issubclass(v[0], Policy):
|
||||
raise ValueError("policy tuple value 0 must be a rllib.Policy "
|
||||
"class or None, got {}".format(v[0]))
|
||||
if not isinstance(v[1], gym.Space):
|
||||
raise ValueError(
|
||||
"policy tuple value 1 (observation_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[1])))
|
||||
if not isinstance(v[2], gym.Space):
|
||||
raise ValueError("policy tuple value 2 (action_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[2])))
|
||||
if not isinstance(v[3], dict):
|
||||
raise ValueError("policy tuple value 3 (config) must be a dict, "
|
||||
"got {}".format(type(v[3])))
|
||||
|
||||
|
||||
def _validate_env(env: Any) -> EnvType:
|
||||
|
|
|
@ -6,12 +6,13 @@ from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
|
||||
_validate_multiagent_config
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput, D4RLReader
|
||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
@ -364,21 +365,20 @@ class WorkerSet:
|
|||
else:
|
||||
input_evaluation = config["input_evaluation"]
|
||||
|
||||
# Assert everything is correct in "multiagent" config dict (if given).
|
||||
ma_policies = config["multiagent"]["policies"]
|
||||
if ma_policies:
|
||||
for pid, policy_spec in ma_policies.copy().items():
|
||||
assert isinstance(policy_spec, (PolicySpec, list, tuple))
|
||||
# Class is None -> Use `policy_cls`.
|
||||
if policy_spec.policy_class is None:
|
||||
ma_policies[pid] = ma_policies[pid]._replace(
|
||||
policy_class=policy_cls)
|
||||
policies = ma_policies
|
||||
|
||||
# Create a policy_spec (MultiAgentPolicyConfigDict),
|
||||
# even if no "multiagent" setup given by user.
|
||||
# Fill in the default policy_cls if 'None' is specified in multiagent.
|
||||
if config["multiagent"]["policies"]:
|
||||
tmp = config["multiagent"]["policies"]
|
||||
_validate_multiagent_config(tmp, allow_none_graph=True)
|
||||
# TODO: (sven) Allow for setting observation and action spaces to
|
||||
# None as well, in which case, spaces are taken from env.
|
||||
# It's tedious to have to provide these in a multi-agent config.
|
||||
for k, v in tmp.items():
|
||||
if v[0] is None:
|
||||
tmp[k] = (policy_cls, v[1], v[2], v[3])
|
||||
policy_spec = tmp
|
||||
# Otherwise, policy spec is simply the policy class itself.
|
||||
else:
|
||||
policies = policy_cls
|
||||
policy_spec = policy_cls
|
||||
|
||||
if worker_index == 0:
|
||||
extra_python_environs = config.get(
|
||||
|
@ -390,7 +390,7 @@ class WorkerSet:
|
|||
worker = cls(
|
||||
env_creator=env_creator,
|
||||
validate_env=validate_env,
|
||||
policy_spec=policies,
|
||||
policy_spec=policy_spec,
|
||||
policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
|
||||
policies_to_train=config["multiagent"]["policies_to_train"],
|
||||
tf_session_creator=(session_creator
|
||||
|
|
|
@ -10,6 +10,7 @@ execution, set the TF_TIMELINE_DIR environment variable.
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
import os
|
||||
import random
|
||||
|
||||
|
@ -20,7 +21,6 @@ from ray.rllib.examples.models.shared_weights_model import \
|
|||
SharedWeightsModel1, SharedWeightsModel2, TF2SharedWeightsModel, \
|
||||
TorchSharedWeightsModel
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
||||
|
@ -73,6 +73,11 @@ if __name__ == "__main__":
|
|||
ModelCatalog.register_custom_model("model1", mod1)
|
||||
ModelCatalog.register_custom_model("model2", mod2)
|
||||
|
||||
# Get obs- and action Spaces.
|
||||
single_env = gym.make("CartPole-v0")
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
|
||||
# Each policy can have a different configuration (including custom model).
|
||||
def gen_policy(i):
|
||||
config = {
|
||||
|
@ -81,7 +86,7 @@ if __name__ == "__main__":
|
|||
},
|
||||
"gamma": random.choice([0.95, 0.99]),
|
||||
}
|
||||
return PolicySpec(config=config)
|
||||
return (None, obs_space, act_space, config)
|
||||
|
||||
# Setup PPO with an ensemble of `num_policies` different policies.
|
||||
policies = {
|
||||
|
|
|
@ -14,6 +14,7 @@ Result for PG_multi_cartpole_0:
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
import os
|
||||
|
||||
import ray
|
||||
|
@ -21,7 +22,6 @@ from ray import tune
|
|||
from ray.tune.registry import register_env
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -58,6 +58,9 @@ if __name__ == "__main__":
|
|||
# Simple environment with 4 independent cartpole entities
|
||||
register_env("multi_agent_cartpole",
|
||||
lambda _: MultiAgentCartPole({"num_agents": 4}))
|
||||
single_env = gym.make("CartPole-v0")
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
|
||||
stop = {
|
||||
"training_iteration": args.stop_iters,
|
||||
|
@ -71,9 +74,9 @@ if __name__ == "__main__":
|
|||
# The multiagent Policy map.
|
||||
"policies": {
|
||||
# The Policy we are actually learning.
|
||||
"pg_policy": PolicySpec(config={"framework": args.framework}),
|
||||
"pg_policy": (None, obs_space, act_space, {}),
|
||||
# Random policy we are playing against.
|
||||
"random": PolicySpec(policy_class=RandomPolicy),
|
||||
"random": (RandomPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
# Map to either random behavior or PR learning behavior based on
|
||||
# the agent's ID.
|
||||
|
|
|
@ -14,6 +14,11 @@ if __name__ == "__main__":
|
|||
env = env_creator({})
|
||||
register_env("waterworld", env_creator)
|
||||
|
||||
obs_space = env.observation_space
|
||||
act_spc = env.action_space
|
||||
|
||||
policies = {agent: (None, obs_space, act_spc, {}) for agent in env.agents}
|
||||
|
||||
tune.run(
|
||||
"APEX_DDPG",
|
||||
stop={"episodes_total": 60000},
|
||||
|
@ -26,7 +31,7 @@ if __name__ == "__main__":
|
|||
"num_workers": 2,
|
||||
# Method specific
|
||||
"multiagent": {
|
||||
"policies": set(env.agents),
|
||||
"policies": policies,
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id, episode, **kwargs: agent_id),
|
||||
},
|
||||
|
|
|
@ -8,15 +8,27 @@ from pettingzoo.sisl import waterworld_v0
|
|||
if __name__ == "__main__":
|
||||
# RDQN - Rainbow DQN
|
||||
# ADQN - Apex DQN
|
||||
def env_creator(args):
|
||||
return PettingZooEnv(waterworld_v0.env())
|
||||
|
||||
register_env("waterworld", lambda _: PettingZooEnv(waterworld_v0.env()))
|
||||
env = env_creator({})
|
||||
register_env("waterworld", env_creator)
|
||||
|
||||
obs_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
||||
policies = {"shared_policy": (None, obs_space, act_space, {})}
|
||||
|
||||
# for all methods
|
||||
policy_ids = list(policies.keys())
|
||||
|
||||
tune.run(
|
||||
"APEX_DDPG",
|
||||
stop={"episodes_total": 60000},
|
||||
checkpoint_freq=10,
|
||||
config={
|
||||
# Enviroment specific.
|
||||
|
||||
# Enviroment specific
|
||||
"env": "waterworld",
|
||||
|
||||
# General
|
||||
|
@ -36,13 +48,9 @@ if __name__ == "__main__":
|
|||
"target_network_update_freq": 50000,
|
||||
"timesteps_per_iteration": 25000,
|
||||
|
||||
# Method specific.
|
||||
# Method specific
|
||||
"multiagent": {
|
||||
# We only have one policy (calling it "shared").
|
||||
# Class, obs/act-spaces, and config will be derived
|
||||
# automatically.
|
||||
"policies": {"shared_policy"},
|
||||
# Always use "shared" policy.
|
||||
"policies": policies,
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id, episode, **kwargs: "shared_policy"),
|
||||
},
|
||||
|
|
|
@ -6,7 +6,7 @@ from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0
|
|||
import ray
|
||||
from ray.rllib.agents.registry import get_trainer_class
|
||||
from ray.rllib.env import PettingZooEnv
|
||||
from pettingzoo.butterfly import pistonball_v4
|
||||
from pettingzoo.butterfly import pistonball_v1
|
||||
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
@ -21,9 +21,9 @@ if __name__ == "__main__":
|
|||
"""
|
||||
alg_name = "PPO"
|
||||
|
||||
# Function that outputs the environment you wish to register.
|
||||
# function that outputs the environment you wish to register.
|
||||
def env_creator(config):
|
||||
env = pistonball_v4.env(local_ratio=config.get("local_ratio", 0.2))
|
||||
env = pistonball_v1.env(local_ratio=config.get("local_ratio", 0.2))
|
||||
env = dtype_v0(env, dtype=float32)
|
||||
env = color_reduction_v0(env, mode="R")
|
||||
env = normalize_obs_v0(env)
|
||||
|
@ -32,23 +32,29 @@ if __name__ == "__main__":
|
|||
num_cpus = 1
|
||||
num_rollouts = 2
|
||||
|
||||
# Gets default training configuration and specifies the POMgame to load.
|
||||
# 1. Gets default training configuration and specifies the POMgame to load.
|
||||
config = deepcopy(get_trainer_class(alg_name)._default_config)
|
||||
|
||||
# Set environment config. This will be passed to
|
||||
# 2. Set environment config. This will be passed to
|
||||
# the env_creator function via the register env lambda below.
|
||||
config["env_config"] = {"local_ratio": 0.5}
|
||||
|
||||
# Register env
|
||||
# 3. Register env
|
||||
register_env("pistonball",
|
||||
lambda config: PettingZooEnv(env_creator(config)))
|
||||
|
||||
# Configuration for multiagent setup with policy sharing:
|
||||
# 4. Extract space dimensions
|
||||
test_env = PettingZooEnv(env_creator({}))
|
||||
obs_space = test_env.observation_space
|
||||
act_space = test_env.action_space
|
||||
|
||||
# 5. Configuration for multiagent setup with policy sharing:
|
||||
config["multiagent"] = {
|
||||
# Setup a single, shared policy for all agents.
|
||||
"policies": {"av"},
|
||||
# Map all agents to that policy.
|
||||
"policy_mapping_fn": lambda agent_id, episode, **kwargs: "av",
|
||||
"policies": {
|
||||
# the first tuple value is None -> uses default policy
|
||||
"av": (None, obs_space, act_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda agent_id, episode, **kwargs: "av"
|
||||
}
|
||||
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
|
@ -68,9 +74,11 @@ if __name__ == "__main__":
|
|||
# If no_done_at_end = True, environment is not resetted
|
||||
# when dones[__all__]= True.
|
||||
|
||||
# Initialize ray and trainer object
|
||||
# 6. Initialize ray and trainer object
|
||||
ray.init(num_cpus=num_cpus + 1)
|
||||
trainer = get_trainer_class(alg_name)(env="pistonball", config=config)
|
||||
|
||||
# Train once
|
||||
# 7. Train once
|
||||
trainer.train()
|
||||
|
||||
test_env.reset()
|
||||
|
|
|
@ -8,6 +8,7 @@ This demonstrates running the following policies in competition:
|
|||
"""
|
||||
|
||||
import argparse
|
||||
from gym.spaces import Discrete
|
||||
import os
|
||||
import random
|
||||
|
||||
|
@ -17,7 +18,6 @@ from ray.rllib.agents.registry import get_trainer_class
|
|||
from ray.rllib.examples.env.rock_paper_scissors import RockPaperScissors
|
||||
from ray.rllib.examples.policy.rock_paper_scissors_dummies import \
|
||||
BeatLastHeuristic, AlwaysSameHeuristic
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
||||
|
@ -93,9 +93,10 @@ def run_heuristic_vs_learned(args, use_lstm=False, trainer="PG"):
|
|||
"multiagent": {
|
||||
"policies_to_train": ["learned"],
|
||||
"policies": {
|
||||
"always_same": PolicySpec(policy_class=AlwaysSameHeuristic),
|
||||
"beat_last": PolicySpec(policy_class=BeatLastHeuristic),
|
||||
"learned": PolicySpec(config={
|
||||
"always_same": (AlwaysSameHeuristic, Discrete(3), Discrete(3),
|
||||
{}),
|
||||
"beat_last": (BeatLastHeuristic, Discrete(3), Discrete(3), {}),
|
||||
"learned": (None, Discrete(3), Discrete(3), {
|
||||
"model": {
|
||||
"use_lstm": use_lstm
|
||||
},
|
||||
|
|
|
@ -9,7 +9,7 @@ See also: centralized_critic.py for centralized critic PPO on this game.
|
|||
"""
|
||||
|
||||
import argparse
|
||||
from gym.spaces import Dict, Discrete, Tuple, MultiDiscrete
|
||||
from gym.spaces import Tuple, MultiDiscrete, Dict, Discrete
|
||||
import os
|
||||
|
||||
import ray
|
||||
|
@ -17,7 +17,6 @@ from ray import tune
|
|||
from ray.tune import register_env, grid_search
|
||||
from ray.rllib.env.multi_agent_env import ENV_STATE
|
||||
from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -81,8 +80,14 @@ if __name__ == "__main__":
|
|||
grouping, obs_space=obs_space, act_space=act_space))
|
||||
|
||||
if args.run == "contrib/MADDPG":
|
||||
obs_space = Discrete(6)
|
||||
act_space = TwoStepGame.action_space
|
||||
obs_space_dict = {
|
||||
"agent_1": Discrete(6),
|
||||
"agent_2": Discrete(6),
|
||||
}
|
||||
act_space_dict = {
|
||||
"agent_1": TwoStepGame.action_space,
|
||||
"agent_2": TwoStepGame.action_space,
|
||||
}
|
||||
config = {
|
||||
"learning_starts": 100,
|
||||
"env_config": {
|
||||
|
@ -90,14 +95,12 @@ if __name__ == "__main__":
|
|||
},
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"pol1": PolicySpec(
|
||||
observation_space=obs_space,
|
||||
action_space=act_space,
|
||||
config={"agent_id": 0}),
|
||||
"pol2": PolicySpec(
|
||||
observation_space=obs_space,
|
||||
action_space=act_space,
|
||||
config={"agent_id": 1}),
|
||||
"pol1": (None, Discrete(6), TwoStepGame.action_space, {
|
||||
"agent_id": 0,
|
||||
}),
|
||||
"pol2": (None, Discrete(6), TwoStepGame.action_space, {
|
||||
"agent_id": 1,
|
||||
}),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda aid, **kwargs: "pol2" if aid else "pol1"),
|
||||
|
|
|
@ -6,6 +6,7 @@ via a custom training workflow.
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
import os
|
||||
|
||||
import ray
|
||||
|
@ -123,14 +124,17 @@ if __name__ == "__main__":
|
|||
# Simple environment with 4 independent cartpole entities
|
||||
register_env("multi_agent_cartpole",
|
||||
lambda _: MultiAgentCartPole({"num_agents": 4}))
|
||||
single_env = gym.make("CartPole-v0")
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
|
||||
# Note that since the trainer below does not include a default policy or
|
||||
# policy configs, we have to explicitly set it in the multiagent config:
|
||||
policies = {
|
||||
"ppo_policy": (PPOTorchPolicy if args.torch or args.mixed_torch_tf else
|
||||
PPOTFPolicy, None, None, PPO_CONFIG),
|
||||
"dqn_policy": (DQNTorchPolicy
|
||||
if args.torch else DQNTFPolicy, None, None, DQN_CONFIG),
|
||||
PPOTFPolicy, obs_space, act_space, PPO_CONFIG),
|
||||
"dqn_policy": (DQNTorchPolicy if args.torch else DQNTFPolicy,
|
||||
obs_space, act_space, DQN_CONFIG),
|
||||
}
|
||||
|
||||
def policy_mapping_fn(agent_id, episode, **kwargs):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from collections import namedtuple
|
||||
import gym
|
||||
from gym.spaces import Box
|
||||
import logging
|
||||
|
@ -24,7 +23,6 @@ torch, _ = try_import_torch()
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.evaluation import MultiAgentEpisode
|
||||
from ray.rllib.utils.typing import PartialTrainerConfigDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -32,33 +30,6 @@ logger = logging.getLogger(__name__)
|
|||
# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
|
||||
LEARNER_STATS_KEY = "learner_stats"
|
||||
|
||||
# A policy spec used in the "config.multiagent.policies" specification dict
|
||||
# as values (keys are the policy IDs (str)). E.g.:
|
||||
# config:
|
||||
# multiagent:
|
||||
# policies: {
|
||||
# "pol1": PolicySpec(None, Box, Discrete(2), {"lr": 0.0001}),
|
||||
# "pol2": PolicySpec(config={"lr": 0.001}),
|
||||
# }
|
||||
PolicySpec = namedtuple(
|
||||
"PolicySpec",
|
||||
[
|
||||
# If None, use the Trainer's default policy class stored under
|
||||
# `Trainer._policy_class`.
|
||||
"policy_class", # type: Union[type, None]
|
||||
# If None, use the env's observation space. If None and there is no Env
|
||||
# (e.g. offline RL), an error is thrown.
|
||||
"observation_space", # type: Union[gym.Space, None]
|
||||
# If None, use the env's action space. If None and there is no Env
|
||||
# (e.g. offline RL), an error is thrown.
|
||||
"action_space", # type: Union[gym.Space, None]
|
||||
# Overrides defined keys in the main Trainer config.
|
||||
# If None, use {}.
|
||||
"config", # type: Union[PartialTrainerConfigDict, None]
|
||||
])
|
||||
# From 3.7 on, we could pass `defaults` into the above constructor.
|
||||
PolicySpec.__new__.__defaults__ = (None, None, None, None)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class Policy(metaclass=ABCMeta):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import glob
|
||||
import gym
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
|
@ -12,6 +13,7 @@ import ray
|
|||
from ray.tune.registry import register_env, register_input, \
|
||||
registry_get_input, registry_contains_input
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.offline import IOContext, JsonWriter, JsonReader, InputReader, \
|
||||
ShuffledInput
|
||||
|
@ -177,6 +179,12 @@ class AgentIOTest(unittest.TestCase):
|
|||
def testMultiAgent(self):
|
||||
register_env("multi_agent_cartpole",
|
||||
lambda _: MultiAgentCartPole({"num_agents": 10}))
|
||||
single_env = gym.make("CartPole-v0")
|
||||
|
||||
def gen_policy():
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
return (PGTFPolicy, obs_space, act_space, {})
|
||||
|
||||
for fw in framework_iterator():
|
||||
pg = PGTrainer(
|
||||
|
@ -185,7 +193,10 @@ class AgentIOTest(unittest.TestCase):
|
|||
"num_workers": 0,
|
||||
"output": self.test_dir,
|
||||
"multiagent": {
|
||||
"policies": {"policy_1", "policy_2"},
|
||||
"policies": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda aid, **kwargs: random.choice(
|
||||
["policy_1", "policy_2"])),
|
||||
|
@ -204,7 +215,10 @@ class AgentIOTest(unittest.TestCase):
|
|||
"input_evaluation": ["simulation"],
|
||||
"train_batch_size": 2000,
|
||||
"multiagent": {
|
||||
"policies": {"policy_1", "policy_2"},
|
||||
"policies": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda aid, **kwargs: random.choice(
|
||||
["policy_1", "policy_2"])),
|
||||
|
|
|
@ -16,7 +16,6 @@ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
|
|||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
|
||||
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.utils.numpy import one_hot
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
||||
|
@ -424,13 +423,16 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
n = 10
|
||||
register_env("multi_agent_cartpole",
|
||||
lambda _: MultiAgentCartPole({"num_agents": n}))
|
||||
single_env = gym.make("CartPole-v0")
|
||||
|
||||
def gen_policy():
|
||||
config = {
|
||||
"gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
|
||||
"n_step": random.choice([1, 2, 3, 4, 5]),
|
||||
}
|
||||
return PolicySpec(config=config)
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
return (None, obs_space, act_space, config)
|
||||
|
||||
pg = PGTrainer(
|
||||
env="multi_agent_cartpole",
|
||||
|
|
|
@ -8,6 +8,7 @@ import unittest
|
|||
import ray
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
|
@ -450,10 +451,10 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
"multiagent": {
|
||||
"policies": {
|
||||
"tuple_policy": (
|
||||
None, TUPLE_SPACE, act_space,
|
||||
PGTFPolicy, TUPLE_SPACE, act_space,
|
||||
{"model": {"custom_model": "tuple_spy"}}),
|
||||
"dict_policy": (
|
||||
None, DICT_SPACE, act_space,
|
||||
PGTFPolicy, DICT_SPACE, act_space,
|
||||
{"model": {"custom_model": "dict_spy"}}),
|
||||
},
|
||||
"policy_mapping_fn": lambda aid, **kwargs: {
|
||||
|
|
|
@ -24,12 +24,16 @@ class TestPettingZooEnv(unittest.TestCase):
|
|||
|
||||
config = deepcopy(agent_class._default_config)
|
||||
|
||||
test_env = PettingZooEnv(simple_spread_v2.env())
|
||||
obs_space = test_env.observation_space
|
||||
act_space = test_env.action_space
|
||||
test_env.close()
|
||||
|
||||
config["multiagent"] = {
|
||||
# Set of policy IDs (by default, will use Trainer's
|
||||
# default policy class, the env's obs/act spaces and config={}).
|
||||
"policies": {"av"},
|
||||
# Mapping function that always returns "av" as policy ID to use
|
||||
# (for any agent).
|
||||
"policies": {
|
||||
# the first tuple value is None -> uses default policy
|
||||
"av": (None, obs_space, act_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda agent_id, episode, **kwargs: "av"
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from gym.spaces import Box, Discrete
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
@ -149,6 +150,9 @@ def learn_test_multi_agent_plus_rollout(algo):
|
|||
def policy_fn(agent_id, episode, **kwargs):
|
||||
return "pol{}".format(agent_id)
|
||||
|
||||
observation_space = Box(float("-inf"), float("inf"), (4, ))
|
||||
action_space = Discrete(2)
|
||||
|
||||
config = {
|
||||
"num_gpus": 0,
|
||||
"num_workers": 1,
|
||||
|
@ -158,7 +162,10 @@ def learn_test_multi_agent_plus_rollout(algo):
|
|||
"framework": fw,
|
||||
"env": MultiAgentCartPole,
|
||||
"multiagent": {
|
||||
"policies": {"pol0", "pol1"},
|
||||
"policies": {
|
||||
"pol0": (None, observation_space, action_space, {}),
|
||||
"pol1": (None, observation_space, action_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": policy_fn,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
|||
import random
|
||||
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, \
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Any, Dict, List, Tuple, Union, TYPE_CHECKING
|
||||
import gym
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
_, tf, _ = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
|
||||
|
@ -42,7 +42,8 @@ AgentID = Any
|
|||
PolicyID = str
|
||||
|
||||
# Type of the config["multiagent"]["policies"] dict for multi-agent training.
|
||||
MultiAgentPolicyConfigDict = Dict[PolicyID, "PolicySpec"]
|
||||
MultiAgentPolicyConfigDict = Dict[PolicyID, Tuple[Union[
|
||||
type, None], gym.Space, gym.Space, PartialTrainerConfigDict]]
|
||||
|
||||
# Represents an environment id. These could be:
|
||||
# - An int index for a sub-env within a vectorized env.
|
||||
|
|
Loading…
Add table
Reference in a new issue