[RLlib] Add necessary fields to Base Envs, and BaseEnv wrapper classes (#20832)

This commit is contained in:
Avnish Narayan 2021-12-09 05:40:40 -08:00 committed by GitHub
parent 8bb9bfe632
commit 6996eaa986
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 192 additions and 21 deletions

View file

@ -290,7 +290,7 @@ py_test(
name = "learning_cartpole_simpleq_fake_gpus", name = "learning_cartpole_simpleq_fake_gpus",
main = "tests/run_regression_tests.py", main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"], tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
size = "large", size = "medium",
srcs = ["tests/run_regression_tests.py"], srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/dqn/cartpole-simpleq-fake-gpus.yaml"], data = ["tuned_examples/dqn/cartpole-simpleq-fake-gpus.yaml"],
args = ["--yaml-dir=tuned_examples/dqn"] args = ["--yaml-dir=tuned_examples/dqn"]
@ -468,7 +468,7 @@ py_test(
py_test( py_test(
name = "learning_tests_transformed_actions_pendulum_sac", name = "learning_tests_transformed_actions_pendulum_sac",
main = "tests/run_regression_tests.py", main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"], tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "flaky"],
size = "large", size = "large",
srcs = ["tests/run_regression_tests.py"], srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"], data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"],
@ -478,7 +478,7 @@ py_test(
py_test( py_test(
name = "learning_pendulum_sac_fake_gpus", name = "learning_pendulum_sac_fake_gpus",
main = "tests/run_regression_tests.py", main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus"], tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus", "flaky"],
size = "large", size = "large",
srcs = ["tests/run_regression_tests.py"], srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/sac/pendulum-sac-fake-gpus.yaml"], data = ["tuned_examples/sac/pendulum-sac-fake-gpus.yaml"],

72
rllib/env/base_env.py vendored
View file

@ -1,6 +1,7 @@
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\ from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
Union Union
import gym
import ray import ray
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \ from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
@ -31,12 +32,6 @@ class BaseEnv:
rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
rllib.ExternalEnv => rllib.BaseEnv rllib.ExternalEnv => rllib.BaseEnv
Attributes:
action_space (gym.Space): Action space. This must be defined for
single-agent envs. Multi-agent envs can set this to None.
observation_space (gym.Space): Observation space. This must be defined
for single-agent envs. Multi-agent envs can set this to None.
Examples: Examples:
>>> env = MyBaseEnv() >>> env = MyBaseEnv()
>>> obs, rewards, dones, infos, off_policy_actions = env.poll() >>> obs, rewards, dones, infos, off_policy_actions = env.poll()
@ -185,12 +180,18 @@ class BaseEnv:
return None return None
@PublicAPI @PublicAPI
def get_sub_environments(self) -> List[EnvType]: def get_sub_environments(
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
"""Return a reference to the underlying sub environments, if any. """Return a reference to the underlying sub environments, if any.
Args:
as_dict: If True, return a dict mapping from env_id to env.
Returns: Returns:
List of the underlying sub environments or []. List or dictionary of the underlying sub environments or [] / {}.
""" """
if as_dict:
return {}
return [] return []
@PublicAPI @PublicAPI
@ -218,6 +219,61 @@ class BaseEnv:
def get_unwrapped(self) -> List[EnvType]: def get_unwrapped(self) -> List[EnvType]:
return self.get_sub_environments() return self.get_sub_environments()
@PublicAPI
@property
def observation_space(self) -> gym.spaces.Dict:
"""Returns the observation space for each environment.
Note: samples from the observation space need to be preprocessed into a
`MultiEnvDict` before being used by a policy.
Returns:
The observation space for each environment.
"""
raise NotImplementedError
@PublicAPI
@property
def action_space(self) -> gym.Space:
"""Returns the action space for each environment.
Note: samples from the action space need to be preprocessed into a
`MultiEnvDict` before being passed to `send_actions`.
Returns:
The observation space for each environment.
"""
raise NotImplementedError
def observation_space_contains(self, x: MultiEnvDict) -> bool:
self._space_contains(self.observation_space, x)
def action_space_contains(self, x: MultiEnvDict) -> bool:
return self._space_contains(self.action_space, x)
@staticmethod
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
"""Check if the given space contains the observations of x.
Args:
space: The space to if x's observations are contained in.
x: The observations to check.
Returns:
True if the observations of x are contained in space.
"""
# this removes the agent_id key and inner dicts
# in MultiEnvDicts
flattened_obs = {
env_id: list(obs.values())
for env_id, obs in x.items()
}
ret = True
for env_id in flattened_obs:
for obs in flattened_obs[env_id]:
ret = ret and space[env_id].contains(obs)
return ret
# Fixed agent identifier when there is only the single agent in the env # Fixed agent identifier when there is only the single agent in the env
_DUMMY_AGENT_ID = "agent0" _DUMMY_AGENT_ID = "agent0"

View file

@ -337,11 +337,11 @@ class ExternalEnvWrapper(BaseEnv):
self.external_env = external_env self.external_env = external_env
self.prep = preprocessor self.prep = preprocessor
self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv) self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
self.action_space = external_env.action_space self._action_space = external_env.action_space
if preprocessor: if preprocessor:
self.observation_space = preprocessor.observation_space self._observation_space = preprocessor.observation_space
else: else:
self.observation_space = external_env.observation_space self._observation_space = external_env.observation_space
external_env.start() external_env.start()
@override(BaseEnv) @override(BaseEnv)
@ -413,3 +413,15 @@ class ExternalEnvWrapper(BaseEnv):
with_dummy_agent_id(all_dones, "__all__"), \ with_dummy_agent_id(all_dones, "__all__"), \
with_dummy_agent_id(all_infos), \ with_dummy_agent_id(all_infos), \
with_dummy_agent_id(off_policy_actions) with_dummy_agent_id(off_policy_actions)
@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
return self._observation_space
@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space

View file

@ -336,7 +336,12 @@ class MultiAgentEnvWrapper(BaseEnv):
return obs return obs
@override(BaseEnv) @override(BaseEnv)
def get_sub_environments(self) -> List[EnvType]: def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
if as_dict:
return {
_id: env_state
for _id, env_state in enumerate(self.env_states)
}
return [state.env for state in self.env_states] return [state.env for state in self.env_states]
@override(BaseEnv) @override(BaseEnv)
@ -346,6 +351,23 @@ class MultiAgentEnvWrapper(BaseEnv):
assert isinstance(env_id, int) assert isinstance(env_id, int)
return self.envs[env_id].render() return self.envs[env_id].render()
@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
space = {
_id: env.observation_space
for _id, env in enumerate(self.envs)
}
return gym.spaces.Dict(space)
@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
space = {_id: env.action_space for _id, env in enumerate(self.envs)}
return gym.spaces.Dict(space)
class _MultiAgentEnvState: class _MultiAgentEnvState:
def __init__(self, env: MultiAgentEnv): def __init__(self, env: MultiAgentEnv):

View file

@ -1,6 +1,8 @@
import logging import logging
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
import gym
import ray import ray
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.annotations import override, PublicAPI
@ -17,6 +19,8 @@ class RemoteBaseEnv(BaseEnv):
from the remote simulator actors. Both single and multi-agent child envs from the remote simulator actors. Both single and multi-agent child envs
are supported, and envs can be stepped synchronously or asynchronously. are supported, and envs can be stepped synchronously or asynchronously.
NOTE: This class implicitly assumes that the remote envs are gym.Env's
You shouldn't need to instantiate this class directly. It's automatically You shouldn't need to instantiate this class directly. It's automatically
inserted when you use the `remote_worker_envs=True` option in your inserted when you use the `remote_worker_envs=True` option in your
Trainer's config. Trainer's config.
@ -61,6 +65,8 @@ class RemoteBaseEnv(BaseEnv):
# List of ray actor handles (each handle points to one @ray.remote # List of ray actor handles (each handle points to one @ray.remote
# sub-environment). # sub-environment).
self.actors: Optional[List[ray.actor.ActorHandle]] = None self.actors: Optional[List[ray.actor.ActorHandle]] = None
self._observation_space = None
self._action_space = None
# Dict mapping object refs (return values of @ray.remote calls), # Dict mapping object refs (return values of @ray.remote calls),
# whose actual values we are waiting for (via ray.wait in # whose actual values we are waiting for (via ray.wait in
# `self.poll()`) to their corresponding actor handles (the actors # `self.poll()`) to their corresponding actor handles (the actors
@ -97,6 +103,10 @@ class RemoteBaseEnv(BaseEnv):
self.actors = [ self.actors = [
make_remote_env(i) for i in range(self.num_envs) make_remote_env(i) for i in range(self.num_envs)
] ]
self._observation_space = ray.get(
self.actors[0].observation_space.remote())
self._action_space = ray.get(
self.actors[0].action_space.remote())
# Lazy initialization. Call `reset()` on all @ray.remote # Lazy initialization. Call `reset()` on all @ray.remote
# sub-environment actors at the beginning. # sub-environment actors at the beginning.
@ -199,9 +209,23 @@ class RemoteBaseEnv(BaseEnv):
@override(BaseEnv) @override(BaseEnv)
@PublicAPI @PublicAPI
def get_sub_environments(self) -> List[EnvType]: def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
if as_dict:
return {env_id: actor for env_id, actor in enumerate(self.actors)}
return self.actors return self.actors
@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
return self._observation_space
@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space
@ray.remote(num_cpus=0) @ray.remote(num_cpus=0)
class _RemoteMultiAgentEnv: class _RemoteMultiAgentEnv:
@ -221,6 +245,14 @@ class _RemoteMultiAgentEnv:
def step(self, action_dict): def step(self, action_dict):
return self.env.step(action_dict) return self.env.step(action_dict)
# defining these 2 functions that way this information can be queried
# with a call to ray.get()
def observation_space(self):
return self.env.observation_space
def action_space(self):
return self.env.action_space
@ray.remote(num_cpus=0) @ray.remote(num_cpus=0)
class _RemoteSingleAgentEnv: class _RemoteSingleAgentEnv:
@ -243,3 +275,11 @@ class _RemoteSingleAgentEnv:
} for x in [obs, rew, done, info]] } for x in [obs, rew, done, info]]
done["__all__"] = done[_DUMMY_AGENT_ID] done["__all__"] = done[_DUMMY_AGENT_ID]
return obs, rew, done, info return obs, rew, done, info
# defining these 2 functions that way this information can be queried
# with a call to ray.get()
def observation_space(self):
return self.env.observation_space
def action_space(self):
return self.env.action_space

View file

@ -1,7 +1,7 @@
import logging import logging
import gym import gym
import numpy as np import numpy as np
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple, Union
from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.base_env import BaseEnv
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
@ -265,13 +265,13 @@ class VectorEnvWrapper(BaseEnv):
def __init__(self, vector_env: VectorEnv): def __init__(self, vector_env: VectorEnv):
self.vector_env = vector_env self.vector_env = vector_env
self.action_space = vector_env.action_space
self.observation_space = vector_env.observation_space
self.num_envs = vector_env.num_envs self.num_envs = vector_env.num_envs
self.new_obs = None # lazily initialized self.new_obs = None # lazily initialized
self.cur_rewards = [None for _ in range(self.num_envs)] self.cur_rewards = [None for _ in range(self.num_envs)]
self.cur_dones = [False for _ in range(self.num_envs)] self.cur_dones = [False for _ in range(self.num_envs)]
self.cur_infos = [None for _ in range(self.num_envs)] self.cur_infos = [None for _ in range(self.num_envs)]
self._observation_space = vector_env.observation_space
self._action_space = vector_env.action_space
@override(BaseEnv) @override(BaseEnv)
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
@ -312,10 +312,51 @@ class VectorEnvWrapper(BaseEnv):
} }
@override(BaseEnv) @override(BaseEnv)
def get_sub_environments(self) -> List[EnvType]: def get_sub_environments(
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
if not as_dict:
return self.vector_env.get_sub_environments() return self.vector_env.get_sub_environments()
else:
return {
_id: env
for _id, env in enumerate(
self.vector_env.get_sub_environments())
}
@override(BaseEnv) @override(BaseEnv)
def try_render(self, env_id: Optional[EnvID] = None) -> None: def try_render(self, env_id: Optional[EnvID] = None) -> None:
assert env_id is None or isinstance(env_id, int) assert env_id is None or isinstance(env_id, int)
return self.vector_env.try_render_at(env_id) return self.vector_env.try_render_at(env_id)
@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
return self._observation_space
@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space
@staticmethod
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
"""Check if the given space contains the observations of x.
Args:
space: The space to if x's observations are contained in.
x: The observations to check.
Note: With vector envs, we can process the raw observations
and ignore the agent ids and env ids, since vector envs'
sub environements are guaranteed to be the same
Returns:
True if the observations of x are contained in space.
"""
for _, multi_agent_dict in x.items():
for _, element in multi_agent_dict.items():
if not space.contains(element):
return False
return True