move import_thread to a separate file (#2349)

* move import_thread to a separate file

* sort imports

* group imports regardless of `from`

* re-organize imoprts based on google style

* Update import_thread.py

* fix event_type names in profile statement

* unify duplicate code
This commit is contained in:
Hao Chen 2018-07-13 12:26:24 +08:00 committed by Philipp Moritz
parent ebf4070d88
commit d6af50785e
2 changed files with 187 additions and 177 deletions

184
python/ray/import_thread.py Normal file
View file

@ -0,0 +1,184 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import traceback
import redis
import ray
from ray import ray_constants
from ray import cloudpickle as pickle
from ray import utils
class ImportThread(object):
"""A thread used to import exports from the driver or other workers.
Note:
The driver also has an import thread, which is used only to
import custom class definitions from calls to register_custom_serializer
that happen under the hood on workers.
Attributes:
worker: the worker object in this process.
mode: worker mode
redis_client: the redis client used to query exports.
"""
def __init__(self, worker, mode):
self.worker = worker
self.mode = mode
self.redis_client = worker.redis_client
def start(self):
"""Start the import thread."""
t = threading.Thread(target=self._run)
# Making the thread a daemon causes it to exit
# when the main thread exits.
t.daemon = True
t.start()
def _run(self):
import_pubsub_client = self.redis_client.pubsub()
# Exports that are published after the call to
# import_pubsub_client.subscribe and before the call to
# import_pubsub_client.listen will still be processed in the loop.
import_pubsub_client.subscribe("__keyspace@0__:Exports")
# Keep track of the number of imports that we've imported.
num_imported = 0
# Get the exports that occurred before the call to subscribe.
with self.worker.lock:
export_keys = self.redis_client.lrange("Exports", 0, -1)
for key in export_keys:
num_imported += 1
self._process_key(key)
try:
for msg in import_pubsub_client.listen():
with self.worker.lock:
if msg["type"] == "subscribe":
continue
assert msg["data"] == b"rpush"
num_imports = self.redis_client.llen("Exports")
assert num_imports >= num_imported
for i in range(num_imported, num_imports):
num_imported += 1
key = self.redis_client.lindex("Exports", i)
self._process_key(key)
except redis.ConnectionError:
# When Redis terminates the listen call will throw a
# ConnectionError, which we catch here.
pass
def _process_key(self, key):
"""Process the given export key from redis."""
from ray.worker import profile, WORKER_MODE
# Handle the driver case first.
if self.mode != WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
with profile("fetch_and_run_function", worker=self.worker):
self.fetch_and_execute_function_to_run(key)
# Return because FunctionsToRun are the only things that
# the driver should import.
return
if key.startswith(b"RemoteFunction"):
with profile("register_remote_function", worker=self.worker):
self.fetch_and_register_remote_function(key)
elif key.startswith(b"FunctionsToRun"):
with profile("fetch_and_run_function", worker=self.worker):
self.fetch_and_execute_function_to_run(key)
elif key.startswith(b"ActorClass"):
# Keep track of the fact that this actor class has been
# exported so that we know it is safe to turn this worker
# into an actor of that class.
self.worker.imported_actor_classes.add(key)
# TODO(rkn): We may need to bring back the case of
# fetching actor classes here.
else:
raise Exception("This code should be unreachable.")
def fetch_and_register_remote_function(self, key):
"""Import a remote function."""
from ray.worker import FunctionExecutionInfo
(driver_id, function_id_str, function_name, serialized_function,
num_return_vals, module, resources,
max_calls) = self.redis_client.hmget(key, [
"driver_id", "function_id", "name", "function", "num_return_vals",
"module", "resources", "max_calls"
])
function_id = ray.ObjectID(function_id_str)
function_name = utils.decode(function_name)
max_calls = int(max_calls)
module = utils.decode(module)
# This is a placeholder in case the function can't be unpickled. This
# will be overwritten if the function is successfully registered.
def f():
raise Exception("This function was not imported properly.")
self.worker.function_execution_info[driver_id][function_id.id()] = (
FunctionExecutionInfo(
function=f, function_name=function_name, max_calls=max_calls))
self.worker.num_task_executions[driver_id][function_id.id()] = 0
try:
function = pickle.loads(serialized_function)
except Exception:
# If an exception was thrown when the remote function was imported,
# we record the traceback and notify the scheduler of the failure.
traceback_str = utils.format_error_message(traceback.format_exc())
# Log the error message.
utils.push_error_to_driver(
self.worker,
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
data={
"function_id": function_id.id(),
"function_name": function_name
})
else:
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
self.worker.function_execution_info[driver_id][
function_id.id()] = (FunctionExecutionInfo(
function=function,
function_name=function_name,
max_calls=max_calls))
# Add the function to the function table.
self.redis_client.rpush(b"FunctionTable:" + function_id.id(),
self.worker.worker_id)
def fetch_and_execute_function_to_run(self, key):
"""Run on arbitrary function on the worker."""
from ray.worker import SCRIPT_MODE, SILENT_MODE
driver_id, serialized_function = self.redis_client.hmget(
key, ["driver_id", "function"])
if (self.worker.mode in [SCRIPT_MODE, SILENT_MODE]
and driver_id != self.worker.task_driver_id.id()):
# This export was from a different driver and there's no need for
# this driver to import it.
return
try:
# Deserialize the function.
function = pickle.loads(serialized_function)
# Run the function.
function({"worker": self.worker})
except Exception:
# If an exception was thrown when the function was run, we record
# the traceback and notify the scheduler of the failure.
traceback_str = traceback.format_exc()
# Log the error message.
name = function.__name__ if ("function" in locals() and hasattr(
function, "__name__")) else ""
utils.push_error_to_driver(
self.worker,
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
data={"name": name})

View file

@ -31,6 +31,7 @@ import ray.signature
import ray.local_scheduler
import ray.plasma
import ray.ray_constants as ray_constants
from ray import import_thread
from ray.utils import (
binary_to_hex,
check_oversized_pickle,
@ -1984,175 +1985,6 @@ def print_error_messages(worker):
pass
def fetch_and_register_remote_function(key, worker=global_worker):
"""Import a remote function."""
(driver_id, function_id_str, function_name, serialized_function,
num_return_vals, module, resources,
max_calls) = worker.redis_client.hmget(key, [
"driver_id", "function_id", "name", "function", "num_return_vals",
"module", "resources", "max_calls"
])
function_id = ray.ObjectID(function_id_str)
function_name = ray.utils.decode(function_name)
max_calls = int(max_calls)
module = ray.utils.decode(module)
# This is a placeholder in case the function can't be unpickled. This will
# be overwritten if the function is successfully registered.
def f():
raise Exception("This function was not imported properly.")
worker.function_execution_info[driver_id][function_id.id()] = (
FunctionExecutionInfo(
function=f, function_name=function_name, max_calls=max_calls))
worker.num_task_executions[driver_id][function_id.id()] = 0
try:
function = pickle.loads(serialized_function)
except Exception:
# If an exception was thrown when the remote function was imported, we
# record the traceback and notify the scheduler of the failure.
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker,
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
data={
"function_id": function_id.id(),
"function_name": function_name
})
else:
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
worker.function_execution_info[driver_id][function_id.id()] = (
FunctionExecutionInfo(
function=function,
function_name=function_name,
max_calls=max_calls))
# Add the function to the function table.
worker.redis_client.rpush(b"FunctionTable:" + function_id.id(),
worker.worker_id)
def fetch_and_execute_function_to_run(key, worker=global_worker):
"""Run on arbitrary function on the worker."""
driver_id, serialized_function = worker.redis_client.hmget(
key, ["driver_id", "function"])
if (worker.mode in [SCRIPT_MODE, SILENT_MODE]
and driver_id != worker.task_driver_id.id()):
# This export was from a different driver and there's no need for this
# driver to import it.
return
try:
# Deserialize the function.
function = pickle.loads(serialized_function)
# Run the function.
function({"worker": worker})
except Exception:
# If an exception was thrown when the function was run, we record the
# traceback and notify the scheduler of the failure.
traceback_str = traceback.format_exc()
# Log the error message.
name = function.__name__ if ("function" in locals()
and hasattr(function, "__name__")) else ""
ray.utils.push_error_to_driver(
worker,
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
data={"name": name})
def import_thread(worker, mode):
worker.import_pubsub_client = worker.redis_client.pubsub()
# Exports that are published after the call to
# import_pubsub_client.subscribe and before the call to
# import_pubsub_client.listen will still be processed in the loop.
worker.import_pubsub_client.subscribe("__keyspace@0__:Exports")
# Keep track of the number of imports that we've imported.
num_imported = 0
# Get the exports that occurred before the call to subscribe.
with worker.lock:
export_keys = worker.redis_client.lrange("Exports", 0, -1)
for key in export_keys:
num_imported += 1
# Handle the driver case first.
if mode != WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
with profile("fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(key, worker=worker)
# Continue because FunctionsToRun are the only things that the
# driver should import.
continue
if key.startswith(b"RemoteFunction"):
with profile("register_remote_function", worker=worker):
fetch_and_register_remote_function(key, worker=worker)
elif key.startswith(b"FunctionsToRun"):
with profile("fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(key, worker=worker)
elif key.startswith(b"ActorClass"):
# Keep track of the fact that this actor class has been
# exported so that we know it is safe to turn this worker into
# an actor of that class.
worker.imported_actor_classes.add(key)
else:
raise Exception("This code should be unreachable.")
try:
for msg in worker.import_pubsub_client.listen():
with worker.lock:
if msg["type"] == "subscribe":
continue
assert msg["data"] == b"rpush"
num_imports = worker.redis_client.llen("Exports")
assert num_imports >= num_imported
for i in range(num_imported, num_imports):
num_imported += 1
key = worker.redis_client.lindex("Exports", i)
# Handle the driver case first.
if mode != WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
with profile(
"fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(
key, worker=worker)
# Continue because FunctionsToRun are the only things
# that the driver should import.
continue
if key.startswith(b"RemoteFunction"):
with profile(
"register_remote_function", worker=worker):
fetch_and_register_remote_function(
key, worker=worker)
elif key.startswith(b"FunctionsToRun"):
with profile("fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(
key, worker=worker)
elif key.startswith(b"ActorClass"):
# Keep track of the fact that this actor class has been
# exported so that we know it is safe to turn this
# worker into an actor of that class.
worker.imported_actor_classes.add(key)
# TODO(rkn): We may need to bring back the case of fetching
# actor classes here.
else:
raise Exception("This code should be unreachable.")
except redis.ConnectionError:
# When Redis terminates the listen call will throw a ConnectionError,
# which we catch here.
pass
def connect(info,
object_id_seed=None,
mode=WORKER_MODE,
@ -2361,14 +2193,8 @@ def connect(info,
# it must be run before we export all of the cached remote functions.
_initialize_serialization()
# Start a thread to import exports from the driver or from other workers.
# Note that the driver also has an import thread, which is used only to
# import custom class definitions from calls to register_custom_serializer
# that happen under the hood on workers.
t = threading.Thread(target=import_thread, args=(worker, mode))
# Making the thread a daemon causes it to exit when the main thread exits.
t.daemon = True
t.start()
# Start the import thread
import_thread.ImportThread(worker, mode).start()
# If this is a driver running in SCRIPT_MODE, start a thread to print error
# messages asynchronously in the background. Ideally the scheduler would