mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
98 lines
4.1 KiB
Python
98 lines
4.1 KiB
Python
![]() |
from typing import Tuple
|
||
|
|
||
|
from ray.rllib.policy.policy import PolicySpec
|
||
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||
|
from ray.rllib.utils.typing import MultiAgentPolicyConfigDict, \
|
||
|
PartialTrainerConfigDict
|
||
|
|
||
|
|
||
|
def check_multi_agent(config: PartialTrainerConfigDict) -> \
|
||
|
Tuple[MultiAgentPolicyConfigDict, bool]:
|
||
|
"""Checks, whether a (partial) config defines a multi-agent setup.
|
||
|
|
||
|
Args:
|
||
|
config: The user/Trainer/Policy config to check for multi-agent.
|
||
|
|
||
|
Returns:
|
||
|
Tuple consisting of the resulting (all fixed) multi-agent policy
|
||
|
dict and bool indicating whether we have a multi-agent setup or not.
|
||
|
|
||
|
Raises:
|
||
|
KeyError: If `config` does not contain a "multiagent" key or if there
|
||
|
is an invalid key inside the "multiagent" config or if any policy
|
||
|
in the "policies" dict has a non-str ID (key).
|
||
|
ValueError: If any subkey of the "multiagent" dict has an invalid
|
||
|
value.
|
||
|
"""
|
||
|
if "multiagent" not in config:
|
||
|
raise KeyError(
|
||
|
"Your `config` to be checked for a multi-agent setup must have "
|
||
|
"the 'multiagent' key defined!")
|
||
|
multiagent_config = config["multiagent"]
|
||
|
|
||
|
policies = multiagent_config.get("policies")
|
||
|
|
||
|
# Check for invalid sub-keys of multiagent config.
|
||
|
from ray.rllib.agents.trainer import COMMON_CONFIG
|
||
|
allowed = list(COMMON_CONFIG["multiagent"].keys())
|
||
|
if any(k not in allowed for k in multiagent_config.keys()):
|
||
|
raise KeyError(
|
||
|
f"You have invalid keys in your 'multiagent' config dict! "
|
||
|
f"The only allowed keys are: {allowed}.")
|
||
|
|
||
|
# Nothing specified in config dict -> Assume simple single agent setup
|
||
|
# with DEFAULT_POLICY_ID as only policy.
|
||
|
if not policies:
|
||
|
policies = {DEFAULT_POLICY_ID}
|
||
|
# Policies given as set (of PolicyIDs) -> Setup each policy automatically
|
||
|
# via empty PolicySpec (will make RLlib infer obs- and action spaces
|
||
|
# as well as the Policy's class).
|
||
|
if isinstance(policies, set):
|
||
|
policies = multiagent_config["policies"] = {
|
||
|
pid: PolicySpec()
|
||
|
for pid in policies
|
||
|
}
|
||
|
|
||
|
# Check each defined policy ID and spec.
|
||
|
for pid, policy_spec in policies.copy().items():
|
||
|
# Policy IDs must be strings.
|
||
|
if not isinstance(pid, str):
|
||
|
raise KeyError(
|
||
|
f"Policy IDs must always be of type `str`, got {type(pid)}")
|
||
|
# Convert to PolicySpec if plain list/tuple.
|
||
|
if not isinstance(policy_spec, PolicySpec):
|
||
|
# Values must be lists/tuples of len 4.
|
||
|
if not isinstance(policy_spec, (list, tuple)) or \
|
||
|
len(policy_spec) != 4:
|
||
|
raise ValueError(
|
||
|
"Policy specs must be tuples/lists of "
|
||
|
"(cls or None, obs_space, action_space, config), "
|
||
|
f"got {policy_spec}")
|
||
|
policies[pid] = PolicySpec(*policy_spec)
|
||
|
|
||
|
# Config is None -> Set to {}.
|
||
|
if policies[pid].config is None:
|
||
|
policies[pid] = policies[pid]._replace(config={})
|
||
|
# Config not a dict.
|
||
|
elif not isinstance(policies[pid].config, dict):
|
||
|
raise ValueError(
|
||
|
f"Multiagent policy config for {pid} must be a dict, "
|
||
|
f"but got {type(policies[pid].config)}!")
|
||
|
|
||
|
# Check other "multiagent" sub-keys' values.
|
||
|
if multiagent_config.get("count_steps_by", "env_steps") not in \
|
||
|
["env_steps", "agent_steps"]:
|
||
|
raise ValueError("config.multiagent.count_steps_by must be "
|
||
|
"[env_steps|agent_steps], not "
|
||
|
f"{multiagent_config['count_steps_by']}!")
|
||
|
if multiagent_config.get("replay_mode", "independent") not in \
|
||
|
["independent", "lockstep"]:
|
||
|
raise ValueError("config.multiagent.replay_mode must be "
|
||
|
"[independent|lockstep], not "
|
||
|
f"{multiagent_config['replay_mode']}!")
|
||
|
|
||
|
# Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only
|
||
|
# PolicyID found in policies dict.
|
||
|
is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies
|
||
|
return policies, is_multiagent
|