import logging from typing import Tuple from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.from_config import from_config from ray.rllib.utils.typing import MultiAgentPolicyConfigDict, PartialTrainerConfigDict logger = logging.getLogger(__name__) def check_multi_agent( config: PartialTrainerConfigDict, ) -> Tuple[MultiAgentPolicyConfigDict, bool]: """Checks, whether a (partial) config defines a multi-agent setup. Args: config: The user/Trainer/Policy config to check for multi-agent. Returns: Tuple consisting of the resulting (all fixed) multi-agent policy dict and bool indicating whether we have a multi-agent setup or not. Raises: KeyError: If `config` does not contain a "multiagent" key or if there is an invalid key inside the "multiagent" config or if any policy in the "policies" dict has a non-str ID (key). ValueError: If any subkey of the "multiagent" dict has an invalid value. """ if "multiagent" not in config: raise KeyError( "Your `config` to be checked for a multi-agent setup must have " "the 'multiagent' key defined!" ) multiagent_config = config["multiagent"] policies = multiagent_config.get("policies") # Check for invalid sub-keys of multiagent config. from ray.rllib.agents.trainer import COMMON_CONFIG allowed = list(COMMON_CONFIG["multiagent"].keys()) if any(k not in allowed for k in multiagent_config.keys()): raise KeyError( f"You have invalid keys in your 'multiagent' config dict! " f"The only allowed keys are: {allowed}." ) # Nothing specified in config dict -> Assume simple single agent setup # with DEFAULT_POLICY_ID as only policy. if not policies: policies = {DEFAULT_POLICY_ID} # Policies given as set/list/tuple (of PolicyIDs) -> Setup each policy # automatically via empty PolicySpec (will make RLlib infer obs- and action spaces # as well as the Policy's class). if isinstance(policies, (set, list, tuple)): policies = multiagent_config["policies"] = { pid: PolicySpec() for pid in policies } # Check each defined policy ID and spec. for pid, policy_spec in policies.copy().items(): # Policy IDs must be strings. if not isinstance(pid, str): raise KeyError(f"Policy IDs must always be of type `str`, got {type(pid)}") # Convert to PolicySpec if plain list/tuple. if not isinstance(policy_spec, PolicySpec): # Values must be lists/tuples of len 4. if not isinstance(policy_spec, (list, tuple)) or len(policy_spec) != 4: raise ValueError( "Policy specs must be tuples/lists of " "(cls or None, obs_space, action_space, config), " f"got {policy_spec}" ) policies[pid] = PolicySpec(*policy_spec) # Config is None -> Set to {}. if policies[pid].config is None: policies[pid] = policies[pid]._replace(config={}) # Config not a dict. elif not isinstance(policies[pid].config, dict): raise ValueError( f"Multiagent policy config for {pid} must be a dict, " f"but got {type(policies[pid].config)}!" ) # Check other "multiagent" sub-keys' values. if multiagent_config.get("count_steps_by", "env_steps") not in [ "env_steps", "agent_steps", ]: raise ValueError( "config.multiagent.count_steps_by must be one of " "[env_steps|agent_steps], not " f"{multiagent_config['count_steps_by']}!" ) if multiagent_config.get("replay_mode", "independent") not in [ "independent", "lockstep", ]: raise ValueError( "`config.multiagent.replay_mode` must be " "[independent|lockstep], not " f"{multiagent_config['replay_mode']}!" ) # Attempt to create a `policy_mapping_fn` from config dict. Helpful # is users would like to specify custom callable classes in yaml files. if isinstance(multiagent_config.get("policy_mapping_fn"), dict): multiagent_config["policy_mapping_fn"] = from_config( multiagent_config["policy_mapping_fn"] ) # Check `policies_to_train` for invalid entries. if isinstance(multiagent_config["policies_to_train"], (list, set, tuple)): if len(multiagent_config["policies_to_train"]) == 0: logger.warning( "`config.multiagent.policies_to_train` is empty! " "Make sure - if you would like to learn at least one policy - " "to add its ID to that list." ) for pid in multiagent_config["policies_to_train"]: if pid not in policies: raise ValueError( "`config.multiagent.policies_to_train` contains policy " f"ID ({pid}) that was not defined in `config.multiagent.policies!" ) # Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only # PolicyID found in policies dict. is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies return policies, is_multiagent