mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Core][Pubsub][Importer] GCS pubsub for function manager & importer (#20804)
This PR allows using Ray pubsub for notifying worker importers that a new function / actor class needs to be imported.
This commit is contained in:
parent
8c0bf41b17
commit
0467bc9df5
10 changed files with 191 additions and 25 deletions
|
@ -156,8 +156,10 @@ class FunctionActorManager:
|
|||
if self._worker.gcs_client.internal_kv_put(
|
||||
holder, key, False, KV_NAMESPACE_FUNCTION_TABLE) > 0:
|
||||
break
|
||||
# TODO(yic) Use gcs pubsub
|
||||
self._worker.redis_client.lpush("Exports", "a")
|
||||
if self._worker.gcs_pubsub_enabled:
|
||||
self._worker.gcs_publisher.publish_function_key(key)
|
||||
else:
|
||||
self._worker.redis_client.lpush("Exports", "a")
|
||||
|
||||
def export(self, remote_function):
|
||||
"""Pickle a remote function and export it to redis.
|
||||
|
|
|
@ -3,7 +3,7 @@ from collections import deque
|
|||
import logging
|
||||
import random
|
||||
import threading
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import grpc
|
||||
try:
|
||||
|
@ -14,6 +14,7 @@ except ImportError:
|
|||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.logging_utils as logging_utils
|
||||
from ray.core.generated.gcs_pb2 import ErrorTableData
|
||||
from ray.core.generated import dependency_pb2
|
||||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
from ray.core.generated import gcs_service_pb2
|
||||
from ray.core.generated import pubsub_pb2
|
||||
|
@ -60,6 +61,14 @@ class _PublisherBase:
|
|||
log_json))
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def _create_function_key_request(key: bytes):
|
||||
return gcs_service_pb2.GcsPublishRequest(pub_messages=[
|
||||
pubsub_pb2.PubMessage(
|
||||
channel_type=pubsub_pb2.RAY_PYTHON_FUNCTION_CHANNEL,
|
||||
python_function_message=dependency_pb2.PythonFunction(key=key))
|
||||
])
|
||||
|
||||
|
||||
class _SubscriberBase:
|
||||
def __init__(self):
|
||||
|
@ -101,6 +110,13 @@ class _SubscriberBase:
|
|||
msg = queue.popleft()
|
||||
return logging_utils.log_batch_proto_to_dict(msg.log_batch_message)
|
||||
|
||||
@staticmethod
|
||||
def _pop_function_key(queue):
|
||||
if len(queue) == 0:
|
||||
return None
|
||||
msg = queue.popleft()
|
||||
return msg.python_function_message.key
|
||||
|
||||
|
||||
class GcsPublisher(_PublisherBase):
|
||||
"""Publisher to GCS."""
|
||||
|
@ -129,6 +145,11 @@ class GcsPublisher(_PublisherBase):
|
|||
req = self._create_log_request(log_batch)
|
||||
self._stub.GcsPublish(req)
|
||||
|
||||
def publish_function_key(self, key: bytes) -> None:
|
||||
"""Publishes function key to GCS."""
|
||||
req = self._create_function_key_request(key)
|
||||
self._stub.GcsPublish(req)
|
||||
|
||||
|
||||
class _SyncSubscriber(_SubscriberBase):
|
||||
def __init__(
|
||||
|
@ -216,6 +237,8 @@ class _SyncSubscriber(_SubscriberBase):
|
|||
"""Closes the subscriber and its active subscription."""
|
||||
|
||||
# Mark close to terminate inflight polling and prevent future requests.
|
||||
if self._close.is_set():
|
||||
return
|
||||
self._close.set()
|
||||
req = self._unsubscribe_request(channels=[self._channel])
|
||||
try:
|
||||
|
@ -281,7 +304,7 @@ class GcsLogSubscriber(_SyncSubscriber):
|
|||
):
|
||||
super().__init__(pubsub_pb2.RAY_LOG_CHANNEL, address, channel)
|
||||
|
||||
def poll(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
|
||||
def poll(self, timeout=None) -> Optional[dict]:
|
||||
"""Polls for new log messages.
|
||||
|
||||
Returns:
|
||||
|
@ -293,6 +316,41 @@ class GcsLogSubscriber(_SyncSubscriber):
|
|||
return self._pop_log_batch(self._queue)
|
||||
|
||||
|
||||
class GcsFunctionKeySubscriber(_SyncSubscriber):
|
||||
"""Subscriber to function(and actor class) dependency keys. Thread safe.
|
||||
|
||||
Usage example:
|
||||
subscriber = GcsFunctionKeySubscriber()
|
||||
# Subscribe to the function key channel.
|
||||
subscriber.subscribe()
|
||||
...
|
||||
while running:
|
||||
key = subscriber.poll()
|
||||
......
|
||||
# Unsubscribe from the function key channel.
|
||||
subscriber.close()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: str = None,
|
||||
channel: grpc.Channel = None,
|
||||
):
|
||||
super().__init__(pubsub_pb2.RAY_PYTHON_FUNCTION_CHANNEL, address,
|
||||
channel)
|
||||
|
||||
def poll(self, timeout=None) -> Optional[bytes]:
|
||||
"""Polls for new function key messages.
|
||||
|
||||
Returns:
|
||||
A byte string of function key.
|
||||
None if polling times out or subscriber closed.
|
||||
"""
|
||||
with self._lock:
|
||||
self._poll_locked(timeout=timeout)
|
||||
return self._pop_function_key(self._queue)
|
||||
|
||||
|
||||
class GcsAioPublisher(_PublisherBase):
|
||||
"""Publisher to GCS. Uses async io."""
|
||||
|
||||
|
|
|
@ -33,8 +33,13 @@ class ImportThread:
|
|||
def __init__(self, worker, mode, threads_stopped):
|
||||
self.worker = worker
|
||||
self.mode = mode
|
||||
self.redis_client = worker.redis_client
|
||||
self.gcs_client = worker.gcs_client
|
||||
if worker.gcs_pubsub_enabled:
|
||||
self.subscriber = worker.gcs_function_key_subscriber
|
||||
self.subscriber.subscribe()
|
||||
else:
|
||||
self.subscriber = worker.redis_client.pubsub()
|
||||
self.subscriber.subscribe("__keyspace@0__:Exports")
|
||||
self.threads_stopped = threads_stopped
|
||||
self.imported_collision_identifiers = defaultdict(int)
|
||||
# Keep track of the number of imports that we've imported.
|
||||
|
@ -53,12 +58,6 @@ class ImportThread:
|
|||
self.t.join()
|
||||
|
||||
def _run(self):
|
||||
import_pubsub_client = self.redis_client.pubsub()
|
||||
# Exports that are published after the call to
|
||||
# import_pubsub_client.subscribe and before the call to
|
||||
# import_pubsub_client.listen will still be processed in the loop.
|
||||
import_pubsub_client.subscribe("__keyspace@0__:Exports")
|
||||
|
||||
try:
|
||||
self._do_importing()
|
||||
while True:
|
||||
|
@ -66,18 +65,26 @@ class ImportThread:
|
|||
if self.threads_stopped.is_set():
|
||||
return
|
||||
|
||||
msg = import_pubsub_client.get_message()
|
||||
if msg is None:
|
||||
self.threads_stopped.wait(timeout=0.01)
|
||||
continue
|
||||
if msg["type"] == "subscribe":
|
||||
continue
|
||||
if self.worker.gcs_pubsub_enabled:
|
||||
key = self.subscriber.poll()
|
||||
if key is None:
|
||||
# subscriber has closed.
|
||||
break
|
||||
else:
|
||||
msg = self.subscriber.get_message()
|
||||
if msg is None:
|
||||
self.threads_stopped.wait(timeout=0.01)
|
||||
continue
|
||||
if msg["type"] == "subscribe":
|
||||
continue
|
||||
|
||||
self._do_importing()
|
||||
except (OSError, redis.exceptions.ConnectionError, grpc.RpcError) as e:
|
||||
logger.error(f"ImportThread: {e}")
|
||||
finally:
|
||||
# Close the pubsub client to avoid leaking file descriptors.
|
||||
import_pubsub_client.close()
|
||||
# Close the Redis / GCS subscriber to avoid leaking file
|
||||
# descriptors.
|
||||
self.subscriber.close()
|
||||
|
||||
def _do_importing(self):
|
||||
while True:
|
||||
|
|
|
@ -10,6 +10,8 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import ray.cluster_utils
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, \
|
||||
GcsFunctionKeySubscriber
|
||||
from ray._private.test_utils import (
|
||||
dicts_equal,
|
||||
wait_for_pid_to_exit,
|
||||
|
@ -323,10 +325,33 @@ def test_function_unique_export(ray_start_regular):
|
|||
def g():
|
||||
ray.get(f.remote())
|
||||
|
||||
ray.get(g.remote())
|
||||
num_exports = ray.worker.global_worker.redis_client.llen("Exports")
|
||||
ray.get([g.remote() for _ in range(5)])
|
||||
assert ray.worker.global_worker.redis_client.llen("Exports") == num_exports
|
||||
if gcs_pubsub_enabled():
|
||||
subscriber = GcsFunctionKeySubscriber(
|
||||
channel=ray.worker.global_worker.gcs_channel.channel())
|
||||
subscriber.subscribe()
|
||||
|
||||
ray.get(g.remote())
|
||||
|
||||
# Poll pubsub channel for messages generated from running task g().
|
||||
num_exports = 0
|
||||
while True:
|
||||
key = subscriber.poll(timeout=1)
|
||||
if key is None:
|
||||
break
|
||||
else:
|
||||
num_exports += 1
|
||||
print(f"num_exports after running g(): {num_exports}")
|
||||
|
||||
ray.get([g.remote() for _ in range(5)])
|
||||
|
||||
key = subscriber.poll(timeout=1)
|
||||
assert key is None, f"Unexpected function key export: {key}"
|
||||
else:
|
||||
ray.get(g.remote())
|
||||
num_exports = ray.worker.global_worker.redis_client.llen("Exports")
|
||||
ray.get([g.remote() for _ in range(5)])
|
||||
assert ray.worker.global_worker.redis_client.llen("Exports") == \
|
||||
num_exports
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
|
|
@ -4,7 +4,8 @@ import threading
|
|||
import ray
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
from ray._private.gcs_pubsub import GcsPublisher, GcsErrorSubscriber, \
|
||||
GcsLogSubscriber, GcsAioPublisher, GcsAioSubscriber
|
||||
GcsLogSubscriber, GcsFunctionKeySubscriber, GcsAioPublisher, \
|
||||
GcsAioSubscriber
|
||||
from ray.core.generated.gcs_pb2 import ErrorTableData
|
||||
import pytest
|
||||
|
||||
|
@ -143,6 +144,34 @@ async def test_aio_publish_and_subscribe_logs(ray_start_regular):
|
|||
await subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"_system_config": {
|
||||
"gcs_grpc_based_pubsub": True
|
||||
}
|
||||
}],
|
||||
indirect=True)
|
||||
def test_publish_and_subscribe_function_keys(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
redis = ray._private.services.create_redis_client(
|
||||
address_info["redis_address"],
|
||||
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||
|
||||
gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)
|
||||
|
||||
subscriber = GcsFunctionKeySubscriber(address=gcs_server_addr)
|
||||
subscriber.subscribe()
|
||||
|
||||
publisher = GcsPublisher(address=gcs_server_addr)
|
||||
publisher.publish_function_key(b"111")
|
||||
publisher.publish_function_key(b"222")
|
||||
|
||||
assert subscriber.poll() == b"111"
|
||||
assert subscriber.poll() == b"222"
|
||||
|
||||
subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"_system_config": {
|
||||
|
|
|
@ -28,7 +28,7 @@ import ray.serialization as serialization
|
|||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.services as services
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher, \
|
||||
GcsErrorSubscriber, GcsLogSubscriber
|
||||
GcsErrorSubscriber, GcsLogSubscriber, GcsFunctionKeySubscriber
|
||||
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
|
||||
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
|
||||
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
|
||||
|
@ -1386,6 +1386,8 @@ def connect(node,
|
|||
channel=worker.gcs_channel.channel())
|
||||
worker.gcs_log_subscriber = GcsLogSubscriber(
|
||||
channel=worker.gcs_channel.channel())
|
||||
worker.gcs_function_key_subscriber = GcsFunctionKeySubscriber(
|
||||
channel=worker.gcs_channel.channel())
|
||||
|
||||
# Initialize some fields.
|
||||
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
|
||||
|
@ -1589,6 +1591,7 @@ def disconnect(exiting_interpreter=False):
|
|||
# in this disconnect method.
|
||||
worker.threads_stopped.set()
|
||||
if worker.gcs_pubsub_enabled:
|
||||
worker.gcs_function_key_subscriber.close()
|
||||
worker.gcs_error_subscriber.close()
|
||||
worker.gcs_log_subscriber.close()
|
||||
if hasattr(worker, "import_thread"):
|
||||
|
|
|
@ -79,6 +79,7 @@ void GcsServer::Start() {
|
|||
rpc::ChannelType::GCS_WORKER_DELTA_CHANNEL,
|
||||
rpc::ChannelType::RAY_ERROR_INFO_CHANNEL,
|
||||
rpc::ChannelType::RAY_LOG_CHANNEL,
|
||||
rpc::ChannelType::RAY_PYTHON_FUNCTION_CHANNEL,
|
||||
},
|
||||
/*periodical_runner=*/&pubsub_periodical_runner_,
|
||||
/*get_time_ms=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; },
|
||||
|
|
|
@ -57,6 +57,18 @@ python_grpc_compile(
|
|||
deps = [":gcs_proto"],
|
||||
)
|
||||
|
||||
# Function and class dependencies.
|
||||
proto_library(
|
||||
name = "dependency_proto",
|
||||
srcs = ["dependency.proto"],
|
||||
)
|
||||
|
||||
python_grpc_compile(
|
||||
name = "dependency_py_proto",
|
||||
deps = [":dependency_proto"],
|
||||
)
|
||||
|
||||
# Text logging.
|
||||
proto_library(
|
||||
name = "logging_proto",
|
||||
srcs = ["logging.proto"],
|
||||
|
@ -252,6 +264,7 @@ proto_library(
|
|||
srcs = ["pubsub.proto"],
|
||||
deps = [
|
||||
":common_proto",
|
||||
":dependency_proto",
|
||||
":gcs_proto",
|
||||
":logging_proto",
|
||||
],
|
||||
|
|
24
src/ray/protobuf/dependency.proto
Normal file
24
src/ray/protobuf/dependency.proto
Normal file
|
@ -0,0 +1,24 @@
|
|||
// Copyright 2021 The Ray Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto3";
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
package ray.rpc;
|
||||
|
||||
// Notifies workers to import a Python function / actor class.
|
||||
message PythonFunction {
|
||||
// Key to internal KV storing pickled Python function or actor class.
|
||||
bytes key = 1;
|
||||
}
|
|
@ -18,6 +18,7 @@ option cc_enable_arenas = true;
|
|||
package ray.rpc;
|
||||
|
||||
import "src/ray/protobuf/common.proto";
|
||||
import "src/ray/protobuf/dependency.proto";
|
||||
import "src/ray/protobuf/gcs.proto";
|
||||
import "src/ray/protobuf/logging.proto";
|
||||
|
||||
|
@ -45,6 +46,8 @@ enum ChannelType {
|
|||
RAY_ERROR_INFO_CHANNEL = 8;
|
||||
/// A channel for logs from various Ray components.
|
||||
RAY_LOG_CHANNEL = 9;
|
||||
/// A channel for keys to pickled python functions and actor classes.
|
||||
RAY_PYTHON_FUNCTION_CHANNEL = 10;
|
||||
}
|
||||
|
||||
///
|
||||
|
@ -68,6 +71,7 @@ message PubMessage {
|
|||
WorkerDeltaData worker_delta_message = 11;
|
||||
ErrorTableData error_info_message = 12;
|
||||
LogBatch log_batch_message = 13;
|
||||
PythonFunction python_function_message = 14;
|
||||
|
||||
// The message that indicates the given key id is not available anymore.
|
||||
FailureMessage failure_message = 6;
|
||||
|
|
Loading…
Add table
Reference in a new issue