[Core][Pubsub][Importer] GCS pubsub for function manager & importer (#20804)

This PR allows using Ray pubsub for notifying worker importers that a new function / actor class needs to be imported.
This commit is contained in:
mwtian 2021-12-01 10:44:50 -08:00 committed by GitHub
parent 8c0bf41b17
commit 0467bc9df5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 191 additions and 25 deletions

View file

@ -156,8 +156,10 @@ class FunctionActorManager:
if self._worker.gcs_client.internal_kv_put( if self._worker.gcs_client.internal_kv_put(
holder, key, False, KV_NAMESPACE_FUNCTION_TABLE) > 0: holder, key, False, KV_NAMESPACE_FUNCTION_TABLE) > 0:
break break
# TODO(yic) Use gcs pubsub if self._worker.gcs_pubsub_enabled:
self._worker.redis_client.lpush("Exports", "a") self._worker.gcs_publisher.publish_function_key(key)
else:
self._worker.redis_client.lpush("Exports", "a")
def export(self, remote_function): def export(self, remote_function):
"""Pickle a remote function and export it to redis. """Pickle a remote function and export it to redis.

View file

@ -3,7 +3,7 @@ from collections import deque
import logging import logging
import random import random
import threading import threading
from typing import Tuple from typing import Optional, Tuple
import grpc import grpc
try: try:
@ -14,6 +14,7 @@ except ImportError:
import ray._private.gcs_utils as gcs_utils import ray._private.gcs_utils as gcs_utils
import ray._private.logging_utils as logging_utils import ray._private.logging_utils as logging_utils
from ray.core.generated.gcs_pb2 import ErrorTableData from ray.core.generated.gcs_pb2 import ErrorTableData
from ray.core.generated import dependency_pb2
from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated import gcs_service_pb2 from ray.core.generated import gcs_service_pb2
from ray.core.generated import pubsub_pb2 from ray.core.generated import pubsub_pb2
@ -60,6 +61,14 @@ class _PublisherBase:
log_json)) log_json))
]) ])
@staticmethod
def _create_function_key_request(key: bytes):
return gcs_service_pb2.GcsPublishRequest(pub_messages=[
pubsub_pb2.PubMessage(
channel_type=pubsub_pb2.RAY_PYTHON_FUNCTION_CHANNEL,
python_function_message=dependency_pb2.PythonFunction(key=key))
])
class _SubscriberBase: class _SubscriberBase:
def __init__(self): def __init__(self):
@ -101,6 +110,13 @@ class _SubscriberBase:
msg = queue.popleft() msg = queue.popleft()
return logging_utils.log_batch_proto_to_dict(msg.log_batch_message) return logging_utils.log_batch_proto_to_dict(msg.log_batch_message)
@staticmethod
def _pop_function_key(queue):
if len(queue) == 0:
return None
msg = queue.popleft()
return msg.python_function_message.key
class GcsPublisher(_PublisherBase): class GcsPublisher(_PublisherBase):
"""Publisher to GCS.""" """Publisher to GCS."""
@ -129,6 +145,11 @@ class GcsPublisher(_PublisherBase):
req = self._create_log_request(log_batch) req = self._create_log_request(log_batch)
self._stub.GcsPublish(req) self._stub.GcsPublish(req)
def publish_function_key(self, key: bytes) -> None:
"""Publishes function key to GCS."""
req = self._create_function_key_request(key)
self._stub.GcsPublish(req)
class _SyncSubscriber(_SubscriberBase): class _SyncSubscriber(_SubscriberBase):
def __init__( def __init__(
@ -216,6 +237,8 @@ class _SyncSubscriber(_SubscriberBase):
"""Closes the subscriber and its active subscription.""" """Closes the subscriber and its active subscription."""
# Mark close to terminate inflight polling and prevent future requests. # Mark close to terminate inflight polling and prevent future requests.
if self._close.is_set():
return
self._close.set() self._close.set()
req = self._unsubscribe_request(channels=[self._channel]) req = self._unsubscribe_request(channels=[self._channel])
try: try:
@ -281,7 +304,7 @@ class GcsLogSubscriber(_SyncSubscriber):
): ):
super().__init__(pubsub_pb2.RAY_LOG_CHANNEL, address, channel) super().__init__(pubsub_pb2.RAY_LOG_CHANNEL, address, channel)
def poll(self, timeout=None) -> Tuple[bytes, ErrorTableData]: def poll(self, timeout=None) -> Optional[dict]:
"""Polls for new log messages. """Polls for new log messages.
Returns: Returns:
@ -293,6 +316,41 @@ class GcsLogSubscriber(_SyncSubscriber):
return self._pop_log_batch(self._queue) return self._pop_log_batch(self._queue)
class GcsFunctionKeySubscriber(_SyncSubscriber):
"""Subscriber to functionand actor class) dependency keys. Thread safe.
Usage example:
subscriber = GcsFunctionKeySubscriber()
# Subscribe to the function key channel.
subscriber.subscribe()
...
while running:
key = subscriber.poll()
......
# Unsubscribe from the function key channel.
subscriber.close()
"""
def __init__(
self,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.RAY_PYTHON_FUNCTION_CHANNEL, address,
channel)
def poll(self, timeout=None) -> Optional[bytes]:
"""Polls for new function key messages.
Returns:
A byte string of function key.
None if polling times out or subscriber closed.
"""
with self._lock:
self._poll_locked(timeout=timeout)
return self._pop_function_key(self._queue)
class GcsAioPublisher(_PublisherBase): class GcsAioPublisher(_PublisherBase):
"""Publisher to GCS. Uses async io.""" """Publisher to GCS. Uses async io."""

View file

@ -33,8 +33,13 @@ class ImportThread:
def __init__(self, worker, mode, threads_stopped): def __init__(self, worker, mode, threads_stopped):
self.worker = worker self.worker = worker
self.mode = mode self.mode = mode
self.redis_client = worker.redis_client
self.gcs_client = worker.gcs_client self.gcs_client = worker.gcs_client
if worker.gcs_pubsub_enabled:
self.subscriber = worker.gcs_function_key_subscriber
self.subscriber.subscribe()
else:
self.subscriber = worker.redis_client.pubsub()
self.subscriber.subscribe("__keyspace@0__:Exports")
self.threads_stopped = threads_stopped self.threads_stopped = threads_stopped
self.imported_collision_identifiers = defaultdict(int) self.imported_collision_identifiers = defaultdict(int)
# Keep track of the number of imports that we've imported. # Keep track of the number of imports that we've imported.
@ -53,12 +58,6 @@ class ImportThread:
self.t.join() self.t.join()
def _run(self): def _run(self):
import_pubsub_client = self.redis_client.pubsub()
# Exports that are published after the call to
# import_pubsub_client.subscribe and before the call to
# import_pubsub_client.listen will still be processed in the loop.
import_pubsub_client.subscribe("__keyspace@0__:Exports")
try: try:
self._do_importing() self._do_importing()
while True: while True:
@ -66,18 +65,26 @@ class ImportThread:
if self.threads_stopped.is_set(): if self.threads_stopped.is_set():
return return
msg = import_pubsub_client.get_message() if self.worker.gcs_pubsub_enabled:
if msg is None: key = self.subscriber.poll()
self.threads_stopped.wait(timeout=0.01) if key is None:
continue # subscriber has closed.
if msg["type"] == "subscribe": break
continue else:
msg = self.subscriber.get_message()
if msg is None:
self.threads_stopped.wait(timeout=0.01)
continue
if msg["type"] == "subscribe":
continue
self._do_importing() self._do_importing()
except (OSError, redis.exceptions.ConnectionError, grpc.RpcError) as e: except (OSError, redis.exceptions.ConnectionError, grpc.RpcError) as e:
logger.error(f"ImportThread: {e}") logger.error(f"ImportThread: {e}")
finally: finally:
# Close the pubsub client to avoid leaking file descriptors. # Close the Redis / GCS subscriber to avoid leaking file
import_pubsub_client.close() # descriptors.
self.subscriber.close()
def _do_importing(self): def _do_importing(self):
while True: while True:

View file

@ -10,6 +10,8 @@ import numpy as np
import pytest import pytest
import ray.cluster_utils import ray.cluster_utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, \
GcsFunctionKeySubscriber
from ray._private.test_utils import ( from ray._private.test_utils import (
dicts_equal, dicts_equal,
wait_for_pid_to_exit, wait_for_pid_to_exit,
@ -323,10 +325,33 @@ def test_function_unique_export(ray_start_regular):
def g(): def g():
ray.get(f.remote()) ray.get(f.remote())
ray.get(g.remote()) if gcs_pubsub_enabled():
num_exports = ray.worker.global_worker.redis_client.llen("Exports") subscriber = GcsFunctionKeySubscriber(
ray.get([g.remote() for _ in range(5)]) channel=ray.worker.global_worker.gcs_channel.channel())
assert ray.worker.global_worker.redis_client.llen("Exports") == num_exports subscriber.subscribe()
ray.get(g.remote())
# Poll pubsub channel for messages generated from running task g().
num_exports = 0
while True:
key = subscriber.poll(timeout=1)
if key is None:
break
else:
num_exports += 1
print(f"num_exports after running g(): {num_exports}")
ray.get([g.remote() for _ in range(5)])
key = subscriber.poll(timeout=1)
assert key is None, f"Unexpected function key export: {key}"
else:
ray.get(g.remote())
num_exports = ray.worker.global_worker.redis_client.llen("Exports")
ray.get([g.remote() for _ in range(5)])
assert ray.worker.global_worker.redis_client.llen("Exports") == \
num_exports
@pytest.mark.skipif( @pytest.mark.skipif(

View file

@ -4,7 +4,8 @@ import threading
import ray import ray
import ray._private.gcs_utils as gcs_utils import ray._private.gcs_utils as gcs_utils
from ray._private.gcs_pubsub import GcsPublisher, GcsErrorSubscriber, \ from ray._private.gcs_pubsub import GcsPublisher, GcsErrorSubscriber, \
GcsLogSubscriber, GcsAioPublisher, GcsAioSubscriber GcsLogSubscriber, GcsFunctionKeySubscriber, GcsAioPublisher, \
GcsAioSubscriber
from ray.core.generated.gcs_pb2 import ErrorTableData from ray.core.generated.gcs_pb2 import ErrorTableData
import pytest import pytest
@ -143,6 +144,34 @@ async def test_aio_publish_and_subscribe_logs(ray_start_regular):
await subscriber.close() await subscriber.close()
@pytest.mark.parametrize(
"ray_start_regular", [{
"_system_config": {
"gcs_grpc_based_pubsub": True
}
}],
indirect=True)
def test_publish_and_subscribe_function_keys(ray_start_regular):
address_info = ray_start_regular
redis = ray._private.services.create_redis_client(
address_info["redis_address"],
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
subscriber = GcsFunctionKeySubscriber(address=gcs_server_addr)
subscriber.subscribe()
publisher = GcsPublisher(address=gcs_server_addr)
publisher.publish_function_key(b"111")
publisher.publish_function_key(b"222")
assert subscriber.poll() == b"111"
assert subscriber.poll() == b"222"
subscriber.close()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ray_start_regular", [{ "ray_start_regular", [{
"_system_config": { "_system_config": {

View file

@ -28,7 +28,7 @@ import ray.serialization as serialization
import ray._private.gcs_utils as gcs_utils import ray._private.gcs_utils as gcs_utils
import ray._private.services as services import ray._private.services as services
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher, \ from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher, \
GcsErrorSubscriber, GcsLogSubscriber GcsErrorSubscriber, GcsLogSubscriber, GcsFunctionKeySubscriber
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
@ -1386,6 +1386,8 @@ def connect(node,
channel=worker.gcs_channel.channel()) channel=worker.gcs_channel.channel())
worker.gcs_log_subscriber = GcsLogSubscriber( worker.gcs_log_subscriber = GcsLogSubscriber(
channel=worker.gcs_channel.channel()) channel=worker.gcs_channel.channel())
worker.gcs_function_key_subscriber = GcsFunctionKeySubscriber(
channel=worker.gcs_channel.channel())
# Initialize some fields. # Initialize some fields.
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE): if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
@ -1589,6 +1591,7 @@ def disconnect(exiting_interpreter=False):
# in this disconnect method. # in this disconnect method.
worker.threads_stopped.set() worker.threads_stopped.set()
if worker.gcs_pubsub_enabled: if worker.gcs_pubsub_enabled:
worker.gcs_function_key_subscriber.close()
worker.gcs_error_subscriber.close() worker.gcs_error_subscriber.close()
worker.gcs_log_subscriber.close() worker.gcs_log_subscriber.close()
if hasattr(worker, "import_thread"): if hasattr(worker, "import_thread"):

View file

@ -79,6 +79,7 @@ void GcsServer::Start() {
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL, rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL,
rpc::ChannelType::RAY_ERROR_INFO_CHANNEL, rpc::ChannelType::RAY_ERROR_INFO_CHANNEL,
rpc::ChannelType::RAY_LOG_CHANNEL, rpc::ChannelType::RAY_LOG_CHANNEL,
rpc::ChannelType::RAY_PYTHON_FUNCTION_CHANNEL,
}, },
/*periodical_runner=*/&pubsub_periodical_runner_, /*periodical_runner=*/&pubsub_periodical_runner_,
/*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; }, /*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },

View file

@ -57,6 +57,18 @@ python_grpc_compile(
deps = [":gcs_proto"], deps = [":gcs_proto"],
) )
# Function and class dependencies.
proto_library(
name = "dependency_proto",
srcs = ["dependency.proto"],
)
python_grpc_compile(
name = "dependency_py_proto",
deps = [":dependency_proto"],
)
# Text logging.
proto_library( proto_library(
name = "logging_proto", name = "logging_proto",
srcs = ["logging.proto"], srcs = ["logging.proto"],
@ -252,6 +264,7 @@ proto_library(
srcs = ["pubsub.proto"], srcs = ["pubsub.proto"],
deps = [ deps = [
":common_proto", ":common_proto",
":dependency_proto",
":gcs_proto", ":gcs_proto",
":logging_proto", ":logging_proto",
], ],

View file

@ -0,0 +1,24 @@
// Copyright 2021 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
option cc_enable_arenas = true;
package ray.rpc;
// Notifies workers to import a Python function / actor class.
message PythonFunction {
// Key to internal KV storing pickled Python function or actor class.
bytes key = 1;
}

View file

@ -18,6 +18,7 @@ option cc_enable_arenas = true;
package ray.rpc; package ray.rpc;
import "src/ray/protobuf/common.proto"; import "src/ray/protobuf/common.proto";
import "src/ray/protobuf/dependency.proto";
import "src/ray/protobuf/gcs.proto"; import "src/ray/protobuf/gcs.proto";
import "src/ray/protobuf/logging.proto"; import "src/ray/protobuf/logging.proto";
@ -45,6 +46,8 @@ enum ChannelType {
RAY_ERROR_INFO_CHANNEL = 8; RAY_ERROR_INFO_CHANNEL = 8;
/// A channel for logs from various Ray components. /// A channel for logs from various Ray components.
RAY_LOG_CHANNEL = 9; RAY_LOG_CHANNEL = 9;
/// A channel for keys to pickled python functions and actor classes.
RAY_PYTHON_FUNCTION_CHANNEL = 10;
} }
/// ///
@ -68,6 +71,7 @@ message PubMessage {
WorkerDeltaData worker_delta_message = 11; WorkerDeltaData worker_delta_message = 11;
ErrorTableData error_info_message = 12; ErrorTableData error_info_message = 12;
LogBatch log_batch_message = 13; LogBatch log_batch_message = 13;
PythonFunction python_function_message = 14;
// The message that indicates the given key id is not available anymore. // The message that indicates the given key id is not available anymore.
FailureMessage failure_message = 6; FailureMessage failure_message = 6;