mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
Check version info in ray start for non-head nodes. (#1264)
* Check version info in ray start for non-head nodes. * Small fix. * Fix * Push error to all drivers when worker has version mismatch. * Linting * Linting * Fix * Unify methods. * Fix bug.
This commit is contained in:
parent
2c0d5544ac
commit
c1496b8111
6 changed files with 103 additions and 88 deletions
|
@ -15,7 +15,8 @@ import ray.local_scheduler
|
|||
import ray.signature as signature
|
||||
import ray.worker
|
||||
from ray.utils import (binary_to_hex, FunctionProperties, random_string,
|
||||
release_gpus_in_use, select_local_scheduler, is_cython)
|
||||
release_gpus_in_use, select_local_scheduler, is_cython,
|
||||
push_error_to_driver)
|
||||
|
||||
|
||||
def random_actor_id():
|
||||
|
@ -252,9 +253,9 @@ def fetch_and_register_actor(actor_class_key, worker):
|
|||
# traceback and notify the scheduler of the failure.
|
||||
traceback_str = ray.worker.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
worker.push_error_to_driver(driver_id, "register_actor_signatures",
|
||||
traceback_str,
|
||||
data={"actor_id": actor_id_str})
|
||||
push_error_to_driver(worker.redis_client, "register_actor_signatures",
|
||||
traceback_str, driver_id,
|
||||
data={"actor_id": actor_id_str})
|
||||
# TODO(rkn): In the future, it might make sense to have the worker exit
|
||||
# here. However, currently that would lead to hanging if someone calls
|
||||
# ray.get on a method invoked on the actor.
|
||||
|
|
|
@ -3,16 +3,12 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import click
|
||||
import redis
|
||||
import subprocess
|
||||
|
||||
import ray.services as services
|
||||
|
||||
|
||||
def check_no_existing_redis_clients(node_ip_address, redis_address):
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=int(redis_port))
|
||||
def check_no_existing_redis_clients(node_ip_address, redis_client):
|
||||
# The client table prefix must be kept in sync with the file
|
||||
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
|
||||
REDIS_CLIENT_TABLE_PREFIX = "CL:"
|
||||
|
@ -158,9 +154,18 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
|||
raise Exception("If --head is not passed in, the --no-ui flag is "
|
||||
"not relevant.")
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
|
||||
# Wait for the Redis server to be started. And throw an exception if we
|
||||
# can't connect to it.
|
||||
services.wait_for_redis_to_start(redis_ip_address, int(redis_port))
|
||||
|
||||
# Create a Redis client.
|
||||
redis_client = services.create_redis_client(redis_address)
|
||||
|
||||
# Check that the verion information on this node matches the version
|
||||
# information that the cluster was started with.
|
||||
services.check_version_info(redis_client)
|
||||
|
||||
# Get the node IP address if one is not provided.
|
||||
if node_ip_address is None:
|
||||
node_ip_address = services.get_node_ip_address(redis_address)
|
||||
|
@ -168,7 +173,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
|||
# Check that there aren't already Redis clients with the same IP
|
||||
# address connected with this Redis instance. This raises an exception
|
||||
# if the Redis server already has clients on this node.
|
||||
check_no_existing_redis_clients(node_ip_address, redis_address)
|
||||
check_no_existing_redis_clients(node_ip_address, redis_client)
|
||||
address_info = services.start_ray_node(
|
||||
node_ip_address=node_ip_address,
|
||||
redis_address=redis_address,
|
||||
|
|
|
@ -226,6 +226,21 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files):
|
|||
redis_client.rpush(log_file_list_key, log_file.name)
|
||||
|
||||
|
||||
def create_redis_client(redis_address):
|
||||
"""Create a Redis client.
|
||||
|
||||
Args:
|
||||
The IP address and port of the Redis server.
|
||||
|
||||
Returns:
|
||||
A Redis client.
|
||||
"""
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
# For this command to work, some other client (on the same machine
|
||||
# as Redis) must have run "CONFIG SET protected-mode no".
|
||||
return redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
|
||||
|
||||
|
||||
def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
|
||||
"""Wait for a Redis server to be available.
|
||||
|
||||
|
|
|
@ -11,6 +11,37 @@ import sys
|
|||
|
||||
import ray.local_scheduler
|
||||
|
||||
ERROR_KEY_PREFIX = b"Error:"
|
||||
DRIVER_ID_LENGTH = 20
|
||||
|
||||
|
||||
def _random_string():
|
||||
return np.random.bytes(20)
|
||||
|
||||
|
||||
def push_error_to_driver(redis_client, error_type, message, driver_id=None,
|
||||
data=None):
|
||||
"""Push an error message to the driver to be printed in the background.
|
||||
|
||||
Args:
|
||||
redis_client: The redis client to use.
|
||||
error_type (str): The type of the error.
|
||||
message (str): The message that will be printed in the background
|
||||
on the driver.
|
||||
driver_id: The ID of the driver to push the error message to. If this
|
||||
is None, then the message will be pushed to all drivers.
|
||||
data: This should be a dictionary mapping strings to strings. It
|
||||
will be serialized with json and stored in Redis.
|
||||
"""
|
||||
if driver_id is None:
|
||||
driver_id = DRIVER_ID_LENGTH * b"\x00"
|
||||
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
|
||||
data = {} if data is None else data
|
||||
redis_client.hmset(error_key, {"type": error_type,
|
||||
"message": message,
|
||||
"data": data})
|
||||
redis_client.rpush("ErrorKeys", error_key)
|
||||
|
||||
|
||||
def is_cython(obj):
|
||||
"""Check if an object is a Cython function or method"""
|
||||
|
|
|
@ -407,9 +407,10 @@ class Worker(object):
|
|||
"object store. This may be fine, or it "
|
||||
"may be a bug.")
|
||||
if not warning_sent:
|
||||
self.push_error_to_driver(self.task_driver_id.id(),
|
||||
"wait_for_class",
|
||||
warning_message)
|
||||
ray.utils.push_error_to_driver(
|
||||
self.redis_client, "wait_for_class",
|
||||
warning_message,
|
||||
driver_id=self.task_driver_id.id())
|
||||
warning_sent = True
|
||||
|
||||
def get_object(self, object_ids):
|
||||
|
@ -599,24 +600,6 @@ class Worker(object):
|
|||
# operations into a transaction (or by implementing a custom
|
||||
# command that does all three things).
|
||||
|
||||
def push_error_to_driver(self, driver_id, error_type, message, data=None):
|
||||
"""Push an error message to the driver to be printed in the background.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver to push the error message to.
|
||||
error_type (str): The type of the error.
|
||||
message (str): The message that will be printed in the background
|
||||
on the driver.
|
||||
data: This should be a dictionary mapping strings to strings. It
|
||||
will be serialized with json and stored in Redis.
|
||||
"""
|
||||
error_key = ERROR_KEY_PREFIX + driver_id + b":" + random_string()
|
||||
data = {} if data is None else data
|
||||
self.redis_client.hmset(error_key, {"type": error_type,
|
||||
"message": message,
|
||||
"data": data})
|
||||
self.redis_client.rpush("ErrorKeys", error_key)
|
||||
|
||||
def _wait_for_function(self, function_id, driver_id, timeout=10):
|
||||
"""Wait until the function to be executed is present on this worker.
|
||||
|
||||
|
@ -651,9 +634,10 @@ class Worker(object):
|
|||
"registered. You may have to restart "
|
||||
"Ray.")
|
||||
if not warning_sent:
|
||||
self.push_error_to_driver(driver_id,
|
||||
"wait_for_function",
|
||||
warning_message)
|
||||
ray.utils.push_error_to_driver(self.redis_client,
|
||||
"wait_for_function",
|
||||
warning_message,
|
||||
driver_id=driver_id)
|
||||
warning_sent = True
|
||||
time.sleep(0.001)
|
||||
|
||||
|
@ -808,10 +792,12 @@ class Worker(object):
|
|||
range(len(return_object_ids))]
|
||||
self._store_outputs_in_objstore(return_object_ids, failure_objects)
|
||||
# Log the error message.
|
||||
self.push_error_to_driver(self.task_driver_id.id(), "task",
|
||||
str(failure_object),
|
||||
data={"function_id": function_id.id(),
|
||||
"function_name": function_name})
|
||||
ray.utils.push_error_to_driver(self.redis_client,
|
||||
"task",
|
||||
str(failure_object),
|
||||
driver_id=self.task_driver_id.id(),
|
||||
data={"function_id": function_id.id(),
|
||||
"function_name": function_name})
|
||||
|
||||
def _wait_for_and_process_task(self, task):
|
||||
"""Wait for a task to be ready and process the task.
|
||||
|
@ -1552,10 +1538,12 @@ def fetch_and_register_remote_function(key, worker=global_worker):
|
|||
# record the traceback and notify the scheduler of the failure.
|
||||
traceback_str = format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
worker.push_error_to_driver(driver_id, "register_remote_function",
|
||||
traceback_str,
|
||||
data={"function_id": function_id.id(),
|
||||
"function_name": function_name})
|
||||
ray.utils.push_error_to_driver(worker.redis_client,
|
||||
"register_remote_function",
|
||||
traceback_str,
|
||||
driver_id=driver_id,
|
||||
data={"function_id": function_id.id(),
|
||||
"function_name": function_name})
|
||||
else:
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
|
@ -1582,8 +1570,11 @@ def fetch_and_execute_function_to_run(key, worker=global_worker):
|
|||
# Log the error message.
|
||||
name = function.__name__ if ("function" in locals() and
|
||||
hasattr(function, "__name__")) else ""
|
||||
worker.push_error_to_driver(driver_id, "function_to_run",
|
||||
traceback_str, data={"name": name})
|
||||
ray.utils.push_error_to_driver(worker.redis_client,
|
||||
"function_to_run",
|
||||
traceback_str,
|
||||
driver_id=driver_id,
|
||||
data={"name": name})
|
||||
|
||||
|
||||
def import_thread(worker, mode):
|
||||
|
@ -1714,9 +1705,19 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
|
|||
worker.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=int(redis_port))
|
||||
|
||||
# Check that the version information matches the version information that
|
||||
# the Ray cluster was started with.
|
||||
ray.services.check_version_info(worker.redis_client)
|
||||
# For driver's check that the version information matches the version
|
||||
# information that the Ray cluster was started with.
|
||||
try:
|
||||
ray.services.check_version_info(worker.redis_client)
|
||||
except Exception as e:
|
||||
if mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||
raise e
|
||||
elif mode == WORKER_MODE:
|
||||
traceback_str = traceback.format_exc()
|
||||
ray.utils.push_error_to_driver(worker.redis_client,
|
||||
"version_mismatch",
|
||||
traceback_str,
|
||||
driver_id=None)
|
||||
|
||||
worker.lock = threading.Lock()
|
||||
|
||||
|
|
|
@ -4,8 +4,6 @@ from __future__ import print_function
|
|||
|
||||
import argparse
|
||||
import binascii
|
||||
import numpy as np
|
||||
import redis
|
||||
import traceback
|
||||
|
||||
import ray
|
||||
|
@ -30,36 +28,6 @@ parser.add_argument("--reconstruct", action="store_true",
|
|||
"mode"))
|
||||
|
||||
|
||||
def random_string():
|
||||
return np.random.bytes(20)
|
||||
|
||||
|
||||
def create_redis_client(redis_address):
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
# For this command to work, some other client (on the same machine
|
||||
# as Redis) must have run "CONFIG SET protected-mode no".
|
||||
return redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
|
||||
|
||||
|
||||
def push_error_to_all_drivers(redis_client, message, error_type):
|
||||
"""Push an error message to all drivers.
|
||||
|
||||
Args:
|
||||
redis_client: The redis client to use.
|
||||
message: The error message to push.
|
||||
error_type: The type of the error.
|
||||
"""
|
||||
DRIVER_ID_LENGTH = 20
|
||||
# We use a driver ID of all zeros to push an error message to all
|
||||
# drivers.
|
||||
driver_id = DRIVER_ID_LENGTH * b"\x00"
|
||||
error_key = b"Error:" + driver_id + b":" + random_string()
|
||||
# Create a Redis client.
|
||||
redis_client.hmset(error_key, {"type": error_type,
|
||||
"message": message})
|
||||
redis_client.rpush("ErrorKeys", error_key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -80,13 +48,6 @@ if __name__ == "__main__":
|
|||
|
||||
ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id)
|
||||
|
||||
try:
|
||||
ray.services.check_version_info(ray.worker.global_worker.redis_client)
|
||||
except Exception as e:
|
||||
traceback_str = traceback.format_exc()
|
||||
push_error_to_all_drivers(ray.worker.global_worker.redis_client,
|
||||
traceback_str, "version_mismatch")
|
||||
|
||||
error_explanation = """
|
||||
This error is unexpected and should not have happened. Somehow a worker
|
||||
crashed in an unanticipated way causing the main_loop to throw an exception,
|
||||
|
@ -103,8 +64,9 @@ if __name__ == "__main__":
|
|||
except Exception as e:
|
||||
traceback_str = traceback.format_exc() + error_explanation
|
||||
# Create a Redis client.
|
||||
redis_client = create_redis_client(args.redis_address)
|
||||
push_error_to_all_drivers(redis_client, traceback_str, "worker_crash")
|
||||
redis_client = ray.services.create_redis_client(args.redis_address)
|
||||
ray.utils.push_error_to_driver(redis_client, "worker_crash",
|
||||
traceback_str, driver_id=None)
|
||||
# TODO(rkn): Note that if the worker was in the middle of executing
|
||||
# a task, then any worker or driver that is blocking in a get call
|
||||
# and waiting for the output of that task will hang. We need to
|
||||
|
|
Loading…
Add table
Reference in a new issue