Rename async_queue_depth -> num_async (#8207)

* rename

* lint
This commit is contained in:
Eric Liang 2020-05-05 01:38:10 -07:00 committed by GitHub
parent f48da50e1c
commit ee0eb44a32
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 15 additions and 17 deletions

View file

@ -292,7 +292,7 @@ def test_gather_async(ray_start_regular_shared):
def test_gather_async_queue(ray_start_regular_shared): def test_gather_async_queue(ray_start_regular_shared):
it = from_range(100) it = from_range(100)
it = it.gather_async(async_queue_depth=4) it = it.gather_async(num_async=4)
assert sorted(it) == list(range(100)) assert sorted(it) == list(range(100))

View file

@ -415,14 +415,14 @@ class ParallelIterator(Generic[T]):
name = "{}.batch_across_shards()".format(self) name = "{}.batch_across_shards()".format(self)
return LocalIterator(base_iterator, SharedMetrics(), name=name) return LocalIterator(base_iterator, SharedMetrics(), name=name)
def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]": def gather_async(self, num_async=1) -> "LocalIterator[T]":
"""Returns a local iterable for asynchronous iteration. """Returns a local iterable for asynchronous iteration.
New items will be fetched from the shards asynchronously as soon as New items will be fetched from the shards asynchronously as soon as
the previous one is computed. Items arrive in non-deterministic order. the previous one is computed. Items arrive in non-deterministic order.
Arguments: Arguments:
async_queue_depth (int): The max number of async requests in flight num_async (int): The max number of async requests in flight
per actor. Increasing this improves the amount of pipeline per actor. Increasing this improves the amount of pipeline
parallelism in the iterator. parallelism in the iterator.
@ -436,7 +436,7 @@ class ParallelIterator(Generic[T]):
... 1 ... 1
""" """
if async_queue_depth < 1: if num_async < 1:
raise ValueError("queue depth must be positive") raise ValueError("queue depth must be positive")
# Forward reference to the returned iterator. # Forward reference to the returned iterator.
@ -448,7 +448,7 @@ class ParallelIterator(Generic[T]):
actor_set.init_actors() actor_set.init_actors()
all_actors.extend(actor_set.actors) all_actors.extend(actor_set.actors)
futures = {} futures = {}
for _ in range(async_queue_depth): for _ in range(num_async):
for a in all_actors: for a in all_actors:
futures[a.par_iter_next.remote()] = a futures[a.par_iter_next.remote()] = a
while futures: while futures:

View file

@ -143,7 +143,7 @@ def execution_plan(workers, config):
# We execute the following steps concurrently: # We execute the following steps concurrently:
# (1) Generate rollouts and store them in our replay buffer actors. Update # (1) Generate rollouts and store them in our replay buffer actors. Update
# the weights of the worker that generated the batch. # the weights of the worker that generated the batch.
rollouts = ParallelRollouts(workers, mode="async", async_queue_depth=2) rollouts = ParallelRollouts(workers, mode="async", num_async=2)
store_op = rollouts \ store_op = rollouts \
.for_each(StoreToReplayBuffer(actors=replay_actors)) \ .for_each(StoreToReplayBuffer(actors=replay_actors)) \
.zip_with_source_actor() \ .zip_with_source_actor() \
@ -154,7 +154,7 @@ def execution_plan(workers, config):
# (2) Read experiences from the replay buffer actors and send to the # (2) Read experiences from the replay buffer actors and send to the
# learner thread via its in-queue. # learner thread via its in-queue.
replay_op = Replay(actors=replay_actors, async_queue_depth=4) \ replay_op = Replay(actors=replay_actors, num_async=4) \
.zip_with_source_actor() \ .zip_with_source_actor() \
.for_each(Enqueue(learner_thread.inqueue)) .for_each(Enqueue(learner_thread.inqueue))

View file

@ -52,7 +52,7 @@ class StoreToReplayBuffer:
def Replay(*, def Replay(*,
local_buffer: LocalReplayBuffer = None, local_buffer: LocalReplayBuffer = None,
actors: List["ActorHandle"] = None, actors: List["ActorHandle"] = None,
async_queue_depth=4): num_async=4):
"""Replay experiences from the given buffer or actors. """Replay experiences from the given buffer or actors.
This should be combined with the StoreToReplayActors operation using the This should be combined with the StoreToReplayActors operation using the
@ -63,7 +63,7 @@ def Replay(*,
and replay_actors can be specified. and replay_actors can be specified.
actors (list): List of replay actors. Only one of this and actors (list): List of replay actors. Only one of this and
local_buffer can be specified. local_buffer can be specified.
async_queue_depth (int): In async mode, the max number of async num_async (int): In async mode, the max number of async
requests in flight per actor. requests in flight per actor.
Examples: Examples:
@ -79,8 +79,8 @@ def Replay(*,
if actors: if actors:
replay = from_actors(actors) replay = from_actors(actors)
return replay.gather_async(async_queue_depth=async_queue_depth).filter( return replay.gather_async(
lambda x: x is not None) num_async=num_async).filter(lambda x: x is not None)
def gen_replay(_): def gen_replay(_):
while True: while True:

View file

@ -17,10 +17,8 @@ from ray.rllib.utils.sgd import standardized
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def ParallelRollouts(workers: WorkerSet, def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync",
*, num_async=1) -> LocalIterator[SampleBatch]:
mode="bulk_sync",
async_queue_depth=1) -> LocalIterator[SampleBatch]:
"""Operator to collect experiences in parallel from rollout workers. """Operator to collect experiences in parallel from rollout workers.
If there are no remote workers, experiences will be collected serially from If there are no remote workers, experiences will be collected serially from
@ -36,7 +34,7 @@ def ParallelRollouts(workers: WorkerSet,
- In 'raw' mode, the ParallelIterator object is returned directly - In 'raw' mode, the ParallelIterator object is returned directly
and the caller is responsible for implementing gather and and the caller is responsible for implementing gather and
updating the timesteps counter. updating the timesteps counter.
async_queue_depth (int): In async mode, the max number of async num_async (int): In async mode, the max number of async
requests in flight per actor. requests in flight per actor.
Returns: Returns:
@ -83,7 +81,7 @@ def ParallelRollouts(workers: WorkerSet,
.for_each(report_timesteps) .for_each(report_timesteps)
elif mode == "async": elif mode == "async":
return rollouts.gather_async( return rollouts.gather_async(
async_queue_depth=async_queue_depth).for_each(report_timesteps) num_async=num_async).for_each(report_timesteps)
elif mode == "raw": elif mode == "raw":
return rollouts return rollouts
else: else: