Isolate function exports by job in separate queues (#20882)

This commit is contained in:
Eric Liang 2021-12-21 16:19:00 -08:00 committed by GitHub
parent 7d861a2c58
commit 1db03862a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 38 deletions

View file

@ -37,9 +37,13 @@ FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
logger = logging.getLogger(__name__)
def make_export_key(pos):
def make_exports_prefix(job_id: bytes) -> bytes:
return b"IsolatedExports:" + job_id
def make_export_key(pos: int, job_id: bytes) -> bytes:
# big-endian for ordering in binary
return b"Exports:" + pos.to_bytes(8, "big")
return make_exports_prefix(job_id) + b":" + pos.to_bytes(8, "big")
class FunctionActorManager:
@ -150,16 +154,21 @@ class FunctionActorManager:
self._worker.import_thread.num_imported)
while True:
self._num_exported += 1
holder = make_export_key(self._num_exported)
holder = make_export_key(self._num_exported,
self._worker.current_job_id.binary())
# This step is atomic since internal kv is a single thread
# atomic db.
if self._worker.gcs_client.internal_kv_put(
holder, key, False, KV_NAMESPACE_FUNCTION_TABLE) > 0:
break
# Notify all subscribers that there is a new function exported. Note
# that the notification doesn't include any actual data.
if self._worker.gcs_pubsub_enabled:
# TODO(mwtian) implement per-job notification here.
self._worker.gcs_publisher.publish_function_key(key)
else:
self._worker.redis_client.lpush("Exports", "a")
self._worker.redis_client.lpush(
make_exports_prefix(self._worker.current_job_id.binary()), "a")
def export(self, remote_function):
"""Pickle a remote function and export it to redis.
@ -218,14 +227,6 @@ class FunctionActorManager:
(job_id_str, function_id_str, function_name, serialized_function,
module, max_calls) = (vals.get(field) for field in fields)
if ray_constants.ISOLATE_EXPORTS and \
job_id_str != self._worker.current_job_id.binary():
# A worker only executes tasks from the assigned job.
# TODO(jjyao): If fetching unrelated remote functions
# becomes a perf issue, we can also consider having export
# queue per job.
return
function_id = ray.FunctionID(function_id_str)
job_id = ray.JobID(job_id_str)
max_calls = int(max_calls)
@ -555,20 +556,27 @@ class FunctionActorManager:
"""Load actor class from GCS."""
key = (b"ActorClass:" + job_id.binary() + b":" +
actor_creation_function_descriptor.function_id.binary())
# Wait for the actor class key to have been imported by the
# import thread. TODO(rkn): It shouldn't be possible to end
# up in an infinite loop here, but we should push an error to
# the driver if too much time is spent here.
while key not in self.imported_actor_classes:
try:
# If we're in the process of deserializing an ActorHandle
# and we hold the function_manager lock, we may be blocking
# the import_thread from loading the actor class. Use cv.wait
# to temporarily yield control to the import thread.
self.cv.wait()
except RuntimeError:
# We don't hold the function_manager lock, just sleep regularly
time.sleep(0.001)
# Only wait for the actor class if it was exported from the same job.
# It will hang if the job id mismatches, since we isolate actor class
# exports from the import thread. It's important to wait since this
# guarantees import order, though we fetch the actor class directly.
# Import order isn't important across jobs, as we only need to fetch
# the class for `ray.get_actor()`.
if job_id.binary() == self._worker.current_job_id.binary():
# Wait for the actor class key to have been imported by the
# import thread. TODO(rkn): It shouldn't be possible to end
# up in an infinite loop here, but we should push an error to
# the driver if too much time is spent here.
while key not in self.imported_actor_classes:
try:
# If we're in the process of deserializing an ActorHandle
# and we hold the function_manager lock, we may be blocking
# the import_thread from loading the actor class. Use wait
# to temporarily yield control to the import thread.
self.cv.wait()
except RuntimeError:
# We don't hold the function_manager lock, just sleep
time.sleep(0.001)
# Fetch raw data from GCS.
vals = self._worker.gcs_client.internal_kv_get(

View file

@ -39,7 +39,9 @@ class ImportThread:
self.subscriber.subscribe()
else:
self.subscriber = worker.redis_client.pubsub()
self.subscriber.subscribe("__keyspace@0__:Exports")
self.subscriber.subscribe(
b"__keyspace@0__:" + ray._private.function_manager.
make_exports_prefix(self.worker.current_job_id.binary()))
self.threads_stopped = threads_stopped
self.imported_collision_identifiers = defaultdict(int)
# Keep track of the number of imports that we've imported.
@ -89,7 +91,7 @@ class ImportThread:
def _do_importing(self):
while True:
export_key = ray._private.function_manager.make_export_key(
self.num_imported + 1)
self.num_imported + 1, self.worker.current_job_id.binary())
key = self.gcs_client.internal_kv_get(
export_key, ray_constants.KV_NAMESPACE_FUNCTION_TABLE)
if key is not None:
@ -174,10 +176,6 @@ class ImportThread:
if self.worker.mode == ray.SCRIPT_MODE:
return
if ray_constants.ISOLATE_EXPORTS and \
job_id != self.worker.current_job_id.binary():
return
try:
# FunctionActorManager may call pickle.loads at the same time.
# Importing the same module in different threads causes deadlock.

View file

@ -160,9 +160,6 @@ REPORTER_UPDATE_INTERVAL_MS = env_integer("REPORTER_UPDATE_INTERVAL_MS", 2500)
# `services.py:wait_for_redis_to_start`.
START_REDIS_WAIT_RETRIES = env_integer("RAY_START_REDIS_WAIT_RETRIES", 16)
# Only unpickle and run exported functions from the same job if it's true.
ISOLATE_EXPORTS = env_bool("RAY_ISOLATE_EXPORTS", True)
LOGGER_FORMAT = (
"%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s")
LOGGER_FORMAT_HELP = f"The logging format. default='{LOGGER_FORMAT}'"

View file

@ -34,6 +34,7 @@ py_test_module_list(
"test_global_gc.py",
"test_grpc_client_credentials.py",
"test_iter.py",
"test_job.py",
"test_joblib.py",
"test_get_locations.py",
"test_global_state.py",
@ -121,7 +122,6 @@ py_test_module_list(
"test_component_failures.py",
"test_debug_tools.py",
"test_distributed_sort.py",
"test_job.py",
"test_kv.py",
"test_microbenchmarks.py",
"test_mini.py",

View file

@ -58,6 +58,44 @@ assert ray.get(lib.task.remote()) == {}
subprocess.check_call([sys.executable, v2_driver])
def test_export_queue_isolation(call_ray_start):
address = call_ray_start
driver_template = """
import ray
ray.init(address="{}")
@ray.remote
def f():
pass
ray.get(f.remote())
count = 0
for k in ray.worker.global_worker.redis_client.keys():
if b"IsolatedExports:" + ray.get_runtime_context().job_id.binary() in k:
count += 1
# Check exports aren't shared across the 5 jobs.
assert count < 5, count
"""
with tempfile.TemporaryDirectory() as tmpdir:
os.makedirs(os.path.join(tmpdir, "v1"))
v1_driver = os.path.join(tmpdir, "v1", "driver.py")
with open(v1_driver, "w") as f:
f.write(driver_template.format(address))
try:
subprocess.check_call([sys.executable, v1_driver])
except Exception:
# Ignore the first run, since it runs extra exports.
pass
# Further runs do not increase the num exports count.
for _ in range(5):
subprocess.check_call([sys.executable, v1_driver])
def test_job_gc(call_ray_start):
address = call_ray_start
@ -176,7 +214,7 @@ ray.shutdown()
assert finished["EndTime"] > finished["StartTime"] > 0, out
lapsed = finished["EndTime"] - finished["StartTime"]
assert 0 < lapsed < 2000, f"Job should've taken ~1s, {finished}"
assert 0 < lapsed < 5000, f"Job should've taken ~1s, {finished}"
assert running["StartTime"] > 0
assert running["EndTime"] == 0
@ -195,7 +233,7 @@ ray.shutdown()
assert finished["EndTime"] > finished["StartTime"] > 0, f"{finished}"
assert finished["EndTime"] == finished["Timestamp"]
lapsed = finished["EndTime"] - finished["StartTime"]
assert 0 < lapsed < 2000, f"Job should've taken ~1s {finished}"
assert 0 < lapsed < 5000, f"Job should've taken ~1s {finished}"
assert prev_running["EndTime"] > prev_running["StartTime"] > 0