mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
138 lines
4.7 KiB
Python
138 lines
4.7 KiB
Python
import logging
|
|
from typing import Tuple, Callable, Optional
|
|
|
|
import ray
|
|
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
|
from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID, MultiAgentDict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@PublicAPI
|
|
class RemoteVectorEnv(BaseEnv):
|
|
"""Vector env that executes envs in remote workers.
|
|
|
|
This provides dynamic batching of inference as observations are returned
|
|
from the remote simulator actors. Both single and multi-agent child envs
|
|
are supported, and envs can be stepped synchronously or async.
|
|
|
|
You shouldn't need to instantiate this class directly. It's automatically
|
|
inserted when you use the `remote_worker_envs` option for Trainers.
|
|
"""
|
|
|
|
def __init__(self, make_env: Callable[[int], EnvType], num_envs: int,
|
|
multiagent: bool, remote_env_batch_wait_ms: int):
|
|
self.make_local_env = make_env
|
|
self.num_envs = num_envs
|
|
self.multiagent = multiagent
|
|
self.poll_timeout = remote_env_batch_wait_ms / 1000
|
|
|
|
self.actors = None # lazy init
|
|
self.pending = None # lazy init
|
|
|
|
@override(BaseEnv)
|
|
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
|
|
MultiEnvDict, MultiEnvDict]:
|
|
if self.actors is None:
|
|
|
|
def make_remote_env(i):
|
|
logger.info("Launching env {} in remote actor".format(i))
|
|
if self.multiagent:
|
|
return _RemoteMultiAgentEnv.remote(self.make_local_env, i)
|
|
else:
|
|
return _RemoteSingleAgentEnv.remote(self.make_local_env, i)
|
|
|
|
self.actors = [make_remote_env(i) for i in range(self.num_envs)]
|
|
|
|
if self.pending is None:
|
|
self.pending = {a.reset.remote(): a for a in self.actors}
|
|
|
|
# each keyed by env_id in [0, num_remote_envs)
|
|
obs, rewards, dones, infos = {}, {}, {}, {}
|
|
ready = []
|
|
|
|
# Wait for at least 1 env to be ready here
|
|
while not ready:
|
|
ready, _ = ray.wait(
|
|
list(self.pending),
|
|
num_returns=len(self.pending),
|
|
timeout=self.poll_timeout)
|
|
|
|
# Get and return observations for each of the ready envs
|
|
env_ids = set()
|
|
for obj_ref in ready:
|
|
actor = self.pending.pop(obj_ref)
|
|
env_id = self.actors.index(actor)
|
|
env_ids.add(env_id)
|
|
ob, rew, done, info = ray.get(obj_ref)
|
|
obs[env_id] = ob
|
|
rewards[env_id] = rew
|
|
dones[env_id] = done
|
|
infos[env_id] = info
|
|
|
|
logger.debug("Got obs batch for actors {}".format(env_ids))
|
|
return obs, rewards, dones, infos, {}
|
|
|
|
@PublicAPI
|
|
def send_actions(self, action_dict: MultiEnvDict) -> None:
|
|
for env_id, actions in action_dict.items():
|
|
actor = self.actors[env_id]
|
|
obj_ref = actor.step.remote(actions)
|
|
self.pending[obj_ref] = actor
|
|
|
|
@PublicAPI
|
|
def try_reset(self,
|
|
env_id: Optional[EnvID] = None) -> Optional[MultiAgentDict]:
|
|
actor = self.actors[env_id]
|
|
obj_ref = actor.reset.remote()
|
|
self.pending[obj_ref] = actor
|
|
return ASYNC_RESET_RETURN
|
|
|
|
@PublicAPI
|
|
def stop(self) -> None:
|
|
if self.actors is not None:
|
|
for actor in self.actors:
|
|
actor.__ray_terminate__.remote()
|
|
|
|
|
|
@ray.remote(num_cpus=0)
|
|
class _RemoteMultiAgentEnv:
|
|
"""Wrapper class for making a multi-agent env a remote actor."""
|
|
|
|
def __init__(self, make_env, i):
|
|
self.env = make_env(i)
|
|
|
|
def reset(self):
|
|
obs = self.env.reset()
|
|
# each keyed by agent_id in the env
|
|
rew = {agent_id: 0 for agent_id in obs.keys()}
|
|
info = {agent_id: {} for agent_id in obs.keys()}
|
|
done = {"__all__": False}
|
|
return obs, rew, done, info
|
|
|
|
def step(self, action_dict):
|
|
return self.env.step(action_dict)
|
|
|
|
|
|
@ray.remote(num_cpus=0)
|
|
class _RemoteSingleAgentEnv:
|
|
"""Wrapper class for making a gym env a remote actor."""
|
|
|
|
def __init__(self, make_env, i):
|
|
self.env = make_env(i)
|
|
|
|
def reset(self):
|
|
obs = {_DUMMY_AGENT_ID: self.env.reset()}
|
|
rew = {agent_id: 0 for agent_id in obs.keys()}
|
|
info = {agent_id: {} for agent_id in obs.keys()}
|
|
done = {"__all__": False}
|
|
return obs, rew, done, info
|
|
|
|
def step(self, action):
|
|
obs, rew, done, info = self.env.step(action[_DUMMY_AGENT_ID])
|
|
obs, rew, done, info = [{
|
|
_DUMMY_AGENT_ID: x
|
|
} for x in [obs, rew, done, info]]
|
|
done["__all__"] = done[_DUMMY_AGENT_ID]
|
|
return obs, rew, done, info
|