mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Discussion 247: Allow remote sub-envs (within vectorized) to be used with custom APIs. (#17118)
This commit is contained in:
parent
29768a7c01
commit
0c5c70b584
6 changed files with 274 additions and 42 deletions
26
rllib/BUILD
26
rllib/BUILD
|
@ -2253,6 +2253,15 @@ py_test(
|
|||
args = ["--stop-iters=2"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/fractional_gpus",
|
||||
main = "examples/fractional_gpus.py",
|
||||
tags = ["examples", "examples_F"],
|
||||
size = "medium",
|
||||
srcs = ["examples/fractional_gpus.py"],
|
||||
args = ["--as-test", "--stop-reward=40.0", "--num-gpus=0", "--num-workers=0"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/hierarchical_training_tf",
|
||||
main = "examples/hierarchical_training.py",
|
||||
|
@ -2417,15 +2426,6 @@ py_test(
|
|||
args = ["--as-test", "--framework=torch", "--stop-reward=60.0", "--run=DQN"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/fractional_gpus",
|
||||
main = "examples/fractional_gpus.py",
|
||||
tags = ["examples", "examples_P"],
|
||||
size = "medium",
|
||||
srcs = ["examples/fractional_gpus.py"],
|
||||
args = ["--as-test", "--stop-reward=40.0", "--num-gpus=0", "--num-workers=0"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/pettingzoo_env",
|
||||
main = "examples/pettingzoo_env.py",
|
||||
|
@ -2434,6 +2434,14 @@ py_test(
|
|||
srcs = ["examples/pettingzoo_env.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/remote_vector_env_with_custom_api",
|
||||
tags = ["examples", "examples_R"],
|
||||
size = "medium",
|
||||
srcs = ["examples/remote_vector_env_with_custom_api.py"],
|
||||
args = ["--stop-iters=3"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/restore_1_of_n_agents_from_checkpoint",
|
||||
tags = ["examples", "examples_R"],
|
||||
|
|
|
@ -551,7 +551,8 @@ class Trainer(Trainable):
|
|||
config = config or {}
|
||||
|
||||
# Trainers allow env ids to be passed directly to the constructor.
|
||||
self._env_id = self._register_if_needed(env or config.get("env"))
|
||||
self._env_id = self._register_if_needed(
|
||||
env or config.get("env"), config)
|
||||
|
||||
# Create a default logger creator if no logger_creator is specified
|
||||
if logger_creator is None:
|
||||
|
@ -1611,12 +1612,26 @@ class Trainer(Trainable):
|
|||
"that were generated via the `ray.rllib.agents.trainer_template."
|
||||
"build_trainer()` function!")
|
||||
|
||||
def _register_if_needed(self, env_object: Union[str, EnvType, None]):
|
||||
def _register_if_needed(self, env_object: Union[str, EnvType, None],
|
||||
config):
|
||||
if isinstance(env_object, str):
|
||||
return env_object
|
||||
elif isinstance(env_object, type):
|
||||
name = env_object.__name__
|
||||
register_env(name, lambda config: env_object(config))
|
||||
|
||||
# Add convenience `_get_spaces` method.
|
||||
|
||||
def _get_spaces(s):
|
||||
return s.observation_space, s.action_space
|
||||
|
||||
env_object._get_spaces = _get_spaces
|
||||
|
||||
if config.get("remote_worker_envs"):
|
||||
register_env(
|
||||
name,
|
||||
lambda cfg: ray.remote(num_cpus=0)(env_object).remote(cfg))
|
||||
else:
|
||||
register_env(name, lambda config: env_object(config))
|
||||
return name
|
||||
elif env_object is None:
|
||||
return None
|
||||
|
|
4
rllib/env/base_env.py
vendored
4
rllib/env/base_env.py
vendored
|
@ -125,7 +125,9 @@ class BaseEnv:
|
|||
make_env,
|
||||
num_envs,
|
||||
multiagent=False,
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
||||
existing_envs=[env],
|
||||
)
|
||||
else:
|
||||
env = VectorEnv.wrap(
|
||||
make_env=make_env,
|
||||
|
|
80
rllib/env/remote_vector_env.py
vendored
80
rllib/env/remote_vector_env.py
vendored
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from typing import Tuple, Callable, Optional
|
||||
from typing import Tuple, Callable, List, Optional
|
||||
|
||||
import ray
|
||||
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
||||
|
@ -21,9 +21,17 @@ class RemoteVectorEnv(BaseEnv):
|
|||
inserted when you use the `remote_worker_envs` option for Trainers.
|
||||
"""
|
||||
|
||||
def __init__(self, make_env: Callable[[int], EnvType], num_envs: int,
|
||||
multiagent: bool, remote_env_batch_wait_ms: int):
|
||||
self.make_local_env = make_env
|
||||
def __init__(self,
|
||||
make_env: Callable[[int], EnvType],
|
||||
num_envs: int,
|
||||
multiagent: bool,
|
||||
remote_env_batch_wait_ms: int,
|
||||
existing_envs: Optional[List[ray.actor.ActorHandle]] = None):
|
||||
# Could be creating local or remote envs.
|
||||
self.make_env = make_env
|
||||
self.make_env_creates_actors = False
|
||||
# Already existing env objects (generated by the RolloutWorker).
|
||||
self.existing_envs = existing_envs or []
|
||||
self.num_envs = num_envs
|
||||
self.multiagent = multiagent
|
||||
self.poll_timeout = remote_env_batch_wait_ms / 1000
|
||||
|
@ -35,15 +43,28 @@ class RemoteVectorEnv(BaseEnv):
|
|||
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
|
||||
MultiEnvDict, MultiEnvDict]:
|
||||
if self.actors is None:
|
||||
# `self.make_env` already produces Actors: Use it directly.
|
||||
if len(self.existing_envs) > 0 and isinstance(
|
||||
self.existing_envs[0], ray.actor.ActorHandle):
|
||||
self.make_env_creates_actors = True
|
||||
self.actors = []
|
||||
while len(self.actors) < self.num_envs:
|
||||
self.actors.append(self.make_env(len(self.actors)))
|
||||
# `self.make_env` produces gym.Envs (or other similar types, such
|
||||
# as MultiAgentEnv): Need to auto-wrap it here. The problem with
|
||||
# this is that custom methods wil get lost.
|
||||
else:
|
||||
|
||||
def make_remote_env(i):
|
||||
logger.info("Launching env {} in remote actor".format(i))
|
||||
if self.multiagent:
|
||||
return _RemoteMultiAgentEnv.remote(self.make_local_env, i)
|
||||
else:
|
||||
return _RemoteSingleAgentEnv.remote(self.make_local_env, i)
|
||||
def make_remote_env(i):
|
||||
logger.info("Launching env {} in remote actor".format(i))
|
||||
if self.multiagent:
|
||||
return _RemoteMultiAgentEnv.remote(self.make_env, i)
|
||||
else:
|
||||
return _RemoteSingleAgentEnv.remote(self.make_env, i)
|
||||
|
||||
self.actors = [make_remote_env(i) for i in range(self.num_envs)]
|
||||
self.actors = [
|
||||
make_remote_env(i) for i in range(self.num_envs)
|
||||
]
|
||||
|
||||
if self.pending is None:
|
||||
self.pending = {a.reset.remote(): a for a in self.actors}
|
||||
|
@ -65,7 +86,32 @@ class RemoteVectorEnv(BaseEnv):
|
|||
actor = self.pending.pop(obj_ref)
|
||||
env_id = self.actors.index(actor)
|
||||
env_ids.add(env_id)
|
||||
ob, rew, done, info = ray.get(obj_ref)
|
||||
ret = ray.get(obj_ref)
|
||||
# Our sub-envs are simple Actor-turned gym.Envs or MultiAgentEnvs.
|
||||
if self.make_env_creates_actors:
|
||||
rew, done, info = None, None, None
|
||||
if self.multiagent:
|
||||
if isinstance(ret, tuple) and len(ret) == 4:
|
||||
ob, rew, done, info = ret
|
||||
else:
|
||||
ob = ret
|
||||
else:
|
||||
if isinstance(ret, tuple) and len(ret) == 4:
|
||||
ob = {_DUMMY_AGENT_ID: ret[0]}
|
||||
rew = {_DUMMY_AGENT_ID: ret[1]}
|
||||
done = {_DUMMY_AGENT_ID: ret[2], "__all__": ret[2]}
|
||||
info = {_DUMMY_AGENT_ID: ret[3]}
|
||||
else:
|
||||
ob = {_DUMMY_AGENT_ID: ret}
|
||||
|
||||
if rew is None:
|
||||
rew = {agent_id: 0 for agent_id in ob.keys()}
|
||||
done = {"__all__": False}
|
||||
info = {agent_id: {} for agent_id in ob.keys()}
|
||||
# Our sub-envs are auto-wrapped and already behave like multi-agent
|
||||
# envs.
|
||||
else:
|
||||
ob, rew, done, info = ret
|
||||
obs[env_id] = ob
|
||||
rewards[env_id] = rew
|
||||
dones[env_id] = done
|
||||
|
@ -74,6 +120,7 @@ class RemoteVectorEnv(BaseEnv):
|
|||
logger.debug("Got obs batch for actors {}".format(env_ids))
|
||||
return obs, rewards, dones, infos, {}
|
||||
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def send_actions(self, action_dict: MultiEnvDict) -> None:
|
||||
for env_id, actions in action_dict.items():
|
||||
|
@ -81,6 +128,7 @@ class RemoteVectorEnv(BaseEnv):
|
|||
obj_ref = actor.step.remote(actions)
|
||||
self.pending[obj_ref] = actor
|
||||
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def try_reset(self,
|
||||
env_id: Optional[EnvID] = None) -> Optional[MultiAgentDict]:
|
||||
|
@ -89,12 +137,18 @@ class RemoteVectorEnv(BaseEnv):
|
|||
self.pending[obj_ref] = actor
|
||||
return ASYNC_RESET_RETURN
|
||||
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def stop(self) -> None:
|
||||
if self.actors is not None:
|
||||
for actor in self.actors:
|
||||
actor.__ray_terminate__.remote()
|
||||
|
||||
@override(BaseEnv)
|
||||
@PublicAPI
|
||||
def get_unwrapped(self):
|
||||
return self.actors
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class _RemoteMultiAgentEnv:
|
||||
|
@ -125,8 +179,8 @@ class _RemoteSingleAgentEnv:
|
|||
def reset(self):
|
||||
obs = {_DUMMY_AGENT_ID: self.env.reset()}
|
||||
rew = {agent_id: 0 for agent_id in obs.keys()}
|
||||
info = {agent_id: {} for agent_id in obs.keys()}
|
||||
done = {"__all__": False}
|
||||
info = {agent_id: {} for agent_id in obs.keys()}
|
||||
return obs, rew, done, info
|
||||
|
||||
def step(self, action):
|
||||
|
|
|
@ -429,7 +429,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
return env
|
||||
|
||||
# We can't auto-wrap a BaseEnv.
|
||||
elif isinstance(self.env, BaseEnv):
|
||||
elif isinstance(self.env, (BaseEnv, ray.actor.ActorHandle)):
|
||||
|
||||
def wrap(env):
|
||||
return env
|
||||
|
@ -994,8 +994,11 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
return []
|
||||
|
||||
envs = self.async_env.get_unwrapped()
|
||||
# `get_unwrapped` not implemented (returned empty list). Call func
|
||||
# directly on `self.async_env`.
|
||||
if not envs:
|
||||
return [func(self.async_env)]
|
||||
# Vectorized env. Call func on all sub-envs.
|
||||
else:
|
||||
return [func(e) for e in envs]
|
||||
|
||||
|
@ -1390,20 +1393,28 @@ def _determine_spaces_for_multi_agent_dict(
|
|||
# Try extracting spaces from env or from given spaces dict.
|
||||
env_obs_space = None
|
||||
env_act_space = None
|
||||
# Extract the observation space from the env directly, if provided.
|
||||
if env is not None and hasattr(env, "observation_space") and isinstance(
|
||||
env.observation_space, gym.Space):
|
||||
env_obs_space = env.observation_space
|
||||
# Try getting the env's spaces from the spaces dict's special __env__ key.
|
||||
elif spaces is not None:
|
||||
env_obs_space = spaces.get("__env__", [None])[0]
|
||||
# Extract the action space from the env directly, if provided.
|
||||
if env is not None and hasattr(env, "action_space") and isinstance(
|
||||
env.action_space, gym.Space):
|
||||
env_act_space = env.action_space
|
||||
# Try getting the env's spaces from the spaces dict's special __env__ key.
|
||||
elif spaces is not None:
|
||||
env_act_space = spaces.get("__env__", [None, None])[1]
|
||||
|
||||
# Env is a ray.remote: Get spaces via its (automatically added)
|
||||
# `_get_spaces()` method.
|
||||
if isinstance(env, ray.actor.ActorHandle):
|
||||
env_obs_space, env_act_space = ray.get(env._get_spaces.remote())
|
||||
# Normal env (gym.Env or MultiAgentEnv): These should have the
|
||||
# `observation_space` and `action_space` properties.
|
||||
elif env is not None:
|
||||
if hasattr(env, "observation_space") and isinstance(
|
||||
env.observation_space, gym.Space):
|
||||
env_obs_space = env.observation_space
|
||||
|
||||
if hasattr(env, "action_space") and isinstance(env.action_space,
|
||||
gym.Space):
|
||||
env_act_space = env.action_space
|
||||
# Last resort: Try getting the env's spaces from the spaces
|
||||
# dict's special __env__ key.
|
||||
if spaces is not None:
|
||||
if env_obs_space is None:
|
||||
env_obs_space = spaces.get("__env__", [None])[0]
|
||||
if env_act_space is None:
|
||||
env_act_space = spaces.get("__env__", [None, None])[1]
|
||||
|
||||
for pid, policy_spec in multi_agent_dict.copy().items():
|
||||
if policy_spec.observation_space is None:
|
||||
|
@ -1445,7 +1456,10 @@ def _validate_env(env: Any) -> EnvType:
|
|||
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
|
||||
return env
|
||||
|
||||
allowed_types = [gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv]
|
||||
allowed_types = [
|
||||
gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv,
|
||||
ray.actor.ActorHandle
|
||||
]
|
||||
if not any(isinstance(env, tpe) for tpe in allowed_types):
|
||||
raise ValueError(
|
||||
"Returned env should be an instance of gym.Env, MultiAgentEnv, "
|
||||
|
|
139
rllib/examples/remote_vector_env_with_custom_api.py
Normal file
139
rllib/examples/remote_vector_env_with_custom_api.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
"""
|
||||
This script demonstrates how one can specify custom env APIs in
|
||||
combination with RLlib's `remote_worker_envs` setting, which
|
||||
parallelizes individual sub-envs within a vector env by making each
|
||||
one a ray Actor.
|
||||
|
||||
You can access your Env's API via a custom callback as shown below.
|
||||
"""
|
||||
import argparse
|
||||
import gym
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.env.apis.task_settable_env import TaskSettableEnv
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
from ray import tune
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--run",
|
||||
type=str,
|
||||
default="PPO",
|
||||
help="The RLlib-registered algorithm to use.")
|
||||
parser.add_argument(
|
||||
"--framework",
|
||||
choices=["tf", "tf2", "tfe", "torch"],
|
||||
default="tf",
|
||||
help="The DL framework specifier.")
|
||||
parser.add_argument("--num-workers", type=int, default=1)
|
||||
|
||||
# This should be >1, otherwise, remote envs make no sense.
|
||||
parser.add_argument("--num-envs-per-worker", type=int, default=4)
|
||||
|
||||
parser.add_argument(
|
||||
"--as-test",
|
||||
action="store_true",
|
||||
help="Whether this script should be run as a test: --stop-reward must "
|
||||
"be achieved within --stop-timesteps AND --stop-iters.")
|
||||
parser.add_argument(
|
||||
"--stop-iters",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of iterations to train.")
|
||||
parser.add_argument(
|
||||
"--stop-timesteps",
|
||||
type=int,
|
||||
default=100000,
|
||||
help="Number of timesteps to train.")
|
||||
parser.add_argument(
|
||||
"--stop-reward",
|
||||
type=float,
|
||||
default=180.0,
|
||||
help="Reward at which we stop training.")
|
||||
|
||||
|
||||
class NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv(TaskSettableEnv):
|
||||
"""Class for a single sub-env to be vectorized into RemoteVectorEnv.
|
||||
|
||||
If you specify this class directly under the "env" config key, RLlib
|
||||
will auto-wrap
|
||||
|
||||
Note that you may implement your own custom APIs. Here, we demonstrate
|
||||
using RLlib's TaskSettableEnv API (which is a simple sub-class
|
||||
of gym.Env).
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.action_space = gym.spaces.Box(0, 1, shape=(1, ))
|
||||
self.observation_space = gym.spaces.Box(0, 1, shape=(2, ))
|
||||
self.task = 1
|
||||
|
||||
def reset(self):
|
||||
self.steps = 0
|
||||
return self.observation_space.sample()
|
||||
|
||||
def step(self, action):
|
||||
self.steps += 1
|
||||
return self.observation_space.sample(), 0, self.steps > 10, {}
|
||||
|
||||
def set_task(self, task) -> None:
|
||||
"""We can set the task of each sub-env (ray actor)"""
|
||||
print("Task set to {}".format(task))
|
||||
self.task = task
|
||||
|
||||
|
||||
class TaskSettingCallback(DefaultCallbacks):
|
||||
"""Custom callback to verify, we can set the task on each remote sub-env.
|
||||
"""
|
||||
|
||||
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
|
||||
""" Curriculum learning as seen in Ray docs """
|
||||
if result["episode_reward_mean"] > 0.0:
|
||||
phase = 0
|
||||
else:
|
||||
phase = 1
|
||||
|
||||
# Sub-envs are now ray.actor.ActorHandles, so we have to add
|
||||
# `remote()` here.
|
||||
trainer.workers.foreach_env(lambda env: env.set_task.remote(phase))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init(num_cpus=6, local_mode=True)
|
||||
|
||||
config = {
|
||||
# Specify your custom (single, non-vectorized) env directly as a
|
||||
# class. This way, RLlib can auto-create Actors from this class
|
||||
# and handle everything correctly.
|
||||
# TODO: Test for multi-agent case.
|
||||
"env": NonVectorizedEnvToBeVectorizedIntoRemoteVectorEnv,
|
||||
# Set up our own callbacks.
|
||||
"callbacks": TaskSettingCallback,
|
||||
# Force sub-envs to be ray.actor.ActorHandles, so we can step
|
||||
# through them in parallel.
|
||||
"remote_worker_envs": True,
|
||||
# How many RolloutWorkers (each with n environment copies:
|
||||
# `num_envs_per_worker`)?
|
||||
"num_workers": args.num_workers,
|
||||
# This setting should not really matter as it does not affect the
|
||||
# number of GPUs reserved for each worker.
|
||||
"num_envs_per_worker": args.num_envs_per_worker,
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"framework": args.framework,
|
||||
}
|
||||
|
||||
stop = {
|
||||
"training_iteration": args.stop_iters,
|
||||
"timesteps_total": args.stop_timesteps,
|
||||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(args.run, config=config, stop=stop, verbose=1)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
ray.shutdown()
|
Loading…
Add table
Reference in a new issue