mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
move import_thread to a separate file (#2349)
* move import_thread to a separate file * sort imports * group imports regardless of `from` * re-organize imoprts based on google style * Update import_thread.py * fix event_type names in profile statement * unify duplicate code
This commit is contained in:
parent
ebf4070d88
commit
d6af50785e
2 changed files with 187 additions and 177 deletions
184
python/ray/import_thread.py
Normal file
184
python/ray/import_thread.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import redis
|
||||
|
||||
import ray
|
||||
from ray import ray_constants
|
||||
from ray import cloudpickle as pickle
|
||||
from ray import utils
|
||||
|
||||
|
||||
class ImportThread(object):
|
||||
"""A thread used to import exports from the driver or other workers.
|
||||
|
||||
Note:
|
||||
The driver also has an import thread, which is used only to
|
||||
import custom class definitions from calls to register_custom_serializer
|
||||
that happen under the hood on workers.
|
||||
|
||||
Attributes:
|
||||
worker: the worker object in this process.
|
||||
mode: worker mode
|
||||
redis_client: the redis client used to query exports.
|
||||
"""
|
||||
|
||||
def __init__(self, worker, mode):
|
||||
self.worker = worker
|
||||
self.mode = mode
|
||||
self.redis_client = worker.redis_client
|
||||
|
||||
def start(self):
|
||||
"""Start the import thread."""
|
||||
t = threading.Thread(target=self._run)
|
||||
# Making the thread a daemon causes it to exit
|
||||
# when the main thread exits.
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
def _run(self):
|
||||
import_pubsub_client = self.redis_client.pubsub()
|
||||
# Exports that are published after the call to
|
||||
# import_pubsub_client.subscribe and before the call to
|
||||
# import_pubsub_client.listen will still be processed in the loop.
|
||||
import_pubsub_client.subscribe("__keyspace@0__:Exports")
|
||||
# Keep track of the number of imports that we've imported.
|
||||
num_imported = 0
|
||||
|
||||
# Get the exports that occurred before the call to subscribe.
|
||||
with self.worker.lock:
|
||||
export_keys = self.redis_client.lrange("Exports", 0, -1)
|
||||
for key in export_keys:
|
||||
num_imported += 1
|
||||
self._process_key(key)
|
||||
try:
|
||||
for msg in import_pubsub_client.listen():
|
||||
with self.worker.lock:
|
||||
if msg["type"] == "subscribe":
|
||||
continue
|
||||
assert msg["data"] == b"rpush"
|
||||
num_imports = self.redis_client.llen("Exports")
|
||||
assert num_imports >= num_imported
|
||||
for i in range(num_imported, num_imports):
|
||||
num_imported += 1
|
||||
key = self.redis_client.lindex("Exports", i)
|
||||
self._process_key(key)
|
||||
except redis.ConnectionError:
|
||||
# When Redis terminates the listen call will throw a
|
||||
# ConnectionError, which we catch here.
|
||||
pass
|
||||
|
||||
def _process_key(self, key):
|
||||
"""Process the given export key from redis."""
|
||||
from ray.worker import profile, WORKER_MODE
|
||||
# Handle the driver case first.
|
||||
if self.mode != WORKER_MODE:
|
||||
if key.startswith(b"FunctionsToRun"):
|
||||
with profile("fetch_and_run_function", worker=self.worker):
|
||||
self.fetch_and_execute_function_to_run(key)
|
||||
# Return because FunctionsToRun are the only things that
|
||||
# the driver should import.
|
||||
return
|
||||
|
||||
if key.startswith(b"RemoteFunction"):
|
||||
with profile("register_remote_function", worker=self.worker):
|
||||
self.fetch_and_register_remote_function(key)
|
||||
elif key.startswith(b"FunctionsToRun"):
|
||||
with profile("fetch_and_run_function", worker=self.worker):
|
||||
self.fetch_and_execute_function_to_run(key)
|
||||
elif key.startswith(b"ActorClass"):
|
||||
# Keep track of the fact that this actor class has been
|
||||
# exported so that we know it is safe to turn this worker
|
||||
# into an actor of that class.
|
||||
self.worker.imported_actor_classes.add(key)
|
||||
# TODO(rkn): We may need to bring back the case of
|
||||
# fetching actor classes here.
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
def fetch_and_register_remote_function(self, key):
|
||||
"""Import a remote function."""
|
||||
from ray.worker import FunctionExecutionInfo
|
||||
(driver_id, function_id_str, function_name, serialized_function,
|
||||
num_return_vals, module, resources,
|
||||
max_calls) = self.redis_client.hmget(key, [
|
||||
"driver_id", "function_id", "name", "function", "num_return_vals",
|
||||
"module", "resources", "max_calls"
|
||||
])
|
||||
function_id = ray.ObjectID(function_id_str)
|
||||
function_name = utils.decode(function_name)
|
||||
max_calls = int(max_calls)
|
||||
module = utils.decode(module)
|
||||
|
||||
# This is a placeholder in case the function can't be unpickled. This
|
||||
# will be overwritten if the function is successfully registered.
|
||||
def f():
|
||||
raise Exception("This function was not imported properly.")
|
||||
|
||||
self.worker.function_execution_info[driver_id][function_id.id()] = (
|
||||
FunctionExecutionInfo(
|
||||
function=f, function_name=function_name, max_calls=max_calls))
|
||||
self.worker.num_task_executions[driver_id][function_id.id()] = 0
|
||||
|
||||
try:
|
||||
function = pickle.loads(serialized_function)
|
||||
except Exception:
|
||||
# If an exception was thrown when the remote function was imported,
|
||||
# we record the traceback and notify the scheduler of the failure.
|
||||
traceback_str = utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
utils.push_error_to_driver(
|
||||
self.worker,
|
||||
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=driver_id,
|
||||
data={
|
||||
"function_id": function_id.id(),
|
||||
"function_name": function_name
|
||||
})
|
||||
else:
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
self.worker.function_execution_info[driver_id][
|
||||
function_id.id()] = (FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
max_calls=max_calls))
|
||||
# Add the function to the function table.
|
||||
self.redis_client.rpush(b"FunctionTable:" + function_id.id(),
|
||||
self.worker.worker_id)
|
||||
|
||||
def fetch_and_execute_function_to_run(self, key):
|
||||
"""Run on arbitrary function on the worker."""
|
||||
from ray.worker import SCRIPT_MODE, SILENT_MODE
|
||||
driver_id, serialized_function = self.redis_client.hmget(
|
||||
key, ["driver_id", "function"])
|
||||
|
||||
if (self.worker.mode in [SCRIPT_MODE, SILENT_MODE]
|
||||
and driver_id != self.worker.task_driver_id.id()):
|
||||
# This export was from a different driver and there's no need for
|
||||
# this driver to import it.
|
||||
return
|
||||
|
||||
try:
|
||||
# Deserialize the function.
|
||||
function = pickle.loads(serialized_function)
|
||||
# Run the function.
|
||||
function({"worker": self.worker})
|
||||
except Exception:
|
||||
# If an exception was thrown when the function was run, we record
|
||||
# the traceback and notify the scheduler of the failure.
|
||||
traceback_str = traceback.format_exc()
|
||||
# Log the error message.
|
||||
name = function.__name__ if ("function" in locals() and hasattr(
|
||||
function, "__name__")) else ""
|
||||
utils.push_error_to_driver(
|
||||
self.worker,
|
||||
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=driver_id,
|
||||
data={"name": name})
|
|
@ -31,6 +31,7 @@ import ray.signature
|
|||
import ray.local_scheduler
|
||||
import ray.plasma
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray import import_thread
|
||||
from ray.utils import (
|
||||
binary_to_hex,
|
||||
check_oversized_pickle,
|
||||
|
@ -1984,175 +1985,6 @@ def print_error_messages(worker):
|
|||
pass
|
||||
|
||||
|
||||
def fetch_and_register_remote_function(key, worker=global_worker):
|
||||
"""Import a remote function."""
|
||||
(driver_id, function_id_str, function_name, serialized_function,
|
||||
num_return_vals, module, resources,
|
||||
max_calls) = worker.redis_client.hmget(key, [
|
||||
"driver_id", "function_id", "name", "function", "num_return_vals",
|
||||
"module", "resources", "max_calls"
|
||||
])
|
||||
function_id = ray.ObjectID(function_id_str)
|
||||
function_name = ray.utils.decode(function_name)
|
||||
max_calls = int(max_calls)
|
||||
module = ray.utils.decode(module)
|
||||
|
||||
# This is a placeholder in case the function can't be unpickled. This will
|
||||
# be overwritten if the function is successfully registered.
|
||||
def f():
|
||||
raise Exception("This function was not imported properly.")
|
||||
|
||||
worker.function_execution_info[driver_id][function_id.id()] = (
|
||||
FunctionExecutionInfo(
|
||||
function=f, function_name=function_name, max_calls=max_calls))
|
||||
worker.num_task_executions[driver_id][function_id.id()] = 0
|
||||
|
||||
try:
|
||||
function = pickle.loads(serialized_function)
|
||||
except Exception:
|
||||
# If an exception was thrown when the remote function was imported, we
|
||||
# record the traceback and notify the scheduler of the failure.
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
ray.utils.push_error_to_driver(
|
||||
worker,
|
||||
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=driver_id,
|
||||
data={
|
||||
"function_id": function_id.id(),
|
||||
"function_name": function_name
|
||||
})
|
||||
else:
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
worker.function_execution_info[driver_id][function_id.id()] = (
|
||||
FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
max_calls=max_calls))
|
||||
# Add the function to the function table.
|
||||
worker.redis_client.rpush(b"FunctionTable:" + function_id.id(),
|
||||
worker.worker_id)
|
||||
|
||||
|
||||
def fetch_and_execute_function_to_run(key, worker=global_worker):
|
||||
"""Run on arbitrary function on the worker."""
|
||||
driver_id, serialized_function = worker.redis_client.hmget(
|
||||
key, ["driver_id", "function"])
|
||||
|
||||
if (worker.mode in [SCRIPT_MODE, SILENT_MODE]
|
||||
and driver_id != worker.task_driver_id.id()):
|
||||
# This export was from a different driver and there's no need for this
|
||||
# driver to import it.
|
||||
return
|
||||
|
||||
try:
|
||||
# Deserialize the function.
|
||||
function = pickle.loads(serialized_function)
|
||||
# Run the function.
|
||||
function({"worker": worker})
|
||||
except Exception:
|
||||
# If an exception was thrown when the function was run, we record the
|
||||
# traceback and notify the scheduler of the failure.
|
||||
traceback_str = traceback.format_exc()
|
||||
# Log the error message.
|
||||
name = function.__name__ if ("function" in locals()
|
||||
and hasattr(function, "__name__")) else ""
|
||||
ray.utils.push_error_to_driver(
|
||||
worker,
|
||||
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=driver_id,
|
||||
data={"name": name})
|
||||
|
||||
|
||||
def import_thread(worker, mode):
|
||||
worker.import_pubsub_client = worker.redis_client.pubsub()
|
||||
# Exports that are published after the call to
|
||||
# import_pubsub_client.subscribe and before the call to
|
||||
# import_pubsub_client.listen will still be processed in the loop.
|
||||
worker.import_pubsub_client.subscribe("__keyspace@0__:Exports")
|
||||
# Keep track of the number of imports that we've imported.
|
||||
num_imported = 0
|
||||
|
||||
# Get the exports that occurred before the call to subscribe.
|
||||
with worker.lock:
|
||||
export_keys = worker.redis_client.lrange("Exports", 0, -1)
|
||||
for key in export_keys:
|
||||
num_imported += 1
|
||||
|
||||
# Handle the driver case first.
|
||||
if mode != WORKER_MODE:
|
||||
if key.startswith(b"FunctionsToRun"):
|
||||
with profile("fetch_and_run_function", worker=worker):
|
||||
fetch_and_execute_function_to_run(key, worker=worker)
|
||||
# Continue because FunctionsToRun are the only things that the
|
||||
# driver should import.
|
||||
continue
|
||||
|
||||
if key.startswith(b"RemoteFunction"):
|
||||
with profile("register_remote_function", worker=worker):
|
||||
fetch_and_register_remote_function(key, worker=worker)
|
||||
elif key.startswith(b"FunctionsToRun"):
|
||||
with profile("fetch_and_run_function", worker=worker):
|
||||
fetch_and_execute_function_to_run(key, worker=worker)
|
||||
elif key.startswith(b"ActorClass"):
|
||||
# Keep track of the fact that this actor class has been
|
||||
# exported so that we know it is safe to turn this worker into
|
||||
# an actor of that class.
|
||||
worker.imported_actor_classes.add(key)
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
try:
|
||||
for msg in worker.import_pubsub_client.listen():
|
||||
with worker.lock:
|
||||
if msg["type"] == "subscribe":
|
||||
continue
|
||||
assert msg["data"] == b"rpush"
|
||||
num_imports = worker.redis_client.llen("Exports")
|
||||
assert num_imports >= num_imported
|
||||
for i in range(num_imported, num_imports):
|
||||
num_imported += 1
|
||||
key = worker.redis_client.lindex("Exports", i)
|
||||
|
||||
# Handle the driver case first.
|
||||
if mode != WORKER_MODE:
|
||||
if key.startswith(b"FunctionsToRun"):
|
||||
with profile(
|
||||
"fetch_and_run_function", worker=worker):
|
||||
fetch_and_execute_function_to_run(
|
||||
key, worker=worker)
|
||||
# Continue because FunctionsToRun are the only things
|
||||
# that the driver should import.
|
||||
continue
|
||||
|
||||
if key.startswith(b"RemoteFunction"):
|
||||
with profile(
|
||||
"register_remote_function", worker=worker):
|
||||
fetch_and_register_remote_function(
|
||||
key, worker=worker)
|
||||
elif key.startswith(b"FunctionsToRun"):
|
||||
with profile("fetch_and_run_function", worker=worker):
|
||||
fetch_and_execute_function_to_run(
|
||||
key, worker=worker)
|
||||
elif key.startswith(b"ActorClass"):
|
||||
# Keep track of the fact that this actor class has been
|
||||
# exported so that we know it is safe to turn this
|
||||
# worker into an actor of that class.
|
||||
worker.imported_actor_classes.add(key)
|
||||
|
||||
# TODO(rkn): We may need to bring back the case of fetching
|
||||
# actor classes here.
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
except redis.ConnectionError:
|
||||
# When Redis terminates the listen call will throw a ConnectionError,
|
||||
# which we catch here.
|
||||
pass
|
||||
|
||||
|
||||
def connect(info,
|
||||
object_id_seed=None,
|
||||
mode=WORKER_MODE,
|
||||
|
@ -2361,14 +2193,8 @@ def connect(info,
|
|||
# it must be run before we export all of the cached remote functions.
|
||||
_initialize_serialization()
|
||||
|
||||
# Start a thread to import exports from the driver or from other workers.
|
||||
# Note that the driver also has an import thread, which is used only to
|
||||
# import custom class definitions from calls to register_custom_serializer
|
||||
# that happen under the hood on workers.
|
||||
t = threading.Thread(target=import_thread, args=(worker, mode))
|
||||
# Making the thread a daemon causes it to exit when the main thread exits.
|
||||
t.daemon = True
|
||||
t.start()
|
||||
# Start the import thread
|
||||
import_thread.ImportThread(worker, mode).start()
|
||||
|
||||
# If this is a driver running in SCRIPT_MODE, start a thread to print error
|
||||
# messages asynchronously in the background. Ideally the scheduler would
|
||||
|
|
Loading…
Add table
Reference in a new issue