Use prctl(PR_SET_PDEATHSIG) on Linux instead of reaper (#7150)

This commit is contained in:
mehrdadn 2020-03-03 09:45:42 -08:00 committed by GitHub
parent f5b1062ed9
commit 4d42664b2a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 244 additions and 46 deletions

View file

@ -66,6 +66,8 @@ class Node:
self._register_shutdown_hooks() self._register_shutdown_hooks()
self.head = head self.head = head
self.kernel_fate_share = (spawn_reaper
and ray.utils.detect_fate_sharing_support())
self.all_processes = {} self.all_processes = {}
# Try to get node IP address with the parameters. # Try to get node IP address with the parameters.
@ -154,7 +156,7 @@ class Node:
# raylet starts. # raylet starts.
self._ray_params.node_manager_port = self._get_unused_port() self._ray_params.node_manager_port = self._get_unused_port()
if not connect_only and spawn_reaper: if not connect_only and spawn_reaper and not self.kernel_fate_share:
self.start_reaper_process() self.start_reaper_process()
# Start processes. # Start processes.
@ -413,7 +415,9 @@ class Node:
This must be the first process spawned and should only be called when This must be the first process spawned and should only be called when
ray processes should be cleaned up if this process dies. ray processes should be cleaned up if this process dies.
""" """
process_info = ray.services.start_reaper() assert not self.kernel_fate_share, (
"a reaper should not be used with kernel fate-sharing")
process_info = ray.services.start_reaper(fate_share=False)
assert ray_constants.PROCESS_TYPE_REAPER not in self.all_processes assert ray_constants.PROCESS_TYPE_REAPER not in self.all_processes
if process_info is not None: if process_info is not None:
self.all_processes[ray_constants.PROCESS_TYPE_REAPER] = [ self.all_processes[ray_constants.PROCESS_TYPE_REAPER] = [
@ -438,7 +442,8 @@ class Node:
redis_max_clients=self._ray_params.redis_max_clients, redis_max_clients=self._ray_params.redis_max_clients,
redirect_worker_output=True, redirect_worker_output=True,
password=self._ray_params.redis_password, password=self._ray_params.redis_password,
include_java=self._ray_params.include_java) include_java=self._ray_params.include_java,
fate_share=self.kernel_fate_share)
assert ( assert (
ray_constants.PROCESS_TYPE_REDIS_SERVER not in self.all_processes) ray_constants.PROCESS_TYPE_REDIS_SERVER not in self.all_processes)
self.all_processes[ray_constants.PROCESS_TYPE_REDIS_SERVER] = ( self.all_processes[ray_constants.PROCESS_TYPE_REDIS_SERVER] = (
@ -452,7 +457,8 @@ class Node:
self._logs_dir, self._logs_dir,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file, stderr_file=stderr_file,
redis_password=self._ray_params.redis_password) redis_password=self._ray_params.redis_password,
fate_share=self.kernel_fate_share)
assert ray_constants.PROCESS_TYPE_LOG_MONITOR not in self.all_processes assert ray_constants.PROCESS_TYPE_LOG_MONITOR not in self.all_processes
self.all_processes[ray_constants.PROCESS_TYPE_LOG_MONITOR] = [ self.all_processes[ray_constants.PROCESS_TYPE_LOG_MONITOR] = [
process_info process_info
@ -465,7 +471,8 @@ class Node:
self.redis_address, self.redis_address,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file, stderr_file=stderr_file,
redis_password=self._ray_params.redis_password) redis_password=self._ray_params.redis_password,
fate_share=self.kernel_fate_share)
assert ray_constants.PROCESS_TYPE_REPORTER not in self.all_processes assert ray_constants.PROCESS_TYPE_REPORTER not in self.all_processes
if process_info is not None: if process_info is not None:
self.all_processes[ray_constants.PROCESS_TYPE_REPORTER] = [ self.all_processes[ray_constants.PROCESS_TYPE_REPORTER] = [
@ -488,7 +495,8 @@ class Node:
self._temp_dir, self._temp_dir,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file, stderr_file=stderr_file,
redis_password=self._ray_params.redis_password) redis_password=self._ray_params.redis_password,
fate_share=self.kernel_fate_share)
assert ray_constants.PROCESS_TYPE_DASHBOARD not in self.all_processes assert ray_constants.PROCESS_TYPE_DASHBOARD not in self.all_processes
if process_info is not None: if process_info is not None:
self.all_processes[ray_constants.PROCESS_TYPE_DASHBOARD] = [ self.all_processes[ray_constants.PROCESS_TYPE_DASHBOARD] = [
@ -506,7 +514,8 @@ class Node:
stderr_file=stderr_file, stderr_file=stderr_file,
plasma_directory=self._ray_params.plasma_directory, plasma_directory=self._ray_params.plasma_directory,
huge_pages=self._ray_params.huge_pages, huge_pages=self._ray_params.huge_pages,
plasma_store_socket_name=self._plasma_store_socket_name) plasma_store_socket_name=self._plasma_store_socket_name,
fate_share=self.kernel_fate_share)
assert ( assert (
ray_constants.PROCESS_TYPE_PLASMA_STORE not in self.all_processes) ray_constants.PROCESS_TYPE_PLASMA_STORE not in self.all_processes)
self.all_processes[ray_constants.PROCESS_TYPE_PLASMA_STORE] = [ self.all_processes[ray_constants.PROCESS_TYPE_PLASMA_STORE] = [
@ -522,7 +531,8 @@ class Node:
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file, stderr_file=stderr_file,
redis_password=self._ray_params.redis_password, redis_password=self._ray_params.redis_password,
config=self._config) config=self._config,
fate_share=self.kernel_fate_share)
assert ( assert (
ray_constants.PROCESS_TYPE_GCS_SERVER not in self.all_processes) ray_constants.PROCESS_TYPE_GCS_SERVER not in self.all_processes)
self.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER] = [ self.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER] = [
@ -559,7 +569,8 @@ class Node:
include_java=self._ray_params.include_java, include_java=self._ray_params.include_java,
java_worker_options=self._ray_params.java_worker_options, java_worker_options=self._ray_params.java_worker_options,
load_code_from_local=self._ray_params.load_code_from_local, load_code_from_local=self._ray_params.load_code_from_local,
use_pickle=self._ray_params.use_pickle) use_pickle=self._ray_params.use_pickle,
fate_share=self.kernel_fate_share)
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]
@ -581,7 +592,8 @@ class Node:
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file, stderr_file=stderr_file,
autoscaling_config=self._ray_params.autoscaling_config, autoscaling_config=self._ray_params.autoscaling_config,
redis_password=self._ray_params.redis_password) redis_password=self._ray_params.redis_password,
fate_share=self.kernel_fate_share)
assert ray_constants.PROCESS_TYPE_MONITOR not in self.all_processes assert ray_constants.PROCESS_TYPE_MONITOR not in self.all_processes
self.all_processes[ray_constants.PROCESS_TYPE_MONITOR] = [process_info] self.all_processes[ray_constants.PROCESS_TYPE_MONITOR] = [process_info]
@ -593,7 +605,8 @@ class Node:
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file, stderr_file=stderr_file,
redis_password=self._ray_params.redis_password, redis_password=self._ray_params.redis_password,
config=self._config) config=self._config,
fate_share=self.kernel_fate_share)
assert (ray_constants.PROCESS_TYPE_RAYLET_MONITOR not in assert (ray_constants.PROCESS_TYPE_RAYLET_MONITOR not in
self.all_processes) self.all_processes)
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET_MONITOR] = [ self.all_processes[ray_constants.PROCESS_TYPE_RAYLET_MONITOR] = [

View file

@ -329,7 +329,8 @@ def start_ray_process(command,
use_tmux=False, use_tmux=False,
stdout_file=None, stdout_file=None,
stderr_file=None, stderr_file=None,
pipe_stdin=False): pipe_stdin=False,
fate_share=None):
"""Start one of the Ray processes. """Start one of the Ray processes.
TODO(rkn): We need to figure out how these commands interact. For example, TODO(rkn): We need to figure out how these commands interact. For example,
@ -358,6 +359,8 @@ def start_ray_process(command,
no redirection should happen, then this should be None. no redirection should happen, then this should be None.
pipe_stdin: If true, subprocess.PIPE will be passed to the process as pipe_stdin: If true, subprocess.PIPE will be passed to the process as
stdin. stdin.
fate_share: If true, the child will be killed if its parent (us) dies.
Note that this functionality must be supported, or it is an error.
Returns: Returns:
Information about the process that was started including a handle to Information about the process that was started including a handle to
@ -439,12 +442,18 @@ def start_ray_process(command,
# version, and tmux 2.1) # version, and tmux 2.1)
command = ["tmux", "new-session", "-d", "{}".format(" ".join(command))] command = ["tmux", "new-session", "-d", "{}".format(" ".join(command))]
# Block sigint for spawned processes so they aren't killed by the SIGINT if fate_share is None:
# propagated from the shell on Ctrl-C so we can handle KeyboardInterrupts logger.warning("fate_share= should be passed to start_ray_process()")
# in interactive sessions. This is only supported in Python 3.3 and above. if fate_share:
def block_sigint(): assert ray.utils.detect_fate_sharing_support(), (
"kernel-level fate-sharing must only be specified if "
"detect_fate_sharing_support() has returned True")
def preexec_fn():
import signal import signal
signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT}) signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT})
if fate_share and sys.platform.startswith("linux"):
ray.utils.set_kill_on_parent_death_linux()
process = subprocess.Popen( process = subprocess.Popen(
command, command,
@ -453,7 +462,10 @@ def start_ray_process(command,
stdout=stdout_file, stdout=stdout_file,
stderr=stderr_file, stderr=stderr_file,
stdin=subprocess.PIPE if pipe_stdin else None, stdin=subprocess.PIPE if pipe_stdin else None,
preexec_fn=block_sigint) preexec_fn=preexec_fn if sys.platform != "win32" else None)
if fate_share and sys.platform == "win32":
ray.utils.set_kill_child_on_death_win32(process)
return ProcessInfo( return ProcessInfo(
process=process, process=process,
@ -569,7 +581,7 @@ def check_version_info(redis_client):
logger.warning(error_message) logger.warning(error_message)
def start_reaper(): def start_reaper(fate_share=None):
"""Start the reaper process. """Start the reaper process.
This is a lightweight process that simply This is a lightweight process that simply
@ -585,8 +597,9 @@ def start_reaper():
# process that started us. # process that started us.
try: try:
os.setpgrp() os.setpgrp()
except OSError as e: except (AttributeError, OSError) as e:
if e.errno == errno.EPERM and os.getpgrp() == os.getpid(): errcode = e.errno if isinstance(e, OSError) else None
if errcode == errno.EPERM and os.getpgrp() == os.getpid():
# Nothing to do; we're already a session leader. # Nothing to do; we're already a session leader.
pass pass
else: else:
@ -600,7 +613,10 @@ def start_reaper():
os.path.dirname(os.path.abspath(__file__)), "ray_process_reaper.py") os.path.dirname(os.path.abspath(__file__)), "ray_process_reaper.py")
command = [sys.executable, "-u", reaper_filepath] command = [sys.executable, "-u", reaper_filepath]
process_info = start_ray_process( process_info = start_ray_process(
command, ray_constants.PROCESS_TYPE_REAPER, pipe_stdin=True) command,
ray_constants.PROCESS_TYPE_REAPER,
pipe_stdin=True,
fate_share=fate_share)
return process_info return process_info
@ -614,7 +630,8 @@ def start_redis(node_ip_address,
redirect_worker_output=False, redirect_worker_output=False,
password=None, password=None,
use_credis=None, use_credis=None,
include_java=False): include_java=False,
fate_share=None):
"""Start the Redis global state store. """Start the Redis global state store.
Args: Args:
@ -698,7 +715,8 @@ def start_redis(node_ip_address,
# primary Redis shard. # primary Redis shard.
redis_max_memory=None, redis_max_memory=None,
stdout_file=redis_stdout_file, stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file) stderr_file=redis_stderr_file,
fate_share=fate_share)
processes.append(p) processes.append(p)
redis_address = address(node_ip_address, port) redis_address = address(node_ip_address, port)
@ -803,7 +821,8 @@ def _start_redis_instance(executable,
stdout_file=None, stdout_file=None,
stderr_file=None, stderr_file=None,
password=None, password=None,
redis_max_memory=None): redis_max_memory=None,
fate_share=None):
"""Start a single Redis server. """Start a single Redis server.
Notes: Notes:
@ -869,7 +888,8 @@ def _start_redis_instance(executable,
command, command,
ray_constants.PROCESS_TYPE_REDIS_SERVER, ray_constants.PROCESS_TYPE_REDIS_SERVER,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file) stderr_file=stderr_file,
fate_share=fate_share)
time.sleep(0.1) time.sleep(0.1)
# Check if Redis successfully started (or at least if it the executable # Check if Redis successfully started (or at least if it the executable
# did not exit within 0.1 seconds). # did not exit within 0.1 seconds).
@ -942,7 +962,8 @@ def start_log_monitor(redis_address,
logs_dir, logs_dir,
stdout_file=None, stdout_file=None,
stderr_file=None, stderr_file=None,
redis_password=None): redis_password=None,
fate_share=None):
"""Start a log monitor process. """Start a log monitor process.
Args: Args:
@ -970,14 +991,16 @@ def start_log_monitor(redis_address,
command, command,
ray_constants.PROCESS_TYPE_LOG_MONITOR, ray_constants.PROCESS_TYPE_LOG_MONITOR,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file) stderr_file=stderr_file,
fate_share=fate_share)
return process_info return process_info
def start_reporter(redis_address, def start_reporter(redis_address,
stdout_file=None, stdout_file=None,
stderr_file=None, stderr_file=None,
redis_password=None): redis_password=None,
fate_share=None):
"""Start a reporter process. """Start a reporter process.
Args: Args:
@ -1004,7 +1027,8 @@ def start_reporter(redis_address,
command, command,
ray_constants.PROCESS_TYPE_REPORTER, ray_constants.PROCESS_TYPE_REPORTER,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file) stderr_file=stderr_file,
fate_share=fate_share)
return process_info return process_info
@ -1014,7 +1038,8 @@ def start_dashboard(require_webui,
temp_dir, temp_dir,
stdout_file=None, stdout_file=None,
stderr_file=None, stderr_file=None,
redis_password=None): redis_password=None,
fate_share=None):
"""Start a dashboard process. """Start a dashboard process.
Args: Args:
@ -1077,7 +1102,8 @@ def start_dashboard(require_webui,
command, command,
ray_constants.PROCESS_TYPE_DASHBOARD, ray_constants.PROCESS_TYPE_DASHBOARD,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file) stderr_file=stderr_file,
fate_share=fate_share)
dashboard_url = "{}:{}".format( dashboard_url = "{}:{}".format(
host if host != "0.0.0.0" else get_node_ip_address(), port) host if host != "0.0.0.0" else get_node_ip_address(), port)
@ -1093,7 +1119,8 @@ def start_gcs_server(redis_address,
stdout_file=None, stdout_file=None,
stderr_file=None, stderr_file=None,
redis_password=None, redis_password=None,
config=None): config=None,
fate_share=None):
"""Start a gcs server. """Start a gcs server.
Args: Args:
redis_address (str): The address that the Redis server is listening on. redis_address (str): The address that the Redis server is listening on.
@ -1123,7 +1150,8 @@ def start_gcs_server(redis_address,
command, command,
ray_constants.PROCESS_TYPE_GCS_SERVER, ray_constants.PROCESS_TYPE_GCS_SERVER,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file) stderr_file=stderr_file,
fate_share=fate_share)
return process_info return process_info
@ -1146,7 +1174,8 @@ def start_raylet(redis_address,
include_java=False, include_java=False,
java_worker_options=None, java_worker_options=None,
load_code_from_local=False, load_code_from_local=False,
use_pickle=False): use_pickle=False,
fate_share=None):
"""Start a raylet, which is a combined local scheduler and object manager. """Start a raylet, which is a combined local scheduler and object manager.
Args: Args:
@ -1275,7 +1304,8 @@ def start_raylet(redis_address,
use_valgrind_profiler=use_profiler, use_valgrind_profiler=use_profiler,
use_perftools_profiler=("RAYLET_PERFTOOLS_PATH" in os.environ), use_perftools_profiler=("RAYLET_PERFTOOLS_PATH" in os.environ),
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file) stderr_file=stderr_file,
fate_share=fate_share)
return process_info return process_info
@ -1437,7 +1467,8 @@ def _start_plasma_store(plasma_store_memory,
stderr_file=None, stderr_file=None,
plasma_directory=None, plasma_directory=None,
huge_pages=False, huge_pages=False,
socket_name=None): socket_name=None,
fate_share=None):
"""Start a plasma store process. """Start a plasma store process.
Args: Args:
@ -1491,7 +1522,8 @@ def _start_plasma_store(plasma_store_memory,
use_valgrind=use_valgrind, use_valgrind=use_valgrind,
use_valgrind_profiler=use_profiler, use_valgrind_profiler=use_profiler,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file) stderr_file=stderr_file,
fate_share=fate_share)
return process_info return process_info
@ -1500,7 +1532,8 @@ def start_plasma_store(resource_spec,
stderr_file=None, stderr_file=None,
plasma_directory=None, plasma_directory=None,
huge_pages=False, huge_pages=False,
plasma_store_socket_name=None): plasma_store_socket_name=None,
fate_share=None):
"""This method starts an object store process. """This method starts an object store process.
Args: Args:
@ -1541,7 +1574,8 @@ def start_plasma_store(resource_spec,
stderr_file=stderr_file, stderr_file=stderr_file,
plasma_directory=plasma_directory, plasma_directory=plasma_directory,
huge_pages=huge_pages, huge_pages=huge_pages,
socket_name=plasma_store_socket_name) socket_name=plasma_store_socket_name,
fate_share=fate_share)
return process_info return process_info
@ -1553,7 +1587,8 @@ def start_worker(node_ip_address,
worker_path, worker_path,
temp_dir, temp_dir,
stdout_file=None, stdout_file=None,
stderr_file=None): stderr_file=None,
fate_share=None):
"""This method starts a worker process. """This method starts a worker process.
Args: Args:
@ -1584,7 +1619,8 @@ def start_worker(node_ip_address,
command, command,
ray_constants.PROCESS_TYPE_WORKER, ray_constants.PROCESS_TYPE_WORKER,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file) stderr_file=stderr_file,
fate_share=fate_share)
return process_info return process_info
@ -1592,7 +1628,8 @@ def start_monitor(redis_address,
stdout_file=None, stdout_file=None,
stderr_file=None, stderr_file=None,
autoscaling_config=None, autoscaling_config=None,
redis_password=None): redis_password=None,
fate_share=None):
"""Run a process to monitor the other processes. """Run a process to monitor the other processes.
Args: Args:
@ -1621,7 +1658,8 @@ def start_monitor(redis_address,
command, command,
ray_constants.PROCESS_TYPE_MONITOR, ray_constants.PROCESS_TYPE_MONITOR,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file) stderr_file=stderr_file,
fate_share=fate_share)
return process_info return process_info
@ -1629,7 +1667,8 @@ def start_raylet_monitor(redis_address,
stdout_file=None, stdout_file=None,
stderr_file=None, stderr_file=None,
redis_password=None, redis_password=None,
config=None): config=None,
fate_share=None):
"""Run a process to monitor the other processes. """Run a process to monitor the other processes.
Args: Args:
@ -1661,5 +1700,6 @@ def start_raylet_monitor(redis_address,
command, command,
ray_constants.PROCESS_TYPE_RAYLET_MONITOR, ray_constants.PROCESS_TYPE_RAYLET_MONITOR,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file) stderr_file=stderr_file,
fate_share=fate_share)
return process_info return process_info

View file

@ -176,7 +176,8 @@ def test_worker_plasma_store_failure(ray_start_cluster_head):
cluster.wait_for_nodes() cluster.wait_for_nodes()
worker.kill_reporter() worker.kill_reporter()
worker.kill_plasma_store() worker.kill_plasma_store()
worker.kill_reaper() if ray_constants.PROCESS_TYPE_REAPER in worker.all_processes:
worker.kill_reaper()
worker.all_processes[ray_constants.PROCESS_TYPE_RAYLET][0].process.wait() worker.all_processes[ray_constants.PROCESS_TYPE_RAYLET][0].process.wait()
assert not worker.any_processes_alive(), worker.live_processes() assert not worker.any_processes_alive(), worker.live_processes()

View file

@ -6,6 +6,7 @@ import logging
import numpy as np import numpy as np
import os import os
import six import six
import subprocess
import sys import sys
import threading import threading
import time import time
@ -18,6 +19,15 @@ import psutil
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Linux can bind child processes' lifetimes to that of their parents via prctl.
# prctl support is detected dynamically once, and assumed thereafter.
linux_prctl = None
# Windows can bind processes' lifetimes to that of kernel-level "job objects".
# We keep a global job object to tie its lifetime to that of our own process.
win32_job = None
win32_AssignProcessToJobObject = None
def _random_string(): def _random_string():
id_hash = hashlib.sha1() id_hash = hashlib.sha1()
@ -496,6 +506,140 @@ def is_main_thread():
return threading.current_thread().getName() == "MainThread" return threading.current_thread().getName() == "MainThread"
def detect_fate_sharing_support_win32():
global win32_job, win32_AssignProcessToJobObject
if win32_job is None and sys.platform == "win32":
import ctypes
try:
from ctypes.wintypes import BOOL, DWORD, HANDLE, LPVOID, LPCWSTR
kernel32 = ctypes.WinDLL("kernel32")
kernel32.CreateJobObjectW.argtypes = (LPVOID, LPCWSTR)
kernel32.CreateJobObjectW.restype = HANDLE
sijo_argtypes = (HANDLE, ctypes.c_int, LPVOID, DWORD)
kernel32.SetInformationJobObject.argtypes = sijo_argtypes
kernel32.SetInformationJobObject.restype = BOOL
kernel32.AssignProcessToJobObject.argtypes = (HANDLE, HANDLE)
kernel32.AssignProcessToJobObject.restype = BOOL
except (AttributeError, TypeError, ImportError):
kernel32 = None
job = kernel32.CreateJobObjectW(None, None) if kernel32 else None
job = subprocess.Handle(job) if job else job
if job:
from ctypes.wintypes import DWORD, LARGE_INTEGER, ULARGE_INTEGER
class JOBOBJECT_BASIC_LIMIT_INFORMATION(ctypes.Structure):
_fields_ = [
("PerProcessUserTimeLimit", LARGE_INTEGER),
("PerJobUserTimeLimit", LARGE_INTEGER),
("LimitFlags", DWORD),
("MinimumWorkingSetSize", ctypes.c_size_t),
("MaximumWorkingSetSize", ctypes.c_size_t),
("ActiveProcessLimit", DWORD),
("Affinity", ctypes.c_size_t),
("PriorityClass", DWORD),
("SchedulingClass", DWORD),
]
class IO_COUNTERS(ctypes.Structure):
_fields_ = [
("ReadOperationCount", ULARGE_INTEGER),
("WriteOperationCount", ULARGE_INTEGER),
("OtherOperationCount", ULARGE_INTEGER),
("ReadTransferCount", ULARGE_INTEGER),
("WriteTransferCount", ULARGE_INTEGER),
("OtherTransferCount", ULARGE_INTEGER),
]
class JOBOBJECT_EXTENDED_LIMIT_INFORMATION(ctypes.Structure):
_fields_ = [
("BasicLimitInformation",
JOBOBJECT_BASIC_LIMIT_INFORMATION),
("IoInfo", IO_COUNTERS),
("ProcessMemoryLimit", ctypes.c_size_t),
("JobMemoryLimit", ctypes.c_size_t),
("PeakProcessMemoryUsed", ctypes.c_size_t),
("PeakJobMemoryUsed", ctypes.c_size_t),
]
# Defined in <WinNT.h>; also available here:
# https://docs.microsoft.com/en-us/windows/win32/api/jobapi2/nf-jobapi2-setinformationjobobject
JobObjectExtendedLimitInformation = 9
JOB_OBJECT_LIMIT_BREAKAWAY_OK = 0x00000800
JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE = 0x00002000
buf = JOBOBJECT_EXTENDED_LIMIT_INFORMATION()
buf.BasicLimitInformation.LimitFlags = (
JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE
| JOB_OBJECT_LIMIT_BREAKAWAY_OK)
infoclass = JobObjectExtendedLimitInformation
if not kernel32.SetInformationJobObject(
job, infoclass, ctypes.byref(buf), ctypes.sizeof(buf)):
job = None
win32_AssignProcessToJobObject = (kernel32.AssignProcessToJobObject
if kernel32 is not None else False)
win32_job = job if job else False
return bool(win32_job)
def detect_fate_sharing_support_linux():
global linux_prctl
if linux_prctl is None and sys.platform.startswith("linux"):
try:
from ctypes import c_int, c_ulong, CDLL
prctl = CDLL(None).prctl
prctl.restype = c_int
prctl.argtypes = [c_int, c_ulong, c_ulong, c_ulong, c_ulong]
except (AttributeError, TypeError):
prctl = None
linux_prctl = prctl if prctl else False
return bool(linux_prctl)
def detect_fate_sharing_support():
result = None
if sys.platform == "win32":
result = detect_fate_sharing_support_win32()
elif sys.platform.startswith("linux"):
result = detect_fate_sharing_support_linux()
return result
def set_kill_on_parent_death_linux():
"""Ensures this process dies if its parent dies (fate-sharing).
Linux-only. Must be called in preexec_fn (i.e. by the child).
"""
if detect_fate_sharing_support_linux():
import signal
PR_SET_PDEATHSIG = 1
if linux_prctl(PR_SET_PDEATHSIG, signal.SIGKILL, 0, 0, 0) != 0:
import ctypes
raise OSError(ctypes.get_errno(), "prctl(PR_SET_PDEATHSIG) failed")
else:
assert False, "PR_SET_PDEATHSIG used despite being unavailable"
def set_kill_child_on_death_win32(child_proc):
"""Ensures the child process dies if this process dies (fate-sharing).
Windows-only. Must be called by the parent, after spawning the child.
Args:
child_proc: The subprocess.Popen or subprocess.Handle object.
"""
if isinstance(child_proc, subprocess.Popen):
child_proc = child_proc._handle
assert isinstance(child_proc, subprocess.Handle)
if detect_fate_sharing_support_win32():
if not win32_AssignProcessToJobObject(win32_job, int(child_proc)):
import ctypes
raise OSError(ctypes.get_last_error(),
"AssignProcessToJobObject() failed")
else:
assert False, "AssignProcessToJobObject used despite being unavailable"
def try_make_directory_shared(directory_path): def try_make_directory_shared(directory_path):
try: try:
os.chmod(directory_path, 0o0777) os.chmod(directory_path, 0o0777)