Allow Ray API to be used from multiple threads (#2422)

This commit is contained in:
Hao Chen 2018-07-21 06:39:01 +08:00 committed by Robert Nishihara
parent 4b6157ed09
commit 05f485e274
5 changed files with 167 additions and 93 deletions

View file

@ -594,7 +594,6 @@ class ActorClass(object):
A handle to the newly created actor.
"""
worker = ray.worker.get_global_worker()
ray.worker.check_main_thread()
if worker.mode is None:
raise Exception("Actors cannot be created before ray.init() "
"has been called.")
@ -773,7 +772,6 @@ class ActorHandle(object):
worker = ray.worker.get_global_worker()
worker.check_connected()
ray.worker.check_main_thread()
function_signature = self._ray_method_signatures[method_name]
if args is None:
@ -929,7 +927,6 @@ class ActorHandle(object):
"""
worker = ray.worker.get_global_worker()
worker.check_connected()
ray.worker.check_main_thread()
if state["ray_forking"]:
actor_handle_id = compute_actor_handle_id(

View file

@ -114,7 +114,6 @@ class RemoteFunction(object):
"""An experimental alternate way to submit remote functions."""
worker = ray.worker.get_global_worker()
worker.check_connected()
ray.worker.check_main_thread()
kwargs = {} if kwargs is None else kwargs
args = ray.signature.extend_args(self._function_signature, args,
kwargs)

View file

@ -3,10 +3,12 @@ from __future__ import division
from __future__ import print_function
import binascii
import functools
import hashlib
import numpy as np
import os
import sys
import threading
import time
import uuid
@ -295,3 +297,57 @@ def check_oversized_pickle(pickled, name, obj_type, worker):
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=worker.task_driver_id.id())
class _ThreadSafeProxy(object):
"""This class is used to create a thread-safe proxy for a given object.
Every method call will be guarded with a lock.
Attributes:
orig_obj (object): the original object.
lock (threading.Lock): the lock object.
_wrapper_cache (dict): a cache from original object's methods to
the proxy methods.
"""
def __init__(self, orig_obj, lock):
self.orig_obj = orig_obj
self.lock = lock
self._wrapper_cache = {}
def __getattr__(self, attr):
orig_attr = getattr(self.orig_obj, attr)
if not callable(orig_attr):
# If the original attr is a field, just return it.
return orig_attr
else:
# If the orginal attr is a method,
# return a wrapper that guards the original method with a lock.
wrapper = self._wrapper_cache.get(attr)
if wrapper is None:
@functools.wraps(orig_attr)
def _wrapper(*args, **kwargs):
with self.lock:
return orig_attr(*args, **kwargs)
self._wrapper_cache[attr] = _wrapper
wrapper = _wrapper
return wrapper
def thread_safe_client(client, lock=None):
"""Create a thread-safe proxy which locks every method call
for the given client.
Args:
client: the client object to be guarded.
lock: the lock object that will be used to lock client's methods.
If None, a new lock will be used.
Returns:
A thread-safe proxy for the given client.
"""
if lock is None:
lock = threading.Lock()
return _ThreadSafeProxy(client, lock)

View file

@ -37,6 +37,7 @@ from ray.utils import (
check_oversized_pickle,
is_cython,
random_string,
thread_safe_client,
)
SCRIPT_MODE = 0
@ -200,6 +201,13 @@ class Worker(object):
cached_functions_to_run (List): A list of functions to run on all of
the workers that should be exported as soon as connect is called.
profiler: the profiler used to aggregate profiling information.
state_lock (Lock):
Used to lock worker's non-thread-safe internal states:
1) task_index increment: make sure we generate unique task ids;
2) Object reconstruction: because the node manager will
recycle/return the worker's resources before/after reconstruction,
it's unsafe for multiple threads to call object
reconstruction simultaneously.
"""
def __init__(self):
@ -236,6 +244,7 @@ class Worker(object):
# CUDA_VISIBLE_DEVICES environment variable.
self.original_gpu_ids = ray.utils.get_cuda_visible_devices()
self.profiler = profiling.Profiler(self)
self.state_lock = threading.Lock()
def check_connected(self):
"""Check if the worker is connected.
@ -365,7 +374,7 @@ class Worker(object):
# Serialize and put the object in the object store.
try:
self.store_and_register(object_id, value)
except pyarrow.PlasmaObjectExists as e:
except pyarrow.PlasmaObjectExists:
# The object already exists in the object store, so there is no
# need to add it again. TODO(rkn): We need to compare the hashes
# and make sure that the objects are in fact the same. We also
@ -393,7 +402,7 @@ class Worker(object):
i + ray._config.worker_get_request_size())],
timeout, self.serialization_context)
return results
except pyarrow.lib.ArrowInvalid as e:
except pyarrow.lib.ArrowInvalid:
# TODO(ekl): the local scheduler could include relevant
# metadata in the task kill case for a better error message
invalid_error = RayTaskError(
@ -401,7 +410,7 @@ class Worker(object):
"Invalid return value: likely worker died or was killed "
"while executing the task.")
return [invalid_error] * len(object_ids)
except pyarrow.DeserializationCallbackError as e:
except pyarrow.DeserializationCallbackError:
# Wait a little bit for the import thread to import the class.
# If we currently have the worker lock, we need to release it
# so that the import thread can acquire it.
@ -466,52 +475,59 @@ class Worker(object):
for (i, val) in enumerate(final_results)
if val is plasma.ObjectNotAvailable
}
was_blocked = (len(unready_ids) > 0)
# Try reconstructing any objects we haven't gotten yet. Try to get them
# until at least get_timeout_milliseconds milliseconds passes, then
# repeat.
while len(unready_ids) > 0:
for unready_id in unready_ids:
if not self.use_raylet:
self.local_scheduler_client.reconstruct_objects(
[ray.ObjectID(unready_id)], False)
# Do another fetch for objects that aren't available locally yet,
# in case they were evicted since the last fetch. We divide the
# fetch into smaller fetches so as to not block the manager for a
# prolonged period of time in a single call.
object_ids_to_fetch = list(
map(plasma.ObjectID, unready_ids.keys()))
ray_object_ids_to_fetch = list(
map(ray.ObjectID, unready_ids.keys()))
for i in range(0, len(object_ids_to_fetch),
ray._config.worker_fetch_request_size()):
if not self.use_raylet:
self.plasma_client.fetch(object_ids_to_fetch[i:(
i + ray._config.worker_fetch_request_size())])
else:
self.local_scheduler_client.reconstruct_objects(
ray_object_ids_to_fetch[i:(
i + ray._config.worker_fetch_request_size())],
False)
results = self.retrieve_and_deserialize(
object_ids_to_fetch,
max([
ray._config.get_timeout_milliseconds(),
int(0.01 * len(unready_ids))
]))
# Remove any entries for objects we received during this iteration
# so we don't retrieve the same object twice.
for i, val in enumerate(results):
if val is not plasma.ObjectNotAvailable:
object_id = object_ids_to_fetch[i].binary()
index = unready_ids[object_id]
final_results[index] = val
unready_ids.pop(object_id)
# If there were objects that we weren't able to get locally, let the
# local scheduler know that we're now unblocked.
if was_blocked:
self.local_scheduler_client.notify_unblocked()
if len(unready_ids) > 0:
with self.state_lock:
# Try reconstructing any objects we haven't gotten yet. Try to
# get them until at least get_timeout_milliseconds
# milliseconds passes, then repeat.
while len(unready_ids) > 0:
for unready_id in unready_ids:
if not self.use_raylet:
self.local_scheduler_client.reconstruct_objects(
[ray.ObjectID(unready_id)], False)
# Do another fetch for objects that aren't available
# locally yet, in case they were evicted since the last
# fetch. We divide the fetch into smaller fetches so as
# to not block the manager for a prolonged period of time
# in a single call.
object_ids_to_fetch = [
plasma.ObjectID(unready_id)
for unready_id in unready_ids.keys()
]
ray_object_ids_to_fetch = [
ray.ObjectID(unready_id)
for unready_id in unready_ids.keys()
]
fetch_request_size = (
ray._config.worker_fetch_request_size())
for i in range(0, len(object_ids_to_fetch),
fetch_request_size):
if not self.use_raylet:
self.plasma_client.fetch(object_ids_to_fetch[i:(
i + fetch_request_size)])
else:
self.local_scheduler_client.reconstruct_objects(
ray_object_ids_to_fetch[i:(
i + fetch_request_size)], False)
results = self.retrieve_and_deserialize(
object_ids_to_fetch,
max([
ray._config.get_timeout_milliseconds(),
int(0.01 * len(unready_ids))
]))
# Remove any entries for objects we received during this
# iteration so we don't retrieve the same object twice.
for i, val in enumerate(results):
if val is not plasma.ObjectNotAvailable:
object_id = object_ids_to_fetch[i].binary()
index = unready_ids[object_id]
final_results[index] = val
unready_ids.pop(object_id)
# If there were objects that we weren't able to get locally,
# let the local scheduler know that we're now unblocked.
self.local_scheduler_client.notify_unblocked()
assert len(final_results) == len(object_ids)
return final_results
@ -563,7 +579,6 @@ class Worker(object):
The return object IDs for this task.
"""
with profiling.profile("submit_task", worker=self):
check_main_thread()
if actor_id is None:
assert actor_handle_id is None
actor_id = ray.ObjectID(NIL_ACTOR_ID)
@ -607,17 +622,19 @@ class Worker(object):
raise ValueError(
"Resource quantities must all be whole numbers.")
with self.state_lock:
# Increment the worker's task index to track how many tasks
# have been submitted by the current task so far.
task_index = self.task_index
self.task_index += 1
# Submit the task to local scheduler.
task = ray.local_scheduler.Task(
driver_id, ray.ObjectID(
function_id.id()), args_for_local_scheduler,
num_return_vals, self.current_task_id, self.task_index,
num_return_vals, self.current_task_id, task_index,
actor_creation_id, actor_creation_dummy_object_id, actor_id,
actor_handle_id, actor_counter, is_actor_checkpoint_method,
execution_dependencies, resources, self.use_raylet)
# Increment the worker's task index to track how many tasks have
# been submitted by the current task so far.
self.task_index += 1
self.local_scheduler_client.submit(task)
return task.returns()
@ -635,7 +652,6 @@ class Worker(object):
decorated_function: The decorated function (this is used to enable
the remote function to recursively call itself).
"""
check_main_thread()
if self.mode not in [SCRIPT_MODE, SILENT_MODE]:
raise Exception("export_remote_function can only be called on a "
"driver.")
@ -687,7 +703,6 @@ class Worker(object):
should not take any arguments. If it returns anything, its
return values will not be used.
"""
check_main_thread()
# If ray.init has not been called yet, then cache the function and
# export it when connect is called. Otherwise, run the function on all
# workers.
@ -1041,7 +1056,6 @@ class Worker(object):
signal.signal(signal.SIGTERM, exit)
check_main_thread()
while True:
task = self._get_next_task_from_local_scheduler()
self._wait_for_and_process_task(task)
@ -1143,20 +1157,6 @@ class RayConnectionError(Exception):
pass
def check_main_thread():
"""Check that we are currently on the main thread.
Raises:
Exception: An exception is raised if this is called on a thread other
than the main thread.
"""
if threading.current_thread().getName() != "MainThread":
raise Exception("The Ray methods are not thread safe and must be "
"called from the main thread. This method was called "
"from thread {}."
.format(threading.current_thread().getName()))
def print_failed_task(task_status):
"""Print information about failed tasks.
@ -1191,12 +1191,9 @@ def error_applies_to_driver(error_key, worker=global_worker):
def error_info(worker=global_worker):
"""Return information about failed tasks."""
worker.check_connected()
check_main_thread()
if worker.use_raylet:
return (global_state.error_messages(job_id=worker.task_driver_id) +
global_state.error_messages(job_id=ray_constants.NIL_JOB_ID))
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
errors = []
for error_key in error_keys:
@ -1388,7 +1385,7 @@ def get_address_info_from_redis(redis_address,
try:
return get_address_info_from_redis_helper(
redis_address, node_ip_address, use_raylet=use_raylet)
except Exception as e:
except Exception:
if counter == num_retries:
raise
# Some of the information may not be in Redis yet, so wait a little
@ -1521,7 +1518,6 @@ def _init(address_info=None,
Exception: An exception is raised if an inappropriate combination of
arguments is passed in.
"""
check_main_thread()
if driver_mode not in [SCRIPT_MODE, LOCAL_MODE, SILENT_MODE]:
raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, "
"ray.LOCAL_MODE, ray.SILENT_MODE].")
@ -1988,7 +1984,6 @@ def connect(info,
LOCAL_MODE, and SILENT_MODE.
use_raylet: True if the new raylet code path should be used.
"""
check_main_thread()
# Do some basic checking to make sure we didn't call ray.init twice.
error_message = "Perhaps you called ray.init twice by accident?"
assert not worker.connected, error_message
@ -2021,8 +2016,8 @@ def connect(info,
# Create a Redis client.
redis_ip_address, redis_port = info["redis_address"].split(":")
worker.redis_client = redis.StrictRedis(
host=redis_ip_address, port=int(redis_port))
worker.redis_client = thread_safe_client(
redis.StrictRedis(host=redis_ip_address, port=int(redis_port)))
# For driver's check that the version information matches the version
# information that the Ray cluster was started with.
@ -2102,11 +2097,12 @@ def connect(info,
# Create an object store client.
if not worker.use_raylet:
worker.plasma_client = plasma.connect(info["store_socket_name"],
info["manager_socket_name"], 64)
worker.plasma_client = thread_safe_client(
plasma.connect(info["store_socket_name"],
info["manager_socket_name"], 64))
else:
worker.plasma_client = plasma.connect(info["store_socket_name"], "",
64)
worker.plasma_client = thread_safe_client(
plasma.connect(info["store_socket_name"], "", 64))
if not worker.use_raylet:
local_scheduler_socket = info["local_scheduler_socket_name"]
@ -2348,7 +2344,7 @@ def register_custom_serializer(cls,
# worker. However, determinism is not guaranteed, and the result
# may be different on different workers.
class_id = _try_to_compute_deterministic_class_id(cls)
except Exception as e:
except Exception:
raise serialization.CloudPickleError("Failed to pickle class "
"'{}'".format(cls))
else:
@ -2399,8 +2395,6 @@ def get(object_ids, worker=global_worker):
"""
worker.check_connected()
with profiling.profile("ray.get", worker=worker):
check_main_thread()
if worker.mode == LOCAL_MODE:
# In LOCAL_MODE, ray.get is the identity operation (the input will
# actually be a value not an objectid).
@ -2432,8 +2426,6 @@ def put(value, worker=global_worker):
"""
worker.check_connected()
with profiling.profile("ray.put", worker=worker):
check_main_thread()
if worker.mode == LOCAL_MODE:
# In LOCAL_MODE, ray.put is the identity operation.
return value
@ -2491,8 +2483,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
worker.check_connected()
with profiling.profile("ray.wait", worker=worker):
check_main_thread()
# When Ray is run in LOCAL_MODE, all functions are run immediately,
# so all objects in object_id are ready.
if worker.mode == LOCAL_MODE:

View file

@ -4,6 +4,7 @@ import os
import re
import string
import sys
import threading
import time
import unittest
from collections import defaultdict, namedtuple, OrderedDict
@ -1144,6 +1145,37 @@ class APITest(unittest.TestCase):
with self.assertRaises(Exception):
ray.get(3)
def testMultithreading(self):
self.init_ray(driver_mode=ray.SILENT_MODE)
@ray.remote
def f():
pass
def g(n):
for _ in range(1000 // n):
ray.get([f.remote() for _ in range(n)])
res = [ray.put(i) for i in range(1000 // n)]
ray.wait(res, len(res))
def test_multi_threading():
threads = [
threading.Thread(target=g, args=(n, ))
for n in [1, 5, 10, 100, 1000]
]
[thread.start() for thread in threads]
[thread.join() for thread in threads]
@ray.remote
def test_multi_threading_in_worker():
test_multi_threading()
# test multi-threading in the driver
test_multi_threading()
# test multi-threading in the worker
ray.get(test_multi_threading_in_worker.remote())
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),