[multiprocessing] Modify Ray's map_async() to match Multiprocessing's map_async() behavior (#19403)

This commit is contained in:
shrekris-anyscale 2021-10-22 14:31:34 -07:00 committed by GitHub
parent 2f8da8f8c8
commit cfae64ebe8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 122 additions and 55 deletions

View file

@ -36,6 +36,15 @@ def pool_4_processes():
ray.shutdown()
@pytest.fixture
def pool_4_processes_python_multiprocessing_lib():
import multiprocessing as mp
pool = mp.Pool(processes=4)
yield pool
pool.terminate()
pool.join()
def test_ray_init(shutdown_only):
def getpid(args):
return os.getpid()
@ -343,15 +352,8 @@ def test_starmap(pool):
assert pool.starmap(lambda x, y: x + y, zip([1, 2], [3, 4])) == [4, 6]
def test_callbacks(pool_4_processes):
def f(args):
time.sleep(0.1 * random.random())
index = args[0]
err_indices = args[1]
if index in err_indices:
raise Exception("intentional failure")
return index
def test_callbacks(pool_4_processes,
pool_4_processes_python_multiprocessing_lib):
callback_queue = queue.Queue()
def callback(result):
@ -361,41 +363,77 @@ def test_callbacks(pool_4_processes):
callback_queue.put(error)
# Will not error, check that callback is called.
result = pool_4_processes.apply_async(f, ((0, [1]), ), callback=callback)
result = pool_4_processes.apply_async(
callback_test_helper, ((0, [1]), ), callback=callback)
assert callback_queue.get() == 0
result.get()
# Will error, check that error_callback is called.
result = pool_4_processes.apply_async(
f, ((0, [0]), ), error_callback=error_callback)
callback_test_helper, ((0, [0]), ), error_callback=error_callback)
assert isinstance(callback_queue.get(), Exception)
with pytest.raises(Exception, match="intentional failure"):
result.get()
# Test callbacks for map_async.
error_indices = [2, 50, 98]
result = pool_4_processes.map_async(
f, [(index, error_indices) for index in range(100)],
callback=callback,
error_callback=error_callback)
callback_results = []
while len(callback_results) < 100:
callback_results.append(callback_queue.get())
# Ensure Ray's map_async behavior matches Multiprocessing's map_async
process_pools = [
pool_4_processes, pool_4_processes_python_multiprocessing_lib
]
assert result.ready()
assert not result.successful()
for process_pool in process_pools:
# Test error callbacks for map_async.
test_callback_types = ["regular callback", "error callback"]
# Check that callbacks were called on every result, error or not.
assert len(callback_results) == 100
# Check that callbacks were processed in the order that the tasks finished.
# NOTE: this could be flaky if the calls happened to finish in order due
# to the random sleeps, but it's very unlikely.
assert not all(i in error_indices or i == result
for i, result in enumerate(callback_results))
# Check that the correct callbacks were called on errors/successes.
assert all(index not in callback_results for index in error_indices)
assert [isinstance(result, Exception)
for result in callback_results].count(True) == len(error_indices)
for callback_type in test_callback_types:
# Reinitialize queue to track number of callback calls made by
# the current process_pool and callback_type in map_async
callback_queue = queue.Queue()
indices, error_indices = list(range(100)), []
if callback_type == "error callback":
error_indices = [2, 50, 98]
result = process_pool.map_async(
callback_test_helper,
[(index, error_indices) for index in indices],
callback=callback,
error_callback=error_callback)
callback_results = None
result.wait()
callback_results = callback_queue.get()
callback_queue.task_done()
# Ensure that callback or error_callback was called only once
assert callback_queue.qsize() == 0
if callback_type == "regular callback":
assert result.successful()
else:
assert not result.successful()
if callback_type == "regular callback":
# Check that regular callback returned a list of all indices
for index in callback_results:
assert index in indices
indices.remove(index)
assert len(indices) == 0
else:
# Check that error callback returned a single exception
assert isinstance(callback_results, Exception)
def callback_test_helper(args):
"""
This is a helper function for the test_callbacks test. It must be placed
outside the test because Python's Multiprocessing library uses Pickle to
serialize functions, but Pickle cannot serialize local functions.
"""
time.sleep(0.1 * random.random())
index = args[0]
err_indices = args[1]
if index in err_indices:
raise Exception("intentional failure")
return index
def test_imap(pool_4_processes):

View file

@ -156,6 +156,7 @@ class PoolTaskError(Exception):
class ResultThread(threading.Thread):
def __init__(self,
object_refs,
single_result=False,
callback=None,
error_callback=None,
total_object_refs=None):
@ -165,6 +166,7 @@ class ResultThread(threading.Thread):
self._num_ready = 0
self._results = []
self._ready_index_queue = queue.Queue()
self._single_result = single_result
self._callback = callback
self._error_callback = error_callback
self._total_object_refs = total_object_refs or len(object_refs)
@ -185,6 +187,7 @@ class ResultThread(threading.Thread):
def run(self):
unready = copy.copy(self._object_refs)
aggregated_batch_results = []
while self._num_ready < self._total_object_refs:
# Get as many new IDs from the queue as possible without blocking,
# unless we have no IDs to wait on, in which case we block.
@ -203,18 +206,39 @@ class ResultThread(threading.Thread):
batch = ray.get(ready_id)
except ray.exceptions.RayError as e:
batch = [e]
for result in batch:
if isinstance(result, Exception):
self._got_error = True
if self._error_callback is not None:
self._error_callback(result)
elif self._callback is not None:
self._callback(result)
# The exception callback is called only once on the first result
# that errors. If no result errors, it is never called.
if not self._got_error:
for result in batch:
if isinstance(result, Exception):
self._got_error = True
if self._error_callback is not None:
self._error_callback(result)
break
else:
aggregated_batch_results.append(result)
self._num_ready += 1
self._results[self._indices[ready_id]] = batch
self._ready_index_queue.put(self._indices[ready_id])
# The regular callback is called only once on the entire List of
# results as long as none of the results were errors. If any results
# were errors, the regular callback is never called; instead, the
# exception callback is called on the first erroring result.
#
# This callback is called outside the while loop to ensure that it's
# called on the entire list of results not just a single batch.
if not self._got_error and self._callback is not None:
if not self._single_result:
self._callback(aggregated_batch_results)
else:
# On a thread handling a function with a single result
# (e.g. apply_async), we call the callback on just that result
# instead of on a list encaspulating that result
self._callback(aggregated_batch_results[0])
def got_error(self):
# Should only be called after the thread finishes.
return self._got_error
@ -247,8 +271,8 @@ class AsyncResult:
error_callback=None,
single_result=False):
self._single_result = single_result
self._result_thread = ResultThread(chunk_object_refs, callback,
error_callback)
self._result_thread = ResultThread(chunk_object_refs, single_result,
callback, error_callback)
self._result_thread.start()
def wait(self, timeout=None):
@ -569,8 +593,8 @@ class Pool:
func,
args=None,
kwargs=None,
callback=None,
error_callback=None):
callback: Callable[[Any], None] = None,
error_callback: Callable[[Exception], None] = None):
"""Run the given function on a random actor process and return an
asynchronous interface to the result.
@ -579,9 +603,9 @@ class Pool:
args: optional arguments to the function.
kwargs: optional keyword arguments to the function.
callback: callback to be executed on the result once it is finished
if it succeeds.
only if it succeeds.
error_callback: callback to be executed the result once it is
finished if the task errors. The exception raised by the
finished only if the task errors. The exception raised by the
task will be passed as the only argument to the callback.
Returns:
@ -712,8 +736,8 @@ class Pool:
func,
iterable,
chunksize=None,
callback=None,
error_callback=None):
callback: Callable[[List], None] = None,
error_callback: Callable[[Exception], None] = None):
"""Run the given function on each element in the iterable round-robin
on the actor processes and return an asynchronous interface to the
results.
@ -724,11 +748,13 @@ class Pool:
func.
chunksize: number of tasks to submit as a batch to each actor
process. If unspecified, a suitable chunksize will be chosen.
callback: callback to be executed on each successful result once it
is finished.
error_callback: callback to be executed on each errored result once
it is finished. The exception raised by the task will be passed
as the only argument to the callback.
callback: Will only be called if none of the results were errors,
and will only be called once after all results are finished.
A Python List of all the finished results will be passed as the
only argument to the callback.
error_callback: callback executed on the first errored result.
The Exception raised by the task will be passed as the only
argument to the callback.
Returns:
AsyncResult
@ -749,8 +775,11 @@ class Pool:
return self._map_async(
func, iterable, chunksize=chunksize, unpack_args=True).get()
def starmap_async(self, func, iterable, callback=None,
error_callback=None):
def starmap_async(self,
func,
iterable,
callback: Callable[[List], None] = None,
error_callback: Callable[[Exception], None] = None):
"""Same as `map_async`, but unpacks each element of the iterable as the
arguments to func like: [func(*args) for args in iterable].
"""