[Remove Redis Pubsub 2/n] clean up remaining Redis references in gcs_utils.py (#23233)

Continue to clean up Redis and other related Redis references, for
- gcs_utils.py
- log_monitor.py
- `publish_error_to_driver()`
This commit is contained in:
mwtian 2022-03-16 19:34:57 -07:00 committed by GitHub
parent b350fe9ee8
commit 391901f86b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 46 additions and 192 deletions

View file

@ -191,11 +191,8 @@ if __name__ == "__main__":
# Something went wrong, so push an error to all drivers.
gcs_publisher = GcsPublisher(args.gcs_address)
ray._private.utils.publish_error_to_driver(
ray_constants.DASHBOARD_DIED_ERROR,
message,
None,
None,
gcs_publisher,
gcs_publisher=gcs_publisher,
)

View file

@ -23,9 +23,5 @@ public class SystemConfig {
return val;
}
public static boolean bootstrapWithGcs() {
return ((Boolean) SystemConfig.get("bootstrap_with_gcs")).booleanValue();
}
private static native String nativeGetSystemConfig(String key);
}

View file

@ -58,16 +58,6 @@ __all__ = [
"PlacementGroupTableData",
]
LOG_FILE_CHANNEL = "RAY_LOG_CHANNEL"
# Actor pub/sub updates
RAY_ACTOR_PUBSUB_PATTERN = "ACTOR:*".encode("ascii")
RAY_ERROR_PUBSUB_PATTERN = "ERROR_INFO:*".encode("ascii")
# These prefixes must be kept up-to-date with the TablePrefix enum in
# gcs.proto.
TablePrefix_ACTOR_string = "ACTOR"
WORKER = 0
DRIVER = 1
@ -78,8 +68,6 @@ _MAX_MESSAGE_LENGTH = 512 * 1024 * 1024
_GRPC_KEEPALIVE_TIME_MS = 60 * 1000
# Keepalive should be replied < 60s
_GRPC_KEEPALIVE_TIMEOUT_MS = 60 * 1000
# Max retries to get GCS address from Redis server
_MAX_GET_GCS_SERVER_ADDRESS_RETRIES = 60
# Also relying on these defaults:
# grpc.keepalive_permit_without_calls=0: No keepalive without inflight calls.
@ -93,32 +81,6 @@ _GRPC_OPTIONS = [
]
def use_gcs_for_bootstrap():
from ray._raylet import Config
return Config.bootstrap_with_gcs()
def get_gcs_address_from_redis(redis) -> str:
"""Reads GCS address from redis.
Args:
redis: Redis client to fetch GCS address.
Returns:
GCS address string.
"""
count = 0
while count < _MAX_GET_GCS_SERVER_ADDRESS_RETRIES:
gcs_address = redis.get("GcsServerAddress")
if gcs_address is None:
logger.debug("Failed to look up gcs address through redis, retrying.")
time.sleep(1)
count += 1
continue
return gcs_address.decode()
raise RuntimeError("Failed to look up gcs address through redis")
def create_gcs_channel(address: str, aio=False):
"""Returns a GRPC channel to GCS.
@ -160,27 +122,15 @@ def _auto_reconnect(f):
class GcsChannel:
def __init__(
self, redis_client=None, gcs_address: Optional[str] = None, aio: bool = False
):
if redis_client is None and gcs_address is None:
raise ValueError("One of `redis_client` or `gcs_address` has to be set")
if redis_client is not None and gcs_address is not None:
raise ValueError("Only one of `redis_client` or `gcs_address` can be set")
self._redis_client = redis_client
def __init__(self, gcs_address: Optional[str] = None, aio: bool = False):
self._gcs_address = gcs_address
self._aio = aio
def connect(self):
# GCS server uses a cached port, so it should use the same port after
# restarting, whether in Redis or GCS bootstrapping mode. This means
# GCS address should stay the same for the lifetime of the Ray cluster.
if self._gcs_address is None:
assert self._redis_client is not None
gcs_address = get_gcs_address_from_redis(self._redis_client)
else:
gcs_address = self._gcs_address
self._channel = create_gcs_channel(gcs_address, self._aio)
# restarting. This means GCS address should stay the same for the
# lifetime of the Ray cluster.
self._channel = create_gcs_channel(self._gcs_address, self._aio)
def channel(self):
return self._channel
@ -294,15 +244,3 @@ class GcsClient:
f"Failed to list prefix {prefix} "
f"due to error {reply.status.message}"
)
@staticmethod
def create_from_redis(redis_cli):
return GcsClient(GcsChannel(redis_client=redis_cli))
@staticmethod
def connect_to_gcs_by_redis_address(redis_address, redis_password):
from ray._private.services import create_redis_client
return GcsClient.create_from_redis(
create_redis_client(redis_address, redis_password)
)

View file

@ -83,11 +83,10 @@ class LogMonitor:
lines (judged by an increase in file size since the last time the file
was opened).
4. Then we will loop through the open files and see if there are any new
lines in the file. If so, we will publish them to Redis.
lines in the file. If so, we will publish them to Ray pubsub.
Attributes:
host (str): The hostname of this machine. Used to improve the log
messages published to Redis.
host (str): The hostname of this machine, for grouping log messages.
logs_dir (str): The directory that the log files are in.
log_filenames (set): This is the set of filenames of all files in
open_file_infos and closed_file_infos.
@ -98,7 +97,7 @@ class LogMonitor:
false otherwise.
"""
def __init__(self, logs_dir, redis_address, gcs_address, redis_password=None):
def __init__(self, logs_dir, gcs_address):
"""Initialize the log monitor object."""
self.ip = services.get_node_ip_address()
self.logs_dir = logs_dir
@ -266,7 +265,7 @@ class LogMonitor:
self.closed_file_infos += files_with_no_updates
def check_log_files_and_publish_updates(self):
"""Get any changes to the log files and push updates to Redis.
"""Gets updates to the log files and publishes them.
Returns:
True if anything was published and false otherwise.
@ -360,8 +359,9 @@ class LogMonitor:
def run(self):
"""Run the log monitor.
This will query Redis once every second to check if there are new log
files to monitor. It will also store those log files in Redis.
This will scan the file system once every LOG_NAME_UPDATE_INTERVAL_S to
check if there are new log files to monitor. It will also publish new
log lines.
"""
total_log_files = 0
last_updated = time.time()
@ -383,21 +383,11 @@ class LogMonitor:
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=("Parse Redis server for the " "log monitor to connect " "to.")
description=("Parse GCS server address for the log monitor to connect to.")
)
parser.add_argument(
"--gcs-address", required=False, type=str, help="The address (ip:port) of GCS."
)
parser.add_argument(
"--redis-address", required=True, type=str, help="The address to use for Redis."
)
parser.add_argument(
"--redis-password",
required=False,
type=str,
default=None,
help="the password to use for Redis",
)
parser.add_argument(
"--logging-level",
required=False,
@ -455,12 +445,7 @@ if __name__ == "__main__":
backup_count=args.logging_rotate_backup_count,
)
log_monitor = LogMonitor(
args.logs_dir,
args.redis_address,
args.gcs_address,
redis_password=args.redis_password,
)
log_monitor = LogMonitor(args.logs_dir, args.gcs_address)
try:
log_monitor.run()

View file

@ -1232,10 +1232,8 @@ def _start_redis_instance(
def start_log_monitor(
redis_address,
gcs_address,
logs_dir,
redis_password=None,
gcs_address,
fate_share=None,
max_bytes=0,
backup_count=0,
@ -1244,9 +1242,10 @@ def start_log_monitor(
"""Start a log monitor process.
Args:
redis_address (str): The address of the Redis instance.
logs_dir (str): The directory of logging files.
redis_password (str): The password of the redis server.
gcs_address (str): GCS address for pubsub.
fate_share (bool): Whether to share fate between log_monitor
and this process.
max_bytes (int): Log rotation parameter. Corresponding to
RotatingFileHandler's maxBytes.
backup_count (int): Log rotation parameter. Corresponding to
@ -1263,11 +1262,10 @@ def start_log_monitor(
sys.executable,
"-u",
log_monitor_filepath,
f"--redis-address={redis_address}",
f"--logs-dir={logs_dir}",
f"--gcs-address={gcs_address}",
f"--logging-rotate-bytes={max_bytes}",
f"--logging-rotate-backup-count={backup_count}",
f"--gcs-address={gcs_address}",
]
if redirect_logging:
# Avoid hanging due to fd inheritance.
@ -1285,8 +1283,6 @@ def start_log_monitor(
# Inherit stdout/stderr streams.
stdout_file = None
stderr_file = None
if redis_password:
command.append(f"--redis-password={redis_password}")
process_info = start_ray_process(
command,
ray_constants.PROCESS_TYPE_LOG_MONITOR,

View file

@ -2,7 +2,6 @@ import asyncio
import io
import fnmatch
import os
import json
import pathlib
import subprocess
import sys
@ -57,12 +56,7 @@ class RayTestTimeoutException(Exception):
def make_global_state_accessor(address_info):
if not gcs_utils.use_gcs_for_bootstrap():
gcs_options = GcsClientOptions.from_redis_address(
address_info["redis_address"], ray.ray_constants.REDIS_DEFAULT_PASSWORD
)
else:
gcs_options = GcsClientOptions.from_gcs_address(address_info["gcs_address"])
gcs_options = GcsClientOptions.from_gcs_address(address_info["gcs_address"])
global_state_accessor = GlobalStateAccessor(gcs_options)
global_state_accessor.connect()
return global_state_accessor
@ -578,14 +572,14 @@ def get_non_head_nodes(cluster):
def init_error_pubsub():
"""Initialize redis error info pub/sub"""
"""Initialize error info pub/sub"""
s = GcsErrorSubscriber(address=ray.worker.global_worker.gcs_client.address)
s.subscribe()
return s
def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
"""Gets errors from GCS / Redis subscriber.
"""Gets errors from GCS subscriber.
Returns maximum `num` error strings within `timeout`.
Only returns errors of `error_type` if specified.
@ -593,18 +587,10 @@ def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
deadline = time.time() + timeout
msgs = []
while time.time() < deadline and len(msgs) < num:
if isinstance(subscriber, GcsErrorSubscriber):
_, error_data = subscriber.poll(timeout=deadline - time.time())
if not error_data:
# Timed out before any data is received.
break
else:
msg = subscriber.get_message()
if msg is None:
time.sleep(0.01)
continue
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
_, error_data = subscriber.poll(timeout=deadline - time.time())
if not error_data:
# Timed out before any data is received.
break
if error_type is None or error_type == error_data.type:
msgs.append(error_data)
else:
@ -614,7 +600,7 @@ def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
def init_log_pubsub():
"""Initialize redis error info pub/sub"""
"""Initialize log pub/sub"""
s = GcsLogSubscriber(address=ray.worker.global_worker.gcs_client.address)
s.subscribe()
return s
@ -630,18 +616,10 @@ def get_log_data(
deadline = time.time() + timeout
msgs = []
while time.time() < deadline and len(msgs) < num:
if isinstance(subscriber, GcsLogSubscriber):
logs_data = subscriber.poll(timeout=deadline - time.time())
if not logs_data:
# Timed out before any data is received.
break
else:
msg = subscriber.get_message()
if msg is None:
time.sleep(0.01)
continue
logs_data = json.loads(ray._private.utils.decode(msg["data"]))
logs_data = subscriber.poll(timeout=deadline - time.time())
if not logs_data:
# Timed out before any data is received.
break
if job_id and job_id != logs_data["job"]:
continue
if matcher and all(not matcher(line) for line in logs_data["lines"]):
@ -657,7 +635,7 @@ def get_log_message(
job_id: Optional[str] = None,
matcher=None,
) -> List[List[str]]:
"""Gets log lines through GCS / Redis subscriber.
"""Gets log lines through GCS subscriber.
Returns maximum `num` of log messages, within `timeout`.
@ -687,7 +665,7 @@ def get_log_batch(
job_id: Optional[str] = None,
matcher=None,
) -> List[str]:
"""Gets log batches through GCS / Redis subscriber.
"""Gets log batches through GCS subscriber.
Returns maximum `num` batches of logs. Each batch is a dict that includes
metadata such as `pid`, `job_id`, and `lines` of log messages.
@ -698,18 +676,10 @@ def get_log_batch(
deadline = time.time() + timeout
batches = []
while time.time() < deadline and len(batches) < num:
if isinstance(subscriber, GcsLogSubscriber):
logs_data = subscriber.poll(timeout=deadline - time.time())
if not logs_data:
# Timed out before any data is received.
break
else:
msg = subscriber.get_message()
if msg is None:
time.sleep(0.01)
continue
logs_data = json.loads(ray._private.utils.decode(msg["data"]))
logs_data = subscriber.poll(timeout=deadline - time.time())
if not logs_data:
# Timed out before any data is received.
break
if job_id and job_id != logs_data["job"]:
continue
if matcher and not matcher(logs_data):

View file

@ -28,7 +28,6 @@ from pathlib import Path
import numpy as np
import ray
import ray._private.gcs_utils as gcs_utils
import ray.ray_constants as ray_constants
from ray._private.gcs_pubsub import construct_error_message
from ray._private.tls_utils import load_certs_from_env
@ -118,7 +117,10 @@ def push_error_to_driver(worker, error_type, message, job_id=None):
def publish_error_to_driver(
error_type, message, job_id=None, redis_client=None, gcs_publisher=None
error_type,
message,
gcs_publisher,
job_id=None,
):
"""Push an error message to the driver to be printed in the background.
@ -131,27 +133,15 @@ def publish_error_to_driver(
error_type (str): The type of the error.
message (str): The message that will be printed in the background
on the driver.
gcs_publisher: The GCS publisher to use.
job_id: The ID of the driver to push the error message to. If this
is None, then the message will be pushed to all drivers.
redis_client: The redis client to use.
gcs_publisher: The GCS publisher to use. If specified, ignores
redis_client.
"""
if job_id is None:
job_id = ray.JobID.nil()
assert isinstance(job_id, ray.JobID)
error_data = construct_error_message(job_id, error_type, message, time.time())
if gcs_publisher:
gcs_publisher.publish_error(job_id.hex().encode(), error_data)
elif redis_client:
pubsub_msg = gcs_utils.PubSubMessage()
pubsub_msg.id = job_id.binary()
pubsub_msg.data = error_data.SerializeToString()
redis_client.publish(
"ERROR_INFO:" + job_id.hex(), pubsub_msg.SerializeToString()
)
else:
raise ValueError("One of redis_client and gcs_publisher needs to be specified!")
gcs_publisher.publish_error(job_id.hex().encode(), error_data)
def random_string():
@ -1232,7 +1222,7 @@ def internal_kv_get_with_retry(gcs_client, key, namespace, num_retries=20):
if result is not None:
break
else:
logger.debug(f"Fetched {key}=None from redis. Retrying.")
logger.debug(f"Fetched {key}=None from KV. Retrying.")
time.sleep(2)
if not result:
raise RuntimeError(

View file

@ -458,7 +458,6 @@ class Monitor:
publish_error_to_driver(
ray_constants.MONITOR_DIED_ERROR,
message,
redis_client=None,
gcs_publisher=gcs_publisher,
)

View file

@ -103,11 +103,3 @@ cdef class Config:
@staticmethod
def record_ref_creation_sites():
return RayConfig.instance().record_ref_creation_sites()
@staticmethod
def gcs_grpc_based_pubsub():
return RayConfig.instance().gcs_grpc_based_pubsub()
@staticmethod
def bootstrap_with_gcs():
return RayConfig.instance().bootstrap_with_gcs()

View file

@ -879,10 +879,8 @@ class Node:
def start_log_monitor(self):
"""Start the log monitor."""
process_info = ray._private.services.start_log_monitor(
self.redis_address,
self.gcs_address,
self._logs_dir,
redis_password=self._ray_params.redis_password,
self.gcs_address,
fate_share=self.kernel_fate_share,
max_bytes=self.max_bytes,
backup_count=self.backup_count,

View file

@ -1404,7 +1404,7 @@ def connect(
startup_token=0,
ray_debugger_external=False,
):
"""Connect this worker to the raylet, to Plasma, and to Redis.
"""Connect this worker to the raylet, to Plasma, and to GCS.
Args:
node (ray.node.Node): The node to connect.
@ -1434,10 +1434,6 @@ def connect(
except io.UnsupportedOperation:
pass # ignore
# Create a Redis client to primary.
# The Redis client can safely be shared between threads. However,
# that is not true of Redis pubsub clients. See the documentation at
# https://github.com/andymccurdy/redis-py#thread-safety.
worker.gcs_client = node.get_gcs_client()
assert worker.gcs_client is not None
_initialize_internal_kv(worker.gcs_client)
@ -1489,8 +1485,6 @@ def connect(
ray._private.utils.publish_error_to_driver(
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
traceback_str,
job_id=None,
redis_client=worker.redis_client,
gcs_publisher=worker.gcs_publisher,
)

View file

@ -39,7 +39,6 @@ class ClusterManager(abc.ABC):
# Add flags for redisless Ray
self.cluster_env.setdefault("env_vars", {})
self.cluster_env["env_vars"]["MATCH_AUTOSCALER_AND_RAY_IMAGES"] = "1"
self.cluster_env["env_vars"]["RAY_bootstrap_with_gcs"] = "1"
self.cluster_env["env_vars"]["RAY_gcs_storage"] = "memory"
self.cluster_env["env_vars"]["RAY_USAGE_STATS_ENABLED"] = "1"
self.cluster_env["env_vars"]["RAY_USAGE_STATS_SOURCE"] = "nightly-tests"