ray/python/ray/import_thread.py

183 lines
8.1 KiB
Python

from collections import defaultdict
import threading
import traceback
import redis
import ray
from ray import ray_constants
from ray import cloudpickle as pickle
from ray import profiling
from ray import utils
import logging
logger = logging.getLogger(__name__)
class ImportThread:
"""A thread used to import exports from the driver or other workers.
Attributes:
worker: the worker object in this process.
mode: worker mode
redis_client: the redis client used to query exports.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
imported_collision_identifiers: This is a dictionary mapping collision
identifiers for the exported remote functions and actor classes to
the number of times that collision identifier has appeared. This is
used to provide good error messages when the same function or class
is exported many times.
"""
def __init__(self, worker, mode, threads_stopped):
self.worker = worker
self.mode = mode
self.redis_client = worker.redis_client
self.threads_stopped = threads_stopped
self.imported_collision_identifiers = defaultdict(int)
def start(self):
"""Start the import thread."""
self.t = threading.Thread(target=self._run, name="ray_import_thread")
# Making the thread a daemon causes it to exit
# when the main thread exits.
self.t.daemon = True
self.t.start()
def join_import_thread(self):
"""Wait for the thread to exit."""
self.t.join()
def _run(self):
import_pubsub_client = self.redis_client.pubsub()
# Exports that are published after the call to
# import_pubsub_client.subscribe and before the call to
# import_pubsub_client.listen will still be processed in the loop.
import_pubsub_client.subscribe("__keyspace@0__:Exports")
# Keep track of the number of imports that we've imported.
num_imported = 0
try:
# Get the exports that occurred before the call to subscribe.
export_keys = self.redis_client.lrange("Exports", 0, -1)
for key in export_keys:
num_imported += 1
self._process_key(key)
while True:
# Exit if we received a signal that we should stop.
if self.threads_stopped.is_set():
return
msg = import_pubsub_client.get_message()
if msg is None:
self.threads_stopped.wait(timeout=0.01)
continue
if msg["type"] == "subscribe":
continue
assert msg["data"] == b"rpush"
num_imports = self.redis_client.llen("Exports")
assert num_imports >= num_imported
for i in range(num_imported, num_imports):
num_imported += 1
key = self.redis_client.lindex("Exports", i)
self._process_key(key)
except (OSError, redis.exceptions.ConnectionError) as e:
logger.error(f"ImportThread: {e}")
finally:
# Close the pubsub client to avoid leaking file descriptors.
import_pubsub_client.close()
def _get_import_info_for_collision_detection(self, key):
"""Retrieve the collision identifier, type, and name of the import."""
if key.startswith(b"RemoteFunction"):
collision_identifier, function_name = (self.redis_client.hmget(
key, ["collision_identifier", "function_name"]))
return (collision_identifier, ray.utils.decode(function_name),
"remote function")
elif key.startswith(b"ActorClass"):
collision_identifier, class_name = self.redis_client.hmget(
key, ["collision_identifier", "class_name"])
return collision_identifier, ray.utils.decode(class_name), "actor"
def _process_key(self, key):
"""Process the given export key from redis."""
if self.mode != ray.WORKER_MODE:
# If the same remote function or actor definition appears to be
# exported many times, then print a warning. We only issue this
# warning from the driver so that it is only triggered once instead
# of many times. TODO(rkn): We may want to push this to the driver
# through Redis so that it can be displayed in the dashboard more
# easily.
if (key.startswith(b"RemoteFunction")
or key.startswith(b"ActorClass")):
collision_identifier, name, import_type = (
self._get_import_info_for_collision_detection(key))
self.imported_collision_identifiers[collision_identifier] += 1
if (self.imported_collision_identifiers[collision_identifier]
== ray_constants.DUPLICATE_REMOTE_FUNCTION_THRESHOLD):
logger.warning(
"The %s '%s' has been exported %s times. It's "
"possible that this warning is accidental, but this "
"may indicate that the same remote function is being "
"defined repeatedly from within many tasks and "
"exported to all of the workers. This can be a "
"performance issue and can be resolved by defining "
"the remote function on the driver instead. See "
"https://github.com/ray-project/ray/issues/6240 for "
"more discussion.", import_type, name,
ray_constants.DUPLICATE_REMOTE_FUNCTION_THRESHOLD)
if key.startswith(b"RemoteFunction"):
# TODO (Alex): There's a race condition here if the worker is
# shutdown before the function finished registering (because core
# worker's global worker is unset before shutdown and is needed
# for profiling).
# with profiling.profile("register_remote_function"):
(self.worker.function_actor_manager.
fetch_and_register_remote_function(key))
elif key.startswith(b"FunctionsToRun"):
with profiling.profile("fetch_and_run_function"):
self.fetch_and_execute_function_to_run(key)
elif key.startswith(b"ActorClass"):
# Keep track of the fact that this actor class has been
# exported so that we know it is safe to turn this worker
# into an actor of that class.
self.worker.function_actor_manager.imported_actor_classes.add(key)
# TODO(rkn): We may need to bring back the case of
# fetching actor classes here.
else:
assert False, "This code should be unreachable."
def fetch_and_execute_function_to_run(self, key):
"""Run on arbitrary function on the worker."""
(job_id, serialized_function,
run_on_other_drivers) = self.redis_client.hmget(
key, ["job_id", "function", "run_on_other_drivers"])
if (utils.decode(run_on_other_drivers) == "False"
and self.worker.mode == ray.SCRIPT_MODE
and job_id != self.worker.current_job_id.binary()):
return
try:
# FunctionActorManager may call pickle.loads at the same time.
# Importing the same module in different threads causes deadlock.
with self.worker.function_actor_manager.lock:
# Deserialize the function.
function = pickle.loads(serialized_function)
# Run the function.
function({"worker": self.worker})
except Exception:
# If an exception was thrown when the function was run, we record
# the traceback and notify the scheduler of the failure.
traceback_str = traceback.format_exc()
# Log the error message.
utils.push_error_to_driver(
self.worker,
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
traceback_str,
job_id=ray.JobID(job_id))