[xray] Adds a driver table. (#2289)

This PR adds a driver table for the new GCS, which enables cleanup functionality associated with monitoring driver death.

Some testing in `monitor_test.py` is restored, but redis sharding for xray is needed to enable remaining tests.
This commit is contained in:
Melih Elibol 2018-08-09 02:41:40 -04:00 committed by Robert Nishihara
parent df7ee7ff1e
commit 8ae82180b4
20 changed files with 230 additions and 24 deletions

2
.gitignore vendored
View file

@ -18,7 +18,7 @@
# Files generated by flatc should be ignored
/src/common/format/*.py
/src/common/format/*_generated.h
/src/plasma/format/*_generated.h
/src/plasma/format/
/src/local_scheduler/format/*_generated.h
/src/ray/gcs/format/*_generated.h
/src/ray/object_manager/format/*_generated.h

View file

@ -148,7 +148,7 @@ matrix:
# - pytest test/component_failures_test.py
- python test/multi_node_test.py
- python -m pytest test/recursion_test.py
# - pytest test/monitor_test.py
- pytest test/monitor_test.py
- python -m pytest test/cython_test.py
- python -m pytest test/credis_test.py

View file

@ -47,6 +47,7 @@ MOCK_MODULES = ["gym",
"ray.core.generated.ClientTableData",
"ray.core.generated.GcsTableEntry",
"ray.core.generated.HeartbeatTableData",
"ray.core.generated.DriverTableData",
"ray.core.generated.ErrorTableData",
"ray.core.generated.ProfileTableData",
"ray.core.generated.ObjectTableData",

View file

@ -169,7 +169,7 @@ class GlobalState(object):
"""
result = []
for client in self.redis_clients:
result.extend(client.keys(pattern))
result.extend(list(client.scan_iter(match=pattern)))
return result
def _object_table(self, object_id):

View file

@ -24,6 +24,7 @@ from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.ProfileTableData import ProfileTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.DriverTableData import DriverTableData
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ray.protocol.Task import Task
@ -34,9 +35,9 @@ __all__ = [
"SubscribeToNotificationsReply", "ResultTableReply",
"TaskExecutionDependencies", "TaskReply", "DriverTableMessage",
"LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo",
"GcsTableEntry", "ClientTableData", "ErrorTableData", "ProfileTableData",
"HeartbeatTableData", "ObjectTableData", "Task", "TablePrefix",
"TablePubsub", "construct_error_message"
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
"DriverTableData", "ProfileTableData", "ObjectTableData", "Task",
"TablePrefix", "TablePubsub", "construct_error_message"
]
# These prefixes must be kept up-to-date with the definitions in

View file

@ -3,12 +3,12 @@ from __future__ import division
from __future__ import print_function
from ray.core.src.local_scheduler.liblocal_scheduler_library_python import (
Task, LocalSchedulerClient, ObjectID, check_simple_value, task_from_string,
task_to_string, _config, common_error)
Task, LocalSchedulerClient, ObjectID, check_simple_value, compute_task_id,
task_from_string, task_to_string, _config, common_error)
from .local_scheduler_services import start_local_scheduler
__all__ = [
"Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
"task_from_string", "task_to_string", "start_local_scheduler", "_config",
"common_error"
"compute_task_id", "task_from_string", "task_to_string",
"start_local_scheduler", "_config", "common_error"
]

View file

@ -37,6 +37,9 @@ DRIVER_DEATH_CHANNEL = b"driver_deaths"
XRAY_HEARTBEAT_CHANNEL = str(
ray.gcs_utils.TablePubsub.HEARTBEAT).encode("ascii")
# xray driver updates
XRAY_DRIVER_CHANNEL = str(ray.gcs_utils.TablePubsub.DRIVER).encode("ascii")
# common/redis_module/ray_redis_module.cc
OBJECT_INFO_PREFIX = b"OI:"
OBJECT_LOCATION_PREFIX = b"OL:"
@ -496,6 +499,87 @@ class Monitor(object):
self._clean_up_entries_for_driver(driver_id)
def _xray_clean_up_entries_for_driver(self, driver_id):
"""Remove this driver's object/task entries from redis.
Removes control-state entries of all tasks and task return
objects belonging to the driver.
Args:
driver_id: The driver id.
"""
xray_task_table_prefix = (
ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii"))
xray_object_table_prefix = (
ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))
task_table_objects = self.state.task_table()
driver_id_hex = binary_to_hex(driver_id)
driver_task_id_bins = set()
for task_id_hex in task_table_objects:
if len(task_table_objects[task_id_hex]) == 0:
continue
task_table_object = task_table_objects[task_id_hex][0]["TaskSpec"]
task_driver_id_hex = task_table_object["DriverID"]
if driver_id_hex != task_driver_id_hex:
# Ignore tasks that aren't from this driver.
continue
driver_task_id_bins.add(hex_to_binary(task_id_hex))
# Get objects associated with the driver.
object_table_objects = self.state.object_table()
driver_object_id_bins = set()
for object_id, object_table_object in object_table_objects.items():
assert len(object_table_object) > 0
task_id_bin = ray.local_scheduler.compute_task_id(object_id).id()
if task_id_bin in driver_task_id_bins:
driver_object_id_bins.add(object_id.id())
def to_shard_index(id_bin):
return binary_to_object_id(id_bin).redis_shard_hash() % len(
self.state.redis_clients)
# Form the redis keys to delete.
sharded_keys = [[] for _ in range(len(self.state.redis_clients))]
for task_id_bin in driver_task_id_bins:
sharded_keys[to_shard_index(task_id_bin)].append(
xray_task_table_prefix + task_id_bin)
for object_id_bin in driver_object_id_bins:
sharded_keys[to_shard_index(object_id_bin)].append(
xray_object_table_prefix + object_id_bin)
# Remove with best effort.
for shard_index in range(len(sharded_keys)):
keys = sharded_keys[shard_index]
if len(keys) == 0:
continue
redis = self.state.redis_clients[shard_index]
num_deleted = redis.delete(*keys)
log.info("Removed {} dead redis entries of the driver"
" from redis shard {}.".format(num_deleted, shard_index))
if num_deleted != len(keys):
log.warning("Failed to remove {} relevant redis entries"
" from redis shard {}.".format(
len(keys) - num_deleted, shard_index))
def xray_driver_removed_handler(self, unused_channel, data):
"""Handle a notification that a driver has been removed.
Args:
unused_channel: The message channel.
data: The message data.
"""
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0)
driver_data = gcs_entries.Entries(0)
message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
driver_data, 0)
driver_id = message.DriverId()
log.info("XRay Driver {} has been removed.".format(
binary_to_hex(driver_id)))
self._xray_clean_up_entries_for_driver(driver_id)
def process_messages(self, max_messages=10000):
"""Process all messages ready in the subscription channels.
@ -537,6 +621,9 @@ class Monitor(object):
elif channel == XRAY_HEARTBEAT_CHANNEL:
# Similar functionality as local scheduler info channel
message_handler = self.xray_heartbeat_handler
elif channel == XRAY_DRIVER_CHANNEL:
# Handles driver death.
message_handler = self.xray_driver_removed_handler
else:
raise Exception("This code should be unreachable.")
@ -582,7 +669,7 @@ class Monitor(object):
max_entries_to_flush = self.gcs_flush_policy.num_entries_to_flush()
num_flushed = self.redis_shard.execute_command(
"HEAD.FLUSH {}".format(max_entries_to_flush))
log.info('num_flushed {}'.format(num_flushed))
log.info("num_flushed {}".format(num_flushed))
# This flushes event log and log files.
ray.experimental.flush_redis_unsafe(self.redis)
@ -601,6 +688,7 @@ class Monitor(object):
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(DRIVER_DEATH_CHANNEL)
self.subscribe(XRAY_HEARTBEAT_CHANNEL, primary=False)
self.subscribe(XRAY_DRIVER_CHANNEL)
# Scan the database table for dead database clients. NOTE: This must be
# called before reading any messages from the subscription channel.

View file

@ -907,3 +907,12 @@ PyObject *check_simple_value(PyObject *self, PyObject *args) {
}
Py_RETURN_FALSE;
}
PyObject *compute_task_id(PyObject *self, PyObject *args) {
ObjectID object_id;
if (!PyArg_ParseTuple(args, "O&", &PyObjectToUniqueID, &object_id)) {
return NULL;
}
TaskID task_id = ray::ComputeTaskId(object_id);
return PyObjectID_make(task_id);
}

View file

@ -56,6 +56,7 @@ int PyObjectToUniqueID(PyObject *object, ray::ObjectID *object_id);
PyObject *PyObjectID_make(ray::ObjectID object_id);
PyObject *check_simple_value(PyObject *self, PyObject *args);
PyObject *compute_task_id(PyObject *self, PyObject *args);
PyObject *PyTask_to_string(PyObject *, PyObject *args);
PyObject *PyTask_from_string(PyObject *, PyObject *args);

View file

@ -493,6 +493,8 @@ static PyTypeObject PyLocalSchedulerClientType = {
static PyMethodDef local_scheduler_methods[] = {
{"check_simple_value", check_simple_value, METH_VARARGS,
"Should the object be passed by value?"},
{"compute_task_id", compute_task_id, METH_VARARGS,
"Return the task ID of an object ID."},
{"task_from_string", PyTask_from_string, METH_VARARGS,
"Creates a Python PyTask object from a string representation of "
"TaskSpec."},

View file

@ -17,6 +17,7 @@ AsyncGcsClient::AsyncGcsClient(const ClientID &client_id, CommandType command_ty
task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this));
task_lease_table_.reset(new TaskLeaseTable(context_, this));
heartbeat_table_.reset(new HeartbeatTable(context_, this));
driver_table_.reset(new DriverTable(primary_context_, this));
error_table_.reset(new ErrorTable(primary_context_, this));
profile_table_.reset(new ProfileTable(context_, this));
command_type_ = command_type;
@ -88,6 +89,8 @@ HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; }
ErrorTable &AsyncGcsClient::error_table() { return *error_table_; }
DriverTable &AsyncGcsClient::driver_table() { return *driver_table_; }
ProfileTable &AsyncGcsClient::profile_table() { return *profile_table_; }
} // namespace gcs

View file

@ -60,6 +60,7 @@ class RAY_EXPORT AsyncGcsClient {
ClientTable &client_table();
HeartbeatTable &heartbeat_table();
ErrorTable &error_table();
DriverTable &driver_table();
ProfileTable &profile_table();
// We also need something to export generic code to run on workers from the
@ -92,6 +93,7 @@ class RAY_EXPORT AsyncGcsClient {
std::unique_ptr<RedisAsioClient> asio_subscribe_client_;
// The following context writes everything to the primary shard
std::shared_ptr<RedisContext> primary_context_;
std::unique_ptr<DriverTable> driver_table_;
std::unique_ptr<RedisAsioClient> asio_async_auxiliary_client_;
std::unique_ptr<RedisAsioClient> asio_subscribe_auxiliary_client_;
CommandType command_type_;

View file

@ -4,6 +4,7 @@ enum Language:int {
JAVA = 2
}
// These indexes are mapped to strings in ray_redis_module.cc.
enum TablePrefix:int {
UNUSED = 0,
TASK,
@ -15,6 +16,7 @@ enum TablePrefix:int {
TASK_RECONSTRUCTION,
HEARTBEAT,
ERROR_INFO,
DRIVER,
PROFILE,
TASK_LEASE,
}
@ -30,6 +32,7 @@ enum TablePubsub:int {
HEARTBEAT,
ERROR_INFO,
TASK_LEASE,
DRIVER,
}
table GcsTableEntry {
@ -202,3 +205,10 @@ table TaskLeaseData {
// The period that the lease is active for.
timeout: long;
}
table DriverTableData {
// The driver ID.
driver_id: string;
// Whether it's dead.
is_dead: bool;
}

View file

@ -266,6 +266,17 @@ Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events
});
}
Status DriverTable::AppendDriverData(const JobID &driver_id, bool is_dead) {
auto data = std::make_shared<DriverTableDataT>();
data->driver_id = driver_id.binary();
data->is_dead = is_dead;
return Append(driver_id, driver_id, data,
[](ray::gcs::AsyncGcsClient *client, const JobID &id,
const DriverTableDataT &data) {
RAY_LOG(DEBUG) << "Driver entry added callback";
});
}
void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) {
client_added_callback_ = callback;
// Call the callback for any added clients that are cached.
@ -425,6 +436,7 @@ template class Table<TaskID, TaskLeaseData>;
template class Table<ClientID, HeartbeatTableData>;
template class Log<JobID, ErrorTableData>;
template class Log<UniqueID, ClientTableData>;
template class Log<JobID, DriverTableData>;
template class Log<UniqueID, ProfileTableData>;
} // namespace gcs

View file

@ -317,6 +317,23 @@ class HeartbeatTable : public Table<ClientID, HeartbeatTableData> {
virtual ~HeartbeatTable() {}
};
class DriverTable : public Log<JobID, DriverTableData> {
public:
DriverTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Log(context, client) {
pubsub_channel_ = TablePubsub::DRIVER;
prefix_ = TablePrefix::DRIVER;
};
virtual ~DriverTable() {}
/// Appends driver data to the driver table.
///
/// \param driver_id The driver id.
/// \param is_dead Whether the driver is dead.
/// \return The return status.
Status AppendDriverData(const JobID &driver_id, bool is_dead);
};
class FunctionTable : public Table<ObjectID, FunctionTableData> {
public:
FunctionTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)

View file

@ -25,8 +25,10 @@ ObjectManager::ObjectManager(asio::io_service &main_service,
RAY_CHECK(config_.max_sends > 0);
RAY_CHECK(config_.max_receives > 0);
main_service_ = &main_service;
store_notification_.SubscribeObjAdded(
[this](const ObjectInfoT &object_info) { NotifyDirectoryObjectAdd(object_info); });
store_notification_.SubscribeObjAdded([this](const ObjectInfoT &object_info) {
NotifyDirectoryObjectAdd(object_info);
HandleUnfulfilledPushRequests(object_info);
});
store_notification_.SubscribeObjDeleted(
[this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); });
StartIOService();
@ -49,8 +51,10 @@ ObjectManager::ObjectManager(asio::io_service &main_service,
RAY_CHECK(config_.max_receives > 0);
// TODO(hme) Client ID is never set with this constructor.
main_service_ = &main_service;
store_notification_.SubscribeObjAdded(
[this](const ObjectInfoT &object_info) { NotifyDirectoryObjectAdd(object_info); });
store_notification_.SubscribeObjAdded([this](const ObjectInfoT &object_info) {
NotifyDirectoryObjectAdd(object_info);
HandleUnfulfilledPushRequests(object_info);
});
store_notification_.SubscribeObjDeleted(
[this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); });
StartIOService();
@ -89,6 +93,10 @@ void ObjectManager::NotifyDirectoryObjectAdd(const ObjectInfoT &object_info) {
local_objects_[object_id] = object_info;
ray::Status status =
object_directory_->ReportObjectAdded(object_id, client_id_, object_info);
}
void ObjectManager::HandleUnfulfilledPushRequests(const ObjectInfoT &object_info) {
ObjectID object_id = ObjectID::from_binary(object_info.object_id);
// Handle the unfulfilled_push_requests_ which contains the push request that is not
// completed due to unsatisfied local objects.
auto iter = unfulfilled_push_requests_.find(object_id);

View file

@ -266,6 +266,10 @@ class ObjectManager : public ObjectManagerInterface {
/// Register object remove with directory.
void NotifyDirectoryObjectDeleted(const ObjectID &object_id);
/// Handle any push requests that were made before an object was available.
/// This is invoked when an "object added" notification is received from the store.
void HandleUnfulfilledPushRequests(const ObjectInfoT &object_info);
/// Part of an asynchronous sequence of Pull methods.
/// Uses an existing connection or creates a connection to ClientID.
/// Executes on main_service_ thread.

View file

@ -193,12 +193,33 @@ ray::Status NodeManager::RegisterGcs() {
RAY_LOG(DEBUG) << "heartbeat table subscription done callback called.";
}));
// Subscribe to driver table updates.
const auto driver_table_handler = [this](
gcs::AsyncGcsClient *client, const ClientID &client_id,
const std::vector<DriverTableDataT> &driver_data) {
HandleDriverTableUpdate(client_id, driver_data);
};
RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe(JobID::nil(), UniqueID::nil(),
driver_table_handler, nullptr));
// Start sending heartbeats to the GCS.
Heartbeat();
return ray::Status::OK();
}
void NodeManager::HandleDriverTableUpdate(
const ClientID &id, const std::vector<DriverTableDataT> &driver_data) {
for (const auto &entry : driver_data) {
RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::from_binary(entry.driver_id)
<< " " << entry.is_dead;
if (entry.is_dead) {
// TODO: Implement cleanup on driver death. For reference,
// see handle_driver_removed_callback in local_scheduler.cc
}
}
}
void NodeManager::Heartbeat() {
RAY_LOG(DEBUG) << "[Heartbeat] sending heartbeat.";
auto &heartbeat_table = gcs_client_->heartbeat_table();
@ -449,6 +470,7 @@ void NodeManager::ProcessClientMessage(
switch (static_cast<protocol::MessageType>(message_type)) {
case protocol::MessageType::RegisterClientRequest: {
auto message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
client->SetClientID(from_flatbuf(*message->client_id()));
auto worker = std::make_shared<Worker>(message->worker_pid(), client);
if (message->is_worker()) {
// Register the new worker.
@ -543,6 +565,8 @@ void NodeManager::ProcessClientMessage(
DispatchTasks();
} else {
// The client is a driver.
RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientID(),
/*is_dead=*/true));
const std::shared_ptr<Worker> driver = worker_pool_.GetRegisteredDriver(client);
RAY_CHECK(driver);
auto driver_id = driver->GetAssignedTaskId();

View file

@ -144,6 +144,10 @@ class NodeManager {
/// accounting, but does not write to any global accounting in the GCS.
void HandleObjectMissing(const ObjectID &object_id);
/// Handles updates to driver table.
void HandleDriverTableUpdate(const ClientID &id,
const std::vector<DriverTableDataT> &driver_data);
boost::asio::io_service &io_service_;
ObjectManager &object_manager_;
/// A Plasma object store client. This is used exclusively for creating new

View file

@ -41,11 +41,18 @@ class MonitorTest(unittest.TestCase):
if (0, 1) != summary_start[:2]:
success.value = False
max_attempts_before_failing = 100
# Two new objects.
ray.get(ray.put(1111))
ray.get(ray.put(1111))
if (2, 1, summary_start[2]) != StateSummary():
success.value = False
attempts = 0
while (2, 1, summary_start[2]) != StateSummary():
time.sleep(0.1)
attempts += 1
if attempts == max_attempts_before_failing:
success.value = False
break
@ray.remote
def f():
@ -53,12 +60,22 @@ class MonitorTest(unittest.TestCase):
return 1111 # A returned object as well.
# 1 new function.
if (2, 1, summary_start[2] + 1) != StateSummary():
success.value = False
attempts = 0
while (2, 1, summary_start[2] + 1) != StateSummary():
time.sleep(0.1)
attempts += 1
if attempts == max_attempts_before_failing:
success.value = False
break
ray.get(f.remote())
if (4, 2, summary_start[2] + 1) != StateSummary():
success.value = False
attempts = 0
while (4, 2, summary_start[2] + 1) != StateSummary():
time.sleep(0.1)
attempts += 1
if attempts == max_attempts_before_failing:
success.value = False
break
ray.shutdown()
@ -67,7 +84,7 @@ class MonitorTest(unittest.TestCase):
driver.start()
# Wait for client to exit.
driver.join()
time.sleep(5)
time.sleep(3)
# Just make sure Driver() is run and succeeded. Note(rkn), if the below
# assertion starts failing, then the issue may be that the summary
@ -85,13 +102,16 @@ class MonitorTest(unittest.TestCase):
subprocess.Popen(["ray", "stop"]).wait()
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
os.environ.get("RAY_USE_NEW_GCS", False),
"Failing with the new GCS API.")
def testCleanupOnDriverExitSingleRedisShard(self):
self._testCleanupOnDriverExit(num_redis_shards=1)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
os.environ.get("RAY_USE_XRAY") == "1",
"This test does not work with xray yet.")
@unittest.skipIf(
os.environ.get("RAY_USE_NEW_GCS", False),
"Hanging with the new GCS API.")
def testCleanupOnDriverExitManyRedisShards(self):
self._testCleanupOnDriverExit(num_redis_shards=5)