mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Add necessary fields to Base Envs, and BaseEnv wrapper classes (#20832)
This commit is contained in:
parent
8bb9bfe632
commit
6996eaa986
6 changed files with 192 additions and 21 deletions
|
@ -290,7 +290,7 @@ py_test(
|
|||
name = "learning_cartpole_simpleq_fake_gpus",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
|
||||
size = "large",
|
||||
size = "medium",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/dqn/cartpole-simpleq-fake-gpus.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/dqn"]
|
||||
|
@ -468,7 +468,7 @@ py_test(
|
|||
py_test(
|
||||
name = "learning_tests_transformed_actions_pendulum_sac",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
|
||||
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "flaky"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"],
|
||||
|
@ -478,7 +478,7 @@ py_test(
|
|||
py_test(
|
||||
name = "learning_pendulum_sac_fake_gpus",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus"],
|
||||
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus", "flaky"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/sac/pendulum-sac-fake-gpus.yaml"],
|
||||
|
|
72
rllib/env/base_env.py
vendored
72
rllib/env/base_env.py
vendored
|
@ -1,6 +1,7 @@
|
|||
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
|
||||
Union
|
||||
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
|
||||
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
|
||||
|
@ -31,12 +32,6 @@ class BaseEnv:
|
|||
rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
|
||||
rllib.ExternalEnv => rllib.BaseEnv
|
||||
|
||||
Attributes:
|
||||
action_space (gym.Space): Action space. This must be defined for
|
||||
single-agent envs. Multi-agent envs can set this to None.
|
||||
observation_space (gym.Space): Observation space. This must be defined
|
||||
for single-agent envs. Multi-agent envs can set this to None.
|
||||
|
||||
Examples:
|
||||
>>> env = MyBaseEnv()
|
||||
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
|
||||
|
@ -185,12 +180,18 @@ class BaseEnv:
|
|||
return None
|
||||
|
||||
@PublicAPI
|
||||
def get_sub_environments(self) -> List[EnvType]:
|
||||
def get_sub_environments(
|
||||
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
|
||||
"""Return a reference to the underlying sub environments, if any.
|
||||
|
||||
Args:
|
||||
as_dict: If True, return a dict mapping from env_id to env.
|
||||
|
||||
Returns:
|
||||
List of the underlying sub environments or [].
|
||||
List or dictionary of the underlying sub environments or [] / {}.
|
||||
"""
|
||||
if as_dict:
|
||||
return {}
|
||||
return []
|
||||
|
||||
@PublicAPI
|
||||
|
@ -218,6 +219,61 @@ class BaseEnv:
|
|||
def get_unwrapped(self) -> List[EnvType]:
|
||||
return self.get_sub_environments()
|
||||
|
||||
@PublicAPI
|
||||
@property
|
||||
def observation_space(self) -> gym.spaces.Dict:
|
||||
"""Returns the observation space for each environment.
|
||||
|
||||
Note: samples from the observation space need to be preprocessed into a
|
||||
`MultiEnvDict` before being used by a policy.
|
||||
|
||||
Returns:
|
||||
The observation space for each environment.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
"""Returns the action space for each environment.
|
||||
|
||||
Note: samples from the action space need to be preprocessed into a
|
||||
`MultiEnvDict` before being passed to `send_actions`.
|
||||
|
||||
Returns:
|
||||
The observation space for each environment.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def observation_space_contains(self, x: MultiEnvDict) -> bool:
|
||||
self._space_contains(self.observation_space, x)
|
||||
|
||||
def action_space_contains(self, x: MultiEnvDict) -> bool:
|
||||
return self._space_contains(self.action_space, x)
|
||||
|
||||
@staticmethod
|
||||
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
|
||||
"""Check if the given space contains the observations of x.
|
||||
|
||||
Args:
|
||||
space: The space to if x's observations are contained in.
|
||||
x: The observations to check.
|
||||
|
||||
Returns:
|
||||
True if the observations of x are contained in space.
|
||||
"""
|
||||
# this removes the agent_id key and inner dicts
|
||||
# in MultiEnvDicts
|
||||
flattened_obs = {
|
||||
env_id: list(obs.values())
|
||||
for env_id, obs in x.items()
|
||||
}
|
||||
ret = True
|
||||
for env_id in flattened_obs:
|
||||
for obs in flattened_obs[env_id]:
|
||||
ret = ret and space[env_id].contains(obs)
|
||||
return ret
|
||||
|
||||
|
||||
# Fixed agent identifier when there is only the single agent in the env
|
||||
_DUMMY_AGENT_ID = "agent0"
|
||||
|
|
18
rllib/env/external_env.py
vendored
18
rllib/env/external_env.py
vendored
|
@ -337,11 +337,11 @@ class ExternalEnvWrapper(BaseEnv):
|
|||
self.external_env = external_env
|
||||
self.prep = preprocessor
|
||||
self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
|
||||
self.action_space = external_env.action_space
|
||||
self._action_space = external_env.action_space
|
||||
if preprocessor:
|
||||
self.observation_space = preprocessor.observation_space
|
||||
self._observation_space = preprocessor.observation_space
|
||||
else:
|
||||
self.observation_space = external_env.observation_space
|
||||
self._observation_space = external_env.observation_space
|
||||
external_env.start()
|
||||
|
||||
@override(BaseEnv)
|
||||
|
@ -413,3 +413,15 @@ class ExternalEnvWrapper(BaseEnv):
|
|||
with_dummy_agent_id(all_dones, "__all__"), \
|
||||
with_dummy_agent_id(all_infos), \
|
||||
with_dummy_agent_id(off_policy_actions)
|
||||
|
||||
@property
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def observation_space(self) -> gym.spaces.Dict:
|
||||
return self._observation_space
|
||||
|
||||
@property
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def action_space(self) -> gym.Space:
|
||||
return self._action_space
|
||||
|
|
24
rllib/env/multi_agent_env.py
vendored
24
rllib/env/multi_agent_env.py
vendored
|
@ -336,7 +336,12 @@ class MultiAgentEnvWrapper(BaseEnv):
|
|||
return obs
|
||||
|
||||
@override(BaseEnv)
|
||||
def get_sub_environments(self) -> List[EnvType]:
|
||||
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
|
||||
if as_dict:
|
||||
return {
|
||||
_id: env_state
|
||||
for _id, env_state in enumerate(self.env_states)
|
||||
}
|
||||
return [state.env for state in self.env_states]
|
||||
|
||||
@override(BaseEnv)
|
||||
|
@ -346,6 +351,23 @@ class MultiAgentEnvWrapper(BaseEnv):
|
|||
assert isinstance(env_id, int)
|
||||
return self.envs[env_id].render()
|
||||
|
||||
@property
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def observation_space(self) -> gym.spaces.Dict:
|
||||
space = {
|
||||
_id: env.observation_space
|
||||
for _id, env in enumerate(self.envs)
|
||||
}
|
||||
return gym.spaces.Dict(space)
|
||||
|
||||
@property
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def action_space(self) -> gym.Space:
|
||||
space = {_id: env.action_space for _id, env in enumerate(self.envs)}
|
||||
return gym.spaces.Dict(space)
|
||||
|
||||
|
||||
class _MultiAgentEnvState:
|
||||
def __init__(self, env: MultiAgentEnv):
|
||||
|
|
42
rllib/env/remote_base_env.py
vendored
42
rllib/env/remote_base_env.py
vendored
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
|
@ -17,6 +19,8 @@ class RemoteBaseEnv(BaseEnv):
|
|||
from the remote simulator actors. Both single and multi-agent child envs
|
||||
are supported, and envs can be stepped synchronously or asynchronously.
|
||||
|
||||
NOTE: This class implicitly assumes that the remote envs are gym.Env's
|
||||
|
||||
You shouldn't need to instantiate this class directly. It's automatically
|
||||
inserted when you use the `remote_worker_envs=True` option in your
|
||||
Trainer's config.
|
||||
|
@ -61,6 +65,8 @@ class RemoteBaseEnv(BaseEnv):
|
|||
# List of ray actor handles (each handle points to one @ray.remote
|
||||
# sub-environment).
|
||||
self.actors: Optional[List[ray.actor.ActorHandle]] = None
|
||||
self._observation_space = None
|
||||
self._action_space = None
|
||||
# Dict mapping object refs (return values of @ray.remote calls),
|
||||
# whose actual values we are waiting for (via ray.wait in
|
||||
# `self.poll()`) to their corresponding actor handles (the actors
|
||||
|
@ -97,6 +103,10 @@ class RemoteBaseEnv(BaseEnv):
|
|||
self.actors = [
|
||||
make_remote_env(i) for i in range(self.num_envs)
|
||||
]
|
||||
self._observation_space = ray.get(
|
||||
self.actors[0].observation_space.remote())
|
||||
self._action_space = ray.get(
|
||||
self.actors[0].action_space.remote())
|
||||
|
||||
# Lazy initialization. Call `reset()` on all @ray.remote
|
||||
# sub-environment actors at the beginning.
|
||||
|
@ -199,9 +209,23 @@ class RemoteBaseEnv(BaseEnv):
|
|||
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def get_sub_environments(self) -> List[EnvType]:
|
||||
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
|
||||
if as_dict:
|
||||
return {env_id: actor for env_id, actor in enumerate(self.actors)}
|
||||
return self.actors
|
||||
|
||||
@property
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def observation_space(self) -> gym.spaces.Dict:
|
||||
return self._observation_space
|
||||
|
||||
@property
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def action_space(self) -> gym.Space:
|
||||
return self._action_space
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class _RemoteMultiAgentEnv:
|
||||
|
@ -221,6 +245,14 @@ class _RemoteMultiAgentEnv:
|
|||
def step(self, action_dict):
|
||||
return self.env.step(action_dict)
|
||||
|
||||
# defining these 2 functions that way this information can be queried
|
||||
# with a call to ray.get()
|
||||
def observation_space(self):
|
||||
return self.env.observation_space
|
||||
|
||||
def action_space(self):
|
||||
return self.env.action_space
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class _RemoteSingleAgentEnv:
|
||||
|
@ -243,3 +275,11 @@ class _RemoteSingleAgentEnv:
|
|||
} for x in [obs, rew, done, info]]
|
||||
done["__all__"] = done[_DUMMY_AGENT_ID]
|
||||
return obs, rew, done, info
|
||||
|
||||
# defining these 2 functions that way this information can be queried
|
||||
# with a call to ray.get()
|
||||
def observation_space(self):
|
||||
return self.env.observation_space
|
||||
|
||||
def action_space(self):
|
||||
return self.env.action_space
|
||||
|
|
51
rllib/env/vector_env.py
vendored
51
rllib/env/vector_env.py
vendored
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
|
||||
|
@ -265,13 +265,13 @@ class VectorEnvWrapper(BaseEnv):
|
|||
|
||||
def __init__(self, vector_env: VectorEnv):
|
||||
self.vector_env = vector_env
|
||||
self.action_space = vector_env.action_space
|
||||
self.observation_space = vector_env.observation_space
|
||||
self.num_envs = vector_env.num_envs
|
||||
self.new_obs = None # lazily initialized
|
||||
self.cur_rewards = [None for _ in range(self.num_envs)]
|
||||
self.cur_dones = [False for _ in range(self.num_envs)]
|
||||
self.cur_infos = [None for _ in range(self.num_envs)]
|
||||
self._observation_space = vector_env.observation_space
|
||||
self._action_space = vector_env.action_space
|
||||
|
||||
@override(BaseEnv)
|
||||
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
|
||||
|
@ -312,10 +312,51 @@ class VectorEnvWrapper(BaseEnv):
|
|||
}
|
||||
|
||||
@override(BaseEnv)
|
||||
def get_sub_environments(self) -> List[EnvType]:
|
||||
return self.vector_env.get_sub_environments()
|
||||
def get_sub_environments(
|
||||
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
|
||||
if not as_dict:
|
||||
return self.vector_env.get_sub_environments()
|
||||
else:
|
||||
return {
|
||||
_id: env
|
||||
for _id, env in enumerate(
|
||||
self.vector_env.get_sub_environments())
|
||||
}
|
||||
|
||||
@override(BaseEnv)
|
||||
def try_render(self, env_id: Optional[EnvID] = None) -> None:
|
||||
assert env_id is None or isinstance(env_id, int)
|
||||
return self.vector_env.try_render_at(env_id)
|
||||
|
||||
@property
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def observation_space(self) -> gym.spaces.Dict:
|
||||
return self._observation_space
|
||||
|
||||
@property
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def action_space(self) -> gym.Space:
|
||||
return self._action_space
|
||||
|
||||
@staticmethod
|
||||
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
|
||||
"""Check if the given space contains the observations of x.
|
||||
|
||||
Args:
|
||||
space: The space to if x's observations are contained in.
|
||||
x: The observations to check.
|
||||
|
||||
Note: With vector envs, we can process the raw observations
|
||||
and ignore the agent ids and env ids, since vector envs'
|
||||
sub environements are guaranteed to be the same
|
||||
|
||||
Returns:
|
||||
True if the observations of x are contained in space.
|
||||
"""
|
||||
for _, multi_agent_dict in x.items():
|
||||
for _, element in multi_agent_dict.items():
|
||||
if not space.contains(element):
|
||||
return False
|
||||
return True
|
||||
|
|
Loading…
Add table
Reference in a new issue