2020-06-19 13:09:05 -07:00
|
|
|
import gym
|
2022-01-06 14:34:20 -08:00
|
|
|
import logging
|
2022-05-23 08:18:44 +02:00
|
|
|
from typing import Callable, Dict, List, Tuple, Optional, Union, Set, Type
|
2020-06-19 13:09:05 -07:00
|
|
|
|
2021-11-30 17:02:10 -08:00
|
|
|
from ray.rllib.env.base_env import BaseEnv
|
2022-07-27 00:10:45 -07:00
|
|
|
from ray.rllib.env.env_context import EnvContext
|
2022-01-06 14:34:20 -08:00
|
|
|
from ray.rllib.utils.annotations import (
|
|
|
|
ExperimentalAPI,
|
|
|
|
override,
|
|
|
|
PublicAPI,
|
|
|
|
DeveloperAPI,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-01-25 14:16:58 +01:00
|
|
|
from ray.rllib.utils.typing import (
|
|
|
|
AgentID,
|
|
|
|
EnvCreator,
|
|
|
|
EnvID,
|
|
|
|
EnvType,
|
|
|
|
MultiAgentDict,
|
|
|
|
MultiEnvDict,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-04-27 14:24:20 +02:00
|
|
|
from ray.util import log_once
|
2018-06-23 18:32:16 -07:00
|
|
|
|
2020-05-01 22:59:34 +02:00
|
|
|
# If the obs space is Dict type, look for the global state under this key.
|
|
|
|
ENV_STATE = "state"
|
|
|
|
|
2022-01-06 14:34:20 -08:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
|
|
|
|
@PublicAPI
|
2021-09-02 23:02:05 -07:00
|
|
|
class MultiAgentEnv(gym.Env):
|
2018-06-23 18:32:16 -07:00
|
|
|
"""An environment that hosts multiple independent agents.
|
|
|
|
|
2018-07-01 00:05:08 -07:00
|
|
|
Agents are identified by (string) agent ids. Note that these "agents" here
|
2022-06-20 15:54:00 +02:00
|
|
|
are not to be confused with RLlib Algorithms, which are also sometimes
|
2021-11-17 21:40:16 +01:00
|
|
|
referred to as "agents" or "RL agents".
|
2018-06-23 18:32:16 -07:00
|
|
|
"""
|
|
|
|
|
2022-01-06 14:34:20 -08:00
|
|
|
def __init__(self):
|
2022-01-18 07:34:06 -08:00
|
|
|
if not hasattr(self, "observation_space"):
|
|
|
|
self.observation_space = None
|
|
|
|
if not hasattr(self, "action_space"):
|
|
|
|
self.action_space = None
|
|
|
|
if not hasattr(self, "_agent_ids"):
|
|
|
|
self._agent_ids = set()
|
2022-01-06 14:34:20 -08:00
|
|
|
|
2022-04-12 07:50:09 +02:00
|
|
|
# Do the action and observation spaces map from agent ids to spaces
|
2022-01-06 14:34:20 -08:00
|
|
|
# for the individual agents?
|
2022-01-18 07:34:06 -08:00
|
|
|
if not hasattr(self, "_spaces_in_preferred_format"):
|
|
|
|
self._spaces_in_preferred_format = None
|
2022-01-06 14:34:20 -08:00
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def reset(self) -> MultiAgentDict:
|
2018-06-23 18:32:16 -07:00
|
|
|
"""Resets the env and returns observations from ready agents.
|
|
|
|
|
|
|
|
Returns:
|
2021-10-29 10:46:52 +02:00
|
|
|
New observations for each ready agent.
|
2021-11-17 21:40:16 +01:00
|
|
|
|
|
|
|
Examples:
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
|
|
|
>>> class MyMultiAgentEnv(MultiAgentEnv): # doctest: +SKIP
|
|
|
|
... # Define your env here. # doctest: +SKIP
|
|
|
|
... ... # doctest: +SKIP
|
|
|
|
>>> env = MyMultiAgentEnv() # doctest: +SKIP
|
|
|
|
>>> obs = env.reset() # doctest: +SKIP
|
|
|
|
>>> print(obs) # doctest: +SKIP
|
2021-11-17 21:40:16 +01:00
|
|
|
{
|
|
|
|
"car_0": [2.4, 1.6],
|
|
|
|
"car_1": [3.4, -3.2],
|
|
|
|
"traffic_light_1": [0, 3, 5, 1],
|
|
|
|
}
|
2018-06-23 18:32:16 -07:00
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def step(
|
|
|
|
self, action_dict: MultiAgentDict
|
|
|
|
) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]:
|
2018-06-23 18:32:16 -07:00
|
|
|
"""Returns observations from ready agents.
|
|
|
|
|
|
|
|
The returns are dicts mapping from agent_id strings to values. The
|
|
|
|
number of agents in the env can vary over time.
|
|
|
|
|
2021-05-12 12:16:00 +02:00
|
|
|
Returns:
|
2021-10-29 10:46:52 +02:00
|
|
|
Tuple containing 1) new observations for
|
|
|
|
each ready agent, 2) reward values for each ready agent. If
|
|
|
|
the episode is just started, the value will be None.
|
|
|
|
3) Done values for each ready agent. The special key
|
|
|
|
"__all__" (required) is used to indicate env termination.
|
|
|
|
4) Optional info values for each agent id.
|
2021-11-17 21:40:16 +01:00
|
|
|
|
|
|
|
Examples:
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> env = ... # doctest: +SKIP
|
|
|
|
>>> obs, rewards, dones, infos = env.step( # doctest: +SKIP
|
|
|
|
... action_dict={ # doctest: +SKIP
|
|
|
|
... "car_0": 1, "car_1": 0, "traffic_light_1": 2, # doctest: +SKIP
|
|
|
|
... }) # doctest: +SKIP
|
|
|
|
>>> print(rewards) # doctest: +SKIP
|
2021-11-17 21:40:16 +01:00
|
|
|
{
|
|
|
|
"car_0": 3,
|
|
|
|
"car_1": -1,
|
|
|
|
"traffic_light_1": 0,
|
|
|
|
}
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> print(dones) # doctest: +SKIP
|
2021-11-17 21:40:16 +01:00
|
|
|
{
|
|
|
|
"car_0": False, # car_0 is still running
|
|
|
|
"car_1": True, # car_1 is done
|
|
|
|
"__all__": False, # the env is not done
|
|
|
|
}
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> print(infos) # doctest: +SKIP
|
2021-11-17 21:40:16 +01:00
|
|
|
{
|
|
|
|
"car_0": {}, # info for car_0
|
|
|
|
"car_1": {}, # info for car_1
|
|
|
|
}
|
2018-06-23 18:32:16 -07:00
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
2018-12-18 10:40:01 -08:00
|
|
|
|
2022-01-06 14:34:20 -08:00
|
|
|
@ExperimentalAPI
|
|
|
|
def observation_space_contains(self, x: MultiAgentDict) -> bool:
|
|
|
|
"""Checks if the observation space contains the given key.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: Observations to check.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
True if the observation space contains the given all observations
|
|
|
|
in x.
|
|
|
|
"""
|
|
|
|
if (
|
|
|
|
not hasattr(self, "_spaces_in_preferred_format")
|
|
|
|
or self._spaces_in_preferred_format is None
|
|
|
|
):
|
|
|
|
self._spaces_in_preferred_format = (
|
|
|
|
self._check_if_space_maps_agent_id_to_sub_space()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-01-06 14:34:20 -08:00
|
|
|
if self._spaces_in_preferred_format:
|
2022-06-07 10:33:35 +02:00
|
|
|
for key, agent_obs in x.items():
|
|
|
|
if not self.observation_space[key].contains(agent_obs):
|
|
|
|
return False
|
|
|
|
if not all(k in self.observation_space for k in x):
|
|
|
|
if log_once("possibly_bad_multi_agent_dict_missing_agent_observations"):
|
|
|
|
logger.warning(
|
|
|
|
"You environment returns observations that are "
|
|
|
|
"MultiAgentDicts with incomplete information. "
|
|
|
|
"Meaning that they only contain information on a subset of"
|
|
|
|
" participating agents. Ignore this warning if this is "
|
|
|
|
"intended, for example if your environment is a turn-based "
|
|
|
|
"simulation."
|
|
|
|
)
|
|
|
|
return True
|
2022-01-06 14:34:20 -08:00
|
|
|
|
|
|
|
logger.warning("observation_space_contains() has not been implemented")
|
|
|
|
return True
|
|
|
|
|
|
|
|
@ExperimentalAPI
|
|
|
|
def action_space_contains(self, x: MultiAgentDict) -> bool:
|
|
|
|
"""Checks if the action space contains the given action.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: Actions to check.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
True if the action space contains all actions in x.
|
|
|
|
"""
|
|
|
|
if (
|
|
|
|
not hasattr(self, "_spaces_in_preferred_format")
|
|
|
|
or self._spaces_in_preferred_format is None
|
|
|
|
):
|
|
|
|
self._spaces_in_preferred_format = (
|
|
|
|
self._check_if_space_maps_agent_id_to_sub_space()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-01-06 14:34:20 -08:00
|
|
|
if self._spaces_in_preferred_format:
|
2022-07-15 03:16:09 -04:00
|
|
|
return all([self.action_space[agent].contains(x[agent]) for agent in x])
|
2022-01-06 14:34:20 -08:00
|
|
|
|
2022-04-27 14:24:20 +02:00
|
|
|
if log_once("action_space_contains"):
|
|
|
|
logger.warning("action_space_contains() has not been implemented")
|
2022-01-06 14:34:20 -08:00
|
|
|
return True
|
|
|
|
|
|
|
|
@ExperimentalAPI
|
|
|
|
def action_space_sample(self, agent_ids: list = None) -> MultiAgentDict:
|
|
|
|
"""Returns a random action for each environment, and potentially each
|
|
|
|
agent in that environment.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
agent_ids: List of agent ids to sample actions for. If None or
|
|
|
|
empty list, sample actions for all agents in the
|
|
|
|
environment.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random action for each environment.
|
|
|
|
"""
|
|
|
|
if (
|
|
|
|
not hasattr(self, "_spaces_in_preferred_format")
|
|
|
|
or self._spaces_in_preferred_format is None
|
|
|
|
):
|
|
|
|
self._spaces_in_preferred_format = (
|
|
|
|
self._check_if_space_maps_agent_id_to_sub_space()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-01-06 14:34:20 -08:00
|
|
|
if self._spaces_in_preferred_format:
|
|
|
|
if agent_ids is None:
|
|
|
|
agent_ids = self.get_agent_ids()
|
|
|
|
samples = self.action_space.sample()
|
2022-02-17 05:06:14 -08:00
|
|
|
return {
|
|
|
|
agent_id: samples[agent_id]
|
|
|
|
for agent_id in agent_ids
|
|
|
|
if agent_id != "__all__"
|
|
|
|
}
|
2022-01-06 14:34:20 -08:00
|
|
|
logger.warning("action_space_sample() has not been implemented")
|
|
|
|
return {}
|
|
|
|
|
|
|
|
@ExperimentalAPI
|
|
|
|
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
|
|
|
|
"""Returns a random observation from the observation space for each
|
|
|
|
agent if agent_ids is None, otherwise returns a random observation for
|
|
|
|
the agents in agent_ids.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
agent_ids: List of agent ids to sample actions for. If None or
|
|
|
|
empty list, sample actions for all agents in the
|
|
|
|
environment.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A random action for each environment.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if (
|
|
|
|
not hasattr(self, "_spaces_in_preferred_format")
|
|
|
|
or self._spaces_in_preferred_format is None
|
|
|
|
):
|
|
|
|
self._spaces_in_preferred_format = (
|
|
|
|
self._check_if_space_maps_agent_id_to_sub_space()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-01-06 14:34:20 -08:00
|
|
|
if self._spaces_in_preferred_format:
|
|
|
|
if agent_ids is None:
|
|
|
|
agent_ids = self.get_agent_ids()
|
|
|
|
samples = self.observation_space.sample()
|
|
|
|
samples = {agent_id: samples[agent_id] for agent_id in agent_ids}
|
|
|
|
return samples
|
2022-04-27 14:24:20 +02:00
|
|
|
if log_once("observation_space_sample"):
|
|
|
|
logger.warning("observation_space_sample() has not been implemented")
|
2022-01-06 14:34:20 -08:00
|
|
|
return {}
|
|
|
|
|
|
|
|
@PublicAPI
|
|
|
|
def get_agent_ids(self) -> Set[AgentID]:
|
|
|
|
"""Returns a set of agent ids in the environment.
|
|
|
|
|
|
|
|
Returns:
|
2022-01-13 02:31:22 -08:00
|
|
|
Set of agent ids.
|
2022-01-06 14:34:20 -08:00
|
|
|
"""
|
|
|
|
if not isinstance(self._agent_ids, set):
|
|
|
|
self._agent_ids = set(self._agent_ids)
|
|
|
|
return self._agent_ids
|
|
|
|
|
2021-05-12 12:16:00 +02:00
|
|
|
@PublicAPI
|
|
|
|
def render(self, mode=None) -> None:
|
|
|
|
"""Tries to render the environment."""
|
|
|
|
|
|
|
|
# By default, do nothing.
|
|
|
|
pass
|
|
|
|
|
2022-02-08 16:29:25 -08:00
|
|
|
# fmt: off
|
2022-01-06 14:34:20 -08:00
|
|
|
# __grouping_doc_begin__
|
2020-06-19 13:09:05 -07:00
|
|
|
def with_agent_groups(
|
2022-01-06 14:34:20 -08:00
|
|
|
self,
|
|
|
|
groups: Dict[str, List[AgentID]],
|
|
|
|
obs_space: gym.Space = None,
|
2020-06-19 13:09:05 -07:00
|
|
|
act_space: gym.Space = None) -> "MultiAgentEnv":
|
2018-12-18 10:40:01 -08:00
|
|
|
"""Convenience method for grouping together agents in this env.
|
|
|
|
|
2021-11-17 21:40:16 +01:00
|
|
|
An agent group is a list of agent IDs that are mapped to a single
|
2018-12-18 10:40:01 -08:00
|
|
|
logical agent. All agents of the group must act at the same time in the
|
|
|
|
environment. The grouped agent exposes Tuple action and observation
|
|
|
|
spaces that are the concatenated action and obs spaces of the
|
|
|
|
individual agents.
|
|
|
|
|
|
|
|
The rewards of all the agents in a group are summed. The individual
|
|
|
|
agent rewards are available under the "individual_rewards" key of the
|
|
|
|
group info return.
|
|
|
|
|
|
|
|
Agent grouping is required to leverage algorithms such as Q-Mix.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2021-10-29 10:46:52 +02:00
|
|
|
groups: Mapping from group id to a list of the agent ids
|
2018-12-18 10:40:01 -08:00
|
|
|
of group members. If an agent id is not present in any group
|
2022-05-16 00:45:32 -07:00
|
|
|
value, it will be left ungrouped. The group id becomes a new agent ID
|
|
|
|
in the final environment.
|
2021-10-29 10:46:52 +02:00
|
|
|
obs_space: Optional observation space for the grouped
|
2022-05-16 00:45:32 -07:00
|
|
|
env. Must be a tuple space. If not provided, will infer this to be a
|
|
|
|
Tuple of n individual agents spaces (n=num agents in a group).
|
2021-10-29 10:46:52 +02:00
|
|
|
act_space: Optional action space for the grouped env.
|
2022-05-16 00:45:32 -07:00
|
|
|
Must be a tuple space. If not provided, will infer this to be a Tuple
|
|
|
|
of n individual agents spaces (n=num agents in a group).
|
2018-12-18 10:40:01 -08:00
|
|
|
|
|
|
|
Examples:
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
|
|
|
>>> class MyMultiAgentEnv(MultiAgentEnv): # doctest: +SKIP
|
|
|
|
... # define your env here
|
|
|
|
... ... # doctest: +SKIP
|
|
|
|
>>> env = MyMultiAgentEnv(...) # doctest: +SKIP
|
|
|
|
>>> grouped_env = env.with_agent_groups(env, { # doctest: +SKIP
|
|
|
|
... "group1": ["agent1", "agent2", "agent3"], # doctest: +SKIP
|
|
|
|
... "group2": ["agent4", "agent5"], # doctest: +SKIP
|
|
|
|
... }) # doctest: +SKIP
|
2018-12-18 10:40:01 -08:00
|
|
|
"""
|
|
|
|
|
2021-01-19 10:09:39 +01:00
|
|
|
from ray.rllib.env.wrappers.group_agents_wrapper import \
|
|
|
|
GroupAgentsWrapper
|
|
|
|
return GroupAgentsWrapper(self, groups, obs_space, act_space)
|
2022-01-06 14:34:20 -08:00
|
|
|
|
|
|
|
# __grouping_doc_end__
|
2022-02-08 16:29:25 -08:00
|
|
|
# fmt: on
|
2021-12-01 00:01:02 -08:00
|
|
|
|
|
|
|
@PublicAPI
|
2022-01-01 22:11:06 -06:00
|
|
|
def to_base_env(
|
|
|
|
self,
|
2022-01-24 19:38:21 +01:00
|
|
|
make_env: Optional[Callable[[int], EnvType]] = None,
|
2022-01-01 22:11:06 -06:00
|
|
|
num_envs: int = 1,
|
|
|
|
remote_envs: bool = False,
|
|
|
|
remote_env_batch_wait_ms: int = 0,
|
2022-07-15 08:55:14 +02:00
|
|
|
restart_failed_sub_environments: bool = False,
|
2022-01-01 22:11:06 -06:00
|
|
|
) -> "BaseEnv":
|
2021-12-01 00:01:02 -08:00
|
|
|
"""Converts an RLlib MultiAgentEnv into a BaseEnv object.
|
|
|
|
|
|
|
|
The resulting BaseEnv is always vectorized (contains n
|
|
|
|
sub-environments) to support batched forward passes, where n may
|
|
|
|
also be 1. BaseEnv also supports async execution via the `poll` and
|
|
|
|
`send_actions` methods and thus supports external simulators.
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-12-01 00:01:02 -08:00
|
|
|
Args:
|
|
|
|
make_env: A callable taking an int as input (which indicates
|
|
|
|
the number of individual sub-environments within the final
|
|
|
|
vectorized BaseEnv) and returning one individual
|
|
|
|
sub-environment.
|
|
|
|
num_envs: The number of sub-environments to create in the
|
|
|
|
resulting (vectorized) BaseEnv. The already existing `env`
|
|
|
|
will be one of the `num_envs`.
|
|
|
|
remote_envs: Whether each sub-env should be a @ray.remote
|
|
|
|
actor. You can set this behavior in your config via the
|
|
|
|
`remote_worker_envs=True` option.
|
|
|
|
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
|
|
|
sub-environments for, if applicable. Only used if
|
|
|
|
`remote_envs` is True.
|
2022-07-15 08:55:14 +02:00
|
|
|
restart_failed_sub_environments: If True and any sub-environment (within
|
|
|
|
a vectorized env) throws any error during env stepping, we will try to
|
|
|
|
restart the faulty sub-environment. This is done
|
|
|
|
without disturbing the other (still intact) sub-environments.
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-12-01 00:01:02 -08:00
|
|
|
Returns:
|
|
|
|
The resulting BaseEnv object.
|
|
|
|
"""
|
2022-01-08 17:13:04 +01:00
|
|
|
from ray.rllib.env.remote_base_env import RemoteBaseEnv
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-12-01 00:01:02 -08:00
|
|
|
if remote_envs:
|
|
|
|
env = RemoteBaseEnv(
|
|
|
|
make_env,
|
|
|
|
num_envs,
|
|
|
|
multiagent=True,
|
|
|
|
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
2022-07-15 08:55:14 +02:00
|
|
|
restart_failed_sub_environments=restart_failed_sub_environments,
|
2021-12-01 00:01:02 -08:00
|
|
|
)
|
|
|
|
# Sub-environments are not ray.remote actors.
|
|
|
|
else:
|
|
|
|
env = MultiAgentEnvWrapper(
|
2022-07-15 08:55:14 +02:00
|
|
|
make_env=make_env,
|
|
|
|
existing_envs=[self],
|
|
|
|
num_envs=num_envs,
|
|
|
|
restart_failed_sub_environments=restart_failed_sub_environments,
|
2021-12-01 00:01:02 -08:00
|
|
|
)
|
|
|
|
|
|
|
|
return env
|
|
|
|
|
2022-01-06 14:34:20 -08:00
|
|
|
@DeveloperAPI
|
|
|
|
def _check_if_space_maps_agent_id_to_sub_space(self) -> bool:
|
|
|
|
# do the action and observation spaces map from agent ids to spaces
|
|
|
|
# for the individual agents?
|
|
|
|
obs_space_check = (
|
|
|
|
hasattr(self, "observation_space")
|
|
|
|
and isinstance(self.observation_space, gym.spaces.Dict)
|
2022-03-04 19:16:30 +01:00
|
|
|
and set(self.observation_space.spaces.keys()) == self.get_agent_ids()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-01-06 14:34:20 -08:00
|
|
|
action_space_check = (
|
|
|
|
hasattr(self, "action_space")
|
|
|
|
and isinstance(self.action_space, gym.spaces.Dict)
|
|
|
|
and set(self.action_space.keys()) == self.get_agent_ids()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-01-06 14:34:20 -08:00
|
|
|
return obs_space_check and action_space_check
|
|
|
|
|
2021-01-19 10:09:39 +01:00
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@PublicAPI
|
2022-01-25 14:16:58 +01:00
|
|
|
def make_multi_agent(
|
|
|
|
env_name_or_creator: Union[str, EnvCreator],
|
|
|
|
) -> Type["MultiAgentEnv"]:
|
2021-06-30 12:32:11 +02:00
|
|
|
"""Convenience wrapper for any single-agent env to be converted into MA.
|
2021-01-19 10:09:39 +01:00
|
|
|
|
2021-11-17 21:40:16 +01:00
|
|
|
Allows you to convert a simple (single-agent) `gym.Env` class
|
|
|
|
into a `MultiAgentEnv` class. This function simply stacks n instances
|
|
|
|
of the given ```gym.Env``` class into one unified ``MultiAgentEnv`` class
|
|
|
|
and returns this class, thus pretending the agents act together in the
|
|
|
|
same environment, whereas - under the hood - they live separately from
|
|
|
|
each other in n parallel single-agent envs.
|
|
|
|
|
|
|
|
Agent IDs in the resulting and are int numbers starting from 0
|
|
|
|
(first agent).
|
2021-01-19 10:09:39 +01:00
|
|
|
|
|
|
|
Args:
|
2021-10-29 10:46:52 +02:00
|
|
|
env_name_or_creator: String specifier or env_maker function taking
|
|
|
|
an EnvContext object as only arg and returning a gym.Env.
|
2021-01-19 10:09:39 +01:00
|
|
|
|
|
|
|
Returns:
|
2021-10-29 10:46:52 +02:00
|
|
|
New MultiAgentEnv class to be used as env.
|
|
|
|
The constructor takes a config dict with `num_agents` key
|
|
|
|
(default=1). The rest of the config dict will be passed on to the
|
|
|
|
underlying single-agent env's constructor.
|
2021-01-19 10:09:39 +01:00
|
|
|
|
|
|
|
Examples:
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> from ray.rllib.env.multi_agent_env import make_multi_agent
|
2021-01-19 10:09:39 +01:00
|
|
|
>>> # By gym string:
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> ma_cartpole_cls = make_multi_agent("CartPole-v0") # doctest: +SKIP
|
2021-01-19 10:09:39 +01:00
|
|
|
>>> # Create a 2 agent multi-agent cartpole.
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> ma_cartpole = ma_cartpole_cls({"num_agents": 2}) # doctest: +SKIP
|
|
|
|
>>> obs = ma_cartpole.reset() # doctest: +SKIP
|
|
|
|
>>> print(obs) # doctest: +SKIP
|
|
|
|
{0: [...], 1: [...]}
|
2021-01-19 10:09:39 +01:00
|
|
|
>>> # By env-maker callable:
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> from ray.rllib.examples.env.stateless_cartpole # doctest: +SKIP
|
|
|
|
... import StatelessCartPole
|
|
|
|
>>> ma_stateless_cartpole_cls = make_multi_agent( # doctest: +SKIP
|
|
|
|
... lambda config: StatelessCartPole(config)) # doctest: +SKIP
|
2021-11-17 21:40:16 +01:00
|
|
|
>>> # Create a 3 agent multi-agent stateless cartpole.
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> ma_stateless_cartpole = ma_stateless_cartpole_cls( # doctest: +SKIP
|
|
|
|
... {"num_agents": 3}) # doctest: +SKIP
|
|
|
|
>>> print(obs) # doctest: +SKIP
|
|
|
|
{0: [...], 1: [...], 2: [...]}
|
2021-01-19 10:09:39 +01:00
|
|
|
"""
|
|
|
|
|
|
|
|
class MultiEnv(MultiAgentEnv):
|
2022-07-27 00:10:45 -07:00
|
|
|
def __init__(self, config: EnvContext = None):
|
2022-01-18 07:34:06 -08:00
|
|
|
MultiAgentEnv.__init__(self)
|
2022-07-27 00:10:45 -07:00
|
|
|
# Note(jungong) : explicitly check for None here, because config
|
|
|
|
# can have an empty dict but meaningful data fields (worker_index,
|
|
|
|
# vector_index) etc.
|
|
|
|
# TODO(jungong) : clean this up, so we are not mixing up dict fields
|
|
|
|
# with data fields.
|
|
|
|
if config is None:
|
|
|
|
config = {}
|
2021-01-19 10:09:39 +01:00
|
|
|
num = config.pop("num_agents", 1)
|
|
|
|
if isinstance(env_name_or_creator, str):
|
|
|
|
self.agents = [gym.make(env_name_or_creator) for _ in range(num)]
|
|
|
|
else:
|
|
|
|
self.agents = [env_name_or_creator(config) for _ in range(num)]
|
|
|
|
self.dones = set()
|
|
|
|
self.observation_space = self.agents[0].observation_space
|
|
|
|
self.action_space = self.agents[0].action_space
|
2022-01-06 14:34:20 -08:00
|
|
|
self._agent_ids = set(range(num))
|
|
|
|
|
|
|
|
@override(MultiAgentEnv)
|
|
|
|
def observation_space_sample(self, agent_ids: list = None) -> MultiAgentDict:
|
|
|
|
if agent_ids is None:
|
|
|
|
agent_ids = list(range(len(self.agents)))
|
|
|
|
obs = {agent_id: self.observation_space.sample() for agent_id in agent_ids}
|
|
|
|
|
|
|
|
return obs
|
|
|
|
|
|
|
|
@override(MultiAgentEnv)
|
|
|
|
def action_space_sample(self, agent_ids: list = None) -> MultiAgentDict:
|
|
|
|
if agent_ids is None:
|
|
|
|
agent_ids = list(range(len(self.agents)))
|
|
|
|
actions = {agent_id: self.action_space.sample() for agent_id in agent_ids}
|
|
|
|
|
|
|
|
return actions
|
|
|
|
|
|
|
|
@override(MultiAgentEnv)
|
|
|
|
def action_space_contains(self, x: MultiAgentDict) -> bool:
|
2022-01-13 02:31:22 -08:00
|
|
|
if not isinstance(x, dict):
|
|
|
|
return False
|
2022-01-06 14:34:20 -08:00
|
|
|
return all(self.action_space.contains(val) for val in x.values())
|
|
|
|
|
|
|
|
@override(MultiAgentEnv)
|
|
|
|
def observation_space_contains(self, x: MultiAgentDict) -> bool:
|
2022-01-13 02:31:22 -08:00
|
|
|
if not isinstance(x, dict):
|
|
|
|
return False
|
2022-01-06 14:34:20 -08:00
|
|
|
return all(self.observation_space.contains(val) for val in x.values())
|
2021-01-19 10:09:39 +01:00
|
|
|
|
2021-05-12 12:16:00 +02:00
|
|
|
@override(MultiAgentEnv)
|
2021-01-19 10:09:39 +01:00
|
|
|
def reset(self):
|
|
|
|
self.dones = set()
|
|
|
|
return {i: a.reset() for i, a in enumerate(self.agents)}
|
|
|
|
|
2021-05-12 12:16:00 +02:00
|
|
|
@override(MultiAgentEnv)
|
2021-01-19 10:09:39 +01:00
|
|
|
def step(self, action_dict):
|
|
|
|
obs, rew, done, info = {}, {}, {}, {}
|
|
|
|
for i, action in action_dict.items():
|
|
|
|
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
|
|
|
|
if done[i]:
|
|
|
|
self.dones.add(i)
|
|
|
|
done["__all__"] = len(self.dones) == len(self.agents)
|
|
|
|
return obs, rew, done, info
|
|
|
|
|
2021-05-12 12:16:00 +02:00
|
|
|
@override(MultiAgentEnv)
|
|
|
|
def render(self, mode=None):
|
2021-06-19 08:57:53 +02:00
|
|
|
return self.agents[0].render(mode)
|
2021-05-12 12:16:00 +02:00
|
|
|
|
2021-01-19 10:09:39 +01:00
|
|
|
return MultiEnv
|
2021-11-30 17:02:10 -08:00
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@PublicAPI
|
2021-11-30 17:02:10 -08:00
|
|
|
class MultiAgentEnvWrapper(BaseEnv):
|
|
|
|
"""Internal adapter of MultiAgentEnv to BaseEnv.
|
|
|
|
|
|
|
|
This also supports vectorization if num_envs > 1.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
make_env: Callable[[int], EnvType],
|
2022-04-21 11:20:54 +02:00
|
|
|
existing_envs: List["MultiAgentEnv"],
|
2021-11-30 17:02:10 -08:00
|
|
|
num_envs: int,
|
2022-07-15 08:55:14 +02:00
|
|
|
restart_failed_sub_environments: bool = False,
|
2021-11-30 17:02:10 -08:00
|
|
|
):
|
|
|
|
"""Wraps MultiAgentEnv(s) into the BaseEnv API.
|
|
|
|
|
|
|
|
Args:
|
2022-04-21 11:20:54 +02:00
|
|
|
make_env: Factory that produces a new MultiAgentEnv instance taking the
|
|
|
|
vector index as only call argument.
|
|
|
|
Must be defined, if the number of existing envs is less than num_envs.
|
|
|
|
existing_envs: List of already existing multi-agent envs.
|
|
|
|
num_envs: Desired num multiagent envs to have at the end in
|
2021-11-30 17:02:10 -08:00
|
|
|
total. This will include the given (already created)
|
|
|
|
`existing_envs`.
|
2022-07-15 08:55:14 +02:00
|
|
|
restart_failed_sub_environments: If True and any sub-environment (within
|
|
|
|
this vectorized env) throws any error during env stepping, we will try
|
|
|
|
to restart the faulty sub-environment. This is done
|
|
|
|
without disturbing the other (still intact) sub-environments.
|
2021-11-30 17:02:10 -08:00
|
|
|
"""
|
|
|
|
self.make_env = make_env
|
|
|
|
self.envs = existing_envs
|
|
|
|
self.num_envs = num_envs
|
2022-07-15 08:55:14 +02:00
|
|
|
self.restart_failed_sub_environments = restart_failed_sub_environments
|
|
|
|
|
2021-11-30 17:02:10 -08:00
|
|
|
self.dones = set()
|
|
|
|
while len(self.envs) < self.num_envs:
|
|
|
|
self.envs.append(self.make_env(len(self.envs)))
|
|
|
|
for env in self.envs:
|
|
|
|
assert isinstance(env, MultiAgentEnv)
|
2022-05-28 10:50:03 +02:00
|
|
|
self._init_env_state(idx=None)
|
2022-01-18 07:34:06 -08:00
|
|
|
self._unwrapped_env = self.envs[0].unwrapped
|
2021-11-30 17:02:10 -08:00
|
|
|
|
|
|
|
@override(BaseEnv)
|
|
|
|
def poll(
|
|
|
|
self,
|
|
|
|
) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
|
|
|
|
obs, rewards, dones, infos = {}, {}, {}, {}
|
|
|
|
for i, env_state in enumerate(self.env_states):
|
|
|
|
obs[i], rewards[i], dones[i], infos[i] = env_state.poll()
|
|
|
|
return obs, rewards, dones, infos, {}
|
|
|
|
|
|
|
|
@override(BaseEnv)
|
|
|
|
def send_actions(self, action_dict: MultiEnvDict) -> None:
|
|
|
|
for env_id, agent_dict in action_dict.items():
|
|
|
|
if env_id in self.dones:
|
2022-07-15 08:55:14 +02:00
|
|
|
raise ValueError(
|
|
|
|
f"Env {env_id} is already done and cannot accept new actions"
|
|
|
|
)
|
2021-11-30 17:02:10 -08:00
|
|
|
env = self.envs[env_id]
|
2022-07-15 08:55:14 +02:00
|
|
|
try:
|
|
|
|
obs, rewards, dones, infos = env.step(agent_dict)
|
|
|
|
except Exception as e:
|
|
|
|
if self.restart_failed_sub_environments:
|
|
|
|
logger.exception(e.args[0])
|
|
|
|
self.try_restart(env_id=env_id)
|
|
|
|
obs, rewards, dones, infos = e, {}, {"__all__": True}, {}
|
|
|
|
else:
|
|
|
|
raise e
|
|
|
|
|
|
|
|
assert isinstance(
|
|
|
|
obs, (dict, Exception)
|
|
|
|
), "Not a multi-agent obs dict or an Exception!"
|
|
|
|
assert isinstance(rewards, dict), "Not a multi-agent reward dict!"
|
|
|
|
assert isinstance(dones, dict), "Not a multi-agent done dict!"
|
|
|
|
assert isinstance(infos, dict), "Not a multi-agent info dict!"
|
|
|
|
if isinstance(obs, dict) and set(infos).difference(set(obs)):
|
2021-11-30 17:02:10 -08:00
|
|
|
raise ValueError(
|
|
|
|
"Key set for infos must be a subset of obs: "
|
|
|
|
"{} vs {}".format(infos.keys(), obs.keys())
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-11-30 17:02:10 -08:00
|
|
|
if "__all__" not in dones:
|
|
|
|
raise ValueError(
|
|
|
|
"In multi-agent environments, '__all__': True|False must "
|
|
|
|
"be included in the 'done' dict: got {}.".format(dones)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-07-15 08:55:14 +02:00
|
|
|
|
2021-11-30 17:02:10 -08:00
|
|
|
if dones["__all__"]:
|
|
|
|
self.dones.add(env_id)
|
|
|
|
self.env_states[env_id].observe(obs, rewards, dones, infos)
|
|
|
|
|
|
|
|
@override(BaseEnv)
|
2021-12-06 05:15:33 -08:00
|
|
|
def try_reset(self, env_id: Optional[EnvID] = None) -> Optional[MultiEnvDict]:
|
2022-01-18 07:34:06 -08:00
|
|
|
ret = {}
|
|
|
|
if isinstance(env_id, int):
|
|
|
|
env_id = [env_id]
|
|
|
|
if env_id is None:
|
|
|
|
env_id = list(range(len(self.envs)))
|
|
|
|
for idx in env_id:
|
|
|
|
obs = self.env_states[idx].reset()
|
2022-07-15 08:55:14 +02:00
|
|
|
if isinstance(obs, Exception):
|
|
|
|
if self.restart_failed_sub_environments:
|
|
|
|
self.env_states[idx].env = self.envs[idx] = self.make_env(idx)
|
|
|
|
else:
|
|
|
|
raise obs
|
|
|
|
else:
|
|
|
|
assert isinstance(obs, dict), "Not a multi-agent obs dict!"
|
2022-01-18 07:34:06 -08:00
|
|
|
if obs is not None and idx in self.dones:
|
|
|
|
self.dones.remove(idx)
|
|
|
|
ret[idx] = obs
|
|
|
|
return ret
|
2021-11-30 17:02:10 -08:00
|
|
|
|
|
|
|
@override(BaseEnv)
|
2022-05-28 10:50:03 +02:00
|
|
|
def try_restart(self, env_id: Optional[EnvID] = None) -> None:
|
|
|
|
if isinstance(env_id, int):
|
|
|
|
env_id = [env_id]
|
|
|
|
if env_id is None:
|
|
|
|
env_id = list(range(len(self.envs)))
|
|
|
|
for idx in env_id:
|
|
|
|
# Recreate the sub-env.
|
2022-07-15 08:55:14 +02:00
|
|
|
logger.warning(f"Trying to restart sub-environment at index {idx}.")
|
|
|
|
self.env_states[idx].env = self.envs[idx] = self.make_env(idx)
|
|
|
|
logger.warning(f"Sub-environment at index {idx} restarted successfully.")
|
2022-05-28 10:50:03 +02:00
|
|
|
|
|
|
|
@override(BaseEnv)
|
|
|
|
def get_sub_environments(
|
|
|
|
self, as_dict: bool = False
|
|
|
|
) -> Union[Dict[str, EnvType], List[EnvType]]:
|
2021-12-09 05:40:40 -08:00
|
|
|
if as_dict:
|
2022-05-27 14:56:24 +02:00
|
|
|
return {_id: env_state.env for _id, env_state in enumerate(self.env_states)}
|
2021-11-30 17:02:10 -08:00
|
|
|
return [state.env for state in self.env_states]
|
|
|
|
|
|
|
|
@override(BaseEnv)
|
|
|
|
def try_render(self, env_id: Optional[EnvID] = None) -> None:
|
|
|
|
if env_id is None:
|
|
|
|
env_id = 0
|
|
|
|
assert isinstance(env_id, int)
|
|
|
|
return self.envs[env_id].render()
|
|
|
|
|
2021-12-09 05:40:40 -08:00
|
|
|
@property
|
|
|
|
@override(BaseEnv)
|
|
|
|
@PublicAPI
|
|
|
|
def observation_space(self) -> gym.spaces.Dict:
|
2022-05-28 10:50:03 +02:00
|
|
|
return self.envs[0].observation_space
|
2021-12-09 05:40:40 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
@override(BaseEnv)
|
|
|
|
@PublicAPI
|
|
|
|
def action_space(self) -> gym.Space:
|
2022-01-06 14:34:20 -08:00
|
|
|
return self.envs[0].action_space
|
|
|
|
|
|
|
|
@override(BaseEnv)
|
|
|
|
def observation_space_contains(self, x: MultiEnvDict) -> bool:
|
|
|
|
return all(self.envs[0].observation_space_contains(val) for val in x.values())
|
|
|
|
|
|
|
|
@override(BaseEnv)
|
|
|
|
def action_space_contains(self, x: MultiEnvDict) -> bool:
|
|
|
|
return all(self.envs[0].action_space_contains(val) for val in x.values())
|
|
|
|
|
|
|
|
@override(BaseEnv)
|
|
|
|
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
|
2022-01-18 07:34:06 -08:00
|
|
|
return {0: self.envs[0].observation_space_sample(agent_ids)}
|
2022-01-06 14:34:20 -08:00
|
|
|
|
|
|
|
@override(BaseEnv)
|
|
|
|
def action_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
|
2022-01-18 07:34:06 -08:00
|
|
|
return {0: self.envs[0].action_space_sample(agent_ids)}
|
|
|
|
|
|
|
|
@override(BaseEnv)
|
|
|
|
def get_agent_ids(self) -> Set[AgentID]:
|
2022-04-21 11:20:54 +02:00
|
|
|
return self.envs[0].get_agent_ids()
|
2021-12-09 05:40:40 -08:00
|
|
|
|
2022-05-28 10:50:03 +02:00
|
|
|
def _init_env_state(self, idx: Optional[int] = None) -> None:
|
|
|
|
"""Resets all or one particular sub-environment's state (by index).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
idx: The index to reset at. If None, reset all the sub-environments' states.
|
|
|
|
"""
|
|
|
|
# If index is None, reset all sub-envs' states:
|
|
|
|
if idx is None:
|
2022-07-15 08:55:14 +02:00
|
|
|
self.env_states = [
|
|
|
|
_MultiAgentEnvState(env, self.restart_failed_sub_environments)
|
|
|
|
for env in self.envs
|
|
|
|
]
|
2022-05-28 10:50:03 +02:00
|
|
|
# Index provided, reset only the sub-env's state at the given index.
|
|
|
|
else:
|
|
|
|
assert isinstance(idx, int)
|
2022-07-15 08:55:14 +02:00
|
|
|
self.env_states[idx] = _MultiAgentEnvState(
|
|
|
|
self.envs[idx], self.restart_failed_sub_environments
|
|
|
|
)
|
2022-05-28 10:50:03 +02:00
|
|
|
|
2021-11-30 17:02:10 -08:00
|
|
|
|
|
|
|
class _MultiAgentEnvState:
|
2022-07-15 08:55:14 +02:00
|
|
|
def __init__(self, env: MultiAgentEnv, return_error_as_obs: bool = False):
|
2021-11-30 17:02:10 -08:00
|
|
|
assert isinstance(env, MultiAgentEnv)
|
|
|
|
self.env = env
|
2022-07-15 08:55:14 +02:00
|
|
|
self.return_error_as_obs = return_error_as_obs
|
|
|
|
|
2021-11-30 17:02:10 -08:00
|
|
|
self.initialized = False
|
|
|
|
self.last_obs = {}
|
|
|
|
self.last_rewards = {}
|
|
|
|
self.last_dones = {"__all__": False}
|
|
|
|
self.last_infos = {}
|
|
|
|
|
|
|
|
def poll(
|
|
|
|
self,
|
|
|
|
) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]:
|
|
|
|
if not self.initialized:
|
|
|
|
self.reset()
|
|
|
|
self.initialized = True
|
|
|
|
|
|
|
|
observations = self.last_obs
|
|
|
|
rewards = {}
|
|
|
|
dones = {"__all__": self.last_dones["__all__"]}
|
|
|
|
infos = {}
|
|
|
|
|
2022-07-15 08:55:14 +02:00
|
|
|
# If episode is done or we have an error, release everything we have.
|
|
|
|
if dones["__all__"] or isinstance(observations, Exception):
|
2021-11-30 17:02:10 -08:00
|
|
|
rewards = self.last_rewards
|
|
|
|
self.last_rewards = {}
|
|
|
|
dones = self.last_dones
|
2022-07-15 08:55:14 +02:00
|
|
|
if isinstance(observations, Exception):
|
|
|
|
dones["__all__"] = True
|
2021-11-30 17:02:10 -08:00
|
|
|
self.last_dones = {}
|
|
|
|
self.last_obs = {}
|
|
|
|
infos = self.last_infos
|
|
|
|
self.last_infos = {}
|
|
|
|
# Only release those agents' rewards/dones/infos, whose
|
|
|
|
# observations we have.
|
|
|
|
else:
|
|
|
|
for ag in observations.keys():
|
|
|
|
if ag in self.last_rewards:
|
|
|
|
rewards[ag] = self.last_rewards[ag]
|
|
|
|
del self.last_rewards[ag]
|
|
|
|
if ag in self.last_dones:
|
|
|
|
dones[ag] = self.last_dones[ag]
|
|
|
|
del self.last_dones[ag]
|
|
|
|
if ag in self.last_infos:
|
|
|
|
infos[ag] = self.last_infos[ag]
|
|
|
|
del self.last_infos[ag]
|
|
|
|
|
|
|
|
self.last_dones["__all__"] = False
|
|
|
|
return observations, rewards, dones, infos
|
|
|
|
|
|
|
|
def observe(
|
|
|
|
self,
|
|
|
|
obs: MultiAgentDict,
|
|
|
|
rewards: MultiAgentDict,
|
|
|
|
dones: MultiAgentDict,
|
|
|
|
infos: MultiAgentDict,
|
|
|
|
):
|
|
|
|
self.last_obs = obs
|
|
|
|
for ag, r in rewards.items():
|
|
|
|
if ag in self.last_rewards:
|
|
|
|
self.last_rewards[ag] += r
|
|
|
|
else:
|
|
|
|
self.last_rewards[ag] = r
|
|
|
|
for ag, d in dones.items():
|
|
|
|
if ag in self.last_dones:
|
|
|
|
self.last_dones[ag] = self.last_dones[ag] or d
|
|
|
|
else:
|
|
|
|
self.last_dones[ag] = d
|
|
|
|
self.last_infos = infos
|
|
|
|
|
|
|
|
def reset(self) -> MultiAgentDict:
|
2022-07-15 08:55:14 +02:00
|
|
|
try:
|
|
|
|
self.last_obs = self.env.reset()
|
|
|
|
except Exception as e:
|
|
|
|
if self.return_error_as_obs:
|
|
|
|
logger.exception(e.args[0])
|
|
|
|
self.last_obs = e
|
|
|
|
else:
|
|
|
|
raise e
|
2021-11-30 17:02:10 -08:00
|
|
|
self.last_rewards = {}
|
|
|
|
self.last_dones = {"__all__": False}
|
|
|
|
self.last_infos = {}
|
|
|
|
return self.last_obs
|