[Python Worker] add feature flag to support forking from workers (#23260)

Make sure Python dependencies can be imported on demand, without the background importer thread. Use cases are:

If the pubsub notification for a new export is lost, importing can still be done.
Allow not running the background importer thread, without affecting Ray's functionalities.
Add a feature flag to support forking from Python workers, by

Enable fork support in gRPC.
Disable importer thread and only leave the main thread in the Python worker. The importer thread will not run after forking anyway.
This commit is contained in:
mwtian 2022-03-18 14:47:18 -07:00 committed by GitHub
parent 91aa5c4060
commit 909cdea3cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 117 additions and 24 deletions

View file

@ -399,15 +399,14 @@ class FunctionActorManager:
break break
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
warning_message = ( warning_message = (
"This worker was asked to execute a " "This worker was asked to execute a function "
"function that it does not have " f"that has not been registered ({function_descriptor}, "
f"registered ({function_descriptor}, "
f"node={self._worker.node_ip_address}, " f"node={self._worker.node_ip_address}, "
f"worker_id={self._worker.worker_id.hex()}, " f"worker_id={self._worker.worker_id.hex()}, "
f"pid={os.getpid()}). You may have to restart " f"pid={os.getpid()}). You may have to restart Ray."
"Ray."
) )
if not warning_sent: if not warning_sent:
logger.error(warning_message)
ray._private.utils.push_error_to_driver( ray._private.utils.push_error_to_driver(
self._worker, self._worker,
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
@ -415,6 +414,9 @@ class FunctionActorManager:
job_id=job_id, job_id=job_id,
) )
warning_sent = True warning_sent = True
# Try importing in case the worker did not get notified, or the
# importer thread did not run.
self._worker.import_thread._do_importing()
time.sleep(0.001) time.sleep(0.001)
def _publish_actor_class_to_key(self, key, actor_class_info): def _publish_actor_class_to_key(self, key, actor_class_info):

View file

@ -38,8 +38,11 @@ class ImportThread:
self.exception_type = grpc.RpcError self.exception_type = grpc.RpcError
self.threads_stopped = threads_stopped self.threads_stopped = threads_stopped
self.imported_collision_identifiers = defaultdict(int) self.imported_collision_identifiers = defaultdict(int)
self.t = None
# Keep track of the number of imports that we've imported. # Keep track of the number of imports that we've imported.
self.num_imported = 0 self.num_imported = 0
# Protect writes to self.num_imported.
self._lock = threading.Lock()
def start(self): def start(self):
"""Start the import thread.""" """Start the import thread."""
@ -51,7 +54,8 @@ class ImportThread:
def join_import_thread(self): def join_import_thread(self):
"""Wait for the thread to exit.""" """Wait for the thread to exit."""
self.t.join() if self.t:
self.t.join()
def _run(self): def _run(self):
try: try:
@ -74,17 +78,18 @@ class ImportThread:
def _do_importing(self): def _do_importing(self):
while True: while True:
export_key = ray._private.function_manager.make_export_key( with self._lock:
self.num_imported + 1, self.worker.current_job_id export_key = ray._private.function_manager.make_export_key(
) self.num_imported + 1, self.worker.current_job_id
key = self.gcs_client.internal_kv_get( )
export_key, ray_constants.KV_NAMESPACE_FUNCTION_TABLE key = self.gcs_client.internal_kv_get(
) export_key, ray_constants.KV_NAMESPACE_FUNCTION_TABLE
if key is not None: )
self._process_key(key) if key is not None:
self.num_imported += 1 self._process_key(key)
else: self.num_imported += 1
break else:
break
def _get_import_info_for_collision_detection(self, key): def _get_import_info_for_collision_detection(self, key):
"""Retrieve the collision identifier, type, and name of the import.""" """Retrieve the collision identifier, type, and name of the import."""

View file

@ -66,3 +66,5 @@ cdef extern from "ray/common/ray_config.h" nogil:
c_bool gcs_grpc_based_pubsub() const c_bool gcs_grpc_based_pubsub() const
c_bool bootstrap_with_gcs() const c_bool bootstrap_with_gcs() const
c_bool start_python_importer_thread() const

View file

@ -103,3 +103,7 @@ cdef class Config:
@staticmethod @staticmethod
def record_ref_creation_sites(): def record_ref_creation_sites():
return RayConfig.instance().record_ref_creation_sites() return RayConfig.instance().record_ref_creation_sites()
@staticmethod
def start_python_importer_thread():
return RayConfig.instance().start_python_importer_thread()

View file

@ -124,6 +124,7 @@ def test_function_unique_export(ray_start_regular):
else: else:
num_exports += 1 num_exports += 1
print(f"num_exports after running g(): {num_exports}") print(f"num_exports after running g(): {num_exports}")
assert num_exports > 0, "Function export notification is not received"
ray.get([g.remote() for _ in range(5)]) ray.get([g.remote() for _ in range(5)])
@ -131,6 +132,67 @@ def test_function_unique_export(ray_start_regular):
assert key is None, f"Unexpected function key export: {key}" assert key is None, f"Unexpected function key export: {key}"
def test_function_import_without_importer_thread(shutdown_only):
"""Test that without background importer thread, dependencies can still be
imported in workers."""
ray.init(
_system_config={
"start_python_importer_thread": False,
},
)
@ray.remote
def f():
import threading
assert threading.get_ident() == threading.main_thread().ident
# Make sure the importer thread is not running.
for thread in threading.enumerate():
assert "import" not in thread.name
@ray.remote
def g():
ray.get(f.remote())
ray.get(g.remote())
ray.get([g.remote() for _ in range(5)])
@pytest.mark.skipif(
sys.platform == "win32",
reason="Fork is only supported on *nix systems.",
)
def test_fork_support(shutdown_only):
"""Test that fork support works."""
ray.init(
_system_config={
"support_fork": True,
},
)
@ray.remote
def pool_factorial():
import multiprocessing
import math
ctx = multiprocessing.get_context("fork")
with ctx.Pool(processes=4) as pool:
return sum(pool.map(math.factorial, range(8)))
@ray.remote
def g():
import threading
assert threading.get_ident() == threading.main_thread().ident
# Make sure this is the only Python thread, because forking does not
# work well under multi-threading.
assert threading.active_count() == 1
return ray.get(pool_factorial.remote())
assert ray.get(g.remote()) == 5914
@pytest.mark.skipif( @pytest.mark.skipif(
sys.platform not in ["win32", "darwin"], sys.platform not in ["win32", "darwin"],
reason="Only listen on localhost by default on mac and windows.", reason="Only listen on localhost by default on mac and windows.",

View file

@ -1585,7 +1585,8 @@ def connect(
worker.import_thread = import_thread.ImportThread( worker.import_thread = import_thread.ImportThread(
worker, mode, worker.threads_stopped worker, mode, worker.threads_stopped
) )
worker.import_thread.start() if ray._raylet.Config.start_python_importer_thread():
worker.import_thread.start()
# If this is a driver running in SCRIPT_MODE, start a thread to print error # If this is a driver running in SCRIPT_MODE, start a thread to print error
# messages asynchronously in the background. Ideally the scheduler would # messages asynchronously in the background. Ideally the scheduler would

View file

@ -318,6 +318,15 @@ RAY_CONFIG(uint32_t, task_retry_delay_ms, 0)
/// Duration to wait between retrying to kill a task. /// Duration to wait between retrying to kill a task.
RAY_CONFIG(uint32_t, cancellation_retry_ms, 2000) RAY_CONFIG(uint32_t, cancellation_retry_ms, 2000)
/// Whether to start a background thread to import Python dependencies eagerly.
/// When set to false, Python dependencies will still be imported, only when
/// they are needed.
RAY_CONFIG(bool, start_python_importer_thread, true)
/// Determines if forking in Ray actors / tasks are supported.
/// Note that this only enables forking in workers, but not drivers.
RAY_CONFIG(bool, support_fork, false)
/// Maximum timeout for GCS reconnection in seconds. /// Maximum timeout for GCS reconnection in seconds.
/// Each reconnection ping will be retried every 1 second. /// Each reconnection ping will be retried every 1 second.
RAY_CONFIG(int32_t, gcs_rpc_server_reconnect_timeout_s, 60) RAY_CONFIG(int32_t, gcs_rpc_server_reconnect_timeout_s, 60)

View file

@ -386,6 +386,14 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
absl::Base64Escape(RayConfig::instance().object_spilling_config())); absl::Base64Escape(RayConfig::instance().object_spilling_config()));
} }
if (language == Language::PYTHON) {
worker_command_args.push_back("--startup-token=" +
std::to_string(worker_startup_token_counter_));
} else if (language == Language::CPP) {
worker_command_args.push_back("--startup_token=" +
std::to_string(worker_startup_token_counter_));
}
ProcessEnvironment env; ProcessEnvironment env;
if (!IsIOWorkerType(worker_type)) { if (!IsIOWorkerType(worker_type)) {
// We pass the job ID to worker processes via an environment variable, so we don't // We pass the job ID to worker processes via an environment variable, so we don't
@ -459,12 +467,12 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
env.insert({"SPT_NOENV", "1"}); env.insert({"SPT_NOENV", "1"});
} }
if (language == Language::PYTHON) { if (RayConfig::instance().support_fork()) {
worker_command_args.push_back("--startup-token=" + // Support forking in gRPC.
std::to_string(worker_startup_token_counter_)); env.insert({"GRPC_ENABLE_FORK_SUPPORT", "True"});
} else if (language == Language::CPP) { env.insert({"GRPC_POLL_STRATEGY", "poll"});
worker_command_args.push_back("--startup_token=" + // Make sure only the main thread is running in Python workers.
std::to_string(worker_startup_token_counter_)); env.insert({"RAY_start_python_importer_thread", "0"});
} }
// Start a process and measure the startup time. // Start a process and measure the startup time.