ray/rllib/execution/parallel_requests.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

155 lines
6.2 KiB
Python

import logging
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set
import ray
from ray.actor import ActorHandle
from ray.rllib.utils.annotations import ExperimentalAPI
logger = logging.getLogger(__name__)
@ExperimentalAPI
def asynchronous_parallel_requests(
remote_requests_in_flight: DefaultDict[ActorHandle, Set[ray.ObjectRef]],
actors: List[ActorHandle],
ray_wait_timeout_s: Optional[float] = None,
max_remote_requests_in_flight_per_actor: int = 2,
remote_fn: Optional[Callable[[ActorHandle, Any, Any], Any]] = None,
remote_args: Optional[List[List[Any]]] = None,
remote_kwargs: Optional[List[Dict[str, Any]]] = None,
) -> Dict[ActorHandle, Any]:
"""Runs parallel and asynchronous rollouts on all remote workers.
May use a timeout (if provided) on `ray.wait()` and returns only those
samples that could be gathered in the timeout window. Allows a maximum
of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight
per remote actor.
Alternatively to calling `actor.sample.remote()`, the user can provide a
`remote_fn()`, which will be applied to the actor(s) instead.
Args:
remote_requests_in_flight: Dict mapping actor handles to a set of
their currently-in-flight pending requests (those we expect to
ray.get results for next). If you have an RLlib Trainer that calls
this function, you can use its `self.remote_requests_in_flight`
property here.
actors: The List of ActorHandles to perform the remote requests on.
ray_wait_timeout_s: Timeout (in sec) to be used for the underlying
`ray.wait()` calls. If None (default), never time out (block
until at least one actor returns something).
max_remote_requests_in_flight_per_actor: Maximum number of remote
requests sent to each actor. 2 (default) is probably
sufficient to avoid idle times between two requests.
remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of
`actor.sample.remote()` to generate the requests.
remote_args: If provided, use this list (per-actor) of lists (call
args) as *args to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_args=[[...] <- *args for A, [...] <- *args for B].
remote_kwargs: If provided, use this list (per-actor) of dicts
(kwargs) as **kwargs to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B].
Returns:
A dict mapping actor handles to the results received by sending requests
to these actors.
None, if no samples are ready.
Examples:
>>> # 2 remote rollout workers (num_workers=2):
>>> batches = asynchronous_parallel_sample(
... trainer.remote_requests_in_flight,
... actors=trainer.workers.remote_workers(),
... ray_wait_timeout_s=0.1,
... remote_fn=lambda w: time.sleep(1) # sleep 1sec
... )
>>> print(len(batches))
... 2
>>> # Expect a timeout to have happened.
>>> batches[0] is None and batches[1] is None
... True
"""
if remote_args is not None:
assert len(remote_args) == len(actors)
if remote_kwargs is not None:
assert len(remote_kwargs) == len(actors)
# For faster hash lookup.
actor_set = set(actors)
# Collect all currently pending remote requests into a single set of
# object refs.
pending_remotes = set()
# Also build a map to get the associated actor for each remote request.
remote_to_actor = {}
for actor, set_ in remote_requests_in_flight.items():
# Only consider those actors' pending requests that are in
# the given `actors` list.
if actor in actor_set:
pending_remotes |= set_
for r in set_:
remote_to_actor[r] = actor
# Add new requests, if possible (if
# `max_remote_requests_in_flight_per_actor` setting allows it).
for actor_idx, actor in enumerate(actors):
# Still room for another request to this actor.
if (
len(remote_requests_in_flight[actor])
< max_remote_requests_in_flight_per_actor
):
if remote_fn is None:
req = actor.sample.remote()
else:
args = remote_args[actor_idx] if remote_args else []
kwargs = remote_kwargs[actor_idx] if remote_kwargs else {}
req = actor.apply.remote(remote_fn, *args, **kwargs)
# Add to our set to send to ray.wait().
pending_remotes.add(req)
# Keep our mappings properly updated.
remote_requests_in_flight[actor].add(req)
remote_to_actor[req] = actor
# There must always be pending remote requests.
assert len(pending_remotes) > 0
pending_remote_list = list(pending_remotes)
# No timeout: Block until at least one result is returned.
if ray_wait_timeout_s is None:
# First try to do a `ray.wait` w/o timeout for efficiency.
ready, _ = ray.wait(
pending_remote_list, num_returns=len(pending_remotes), timeout=0
)
# Nothing returned and `timeout` is None -> Fall back to a
# blocking wait to make sure we can return something.
if not ready:
ready, _ = ray.wait(pending_remote_list, num_returns=1)
# Timeout: Do a `ray.wait() call` w/ timeout.
else:
ready, _ = ray.wait(
pending_remote_list,
num_returns=len(pending_remotes),
timeout=ray_wait_timeout_s,
)
# Return empty results if nothing ready after the timeout.
if not ready:
return {}
# Remove in-flight records for ready refs.
for obj_ref in ready:
remote_requests_in_flight[remote_to_actor[obj_ref]].remove(obj_ref)
# Do one ray.get().
results = ray.get(ready)
assert len(ready) == len(results)
# Return mapping from (ready) actors to their results.
ret = {}
for obj_ref, result in zip(ready, results):
ret[remote_to_actor[obj_ref]] = result
return ret