mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[Remove Redis Pubsub 1/n] Remove enable_gcs_pubsub()
(#23189)
GCS pubsub has been the default for awhile. There is little chance that we would need to revert back to Redis pubsub in future. This is the step in removing Redis pubsub, by first removing the `enable_gcs_pubsub()` feature guard.
This commit is contained in:
parent
678d23fe42
commit
72ef9f91aa
15 changed files with 84 additions and 374 deletions
|
@ -7,7 +7,6 @@ import traceback
|
|||
import ray
|
||||
import pytest
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.gcs_pubsub as gcs_pubsub
|
||||
from ray.dashboard.tests.conftest import * # noqa
|
||||
from ray.dashboard.modules.actor import actor_consts
|
||||
|
@ -228,15 +227,7 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
def handle_pub_messages(msgs, timeout, expect_num):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout and len(msgs) < expect_num:
|
||||
if gcs_pubsub.gcs_pubsub_enabled():
|
||||
_, actor_data = sub.poll(timeout=timeout)
|
||||
else:
|
||||
msg = sub.get_message()
|
||||
if msg is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
|
||||
actor_data = gcs_utils.ActorTableData.FromString(pubsub_msg.data)
|
||||
_, actor_data = sub.poll(timeout=timeout)
|
||||
if actor_data is None:
|
||||
continue
|
||||
msgs.append(actor_data)
|
||||
|
|
|
@ -5,7 +5,6 @@ import os
|
|||
import aiohttp.web
|
||||
|
||||
import ray
|
||||
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||
import ray.experimental.internal_kv as internal_kv
|
||||
|
@ -18,7 +17,7 @@ from ray.ray_constants import (
|
|||
)
|
||||
from ray.core.generated import reporter_pb2
|
||||
from ray.core.generated import reporter_pb2_grpc
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioResourceUsageSubscriber
|
||||
from ray._private.gcs_pubsub import GcsAioResourceUsageSubscriber
|
||||
from ray._private.metrics_agent import PrometheusServiceDiscoveryWriter
|
||||
from ray.dashboard.datacenter import DataSource
|
||||
|
||||
|
@ -148,45 +147,24 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
|||
# Need daemon True to avoid dashboard hangs at exit.
|
||||
self.service_discovery.daemon = True
|
||||
self.service_discovery.start()
|
||||
if gcs_pubsub_enabled():
|
||||
gcs_addr = self._dashboard_head.gcs_address
|
||||
subscriber = GcsAioResourceUsageSubscriber(gcs_addr)
|
||||
await subscriber.subscribe()
|
||||
gcs_addr = self._dashboard_head.gcs_address
|
||||
subscriber = GcsAioResourceUsageSubscriber(gcs_addr)
|
||||
await subscriber.subscribe()
|
||||
|
||||
while True:
|
||||
try:
|
||||
# The key is b'RAY_REPORTER:{node id hex}',
|
||||
# e.g. b'RAY_REPORTER:2b4fbd...'
|
||||
key, data = await subscriber.poll()
|
||||
if key is None:
|
||||
continue
|
||||
data = json.loads(data)
|
||||
node_id = key.split(":")[-1]
|
||||
DataSource.node_physical_stats[node_id] = data
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error receiving node physical stats " "from reporter agent."
|
||||
)
|
||||
else:
|
||||
from aioredis.pubsub import Receiver
|
||||
|
||||
receiver = Receiver()
|
||||
aioredis_client = self._dashboard_head.aioredis_client
|
||||
reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
|
||||
await aioredis_client.psubscribe(receiver.pattern(reporter_key))
|
||||
logger.info(f"Subscribed to {reporter_key}")
|
||||
|
||||
async for sender, msg in receiver.iter():
|
||||
try:
|
||||
key, data = msg
|
||||
data = json.loads(ray._private.utils.decode(data))
|
||||
key = key.decode("utf-8")
|
||||
node_id = key.split(":")[-1]
|
||||
DataSource.node_physical_stats[node_id] = data
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error receiving node physical stats " "from reporter agent."
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
# The key is b'RAY_REPORTER:{node id hex}',
|
||||
# e.g. b'RAY_REPORTER:2b4fbd...'
|
||||
key, data = await subscriber.poll()
|
||||
if key is None:
|
||||
continue
|
||||
data = json.loads(data)
|
||||
node_id = key.split(":")[-1]
|
||||
DataSource.node_physical_stats[node_id] = data
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error receiving node physical stats from reporter agent."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_minimal_module():
|
||||
|
|
|
@ -23,7 +23,6 @@ from ray._private.test_utils import (
|
|||
run_string_as_driver,
|
||||
wait_until_succeeded_without_exception,
|
||||
)
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled
|
||||
from ray.ray_constants import DEBUG_AUTOSCALING_STATUS_LEGACY, DEBUG_AUTOSCALING_ERROR
|
||||
from ray.dashboard import dashboard
|
||||
import ray.dashboard.consts as dashboard_consts
|
||||
|
@ -701,13 +700,9 @@ def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
|
|||
|
||||
gcs_server_proc.kill()
|
||||
gcs_server_proc.wait()
|
||||
if gcs_pubsub_enabled():
|
||||
# When pubsub enabled, the exits comes from pubsub errored.
|
||||
# TODO: Fix this exits logic for pubsub
|
||||
assert dashboard_proc.wait(10) != 0
|
||||
else:
|
||||
# The dashboard exits by os._exit(-1)
|
||||
assert dashboard_proc.wait(10) == 255
|
||||
|
||||
# The dashboard exits by os._exit(-1)
|
||||
assert dashboard_proc.wait(10) == 255
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -178,13 +178,8 @@ class FunctionActorManager:
|
|||
break
|
||||
# Notify all subscribers that there is a new function exported. Note
|
||||
# that the notification doesn't include any actual data.
|
||||
if self._worker.gcs_pubsub_enabled:
|
||||
# TODO(mwtian) implement per-job notification here.
|
||||
self._worker.gcs_publisher.publish_function_key(key)
|
||||
else:
|
||||
self._worker.redis_client.lpush(
|
||||
make_exports_prefix(self._worker.current_job_id), "a"
|
||||
)
|
||||
# TODO(mwtian) implement per-job notification here.
|
||||
self._worker.gcs_publisher.publish_function_key(key)
|
||||
|
||||
def export(self, remote_function):
|
||||
"""Pickle a remote function and export it to redis.
|
||||
|
|
|
@ -16,7 +16,6 @@ except ImportError:
|
|||
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.logging_utils as logging_utils
|
||||
from ray._raylet import Config
|
||||
from ray.core.generated.gcs_pb2 import ErrorTableData
|
||||
from ray.core.generated import dependency_pb2
|
||||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
|
@ -30,11 +29,6 @@ logger = logging.getLogger(__name__)
|
|||
MAX_GCS_PUBLISH_RETRIES = 60
|
||||
|
||||
|
||||
def gcs_pubsub_enabled():
|
||||
"""Checks whether GCS pubsub feature flag is enabled."""
|
||||
return Config.gcs_grpc_based_pubsub()
|
||||
|
||||
|
||||
def construct_error_message(job_id, error_type, message, timestamp):
|
||||
"""Construct an ErrorTableData object.
|
||||
|
||||
|
|
|
@ -94,13 +94,9 @@ _GRPC_OPTIONS = [
|
|||
|
||||
|
||||
def use_gcs_for_bootstrap():
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled
|
||||
from ray._raylet import Config
|
||||
|
||||
ret = Config.bootstrap_with_gcs()
|
||||
if ret:
|
||||
assert gcs_pubsub_enabled()
|
||||
return ret
|
||||
return Config.bootstrap_with_gcs()
|
||||
|
||||
|
||||
def get_gcs_address_from_redis(redis) -> str:
|
||||
|
|
|
@ -33,21 +33,9 @@ class ImportThread:
|
|||
self.worker = worker
|
||||
self.mode = mode
|
||||
self.gcs_client = worker.gcs_client
|
||||
if worker.gcs_pubsub_enabled:
|
||||
self.subscriber = worker.gcs_function_key_subscriber
|
||||
self.subscriber.subscribe()
|
||||
self.exception_type = grpc.RpcError
|
||||
else:
|
||||
import redis
|
||||
|
||||
self.subscriber = worker.redis_client.pubsub()
|
||||
self.subscriber.subscribe(
|
||||
b"__keyspace@0__:"
|
||||
+ ray._private.function_manager.make_exports_prefix(
|
||||
self.worker.current_job_id
|
||||
)
|
||||
)
|
||||
self.exception_type = redis.exceptions.ConnectionError
|
||||
self.subscriber = worker.gcs_function_key_subscriber
|
||||
self.subscriber.subscribe()
|
||||
self.exception_type = grpc.RpcError
|
||||
self.threads_stopped = threads_stopped
|
||||
self.imported_collision_identifiers = defaultdict(int)
|
||||
# Keep track of the number of imports that we've imported.
|
||||
|
@ -72,20 +60,10 @@ class ImportThread:
|
|||
# Exit if we received a signal that we should stop.
|
||||
if self.threads_stopped.is_set():
|
||||
return
|
||||
|
||||
if self.worker.gcs_pubsub_enabled:
|
||||
key = self.subscriber.poll()
|
||||
if key is None:
|
||||
# subscriber has closed.
|
||||
break
|
||||
else:
|
||||
msg = self.subscriber.get_message()
|
||||
if msg is None:
|
||||
self.threads_stopped.wait(timeout=0.01)
|
||||
continue
|
||||
if msg["type"] == "subscribe":
|
||||
continue
|
||||
|
||||
key = self.subscriber.poll()
|
||||
if key is None:
|
||||
# subscriber has closed.
|
||||
break
|
||||
self._do_importing()
|
||||
except (OSError, self.exception_type) as e:
|
||||
logger.error(f"ImportThread: {e}")
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import argparse
|
||||
import errno
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
|
@ -13,10 +12,9 @@ import traceback
|
|||
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray._private.gcs_pubsub as gcs_pubsub
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray._private.services as services
|
||||
import ray._private.utils
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||
from ray._private.gcs_pubsub import GcsPublisher
|
||||
from ray._private.ray_logging import setup_component_logger
|
||||
|
||||
# Logger for this module. It should be configured at the entry point
|
||||
|
@ -91,7 +89,6 @@ class LogMonitor:
|
|||
host (str): The hostname of this machine. Used to improve the log
|
||||
messages published to Redis.
|
||||
logs_dir (str): The directory that the log files are in.
|
||||
redis_client: A client used to communicate with the Redis server.
|
||||
log_filenames (set): This is the set of filenames of all files in
|
||||
open_file_infos and closed_file_infos.
|
||||
open_file_infos (list[LogFileInfo]): Info for all of the open files.
|
||||
|
@ -105,10 +102,7 @@ class LogMonitor:
|
|||
"""Initialize the log monitor object."""
|
||||
self.ip = services.get_node_ip_address()
|
||||
self.logs_dir = logs_dir
|
||||
self.redis_client = None
|
||||
self.publisher = None
|
||||
if gcs_pubsub.gcs_pubsub_enabled():
|
||||
self.publisher = gcs_pubsub.GcsPublisher(address=gcs_address)
|
||||
self.publisher = gcs_pubsub.GcsPublisher(address=gcs_address)
|
||||
self.log_filenames = set()
|
||||
self.open_file_infos = []
|
||||
self.closed_file_infos = []
|
||||
|
@ -293,12 +287,7 @@ class LogMonitor:
|
|||
"actor_name": file_info.actor_name,
|
||||
"task_name": file_info.task_name,
|
||||
}
|
||||
if self.publisher:
|
||||
self.publisher.publish_logs(data)
|
||||
else:
|
||||
self.redis_client.publish(
|
||||
gcs_utils.LOG_FILE_CHANNEL, json.dumps(data)
|
||||
)
|
||||
self.publisher.publish_logs(data)
|
||||
anything_published = True
|
||||
lines_to_publish = []
|
||||
|
||||
|
@ -477,12 +466,7 @@ if __name__ == "__main__":
|
|||
log_monitor.run()
|
||||
except Exception as e:
|
||||
# Something went wrong, so push an error to all drivers.
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
args.redis_address, password=args.redis_password
|
||||
)
|
||||
gcs_publisher = None
|
||||
if gcs_pubsub_enabled():
|
||||
gcs_publisher = GcsPublisher(address=args.gcs_address)
|
||||
gcs_publisher = GcsPublisher(address=args.gcs_address)
|
||||
traceback_str = ray._private.utils.format_error_message(traceback.format_exc())
|
||||
message = (
|
||||
f"The log monitor on node {platform.node()} "
|
||||
|
@ -491,7 +475,6 @@ if __name__ == "__main__":
|
|||
ray._private.utils.publish_error_to_driver(
|
||||
ray_constants.LOG_MONITOR_DIED_ERROR,
|
||||
message,
|
||||
redis_client=redis_client,
|
||||
gcs_publisher=gcs_publisher,
|
||||
)
|
||||
logger.error(message)
|
||||
|
|
|
@ -31,7 +31,6 @@ from ray.core.generated import gcs_pb2
|
|||
from ray.core.generated import node_manager_pb2
|
||||
from ray.core.generated import node_manager_pb2_grpc
|
||||
from ray._private.gcs_pubsub import (
|
||||
gcs_pubsub_enabled,
|
||||
GcsErrorSubscriber,
|
||||
GcsLogSubscriber,
|
||||
)
|
||||
|
@ -218,7 +217,6 @@ def run_string_as_driver(driver_script: str, env: Dict = None, encode: str = "ut
|
|||
Returns:
|
||||
The script's output.
|
||||
"""
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[sys.executable, "-"],
|
||||
stdin=subprocess.PIPE,
|
||||
|
@ -581,12 +579,8 @@ def get_non_head_nodes(cluster):
|
|||
|
||||
def init_error_pubsub():
|
||||
"""Initialize redis error info pub/sub"""
|
||||
if gcs_pubsub_enabled():
|
||||
s = GcsErrorSubscriber(address=ray.worker.global_worker.gcs_client.address)
|
||||
s.subscribe()
|
||||
else:
|
||||
s = ray.worker.global_worker.redis_client.pubsub(ignore_subscribe_messages=True)
|
||||
s.psubscribe(gcs_utils.RAY_ERROR_PUBSUB_PATTERN)
|
||||
s = GcsErrorSubscriber(address=ray.worker.global_worker.gcs_client.address)
|
||||
s.subscribe()
|
||||
return s
|
||||
|
||||
|
||||
|
@ -621,12 +615,8 @@ def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
|
|||
|
||||
def init_log_pubsub():
|
||||
"""Initialize redis error info pub/sub"""
|
||||
if gcs_pubsub_enabled():
|
||||
s = GcsLogSubscriber(address=ray.worker.global_worker.gcs_client.address)
|
||||
s.subscribe()
|
||||
else:
|
||||
s = ray.worker.global_worker.redis_client.pubsub(ignore_subscribe_messages=True)
|
||||
s.psubscribe(gcs_utils.LOG_FILE_CHANNEL)
|
||||
s = GcsLogSubscriber(address=ray.worker.global_worker.gcs_client.address)
|
||||
s.subscribe()
|
||||
return s
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import ray.cluster_utils
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsFunctionKeySubscriber
|
||||
from ray._private.gcs_pubsub import GcsFunctionKeySubscriber
|
||||
from ray._private.test_utils import wait_for_condition
|
||||
from ray.autoscaler._private.constants import RAY_PROCESSES
|
||||
from pathlib import Path
|
||||
|
@ -108,33 +108,27 @@ def test_function_unique_export(ray_start_regular):
|
|||
def g():
|
||||
ray.get(f.remote())
|
||||
|
||||
if gcs_pubsub_enabled():
|
||||
subscriber = GcsFunctionKeySubscriber(
|
||||
address=ray.worker.global_worker.gcs_client.address
|
||||
)
|
||||
subscriber.subscribe()
|
||||
subscriber = GcsFunctionKeySubscriber(
|
||||
address=ray.worker.global_worker.gcs_client.address
|
||||
)
|
||||
subscriber.subscribe()
|
||||
|
||||
ray.get(g.remote())
|
||||
|
||||
# Poll pubsub channel for messages generated from running task g().
|
||||
num_exports = 0
|
||||
while True:
|
||||
key = subscriber.poll(timeout=1)
|
||||
if key is None:
|
||||
break
|
||||
else:
|
||||
num_exports += 1
|
||||
print(f"num_exports after running g(): {num_exports}")
|
||||
|
||||
ray.get([g.remote() for _ in range(5)])
|
||||
ray.get(g.remote())
|
||||
|
||||
# Poll pubsub channel for messages generated from running task g().
|
||||
num_exports = 0
|
||||
while True:
|
||||
key = subscriber.poll(timeout=1)
|
||||
assert key is None, f"Unexpected function key export: {key}"
|
||||
else:
|
||||
ray.get(g.remote())
|
||||
num_exports = ray.worker.global_worker.redis_client.llen("Exports")
|
||||
ray.get([g.remote() for _ in range(5)])
|
||||
assert ray.worker.global_worker.redis_client.llen("Exports") == num_exports
|
||||
if key is None:
|
||||
break
|
||||
else:
|
||||
num_exports += 1
|
||||
print(f"num_exports after running g(): {num_exports}")
|
||||
|
||||
ray.get([g.remote() for _ in range(5)])
|
||||
|
||||
key = subscriber.poll(timeout=1)
|
||||
assert key is None, f"Unexpected function key export: {key}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
|
|
@ -11,7 +11,7 @@ import ray._private.utils
|
|||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray.exceptions import RayTaskError, RayActorError, GetTimeoutError
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||
from ray._private.gcs_pubsub import GcsPublisher
|
||||
from ray._private.test_utils import (
|
||||
wait_for_condition,
|
||||
SignalActor,
|
||||
|
@ -69,21 +69,12 @@ def test_unhandled_errors(ray_start_regular):
|
|||
|
||||
def test_publish_error_to_driver(ray_start_regular, error_pubsub):
|
||||
address_info = ray_start_regular
|
||||
redis_client = None
|
||||
gcs_publisher = None
|
||||
if gcs_pubsub_enabled():
|
||||
gcs_publisher = GcsPublisher(address=address_info["gcs_address"])
|
||||
else:
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
address_info["redis_address"],
|
||||
password=ray.ray_constants.REDIS_DEFAULT_PASSWORD,
|
||||
)
|
||||
gcs_publisher = GcsPublisher(address=address_info["gcs_address"])
|
||||
|
||||
error_message = "Test error message"
|
||||
ray._private.utils.publish_error_to_driver(
|
||||
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
|
||||
error_message,
|
||||
redis_client=redis_client,
|
||||
gcs_publisher=gcs_publisher,
|
||||
)
|
||||
errors = get_error_message(
|
||||
|
|
|
@ -14,7 +14,6 @@ from ray.ray_constants import DEBUG_AUTOSCALING_ERROR
|
|||
import ray._private.utils
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray.cluster_utils import cluster_not_supported
|
||||
import ray._private.gcs_pubsub as gcs_pubsub
|
||||
from ray._private.test_utils import (
|
||||
init_error_pubsub,
|
||||
get_error_message,
|
||||
|
@ -425,37 +424,6 @@ def test_fate_sharing(ray_start_cluster, use_actors, node_failure):
|
|||
test_process_failure(use_actors)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular",
|
||||
[{"_system_config": {"gcs_rpc_server_reconnect_timeout_s": 100}}],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
gcs_pubsub.gcs_pubsub_enabled(),
|
||||
reason="Logs are streamed via GCS pubsub when it is enabled, so logs "
|
||||
"cannot be delivered after GCS is killed.",
|
||||
)
|
||||
def test_gcs_server_failiure_report(ray_start_regular, log_pubsub):
|
||||
# Get gcs server pid to send a signal.
|
||||
all_processes = ray.worker._global_node.all_processes
|
||||
gcs_server_process = all_processes["gcs_server"][0].process
|
||||
gcs_server_pid = gcs_server_process.pid
|
||||
|
||||
# TODO(mwtian): make sure logs are delivered after GCS is restarted.
|
||||
if sys.platform == "win32":
|
||||
sig = 9
|
||||
else:
|
||||
sig = signal.SIGBUS
|
||||
os.kill(gcs_server_pid, sig)
|
||||
# wait for 30 seconds, for the 1st batch of logs.
|
||||
batches = get_log_batch(log_pubsub, 1, timeout=30)
|
||||
assert gcs_server_process.poll() is not None
|
||||
if sys.platform != "win32":
|
||||
# Windows signal handler does not run when process is terminated
|
||||
assert len(batches) == 1
|
||||
assert batches[0]["pid"] == "gcs_server", batches
|
||||
|
||||
|
||||
def test_list_named_actors_timeout(monkeypatch, shutdown_only):
|
||||
with monkeypatch.context() as m:
|
||||
# defer for 3s
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import sys
|
||||
|
||||
import ray
|
||||
import ray._private.gcs_pubsub as gcs_pubsub
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
import pytest
|
||||
from ray._private.test_utils import (
|
||||
|
@ -63,11 +62,11 @@ def test_gcs_server_restart(ray_start_regular_with_external_redis):
|
|||
],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
gcs_pubsub.gcs_pubsub_enabled(),
|
||||
@pytest.mark.skip(
|
||||
reason="GCS pubsub may lose messages after GCS restarts. Need to "
|
||||
"implement re-fetching state in GCS client.",
|
||||
)
|
||||
# TODO(mwtian): re-enable after fixing https://github.com/ray-project/ray/issues/22340
|
||||
def test_gcs_server_restart_during_actor_creation(
|
||||
ray_start_regular_with_external_redis,
|
||||
):
|
||||
|
|
|
@ -15,11 +15,6 @@ from ray.core.generated.gcs_pb2 import ErrorTableData
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular",
|
||||
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_publish_and_subscribe_error_info(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
gcs_server_addr = address_info["gcs_address"]
|
||||
|
@ -40,11 +35,6 @@ def test_publish_and_subscribe_error_info(ray_start_regular):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular",
|
||||
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
gcs_server_addr = address_info["gcs_address"]
|
||||
|
@ -64,11 +54,6 @@ async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
|
|||
await subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular",
|
||||
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_publish_and_subscribe_logs(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
gcs_server_addr = address_info["gcs_address"]
|
||||
|
@ -96,11 +81,6 @@ def test_publish_and_subscribe_logs(ray_start_regular):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular",
|
||||
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_aio_publish_and_subscribe_logs(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
gcs_server_addr = address_info["gcs_address"]
|
||||
|
@ -125,11 +105,6 @@ async def test_aio_publish_and_subscribe_logs(ray_start_regular):
|
|||
await subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular",
|
||||
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_publish_and_subscribe_function_keys(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
gcs_server_addr = address_info["gcs_address"]
|
||||
|
@ -148,11 +123,6 @@ def test_publish_and_subscribe_function_keys(ray_start_regular):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular",
|
||||
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_aio_publish_and_subscribe_resource_usage(ray_start_regular):
|
||||
address_info = ray_start_regular
|
||||
gcs_server_addr = address_info["gcs_address"]
|
||||
|
@ -170,11 +140,6 @@ async def test_aio_publish_and_subscribe_resource_usage(ray_start_regular):
|
|||
await subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular",
|
||||
[{"_system_config": {"gcs_grpc_based_pubsub": True}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_two_subscribers(ray_start_regular):
|
||||
"""Tests concurrently subscribing to two channels work."""
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@ import ray._private.gcs_utils as gcs_utils
|
|||
import ray._private.services as services
|
||||
from ray.util.scheduling_strategies import SchedulingStrategyT
|
||||
from ray._private.gcs_pubsub import (
|
||||
gcs_pubsub_enabled,
|
||||
GcsPublisher,
|
||||
GcsErrorSubscriber,
|
||||
GcsLogSubscriber,
|
||||
|
@ -453,26 +452,13 @@ class Worker:
|
|||
|
||||
def print_logs(self):
|
||||
"""Prints log messages from workers on all nodes in the same job."""
|
||||
if self.gcs_pubsub_enabled:
|
||||
import grpc
|
||||
import grpc
|
||||
|
||||
subscriber = self.gcs_log_subscriber
|
||||
subscriber.subscribe()
|
||||
exception_type = grpc.RpcError
|
||||
else:
|
||||
import redis
|
||||
|
||||
subscriber = self.redis_client.pubsub(ignore_subscribe_messages=True)
|
||||
subscriber.subscribe(gcs_utils.LOG_FILE_CHANNEL)
|
||||
exception_type = redis.exceptions.ConnectionError
|
||||
subscriber = self.gcs_log_subscriber
|
||||
subscriber.subscribe()
|
||||
exception_type = grpc.RpcError
|
||||
localhost = services.get_node_ip_address()
|
||||
try:
|
||||
# Keep track of the number of consecutive log messages that have
|
||||
# been received with no break in between. If this number grows
|
||||
# continually, then the worker is probably not able to process the
|
||||
# log messages as rapidly as they are coming in.
|
||||
# This is meaningful only for Redis subscriber.
|
||||
num_consecutive_messages_received = 0
|
||||
# Number of messages received from the last polling. When the batch
|
||||
# size exceeds 100 and keeps increasing, the worker and the user
|
||||
# probably will not be able to consume the log messages as rapidly
|
||||
|
@ -485,43 +471,21 @@ class Worker:
|
|||
if self.threads_stopped.is_set():
|
||||
return
|
||||
|
||||
if self.gcs_pubsub_enabled:
|
||||
msg = subscriber.poll()
|
||||
else:
|
||||
msg = subscriber.get_message()
|
||||
data = subscriber.poll()
|
||||
# GCS subscriber only returns None on unavailability.
|
||||
# Redis subscriber returns None when there is no new message.
|
||||
if msg is None:
|
||||
num_consecutive_messages_received = 0
|
||||
if data is None:
|
||||
last_polling_batch_size = 0
|
||||
self.threads_stopped.wait(timeout=0.01)
|
||||
continue
|
||||
|
||||
if self.gcs_pubsub_enabled:
|
||||
data = msg
|
||||
else:
|
||||
data = json.loads(ray._private.utils.decode(msg["data"]))
|
||||
|
||||
# Don't show logs from other drivers.
|
||||
if data["job"] and data["job"] != job_id_hex:
|
||||
num_consecutive_messages_received = 0
|
||||
last_polling_batch_size = 0
|
||||
continue
|
||||
|
||||
data["localhost"] = localhost
|
||||
global_worker_stdstream_dispatcher.emit(data)
|
||||
|
||||
if self.gcs_pubsub_enabled:
|
||||
lagging = (
|
||||
100 <= last_polling_batch_size < subscriber.last_batch_size
|
||||
)
|
||||
last_polling_batch_size = subscriber.last_batch_size
|
||||
else:
|
||||
num_consecutive_messages_received += 1
|
||||
lagging = (
|
||||
num_consecutive_messages_received % 100 == 0
|
||||
and num_consecutive_messages_received > 0
|
||||
)
|
||||
lagging = 100 <= last_polling_batch_size < subscriber.last_batch_size
|
||||
if lagging:
|
||||
logger.warning(
|
||||
"The driver may not be able to keep up with the "
|
||||
|
@ -530,6 +494,8 @@ class Worker:
|
|||
"'ray.init(log_to_driver=False)'."
|
||||
)
|
||||
|
||||
last_polling_batch_size = subscriber.last_batch_size
|
||||
|
||||
except (OSError, exception_type) as e:
|
||||
logger.error(f"print_logs: {e}")
|
||||
finally:
|
||||
|
@ -1366,72 +1332,7 @@ def print_worker_logs(data: Dict[str, str], print_file: Any):
|
|||
)
|
||||
|
||||
|
||||
def listen_error_messages_raylet(worker, threads_stopped):
|
||||
"""Listen to error messages in the background on the driver.
|
||||
|
||||
This runs in a separate thread on the driver and pushes (error, time)
|
||||
tuples to the output queue.
|
||||
|
||||
Args:
|
||||
worker: The worker class that this thread belongs to.
|
||||
threads_stopped (threading.Event): A threading event used to signal to
|
||||
the thread that it should exit.
|
||||
"""
|
||||
import redis
|
||||
|
||||
worker.error_message_pubsub_client = worker.redis_client.pubsub(
|
||||
ignore_subscribe_messages=True
|
||||
)
|
||||
# Exports that are published after the call to
|
||||
# error_message_pubsub_client.subscribe and before the call to
|
||||
# error_message_pubsub_client.listen will still be processed in the loop.
|
||||
|
||||
# Really we should just subscribe to the errors for this specific job.
|
||||
# However, currently all errors seem to be published on the same channel.
|
||||
error_pubsub_channel = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
|
||||
worker.error_message_pubsub_client.psubscribe(error_pubsub_channel)
|
||||
|
||||
try:
|
||||
if _internal_kv_initialized():
|
||||
# Get any autoscaler errors that occurred before the call to
|
||||
# subscribe.
|
||||
error_message = _internal_kv_get(ray_constants.DEBUG_AUTOSCALING_ERROR)
|
||||
if error_message is not None:
|
||||
logger.warning(error_message.decode())
|
||||
|
||||
while True:
|
||||
# Exit if we received a signal that we should stop.
|
||||
if threads_stopped.is_set():
|
||||
return
|
||||
|
||||
msg = worker.error_message_pubsub_client.get_message()
|
||||
if msg is None:
|
||||
threads_stopped.wait(timeout=0.01)
|
||||
continue
|
||||
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
|
||||
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
|
||||
job_id = error_data.job_id
|
||||
if job_id not in [
|
||||
worker.current_job_id.binary(),
|
||||
JobID.nil().binary(),
|
||||
]:
|
||||
continue
|
||||
|
||||
error_message = error_data.error_message
|
||||
if error_data.type == ray_constants.TASK_PUSH_ERROR:
|
||||
# TODO(ekl) remove task push errors entirely now that we have
|
||||
# the separate unhandled exception handler.
|
||||
pass
|
||||
else:
|
||||
logger.warning(error_message)
|
||||
except (OSError, redis.exceptions.ConnectionError) as e:
|
||||
logger.error(f"listen_error_messages_raylet: {e}")
|
||||
finally:
|
||||
# Close the pubsub client to avoid leaking file descriptors.
|
||||
worker.error_message_pubsub_client.close()
|
||||
|
||||
|
||||
def listen_error_messages_from_gcs(worker, threads_stopped):
|
||||
def listen_error_messages(worker, threads_stopped):
|
||||
"""Listen to error messages in the background on the driver.
|
||||
|
||||
This runs in a separate thread on the driver and pushes (error, time)
|
||||
|
@ -1476,7 +1377,7 @@ def listen_error_messages_from_gcs(worker, threads_stopped):
|
|||
else:
|
||||
logger.warning(error_message)
|
||||
except (OSError, ConnectionError) as e:
|
||||
logger.error(f"listen_error_messages_from_gcs: {e}")
|
||||
logger.error(f"listen_error_messages: {e}")
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
@ -1543,17 +1444,12 @@ def connect(
|
|||
ray.state.state._initialize_global_state(
|
||||
ray._raylet.GcsClientOptions.from_gcs_address(node.gcs_address)
|
||||
)
|
||||
worker.gcs_pubsub_enabled = gcs_pubsub_enabled()
|
||||
worker.gcs_publisher = None
|
||||
if worker.gcs_pubsub_enabled:
|
||||
worker.gcs_publisher = GcsPublisher(address=worker.gcs_client.address)
|
||||
worker.gcs_error_subscriber = GcsErrorSubscriber(
|
||||
address=worker.gcs_client.address
|
||||
)
|
||||
worker.gcs_log_subscriber = GcsLogSubscriber(address=worker.gcs_client.address)
|
||||
worker.gcs_function_key_subscriber = GcsFunctionKeySubscriber(
|
||||
address=worker.gcs_client.address
|
||||
)
|
||||
worker.gcs_publisher = GcsPublisher(address=worker.gcs_client.address)
|
||||
worker.gcs_error_subscriber = GcsErrorSubscriber(address=worker.gcs_client.address)
|
||||
worker.gcs_log_subscriber = GcsLogSubscriber(address=worker.gcs_client.address)
|
||||
worker.gcs_function_key_subscriber = GcsFunctionKeySubscriber(
|
||||
address=worker.gcs_client.address
|
||||
)
|
||||
|
||||
# Initialize some fields.
|
||||
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
|
||||
|
@ -1701,9 +1597,7 @@ def connect(
|
|||
# scheduler for new error messages.
|
||||
if mode == SCRIPT_MODE:
|
||||
worker.listener_thread = threading.Thread(
|
||||
target=listen_error_messages_from_gcs
|
||||
if worker.gcs_pubsub_enabled
|
||||
else listen_error_messages_raylet,
|
||||
target=listen_error_messages,
|
||||
name="ray_listen_error_messages",
|
||||
args=(worker, worker.threads_stopped),
|
||||
)
|
||||
|
@ -1775,10 +1669,9 @@ def disconnect(exiting_interpreter=False):
|
|||
# should be handled cleanly in the worker object's destructor and not
|
||||
# in this disconnect method.
|
||||
worker.threads_stopped.set()
|
||||
if worker.gcs_pubsub_enabled:
|
||||
worker.gcs_function_key_subscriber.close()
|
||||
worker.gcs_error_subscriber.close()
|
||||
worker.gcs_log_subscriber.close()
|
||||
worker.gcs_function_key_subscriber.close()
|
||||
worker.gcs_error_subscriber.close()
|
||||
worker.gcs_log_subscriber.close()
|
||||
if hasattr(worker, "import_thread"):
|
||||
worker.import_thread.join_import_thread()
|
||||
if hasattr(worker, "listener_thread"):
|
||||
|
|
Loading…
Add table
Reference in a new issue