mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Allow Ray API to be used from multiple threads (#2422)
This commit is contained in:
parent
4b6157ed09
commit
05f485e274
5 changed files with 167 additions and 93 deletions
|
@ -594,7 +594,6 @@ class ActorClass(object):
|
|||
A handle to the newly created actor.
|
||||
"""
|
||||
worker = ray.worker.get_global_worker()
|
||||
ray.worker.check_main_thread()
|
||||
if worker.mode is None:
|
||||
raise Exception("Actors cannot be created before ray.init() "
|
||||
"has been called.")
|
||||
|
@ -773,7 +772,6 @@ class ActorHandle(object):
|
|||
worker = ray.worker.get_global_worker()
|
||||
|
||||
worker.check_connected()
|
||||
ray.worker.check_main_thread()
|
||||
|
||||
function_signature = self._ray_method_signatures[method_name]
|
||||
if args is None:
|
||||
|
@ -929,7 +927,6 @@ class ActorHandle(object):
|
|||
"""
|
||||
worker = ray.worker.get_global_worker()
|
||||
worker.check_connected()
|
||||
ray.worker.check_main_thread()
|
||||
|
||||
if state["ray_forking"]:
|
||||
actor_handle_id = compute_actor_handle_id(
|
||||
|
|
|
@ -114,7 +114,6 @@ class RemoteFunction(object):
|
|||
"""An experimental alternate way to submit remote functions."""
|
||||
worker = ray.worker.get_global_worker()
|
||||
worker.check_connected()
|
||||
ray.worker.check_main_thread()
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
args = ray.signature.extend_args(self._function_signature, args,
|
||||
kwargs)
|
||||
|
|
|
@ -3,10 +3,12 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import binascii
|
||||
import functools
|
||||
import hashlib
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
|
@ -295,3 +297,57 @@ def check_oversized_pickle(pickled, name, obj_type, worker):
|
|||
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=worker.task_driver_id.id())
|
||||
|
||||
|
||||
class _ThreadSafeProxy(object):
|
||||
"""This class is used to create a thread-safe proxy for a given object.
|
||||
Every method call will be guarded with a lock.
|
||||
|
||||
Attributes:
|
||||
orig_obj (object): the original object.
|
||||
lock (threading.Lock): the lock object.
|
||||
_wrapper_cache (dict): a cache from original object's methods to
|
||||
the proxy methods.
|
||||
"""
|
||||
|
||||
def __init__(self, orig_obj, lock):
|
||||
self.orig_obj = orig_obj
|
||||
self.lock = lock
|
||||
self._wrapper_cache = {}
|
||||
|
||||
def __getattr__(self, attr):
|
||||
orig_attr = getattr(self.orig_obj, attr)
|
||||
if not callable(orig_attr):
|
||||
# If the original attr is a field, just return it.
|
||||
return orig_attr
|
||||
else:
|
||||
# If the orginal attr is a method,
|
||||
# return a wrapper that guards the original method with a lock.
|
||||
wrapper = self._wrapper_cache.get(attr)
|
||||
if wrapper is None:
|
||||
|
||||
@functools.wraps(orig_attr)
|
||||
def _wrapper(*args, **kwargs):
|
||||
with self.lock:
|
||||
return orig_attr(*args, **kwargs)
|
||||
|
||||
self._wrapper_cache[attr] = _wrapper
|
||||
wrapper = _wrapper
|
||||
return wrapper
|
||||
|
||||
|
||||
def thread_safe_client(client, lock=None):
|
||||
"""Create a thread-safe proxy which locks every method call
|
||||
for the given client.
|
||||
|
||||
Args:
|
||||
client: the client object to be guarded.
|
||||
lock: the lock object that will be used to lock client's methods.
|
||||
If None, a new lock will be used.
|
||||
|
||||
Returns:
|
||||
A thread-safe proxy for the given client.
|
||||
"""
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
return _ThreadSafeProxy(client, lock)
|
||||
|
|
|
@ -37,6 +37,7 @@ from ray.utils import (
|
|||
check_oversized_pickle,
|
||||
is_cython,
|
||||
random_string,
|
||||
thread_safe_client,
|
||||
)
|
||||
|
||||
SCRIPT_MODE = 0
|
||||
|
@ -200,6 +201,13 @@ class Worker(object):
|
|||
cached_functions_to_run (List): A list of functions to run on all of
|
||||
the workers that should be exported as soon as connect is called.
|
||||
profiler: the profiler used to aggregate profiling information.
|
||||
state_lock (Lock):
|
||||
Used to lock worker's non-thread-safe internal states:
|
||||
1) task_index increment: make sure we generate unique task ids;
|
||||
2) Object reconstruction: because the node manager will
|
||||
recycle/return the worker's resources before/after reconstruction,
|
||||
it's unsafe for multiple threads to call object
|
||||
reconstruction simultaneously.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -236,6 +244,7 @@ class Worker(object):
|
|||
# CUDA_VISIBLE_DEVICES environment variable.
|
||||
self.original_gpu_ids = ray.utils.get_cuda_visible_devices()
|
||||
self.profiler = profiling.Profiler(self)
|
||||
self.state_lock = threading.Lock()
|
||||
|
||||
def check_connected(self):
|
||||
"""Check if the worker is connected.
|
||||
|
@ -365,7 +374,7 @@ class Worker(object):
|
|||
# Serialize and put the object in the object store.
|
||||
try:
|
||||
self.store_and_register(object_id, value)
|
||||
except pyarrow.PlasmaObjectExists as e:
|
||||
except pyarrow.PlasmaObjectExists:
|
||||
# The object already exists in the object store, so there is no
|
||||
# need to add it again. TODO(rkn): We need to compare the hashes
|
||||
# and make sure that the objects are in fact the same. We also
|
||||
|
@ -393,7 +402,7 @@ class Worker(object):
|
|||
i + ray._config.worker_get_request_size())],
|
||||
timeout, self.serialization_context)
|
||||
return results
|
||||
except pyarrow.lib.ArrowInvalid as e:
|
||||
except pyarrow.lib.ArrowInvalid:
|
||||
# TODO(ekl): the local scheduler could include relevant
|
||||
# metadata in the task kill case for a better error message
|
||||
invalid_error = RayTaskError(
|
||||
|
@ -401,7 +410,7 @@ class Worker(object):
|
|||
"Invalid return value: likely worker died or was killed "
|
||||
"while executing the task.")
|
||||
return [invalid_error] * len(object_ids)
|
||||
except pyarrow.DeserializationCallbackError as e:
|
||||
except pyarrow.DeserializationCallbackError:
|
||||
# Wait a little bit for the import thread to import the class.
|
||||
# If we currently have the worker lock, we need to release it
|
||||
# so that the import thread can acquire it.
|
||||
|
@ -466,41 +475,49 @@ class Worker(object):
|
|||
for (i, val) in enumerate(final_results)
|
||||
if val is plasma.ObjectNotAvailable
|
||||
}
|
||||
was_blocked = (len(unready_ids) > 0)
|
||||
# Try reconstructing any objects we haven't gotten yet. Try to get them
|
||||
# until at least get_timeout_milliseconds milliseconds passes, then
|
||||
# repeat.
|
||||
|
||||
if len(unready_ids) > 0:
|
||||
with self.state_lock:
|
||||
# Try reconstructing any objects we haven't gotten yet. Try to
|
||||
# get them until at least get_timeout_milliseconds
|
||||
# milliseconds passes, then repeat.
|
||||
while len(unready_ids) > 0:
|
||||
for unready_id in unready_ids:
|
||||
if not self.use_raylet:
|
||||
self.local_scheduler_client.reconstruct_objects(
|
||||
[ray.ObjectID(unready_id)], False)
|
||||
# Do another fetch for objects that aren't available locally yet,
|
||||
# in case they were evicted since the last fetch. We divide the
|
||||
# fetch into smaller fetches so as to not block the manager for a
|
||||
# prolonged period of time in a single call.
|
||||
object_ids_to_fetch = list(
|
||||
map(plasma.ObjectID, unready_ids.keys()))
|
||||
ray_object_ids_to_fetch = list(
|
||||
map(ray.ObjectID, unready_ids.keys()))
|
||||
# Do another fetch for objects that aren't available
|
||||
# locally yet, in case they were evicted since the last
|
||||
# fetch. We divide the fetch into smaller fetches so as
|
||||
# to not block the manager for a prolonged period of time
|
||||
# in a single call.
|
||||
object_ids_to_fetch = [
|
||||
plasma.ObjectID(unready_id)
|
||||
for unready_id in unready_ids.keys()
|
||||
]
|
||||
ray_object_ids_to_fetch = [
|
||||
ray.ObjectID(unready_id)
|
||||
for unready_id in unready_ids.keys()
|
||||
]
|
||||
fetch_request_size = (
|
||||
ray._config.worker_fetch_request_size())
|
||||
for i in range(0, len(object_ids_to_fetch),
|
||||
ray._config.worker_fetch_request_size()):
|
||||
fetch_request_size):
|
||||
if not self.use_raylet:
|
||||
self.plasma_client.fetch(object_ids_to_fetch[i:(
|
||||
i + ray._config.worker_fetch_request_size())])
|
||||
i + fetch_request_size)])
|
||||
else:
|
||||
self.local_scheduler_client.reconstruct_objects(
|
||||
ray_object_ids_to_fetch[i:(
|
||||
i + ray._config.worker_fetch_request_size())],
|
||||
False)
|
||||
i + fetch_request_size)], False)
|
||||
results = self.retrieve_and_deserialize(
|
||||
object_ids_to_fetch,
|
||||
max([
|
||||
ray._config.get_timeout_milliseconds(),
|
||||
int(0.01 * len(unready_ids))
|
||||
]))
|
||||
# Remove any entries for objects we received during this iteration
|
||||
# so we don't retrieve the same object twice.
|
||||
# Remove any entries for objects we received during this
|
||||
# iteration so we don't retrieve the same object twice.
|
||||
for i, val in enumerate(results):
|
||||
if val is not plasma.ObjectNotAvailable:
|
||||
object_id = object_ids_to_fetch[i].binary()
|
||||
|
@ -508,9 +525,8 @@ class Worker(object):
|
|||
final_results[index] = val
|
||||
unready_ids.pop(object_id)
|
||||
|
||||
# If there were objects that we weren't able to get locally, let the
|
||||
# local scheduler know that we're now unblocked.
|
||||
if was_blocked:
|
||||
# If there were objects that we weren't able to get locally,
|
||||
# let the local scheduler know that we're now unblocked.
|
||||
self.local_scheduler_client.notify_unblocked()
|
||||
|
||||
assert len(final_results) == len(object_ids)
|
||||
|
@ -563,7 +579,6 @@ class Worker(object):
|
|||
The return object IDs for this task.
|
||||
"""
|
||||
with profiling.profile("submit_task", worker=self):
|
||||
check_main_thread()
|
||||
if actor_id is None:
|
||||
assert actor_handle_id is None
|
||||
actor_id = ray.ObjectID(NIL_ACTOR_ID)
|
||||
|
@ -607,17 +622,19 @@ class Worker(object):
|
|||
raise ValueError(
|
||||
"Resource quantities must all be whole numbers.")
|
||||
|
||||
with self.state_lock:
|
||||
# Increment the worker's task index to track how many tasks
|
||||
# have been submitted by the current task so far.
|
||||
task_index = self.task_index
|
||||
self.task_index += 1
|
||||
# Submit the task to local scheduler.
|
||||
task = ray.local_scheduler.Task(
|
||||
driver_id, ray.ObjectID(
|
||||
function_id.id()), args_for_local_scheduler,
|
||||
num_return_vals, self.current_task_id, self.task_index,
|
||||
num_return_vals, self.current_task_id, task_index,
|
||||
actor_creation_id, actor_creation_dummy_object_id, actor_id,
|
||||
actor_handle_id, actor_counter, is_actor_checkpoint_method,
|
||||
execution_dependencies, resources, self.use_raylet)
|
||||
# Increment the worker's task index to track how many tasks have
|
||||
# been submitted by the current task so far.
|
||||
self.task_index += 1
|
||||
self.local_scheduler_client.submit(task)
|
||||
|
||||
return task.returns()
|
||||
|
@ -635,7 +652,6 @@ class Worker(object):
|
|||
decorated_function: The decorated function (this is used to enable
|
||||
the remote function to recursively call itself).
|
||||
"""
|
||||
check_main_thread()
|
||||
if self.mode not in [SCRIPT_MODE, SILENT_MODE]:
|
||||
raise Exception("export_remote_function can only be called on a "
|
||||
"driver.")
|
||||
|
@ -687,7 +703,6 @@ class Worker(object):
|
|||
should not take any arguments. If it returns anything, its
|
||||
return values will not be used.
|
||||
"""
|
||||
check_main_thread()
|
||||
# If ray.init has not been called yet, then cache the function and
|
||||
# export it when connect is called. Otherwise, run the function on all
|
||||
# workers.
|
||||
|
@ -1041,7 +1056,6 @@ class Worker(object):
|
|||
|
||||
signal.signal(signal.SIGTERM, exit)
|
||||
|
||||
check_main_thread()
|
||||
while True:
|
||||
task = self._get_next_task_from_local_scheduler()
|
||||
self._wait_for_and_process_task(task)
|
||||
|
@ -1143,20 +1157,6 @@ class RayConnectionError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def check_main_thread():
|
||||
"""Check that we are currently on the main thread.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if this is called on a thread other
|
||||
than the main thread.
|
||||
"""
|
||||
if threading.current_thread().getName() != "MainThread":
|
||||
raise Exception("The Ray methods are not thread safe and must be "
|
||||
"called from the main thread. This method was called "
|
||||
"from thread {}."
|
||||
.format(threading.current_thread().getName()))
|
||||
|
||||
|
||||
def print_failed_task(task_status):
|
||||
"""Print information about failed tasks.
|
||||
|
||||
|
@ -1191,12 +1191,9 @@ def error_applies_to_driver(error_key, worker=global_worker):
|
|||
def error_info(worker=global_worker):
|
||||
"""Return information about failed tasks."""
|
||||
worker.check_connected()
|
||||
check_main_thread()
|
||||
|
||||
if worker.use_raylet:
|
||||
return (global_state.error_messages(job_id=worker.task_driver_id) +
|
||||
global_state.error_messages(job_id=ray_constants.NIL_JOB_ID))
|
||||
|
||||
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
|
||||
errors = []
|
||||
for error_key in error_keys:
|
||||
|
@ -1388,7 +1385,7 @@ def get_address_info_from_redis(redis_address,
|
|||
try:
|
||||
return get_address_info_from_redis_helper(
|
||||
redis_address, node_ip_address, use_raylet=use_raylet)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
if counter == num_retries:
|
||||
raise
|
||||
# Some of the information may not be in Redis yet, so wait a little
|
||||
|
@ -1521,7 +1518,6 @@ def _init(address_info=None,
|
|||
Exception: An exception is raised if an inappropriate combination of
|
||||
arguments is passed in.
|
||||
"""
|
||||
check_main_thread()
|
||||
if driver_mode not in [SCRIPT_MODE, LOCAL_MODE, SILENT_MODE]:
|
||||
raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, "
|
||||
"ray.LOCAL_MODE, ray.SILENT_MODE].")
|
||||
|
@ -1988,7 +1984,6 @@ def connect(info,
|
|||
LOCAL_MODE, and SILENT_MODE.
|
||||
use_raylet: True if the new raylet code path should be used.
|
||||
"""
|
||||
check_main_thread()
|
||||
# Do some basic checking to make sure we didn't call ray.init twice.
|
||||
error_message = "Perhaps you called ray.init twice by accident?"
|
||||
assert not worker.connected, error_message
|
||||
|
@ -2021,8 +2016,8 @@ def connect(info,
|
|||
|
||||
# Create a Redis client.
|
||||
redis_ip_address, redis_port = info["redis_address"].split(":")
|
||||
worker.redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=int(redis_port))
|
||||
worker.redis_client = thread_safe_client(
|
||||
redis.StrictRedis(host=redis_ip_address, port=int(redis_port)))
|
||||
|
||||
# For driver's check that the version information matches the version
|
||||
# information that the Ray cluster was started with.
|
||||
|
@ -2102,11 +2097,12 @@ def connect(info,
|
|||
|
||||
# Create an object store client.
|
||||
if not worker.use_raylet:
|
||||
worker.plasma_client = plasma.connect(info["store_socket_name"],
|
||||
info["manager_socket_name"], 64)
|
||||
worker.plasma_client = thread_safe_client(
|
||||
plasma.connect(info["store_socket_name"],
|
||||
info["manager_socket_name"], 64))
|
||||
else:
|
||||
worker.plasma_client = plasma.connect(info["store_socket_name"], "",
|
||||
64)
|
||||
worker.plasma_client = thread_safe_client(
|
||||
plasma.connect(info["store_socket_name"], "", 64))
|
||||
|
||||
if not worker.use_raylet:
|
||||
local_scheduler_socket = info["local_scheduler_socket_name"]
|
||||
|
@ -2348,7 +2344,7 @@ def register_custom_serializer(cls,
|
|||
# worker. However, determinism is not guaranteed, and the result
|
||||
# may be different on different workers.
|
||||
class_id = _try_to_compute_deterministic_class_id(cls)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise serialization.CloudPickleError("Failed to pickle class "
|
||||
"'{}'".format(cls))
|
||||
else:
|
||||
|
@ -2399,8 +2395,6 @@ def get(object_ids, worker=global_worker):
|
|||
"""
|
||||
worker.check_connected()
|
||||
with profiling.profile("ray.get", worker=worker):
|
||||
check_main_thread()
|
||||
|
||||
if worker.mode == LOCAL_MODE:
|
||||
# In LOCAL_MODE, ray.get is the identity operation (the input will
|
||||
# actually be a value not an objectid).
|
||||
|
@ -2432,8 +2426,6 @@ def put(value, worker=global_worker):
|
|||
"""
|
||||
worker.check_connected()
|
||||
with profiling.profile("ray.put", worker=worker):
|
||||
check_main_thread()
|
||||
|
||||
if worker.mode == LOCAL_MODE:
|
||||
# In LOCAL_MODE, ray.put is the identity operation.
|
||||
return value
|
||||
|
@ -2491,8 +2483,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
|||
|
||||
worker.check_connected()
|
||||
with profiling.profile("ray.wait", worker=worker):
|
||||
check_main_thread()
|
||||
|
||||
# When Ray is run in LOCAL_MODE, all functions are run immediately,
|
||||
# so all objects in object_id are ready.
|
||||
if worker.mode == LOCAL_MODE:
|
||||
|
|
|
@ -4,6 +4,7 @@ import os
|
|||
import re
|
||||
import string
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
|
@ -1144,6 +1145,37 @@ class APITest(unittest.TestCase):
|
|||
with self.assertRaises(Exception):
|
||||
ray.get(3)
|
||||
|
||||
def testMultithreading(self):
|
||||
self.init_ray(driver_mode=ray.SILENT_MODE)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
pass
|
||||
|
||||
def g(n):
|
||||
for _ in range(1000 // n):
|
||||
ray.get([f.remote() for _ in range(n)])
|
||||
res = [ray.put(i) for i in range(1000 // n)]
|
||||
ray.wait(res, len(res))
|
||||
|
||||
def test_multi_threading():
|
||||
threads = [
|
||||
threading.Thread(target=g, args=(n, ))
|
||||
for n in [1, 5, 10, 100, 1000]
|
||||
]
|
||||
|
||||
[thread.start() for thread in threads]
|
||||
[thread.join() for thread in threads]
|
||||
|
||||
@ray.remote
|
||||
def test_multi_threading_in_worker():
|
||||
test_multi_threading()
|
||||
|
||||
# test multi-threading in the driver
|
||||
test_multi_threading()
|
||||
# test multi-threading in the worker
|
||||
ray.get(test_multi_threading_in_worker.remote())
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
|
|
Loading…
Add table
Reference in a new issue