[RLlib] Add necessary fields to Base Envs, and BaseEnv wrapper classes (#20832)

This commit is contained in:
Avnish Narayan 2021-12-09 05:40:40 -08:00 committed by GitHub
parent 8bb9bfe632
commit 6996eaa986
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 192 additions and 21 deletions

View file

@ -290,7 +290,7 @@ py_test(
name = "learning_cartpole_simpleq_fake_gpus",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
size = "large",
size = "medium",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/dqn/cartpole-simpleq-fake-gpus.yaml"],
args = ["--yaml-dir=tuned_examples/dqn"]
@ -468,7 +468,7 @@ py_test(
py_test(
name = "learning_tests_transformed_actions_pendulum_sac",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "flaky"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"],
@ -478,7 +478,7 @@ py_test(
py_test(
name = "learning_pendulum_sac_fake_gpus",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus"],
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus", "flaky"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/sac/pendulum-sac-fake-gpus.yaml"],

72
rllib/env/base_env.py vendored
View file

@ -1,6 +1,7 @@
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
Union
import gym
import ray
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
@ -31,12 +32,6 @@ class BaseEnv:
rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
rllib.ExternalEnv => rllib.BaseEnv
Attributes:
action_space (gym.Space): Action space. This must be defined for
single-agent envs. Multi-agent envs can set this to None.
observation_space (gym.Space): Observation space. This must be defined
for single-agent envs. Multi-agent envs can set this to None.
Examples:
>>> env = MyBaseEnv()
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
@ -185,12 +180,18 @@ class BaseEnv:
return None
@PublicAPI
def get_sub_environments(self) -> List[EnvType]:
def get_sub_environments(
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
"""Return a reference to the underlying sub environments, if any.
Args:
as_dict: If True, return a dict mapping from env_id to env.
Returns:
List of the underlying sub environments or [].
List or dictionary of the underlying sub environments or [] / {}.
"""
if as_dict:
return {}
return []
@PublicAPI
@ -218,6 +219,61 @@ class BaseEnv:
def get_unwrapped(self) -> List[EnvType]:
return self.get_sub_environments()
@PublicAPI
@property
def observation_space(self) -> gym.spaces.Dict:
"""Returns the observation space for each environment.
Note: samples from the observation space need to be preprocessed into a
`MultiEnvDict` before being used by a policy.
Returns:
The observation space for each environment.
"""
raise NotImplementedError
@PublicAPI
@property
def action_space(self) -> gym.Space:
"""Returns the action space for each environment.
Note: samples from the action space need to be preprocessed into a
`MultiEnvDict` before being passed to `send_actions`.
Returns:
The observation space for each environment.
"""
raise NotImplementedError
def observation_space_contains(self, x: MultiEnvDict) -> bool:
self._space_contains(self.observation_space, x)
def action_space_contains(self, x: MultiEnvDict) -> bool:
return self._space_contains(self.action_space, x)
@staticmethod
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
"""Check if the given space contains the observations of x.
Args:
space: The space to if x's observations are contained in.
x: The observations to check.
Returns:
True if the observations of x are contained in space.
"""
# this removes the agent_id key and inner dicts
# in MultiEnvDicts
flattened_obs = {
env_id: list(obs.values())
for env_id, obs in x.items()
}
ret = True
for env_id in flattened_obs:
for obs in flattened_obs[env_id]:
ret = ret and space[env_id].contains(obs)
return ret
# Fixed agent identifier when there is only the single agent in the env
_DUMMY_AGENT_ID = "agent0"

View file

@ -337,11 +337,11 @@ class ExternalEnvWrapper(BaseEnv):
self.external_env = external_env
self.prep = preprocessor
self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
self.action_space = external_env.action_space
self._action_space = external_env.action_space
if preprocessor:
self.observation_space = preprocessor.observation_space
self._observation_space = preprocessor.observation_space
else:
self.observation_space = external_env.observation_space
self._observation_space = external_env.observation_space
external_env.start()
@override(BaseEnv)
@ -413,3 +413,15 @@ class ExternalEnvWrapper(BaseEnv):
with_dummy_agent_id(all_dones, "__all__"), \
with_dummy_agent_id(all_infos), \
with_dummy_agent_id(off_policy_actions)
@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
return self._observation_space
@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space

View file

@ -336,7 +336,12 @@ class MultiAgentEnvWrapper(BaseEnv):
return obs
@override(BaseEnv)
def get_sub_environments(self) -> List[EnvType]:
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
if as_dict:
return {
_id: env_state
for _id, env_state in enumerate(self.env_states)
}
return [state.env for state in self.env_states]
@override(BaseEnv)
@ -346,6 +351,23 @@ class MultiAgentEnvWrapper(BaseEnv):
assert isinstance(env_id, int)
return self.envs[env_id].render()
@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
space = {
_id: env.observation_space
for _id, env in enumerate(self.envs)
}
return gym.spaces.Dict(space)
@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
space = {_id: env.action_space for _id, env in enumerate(self.envs)}
return gym.spaces.Dict(space)
class _MultiAgentEnvState:
def __init__(self, env: MultiAgentEnv):

View file

@ -1,6 +1,8 @@
import logging
from typing import Callable, Dict, List, Optional, Tuple
import gym
import ray
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
from ray.rllib.utils.annotations import override, PublicAPI
@ -17,6 +19,8 @@ class RemoteBaseEnv(BaseEnv):
from the remote simulator actors. Both single and multi-agent child envs
are supported, and envs can be stepped synchronously or asynchronously.
NOTE: This class implicitly assumes that the remote envs are gym.Env's
You shouldn't need to instantiate this class directly. It's automatically
inserted when you use the `remote_worker_envs=True` option in your
Trainer's config.
@ -61,6 +65,8 @@ class RemoteBaseEnv(BaseEnv):
# List of ray actor handles (each handle points to one @ray.remote
# sub-environment).
self.actors: Optional[List[ray.actor.ActorHandle]] = None
self._observation_space = None
self._action_space = None
# Dict mapping object refs (return values of @ray.remote calls),
# whose actual values we are waiting for (via ray.wait in
# `self.poll()`) to their corresponding actor handles (the actors
@ -97,6 +103,10 @@ class RemoteBaseEnv(BaseEnv):
self.actors = [
make_remote_env(i) for i in range(self.num_envs)
]
self._observation_space = ray.get(
self.actors[0].observation_space.remote())
self._action_space = ray.get(
self.actors[0].action_space.remote())
# Lazy initialization. Call `reset()` on all @ray.remote
# sub-environment actors at the beginning.
@ -199,9 +209,23 @@ class RemoteBaseEnv(BaseEnv):
@override(BaseEnv)
@PublicAPI
def get_sub_environments(self) -> List[EnvType]:
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
if as_dict:
return {env_id: actor for env_id, actor in enumerate(self.actors)}
return self.actors
@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
return self._observation_space
@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space
@ray.remote(num_cpus=0)
class _RemoteMultiAgentEnv:
@ -221,6 +245,14 @@ class _RemoteMultiAgentEnv:
def step(self, action_dict):
return self.env.step(action_dict)
# defining these 2 functions that way this information can be queried
# with a call to ray.get()
def observation_space(self):
return self.env.observation_space
def action_space(self):
return self.env.action_space
@ray.remote(num_cpus=0)
class _RemoteSingleAgentEnv:
@ -243,3 +275,11 @@ class _RemoteSingleAgentEnv:
} for x in [obs, rew, done, info]]
done["__all__"] = done[_DUMMY_AGENT_ID]
return obs, rew, done, info
# defining these 2 functions that way this information can be queried
# with a call to ray.get()
def observation_space(self):
return self.env.observation_space
def action_space(self):
return self.env.action_space

View file

@ -1,7 +1,7 @@
import logging
import gym
import numpy as np
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Union
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
@ -265,13 +265,13 @@ class VectorEnvWrapper(BaseEnv):
def __init__(self, vector_env: VectorEnv):
self.vector_env = vector_env
self.action_space = vector_env.action_space
self.observation_space = vector_env.observation_space
self.num_envs = vector_env.num_envs
self.new_obs = None # lazily initialized
self.cur_rewards = [None for _ in range(self.num_envs)]
self.cur_dones = [False for _ in range(self.num_envs)]
self.cur_infos = [None for _ in range(self.num_envs)]
self._observation_space = vector_env.observation_space
self._action_space = vector_env.action_space
@override(BaseEnv)
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
@ -312,10 +312,51 @@ class VectorEnvWrapper(BaseEnv):
}
@override(BaseEnv)
def get_sub_environments(self) -> List[EnvType]:
def get_sub_environments(
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
if not as_dict:
return self.vector_env.get_sub_environments()
else:
return {
_id: env
for _id, env in enumerate(
self.vector_env.get_sub_environments())
}
@override(BaseEnv)
def try_render(self, env_id: Optional[EnvID] = None) -> None:
assert env_id is None or isinstance(env_id, int)
return self.vector_env.try_render_at(env_id)
@property
@override(BaseEnv)
@PublicAPI
def observation_space(self) -> gym.spaces.Dict:
return self._observation_space
@property
@override(BaseEnv)
@PublicAPI
def action_space(self) -> gym.Space:
return self._action_space
@staticmethod
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
"""Check if the given space contains the observations of x.
Args:
space: The space to if x's observations are contained in.
x: The observations to check.
Note: With vector envs, we can process the raw observations
and ignore the agent ids and env ids, since vector envs'
sub environements are guaranteed to be the same
Returns:
True if the observations of x are contained in space.
"""
for _, multi_agent_dict in x.items():
for _, element in multi_agent_dict.items():
if not space.contains(element):
return False
return True