mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
131 lines
5.3 KiB
Python
131 lines
5.3 KiB
Python
import logging
|
|
from typing import Tuple
|
|
|
|
from ray.rllib.policy.policy import PolicySpec
|
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
|
from ray.rllib.utils.from_config import from_config
|
|
from ray.rllib.utils.typing import MultiAgentPolicyConfigDict, PartialTrainerConfigDict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def check_multi_agent(
|
|
config: PartialTrainerConfigDict,
|
|
) -> Tuple[MultiAgentPolicyConfigDict, bool]:
|
|
"""Checks, whether a (partial) config defines a multi-agent setup.
|
|
|
|
Args:
|
|
config: The user/Trainer/Policy config to check for multi-agent.
|
|
|
|
Returns:
|
|
Tuple consisting of the resulting (all fixed) multi-agent policy
|
|
dict and bool indicating whether we have a multi-agent setup or not.
|
|
|
|
Raises:
|
|
KeyError: If `config` does not contain a "multiagent" key or if there
|
|
is an invalid key inside the "multiagent" config or if any policy
|
|
in the "policies" dict has a non-str ID (key).
|
|
ValueError: If any subkey of the "multiagent" dict has an invalid
|
|
value.
|
|
"""
|
|
if "multiagent" not in config:
|
|
raise KeyError(
|
|
"Your `config` to be checked for a multi-agent setup must have "
|
|
"the 'multiagent' key defined!"
|
|
)
|
|
multiagent_config = config["multiagent"]
|
|
|
|
policies = multiagent_config.get("policies")
|
|
|
|
# Check for invalid sub-keys of multiagent config.
|
|
from ray.rllib.agents.trainer import COMMON_CONFIG
|
|
|
|
allowed = list(COMMON_CONFIG["multiagent"].keys())
|
|
if any(k not in allowed for k in multiagent_config.keys()):
|
|
raise KeyError(
|
|
f"You have invalid keys in your 'multiagent' config dict! "
|
|
f"The only allowed keys are: {allowed}."
|
|
)
|
|
|
|
# Nothing specified in config dict -> Assume simple single agent setup
|
|
# with DEFAULT_POLICY_ID as only policy.
|
|
if not policies:
|
|
policies = {DEFAULT_POLICY_ID}
|
|
# Policies given as set/list/tuple (of PolicyIDs) -> Setup each policy
|
|
# automatically via empty PolicySpec (will make RLlib infer obs- and action spaces
|
|
# as well as the Policy's class).
|
|
if isinstance(policies, (set, list, tuple)):
|
|
policies = multiagent_config["policies"] = {
|
|
pid: PolicySpec() for pid in policies
|
|
}
|
|
|
|
# Check each defined policy ID and spec.
|
|
for pid, policy_spec in policies.copy().items():
|
|
# Policy IDs must be strings.
|
|
if not isinstance(pid, str):
|
|
raise KeyError(f"Policy IDs must always be of type `str`, got {type(pid)}")
|
|
# Convert to PolicySpec if plain list/tuple.
|
|
if not isinstance(policy_spec, PolicySpec):
|
|
# Values must be lists/tuples of len 4.
|
|
if not isinstance(policy_spec, (list, tuple)) or len(policy_spec) != 4:
|
|
raise ValueError(
|
|
"Policy specs must be tuples/lists of "
|
|
"(cls or None, obs_space, action_space, config), "
|
|
f"got {policy_spec}"
|
|
)
|
|
policies[pid] = PolicySpec(*policy_spec)
|
|
|
|
# Config is None -> Set to {}.
|
|
if policies[pid].config is None:
|
|
policies[pid] = policies[pid]._replace(config={})
|
|
# Config not a dict.
|
|
elif not isinstance(policies[pid].config, dict):
|
|
raise ValueError(
|
|
f"Multiagent policy config for {pid} must be a dict, "
|
|
f"but got {type(policies[pid].config)}!"
|
|
)
|
|
|
|
# Check other "multiagent" sub-keys' values.
|
|
if multiagent_config.get("count_steps_by", "env_steps") not in [
|
|
"env_steps",
|
|
"agent_steps",
|
|
]:
|
|
raise ValueError(
|
|
"config.multiagent.count_steps_by must be one of "
|
|
"[env_steps|agent_steps], not "
|
|
f"{multiagent_config['count_steps_by']}!"
|
|
)
|
|
if multiagent_config.get("replay_mode", "independent") not in [
|
|
"independent",
|
|
"lockstep",
|
|
]:
|
|
raise ValueError(
|
|
"`config.multiagent.replay_mode` must be "
|
|
"[independent|lockstep], not "
|
|
f"{multiagent_config['replay_mode']}!"
|
|
)
|
|
# Attempt to create a `policy_mapping_fn` from config dict. Helpful
|
|
# is users would like to specify custom callable classes in yaml files.
|
|
if isinstance(multiagent_config.get("policy_mapping_fn"), dict):
|
|
multiagent_config["policy_mapping_fn"] = from_config(
|
|
multiagent_config["policy_mapping_fn"]
|
|
)
|
|
# Check `policies_to_train` for invalid entries.
|
|
if isinstance(multiagent_config["policies_to_train"], (list, set, tuple)):
|
|
if len(multiagent_config["policies_to_train"]) == 0:
|
|
logger.warning(
|
|
"`config.multiagent.policies_to_train` is empty! "
|
|
"Make sure - if you would like to learn at least one policy - "
|
|
"to add its ID to that list."
|
|
)
|
|
for pid in multiagent_config["policies_to_train"]:
|
|
if pid not in policies:
|
|
raise ValueError(
|
|
"`config.multiagent.policies_to_train` contains policy "
|
|
f"ID ({pid}) that was not defined in `config.multiagent.policies!"
|
|
)
|
|
|
|
# Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only
|
|
# PolicyID found in policies dict.
|
|
is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies
|
|
return policies, is_multiagent
|