[RLlib] [MultiAgentEnv Refactor #2] Change space types for BaseEnvs and MultiAgentEnvs (#21063)

This commit is contained in:
Avnish Narayan 2022-01-06 14:34:20 -08:00 committed by GitHub
parent 8b4cb45088
commit 39f8072eac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 277 additions and 60 deletions

41
rllib/env/base_env.py vendored
View file

@ -1,6 +1,6 @@
import logging
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
Union
Union, Set
import gym
import ray
@ -198,14 +198,13 @@ class BaseEnv:
return []
@PublicAPI
def get_agent_ids(self) -> Dict[EnvID, List[AgentID]]:
"""Return the agent ids for each sub-environment.
def get_agent_ids(self) -> Set[AgentID]:
"""Return the agent ids for the sub_environment.
Returns:
A dict mapping from env_id to a list of agent_ids.
All agent ids for each the environment.
"""
logger.warning("get_agent_ids() has not been implemented")
return {}
return {_DUMMY_AGENT_ID}
@PublicAPI
def try_render(self, env_id: Optional[EnvID] = None) -> None:
@ -234,8 +233,8 @@ class BaseEnv:
@PublicAPI
@property
def observation_space(self) -> gym.spaces.Dict:
"""Returns the observation space for each environment.
def observation_space(self) -> gym.Space:
"""Returns the observation space for each agent.
Note: samples from the observation space need to be preprocessed into a
`MultiEnvDict` before being used by a policy.
@ -248,7 +247,7 @@ class BaseEnv:
@PublicAPI
@property
def action_space(self) -> gym.Space:
"""Returns the action space for each environment.
"""Returns the action space for each agent.
Note: samples from the action space need to be preprocessed into a
`MultiEnvDict` before being passed to `send_actions`.
@ -270,6 +269,7 @@ class BaseEnv:
Returns:
A random action for each environment.
"""
logger.warning("action_space_sample() has not been implemented")
del agent_id
return {}
@ -286,6 +286,7 @@ class BaseEnv:
A random action for each environment.
"""
logger.warning("observation_space_sample() has not been implemented")
del agent_id
return {}
@PublicAPI
@ -326,8 +327,7 @@ class BaseEnv:
"""
return self._space_contains(self.action_space, x)
@staticmethod
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
def _space_contains(self, space: gym.Space, x: MultiEnvDict) -> bool:
"""Check if the given space contains the observations of x.
Args:
@ -337,17 +337,14 @@ class BaseEnv:
Returns:
True if the observations of x are contained in space.
"""
# this removes the agent_id key and inner dicts
# in MultiEnvDicts
flattened_obs = {
env_id: list(obs.values())
for env_id, obs in x.items()
}
ret = True
for env_id in flattened_obs:
for obs in flattened_obs[env_id]:
ret = ret and space[env_id].contains(obs)
return ret
agents = set(self.get_agent_ids())
for multi_agent_dict in x.values():
for agent_id, obs in multi_agent_dict:
if (agent_id not in agents) or (
not space[agent_id].contains(obs)):
return False
return True
# Fixed agent identifier when there is only the single agent in the env

View file

@ -1,15 +1,19 @@
import gym
from typing import Callable, Dict, List, Tuple, Type, Optional, Union
import logging
from typing import Callable, Dict, List, Tuple, Type, Optional, Union, Set
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI
from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI, \
DeveloperAPI
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
MultiEnvDict
# If the obs space is Dict type, look for the global state under this key.
ENV_STATE = "state"
logger = logging.getLogger(__name__)
@PublicAPI
class MultiAgentEnv(gym.Env):
@ -20,6 +24,15 @@ class MultiAgentEnv(gym.Env):
referred to as "agents" or "RL agents".
"""
def __init__(self):
self.observation_space = None
self.action_space = None
self._agent_ids = {}
# do the action and observation spaces map from agent ids to spaces
# for the individual agents?
self._spaces_in_preferred_format = None
@PublicAPI
def reset(self) -> MultiAgentDict:
"""Resets the env and returns observations from ready agents.
@ -81,6 +94,113 @@ class MultiAgentEnv(gym.Env):
"""
raise NotImplementedError
@ExperimentalAPI
def observation_space_contains(self, x: MultiAgentDict) -> bool:
"""Checks if the observation space contains the given key.
Args:
x: Observations to check.
Returns:
True if the observation space contains the given all observations
in x.
"""
if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
if self._spaces_in_preferred_format:
return self.observation_space.contains(x)
logger.warning("observation_space_contains() has not been implemented")
return True
@ExperimentalAPI
def action_space_contains(self, x: MultiAgentDict) -> bool:
"""Checks if the action space contains the given action.
Args:
x: Actions to check.
Returns:
True if the action space contains all actions in x.
"""
if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
if self._spaces_in_preferred_format:
return self.action_space.contains(x)
logger.warning("action_space_contains() has not been implemented")
return True
@ExperimentalAPI
def action_space_sample(self, agent_ids: list = None) -> MultiAgentDict:
"""Returns a random action for each environment, and potentially each
agent in that environment.
Args:
agent_ids: List of agent ids to sample actions for. If None or
empty list, sample actions for all agents in the
environment.
Returns:
A random action for each environment.
"""
if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
if self._spaces_in_preferred_format:
if agent_ids is None:
agent_ids = self.get_agent_ids()
samples = self.action_space.sample()
return {agent_id: samples[agent_id] for agent_id in agent_ids}
logger.warning("action_space_sample() has not been implemented")
del agent_ids
return {}
@ExperimentalAPI
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
"""Returns a random observation from the observation space for each
agent if agent_ids is None, otherwise returns a random observation for
the agents in agent_ids.
Args:
agent_ids: List of agent ids to sample actions for. If None or
empty list, sample actions for all agents in the
environment.
Returns:
A random action for each environment.
"""
if not hasattr(self, "_spaces_in_preferred_format") or \
self._spaces_in_preferred_format is None:
self._spaces_in_preferred_format = \
self._check_if_space_maps_agent_id_to_sub_space()
if self._spaces_in_preferred_format:
if agent_ids is None:
agent_ids = self.get_agent_ids()
samples = self.observation_space.sample()
samples = {agent_id: samples[agent_id] for agent_id in agent_ids}
return samples
logger.warning("observation_space_sample() has not been implemented")
del agent_ids
return {}
@PublicAPI
def get_agent_ids(self) -> Set[AgentID]:
"""Returns a set of agent ids in the environment.
Returns:
set of agent ids.
"""
if not isinstance(self._agent_ids, set):
self._agent_ids = set(self._agent_ids)
return self._agent_ids
@PublicAPI
def render(self, mode=None) -> None:
"""Tries to render the environment."""
@ -88,13 +208,13 @@ class MultiAgentEnv(gym.Env):
# By default, do nothing.
pass
# yapf: disable
# __grouping_doc_begin__
# yapf: disable
# __grouping_doc_begin__
@ExperimentalAPI
def with_agent_groups(
self,
groups: Dict[str, List[AgentID]],
obs_space: gym.Space = None,
self,
groups: Dict[str, List[AgentID]],
obs_space: gym.Space = None,
act_space: gym.Space = None) -> "MultiAgentEnv":
"""Convenience method for grouping together agents in this env.
@ -132,8 +252,9 @@ class MultiAgentEnv(gym.Env):
from ray.rllib.env.wrappers.group_agents_wrapper import \
GroupAgentsWrapper
return GroupAgentsWrapper(self, groups, obs_space, act_space)
# __grouping_doc_end__
# yapf: enable
# __grouping_doc_end__
# yapf: enable
@PublicAPI
def to_base_env(
@ -182,6 +303,20 @@ class MultiAgentEnv(gym.Env):
return env
@DeveloperAPI
def _check_if_space_maps_agent_id_to_sub_space(self) -> bool:
# do the action and observation spaces map from agent ids to spaces
# for the individual agents?
obs_space_check = (
hasattr(self, "observation_space")
and isinstance(self.observation_space, gym.spaces.Dict)
and set(self.observation_space.keys()) == self.get_agent_ids())
action_space_check = (
hasattr(self, "action_space")
and isinstance(self.action_space, gym.spaces.Dict)
and set(self.action_space.keys()) == self.get_agent_ids())
return obs_space_check and action_space_check
def make_multi_agent(
env_name_or_creator: Union[str, Callable[[EnvContext], EnvType]],
@ -242,6 +377,40 @@ def make_multi_agent(
self.dones = set()
self.observation_space = self.agents[0].observation_space
self.action_space = self.agents[0].action_space
self._agent_ids = set(range(num))
@override(MultiAgentEnv)
def observation_space_sample(self,
agent_ids: list = None) -> MultiAgentDict:
if agent_ids is None:
agent_ids = list(range(len(self.agents)))
obs = {
agent_id: self.observation_space.sample()
for agent_id in agent_ids
}
return obs
@override(MultiAgentEnv)
def action_space_sample(self,
agent_ids: list = None) -> MultiAgentDict:
if agent_ids is None:
agent_ids = list(range(len(self.agents)))
actions = {
agent_id: self.action_space.sample()
for agent_id in agent_ids
}
return actions
@override(MultiAgentEnv)
def action_space_contains(self, x: MultiAgentDict) -> bool:
return all(self.action_space.contains(val) for val in x.values())
@override(MultiAgentEnv)
def observation_space_contains(self, x: MultiAgentDict) -> bool:
return all(
self.observation_space.contains(val) for val in x.values())
@override(MultiAgentEnv)
def reset(self):
@ -277,7 +446,7 @@ class MultiAgentEnvWrapper(BaseEnv):
Args:
make_env (Callable[[int], EnvType]): Factory that produces a new
MultiAgentEnv intance. Must be defined, if the number of
MultiAgentEnv instance. Must be defined, if the number of
existing envs is less than num_envs.
existing_envs (List[MultiAgentEnv]): List of already existing
multi-agent envs.
@ -355,18 +524,31 @@ class MultiAgentEnvWrapper(BaseEnv):
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
space = {
_id: env.observation_space
for _id, env in enumerate(self.envs)
}
return gym.spaces.Dict(space)
self.envs[0].observation_space
@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
space = {_id: env.action_space for _id, env in enumerate(self.envs)}
return gym.spaces.Dict(space)
return self.envs[0].action_space
@override(BaseEnv)
def observation_space_contains(self, x: MultiEnvDict) -> bool:
return all(
self.envs[0].observation_space_contains(val) for val in x.values())
@override(BaseEnv)
def action_space_contains(self, x: MultiEnvDict) -> bool:
return all(
self.envs[0].action_space_contains(val) for val in x.values())
@override(BaseEnv)
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
return self.envs[0].observation_space_sample(agent_ids)
@override(BaseEnv)
def action_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
return self.envs[0].action_space_sample(agent_ids)
class _MultiAgentEnvState:

50
rllib/env/tests/test_multi_agent_env.py vendored Normal file
View file

@ -0,0 +1,50 @@
import pytest
from ray.rllib.env.multi_agent_env import make_multi_agent
from ray.rllib.tests.test_nested_observation_spaces import NestedMultiAgentEnv
class TestMultiAgentEnv:
def test_space_in_preferred_format(self):
env = NestedMultiAgentEnv()
spaces_in_preferred_format = \
env._check_if_space_maps_agent_id_to_sub_space()
assert spaces_in_preferred_format, "Space is not in preferred " \
"format"
env2 = make_multi_agent("CartPole-v1")()
spaces_in_preferred_format = \
env2._check_if_space_maps_agent_id_to_sub_space()
assert not spaces_in_preferred_format, "Space should not be in " \
"preferred format but is."
def test_spaces_sample_contain_in_preferred_format(self):
env = NestedMultiAgentEnv()
# this environment has spaces that are in the preferred format
# for multi-agent environments where the spaces are dict spaces
# mapping agent-ids to sub-spaces
obs = env.observation_space_sample()
assert env.observation_space_contains(obs), "Observation space does " \
"not contain obs"
action = env.action_space_sample()
assert env.action_space_contains(action), "Action space does " \
"not contain action"
def test_spaces_sample_contain_not_in_preferred_format(self):
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
# this environment has spaces that are not in the preferred format
# for multi-agent environments where the spaces not in the preferred
# format, users must override the observation_space_contains,
# action_space_contains observation_space_sample,
# and action_space_sample methods in order to do proper checks
obs = env.observation_space_sample()
assert env.observation_space_contains(obs), "Observation space does " \
"not contain obs"
action = env.action_space_sample()
assert env.action_space_contains(action), "Action space does " \
"not contain action"
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -339,24 +339,3 @@ class VectorEnvWrapper(BaseEnv):
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space
@staticmethod
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
"""Check if the given space contains the observations of x.
Args:
space: The space to if x's observations are contained in.
x: The observations to check.
Note: With vector envs, we can process the raw observations
and ignore the agent ids and env ids, since vector envs'
sub environements are guaranteed to be the same
Returns:
True if the observations of x are contained in space.
"""
for _, multi_agent_dict in x.items():
for _, element in multi_agent_dict.items():
if not space.contains(element):
return False
return True

View file

@ -120,6 +120,15 @@ class RepeatedSpaceEnv(gym.Env):
class NestedMultiAgentEnv(MultiAgentEnv):
def __init__(self):
self.observation_space = spaces.Dict({
"dict_agent": DICT_SPACE,
"tuple_agent": TUPLE_SPACE
})
self.action_space = spaces.Dict({
"dict_agent": spaces.Discrete(1),
"tuple_agent": spaces.Discrete(1)
})
self._agent_ids = {"dict_agent", "tuple_agent"}
self.steps = 0
def reset(self):