mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Make to_base_env()
a method of all RLlib-supported Env classes (#20811)
This commit is contained in:
parent
f481081904
commit
74dd0e4085
5 changed files with 136 additions and 126 deletions
131
rllib/env/base_env.py
vendored
131
rllib/env/base_env.py
vendored
|
@ -3,7 +3,7 @@ from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING
|
|||
import ray
|
||||
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
|
||||
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
|
||||
MultiEnvDict, PartialTrainerConfigDict
|
||||
MultiEnvDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.models.preprocessors import Preprocessor
|
||||
|
@ -78,18 +78,12 @@ class BaseEnv:
|
|||
}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@Deprecated(
|
||||
old="ray.rllib.env.base_env.BaseEnv.to_base_env",
|
||||
new="ray.rllib.env.base_env.convert_to_base_env",
|
||||
error=False)
|
||||
def to_base_env(
|
||||
env: EnvType,
|
||||
self,
|
||||
make_env: Callable[[int], EnvType] = None,
|
||||
num_envs: int = 1,
|
||||
remote_envs: bool = False,
|
||||
remote_env_batch_wait_ms: int = 0,
|
||||
policy_config: Optional[PartialTrainerConfigDict] = None,
|
||||
) -> "BaseEnv":
|
||||
"""Converts an RLlib-supported env into a BaseEnv object.
|
||||
|
||||
|
@ -126,83 +120,8 @@ class BaseEnv:
|
|||
Returns:
|
||||
The resulting BaseEnv object.
|
||||
"""
|
||||
|
||||
from ray.rllib.env.remote_vector_env import RemoteBaseEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv, ExternalEnvWrapper
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv, \
|
||||
MultiAgentEnvWrapper
|
||||
from ray.rllib.env.vector_env import VectorEnv, VectorEnvWrapper
|
||||
if remote_envs and num_envs == 1:
|
||||
raise ValueError(
|
||||
"Remote envs only make sense to use if num_envs > 1 "
|
||||
"(i.e. vectorization is enabled).")
|
||||
|
||||
# Given `env` is already a BaseEnv -> Return as is.
|
||||
if isinstance(env, BaseEnv):
|
||||
return env
|
||||
|
||||
# `env` is not a BaseEnv yet -> Need to convert/vectorize.
|
||||
|
||||
# MultiAgentEnv (which is a gym.Env).
|
||||
if isinstance(env, MultiAgentEnv):
|
||||
# Sub-environments are ray.remote actors:
|
||||
if remote_envs:
|
||||
env = RemoteBaseEnv(
|
||||
make_env,
|
||||
num_envs,
|
||||
multiagent=True,
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
|
||||
# Sub-environments are not ray.remote actors.
|
||||
else:
|
||||
env = MultiAgentEnvWrapper(
|
||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
||||
# ExternalEnv.
|
||||
elif isinstance(env, ExternalEnv):
|
||||
if num_envs != 1:
|
||||
raise ValueError(
|
||||
"External(MultiAgent)Env does not currently support "
|
||||
"num_envs > 1. One way of solving this would be to "
|
||||
"treat your Env as a MultiAgentEnv hosting only one "
|
||||
"type of agent but with several copies.")
|
||||
env = ExternalEnvWrapper(env)
|
||||
# VectorEnv.
|
||||
# Note that all BaseEnvs are also vectorized, but the user may want to
|
||||
# define custom vectorization logic and thus implement a custom
|
||||
# VectorEnv class.
|
||||
elif isinstance(env, VectorEnv):
|
||||
env = VectorEnvWrapper(env)
|
||||
# Anything else: This usually implies that env is a gym.Env object.
|
||||
else:
|
||||
# Sub-environments are ray.remote actors:
|
||||
if remote_envs:
|
||||
# Determine, whether the already existing sub-env (could
|
||||
# be a ray.actor) is multi-agent or not.
|
||||
multiagent = ray.get(env._is_multi_agent.remote()) if \
|
||||
hasattr(env, "_is_multi_agent") else False
|
||||
env = RemoteBaseEnv(
|
||||
make_env,
|
||||
num_envs,
|
||||
multiagent=multiagent,
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
||||
existing_envs=[env],
|
||||
)
|
||||
# Sub-environments are not ray.remote actors.
|
||||
else:
|
||||
# Convert gym.Env to VectorEnv ...
|
||||
env = VectorEnv.vectorize_gym_envs(
|
||||
make_env=make_env,
|
||||
existing_envs=[env],
|
||||
num_envs=num_envs,
|
||||
action_space=env.action_space,
|
||||
observation_space=env.observation_space,
|
||||
)
|
||||
# ... then the resulting VectorEnv to a BaseEnv.
|
||||
env = VectorEnvWrapper(env)
|
||||
|
||||
# Make sure conversion went well.
|
||||
assert isinstance(env, BaseEnv), env
|
||||
|
||||
return env
|
||||
del make_env, num_envs, remote_envs, remote_env_batch_wait_ms
|
||||
return self
|
||||
|
||||
@PublicAPI
|
||||
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
|
||||
|
@ -627,7 +546,6 @@ def convert_to_base_env(
|
|||
num_envs: int = 1,
|
||||
remote_envs: bool = False,
|
||||
remote_env_batch_wait_ms: int = 0,
|
||||
policy_config: Optional[PartialTrainerConfigDict] = None,
|
||||
) -> "BaseEnv":
|
||||
"""Converts an RLlib-supported env into a BaseEnv object.
|
||||
|
||||
|
@ -659,56 +577,23 @@ def convert_to_base_env(
|
|||
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
||||
sub-environments for, if applicable. Only used if
|
||||
`remote_envs` is True.
|
||||
policy_config: Optional policy config dict.
|
||||
|
||||
Returns:
|
||||
The resulting BaseEnv object.
|
||||
"""
|
||||
|
||||
from ray.rllib.env.remote_vector_env import RemoteBaseEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv, ExternalEnvWrapper
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv, \
|
||||
MultiAgentEnvWrapper
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv, VectorEnvWrapper
|
||||
if remote_envs and num_envs == 1:
|
||||
raise ValueError("Remote envs only make sense to use if num_envs > 1 "
|
||||
"(i.e. vectorization is enabled).")
|
||||
|
||||
# Given `env` is already a BaseEnv -> Return as is.
|
||||
if isinstance(env, BaseEnv):
|
||||
return env
|
||||
|
||||
if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)):
|
||||
return env.to_base_env()
|
||||
# `env` is not a BaseEnv yet -> Need to convert/vectorize.
|
||||
|
||||
# MultiAgentEnv (which is a gym.Env).
|
||||
if isinstance(env, MultiAgentEnv):
|
||||
# Sub-environments are ray.remote actors:
|
||||
if remote_envs:
|
||||
env = RemoteBaseEnv(
|
||||
make_env,
|
||||
num_envs,
|
||||
multiagent=True,
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
|
||||
# Sub-environments are not ray.remote actors.
|
||||
else:
|
||||
env = MultiAgentEnvWrapper(
|
||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
||||
# ExternalEnv.
|
||||
elif isinstance(env, ExternalEnv):
|
||||
if num_envs != 1:
|
||||
raise ValueError(
|
||||
"External(MultiAgent)Env does not currently support "
|
||||
"num_envs > 1. One way of solving this would be to "
|
||||
"treat your Env as a MultiAgentEnv hosting only one "
|
||||
"type of agent but with several copies.")
|
||||
env = ExternalEnvWrapper(env)
|
||||
# VectorEnv.
|
||||
# Note that all BaseEnvs are also vectorized, but the user may want to
|
||||
# define custom vectorization logic and thus implement a custom
|
||||
# VectorEnv class.
|
||||
elif isinstance(env, VectorEnv):
|
||||
env = VectorEnvWrapper(env)
|
||||
# Anything else: This usually implies that env is a gym.Env object.
|
||||
else:
|
||||
# Sub-environments are ray.remote actors:
|
||||
if remote_envs:
|
||||
|
|
46
rllib/env/external_env.py
vendored
46
rllib/env/external_env.py
vendored
|
@ -2,12 +2,12 @@ from six.moves import queue
|
|||
import gym
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Callable, Tuple, Optional, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.typing import EnvActionType, EnvInfoDict, EnvObsType, \
|
||||
MultiEnvDict
|
||||
EnvType, MultiEnvDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.models.preprocessors import Preprocessor
|
||||
|
@ -188,6 +188,48 @@ class ExternalEnv(threading.Thread):
|
|||
|
||||
return self._episodes[episode_id]
|
||||
|
||||
def to_base_env(
|
||||
self,
|
||||
make_env: Callable[[int], EnvType] = None,
|
||||
num_envs: int = 1,
|
||||
remote_envs: bool = False,
|
||||
remote_env_batch_wait_ms: int = 0,
|
||||
) -> "BaseEnv":
|
||||
"""Converts an RLlib MultiAgentEnv into a BaseEnv object.
|
||||
|
||||
The resulting BaseEnv is always vectorized (contains n
|
||||
sub-environments) to support batched forward passes, where n may
|
||||
also be 1. BaseEnv also supports async execution via the `poll` and
|
||||
`send_actions` methods and thus supports external simulators.
|
||||
|
||||
Args:
|
||||
make_env: A callable taking an int as input (which indicates
|
||||
the number of individual sub-environments within the final
|
||||
vectorized BaseEnv) and returning one individual
|
||||
sub-environment.
|
||||
num_envs: The number of sub-environments to create in the
|
||||
resulting (vectorized) BaseEnv. The already existing `env`
|
||||
will be one of the `num_envs`.
|
||||
remote_envs: Whether each sub-env should be a @ray.remote
|
||||
actor. You can set this behavior in your config via the
|
||||
`remote_worker_envs=True` option.
|
||||
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
||||
sub-environments for, if applicable. Only used if
|
||||
`remote_envs` is True.
|
||||
|
||||
Returns:
|
||||
The resulting BaseEnv object.
|
||||
"""
|
||||
if num_envs != 1:
|
||||
raise ValueError(
|
||||
"External(MultiAgent)Env does not currently support "
|
||||
"num_envs > 1. One way of solving this would be to "
|
||||
"treat your Env as a MultiAgentEnv hosting only one "
|
||||
"type of agent but with several copies.")
|
||||
env = ExternalEnvWrapper(self)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class _ExternalEnvEpisode:
|
||||
"""Tracked state for each active episode."""
|
||||
|
|
47
rllib/env/multi_agent_env.py
vendored
47
rllib/env/multi_agent_env.py
vendored
|
@ -132,6 +132,53 @@ class MultiAgentEnv(gym.Env):
|
|||
from ray.rllib.env.wrappers.group_agents_wrapper import \
|
||||
GroupAgentsWrapper
|
||||
return GroupAgentsWrapper(self, groups, obs_space, act_space)
|
||||
|
||||
@PublicAPI
|
||||
def to_base_env(self,
|
||||
make_env: Callable[[int], EnvType] = None,
|
||||
num_envs: int = 1,
|
||||
remote_envs: bool = False,
|
||||
remote_env_batch_wait_ms: int = 0,
|
||||
) -> "BaseEnv":
|
||||
"""Converts an RLlib MultiAgentEnv into a BaseEnv object.
|
||||
|
||||
The resulting BaseEnv is always vectorized (contains n
|
||||
sub-environments) to support batched forward passes, where n may
|
||||
also be 1. BaseEnv also supports async execution via the `poll` and
|
||||
`send_actions` methods and thus supports external simulators.
|
||||
|
||||
Args:
|
||||
make_env: A callable taking an int as input (which indicates
|
||||
the number of individual sub-environments within the final
|
||||
vectorized BaseEnv) and returning one individual
|
||||
sub-environment.
|
||||
num_envs: The number of sub-environments to create in the
|
||||
resulting (vectorized) BaseEnv. The already existing `env`
|
||||
will be one of the `num_envs`.
|
||||
remote_envs: Whether each sub-env should be a @ray.remote
|
||||
actor. You can set this behavior in your config via the
|
||||
`remote_worker_envs=True` option.
|
||||
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
||||
sub-environments for, if applicable. Only used if
|
||||
`remote_envs` is True.
|
||||
|
||||
Returns:
|
||||
The resulting BaseEnv object.
|
||||
"""
|
||||
from ray.rllib.env.remote_vector_env import RemoteBaseEnv
|
||||
if remote_envs:
|
||||
env = RemoteBaseEnv(
|
||||
make_env,
|
||||
num_envs,
|
||||
multiagent=True,
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
|
||||
# Sub-environments are not ray.remote actors.
|
||||
else:
|
||||
env = MultiAgentEnvWrapper(
|
||||
make_env=make_env, existing_envs=[self], num_envs=num_envs)
|
||||
|
||||
return env
|
||||
|
||||
# __grouping_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
|
37
rllib/env/vector_env.py
vendored
37
rllib/env/vector_env.py
vendored
|
@ -135,6 +135,43 @@ class VectorEnv:
|
|||
def get_unwrapped(self) -> List[EnvType]:
|
||||
return self.get_sub_environments()
|
||||
|
||||
@PublicAPI
|
||||
def to_base_env(
|
||||
self,
|
||||
make_env: Callable[[int], EnvType] = None,
|
||||
num_envs: int = 1,
|
||||
remote_envs: bool = False,
|
||||
remote_env_batch_wait_ms: int = 0,
|
||||
) -> "BaseEnv":
|
||||
"""Converts an RLlib MultiAgentEnv into a BaseEnv object.
|
||||
|
||||
The resulting BaseEnv is always vectorized (contains n
|
||||
sub-environments) to support batched forward passes, where n may
|
||||
also be 1. BaseEnv also supports async execution via the `poll` and
|
||||
`send_actions` methods and thus supports external simulators.
|
||||
|
||||
Args:
|
||||
make_env: A callable taking an int as input (which indicates
|
||||
the number of individual sub-environments within the final
|
||||
vectorized BaseEnv) and returning one individual
|
||||
sub-environment.
|
||||
num_envs: The number of sub-environments to create in the
|
||||
resulting (vectorized) BaseEnv. The already existing `env`
|
||||
will be one of the `num_envs`.
|
||||
remote_envs: Whether each sub-env should be a @ray.remote
|
||||
actor. You can set this behavior in your config via the
|
||||
`remote_worker_envs=True` option.
|
||||
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
||||
sub-environments for, if applicable. Only used if
|
||||
`remote_envs` is True.
|
||||
|
||||
Returns:
|
||||
The resulting BaseEnv object.
|
||||
"""
|
||||
del make_env, num_envs, remote_envs, remote_env_batch_wait_ms
|
||||
env = VectorEnvWrapper(self)
|
||||
return env
|
||||
|
||||
|
||||
class _VectorizedGymEnv(VectorEnv):
|
||||
"""Internal wrapper to translate any gym.Envs into a VectorEnv object.
|
||||
|
|
|
@ -636,7 +636,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
num_envs=num_envs,
|
||||
remote_envs=remote_worker_envs,
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
||||
policy_config=policy_config,
|
||||
)
|
||||
|
||||
# `truncate_episodes`: Allow a batch to contain more than one episode
|
||||
|
|
Loading…
Add table
Reference in a new issue