From 38b5b6d24c47b72a60f1746f636ba287cc10bc16 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Tue, 13 Jul 2021 09:57:15 -0700 Subject: [PATCH] Revert "[RLlib] Simplify multiagent config (automatically infer class/spaces/config). (#16565)" (#17036) This reverts commit e4123fff271b1a358712850ba07195a6fb02e8b0. --- .../requirements/rllib/requirements_rllib.txt | 2 - rllib/BUILD | 8 - rllib/agents/ppo/tests/test_ddppo.py | 2 +- rllib/agents/trainer.py | 44 +----- rllib/evaluation/rollout_worker.py | 148 ++++++++++-------- rllib/evaluation/worker_set.py | 34 ++-- rllib/examples/multi_agent_cartpole.py | 9 +- rllib/examples/multi_agent_custom_policy.py | 9 +- .../multi_agent_independent_learning.py | 7 +- .../examples/multi_agent_parameter_sharing.py | 24 ++- rllib/examples/pettingzoo_env.py | 34 ++-- .../rock_paper_scissors_multiagent.py | 9 +- rllib/examples/two_step_game.py | 27 ++-- rllib/examples/two_trainer_workflow.py | 10 +- rllib/policy/policy.py | 29 ---- rllib/tests/test_io.py | 18 ++- rllib/tests/test_multi_agent_env.py | 6 +- rllib/tests/test_nested_observation_spaces.py | 5 +- rllib/tests/test_pettingzoo_env.py | 14 +- rllib/tests/test_rollout.py | 9 +- rllib/utils/sgd.py | 2 +- rllib/utils/typing.py | 5 +- 22 files changed, 227 insertions(+), 228 deletions(-) diff --git a/python/requirements/rllib/requirements_rllib.txt b/python/requirements/rllib/requirements_rllib.txt index e28c3b5cd..ddd76da17 100644 --- a/python/requirements/rllib/requirements_rllib.txt +++ b/python/requirements/rllib/requirements_rllib.txt @@ -23,8 +23,6 @@ mlagents==0.26.0 mlagents_envs==0.26.0 # For tests on PettingZoo's multi-agent envs. pettingzoo==1.8.2 -pymunk -supersuit # For testing in MuJoCo-like envs (in PyBullet). pybullet==3.1.7 # For tests on RecSim and Kaggle envs. diff --git a/rllib/BUILD b/rllib/BUILD index 7307cf2f3..c2017d3af 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2427,14 +2427,6 @@ py_test( args = ["--as-test", "--stop-reward=40.0", "--num-gpus=0", "--num-workers=0"] ) -py_test( - name = "examples/pettingzoo_env", - main = "examples/pettingzoo_env.py", - tags = ["examples", "examples_P"], - size = "medium", - srcs = ["examples/pettingzoo_env.py"], -) - py_test( name = "examples/restore_1_of_n_agents_from_checkpoint", tags = ["examples", "examples_R"], diff --git a/rllib/agents/ppo/tests/test_ddppo.py b/rllib/agents/ppo/tests/test_ddppo.py index 489d86f59..273df0de5 100644 --- a/rllib/agents/ppo/tests/test_ddppo.py +++ b/rllib/agents/ppo/tests/test_ddppo.py @@ -2,8 +2,8 @@ import unittest import ray import ray.rllib.agents.ppo as ppo -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.policy.policy import LEARNER_STATS_KEY +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.test_utils import check_compute_single_action, \ framework_iterator diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 571489a09..31c6dbb8f 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -23,9 +23,9 @@ from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.models import MODEL_DEFAULTS -from ray.rllib.policy.policy import Policy, PolicySpec +from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils import deep_update, FilterManager, merge_dicts +from ray.rllib.utils import FilterManager, deep_update, merge_dicts from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf, TensorStructType @@ -1366,44 +1366,6 @@ class Trainer(Trainable): if simple_optim_setting != DEPRECATED_VALUE: deprecation_warning(old="simple_optimizer", error=False) - # Loop through all policy definitions in multi-agent policies. - multiagent_config = config["multiagent"] - policies = multiagent_config.get("policies") - if not policies: - policies = {DEFAULT_POLICY_ID} - if isinstance(policies, set): - policies = multiagent_config["policies"] = { - pid: PolicySpec() - for pid in policies - } - is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies - - for pid, policy_spec in policies.copy().items(): - # Policy IDs must be strings. - if not isinstance(pid, str): - raise ValueError("Policy keys must be strs, got {}".format( - type(pid))) - - # Convert to PolicySpec if plain list/tuple. - if not isinstance(policy_spec, PolicySpec): - # Values must be lists/tuples of len 4. - if not isinstance(policy_spec, (list, tuple)) or \ - len(policy_spec) != 4: - raise ValueError( - "Policy specs must be tuples/lists of " - "(cls or None, obs_space, action_space, config), " - f"got {policy_spec}") - policies[pid] = PolicySpec(*policy_spec) - - # Config is None -> Set to {}. - if policies[pid].config is None: - policies[pid] = policies[pid]._replace(config={}) - # Config not a dict. - elif not isinstance(policies[pid].config, dict): - raise ValueError( - f"Multiagent policy config for {pid} must be a dict, " - f"but got {type(policies[pid].config)}!") - framework = config.get("framework") # Multi-GPU setting: Must use TFMultiGPU if tf. if config.get("num_gpus", 0) > 1: @@ -1423,7 +1385,7 @@ class Trainer(Trainable): config["simple_optimizer"] = True # TF + Multi-agent case: Try using MultiGPU optimizer (only # if all policies used are DynamicTFPolicies). - elif is_multiagent: + elif len(config["multiagent"]["policies"]) > 0: from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy default_policy_cls = None if trainer_obj_or_none is None else \ getattr(trainer_obj_or_none, "_policy_class", None) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 8c38e6da2..6971869bb 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -27,7 +27,7 @@ from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID -from ray.rllib.policy.policy import Policy, PolicySpec +from ray.rllib.policy.policy import Policy from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils import merge_dicts @@ -139,7 +139,9 @@ class RolloutWorker(ParallelIteratorWorker): env_creator: Callable[[EnvContext], EnvType], validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None, - policy_spec: Union[type, Dict[PolicyID, PolicySpec]] = None, + policy_spec: Union[type, Dict[ + str, Tuple[Optional[type], gym.Space, gym.Space, + PartialTrainerConfigDict]]] = None, policy_mapping_fn: Optional[Callable[ [AgentID, "MultiAgentEpisode"], PolicyID]] = None, policies_to_train: Optional[List[PolicyID]] = None, @@ -181,7 +183,9 @@ class RolloutWorker(ParallelIteratorWorker): fake_sampler: bool = False, spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]]] = None, - policy=None, + policy: Union[type, Dict[ + str, Tuple[Optional[type], gym.Space, gym.Space, + PartialTrainerConfigDict]]] = None, monitor_path=None, ): """Initialize a rollout worker. @@ -192,11 +196,11 @@ class RolloutWorker(ParallelIteratorWorker): validate_env (Optional[Callable[[EnvType, EnvContext], None]]): Optional callable to validate the generated environment (only on worker=0). - policy_spec (Optional[Union[Type[Policy], - MultiAgentPolicyConfigDict]]): The MultiAgentPolicyConfigDict - mapping policy IDs (str) to PolicySpec's or a single policy - class to use. - If a dict is specified, then we are in multi-agent mode and a + policy_spec (Union[type, Dict[str, Tuple[Type[Policy], gym.Space, + gym.Space, PartialTrainerConfigDict]]]): Either a Policy class + or a dict of policy id strings to + (Policy class, obs_space, action_space, config)-tuples. If a + dict is specified, then we are in multi-agent mode and a policy_mapping_fn can also be set (if not, will map all agents to DEFAULT_POLICY_ID). policy_mapping_fn (Optional[Callable[[AgentID, MultiAgentEpisode], @@ -305,26 +309,15 @@ class RolloutWorker(ParallelIteratorWorker): gym.spaces.Space]]]): An optional space dict mapping policy IDs to (obs_space, action_space)-tuples. This is used in case no Env is created on this RolloutWorker. + policy: Obsoleted arg. Use `policy_spec` instead. monitor_path: Obsoleted arg. Use `record_env` instead. """ - # Deprecated args. if policy is not None: deprecation_warning("policy", "policy_spec", error=False) policy_spec = policy - assert policy_spec is not None, \ - "Must provide `policy_spec` when creating RolloutWorker!" - - # Do quick translation into MultiAgentPolicyConfigDict. - if not isinstance(policy_spec, dict): - policy_spec = { - DEFAULT_POLICY_ID: PolicySpec(policy_class=policy_spec) - } - policy_spec = { - pid: spec if isinstance(spec, PolicySpec) else PolicySpec(*spec) - for pid, spec in policy_spec.copy().items() - } - + assert policy_spec is not None, "Must provide `policy_spec` when " \ + "creating RolloutWorker!" if monitor_path is not None: deprecation_warning("monitor_path", "record_env", error=False) record_env = monitor_path @@ -490,7 +483,7 @@ class RolloutWorker(ParallelIteratorWorker): self.make_env_fn = make_env self.tf_sess = None - policy_dict = _determine_spaces_for_multi_agent_dict( + policy_dict = _validate_and_canonicalize( policy_spec, self.env, spaces=spaces, policy_config=policy_config) # List of IDs of those policies, which should be trained. # By default, these are all policies found in the policy_dict. @@ -1309,7 +1302,7 @@ class RolloutWorker(ParallelIteratorWorker): for name, (cls, obs_space, act_space, conf) in sorted(policy_dict.items()): logger.debug("Creating policy for {}".format(name)) - merged_conf = merge_dicts(policy_config, conf or {}) + merged_conf = merge_dicts(policy_config, conf) merged_conf["num_workers"] = self.num_workers merged_conf["worker_index"] = self.worker_index if self.preprocessing_enabled: @@ -1386,57 +1379,74 @@ class RolloutWorker(ParallelIteratorWorker): self.sampler.shutdown = True -def _determine_spaces_for_multi_agent_dict( - multi_agent_dict: MultiAgentPolicyConfigDict, - env: Optional[EnvType] = None, +def _validate_and_canonicalize( + policy: Union[Type[Policy], MultiAgentPolicyConfigDict], + env: Optional[EnvType], spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space, - gym.spaces.Space]]] = None, - policy_config: Optional[PartialTrainerConfigDict] = None, + gym.spaces.Space]]], + policy_config: Optional[PartialTrainerConfigDict], ) -> MultiAgentPolicyConfigDict: - # Try extracting spaces from env. - env_obs_space = None - env_act_space = None - if env is not None and hasattr(env, "observation_space") and isinstance( - env.observation_space, gym.Space): - env_obs_space = env.observation_space - if env is not None and hasattr(env, "action_space") and isinstance( - env.action_space, gym.Space): - env_act_space = env.action_space + if isinstance(policy, dict): + _validate_multiagent_config(policy) + return policy + elif not issubclass(policy, Policy): + raise ValueError(f"`policy` ({policy}) must be a rllib.Policy class!") + else: + if (isinstance(env, MultiAgentEnv) + and not hasattr(env, "observation_space")): + raise ValueError( + "MultiAgentEnv must have observation_space defined if run " + "in a single-agent configuration.") + if env is not None: + return { + DEFAULT_POLICY_ID: (policy, env.observation_space, + env.action_space, {}) + } - for pid, policy_spec in multi_agent_dict.copy().items(): - if policy_spec.observation_space is None: - if spaces is not None: - obs_space = spaces[pid][0] - elif env_obs_space is not None: - obs_space = env_obs_space - elif "observation_space" in policy_config: - obs_space = policy_config["observation_space"] - else: + if spaces is None: + if "action_space" not in policy_config or \ + "observation_space" not in policy_config: raise ValueError( - "`observation_space` not provided in PolicySpec for " - f"{pid} and env does not have an observation space OR " - "no spaces received from other workers' env(s) OR no " - "`observation_space` specified in config!") - multi_agent_dict[pid] = multi_agent_dict[pid]._replace( - observation_space=obs_space) + "If no env given, must provide obs/action spaces either " + "in the `multiagent.policies` dict or under " + "`config.[observation|action]_space`!") + spaces = { + DEFAULT_POLICY_ID: (policy_config["observation_space"], + policy_config["action_space"]) + } + return { + DEFAULT_POLICY_ID: (policy, spaces[DEFAULT_POLICY_ID][0], + spaces[DEFAULT_POLICY_ID][1], {}) + } - if policy_spec.action_space is None: - if spaces is not None: - act_space = spaces[pid][1] - elif env_act_space is not None: - act_space = env_act_space - elif "action_space" in policy_config: - act_space = policy_config["action_space"] - else: - raise ValueError( - "`action_space` not provided in PolicySpec for " - f"{pid} and env does not have an action space OR " - "no spaces received from other workers' env(s) OR no " - "`action_space` specified in config!") - multi_agent_dict[pid] = multi_agent_dict[pid]._replace( - action_space=act_space) - return multi_agent_dict + +def _validate_multiagent_config(policy: MultiAgentPolicyConfigDict, + allow_none_graph: bool = False) -> None: + # Loop through all policy definitions in multi-agent policie + for k, v in policy.items(): + if not isinstance(k, str): + raise ValueError("Policy key must be str, got {}!".format(k)) + if not isinstance(v, (tuple, list)) or len(v) != 4: + raise ValueError( + "policy values must be tuples/lists of " + "(cls or None, obs_space, action_space, config), got {}". + format(v)) + if allow_none_graph and v[0] is None: + pass + elif not issubclass(v[0], Policy): + raise ValueError("policy tuple value 0 must be a rllib.Policy " + "class or None, got {}".format(v[0])) + if not isinstance(v[1], gym.Space): + raise ValueError( + "policy tuple value 1 (observation_space) must be a " + "gym.Space, got {}".format(type(v[1]))) + if not isinstance(v[2], gym.Space): + raise ValueError("policy tuple value 2 (action_space) must be a " + "gym.Space, got {}".format(type(v[2]))) + if not isinstance(v[3], dict): + raise ValueError("policy tuple value 3 (config) must be a dict, " + "got {}".format(type(v[3]))) def _validate_env(env: Any) -> EnvType: diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 2a5a045af..12d2ca216 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -6,12 +6,13 @@ from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union import ray from ray.actor import ActorHandle -from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.evaluation.rollout_worker import RolloutWorker, \ + _validate_multiagent_config from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.env_context import EnvContext from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ ShuffledInput, D4RLReader -from ray.rllib.policy.policy import Policy, PolicySpec +from ray.rllib.policy import Policy from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.framework import try_import_tf @@ -364,21 +365,20 @@ class WorkerSet: else: input_evaluation = config["input_evaluation"] - # Assert everything is correct in "multiagent" config dict (if given). - ma_policies = config["multiagent"]["policies"] - if ma_policies: - for pid, policy_spec in ma_policies.copy().items(): - assert isinstance(policy_spec, (PolicySpec, list, tuple)) - # Class is None -> Use `policy_cls`. - if policy_spec.policy_class is None: - ma_policies[pid] = ma_policies[pid]._replace( - policy_class=policy_cls) - policies = ma_policies - - # Create a policy_spec (MultiAgentPolicyConfigDict), - # even if no "multiagent" setup given by user. + # Fill in the default policy_cls if 'None' is specified in multiagent. + if config["multiagent"]["policies"]: + tmp = config["multiagent"]["policies"] + _validate_multiagent_config(tmp, allow_none_graph=True) + # TODO: (sven) Allow for setting observation and action spaces to + # None as well, in which case, spaces are taken from env. + # It's tedious to have to provide these in a multi-agent config. + for k, v in tmp.items(): + if v[0] is None: + tmp[k] = (policy_cls, v[1], v[2], v[3]) + policy_spec = tmp + # Otherwise, policy spec is simply the policy class itself. else: - policies = policy_cls + policy_spec = policy_cls if worker_index == 0: extra_python_environs = config.get( @@ -390,7 +390,7 @@ class WorkerSet: worker = cls( env_creator=env_creator, validate_env=validate_env, - policy_spec=policies, + policy_spec=policy_spec, policy_mapping_fn=config["multiagent"]["policy_mapping_fn"], policies_to_train=config["multiagent"]["policies_to_train"], tf_session_creator=(session_creator diff --git a/rllib/examples/multi_agent_cartpole.py b/rllib/examples/multi_agent_cartpole.py index 05444fb14..1bb58986d 100644 --- a/rllib/examples/multi_agent_cartpole.py +++ b/rllib/examples/multi_agent_cartpole.py @@ -10,6 +10,7 @@ execution, set the TF_TIMELINE_DIR environment variable. """ import argparse +import gym import os import random @@ -20,7 +21,6 @@ from ray.rllib.examples.models.shared_weights_model import \ SharedWeightsModel1, SharedWeightsModel2, TF2SharedWeightsModel, \ TorchSharedWeightsModel from ray.rllib.models import ModelCatalog -from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_learning_achieved @@ -73,6 +73,11 @@ if __name__ == "__main__": ModelCatalog.register_custom_model("model1", mod1) ModelCatalog.register_custom_model("model2", mod2) + # Get obs- and action Spaces. + single_env = gym.make("CartPole-v0") + obs_space = single_env.observation_space + act_space = single_env.action_space + # Each policy can have a different configuration (including custom model). def gen_policy(i): config = { @@ -81,7 +86,7 @@ if __name__ == "__main__": }, "gamma": random.choice([0.95, 0.99]), } - return PolicySpec(config=config) + return (None, obs_space, act_space, config) # Setup PPO with an ensemble of `num_policies` different policies. policies = { diff --git a/rllib/examples/multi_agent_custom_policy.py b/rllib/examples/multi_agent_custom_policy.py index 2aea9568e..52d1e88bf 100644 --- a/rllib/examples/multi_agent_custom_policy.py +++ b/rllib/examples/multi_agent_custom_policy.py @@ -14,6 +14,7 @@ Result for PG_multi_cartpole_0: """ import argparse +import gym import os import ray @@ -21,7 +22,6 @@ from ray import tune from ray.tune.registry import register_env from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.examples.policy.random_policy import RandomPolicy -from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.test_utils import check_learning_achieved parser = argparse.ArgumentParser() @@ -58,6 +58,9 @@ if __name__ == "__main__": # Simple environment with 4 independent cartpole entities register_env("multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 4})) + single_env = gym.make("CartPole-v0") + obs_space = single_env.observation_space + act_space = single_env.action_space stop = { "training_iteration": args.stop_iters, @@ -71,9 +74,9 @@ if __name__ == "__main__": # The multiagent Policy map. "policies": { # The Policy we are actually learning. - "pg_policy": PolicySpec(config={"framework": args.framework}), + "pg_policy": (None, obs_space, act_space, {}), # Random policy we are playing against. - "random": PolicySpec(policy_class=RandomPolicy), + "random": (RandomPolicy, obs_space, act_space, {}), }, # Map to either random behavior or PR learning behavior based on # the agent's ID. diff --git a/rllib/examples/multi_agent_independent_learning.py b/rllib/examples/multi_agent_independent_learning.py index d425cc065..34a68707d 100644 --- a/rllib/examples/multi_agent_independent_learning.py +++ b/rllib/examples/multi_agent_independent_learning.py @@ -14,6 +14,11 @@ if __name__ == "__main__": env = env_creator({}) register_env("waterworld", env_creator) + obs_space = env.observation_space + act_spc = env.action_space + + policies = {agent: (None, obs_space, act_spc, {}) for agent in env.agents} + tune.run( "APEX_DDPG", stop={"episodes_total": 60000}, @@ -26,7 +31,7 @@ if __name__ == "__main__": "num_workers": 2, # Method specific "multiagent": { - "policies": set(env.agents), + "policies": policies, "policy_mapping_fn": ( lambda agent_id, episode, **kwargs: agent_id), }, diff --git a/rllib/examples/multi_agent_parameter_sharing.py b/rllib/examples/multi_agent_parameter_sharing.py index 5dae06923..06c71f9ff 100644 --- a/rllib/examples/multi_agent_parameter_sharing.py +++ b/rllib/examples/multi_agent_parameter_sharing.py @@ -8,15 +8,27 @@ from pettingzoo.sisl import waterworld_v0 if __name__ == "__main__": # RDQN - Rainbow DQN # ADQN - Apex DQN + def env_creator(args): + return PettingZooEnv(waterworld_v0.env()) - register_env("waterworld", lambda _: PettingZooEnv(waterworld_v0.env())) + env = env_creator({}) + register_env("waterworld", env_creator) + + obs_space = env.observation_space + act_space = env.action_space + + policies = {"shared_policy": (None, obs_space, act_space, {})} + + # for all methods + policy_ids = list(policies.keys()) tune.run( "APEX_DDPG", stop={"episodes_total": 60000}, checkpoint_freq=10, config={ - # Enviroment specific. + + # Enviroment specific "env": "waterworld", # General @@ -36,13 +48,9 @@ if __name__ == "__main__": "target_network_update_freq": 50000, "timesteps_per_iteration": 25000, - # Method specific. + # Method specific "multiagent": { - # We only have one policy (calling it "shared"). - # Class, obs/act-spaces, and config will be derived - # automatically. - "policies": {"shared_policy"}, - # Always use "shared" policy. + "policies": policies, "policy_mapping_fn": ( lambda agent_id, episode, **kwargs: "shared_policy"), }, diff --git a/rllib/examples/pettingzoo_env.py b/rllib/examples/pettingzoo_env.py index cb4916857..2e4bb531c 100644 --- a/rllib/examples/pettingzoo_env.py +++ b/rllib/examples/pettingzoo_env.py @@ -6,7 +6,7 @@ from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0 import ray from ray.rllib.agents.registry import get_trainer_class from ray.rllib.env import PettingZooEnv -from pettingzoo.butterfly import pistonball_v4 +from pettingzoo.butterfly import pistonball_v1 from ray.tune.registry import register_env @@ -21,9 +21,9 @@ if __name__ == "__main__": """ alg_name = "PPO" - # Function that outputs the environment you wish to register. + # function that outputs the environment you wish to register. def env_creator(config): - env = pistonball_v4.env(local_ratio=config.get("local_ratio", 0.2)) + env = pistonball_v1.env(local_ratio=config.get("local_ratio", 0.2)) env = dtype_v0(env, dtype=float32) env = color_reduction_v0(env, mode="R") env = normalize_obs_v0(env) @@ -32,23 +32,29 @@ if __name__ == "__main__": num_cpus = 1 num_rollouts = 2 - # Gets default training configuration and specifies the POMgame to load. + # 1. Gets default training configuration and specifies the POMgame to load. config = deepcopy(get_trainer_class(alg_name)._default_config) - # Set environment config. This will be passed to + # 2. Set environment config. This will be passed to # the env_creator function via the register env lambda below. config["env_config"] = {"local_ratio": 0.5} - # Register env + # 3. Register env register_env("pistonball", lambda config: PettingZooEnv(env_creator(config))) - # Configuration for multiagent setup with policy sharing: + # 4. Extract space dimensions + test_env = PettingZooEnv(env_creator({})) + obs_space = test_env.observation_space + act_space = test_env.action_space + + # 5. Configuration for multiagent setup with policy sharing: config["multiagent"] = { - # Setup a single, shared policy for all agents. - "policies": {"av"}, - # Map all agents to that policy. - "policy_mapping_fn": lambda agent_id, episode, **kwargs: "av", + "policies": { + # the first tuple value is None -> uses default policy + "av": (None, obs_space, act_space, {}), + }, + "policy_mapping_fn": lambda agent_id, episode, **kwargs: "av" } # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. @@ -68,9 +74,11 @@ if __name__ == "__main__": # If no_done_at_end = True, environment is not resetted # when dones[__all__]= True. - # Initialize ray and trainer object + # 6. Initialize ray and trainer object ray.init(num_cpus=num_cpus + 1) trainer = get_trainer_class(alg_name)(env="pistonball", config=config) - # Train once + # 7. Train once trainer.train() + + test_env.reset() diff --git a/rllib/examples/rock_paper_scissors_multiagent.py b/rllib/examples/rock_paper_scissors_multiagent.py index a690985a2..cbcbd0540 100644 --- a/rllib/examples/rock_paper_scissors_multiagent.py +++ b/rllib/examples/rock_paper_scissors_multiagent.py @@ -8,6 +8,7 @@ This demonstrates running the following policies in competition: """ import argparse +from gym.spaces import Discrete import os import random @@ -17,7 +18,6 @@ from ray.rllib.agents.registry import get_trainer_class from ray.rllib.examples.env.rock_paper_scissors import RockPaperScissors from ray.rllib.examples.policy.rock_paper_scissors_dummies import \ BeatLastHeuristic, AlwaysSameHeuristic -from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_learning_achieved @@ -93,9 +93,10 @@ def run_heuristic_vs_learned(args, use_lstm=False, trainer="PG"): "multiagent": { "policies_to_train": ["learned"], "policies": { - "always_same": PolicySpec(policy_class=AlwaysSameHeuristic), - "beat_last": PolicySpec(policy_class=BeatLastHeuristic), - "learned": PolicySpec(config={ + "always_same": (AlwaysSameHeuristic, Discrete(3), Discrete(3), + {}), + "beat_last": (BeatLastHeuristic, Discrete(3), Discrete(3), {}), + "learned": (None, Discrete(3), Discrete(3), { "model": { "use_lstm": use_lstm }, diff --git a/rllib/examples/two_step_game.py b/rllib/examples/two_step_game.py index e3c83bbde..d5ac76048 100644 --- a/rllib/examples/two_step_game.py +++ b/rllib/examples/two_step_game.py @@ -9,7 +9,7 @@ See also: centralized_critic.py for centralized critic PPO on this game. """ import argparse -from gym.spaces import Dict, Discrete, Tuple, MultiDiscrete +from gym.spaces import Tuple, MultiDiscrete, Dict, Discrete import os import ray @@ -17,7 +17,6 @@ from ray import tune from ray.tune import register_env, grid_search from ray.rllib.env.multi_agent_env import ENV_STATE from ray.rllib.examples.env.two_step_game import TwoStepGame -from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.test_utils import check_learning_achieved parser = argparse.ArgumentParser() @@ -81,8 +80,14 @@ if __name__ == "__main__": grouping, obs_space=obs_space, act_space=act_space)) if args.run == "contrib/MADDPG": - obs_space = Discrete(6) - act_space = TwoStepGame.action_space + obs_space_dict = { + "agent_1": Discrete(6), + "agent_2": Discrete(6), + } + act_space_dict = { + "agent_1": TwoStepGame.action_space, + "agent_2": TwoStepGame.action_space, + } config = { "learning_starts": 100, "env_config": { @@ -90,14 +95,12 @@ if __name__ == "__main__": }, "multiagent": { "policies": { - "pol1": PolicySpec( - observation_space=obs_space, - action_space=act_space, - config={"agent_id": 0}), - "pol2": PolicySpec( - observation_space=obs_space, - action_space=act_space, - config={"agent_id": 1}), + "pol1": (None, Discrete(6), TwoStepGame.action_space, { + "agent_id": 0, + }), + "pol2": (None, Discrete(6), TwoStepGame.action_space, { + "agent_id": 1, + }), }, "policy_mapping_fn": ( lambda aid, **kwargs: "pol2" if aid else "pol1"), diff --git a/rllib/examples/two_trainer_workflow.py b/rllib/examples/two_trainer_workflow.py index c403c5bb1..8eb34ab8c 100644 --- a/rllib/examples/two_trainer_workflow.py +++ b/rllib/examples/two_trainer_workflow.py @@ -6,6 +6,7 @@ via a custom training workflow. """ import argparse +import gym import os import ray @@ -123,14 +124,17 @@ if __name__ == "__main__": # Simple environment with 4 independent cartpole entities register_env("multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 4})) + single_env = gym.make("CartPole-v0") + obs_space = single_env.observation_space + act_space = single_env.action_space # Note that since the trainer below does not include a default policy or # policy configs, we have to explicitly set it in the multiagent config: policies = { "ppo_policy": (PPOTorchPolicy if args.torch or args.mixed_torch_tf else - PPOTFPolicy, None, None, PPO_CONFIG), - "dqn_policy": (DQNTorchPolicy - if args.torch else DQNTFPolicy, None, None, DQN_CONFIG), + PPOTFPolicy, obs_space, act_space, PPO_CONFIG), + "dqn_policy": (DQNTorchPolicy if args.torch else DQNTFPolicy, + obs_space, act_space, DQN_CONFIG), } def policy_mapping_fn(agent_id, episode, **kwargs): diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index e8a6f4d33..eb52c1041 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -1,5 +1,4 @@ from abc import ABCMeta, abstractmethod -from collections import namedtuple import gym from gym.spaces import Box import logging @@ -24,7 +23,6 @@ torch, _ = try_import_torch() if TYPE_CHECKING: from ray.rllib.evaluation import MultiAgentEpisode - from ray.rllib.utils.typing import PartialTrainerConfigDict logger = logging.getLogger(__name__) @@ -32,33 +30,6 @@ logger = logging.getLogger(__name__) # `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. LEARNER_STATS_KEY = "learner_stats" -# A policy spec used in the "config.multiagent.policies" specification dict -# as values (keys are the policy IDs (str)). E.g.: -# config: -# multiagent: -# policies: { -# "pol1": PolicySpec(None, Box, Discrete(2), {"lr": 0.0001}), -# "pol2": PolicySpec(config={"lr": 0.001}), -# } -PolicySpec = namedtuple( - "PolicySpec", - [ - # If None, use the Trainer's default policy class stored under - # `Trainer._policy_class`. - "policy_class", # type: Union[type, None] - # If None, use the env's observation space. If None and there is no Env - # (e.g. offline RL), an error is thrown. - "observation_space", # type: Union[gym.Space, None] - # If None, use the env's action space. If None and there is no Env - # (e.g. offline RL), an error is thrown. - "action_space", # type: Union[gym.Space, None] - # Overrides defined keys in the main Trainer config. - # If None, use {}. - "config", # type: Union[PartialTrainerConfigDict, None] - ]) -# From 3.7 on, we could pass `defaults` into the above constructor. -PolicySpec.__new__.__defaults__ = (None, None, None, None) - @DeveloperAPI class Policy(metaclass=ABCMeta): diff --git a/rllib/tests/test_io.py b/rllib/tests/test_io.py index 480b857da..02e8da83b 100644 --- a/rllib/tests/test_io.py +++ b/rllib/tests/test_io.py @@ -1,4 +1,5 @@ import glob +import gym import json import numpy as np import os @@ -12,6 +13,7 @@ import ray from ray.tune.registry import register_env, register_input, \ registry_get_input, registry_contains_input from ray.rllib.agents.pg import PGTrainer +from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.offline import IOContext, JsonWriter, JsonReader, InputReader, \ ShuffledInput @@ -177,6 +179,12 @@ class AgentIOTest(unittest.TestCase): def testMultiAgent(self): register_env("multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 10})) + single_env = gym.make("CartPole-v0") + + def gen_policy(): + obs_space = single_env.observation_space + act_space = single_env.action_space + return (PGTFPolicy, obs_space, act_space, {}) for fw in framework_iterator(): pg = PGTrainer( @@ -185,7 +193,10 @@ class AgentIOTest(unittest.TestCase): "num_workers": 0, "output": self.test_dir, "multiagent": { - "policies": {"policy_1", "policy_2"}, + "policies": { + "policy_1": gen_policy(), + "policy_2": gen_policy(), + }, "policy_mapping_fn": ( lambda aid, **kwargs: random.choice( ["policy_1", "policy_2"])), @@ -204,7 +215,10 @@ class AgentIOTest(unittest.TestCase): "input_evaluation": ["simulation"], "train_batch_size": 2000, "multiagent": { - "policies": {"policy_1", "policy_2"}, + "policies": { + "policy_1": gen_policy(), + "policy_2": gen_policy(), + }, "policy_mapping_fn": ( lambda aid, **kwargs: random.choice( ["policy_1", "policy_2"])), diff --git a/rllib/tests/test_multi_agent_env.py b/rllib/tests/test_multi_agent_env.py index ae2c63b76..8fa0b2d1b 100644 --- a/rllib/tests/test_multi_agent_env.py +++ b/rllib/tests/test_multi_agent_env.py @@ -16,7 +16,6 @@ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \ from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv -from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.test_utils import check @@ -424,13 +423,16 @@ class TestMultiAgentEnv(unittest.TestCase): n = 10 register_env("multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": n})) + single_env = gym.make("CartPole-v0") def gen_policy(): config = { "gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]), "n_step": random.choice([1, 2, 3, 4, 5]), } - return PolicySpec(config=config) + obs_space = single_env.observation_space + act_space = single_env.action_space + return (None, obs_space, act_space, config) pg = PGTrainer( env="multi_agent_cartpole", diff --git a/rllib/tests/test_nested_observation_spaces.py b/rllib/tests/test_nested_observation_spaces.py index abe03e885..582681726 100644 --- a/rllib/tests/test_nested_observation_spaces.py +++ b/rllib/tests/test_nested_observation_spaces.py @@ -8,6 +8,7 @@ import unittest import ray from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.agents.pg import PGTrainer +from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.vector_env import VectorEnv @@ -450,10 +451,10 @@ class NestedSpacesTest(unittest.TestCase): "multiagent": { "policies": { "tuple_policy": ( - None, TUPLE_SPACE, act_space, + PGTFPolicy, TUPLE_SPACE, act_space, {"model": {"custom_model": "tuple_spy"}}), "dict_policy": ( - None, DICT_SPACE, act_space, + PGTFPolicy, DICT_SPACE, act_space, {"model": {"custom_model": "dict_spy"}}), }, "policy_mapping_fn": lambda aid, **kwargs: { diff --git a/rllib/tests/test_pettingzoo_env.py b/rllib/tests/test_pettingzoo_env.py index c4f3bd4f4..eb36b7eb5 100644 --- a/rllib/tests/test_pettingzoo_env.py +++ b/rllib/tests/test_pettingzoo_env.py @@ -24,12 +24,16 @@ class TestPettingZooEnv(unittest.TestCase): config = deepcopy(agent_class._default_config) + test_env = PettingZooEnv(simple_spread_v2.env()) + obs_space = test_env.observation_space + act_space = test_env.action_space + test_env.close() + config["multiagent"] = { - # Set of policy IDs (by default, will use Trainer's - # default policy class, the env's obs/act spaces and config={}). - "policies": {"av"}, - # Mapping function that always returns "av" as policy ID to use - # (for any agent). + "policies": { + # the first tuple value is None -> uses default policy + "av": (None, obs_space, act_space, {}), + }, "policy_mapping_fn": lambda agent_id, episode, **kwargs: "av" } diff --git a/rllib/tests/test_rollout.py b/rllib/tests/test_rollout.py index 53f6e8773..b3427d44d 100644 --- a/rllib/tests/test_rollout.py +++ b/rllib/tests/test_rollout.py @@ -1,3 +1,4 @@ +from gym.spaces import Box, Discrete import os from pathlib import Path import re @@ -149,6 +150,9 @@ def learn_test_multi_agent_plus_rollout(algo): def policy_fn(agent_id, episode, **kwargs): return "pol{}".format(agent_id) + observation_space = Box(float("-inf"), float("inf"), (4, )) + action_space = Discrete(2) + config = { "num_gpus": 0, "num_workers": 1, @@ -158,7 +162,10 @@ def learn_test_multi_agent_plus_rollout(algo): "framework": fw, "env": MultiAgentCartPole, "multiagent": { - "policies": {"pol0", "pol1"}, + "policies": { + "pol0": (None, observation_space, action_space, {}), + "pol1": (None, observation_space, action_space, {}), + }, "policy_mapping_fn": policy_fn, }, } diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index 5130989ce..77e655559 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -6,7 +6,7 @@ from collections import defaultdict import random from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, \ +from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch logger = logging.getLogger(__name__) diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index 5a8e7ff1c..89aba715e 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -1,10 +1,10 @@ from typing import Any, Dict, List, Tuple, Union, TYPE_CHECKING +import gym if TYPE_CHECKING: from ray.rllib.utils import try_import_tf, try_import_torch _, tf, _ = try_import_tf() torch, _ = try_import_torch() - from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.policy.view_requirement import ViewRequirement @@ -42,7 +42,8 @@ AgentID = Any PolicyID = str # Type of the config["multiagent"]["policies"] dict for multi-agent training. -MultiAgentPolicyConfigDict = Dict[PolicyID, "PolicySpec"] +MultiAgentPolicyConfigDict = Dict[PolicyID, Tuple[Union[ + type, None], gym.Space, gym.Space, PartialTrainerConfigDict]] # Represents an environment id. These could be: # - An int index for a sub-env within a vectorized env.