mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[multiprocessing] Modify Ray's map_async() to match Multiprocessing's map_async() behavior (#19403)
This commit is contained in:
parent
2f8da8f8c8
commit
cfae64ebe8
2 changed files with 122 additions and 55 deletions
|
@ -36,6 +36,15 @@ def pool_4_processes():
|
|||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pool_4_processes_python_multiprocessing_lib():
|
||||
import multiprocessing as mp
|
||||
pool = mp.Pool(processes=4)
|
||||
yield pool
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
|
||||
|
||||
def test_ray_init(shutdown_only):
|
||||
def getpid(args):
|
||||
return os.getpid()
|
||||
|
@ -343,15 +352,8 @@ def test_starmap(pool):
|
|||
assert pool.starmap(lambda x, y: x + y, zip([1, 2], [3, 4])) == [4, 6]
|
||||
|
||||
|
||||
def test_callbacks(pool_4_processes):
|
||||
def f(args):
|
||||
time.sleep(0.1 * random.random())
|
||||
index = args[0]
|
||||
err_indices = args[1]
|
||||
if index in err_indices:
|
||||
raise Exception("intentional failure")
|
||||
return index
|
||||
|
||||
def test_callbacks(pool_4_processes,
|
||||
pool_4_processes_python_multiprocessing_lib):
|
||||
callback_queue = queue.Queue()
|
||||
|
||||
def callback(result):
|
||||
|
@ -361,41 +363,77 @@ def test_callbacks(pool_4_processes):
|
|||
callback_queue.put(error)
|
||||
|
||||
# Will not error, check that callback is called.
|
||||
result = pool_4_processes.apply_async(f, ((0, [1]), ), callback=callback)
|
||||
result = pool_4_processes.apply_async(
|
||||
callback_test_helper, ((0, [1]), ), callback=callback)
|
||||
assert callback_queue.get() == 0
|
||||
result.get()
|
||||
|
||||
# Will error, check that error_callback is called.
|
||||
result = pool_4_processes.apply_async(
|
||||
f, ((0, [0]), ), error_callback=error_callback)
|
||||
callback_test_helper, ((0, [0]), ), error_callback=error_callback)
|
||||
assert isinstance(callback_queue.get(), Exception)
|
||||
with pytest.raises(Exception, match="intentional failure"):
|
||||
result.get()
|
||||
|
||||
# Test callbacks for map_async.
|
||||
error_indices = [2, 50, 98]
|
||||
result = pool_4_processes.map_async(
|
||||
f, [(index, error_indices) for index in range(100)],
|
||||
callback=callback,
|
||||
error_callback=error_callback)
|
||||
callback_results = []
|
||||
while len(callback_results) < 100:
|
||||
callback_results.append(callback_queue.get())
|
||||
# Ensure Ray's map_async behavior matches Multiprocessing's map_async
|
||||
process_pools = [
|
||||
pool_4_processes, pool_4_processes_python_multiprocessing_lib
|
||||
]
|
||||
|
||||
assert result.ready()
|
||||
assert not result.successful()
|
||||
for process_pool in process_pools:
|
||||
# Test error callbacks for map_async.
|
||||
test_callback_types = ["regular callback", "error callback"]
|
||||
|
||||
# Check that callbacks were called on every result, error or not.
|
||||
assert len(callback_results) == 100
|
||||
# Check that callbacks were processed in the order that the tasks finished.
|
||||
# NOTE: this could be flaky if the calls happened to finish in order due
|
||||
# to the random sleeps, but it's very unlikely.
|
||||
assert not all(i in error_indices or i == result
|
||||
for i, result in enumerate(callback_results))
|
||||
# Check that the correct callbacks were called on errors/successes.
|
||||
assert all(index not in callback_results for index in error_indices)
|
||||
assert [isinstance(result, Exception)
|
||||
for result in callback_results].count(True) == len(error_indices)
|
||||
for callback_type in test_callback_types:
|
||||
# Reinitialize queue to track number of callback calls made by
|
||||
# the current process_pool and callback_type in map_async
|
||||
callback_queue = queue.Queue()
|
||||
|
||||
indices, error_indices = list(range(100)), []
|
||||
if callback_type == "error callback":
|
||||
error_indices = [2, 50, 98]
|
||||
result = process_pool.map_async(
|
||||
callback_test_helper,
|
||||
[(index, error_indices) for index in indices],
|
||||
callback=callback,
|
||||
error_callback=error_callback)
|
||||
callback_results = None
|
||||
result.wait()
|
||||
|
||||
callback_results = callback_queue.get()
|
||||
callback_queue.task_done()
|
||||
|
||||
# Ensure that callback or error_callback was called only once
|
||||
assert callback_queue.qsize() == 0
|
||||
|
||||
if callback_type == "regular callback":
|
||||
assert result.successful()
|
||||
else:
|
||||
assert not result.successful()
|
||||
|
||||
if callback_type == "regular callback":
|
||||
# Check that regular callback returned a list of all indices
|
||||
for index in callback_results:
|
||||
assert index in indices
|
||||
indices.remove(index)
|
||||
assert len(indices) == 0
|
||||
else:
|
||||
# Check that error callback returned a single exception
|
||||
assert isinstance(callback_results, Exception)
|
||||
|
||||
|
||||
def callback_test_helper(args):
|
||||
"""
|
||||
This is a helper function for the test_callbacks test. It must be placed
|
||||
outside the test because Python's Multiprocessing library uses Pickle to
|
||||
serialize functions, but Pickle cannot serialize local functions.
|
||||
"""
|
||||
time.sleep(0.1 * random.random())
|
||||
index = args[0]
|
||||
err_indices = args[1]
|
||||
if index in err_indices:
|
||||
raise Exception("intentional failure")
|
||||
return index
|
||||
|
||||
|
||||
def test_imap(pool_4_processes):
|
||||
|
|
|
@ -156,6 +156,7 @@ class PoolTaskError(Exception):
|
|||
class ResultThread(threading.Thread):
|
||||
def __init__(self,
|
||||
object_refs,
|
||||
single_result=False,
|
||||
callback=None,
|
||||
error_callback=None,
|
||||
total_object_refs=None):
|
||||
|
@ -165,6 +166,7 @@ class ResultThread(threading.Thread):
|
|||
self._num_ready = 0
|
||||
self._results = []
|
||||
self._ready_index_queue = queue.Queue()
|
||||
self._single_result = single_result
|
||||
self._callback = callback
|
||||
self._error_callback = error_callback
|
||||
self._total_object_refs = total_object_refs or len(object_refs)
|
||||
|
@ -185,6 +187,7 @@ class ResultThread(threading.Thread):
|
|||
|
||||
def run(self):
|
||||
unready = copy.copy(self._object_refs)
|
||||
aggregated_batch_results = []
|
||||
while self._num_ready < self._total_object_refs:
|
||||
# Get as many new IDs from the queue as possible without blocking,
|
||||
# unless we have no IDs to wait on, in which case we block.
|
||||
|
@ -203,18 +206,39 @@ class ResultThread(threading.Thread):
|
|||
batch = ray.get(ready_id)
|
||||
except ray.exceptions.RayError as e:
|
||||
batch = [e]
|
||||
for result in batch:
|
||||
if isinstance(result, Exception):
|
||||
self._got_error = True
|
||||
if self._error_callback is not None:
|
||||
self._error_callback(result)
|
||||
elif self._callback is not None:
|
||||
self._callback(result)
|
||||
|
||||
# The exception callback is called only once on the first result
|
||||
# that errors. If no result errors, it is never called.
|
||||
if not self._got_error:
|
||||
for result in batch:
|
||||
if isinstance(result, Exception):
|
||||
self._got_error = True
|
||||
if self._error_callback is not None:
|
||||
self._error_callback(result)
|
||||
break
|
||||
else:
|
||||
aggregated_batch_results.append(result)
|
||||
|
||||
self._num_ready += 1
|
||||
self._results[self._indices[ready_id]] = batch
|
||||
self._ready_index_queue.put(self._indices[ready_id])
|
||||
|
||||
# The regular callback is called only once on the entire List of
|
||||
# results as long as none of the results were errors. If any results
|
||||
# were errors, the regular callback is never called; instead, the
|
||||
# exception callback is called on the first erroring result.
|
||||
#
|
||||
# This callback is called outside the while loop to ensure that it's
|
||||
# called on the entire list of results– not just a single batch.
|
||||
if not self._got_error and self._callback is not None:
|
||||
if not self._single_result:
|
||||
self._callback(aggregated_batch_results)
|
||||
else:
|
||||
# On a thread handling a function with a single result
|
||||
# (e.g. apply_async), we call the callback on just that result
|
||||
# instead of on a list encaspulating that result
|
||||
self._callback(aggregated_batch_results[0])
|
||||
|
||||
def got_error(self):
|
||||
# Should only be called after the thread finishes.
|
||||
return self._got_error
|
||||
|
@ -247,8 +271,8 @@ class AsyncResult:
|
|||
error_callback=None,
|
||||
single_result=False):
|
||||
self._single_result = single_result
|
||||
self._result_thread = ResultThread(chunk_object_refs, callback,
|
||||
error_callback)
|
||||
self._result_thread = ResultThread(chunk_object_refs, single_result,
|
||||
callback, error_callback)
|
||||
self._result_thread.start()
|
||||
|
||||
def wait(self, timeout=None):
|
||||
|
@ -569,8 +593,8 @@ class Pool:
|
|||
func,
|
||||
args=None,
|
||||
kwargs=None,
|
||||
callback=None,
|
||||
error_callback=None):
|
||||
callback: Callable[[Any], None] = None,
|
||||
error_callback: Callable[[Exception], None] = None):
|
||||
"""Run the given function on a random actor process and return an
|
||||
asynchronous interface to the result.
|
||||
|
||||
|
@ -579,9 +603,9 @@ class Pool:
|
|||
args: optional arguments to the function.
|
||||
kwargs: optional keyword arguments to the function.
|
||||
callback: callback to be executed on the result once it is finished
|
||||
if it succeeds.
|
||||
only if it succeeds.
|
||||
error_callback: callback to be executed the result once it is
|
||||
finished if the task errors. The exception raised by the
|
||||
finished only if the task errors. The exception raised by the
|
||||
task will be passed as the only argument to the callback.
|
||||
|
||||
Returns:
|
||||
|
@ -712,8 +736,8 @@ class Pool:
|
|||
func,
|
||||
iterable,
|
||||
chunksize=None,
|
||||
callback=None,
|
||||
error_callback=None):
|
||||
callback: Callable[[List], None] = None,
|
||||
error_callback: Callable[[Exception], None] = None):
|
||||
"""Run the given function on each element in the iterable round-robin
|
||||
on the actor processes and return an asynchronous interface to the
|
||||
results.
|
||||
|
@ -724,11 +748,13 @@ class Pool:
|
|||
func.
|
||||
chunksize: number of tasks to submit as a batch to each actor
|
||||
process. If unspecified, a suitable chunksize will be chosen.
|
||||
callback: callback to be executed on each successful result once it
|
||||
is finished.
|
||||
error_callback: callback to be executed on each errored result once
|
||||
it is finished. The exception raised by the task will be passed
|
||||
as the only argument to the callback.
|
||||
callback: Will only be called if none of the results were errors,
|
||||
and will only be called once after all results are finished.
|
||||
A Python List of all the finished results will be passed as the
|
||||
only argument to the callback.
|
||||
error_callback: callback executed on the first errored result.
|
||||
The Exception raised by the task will be passed as the only
|
||||
argument to the callback.
|
||||
|
||||
Returns:
|
||||
AsyncResult
|
||||
|
@ -749,8 +775,11 @@ class Pool:
|
|||
return self._map_async(
|
||||
func, iterable, chunksize=chunksize, unpack_args=True).get()
|
||||
|
||||
def starmap_async(self, func, iterable, callback=None,
|
||||
error_callback=None):
|
||||
def starmap_async(self,
|
||||
func,
|
||||
iterable,
|
||||
callback: Callable[[List], None] = None,
|
||||
error_callback: Callable[[Exception], None] = None):
|
||||
"""Same as `map_async`, but unpacks each element of the iterable as the
|
||||
arguments to func like: [func(*args) for args in iterable].
|
||||
"""
|
||||
|
|
Loading…
Add table
Reference in a new issue