Make local scheduler client thread-safe (#2386)

* Make local scheduler client thread-safe for python

* lock write_messages

* remove allow-threads

* fix linter

* rename _write_message to do_write_message
This commit is contained in:
Hao Chen 2018-07-14 07:19:00 +08:00 committed by Philipp Moritz
parent 62f84d2f07
commit c1575e98c1
4 changed files with 76 additions and 37 deletions

View file

@ -251,7 +251,7 @@ int write_bytes(int fd, uint8_t *cursor, size_t length) {
return 0;
}
int write_message(int fd, int64_t type, int64_t length, uint8_t *bytes) {
int do_write_message(int fd, int64_t type, int64_t length, uint8_t *bytes) {
int64_t version = RayConfig::instance().ray_protocol_version();
int closed;
closed = write_bytes(fd, (uint8_t *) &version, sizeof(version));
@ -273,6 +273,19 @@ int write_message(int fd, int64_t type, int64_t length, uint8_t *bytes) {
return 0;
}
int write_message(int fd,
int64_t type,
int64_t length,
uint8_t *bytes,
std::mutex *mutex) {
if (mutex != NULL) {
std::unique_lock<std::mutex> guard(*mutex);
return do_write_message(fd, type, length, bytes);
} else {
return do_write_message(fd, type, length, bytes);
}
}
int read_bytes(int fd, uint8_t *cursor, size_t length) {
ssize_t nbytes = 0;
/* Termination condition: EOF or read 'length' bytes total. */
@ -388,8 +401,8 @@ disconnected:
void write_log_message(int fd, const char *message) {
/* Account for the \0 at the end of the string. */
write_message(fd, static_cast<int64_t>(CommonMessageType::LOG_MESSAGE),
strlen(message) + 1, (uint8_t *) message);
do_write_message(fd, static_cast<int64_t>(CommonMessageType::LOG_MESSAGE),
strlen(message) + 1, (uint8_t *) message);
}
char *read_log_message(int fd) {

View file

@ -4,6 +4,7 @@
#include <stdint.h>
#include <stdlib.h>
#include <mutex>
#include <vector>
struct aeEventLoop;
@ -121,10 +122,16 @@ int accept_client(int socket_fd);
* @param type The type of the message to send.
* @param length The size in bytes of the bytes parameter.
* @param bytes The address of the message to send.
* @param mutex If not NULL, the whole write operation will be locked
* with this mutex, otherwise do nothing.
* @return int Whether there was an error while writing. 0 corresponds to
* success and -1 corresponds to an error (errno will be set).
*/
int write_message(int fd, int64_t type, int64_t length, uint8_t *bytes);
int write_message(int fd,
int64_t type,
int64_t length,
uint8_t *bytes,
std::mutex *mutex = NULL);
/**
* Read a sequence of bytes written by write_message from a file descriptor.

View file

@ -31,7 +31,7 @@ LocalSchedulerConnection *LocalSchedulerConnection_init(
/* Register the process ID with the local scheduler. */
int success = write_message(
result->conn, static_cast<int64_t>(MessageType::RegisterClientRequest),
fbb.GetSize(), fbb.GetBufferPointer());
fbb.GetSize(), fbb.GetBufferPointer(), &result->write_mutex);
RAY_CHECK(success == 0) << "Unable to register worker with local scheduler";
return result;
@ -47,7 +47,7 @@ void local_scheduler_disconnect_client(LocalSchedulerConnection *conn) {
auto message = ray::local_scheduler::protocol::CreateDisconnectClient(fbb);
fbb.Finish(message);
write_message(conn->conn, static_cast<int64_t>(MessageType::DisconnectClient),
fbb.GetSize(), fbb.GetBufferPointer());
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
}
void local_scheduler_log_event(LocalSchedulerConnection *conn,
@ -63,7 +63,7 @@ void local_scheduler_log_event(LocalSchedulerConnection *conn,
fbb, key_string, value_string, timestamp);
fbb.Finish(message);
write_message(conn->conn, static_cast<int64_t>(MessageType::EventLogMessage),
fbb.GetSize(), fbb.GetBufferPointer());
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
}
void local_scheduler_submit(LocalSchedulerConnection *conn,
@ -78,7 +78,7 @@ void local_scheduler_submit(LocalSchedulerConnection *conn,
fbb, execution_dependencies, task_spec);
fbb.Finish(message);
write_message(conn->conn, static_cast<int64_t>(MessageType::SubmitTask),
fbb.GetSize(), fbb.GetBufferPointer());
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
}
void local_scheduler_submit_raylet(
@ -91,19 +91,22 @@ void local_scheduler_submit_raylet(
fbb, execution_dependencies_message, task_spec.ToFlatbuffer(fbb));
fbb.Finish(message);
write_message(conn->conn, static_cast<int64_t>(MessageType::SubmitTask),
fbb.GetSize(), fbb.GetBufferPointer());
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
}
TaskSpec *local_scheduler_get_task(LocalSchedulerConnection *conn,
int64_t *task_size) {
write_message(conn->conn, static_cast<int64_t>(MessageType::GetTask), 0,
NULL);
int64_t type;
int64_t reply_size;
uint8_t *reply;
/* Receive a task from the local scheduler. This will block until the local
* scheduler gives this client a task. */
read_message(conn->conn, &type, &reply_size, &reply);
{
std::unique_lock<std::mutex> guard(conn->mutex);
write_message(conn->conn, static_cast<int64_t>(MessageType::GetTask), 0,
NULL, &conn->write_mutex);
/* Receive a task from the local scheduler. This will block until the local
* scheduler gives this client a task. */
read_message(conn->conn, &type, &reply_size, &reply);
}
if (type == static_cast<int64_t>(CommonMessageType::DISCONNECT_CLIENT)) {
RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection.";
exit(1);
@ -139,14 +142,17 @@ TaskSpec *local_scheduler_get_task(LocalSchedulerConnection *conn,
// the raylet and non-raylet code paths.
TaskSpec *local_scheduler_get_task_raylet(LocalSchedulerConnection *conn,
int64_t *task_size) {
write_message(conn->conn, static_cast<int64_t>(MessageType::GetTask), 0,
NULL);
int64_t type;
int64_t reply_size;
uint8_t *reply;
// Receive a task from the local scheduler. This will block until the local
// scheduler gives this client a task.
read_message(conn->conn, &type, &reply_size, &reply);
{
std::unique_lock<std::mutex> guard(conn->mutex);
write_message(conn->conn, static_cast<int64_t>(MessageType::GetTask), 0,
NULL, &conn->write_mutex);
// Receive a task from the local scheduler. This will block until the local
// scheduler gives this client a task.
read_message(conn->conn, &type, &reply_size, &reply);
}
if (type == static_cast<int64_t>(CommonMessageType::DISCONNECT_CLIENT)) {
RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection.";
exit(1);
@ -197,7 +203,7 @@ TaskSpec *local_scheduler_get_task_raylet(LocalSchedulerConnection *conn,
void local_scheduler_task_done(LocalSchedulerConnection *conn) {
write_message(conn->conn, static_cast<int64_t>(MessageType::TaskDone), 0,
NULL);
NULL, &conn->write_mutex);
}
void local_scheduler_reconstruct_objects(
@ -211,18 +217,18 @@ void local_scheduler_reconstruct_objects(
fbb.Finish(message);
write_message(conn->conn,
static_cast<int64_t>(MessageType::ReconstructObjects),
fbb.GetSize(), fbb.GetBufferPointer());
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
/* TODO(swang): Propagate the error. */
}
void local_scheduler_log_message(LocalSchedulerConnection *conn) {
write_message(conn->conn, static_cast<int64_t>(MessageType::EventLogMessage),
0, NULL);
0, NULL, &conn->write_mutex);
}
void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn) {
write_message(conn->conn, static_cast<int64_t>(MessageType::NotifyUnblocked),
0, NULL);
0, NULL, &conn->write_mutex);
}
void local_scheduler_put_object(LocalSchedulerConnection *conn,
@ -234,7 +240,7 @@ void local_scheduler_put_object(LocalSchedulerConnection *conn,
fbb.Finish(message);
write_message(conn->conn, static_cast<int64_t>(MessageType::PutObject),
fbb.GetSize(), fbb.GetBufferPointer());
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
}
const std::vector<uint8_t> local_scheduler_get_actor_frontier(
@ -244,13 +250,16 @@ const std::vector<uint8_t> local_scheduler_get_actor_frontier(
auto message = ray::local_scheduler::protocol::CreateGetActorFrontierRequest(
fbb, to_flatbuf(fbb, actor_id));
fbb.Finish(message);
write_message(conn->conn,
static_cast<int64_t>(MessageType::GetActorFrontierRequest),
fbb.GetSize(), fbb.GetBufferPointer());
int64_t type;
std::vector<uint8_t> reply;
read_vector(conn->conn, &type, reply);
{
std::unique_lock<std::mutex> guard(conn->mutex);
write_message(conn->conn,
static_cast<int64_t>(MessageType::GetActorFrontierRequest),
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
read_vector(conn->conn, &type, reply);
}
if (static_cast<CommonMessageType>(type) ==
CommonMessageType::DISCONNECT_CLIENT) {
RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection.";
@ -264,7 +273,8 @@ const std::vector<uint8_t> local_scheduler_get_actor_frontier(
void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
const std::vector<uint8_t> &frontier) {
write_message(conn->conn, static_cast<int64_t>(MessageType::SetActorFrontier),
frontier.size(), const_cast<uint8_t *>(frontier.data()));
frontier.size(), const_cast<uint8_t *>(frontier.data()),
&conn->write_mutex);
}
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
@ -279,14 +289,17 @@ std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds,
wait_local);
fbb.Finish(message);
write_message(conn->conn,
static_cast<int64_t>(ray::protocol::MessageType::WaitRequest),
fbb.GetSize(), fbb.GetBufferPointer());
// Read result.
int64_t type;
int64_t reply_size;
uint8_t *reply;
read_message(conn->conn, &type, &reply_size, &reply);
{
std::unique_lock<std::mutex> guard(conn->mutex);
write_message(conn->conn,
static_cast<int64_t>(ray::protocol::MessageType::WaitRequest),
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
// Read result.
read_message(conn->conn, &type, &reply_size, &reply);
}
RAY_CHECK(static_cast<ray::protocol::MessageType>(type) ==
ray::protocol::MessageType::WaitReply);
auto reply_message = flatbuffers::GetRoot<ray::protocol::WaitReply>(reply);
@ -320,7 +333,7 @@ void local_scheduler_push_error(LocalSchedulerConnection *conn,
write_message(conn->conn, static_cast<int64_t>(
ray::protocol::MessageType::PushErrorRequest),
fbb.GetSize(), fbb.GetBufferPointer());
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
}
void local_scheduler_push_profile_events(
@ -334,5 +347,5 @@ void local_scheduler_push_profile_events(
write_message(conn->conn,
static_cast<int64_t>(
ray::protocol::MessageType::PushProfileEventsRequest),
fbb.GetSize(), fbb.GetBufferPointer());
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
}

View file

@ -1,6 +1,8 @@
#ifndef LOCAL_SCHEDULER_CLIENT_H
#define LOCAL_SCHEDULER_CLIENT_H
#include <mutex>
#include "common/task.h"
#include "local_scheduler_shared.h"
#include "ray/raylet/task_spec.h"
@ -19,6 +21,10 @@ struct LocalSchedulerConnection {
/// of that resource allocated for this worker.
std::unordered_map<std::string, std::vector<std::pair<int64_t, double>>>
resource_ids_;
/// A mutex to protect stateful operations of the local scheduler client.
std::mutex mutex;
/// A mutext to protect write operations of the local scheduler client.
std::mutex write_mutex;
};
/**