mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
This commit is contained in:
parent
8b4cb45088
commit
39f8072eac
5 changed files with 277 additions and 60 deletions
41
rllib/env/base_env.py
vendored
41
rllib/env/base_env.py
vendored
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
|
||||
Union
|
||||
Union, Set
|
||||
|
||||
import gym
|
||||
import ray
|
||||
|
@ -198,14 +198,13 @@ class BaseEnv:
|
|||
return []
|
||||
|
||||
@PublicAPI
|
||||
def get_agent_ids(self) -> Dict[EnvID, List[AgentID]]:
|
||||
"""Return the agent ids for each sub-environment.
|
||||
def get_agent_ids(self) -> Set[AgentID]:
|
||||
"""Return the agent ids for the sub_environment.
|
||||
|
||||
Returns:
|
||||
A dict mapping from env_id to a list of agent_ids.
|
||||
All agent ids for each the environment.
|
||||
"""
|
||||
logger.warning("get_agent_ids() has not been implemented")
|
||||
return {}
|
||||
return {_DUMMY_AGENT_ID}
|
||||
|
||||
@PublicAPI
|
||||
def try_render(self, env_id: Optional[EnvID] = None) -> None:
|
||||
|
@ -234,8 +233,8 @@ class BaseEnv:
|
|||
|
||||
@PublicAPI
|
||||
@property
|
||||
def observation_space(self) -> gym.spaces.Dict:
|
||||
"""Returns the observation space for each environment.
|
||||
def observation_space(self) -> gym.Space:
|
||||
"""Returns the observation space for each agent.
|
||||
|
||||
Note: samples from the observation space need to be preprocessed into a
|
||||
`MultiEnvDict` before being used by a policy.
|
||||
|
@ -248,7 +247,7 @@ class BaseEnv:
|
|||
@PublicAPI
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
"""Returns the action space for each environment.
|
||||
"""Returns the action space for each agent.
|
||||
|
||||
Note: samples from the action space need to be preprocessed into a
|
||||
`MultiEnvDict` before being passed to `send_actions`.
|
||||
|
@ -270,6 +269,7 @@ class BaseEnv:
|
|||
Returns:
|
||||
A random action for each environment.
|
||||
"""
|
||||
logger.warning("action_space_sample() has not been implemented")
|
||||
del agent_id
|
||||
return {}
|
||||
|
||||
|
@ -286,6 +286,7 @@ class BaseEnv:
|
|||
A random action for each environment.
|
||||
"""
|
||||
logger.warning("observation_space_sample() has not been implemented")
|
||||
del agent_id
|
||||
return {}
|
||||
|
||||
@PublicAPI
|
||||
|
@ -326,8 +327,7 @@ class BaseEnv:
|
|||
"""
|
||||
return self._space_contains(self.action_space, x)
|
||||
|
||||
@staticmethod
|
||||
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
|
||||
def _space_contains(self, space: gym.Space, x: MultiEnvDict) -> bool:
|
||||
"""Check if the given space contains the observations of x.
|
||||
|
||||
Args:
|
||||
|
@ -337,17 +337,14 @@ class BaseEnv:
|
|||
Returns:
|
||||
True if the observations of x are contained in space.
|
||||
"""
|
||||
# this removes the agent_id key and inner dicts
|
||||
# in MultiEnvDicts
|
||||
flattened_obs = {
|
||||
env_id: list(obs.values())
|
||||
for env_id, obs in x.items()
|
||||
}
|
||||
ret = True
|
||||
for env_id in flattened_obs:
|
||||
for obs in flattened_obs[env_id]:
|
||||
ret = ret and space[env_id].contains(obs)
|
||||
return ret
|
||||
agents = set(self.get_agent_ids())
|
||||
for multi_agent_dict in x.values():
|
||||
for agent_id, obs in multi_agent_dict:
|
||||
if (agent_id not in agents) or (
|
||||
not space[agent_id].contains(obs)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Fixed agent identifier when there is only the single agent in the env
|
||||
|
|
216
rllib/env/multi_agent_env.py
vendored
216
rllib/env/multi_agent_env.py
vendored
|
@ -1,15 +1,19 @@
|
|||
import gym
|
||||
from typing import Callable, Dict, List, Tuple, Type, Optional, Union
|
||||
import logging
|
||||
from typing import Callable, Dict, List, Tuple, Type, Optional, Union, Set
|
||||
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI, \
|
||||
DeveloperAPI
|
||||
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
|
||||
MultiEnvDict
|
||||
|
||||
# If the obs space is Dict type, look for the global state under this key.
|
||||
ENV_STATE = "state"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class MultiAgentEnv(gym.Env):
|
||||
|
@ -20,6 +24,15 @@ class MultiAgentEnv(gym.Env):
|
|||
referred to as "agents" or "RL agents".
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.observation_space = None
|
||||
self.action_space = None
|
||||
self._agent_ids = {}
|
||||
|
||||
# do the action and observation spaces map from agent ids to spaces
|
||||
# for the individual agents?
|
||||
self._spaces_in_preferred_format = None
|
||||
|
||||
@PublicAPI
|
||||
def reset(self) -> MultiAgentDict:
|
||||
"""Resets the env and returns observations from ready agents.
|
||||
|
@ -81,6 +94,113 @@ class MultiAgentEnv(gym.Env):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ExperimentalAPI
|
||||
def observation_space_contains(self, x: MultiAgentDict) -> bool:
|
||||
"""Checks if the observation space contains the given key.
|
||||
|
||||
Args:
|
||||
x: Observations to check.
|
||||
|
||||
Returns:
|
||||
True if the observation space contains the given all observations
|
||||
in x.
|
||||
"""
|
||||
if not hasattr(self, "_spaces_in_preferred_format") or \
|
||||
self._spaces_in_preferred_format is None:
|
||||
self._spaces_in_preferred_format = \
|
||||
self._check_if_space_maps_agent_id_to_sub_space()
|
||||
if self._spaces_in_preferred_format:
|
||||
return self.observation_space.contains(x)
|
||||
|
||||
logger.warning("observation_space_contains() has not been implemented")
|
||||
return True
|
||||
|
||||
@ExperimentalAPI
|
||||
def action_space_contains(self, x: MultiAgentDict) -> bool:
|
||||
"""Checks if the action space contains the given action.
|
||||
|
||||
Args:
|
||||
x: Actions to check.
|
||||
|
||||
Returns:
|
||||
True if the action space contains all actions in x.
|
||||
"""
|
||||
if not hasattr(self, "_spaces_in_preferred_format") or \
|
||||
self._spaces_in_preferred_format is None:
|
||||
self._spaces_in_preferred_format = \
|
||||
self._check_if_space_maps_agent_id_to_sub_space()
|
||||
if self._spaces_in_preferred_format:
|
||||
return self.action_space.contains(x)
|
||||
|
||||
logger.warning("action_space_contains() has not been implemented")
|
||||
return True
|
||||
|
||||
@ExperimentalAPI
|
||||
def action_space_sample(self, agent_ids: list = None) -> MultiAgentDict:
|
||||
"""Returns a random action for each environment, and potentially each
|
||||
agent in that environment.
|
||||
|
||||
Args:
|
||||
agent_ids: List of agent ids to sample actions for. If None or
|
||||
empty list, sample actions for all agents in the
|
||||
environment.
|
||||
|
||||
Returns:
|
||||
A random action for each environment.
|
||||
"""
|
||||
if not hasattr(self, "_spaces_in_preferred_format") or \
|
||||
self._spaces_in_preferred_format is None:
|
||||
self._spaces_in_preferred_format = \
|
||||
self._check_if_space_maps_agent_id_to_sub_space()
|
||||
if self._spaces_in_preferred_format:
|
||||
if agent_ids is None:
|
||||
agent_ids = self.get_agent_ids()
|
||||
samples = self.action_space.sample()
|
||||
return {agent_id: samples[agent_id] for agent_id in agent_ids}
|
||||
logger.warning("action_space_sample() has not been implemented")
|
||||
del agent_ids
|
||||
return {}
|
||||
|
||||
@ExperimentalAPI
|
||||
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
|
||||
"""Returns a random observation from the observation space for each
|
||||
agent if agent_ids is None, otherwise returns a random observation for
|
||||
the agents in agent_ids.
|
||||
|
||||
Args:
|
||||
agent_ids: List of agent ids to sample actions for. If None or
|
||||
empty list, sample actions for all agents in the
|
||||
environment.
|
||||
|
||||
Returns:
|
||||
A random action for each environment.
|
||||
"""
|
||||
|
||||
if not hasattr(self, "_spaces_in_preferred_format") or \
|
||||
self._spaces_in_preferred_format is None:
|
||||
self._spaces_in_preferred_format = \
|
||||
self._check_if_space_maps_agent_id_to_sub_space()
|
||||
if self._spaces_in_preferred_format:
|
||||
if agent_ids is None:
|
||||
agent_ids = self.get_agent_ids()
|
||||
samples = self.observation_space.sample()
|
||||
samples = {agent_id: samples[agent_id] for agent_id in agent_ids}
|
||||
return samples
|
||||
logger.warning("observation_space_sample() has not been implemented")
|
||||
del agent_ids
|
||||
return {}
|
||||
|
||||
@PublicAPI
|
||||
def get_agent_ids(self) -> Set[AgentID]:
|
||||
"""Returns a set of agent ids in the environment.
|
||||
|
||||
Returns:
|
||||
set of agent ids.
|
||||
"""
|
||||
if not isinstance(self._agent_ids, set):
|
||||
self._agent_ids = set(self._agent_ids)
|
||||
return self._agent_ids
|
||||
|
||||
@PublicAPI
|
||||
def render(self, mode=None) -> None:
|
||||
"""Tries to render the environment."""
|
||||
|
@ -88,13 +208,13 @@ class MultiAgentEnv(gym.Env):
|
|||
# By default, do nothing.
|
||||
pass
|
||||
|
||||
# yapf: disable
|
||||
# __grouping_doc_begin__
|
||||
# yapf: disable
|
||||
# __grouping_doc_begin__
|
||||
@ExperimentalAPI
|
||||
def with_agent_groups(
|
||||
self,
|
||||
groups: Dict[str, List[AgentID]],
|
||||
obs_space: gym.Space = None,
|
||||
self,
|
||||
groups: Dict[str, List[AgentID]],
|
||||
obs_space: gym.Space = None,
|
||||
act_space: gym.Space = None) -> "MultiAgentEnv":
|
||||
"""Convenience method for grouping together agents in this env.
|
||||
|
||||
|
@ -132,8 +252,9 @@ class MultiAgentEnv(gym.Env):
|
|||
from ray.rllib.env.wrappers.group_agents_wrapper import \
|
||||
GroupAgentsWrapper
|
||||
return GroupAgentsWrapper(self, groups, obs_space, act_space)
|
||||
# __grouping_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
# __grouping_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
@PublicAPI
|
||||
def to_base_env(
|
||||
|
@ -182,6 +303,20 @@ class MultiAgentEnv(gym.Env):
|
|||
|
||||
return env
|
||||
|
||||
@DeveloperAPI
|
||||
def _check_if_space_maps_agent_id_to_sub_space(self) -> bool:
|
||||
# do the action and observation spaces map from agent ids to spaces
|
||||
# for the individual agents?
|
||||
obs_space_check = (
|
||||
hasattr(self, "observation_space")
|
||||
and isinstance(self.observation_space, gym.spaces.Dict)
|
||||
and set(self.observation_space.keys()) == self.get_agent_ids())
|
||||
action_space_check = (
|
||||
hasattr(self, "action_space")
|
||||
and isinstance(self.action_space, gym.spaces.Dict)
|
||||
and set(self.action_space.keys()) == self.get_agent_ids())
|
||||
return obs_space_check and action_space_check
|
||||
|
||||
|
||||
def make_multi_agent(
|
||||
env_name_or_creator: Union[str, Callable[[EnvContext], EnvType]],
|
||||
|
@ -242,6 +377,40 @@ def make_multi_agent(
|
|||
self.dones = set()
|
||||
self.observation_space = self.agents[0].observation_space
|
||||
self.action_space = self.agents[0].action_space
|
||||
self._agent_ids = set(range(num))
|
||||
|
||||
@override(MultiAgentEnv)
|
||||
def observation_space_sample(self,
|
||||
agent_ids: list = None) -> MultiAgentDict:
|
||||
if agent_ids is None:
|
||||
agent_ids = list(range(len(self.agents)))
|
||||
obs = {
|
||||
agent_id: self.observation_space.sample()
|
||||
for agent_id in agent_ids
|
||||
}
|
||||
|
||||
return obs
|
||||
|
||||
@override(MultiAgentEnv)
|
||||
def action_space_sample(self,
|
||||
agent_ids: list = None) -> MultiAgentDict:
|
||||
if agent_ids is None:
|
||||
agent_ids = list(range(len(self.agents)))
|
||||
actions = {
|
||||
agent_id: self.action_space.sample()
|
||||
for agent_id in agent_ids
|
||||
}
|
||||
|
||||
return actions
|
||||
|
||||
@override(MultiAgentEnv)
|
||||
def action_space_contains(self, x: MultiAgentDict) -> bool:
|
||||
return all(self.action_space.contains(val) for val in x.values())
|
||||
|
||||
@override(MultiAgentEnv)
|
||||
def observation_space_contains(self, x: MultiAgentDict) -> bool:
|
||||
return all(
|
||||
self.observation_space.contains(val) for val in x.values())
|
||||
|
||||
@override(MultiAgentEnv)
|
||||
def reset(self):
|
||||
|
@ -277,7 +446,7 @@ class MultiAgentEnvWrapper(BaseEnv):
|
|||
|
||||
Args:
|
||||
make_env (Callable[[int], EnvType]): Factory that produces a new
|
||||
MultiAgentEnv intance. Must be defined, if the number of
|
||||
MultiAgentEnv instance. Must be defined, if the number of
|
||||
existing envs is less than num_envs.
|
||||
existing_envs (List[MultiAgentEnv]): List of already existing
|
||||
multi-agent envs.
|
||||
|
@ -355,18 +524,31 @@ class MultiAgentEnvWrapper(BaseEnv):
|
|||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def observation_space(self) -> gym.spaces.Dict:
|
||||
space = {
|
||||
_id: env.observation_space
|
||||
for _id, env in enumerate(self.envs)
|
||||
}
|
||||
return gym.spaces.Dict(space)
|
||||
self.envs[0].observation_space
|
||||
|
||||
@property
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def action_space(self) -> gym.Space:
|
||||
space = {_id: env.action_space for _id, env in enumerate(self.envs)}
|
||||
return gym.spaces.Dict(space)
|
||||
return self.envs[0].action_space
|
||||
|
||||
@override(BaseEnv)
|
||||
def observation_space_contains(self, x: MultiEnvDict) -> bool:
|
||||
return all(
|
||||
self.envs[0].observation_space_contains(val) for val in x.values())
|
||||
|
||||
@override(BaseEnv)
|
||||
def action_space_contains(self, x: MultiEnvDict) -> bool:
|
||||
return all(
|
||||
self.envs[0].action_space_contains(val) for val in x.values())
|
||||
|
||||
@override(BaseEnv)
|
||||
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
|
||||
return self.envs[0].observation_space_sample(agent_ids)
|
||||
|
||||
@override(BaseEnv)
|
||||
def action_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
|
||||
return self.envs[0].action_space_sample(agent_ids)
|
||||
|
||||
|
||||
class _MultiAgentEnvState:
|
||||
|
|
50
rllib/env/tests/test_multi_agent_env.py
vendored
Normal file
50
rllib/env/tests/test_multi_agent_env.py
vendored
Normal file
|
@ -0,0 +1,50 @@
|
|||
import pytest
|
||||
from ray.rllib.env.multi_agent_env import make_multi_agent
|
||||
from ray.rllib.tests.test_nested_observation_spaces import NestedMultiAgentEnv
|
||||
|
||||
|
||||
class TestMultiAgentEnv:
|
||||
def test_space_in_preferred_format(self):
|
||||
env = NestedMultiAgentEnv()
|
||||
spaces_in_preferred_format = \
|
||||
env._check_if_space_maps_agent_id_to_sub_space()
|
||||
assert spaces_in_preferred_format, "Space is not in preferred " \
|
||||
"format"
|
||||
env2 = make_multi_agent("CartPole-v1")()
|
||||
spaces_in_preferred_format = \
|
||||
env2._check_if_space_maps_agent_id_to_sub_space()
|
||||
assert not spaces_in_preferred_format, "Space should not be in " \
|
||||
"preferred format but is."
|
||||
|
||||
def test_spaces_sample_contain_in_preferred_format(self):
|
||||
env = NestedMultiAgentEnv()
|
||||
# this environment has spaces that are in the preferred format
|
||||
# for multi-agent environments where the spaces are dict spaces
|
||||
# mapping agent-ids to sub-spaces
|
||||
obs = env.observation_space_sample()
|
||||
assert env.observation_space_contains(obs), "Observation space does " \
|
||||
"not contain obs"
|
||||
|
||||
action = env.action_space_sample()
|
||||
assert env.action_space_contains(action), "Action space does " \
|
||||
"not contain action"
|
||||
|
||||
def test_spaces_sample_contain_not_in_preferred_format(self):
|
||||
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
|
||||
# this environment has spaces that are not in the preferred format
|
||||
# for multi-agent environments where the spaces not in the preferred
|
||||
# format, users must override the observation_space_contains,
|
||||
# action_space_contains observation_space_sample,
|
||||
# and action_space_sample methods in order to do proper checks
|
||||
obs = env.observation_space_sample()
|
||||
assert env.observation_space_contains(obs), "Observation space does " \
|
||||
"not contain obs"
|
||||
action = env.action_space_sample()
|
||||
assert env.action_space_contains(action), "Action space does " \
|
||||
"not contain action"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
21
rllib/env/vector_env.py
vendored
21
rllib/env/vector_env.py
vendored
|
@ -339,24 +339,3 @@ class VectorEnvWrapper(BaseEnv):
|
|||
@PublicAPI
|
||||
def action_space(self) -> gym.Space:
|
||||
return self._action_space
|
||||
|
||||
@staticmethod
|
||||
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
|
||||
"""Check if the given space contains the observations of x.
|
||||
|
||||
Args:
|
||||
space: The space to if x's observations are contained in.
|
||||
x: The observations to check.
|
||||
|
||||
Note: With vector envs, we can process the raw observations
|
||||
and ignore the agent ids and env ids, since vector envs'
|
||||
sub environements are guaranteed to be the same
|
||||
|
||||
Returns:
|
||||
True if the observations of x are contained in space.
|
||||
"""
|
||||
for _, multi_agent_dict in x.items():
|
||||
for _, element in multi_agent_dict.items():
|
||||
if not space.contains(element):
|
||||
return False
|
||||
return True
|
||||
|
|
|
@ -120,6 +120,15 @@ class RepeatedSpaceEnv(gym.Env):
|
|||
|
||||
class NestedMultiAgentEnv(MultiAgentEnv):
|
||||
def __init__(self):
|
||||
self.observation_space = spaces.Dict({
|
||||
"dict_agent": DICT_SPACE,
|
||||
"tuple_agent": TUPLE_SPACE
|
||||
})
|
||||
self.action_space = spaces.Dict({
|
||||
"dict_agent": spaces.Discrete(1),
|
||||
"tuple_agent": spaces.Discrete(1)
|
||||
})
|
||||
self._agent_ids = {"dict_agent", "tuple_agent"}
|
||||
self.steps = 0
|
||||
|
||||
def reset(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue