mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
parent
f48da50e1c
commit
ee0eb44a32
5 changed files with 15 additions and 17 deletions
|
@ -292,7 +292,7 @@ def test_gather_async(ray_start_regular_shared):
|
|||
|
||||
def test_gather_async_queue(ray_start_regular_shared):
|
||||
it = from_range(100)
|
||||
it = it.gather_async(async_queue_depth=4)
|
||||
it = it.gather_async(num_async=4)
|
||||
assert sorted(it) == list(range(100))
|
||||
|
||||
|
||||
|
|
|
@ -415,14 +415,14 @@ class ParallelIterator(Generic[T]):
|
|||
name = "{}.batch_across_shards()".format(self)
|
||||
return LocalIterator(base_iterator, SharedMetrics(), name=name)
|
||||
|
||||
def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
|
||||
def gather_async(self, num_async=1) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for asynchronous iteration.
|
||||
|
||||
New items will be fetched from the shards asynchronously as soon as
|
||||
the previous one is computed. Items arrive in non-deterministic order.
|
||||
|
||||
Arguments:
|
||||
async_queue_depth (int): The max number of async requests in flight
|
||||
num_async (int): The max number of async requests in flight
|
||||
per actor. Increasing this improves the amount of pipeline
|
||||
parallelism in the iterator.
|
||||
|
||||
|
@ -436,7 +436,7 @@ class ParallelIterator(Generic[T]):
|
|||
... 1
|
||||
"""
|
||||
|
||||
if async_queue_depth < 1:
|
||||
if num_async < 1:
|
||||
raise ValueError("queue depth must be positive")
|
||||
|
||||
# Forward reference to the returned iterator.
|
||||
|
@ -448,7 +448,7 @@ class ParallelIterator(Generic[T]):
|
|||
actor_set.init_actors()
|
||||
all_actors.extend(actor_set.actors)
|
||||
futures = {}
|
||||
for _ in range(async_queue_depth):
|
||||
for _ in range(num_async):
|
||||
for a in all_actors:
|
||||
futures[a.par_iter_next.remote()] = a
|
||||
while futures:
|
||||
|
|
|
@ -143,7 +143,7 @@ def execution_plan(workers, config):
|
|||
# We execute the following steps concurrently:
|
||||
# (1) Generate rollouts and store them in our replay buffer actors. Update
|
||||
# the weights of the worker that generated the batch.
|
||||
rollouts = ParallelRollouts(workers, mode="async", async_queue_depth=2)
|
||||
rollouts = ParallelRollouts(workers, mode="async", num_async=2)
|
||||
store_op = rollouts \
|
||||
.for_each(StoreToReplayBuffer(actors=replay_actors)) \
|
||||
.zip_with_source_actor() \
|
||||
|
@ -154,7 +154,7 @@ def execution_plan(workers, config):
|
|||
|
||||
# (2) Read experiences from the replay buffer actors and send to the
|
||||
# learner thread via its in-queue.
|
||||
replay_op = Replay(actors=replay_actors, async_queue_depth=4) \
|
||||
replay_op = Replay(actors=replay_actors, num_async=4) \
|
||||
.zip_with_source_actor() \
|
||||
.for_each(Enqueue(learner_thread.inqueue))
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class StoreToReplayBuffer:
|
|||
def Replay(*,
|
||||
local_buffer: LocalReplayBuffer = None,
|
||||
actors: List["ActorHandle"] = None,
|
||||
async_queue_depth=4):
|
||||
num_async=4):
|
||||
"""Replay experiences from the given buffer or actors.
|
||||
|
||||
This should be combined with the StoreToReplayActors operation using the
|
||||
|
@ -63,7 +63,7 @@ def Replay(*,
|
|||
and replay_actors can be specified.
|
||||
actors (list): List of replay actors. Only one of this and
|
||||
local_buffer can be specified.
|
||||
async_queue_depth (int): In async mode, the max number of async
|
||||
num_async (int): In async mode, the max number of async
|
||||
requests in flight per actor.
|
||||
|
||||
Examples:
|
||||
|
@ -79,8 +79,8 @@ def Replay(*,
|
|||
|
||||
if actors:
|
||||
replay = from_actors(actors)
|
||||
return replay.gather_async(async_queue_depth=async_queue_depth).filter(
|
||||
lambda x: x is not None)
|
||||
return replay.gather_async(
|
||||
num_async=num_async).filter(lambda x: x is not None)
|
||||
|
||||
def gen_replay(_):
|
||||
while True:
|
||||
|
|
|
@ -17,10 +17,8 @@ from ray.rllib.utils.sgd import standardized
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ParallelRollouts(workers: WorkerSet,
|
||||
*,
|
||||
mode="bulk_sync",
|
||||
async_queue_depth=1) -> LocalIterator[SampleBatch]:
|
||||
def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync",
|
||||
num_async=1) -> LocalIterator[SampleBatch]:
|
||||
"""Operator to collect experiences in parallel from rollout workers.
|
||||
|
||||
If there are no remote workers, experiences will be collected serially from
|
||||
|
@ -36,7 +34,7 @@ def ParallelRollouts(workers: WorkerSet,
|
|||
- In 'raw' mode, the ParallelIterator object is returned directly
|
||||
and the caller is responsible for implementing gather and
|
||||
updating the timesteps counter.
|
||||
async_queue_depth (int): In async mode, the max number of async
|
||||
num_async (int): In async mode, the max number of async
|
||||
requests in flight per actor.
|
||||
|
||||
Returns:
|
||||
|
@ -83,7 +81,7 @@ def ParallelRollouts(workers: WorkerSet,
|
|||
.for_each(report_timesteps)
|
||||
elif mode == "async":
|
||||
return rollouts.gather_async(
|
||||
async_queue_depth=async_queue_depth).for_each(report_timesteps)
|
||||
num_async=num_async).for_each(report_timesteps)
|
||||
elif mode == "raw":
|
||||
return rollouts
|
||||
else:
|
||||
|
|
Loading…
Add table
Reference in a new issue