Rename async_queue_depth -> num_async (#8207)

* rename

* lint
This commit is contained in:
Eric Liang 2020-05-05 01:38:10 -07:00 committed by GitHub
parent f48da50e1c
commit ee0eb44a32
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 15 additions and 17 deletions

View file

@ -292,7 +292,7 @@ def test_gather_async(ray_start_regular_shared):
def test_gather_async_queue(ray_start_regular_shared):
it = from_range(100)
it = it.gather_async(async_queue_depth=4)
it = it.gather_async(num_async=4)
assert sorted(it) == list(range(100))

View file

@ -415,14 +415,14 @@ class ParallelIterator(Generic[T]):
name = "{}.batch_across_shards()".format(self)
return LocalIterator(base_iterator, SharedMetrics(), name=name)
def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
def gather_async(self, num_async=1) -> "LocalIterator[T]":
"""Returns a local iterable for asynchronous iteration.
New items will be fetched from the shards asynchronously as soon as
the previous one is computed. Items arrive in non-deterministic order.
Arguments:
async_queue_depth (int): The max number of async requests in flight
num_async (int): The max number of async requests in flight
per actor. Increasing this improves the amount of pipeline
parallelism in the iterator.
@ -436,7 +436,7 @@ class ParallelIterator(Generic[T]):
... 1
"""
if async_queue_depth < 1:
if num_async < 1:
raise ValueError("queue depth must be positive")
# Forward reference to the returned iterator.
@ -448,7 +448,7 @@ class ParallelIterator(Generic[T]):
actor_set.init_actors()
all_actors.extend(actor_set.actors)
futures = {}
for _ in range(async_queue_depth):
for _ in range(num_async):
for a in all_actors:
futures[a.par_iter_next.remote()] = a
while futures:

View file

@ -143,7 +143,7 @@ def execution_plan(workers, config):
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our replay buffer actors. Update
# the weights of the worker that generated the batch.
rollouts = ParallelRollouts(workers, mode="async", async_queue_depth=2)
rollouts = ParallelRollouts(workers, mode="async", num_async=2)
store_op = rollouts \
.for_each(StoreToReplayBuffer(actors=replay_actors)) \
.zip_with_source_actor() \
@ -154,7 +154,7 @@ def execution_plan(workers, config):
# (2) Read experiences from the replay buffer actors and send to the
# learner thread via its in-queue.
replay_op = Replay(actors=replay_actors, async_queue_depth=4) \
replay_op = Replay(actors=replay_actors, num_async=4) \
.zip_with_source_actor() \
.for_each(Enqueue(learner_thread.inqueue))

View file

@ -52,7 +52,7 @@ class StoreToReplayBuffer:
def Replay(*,
local_buffer: LocalReplayBuffer = None,
actors: List["ActorHandle"] = None,
async_queue_depth=4):
num_async=4):
"""Replay experiences from the given buffer or actors.
This should be combined with the StoreToReplayActors operation using the
@ -63,7 +63,7 @@ def Replay(*,
and replay_actors can be specified.
actors (list): List of replay actors. Only one of this and
local_buffer can be specified.
async_queue_depth (int): In async mode, the max number of async
num_async (int): In async mode, the max number of async
requests in flight per actor.
Examples:
@ -79,8 +79,8 @@ def Replay(*,
if actors:
replay = from_actors(actors)
return replay.gather_async(async_queue_depth=async_queue_depth).filter(
lambda x: x is not None)
return replay.gather_async(
num_async=num_async).filter(lambda x: x is not None)
def gen_replay(_):
while True:

View file

@ -17,10 +17,8 @@ from ray.rllib.utils.sgd import standardized
logger = logging.getLogger(__name__)
def ParallelRollouts(workers: WorkerSet,
*,
mode="bulk_sync",
async_queue_depth=1) -> LocalIterator[SampleBatch]:
def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync",
num_async=1) -> LocalIterator[SampleBatch]:
"""Operator to collect experiences in parallel from rollout workers.
If there are no remote workers, experiences will be collected serially from
@ -36,7 +34,7 @@ def ParallelRollouts(workers: WorkerSet,
- In 'raw' mode, the ParallelIterator object is returned directly
and the caller is responsible for implementing gather and
updating the timesteps counter.
async_queue_depth (int): In async mode, the max number of async
num_async (int): In async mode, the max number of async
requests in flight per actor.
Returns:
@ -83,7 +81,7 @@ def ParallelRollouts(workers: WorkerSet,
.for_each(report_timesteps)
elif mode == "async":
return rollouts.gather_async(
async_queue_depth=async_queue_depth).for_each(report_timesteps)
num_async=num_async).for_each(report_timesteps)
elif mode == "raw":
return rollouts
else: