mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Add necessary fields to Base Envs, and BaseEnv wrapper classes (#20832)
This commit is contained in:
parent
8bb9bfe632
commit
6996eaa986
6 changed files with 192 additions and 21 deletions
|
@ -290,7 +290,7 @@ py_test(
|
||||||
name = "learning_cartpole_simpleq_fake_gpus",
|
name = "learning_cartpole_simpleq_fake_gpus",
|
||||||
main = "tests/run_regression_tests.py",
|
main = "tests/run_regression_tests.py",
|
||||||
tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
|
tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
|
||||||
size = "large",
|
size = "medium",
|
||||||
srcs = ["tests/run_regression_tests.py"],
|
srcs = ["tests/run_regression_tests.py"],
|
||||||
data = ["tuned_examples/dqn/cartpole-simpleq-fake-gpus.yaml"],
|
data = ["tuned_examples/dqn/cartpole-simpleq-fake-gpus.yaml"],
|
||||||
args = ["--yaml-dir=tuned_examples/dqn"]
|
args = ["--yaml-dir=tuned_examples/dqn"]
|
||||||
|
@ -468,7 +468,7 @@ py_test(
|
||||||
py_test(
|
py_test(
|
||||||
name = "learning_tests_transformed_actions_pendulum_sac",
|
name = "learning_tests_transformed_actions_pendulum_sac",
|
||||||
main = "tests/run_regression_tests.py",
|
main = "tests/run_regression_tests.py",
|
||||||
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
|
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "flaky"],
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["tests/run_regression_tests.py"],
|
srcs = ["tests/run_regression_tests.py"],
|
||||||
data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"],
|
data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"],
|
||||||
|
@ -478,7 +478,7 @@ py_test(
|
||||||
py_test(
|
py_test(
|
||||||
name = "learning_pendulum_sac_fake_gpus",
|
name = "learning_pendulum_sac_fake_gpus",
|
||||||
main = "tests/run_regression_tests.py",
|
main = "tests/run_regression_tests.py",
|
||||||
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus"],
|
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus", "flaky"],
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["tests/run_regression_tests.py"],
|
srcs = ["tests/run_regression_tests.py"],
|
||||||
data = ["tuned_examples/sac/pendulum-sac-fake-gpus.yaml"],
|
data = ["tuned_examples/sac/pendulum-sac-fake-gpus.yaml"],
|
||||||
|
|
72
rllib/env/base_env.py
vendored
72
rllib/env/base_env.py
vendored
|
@ -1,6 +1,7 @@
|
||||||
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
|
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
|
||||||
Union
|
Union
|
||||||
|
|
||||||
|
import gym
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
|
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
|
||||||
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
|
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
|
||||||
|
@ -31,12 +32,6 @@ class BaseEnv:
|
||||||
rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
|
rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
|
||||||
rllib.ExternalEnv => rllib.BaseEnv
|
rllib.ExternalEnv => rllib.BaseEnv
|
||||||
|
|
||||||
Attributes:
|
|
||||||
action_space (gym.Space): Action space. This must be defined for
|
|
||||||
single-agent envs. Multi-agent envs can set this to None.
|
|
||||||
observation_space (gym.Space): Observation space. This must be defined
|
|
||||||
for single-agent envs. Multi-agent envs can set this to None.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> env = MyBaseEnv()
|
>>> env = MyBaseEnv()
|
||||||
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
|
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
|
||||||
|
@ -185,12 +180,18 @@ class BaseEnv:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def get_sub_environments(self) -> List[EnvType]:
|
def get_sub_environments(
|
||||||
|
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
|
||||||
"""Return a reference to the underlying sub environments, if any.
|
"""Return a reference to the underlying sub environments, if any.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
as_dict: If True, return a dict mapping from env_id to env.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of the underlying sub environments or [].
|
List or dictionary of the underlying sub environments or [] / {}.
|
||||||
"""
|
"""
|
||||||
|
if as_dict:
|
||||||
|
return {}
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
|
@ -218,6 +219,61 @@ class BaseEnv:
|
||||||
def get_unwrapped(self) -> List[EnvType]:
|
def get_unwrapped(self) -> List[EnvType]:
|
||||||
return self.get_sub_environments()
|
return self.get_sub_environments()
|
||||||
|
|
||||||
|
@PublicAPI
|
||||||
|
@property
|
||||||
|
def observation_space(self) -> gym.spaces.Dict:
|
||||||
|
"""Returns the observation space for each environment.
|
||||||
|
|
||||||
|
Note: samples from the observation space need to be preprocessed into a
|
||||||
|
`MultiEnvDict` before being used by a policy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The observation space for each environment.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@PublicAPI
|
||||||
|
@property
|
||||||
|
def action_space(self) -> gym.Space:
|
||||||
|
"""Returns the action space for each environment.
|
||||||
|
|
||||||
|
Note: samples from the action space need to be preprocessed into a
|
||||||
|
`MultiEnvDict` before being passed to `send_actions`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The observation space for each environment.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def observation_space_contains(self, x: MultiEnvDict) -> bool:
|
||||||
|
self._space_contains(self.observation_space, x)
|
||||||
|
|
||||||
|
def action_space_contains(self, x: MultiEnvDict) -> bool:
|
||||||
|
return self._space_contains(self.action_space, x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
|
||||||
|
"""Check if the given space contains the observations of x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space: The space to if x's observations are contained in.
|
||||||
|
x: The observations to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the observations of x are contained in space.
|
||||||
|
"""
|
||||||
|
# this removes the agent_id key and inner dicts
|
||||||
|
# in MultiEnvDicts
|
||||||
|
flattened_obs = {
|
||||||
|
env_id: list(obs.values())
|
||||||
|
for env_id, obs in x.items()
|
||||||
|
}
|
||||||
|
ret = True
|
||||||
|
for env_id in flattened_obs:
|
||||||
|
for obs in flattened_obs[env_id]:
|
||||||
|
ret = ret and space[env_id].contains(obs)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
# Fixed agent identifier when there is only the single agent in the env
|
# Fixed agent identifier when there is only the single agent in the env
|
||||||
_DUMMY_AGENT_ID = "agent0"
|
_DUMMY_AGENT_ID = "agent0"
|
||||||
|
|
18
rllib/env/external_env.py
vendored
18
rllib/env/external_env.py
vendored
|
@ -337,11 +337,11 @@ class ExternalEnvWrapper(BaseEnv):
|
||||||
self.external_env = external_env
|
self.external_env = external_env
|
||||||
self.prep = preprocessor
|
self.prep = preprocessor
|
||||||
self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
|
self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
|
||||||
self.action_space = external_env.action_space
|
self._action_space = external_env.action_space
|
||||||
if preprocessor:
|
if preprocessor:
|
||||||
self.observation_space = preprocessor.observation_space
|
self._observation_space = preprocessor.observation_space
|
||||||
else:
|
else:
|
||||||
self.observation_space = external_env.observation_space
|
self._observation_space = external_env.observation_space
|
||||||
external_env.start()
|
external_env.start()
|
||||||
|
|
||||||
@override(BaseEnv)
|
@override(BaseEnv)
|
||||||
|
@ -413,3 +413,15 @@ class ExternalEnvWrapper(BaseEnv):
|
||||||
with_dummy_agent_id(all_dones, "__all__"), \
|
with_dummy_agent_id(all_dones, "__all__"), \
|
||||||
with_dummy_agent_id(all_infos), \
|
with_dummy_agent_id(all_infos), \
|
||||||
with_dummy_agent_id(off_policy_actions)
|
with_dummy_agent_id(off_policy_actions)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override(BaseEnv)
|
||||||
|
@PublicAPI
|
||||||
|
def observation_space(self) -> gym.spaces.Dict:
|
||||||
|
return self._observation_space
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override(BaseEnv)
|
||||||
|
@PublicAPI
|
||||||
|
def action_space(self) -> gym.Space:
|
||||||
|
return self._action_space
|
||||||
|
|
24
rllib/env/multi_agent_env.py
vendored
24
rllib/env/multi_agent_env.py
vendored
|
@ -336,7 +336,12 @@ class MultiAgentEnvWrapper(BaseEnv):
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
@override(BaseEnv)
|
@override(BaseEnv)
|
||||||
def get_sub_environments(self) -> List[EnvType]:
|
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
|
||||||
|
if as_dict:
|
||||||
|
return {
|
||||||
|
_id: env_state
|
||||||
|
for _id, env_state in enumerate(self.env_states)
|
||||||
|
}
|
||||||
return [state.env for state in self.env_states]
|
return [state.env for state in self.env_states]
|
||||||
|
|
||||||
@override(BaseEnv)
|
@override(BaseEnv)
|
||||||
|
@ -346,6 +351,23 @@ class MultiAgentEnvWrapper(BaseEnv):
|
||||||
assert isinstance(env_id, int)
|
assert isinstance(env_id, int)
|
||||||
return self.envs[env_id].render()
|
return self.envs[env_id].render()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override(BaseEnv)
|
||||||
|
@PublicAPI
|
||||||
|
def observation_space(self) -> gym.spaces.Dict:
|
||||||
|
space = {
|
||||||
|
_id: env.observation_space
|
||||||
|
for _id, env in enumerate(self.envs)
|
||||||
|
}
|
||||||
|
return gym.spaces.Dict(space)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override(BaseEnv)
|
||||||
|
@PublicAPI
|
||||||
|
def action_space(self) -> gym.Space:
|
||||||
|
space = {_id: env.action_space for _id, env in enumerate(self.envs)}
|
||||||
|
return gym.spaces.Dict(space)
|
||||||
|
|
||||||
|
|
||||||
class _MultiAgentEnvState:
|
class _MultiAgentEnvState:
|
||||||
def __init__(self, env: MultiAgentEnv):
|
def __init__(self, env: MultiAgentEnv):
|
||||||
|
|
42
rllib/env/remote_base_env.py
vendored
42
rllib/env/remote_base_env.py
vendored
|
@ -1,6 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import gym
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
|
@ -17,6 +19,8 @@ class RemoteBaseEnv(BaseEnv):
|
||||||
from the remote simulator actors. Both single and multi-agent child envs
|
from the remote simulator actors. Both single and multi-agent child envs
|
||||||
are supported, and envs can be stepped synchronously or asynchronously.
|
are supported, and envs can be stepped synchronously or asynchronously.
|
||||||
|
|
||||||
|
NOTE: This class implicitly assumes that the remote envs are gym.Env's
|
||||||
|
|
||||||
You shouldn't need to instantiate this class directly. It's automatically
|
You shouldn't need to instantiate this class directly. It's automatically
|
||||||
inserted when you use the `remote_worker_envs=True` option in your
|
inserted when you use the `remote_worker_envs=True` option in your
|
||||||
Trainer's config.
|
Trainer's config.
|
||||||
|
@ -61,6 +65,8 @@ class RemoteBaseEnv(BaseEnv):
|
||||||
# List of ray actor handles (each handle points to one @ray.remote
|
# List of ray actor handles (each handle points to one @ray.remote
|
||||||
# sub-environment).
|
# sub-environment).
|
||||||
self.actors: Optional[List[ray.actor.ActorHandle]] = None
|
self.actors: Optional[List[ray.actor.ActorHandle]] = None
|
||||||
|
self._observation_space = None
|
||||||
|
self._action_space = None
|
||||||
# Dict mapping object refs (return values of @ray.remote calls),
|
# Dict mapping object refs (return values of @ray.remote calls),
|
||||||
# whose actual values we are waiting for (via ray.wait in
|
# whose actual values we are waiting for (via ray.wait in
|
||||||
# `self.poll()`) to their corresponding actor handles (the actors
|
# `self.poll()`) to their corresponding actor handles (the actors
|
||||||
|
@ -97,6 +103,10 @@ class RemoteBaseEnv(BaseEnv):
|
||||||
self.actors = [
|
self.actors = [
|
||||||
make_remote_env(i) for i in range(self.num_envs)
|
make_remote_env(i) for i in range(self.num_envs)
|
||||||
]
|
]
|
||||||
|
self._observation_space = ray.get(
|
||||||
|
self.actors[0].observation_space.remote())
|
||||||
|
self._action_space = ray.get(
|
||||||
|
self.actors[0].action_space.remote())
|
||||||
|
|
||||||
# Lazy initialization. Call `reset()` on all @ray.remote
|
# Lazy initialization. Call `reset()` on all @ray.remote
|
||||||
# sub-environment actors at the beginning.
|
# sub-environment actors at the beginning.
|
||||||
|
@ -199,9 +209,23 @@ class RemoteBaseEnv(BaseEnv):
|
||||||
|
|
||||||
@override(BaseEnv)
|
@override(BaseEnv)
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def get_sub_environments(self) -> List[EnvType]:
|
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
|
||||||
|
if as_dict:
|
||||||
|
return {env_id: actor for env_id, actor in enumerate(self.actors)}
|
||||||
return self.actors
|
return self.actors
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override(BaseEnv)
|
||||||
|
@PublicAPI
|
||||||
|
def observation_space(self) -> gym.spaces.Dict:
|
||||||
|
return self._observation_space
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override(BaseEnv)
|
||||||
|
@PublicAPI
|
||||||
|
def action_space(self) -> gym.Space:
|
||||||
|
return self._action_space
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_cpus=0)
|
@ray.remote(num_cpus=0)
|
||||||
class _RemoteMultiAgentEnv:
|
class _RemoteMultiAgentEnv:
|
||||||
|
@ -221,6 +245,14 @@ class _RemoteMultiAgentEnv:
|
||||||
def step(self, action_dict):
|
def step(self, action_dict):
|
||||||
return self.env.step(action_dict)
|
return self.env.step(action_dict)
|
||||||
|
|
||||||
|
# defining these 2 functions that way this information can be queried
|
||||||
|
# with a call to ray.get()
|
||||||
|
def observation_space(self):
|
||||||
|
return self.env.observation_space
|
||||||
|
|
||||||
|
def action_space(self):
|
||||||
|
return self.env.action_space
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_cpus=0)
|
@ray.remote(num_cpus=0)
|
||||||
class _RemoteSingleAgentEnv:
|
class _RemoteSingleAgentEnv:
|
||||||
|
@ -243,3 +275,11 @@ class _RemoteSingleAgentEnv:
|
||||||
} for x in [obs, rew, done, info]]
|
} for x in [obs, rew, done, info]]
|
||||||
done["__all__"] = done[_DUMMY_AGENT_ID]
|
done["__all__"] = done[_DUMMY_AGENT_ID]
|
||||||
return obs, rew, done, info
|
return obs, rew, done, info
|
||||||
|
|
||||||
|
# defining these 2 functions that way this information can be queried
|
||||||
|
# with a call to ray.get()
|
||||||
|
def observation_space(self):
|
||||||
|
return self.env.observation_space
|
||||||
|
|
||||||
|
def action_space(self):
|
||||||
|
return self.env.action_space
|
||||||
|
|
49
rllib/env/vector_env.py
vendored
49
rllib/env/vector_env.py
vendored
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from ray.rllib.env.base_env import BaseEnv
|
from ray.rllib.env.base_env import BaseEnv
|
||||||
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
|
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
|
||||||
|
@ -265,13 +265,13 @@ class VectorEnvWrapper(BaseEnv):
|
||||||
|
|
||||||
def __init__(self, vector_env: VectorEnv):
|
def __init__(self, vector_env: VectorEnv):
|
||||||
self.vector_env = vector_env
|
self.vector_env = vector_env
|
||||||
self.action_space = vector_env.action_space
|
|
||||||
self.observation_space = vector_env.observation_space
|
|
||||||
self.num_envs = vector_env.num_envs
|
self.num_envs = vector_env.num_envs
|
||||||
self.new_obs = None # lazily initialized
|
self.new_obs = None # lazily initialized
|
||||||
self.cur_rewards = [None for _ in range(self.num_envs)]
|
self.cur_rewards = [None for _ in range(self.num_envs)]
|
||||||
self.cur_dones = [False for _ in range(self.num_envs)]
|
self.cur_dones = [False for _ in range(self.num_envs)]
|
||||||
self.cur_infos = [None for _ in range(self.num_envs)]
|
self.cur_infos = [None for _ in range(self.num_envs)]
|
||||||
|
self._observation_space = vector_env.observation_space
|
||||||
|
self._action_space = vector_env.action_space
|
||||||
|
|
||||||
@override(BaseEnv)
|
@override(BaseEnv)
|
||||||
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
|
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
|
||||||
|
@ -312,10 +312,51 @@ class VectorEnvWrapper(BaseEnv):
|
||||||
}
|
}
|
||||||
|
|
||||||
@override(BaseEnv)
|
@override(BaseEnv)
|
||||||
def get_sub_environments(self) -> List[EnvType]:
|
def get_sub_environments(
|
||||||
|
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
|
||||||
|
if not as_dict:
|
||||||
return self.vector_env.get_sub_environments()
|
return self.vector_env.get_sub_environments()
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
_id: env
|
||||||
|
for _id, env in enumerate(
|
||||||
|
self.vector_env.get_sub_environments())
|
||||||
|
}
|
||||||
|
|
||||||
@override(BaseEnv)
|
@override(BaseEnv)
|
||||||
def try_render(self, env_id: Optional[EnvID] = None) -> None:
|
def try_render(self, env_id: Optional[EnvID] = None) -> None:
|
||||||
assert env_id is None or isinstance(env_id, int)
|
assert env_id is None or isinstance(env_id, int)
|
||||||
return self.vector_env.try_render_at(env_id)
|
return self.vector_env.try_render_at(env_id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override(BaseEnv)
|
||||||
|
@PublicAPI
|
||||||
|
def observation_space(self) -> gym.spaces.Dict:
|
||||||
|
return self._observation_space
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override(BaseEnv)
|
||||||
|
@PublicAPI
|
||||||
|
def action_space(self) -> gym.Space:
|
||||||
|
return self._action_space
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
|
||||||
|
"""Check if the given space contains the observations of x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space: The space to if x's observations are contained in.
|
||||||
|
x: The observations to check.
|
||||||
|
|
||||||
|
Note: With vector envs, we can process the raw observations
|
||||||
|
and ignore the agent ids and env ids, since vector envs'
|
||||||
|
sub environements are guaranteed to be the same
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the observations of x are contained in space.
|
||||||
|
"""
|
||||||
|
for _, multi_agent_dict in x.items():
|
||||||
|
for _, element in multi_agent_dict.items():
|
||||||
|
if not space.contains(element):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
Loading…
Add table
Reference in a new issue