[xray] Add error table and push error messages to driver through node manager. (#2256)

* Fix documentation indentation.

* Add error table to GCS and push error messages through node manager.

* Add type to error data.

* Linting

* Fix failure_test bug.

* Linting.

* Enable one more test.

* Attempt to fix doc building.

* Restructuring

* Fixes

* More fixes.

* Move current_time_ms function into util.h.
This commit is contained in:
Robert Nishihara 2018-06-20 21:29:28 -07:00 committed by Philipp Moritz
parent 6bf48f47bc
commit ff2217251f
27 changed files with 610 additions and 204 deletions

View file

@ -1,5 +1,6 @@
colorama
click
flatbuffers
funcsigs
mock
numpy

View file

@ -34,14 +34,24 @@ MOCK_MODULES = ["gym",
"tensorflow.python.util",
"ray.local_scheduler",
"ray.plasma",
"ray.core",
"ray.core.generated",
"ray.core.generated.DriverTableMessage",
"ray.core.generated.LocalSchedulerInfoMessage",
"ray.core.generated.ResultTableReply",
"ray.core.generated.SubscribeToDBClientTableReply",
"ray.core.generated.SubscribeToNotificationsReply",
"ray.core.generated.TaskInfo",
"ray.core.generated.TaskReply",
"ray.core.generated.ResultTableReply",
"ray.core.generated.TaskExecutionDependencies",
"ray.core.generated.ClientTableData",
"ray.core.generated.GcsTableEntry",
"ray.core.generated.HeartbeatTableData",
"ray.core.generated.ErrorTableData",
"ray.core.generated.ObjectTableData",
"ray.core.generated.ray.protocol.Task"]
"ray.core.generated.ray.protocol.Task",
"ray.core.generated.TablePrefix",
"ray.core.generated.TablePubsub",]
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()

View file

@ -164,7 +164,7 @@ def save_and_log_checkpoint(worker, actor):
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id.id(),
@ -188,7 +188,7 @@ def restore_and_log_checkpoint(worker, actor):
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id.id(),
@ -330,7 +330,7 @@ def fetch_and_register_actor(actor_class_key, worker):
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
push_error_to_driver(
worker.redis_client,
worker,
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
traceback_str,
driver_id,
@ -402,7 +402,7 @@ def export_actor_class(class_id, Class, actor_method_names,
.format(actor_class_info["class_name"],
len(actor_class_info["class"])))
ray.utils.push_error_to_driver(
worker.redis_client,
worker,
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=worker.task_driver_id.id())

View file

@ -8,20 +8,9 @@ import sys
import time
import unittest
import ray.gcs_utils
import ray.services
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToNotificationsReply \
import SubscribeToNotificationsReply
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.ResultTableReply import ResultTableReply
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
OBJECT_SUBSCRIBE_PREFIX = "OS:"
TASK_PREFIX = "TT:"
OBJECT_CHANNEL_PREFIX = "OC:"
def integerToAsciiHex(num, numbytes):
retstr = b""
@ -194,7 +183,7 @@ class TestGlobalStateStore(unittest.TestCase):
# notifications.
def check_object_notification(notification_message, object_id,
object_size, manager_ids):
notification_object = (SubscribeToNotificationsReply.
notification_object = (ray.gcs_utils.SubscribeToNotificationsReply.
GetRootAsSubscribeToNotificationsReply(
notification_message, 0))
self.assertEqual(notification_object.ObjectId(), object_id)
@ -208,7 +197,8 @@ class TestGlobalStateStore(unittest.TestCase):
data_size = 0xf1f0
p = self.redis.pubsub()
# Subscribe to an object ID.
p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX))
p.psubscribe("{}manager_id1".format(
ray.gcs_utils.OBJECT_CHANNEL_PREFIX))
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1",
data_size, "hash1", "manager_id2")
# Receive the acknowledgement message.
@ -252,8 +242,9 @@ class TestGlobalStateStore(unittest.TestCase):
def testResultTableAddAndLookup(self):
def check_result_table_entry(message, task_id, is_put):
result_table_reply = ResultTableReply.GetRootAsResultTableReply(
message, 0)
result_table_reply = (
ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply(
message, 0))
self.assertEqual(result_table_reply.TaskId(), task_id)
self.assertEqual(result_table_reply.IsPut(), is_put)
@ -315,12 +306,13 @@ class TestGlobalStateStore(unittest.TestCase):
# make sure somebody will get a notification (checked in the redis
# module)
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX))
def check_task_reply(message, task_args, updated=False):
(task_status, local_scheduler_id, execution_dependencies_string,
spillback_count, task_spec) = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
task_reply_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply(
message, 0)
self.assertEqual(task_reply_object.State(), task_status)
self.assertEqual(task_reply_object.LocalSchedulerId(),
local_scheduler_id)
@ -409,7 +401,8 @@ class TestGlobalStateStore(unittest.TestCase):
# Receive the data.
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
notification_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply(
message, 0)
self.assertEqual(notification_object.TaskId(), task_args[0])
self.assertEqual(notification_object.State(), task_args[1])
self.assertEqual(notification_object.LocalSchedulerId(), task_args[2])
@ -422,32 +415,34 @@ class TestGlobalStateStore(unittest.TestCase):
local_scheduler_id = "local_scheduler_id"
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
# unsubscribe to make sure there is only one subscriber at a given time
p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
p.punsubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
prefix=ray.gcs_utils.TASK_PREFIX,
local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
prefix=ray.gcs_utils.TASK_PREFIX,
local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)

View file

@ -12,41 +12,10 @@ import sys
import time
import ray
import ray.gcs_utils
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
# Import flatbuffer bindings.
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.ResultTableReply import ResultTableReply
from ray.core.generated.TaskExecutionDependencies import \
TaskExecutionDependencies
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ray.protocol.Task import Task
# These prefixes must be kept up-to-date with the definitions in
# ray_redis_module.cc.
DB_CLIENT_PREFIX = "CL:"
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
OBJECT_SUBSCRIBE_PREFIX = "OS:"
TASK_PREFIX = "TT:"
FUNCTION_PREFIX = "RemoteFunction:"
OBJECT_CHANNEL_PREFIX = "OC:"
# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs.
# TODO(rkn): We should use scoped enums, in which case we should be able to
# just access the flatbuffer generated values.
TablePrefix_RAYLET_TASK = 2
TablePrefix_RAYLET_TASK_string = "TASK"
TablePrefix_CLIENT = 3
TablePrefix_CLIENT_string = "CLIENT"
TablePrefix_OBJECT = 4
TablePrefix_OBJECT_string = "OBJECT"
# This mapping from integer to task state string must be kept up-to-date with
# the scheduling_state enum in task.h.
TASK_STATUS_WAITING = 1
@ -231,8 +200,9 @@ class GlobalState(object):
result_table_response = self._execute_command(
object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id())
result_table_message = ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0)
result_table_message = (
ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0))
result = {
"ManagerIDs": manager_ids,
@ -245,12 +215,14 @@ class GlobalState(object):
else:
# Use the raylet code path.
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", TablePrefix_OBJECT, "", object_id.id())
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.OBJECT, "",
object_id.id())
result = []
gcs_entry = GcsTableEntry.GetRootAsGcsTableEntry(message, 0)
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
for i in range(gcs_entry.EntriesLength()):
entry = ObjectTableData.GetRootAsObjectTableData(
entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData(
gcs_entry.Entries(i), 0)
object_info = {
"DataSize": entry.ObjectSize(),
@ -279,19 +251,22 @@ class GlobalState(object):
else:
# Return the entire object table.
if not self.use_raylet:
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
object_info_keys = self._keys(
ray.gcs_utils.OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(
ray.gcs_utils.OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set([
key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys
key[len(ray.gcs_utils.OBJECT_INFO_PREFIX):]
for key in object_info_keys
] + [
key[len(OBJECT_LOCATION_PREFIX):]
key[len(ray.gcs_utils.OBJECT_LOCATION_PREFIX):]
for key in object_location_keys
])
else:
object_keys = self.redis_client.keys(
TablePrefix_OBJECT_string + ":*")
ray.gcs_utils.TablePrefix_OBJECT_string + "*")
object_ids_binary = {
key[len(TablePrefix_OBJECT_string + ":"):]
key[len(ray.gcs_utils.TablePrefix_OBJECT_string):]
for key in object_keys
}
@ -320,7 +295,7 @@ class GlobalState(object):
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task "
"table.".format(binary_to_hex(task_id.id())))
task_table_message = TaskReply.GetRootAsTaskReply(
task_table_message = ray.gcs_utils.TaskReply.GetRootAsTaskReply(
task_table_response, 0)
task_spec = task_table_message.TaskSpec()
task_spec = ray.local_scheduler.task_from_string(task_spec)
@ -343,7 +318,8 @@ class GlobalState(object):
}
execution_dependencies_message = (
TaskExecutionDependencies.GetRootAsTaskExecutionDependencies(
ray.gcs_utils.TaskExecutionDependencies.
GetRootAsTaskExecutionDependencies(
task_table_message.ExecutionDependencies(), 0))
execution_dependencies = [
ray.ObjectID(
@ -371,15 +347,17 @@ class GlobalState(object):
else:
# Use the raylet code path.
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", TablePrefix_RAYLET_TASK, "", task_id.id())
gcs_entries = GcsTableEntry.GetRootAsGcsTableEntry(message, 0)
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.RAYLET_TASK, "",
task_id.id())
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
info = []
for i in range(gcs_entries.EntriesLength()):
task_table_message = Task.GetRootAsTask(
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
gcs_entries.Entries(i), 0)
task_table_message = Task.GetRootAsTask(
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
gcs_entries.Entries(0), 0)
execution_spec = task_table_message.TaskExecutionSpec()
task_spec = task_table_message.TaskSpecification()
@ -432,15 +410,16 @@ class GlobalState(object):
return self._task_table(task_id)
else:
if not self.use_raylet:
task_table_keys = self._keys(TASK_PREFIX + "*")
task_table_keys = self._keys(ray.gcs_utils.TASK_PREFIX + "*")
task_ids_binary = [
key[len(TASK_PREFIX):] for key in task_table_keys
key[len(ray.gcs_utils.TASK_PREFIX):]
for key in task_table_keys
]
else:
task_table_keys = self.redis_client.keys(
TablePrefix_RAYLET_TASK_string + ":*")
ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
task_ids_binary = [
key[len(TablePrefix_RAYLET_TASK_string + ":"):]
key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):]
for key in task_table_keys
]
@ -458,7 +437,8 @@ class GlobalState(object):
function.
"""
self._check_connected()
function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*")
function_table_keys = self.redis_client.keys(
ray.gcs_utils.FUNCTION_PREFIX + "*")
results = {}
for key in function_table_keys:
info = self.redis_client.hgetall(key)
@ -478,7 +458,8 @@ class GlobalState(object):
"""
self._check_connected()
if not self.use_raylet:
db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*")
db_client_keys = self.redis_client.keys(
ray.gcs_utils.DB_CLIENT_PREFIX + "*")
node_info = {}
for key in db_client_keys:
client_info = self.redis_client.hgetall(key)
@ -520,13 +501,16 @@ class GlobalState(object):
# This is the raylet code path.
NIL_CLIENT_ID = 20 * b"\xff"
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", TablePrefix_CLIENT, "", NIL_CLIENT_ID)
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "",
NIL_CLIENT_ID)
node_info = []
gcs_entry = GcsTableEntry.GetRootAsGcsTableEntry(message, 0)
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
for i in range(gcs_entry.EntriesLength()):
client = ClientTableData.GetRootAsClientTableData(
gcs_entry.Entries(i), 0)
client = (
ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
gcs_entry.Entries(i), 0))
resources = {
client.ResourcesTotalLabel(i).decode("ascii"):
@ -1146,3 +1130,64 @@ class GlobalState(object):
resources[key] += value
return dict(resources)
def _error_messages(self, job_id):
"""Get the error messages for a specific job.
Args:
job_id: The ID of the job to get the errors for.
Returns:
A list of the error messages for this job.
"""
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "",
job_id.id())
# If there are no errors, return early.
if message is None:
return []
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
error_messages = []
for i in range(gcs_entries.EntriesLength()):
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
gcs_entries.Entries(i), 0)
error_message = {
"type": error_data.Type().decode("ascii"),
"message": error_data.ErrorMessage().decode("ascii"),
"timestamp": error_data.Timestamp(),
}
error_messages.append(error_message)
return error_messages
def error_messages(self, job_id=None):
"""Get the error messages for all jobs or a specific job.
Args:
job_id: The specific job to get the errors for. If this is None,
then this method retrieves the errors for all jobs.
Returns:
A dictionary mapping job ID to a list of the error messages for
that job.
"""
if not self.use_raylet:
raise Exception("The error_messages method is only supported in "
"the raylet code path.")
if job_id is not None:
return self._error_messages(job_id)
error_table_keys = self.redis_client.keys(
ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*")
job_ids = [
key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):]
for key in error_table_keys
]
return {
binary_to_hex(job_id): self._error_messages(ray.ObjectID(job_id))
for job_id in job_ids
}

84
python/ray/gcs_utils.py Normal file
View file

@ -0,0 +1,84 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import flatbuffers
from ray.core.generated.ResultTableReply import ResultTableReply
from ray.core.generated.SubscribeToNotificationsReply \
import SubscribeToNotificationsReply
from ray.core.generated.TaskExecutionDependencies import \
TaskExecutionDependencies
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.DriverTableMessage import DriverTableMessage
from ray.core.generated.LocalSchedulerInfoMessage import \
LocalSchedulerInfoMessage
from ray.core.generated.SubscribeToDBClientTableReply import \
SubscribeToDBClientTableReply
from ray.core.generated.TaskInfo import TaskInfo
import ray.core.generated.ErrorTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ray.protocol.Task import Task
from ray.core.generated.TablePrefix import TablePrefix
from ray.core.generated.TablePubsub import TablePubsub
__all__ = [
"SubscribeToNotificationsReply", "ResultTableReply",
"TaskExecutionDependencies", "TaskReply", "DriverTableMessage",
"LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo",
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
"ObjectTableData", "Task", "TablePrefix", "TablePubsub",
"construct_error_message"
]
# These prefixes must be kept up-to-date with the definitions in
# ray_redis_module.cc.
DB_CLIENT_PREFIX = "CL:"
TASK_PREFIX = "TT:"
OBJECT_CHANNEL_PREFIX = "OC:"
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
FUNCTION_PREFIX = "RemoteFunction:"
# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs.
# TODO(rkn): We should use scoped enums, in which case we should be able to
# just access the flatbuffer generated values.
TablePrefix_RAYLET_TASK_string = "RAYLET_TASK"
TablePrefix_OBJECT_string = "OBJECT"
TablePrefix_ERROR_INFO_string = "ERROR_INFO"
def construct_error_message(error_type, message, timestamp):
"""Construct a serialized ErrorTableData object.
Args:
error_type: The type of the error.
message: The error message.
timestamp: The time of the error.
Returns:
The serialized object.
"""
builder = flatbuffers.Builder(0)
error_type_offset = builder.CreateString(error_type)
message_offset = builder.CreateString(message)
ray.core.generated.ErrorTableData.ErrorTableDataStart(builder)
ray.core.generated.ErrorTableData.ErrorTableDataAddType(
builder, error_type_offset)
ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage(
builder, message_offset)
ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp(
builder, timestamp)
error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd(
builder)
builder.Finish(error_data_offset)
return bytes(builder.Output())

View file

@ -29,11 +29,6 @@ NIL_WORKER_ID = 20 * b"\xff"
NIL_OBJECT_ID = 20 * b"\xff"
NIL_ACTOR_ID = 20 * b"\xff"
# These constants are an implementation detail of ray_redis_module.cc, so this
# must be kept in sync with that file.
DB_CLIENT_PREFIX = "CL:"
TASK_PREFIX = "TT:"
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))

View file

@ -9,20 +9,13 @@ import os
import time
from collections import Counter, defaultdict
import ray
import ray.cloudpickle as pickle
import ray.utils
import redis
# Import flatbuffer bindings.
from ray.core.generated.DriverTableMessage import DriverTableMessage
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.LocalSchedulerInfoMessage import \
LocalSchedulerInfoMessage
from ray.core.generated.SubscribeToDBClientTableReply import \
SubscribeToDBClientTableReply
import ray
from ray.autoscaler.autoscaler import LoadMetrics, StandardAutoscaler
from ray.core.generated.TaskInfo import TaskInfo
import ray.cloudpickle as pickle
import ray.gcs_utils
import ray.utils
from ray.services import get_ip_address, get_port
from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary
from ray.worker import NIL_ACTOR_ID
@ -259,7 +252,7 @@ class Monitor(object):
the associated state in the state tables should be handled by the
caller.
"""
notification_object = (SubscribeToDBClientTableReply.
notification_object = (ray.gcs_utils.SubscribeToDBClientTableReply.
GetRootAsSubscribeToDBClientTableReply(data, 0))
db_client_id = binary_to_hex(notification_object.DbClientId())
client_type = notification_object.ClientType()
@ -285,8 +278,8 @@ class Monitor(object):
def local_scheduler_info_handler(self, unused_channel, data):
"""Handle a local scheduler heartbeat from Redis."""
message = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage(
data, 0)
message = (ray.gcs_utils.LocalSchedulerInfoMessage.
GetRootAsLocalSchedulerInfoMessage(data, 0))
num_resources = message.DynamicResourcesLength()
static_resources = {}
dynamic_resources = {}
@ -308,9 +301,10 @@ class Monitor(object):
def xray_heartbeat_handler(self, unused_channel, data):
"""Handle an xray heartbeat message from Redis."""
gcs_entries = GcsTableEntry.GetRootAsGcsTableEntry(data, 0)
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0)
heartbeat_data = gcs_entries.Entries(0)
message = HeartbeatTableData.GetRootAsHeartbeatTableData(
message = ray.gcs_utils.HeartbeatTableData.GetRootAsHeartbeatTableData(
heartbeat_data, 0)
num_resources = message.ResourcesAvailableLabelLength()
static_resources = {}
@ -363,7 +357,8 @@ class Monitor(object):
# driver. Use a cursor in order not to block the redis shards.
for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"):
entry = redis.hgetall(key)
task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0)
task_info = ray.gcs_utils.TaskInfo.GetRootAsTaskInfo(
entry[b"TaskSpec"], 0)
if driver_id != task_info.DriverId():
# Ignore tasks that aren't from this driver.
continue
@ -475,7 +470,8 @@ class Monitor(object):
This releases any GPU resources that were reserved for that driver in
Redis.
"""
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
message = ray.gcs_utils.DriverTableMessage.GetRootAsDriverTableMessage(
data, 0)
driver_id = message.DriverId()
log.info("Driver {} has been removed.".format(
binary_to_hex(driver_id)))

View file

@ -5,6 +5,8 @@ from __future__ import print_function
import os
import ray
def env_integer(key, default):
if key in os.environ:
@ -12,6 +14,9 @@ def env_integer(key, default):
return default
ID_SIZE = 20
NIL_JOB_ID = ray.ObjectID(ID_SIZE * b"\x00")
# If a remote function or actor (or some other export) has serialized size
# greater than this quantity, print an warning.
PICKLE_OBJECT_WARNING_SIZE = 10**7

View file

@ -7,9 +7,12 @@ import hashlib
import numpy as np
import os
import sys
import time
import uuid
import ray.gcs_utils
import ray.local_scheduler
import ray.ray_constants as ray_constants
ERROR_KEY_PREFIX = b"Error:"
DRIVER_ID_LENGTH = 20
@ -45,7 +48,7 @@ def format_error_message(exception_message, task_exception=False):
return "\n".join(lines)
def push_error_to_driver(redis_client,
def push_error_to_driver(worker,
error_type,
message,
driver_id=None,
@ -53,7 +56,7 @@ def push_error_to_driver(redis_client,
"""Push an error message to the driver to be printed in the background.
Args:
redis_client: The redis client to use.
worker: The worker to use.
error_type (str): The type of the error.
message (str): The message that will be printed in the background
on the driver.
@ -63,15 +66,65 @@ def push_error_to_driver(redis_client,
will be serialized with json and stored in Redis.
"""
if driver_id is None:
driver_id = DRIVER_ID_LENGTH * b"\x00"
driver_id = ray_constants.NIL_JOB_ID.id()
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
data = {} if data is None else data
redis_client.hmset(error_key, {
"type": error_type,
"message": message,
"data": data
})
redis_client.rpush("ErrorKeys", error_key)
if not worker.use_raylet:
worker.redis_client.hmset(error_key, {
"type": error_type,
"message": message,
"data": data
})
worker.redis_client.rpush("ErrorKeys", error_key)
else:
worker.local_scheduler_client.push_error(
ray.ObjectID(driver_id), error_type, message, time.time())
def push_error_to_driver_through_redis(redis_client,
use_raylet,
error_type,
message,
driver_id=None,
data=None):
"""Push an error message to the driver to be printed in the background.
Normally the push_error_to_driver function should be used. However, in some
instances, the local scheduler client is not available, e.g., because the
error happens in Python before the driver or worker has connected to the
backend processes.
Args:
redis_client: The redis client to use.
use_raylet: True if we are using the Raylet code path and false
otherwise.
error_type (str): The type of the error.
message (str): The message that will be printed in the background
on the driver.
driver_id: The ID of the driver to push the error message to. If this
is None, then the message will be pushed to all drivers.
data: This should be a dictionary mapping strings to strings. It
will be serialized with json and stored in Redis.
"""
if driver_id is None:
driver_id = ray_constants.NIL_JOB_ID.id()
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
data = {} if data is None else data
if not use_raylet:
redis_client.hmset(error_key, {
"type": error_type,
"message": message,
"data": data
})
redis_client.rpush("ErrorKeys", error_key)
else:
# Do everything in Python and through the Python Redis client instead
# of through the raylet.
error_data = ray.gcs_utils.construct_error_message(
error_type, message, time.time())
redis_client.execute_command(
"RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data)
def is_cython(obj):

View file

@ -22,6 +22,7 @@ import pyarrow
import pyarrow.plasma as plasma
import ray.cloudpickle as pickle
import ray.experimental.state as state
import ray.gcs_utils
import ray.remote_function
import ray.serialization as serialization
import ray.services as services
@ -31,9 +32,6 @@ import ray.plasma
import ray.ray_constants as ray_constants
from ray.utils import random_string, binary_to_hex, is_cython
# Import flatbuffer bindings.
from ray.core.generated.ClientTableData import ClientTableData
SCRIPT_MODE = 0
WORKER_MODE = 1
PYTHON_MODE = 2
@ -415,7 +413,7 @@ class Worker(object):
"may be a bug.")
if not warning_sent:
ray.utils.push_error_to_driver(
self.redis_client,
self,
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id.id())
@ -663,7 +661,7 @@ class Worker(object):
"large array or other object.".format(
function_name, len(pickled_function)))
ray.utils.push_error_to_driver(
self.redis_client,
self,
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id.id())
@ -726,7 +724,7 @@ class Worker(object):
.format(function.__name__,
len(pickled_function)))
ray.utils.push_error_to_driver(
self.redis_client,
self,
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id.id())
@ -781,7 +779,7 @@ class Worker(object):
"Ray.")
if not warning_sent:
ray.utils.push_error_to_driver(
self.redis_client,
self,
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
warning_message,
driver_id=driver_id)
@ -942,7 +940,7 @@ class Worker(object):
self._store_outputs_in_objstore(return_object_ids, failure_objects)
# Log the error message.
ray.utils.push_error_to_driver(
self.redis_client,
self,
ray_constants.TASK_PUSH_ERROR,
str(failure_object),
driver_id=self.task_driver_id.id(),
@ -1200,6 +1198,11 @@ def error_info(worker=global_worker):
"""Return information about failed tasks."""
worker.check_connected()
check_main_thread()
if worker.use_raylet:
return (global_state.error_messages(job_id=worker.task_driver_id) +
global_state.error_messages(job_id=ray_constants.NIL_JOB_ID))
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
errors = []
for error_key in error_keys:
@ -1291,9 +1294,8 @@ def get_address_info_from_redis_helper(redis_address,
if not use_raylet:
# The client table prefix must be kept in sync with the file
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
REDIS_CLIENT_TABLE_PREFIX = "CL:"
client_keys = redis_client.keys(
"{}*".format(REDIS_CLIENT_TABLE_PREFIX))
client_keys = redis_client.keys("{}*".format(
ray.gcs_utils.DB_CLIENT_PREFIX))
# Filter to live clients on the same node and do some basic checking.
plasma_managers = []
local_schedulers = []
@ -1350,11 +1352,11 @@ def get_address_info_from_redis_helper(redis_address,
else:
# In the raylet code path, all client data is stored in a zset at the
# key for the nil client.
client_key = b"CLIENT:" + NIL_CLIENT_ID
client_key = b"CLIENT" + NIL_CLIENT_ID
clients = redis_client.zrange(client_key, 0, -1)
raylets = []
for client_message in clients:
client = ClientTableData.GetRootAsClientTableData(
client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
client_message, 0)
client_node_ip_address = client.NodeManagerAddress().decode(
"ascii")
@ -1819,6 +1821,71 @@ def custom_excepthook(type, value, tb):
sys.excepthook = custom_excepthook
def print_error_messages_raylet(worker):
"""Print error messages in the background on the driver.
This runs in a separate thread on the driver and prints error messages in
the background.
"""
if not worker.use_raylet:
raise Exception("This function is specific to the raylet code path.")
worker.error_message_pubsub_client = worker.redis_client.pubsub(
ignore_subscribe_messages=True)
# Exports that are published after the call to
# error_message_pubsub_client.subscribe and before the call to
# error_message_pubsub_client.listen will still be processed in the loop.
# Really we should just subscribe to the errors for this specific job.
# However, currently all errors seem to be published on the same channel.
error_pubsub_channel = str(
ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii")
worker.error_message_pubsub_client.subscribe(error_pubsub_channel)
# worker.error_message_pubsub_client.psubscribe("*")
# Keep a set of all the error messages that we've seen so far in order to
# avoid printing the same error message repeatedly. This is especially
# important when running a script inside of a tool like screen where
# scrolling is difficult.
old_error_messages = set()
# Get the exports that occurred before the call to subscribe.
with worker.lock:
error_messages = global_state.error_messages(worker.task_driver_id)
for error_message in error_messages:
if error_message not in old_error_messages:
print(error_message)
old_error_messages.add(error_message)
else:
print("Suppressing duplicate error message.")
try:
for msg in worker.error_message_pubsub_client.listen():
gcs_entry = state.GcsTableEntry.GetRootAsGcsTableEntry(
msg["data"], 0)
assert gcs_entry.EntriesLength() == 1
error_data = state.ErrorTableData.GetRootAsErrorTableData(
gcs_entry.Entries(0), 0)
NIL_JOB_ID = 20 * b"\x00"
job_id = error_data.JobId()
if job_id not in [worker.task_driver_id.id(), NIL_JOB_ID]:
continue
error_message = error_data.ErrorMessage().decode("ascii")
if error_message not in old_error_messages:
print(error_message)
old_error_messages.add(error_message)
else:
print("Suppressing duplicate error message.")
except redis.ConnectionError:
# When Redis terminates the listen call will throw a ConnectionError,
# which we catch here.
pass
def print_error_messages(worker):
"""Print error messages in the background on the driver.
@ -1907,7 +1974,7 @@ def fetch_and_register_remote_function(key, worker=global_worker):
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
worker,
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
@ -1952,7 +2019,7 @@ def fetch_and_execute_function_to_run(key, worker=global_worker):
name = function.__name__ if ("function" in locals()
and hasattr(function, "__name__")) else ""
ray.utils.push_error_to_driver(
worker.redis_client,
worker,
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
@ -2111,8 +2178,9 @@ def connect(info,
raise e
elif mode == WORKER_MODE:
traceback_str = traceback.format_exc()
ray.utils.push_error_to_driver(
ray.utils.push_error_to_driver_through_redis(
worker.redis_client,
worker.use_raylet,
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
traceback_str,
driver_id=None)
@ -2237,13 +2305,11 @@ def connect(info,
driver_task.execution_dependencies_string(), 0,
ray.local_scheduler.task_to_string(driver_task))
else:
TablePubsub_RAYLET_TASK = 2
# TODO(rkn): When we shard the GCS in xray, we will need to change
# this to use _execute_command.
global_state.redis_client.execute_command(
"RAY.TABLE_ADD", state.TablePrefix_RAYLET_TASK,
TablePubsub_RAYLET_TASK,
"RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.RAYLET_TASK,
ray.gcs_utils.TablePubsub.RAYLET_TASK,
driver_task.task_id().id(),
driver_task._serialized_raylet_task())
@ -2271,7 +2337,11 @@ def connect(info,
# temporarily using this implementation which constantly queries the
# scheduler for new error messages.
if mode == SCRIPT_MODE:
t = threading.Thread(target=print_error_messages, args=(worker, ))
if not worker.use_raylet:
t = threading.Thread(target=print_error_messages, args=(worker, ))
else:
t = threading.Thread(
target=print_error_messages_raylet, args=(worker, ))
# Making the thread a daemon causes it to exit when the main thread
# exits.
t.daemon = True

View file

@ -69,10 +69,11 @@ if __name__ == "__main__":
ray.worker.global_worker.main_loop()
except Exception as e:
traceback_str = traceback.format_exc() + error_explanation
# Create a Redis client.
redis_client = ray.services.create_redis_client(args.redis_address)
ray.utils.push_error_to_driver(
redis_client, "worker_crash", traceback_str, driver_id=None)
ray.worker.global_worker,
"worker_crash",
traceback_str,
driver_id=None)
# TODO(rkn): Note that if the worker was in the middle of executing
# a task, then any worker or driver that is blocking in a get call
# and waiting for the output of that task will hang. We need to

View file

@ -61,14 +61,6 @@ extern RedisChainModule module;
return RedisModule_ReplyWithError(ctx, (MESSAGE)); \
}
// NOTE(swang): The order of prefixes here must match the TablePrefix enum
// defined in src/ray/gcs/format/gcs.fbs.
static const char *table_prefixes[] = {
NULL, "TASK:", "TASK:", "CLIENT:",
"OBJECT:", "ACTOR:", "FUNCTION:", "TASK_RECONSTRUCTION:",
"HEARTBEAT:",
};
/// Parse a Redis string into a TablePubsub channel.
TablePubsub ParseTablePubsub(const RedisModuleString *pubsub_channel_str) {
long long pubsub_channel_long;
@ -128,8 +120,8 @@ RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx,
<< "This table has no prefix registered";
RAY_CHECK(prefix >= TablePrefix::MIN && prefix <= TablePrefix::MAX)
<< "Prefix must be a valid TablePrefix";
return OpenPrefixedKey(ctx, table_prefixes[static_cast<long long>(prefix)],
keyname, mode, mutated_key_str);
return OpenPrefixedKey(ctx, EnumNameTablePrefix(prefix), keyname, mode,
mutated_key_str);
}
RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx,

View file

@ -286,6 +286,29 @@ static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) {
return Py_BuildValue("(OO)", py_found, py_remaining);
}
static PyObject *PyLocalSchedulerClient_push_error(PyObject *self,
PyObject *args) {
JobID job_id;
const char *type;
int type_length;
const char *error_message;
int error_message_length;
double timestamp;
if (!PyArg_ParseTuple(args, "O&s#s#d", &PyObjectToUniqueID, &job_id, &type,
&type_length, &error_message, &error_message_length,
&timestamp)) {
return NULL;
}
local_scheduler_push_error(reinterpret_cast<PyLocalSchedulerClient *>(self)
->local_scheduler_connection,
job_id, std::string(type, type_length),
std::string(error_message, error_message_length),
timestamp);
Py_RETURN_NONE;
}
static PyMethodDef PyLocalSchedulerClient_methods[] = {
{"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS,
"Notify the local scheduler that this client is exiting gracefully."},
@ -313,6 +336,8 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = {
(PyCFunction) PyLocalSchedulerClient_set_actor_frontier, METH_VARARGS, ""},
{"wait", (PyCFunction) PyLocalSchedulerClient_wait, METH_VARARGS,
"Wait for a list of objects to be created."},
{"push_error", (PyCFunction) PyLocalSchedulerClient_push_error,
METH_VARARGS, "Push an error message to the relevant driver."},
{NULL} /* Sentinel */
};

View file

@ -306,3 +306,19 @@ std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
free(reply);
return result;
}
void local_scheduler_push_error(LocalSchedulerConnection *conn,
const JobID &job_id,
const std::string &type,
const std::string &error_message,
double timestamp) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreatePushErrorRequest(
fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type),
fbb.CreateString(error_message), timestamp);
fbb.Finish(message);
write_message(conn->conn, static_cast<int64_t>(
ray::protocol::MessageType::PushErrorRequest),
fbb.GetSize(), fbb.GetBufferPointer());
}

View file

@ -211,4 +211,18 @@ std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
int64_t timeout_milliseconds,
bool wait_local);
/// Push an error to the relevant driver.
///
/// \param conn The connection information.
/// \param The ID of the job that the error is for.
/// \param The type of the error.
/// \param The error message.
/// \param The timestamp of the error.
/// \return Void.
void local_scheduler_push_error(LocalSchedulerConnection *conn,
const JobID &job_id,
const std::string &type,
const std::string &error_message,
double timestamp);
#endif

View file

@ -15,6 +15,7 @@ AsyncGcsClient::AsyncGcsClient(const ClientID &client_id, CommandType command_ty
raylet_task_table_.reset(new raylet::TaskTable(context_, this, command_type));
task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this));
heartbeat_table_.reset(new HeartbeatTable(context_, this));
error_table_.reset(new ErrorTable(context_, this));
command_type_ = command_type;
}
@ -74,6 +75,9 @@ FunctionTable &AsyncGcsClient::function_table() { return *function_table_; }
ClassTable &AsyncGcsClient::class_table() { return *class_table_; }
HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; }
ErrorTable &AsyncGcsClient::error_table() { return *error_table_; }
} // namespace gcs
} // namespace ray

View file

@ -57,7 +57,7 @@ class RAY_EXPORT AsyncGcsClient {
TaskReconstructionLog &task_reconstruction_log();
ClientTable &client_table();
HeartbeatTable &heartbeat_table();
inline ErrorTable &error_table();
ErrorTable &error_table();
// We also need something to export generic code to run on workers from the
// driver (to set the PYTHONPATH)
@ -78,6 +78,7 @@ class RAY_EXPORT AsyncGcsClient {
std::unique_ptr<ActorTable> actor_table_;
std::unique_ptr<TaskReconstructionLog> task_reconstruction_log_;
std::unique_ptr<HeartbeatTable> heartbeat_table_;
std::unique_ptr<ErrorTable> error_table_;
std::unique_ptr<ClientTable> client_table_;
std::shared_ptr<RedisContext> context_;
std::unique_ptr<RedisAsioClient> asio_async_client_;

View file

@ -14,6 +14,7 @@ enum TablePrefix:int {
FUNCTION,
TASK_RECONSTRUCTION,
HEARTBEAT,
ERROR_INFO,
}
// The channel that Add operations to the Table should be published on, if any.
@ -24,7 +25,8 @@ enum TablePubsub:int {
CLIENT,
OBJECT,
ACTOR,
HEARTBEAT
HEARTBEAT,
ERROR_INFO,
}
table GcsTableEntry {
@ -103,6 +105,14 @@ table ActorTableData {
}
table ErrorTableData {
// The ID of the job that the error is for.
job_id: string;
// The type of the error.
type: string;
// The error message.
error_message: string;
// The timestamp of the error message.
timestamp: double;
}
table CustomSerializerData {

View file

@ -183,6 +183,19 @@ Status Table<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id
done);
}
Status ErrorTable::PushErrorToDriver(const JobID &job_id, const std::string &type,
const std::string &error_message, double timestamp) {
auto data = std::make_shared<ErrorTableDataT>();
data->job_id = job_id.binary();
data->type = type;
data->error_message = error_message;
data->timestamp = timestamp;
return Append(job_id, job_id, data, [](ray::gcs::AsyncGcsClient *client,
const JobID &id, const ErrorTableDataT &data) {
RAY_LOG(DEBUG) << "Error message pushed callback";
});
}
void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) {
client_added_callback_ = callback;
// Call the callback for any added clients that are cached.
@ -333,6 +346,7 @@ template class Table<TaskID, TaskTableData>;
template class Log<ActorID, ActorTableData>;
template class Log<TaskID, TaskReconstructionData>;
template class Table<ClientID, HeartbeatTableData>;
template class Log<JobID, ErrorTableData>;
template class Log<UniqueID, ClientTableData>;
} // namespace gcs

View file

@ -95,7 +95,7 @@ class Log : virtual public PubsubInterface<ID> {
/// \param id The ID of the data that is added to the GCS.
/// \param data Data to append to the log.
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// GCS.
/// \return Status
Status Append(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
const WriteCallback &done);
@ -108,10 +108,9 @@ class Log : virtual public PubsubInterface<ID> {
/// \param data Data to append to the log.
/// \param done Callback that is called if the data was appended to the log.
/// \param failure Callback that is called if the data was not appended to
/// the log because the log length did not match the given
/// `log_length`.
/// the log because the log length did not match the given `log_length`.
/// \param log_length The number of entries that the log must have for the
/// append to succeed.
/// append to succeed.
/// \return Status
Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
const WriteCallback &done, const WriteCallback &failure,
@ -122,7 +121,7 @@ class Log : virtual public PubsubInterface<ID> {
/// \param job_id The ID of the job (= driver).
/// \param id The ID of the data that is looked up in the GCS.
/// \param lookup Callback that is called after lookup. If the callback is
/// called with an empty vector, then there was no data at the key.
/// called with an empty vector, then there was no data at the key.
/// \return Status
Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup);
@ -133,15 +132,14 @@ class Log : virtual public PubsubInterface<ID> {
///
/// \param job_id The ID of the job (= driver).
/// \param client_id The type of update to listen to. If this is nil, then a
/// message for each Add to the table will be received. Else, only
/// messages for the given client will be received. In the latter
/// case, the client may request notifications on specific keys in the
/// table via `RequestNotifications`.
/// message for each Add to the table will be received. Else, only
/// messages for the given client will be received. In the latter
/// case, the client may request notifications on specific keys in the
/// table via `RequestNotifications`.
/// \param subscribe Callback that is called on each received message. If the
/// callback is called with an empty vector, then there was no data at
/// the key.
/// callback is called with an empty vector, then there was no data at the key.
/// \param done Callback that is called when subscription is complete and we
/// are ready to receive messages.
/// are ready to receive messages.
/// \return Status
Status Subscribe(const JobID &job_id, const ClientID &client_id,
const Callback &subscribe, const SubscriptionCallback &done);
@ -158,8 +156,8 @@ class Log : virtual public PubsubInterface<ID> {
/// \param job_id The ID of the job (= driver).
/// \param id The ID of the key to request notifications for.
/// \param client_id The client who is requesting notifications. Before
/// notifications can be requested, a call to `Subscribe` to this
/// table with the same `client_id` must complete successfully.
/// notifications can be requested, a call to `Subscribe` to this
/// table with the same `client_id` must complete successfully.
/// \return Status
Status RequestNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id);
@ -241,7 +239,7 @@ class Table : private Log<ID, Data>,
/// \param id The ID of the data that is added to the GCS.
/// \param data Data that is added to the GCS.
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// GCS.
/// \return Status
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
const WriteCallback &done);
@ -251,9 +249,9 @@ class Table : private Log<ID, Data>,
/// \param job_id The ID of the job (= driver).
/// \param id The ID of the data that is looked up in the GCS.
/// \param lookup Callback that is called after lookup if there was data the
/// key.
/// key.
/// \param failure Callback that is called after lookup if there was no data
/// at the key.
/// at the key.
/// \return Status
Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup,
const FailureCallback &failure);
@ -366,10 +364,10 @@ class TaskTable : public Table<TaskID, TaskTableData> {
///
/// \param task_id The task ID of the task entry to update.
/// \param test_state_bitmask The bitmask to apply to the task entry's current
/// scheduling state. The update happens if and only if the current
/// scheduling state AND-ed with the bitmask is greater than 0.
/// scheduling state. The update happens if and only if the current
/// scheduling state AND-ed with the bitmask is greater than 0.
/// \param update_state The value to update the task entry's scheduling state
/// with, if the current state matches test_state_bitmask.
/// with, if the current state matches test_state_bitmask.
/// \param callback Function to be called when database returns result.
/// \return Status
Status TestAndUpdate(const JobID &job_id, const TaskID &id,
@ -397,16 +395,14 @@ class TaskTable : public Table<TaskID, TaskTableData> {
/// task's local scheduler ID.
///
/// \param local_scheduler_id The db_client_id of the local scheduler whose
/// events we want to listen to. If you want to subscribe to updates
/// from
/// all local schedulers, pass in NIL_ID.
/// events we want to listen to. If you want to subscribe to updates from
/// all local schedulers, pass in NIL_ID.
/// \param subscribe_callback Callback that will be called when the task table
/// is
/// updated.
/// is updated.
/// \param state_filter Events we want to listen to. Can have values from the
/// enum "scheduling_state" in task.h.
/// TODO(pcm): Make it possible to combine these using flags like
/// TASK_STATUS_WAITING | TASK_STATUS_SCHEDULED.
/// enum "scheduling_state" in task.h.
/// TODO(pcm): Make it possible to combine these using flags like
/// TASK_STATUS_WAITING | TASK_STATUS_SCHEDULED.
/// \param callback Function to be called when database returns result.
/// \return Status
Status SubscribeToTask(const JobID &job_id, const ClientID &local_scheduler_id,
@ -422,7 +418,28 @@ Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id,
SchedulingState update_state,
const TaskTable::TestAndUpdateCallback &callback);
using ErrorTable = Table<TaskID, ErrorTableData>;
class ErrorTable : private Log<JobID, ErrorTableData> {
public:
ErrorTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Log(context, client) {
pubsub_channel_ = TablePubsub::ERROR_INFO;
prefix_ = TablePrefix::ERROR_INFO;
};
/// Push an error message for a specific job.
///
/// TODO(rkn): We need to make sure that the errors are unique because
/// duplicate messages currently cause failures (the GCS doesn't allow it).
///
/// \param job_id The ID of the job that generated the error. If the error
/// should be pushed to all jobs, then this should be nil.
/// \param type The type of the error.
/// \param error_message The error message to push.
/// \param timestamp The timestamp of the error.
/// \return Status.
Status PushErrorToDriver(const JobID &job_id, const std::string &type,
const std::string &error_message, double timestamp);
};
using CustomSerializerTable = Table<ClassID, CustomSerializerData>;
@ -467,7 +484,7 @@ class ClientTable : private Log<UniqueID, ClientTableData> {
/// and begins subscription to client table notifications.
///
/// \param Information about the connecting client. This must have the
/// same client_id as the one set in the client table.
/// same client_id as the one set in the client table.
/// \return Status
ray::Status Connect(const ClientTableDataT &local_client);
@ -499,7 +516,7 @@ class ClientTable : private Log<UniqueID, ClientTableData> {
///
/// \param client The client to get information about.
/// \return A reference to the requested client. If the client is not in the
/// cache, then an entry with a nil ClientID will be returned.
/// cache, then an entry with a nil ClientID will be returned.
const ClientTableDataT &GetClient(const ClientID &client);
/// Get the local client's ID.

View file

@ -58,7 +58,10 @@ enum MessageType:int {
WaitRequest,
// The response message to WaitRequest; replies with the objects found and objects
// remaining.
WaitReply
WaitReply,
// Push an error to the relevant driver. This is sent from a worker to the
// node manager.
PushErrorRequest,
}
table TaskExecutionSpecification {
@ -154,3 +157,15 @@ table WaitReply {
// List of object ids not found.
remaining: [string];
}
// This struct is the same as ErrorTableData.
table PushErrorRequest {
// The ID of the job that the error is for.
job_id: string;
// The type of the error.
type: string;
// The error message.
error_message: string;
// The timestamp of the error message.
timestamp: double;
}

View file

@ -3,6 +3,7 @@
#include "common_protocol.h"
#include "local_scheduler/format/local_scheduler_generated.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/util/util.h"
namespace {
@ -372,11 +373,28 @@ void NodeManager::ProcessClientMessage(
// This if statement distinguishes workers from drivers.
if (worker) {
// TODO(swang): Handle the case where the worker is killed while
// executing a task. Clean up the assigned task's resources, return an
// error to the driver.
// RAY_CHECK(worker->GetAssignedTaskId().is_nil())
// << "Worker died while executing task: " << worker->GetAssignedTaskId();
// Handle the case where the worker is killed while executing a task.
// Clean up the assigned task's resources, push an error to the driver.
const TaskID &task_id = worker->GetAssignedTaskId();
if (!task_id.is_nil()) {
auto const &running_tasks = local_queues_.GetRunningTasks();
// TODO(rkn): This is too heavyweight just to get the task's driver ID.
auto const it = std::find_if(
running_tasks.begin(), running_tasks.end(), [task_id](const Task &task) {
return task.GetTaskSpecification().TaskId() == task_id;
});
RAY_CHECK(running_tasks.size() != 0);
RAY_CHECK(it != running_tasks.end());
JobID job_id = it->GetTaskSpecification().DriverId();
// TODO(rkn): Define this constant somewhere else.
std::string type = "worker_died";
std::ostringstream error_message;
error_message << "A worker died or was killed while executing task " << task_id
<< ".";
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
job_id, type, error_message.str(), current_time_ms()));
}
worker_pool_.DisconnectWorker(worker);
const ClientID &client_id = gcs_client_->client_table().GetLocalClientId();
@ -521,6 +539,17 @@ void NodeManager::ProcessClientMessage(
});
RAY_CHECK_OK(status);
} break;
case protocol::MessageType::PushErrorRequest: {
auto message = flatbuffers::GetRoot<protocol::PushErrorRequest>(message_data);
JobID job_id = from_flatbuf(*message->job_id());
auto const &type = string_from_flatbuf(*message->type());
auto const &error_message = string_from_flatbuf(*message->error_message());
double timestamp = message->timestamp();
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message,
timestamp));
} break;
default:
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;

View file

@ -1,6 +1,7 @@
install(FILES
logging.h
macros.h
util.h
visibility.h
DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/ray/util"
)

19
src/ray/util/util.h Normal file
View file

@ -0,0 +1,19 @@
#ifndef RAY_UTIL_UTIL_H
#define RAY_UTIL_UTIL_H
#include <chrono>
/// Return the number of milliseconds since the Unix epoch.
///
/// TODO(rkn): This function appears in multiple places. It should be
/// deduplicated.
///
/// \return The number of milliseconds since the Unix epoch.
int64_t current_time_ms() {
std::chrono::milliseconds ms_since_epoch =
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now().time_since_epoch());
return ms_since_epoch.count();
}
#endif // RAY_UTIL_UTIL_H

View file

@ -269,9 +269,6 @@ class WorkerDeath(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") == "1",
"This test does not work with xray yet.")
def testWorkerRaisingException(self):
ray.init(num_workers=1, driver_mode=ray.SILENT_MODE)
@ -287,9 +284,6 @@ class WorkerDeath(unittest.TestCase):
wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1)
self.assertEqual(len(ray.error_info()), 2)
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") == "1",
"This test does not work with xray yet.")
def testWorkerDying(self):
ray.init(num_workers=0, driver_mode=ray.SILENT_MODE)
@ -303,7 +297,7 @@ class WorkerDeath(unittest.TestCase):
wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1)
self.assertEqual(len(ray.error_info()), 1)
self.assertIn("died or was killed while executing the task",
self.assertIn("died or was killed while executing",
ray.error_info()[0]["message"])
@unittest.skipIf(

View file

@ -2243,7 +2243,7 @@ class GlobalStateAPI(unittest.TestCase):
worker_ids = set(ray.get([f.remote() for _ in range(10)]))
worker_info = ray.global_state.workers()
self.assertEqual(len(worker_info), num_workers)
assert len(worker_info) >= num_workers
for worker_id, info in worker_info.items():
self.assertIn("node_ip_address", info)
self.assertIn("local_scheduler_socket", info)