mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Make local scheduler client thread-safe (#2386)
* Make local scheduler client thread-safe for python * lock write_messages * remove allow-threads * fix linter * rename _write_message to do_write_message
This commit is contained in:
parent
62f84d2f07
commit
c1575e98c1
4 changed files with 76 additions and 37 deletions
|
@ -251,7 +251,7 @@ int write_bytes(int fd, uint8_t *cursor, size_t length) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
int write_message(int fd, int64_t type, int64_t length, uint8_t *bytes) {
|
||||
int do_write_message(int fd, int64_t type, int64_t length, uint8_t *bytes) {
|
||||
int64_t version = RayConfig::instance().ray_protocol_version();
|
||||
int closed;
|
||||
closed = write_bytes(fd, (uint8_t *) &version, sizeof(version));
|
||||
|
@ -273,6 +273,19 @@ int write_message(int fd, int64_t type, int64_t length, uint8_t *bytes) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
int write_message(int fd,
|
||||
int64_t type,
|
||||
int64_t length,
|
||||
uint8_t *bytes,
|
||||
std::mutex *mutex) {
|
||||
if (mutex != NULL) {
|
||||
std::unique_lock<std::mutex> guard(*mutex);
|
||||
return do_write_message(fd, type, length, bytes);
|
||||
} else {
|
||||
return do_write_message(fd, type, length, bytes);
|
||||
}
|
||||
}
|
||||
|
||||
int read_bytes(int fd, uint8_t *cursor, size_t length) {
|
||||
ssize_t nbytes = 0;
|
||||
/* Termination condition: EOF or read 'length' bytes total. */
|
||||
|
@ -388,8 +401,8 @@ disconnected:
|
|||
|
||||
void write_log_message(int fd, const char *message) {
|
||||
/* Account for the \0 at the end of the string. */
|
||||
write_message(fd, static_cast<int64_t>(CommonMessageType::LOG_MESSAGE),
|
||||
strlen(message) + 1, (uint8_t *) message);
|
||||
do_write_message(fd, static_cast<int64_t>(CommonMessageType::LOG_MESSAGE),
|
||||
strlen(message) + 1, (uint8_t *) message);
|
||||
}
|
||||
|
||||
char *read_log_message(int fd) {
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
struct aeEventLoop;
|
||||
|
@ -121,10 +122,16 @@ int accept_client(int socket_fd);
|
|||
* @param type The type of the message to send.
|
||||
* @param length The size in bytes of the bytes parameter.
|
||||
* @param bytes The address of the message to send.
|
||||
* @param mutex If not NULL, the whole write operation will be locked
|
||||
* with this mutex, otherwise do nothing.
|
||||
* @return int Whether there was an error while writing. 0 corresponds to
|
||||
* success and -1 corresponds to an error (errno will be set).
|
||||
*/
|
||||
int write_message(int fd, int64_t type, int64_t length, uint8_t *bytes);
|
||||
int write_message(int fd,
|
||||
int64_t type,
|
||||
int64_t length,
|
||||
uint8_t *bytes,
|
||||
std::mutex *mutex = NULL);
|
||||
|
||||
/**
|
||||
* Read a sequence of bytes written by write_message from a file descriptor.
|
||||
|
|
|
@ -31,7 +31,7 @@ LocalSchedulerConnection *LocalSchedulerConnection_init(
|
|||
/* Register the process ID with the local scheduler. */
|
||||
int success = write_message(
|
||||
result->conn, static_cast<int64_t>(MessageType::RegisterClientRequest),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), &result->write_mutex);
|
||||
RAY_CHECK(success == 0) << "Unable to register worker with local scheduler";
|
||||
|
||||
return result;
|
||||
|
@ -47,7 +47,7 @@ void local_scheduler_disconnect_client(LocalSchedulerConnection *conn) {
|
|||
auto message = ray::local_scheduler::protocol::CreateDisconnectClient(fbb);
|
||||
fbb.Finish(message);
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::DisconnectClient),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
|
||||
}
|
||||
|
||||
void local_scheduler_log_event(LocalSchedulerConnection *conn,
|
||||
|
@ -63,7 +63,7 @@ void local_scheduler_log_event(LocalSchedulerConnection *conn,
|
|||
fbb, key_string, value_string, timestamp);
|
||||
fbb.Finish(message);
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::EventLogMessage),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
|
||||
}
|
||||
|
||||
void local_scheduler_submit(LocalSchedulerConnection *conn,
|
||||
|
@ -78,7 +78,7 @@ void local_scheduler_submit(LocalSchedulerConnection *conn,
|
|||
fbb, execution_dependencies, task_spec);
|
||||
fbb.Finish(message);
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::SubmitTask),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
|
||||
}
|
||||
|
||||
void local_scheduler_submit_raylet(
|
||||
|
@ -91,19 +91,22 @@ void local_scheduler_submit_raylet(
|
|||
fbb, execution_dependencies_message, task_spec.ToFlatbuffer(fbb));
|
||||
fbb.Finish(message);
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::SubmitTask),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
|
||||
}
|
||||
|
||||
TaskSpec *local_scheduler_get_task(LocalSchedulerConnection *conn,
|
||||
int64_t *task_size) {
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::GetTask), 0,
|
||||
NULL);
|
||||
int64_t type;
|
||||
int64_t reply_size;
|
||||
uint8_t *reply;
|
||||
/* Receive a task from the local scheduler. This will block until the local
|
||||
* scheduler gives this client a task. */
|
||||
read_message(conn->conn, &type, &reply_size, &reply);
|
||||
{
|
||||
std::unique_lock<std::mutex> guard(conn->mutex);
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::GetTask), 0,
|
||||
NULL, &conn->write_mutex);
|
||||
/* Receive a task from the local scheduler. This will block until the local
|
||||
* scheduler gives this client a task. */
|
||||
read_message(conn->conn, &type, &reply_size, &reply);
|
||||
}
|
||||
if (type == static_cast<int64_t>(CommonMessageType::DISCONNECT_CLIENT)) {
|
||||
RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection.";
|
||||
exit(1);
|
||||
|
@ -139,14 +142,17 @@ TaskSpec *local_scheduler_get_task(LocalSchedulerConnection *conn,
|
|||
// the raylet and non-raylet code paths.
|
||||
TaskSpec *local_scheduler_get_task_raylet(LocalSchedulerConnection *conn,
|
||||
int64_t *task_size) {
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::GetTask), 0,
|
||||
NULL);
|
||||
int64_t type;
|
||||
int64_t reply_size;
|
||||
uint8_t *reply;
|
||||
// Receive a task from the local scheduler. This will block until the local
|
||||
// scheduler gives this client a task.
|
||||
read_message(conn->conn, &type, &reply_size, &reply);
|
||||
{
|
||||
std::unique_lock<std::mutex> guard(conn->mutex);
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::GetTask), 0,
|
||||
NULL, &conn->write_mutex);
|
||||
// Receive a task from the local scheduler. This will block until the local
|
||||
// scheduler gives this client a task.
|
||||
read_message(conn->conn, &type, &reply_size, &reply);
|
||||
}
|
||||
if (type == static_cast<int64_t>(CommonMessageType::DISCONNECT_CLIENT)) {
|
||||
RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection.";
|
||||
exit(1);
|
||||
|
@ -197,7 +203,7 @@ TaskSpec *local_scheduler_get_task_raylet(LocalSchedulerConnection *conn,
|
|||
|
||||
void local_scheduler_task_done(LocalSchedulerConnection *conn) {
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::TaskDone), 0,
|
||||
NULL);
|
||||
NULL, &conn->write_mutex);
|
||||
}
|
||||
|
||||
void local_scheduler_reconstruct_objects(
|
||||
|
@ -211,18 +217,18 @@ void local_scheduler_reconstruct_objects(
|
|||
fbb.Finish(message);
|
||||
write_message(conn->conn,
|
||||
static_cast<int64_t>(MessageType::ReconstructObjects),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
|
||||
/* TODO(swang): Propagate the error. */
|
||||
}
|
||||
|
||||
void local_scheduler_log_message(LocalSchedulerConnection *conn) {
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::EventLogMessage),
|
||||
0, NULL);
|
||||
0, NULL, &conn->write_mutex);
|
||||
}
|
||||
|
||||
void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn) {
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::NotifyUnblocked),
|
||||
0, NULL);
|
||||
0, NULL, &conn->write_mutex);
|
||||
}
|
||||
|
||||
void local_scheduler_put_object(LocalSchedulerConnection *conn,
|
||||
|
@ -234,7 +240,7 @@ void local_scheduler_put_object(LocalSchedulerConnection *conn,
|
|||
fbb.Finish(message);
|
||||
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::PutObject),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
|
||||
}
|
||||
|
||||
const std::vector<uint8_t> local_scheduler_get_actor_frontier(
|
||||
|
@ -244,13 +250,16 @@ const std::vector<uint8_t> local_scheduler_get_actor_frontier(
|
|||
auto message = ray::local_scheduler::protocol::CreateGetActorFrontierRequest(
|
||||
fbb, to_flatbuf(fbb, actor_id));
|
||||
fbb.Finish(message);
|
||||
write_message(conn->conn,
|
||||
static_cast<int64_t>(MessageType::GetActorFrontierRequest),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
|
||||
int64_t type;
|
||||
std::vector<uint8_t> reply;
|
||||
read_vector(conn->conn, &type, reply);
|
||||
{
|
||||
std::unique_lock<std::mutex> guard(conn->mutex);
|
||||
write_message(conn->conn,
|
||||
static_cast<int64_t>(MessageType::GetActorFrontierRequest),
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
|
||||
|
||||
read_vector(conn->conn, &type, reply);
|
||||
}
|
||||
if (static_cast<CommonMessageType>(type) ==
|
||||
CommonMessageType::DISCONNECT_CLIENT) {
|
||||
RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection.";
|
||||
|
@ -264,7 +273,8 @@ const std::vector<uint8_t> local_scheduler_get_actor_frontier(
|
|||
void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
|
||||
const std::vector<uint8_t> &frontier) {
|
||||
write_message(conn->conn, static_cast<int64_t>(MessageType::SetActorFrontier),
|
||||
frontier.size(), const_cast<uint8_t *>(frontier.data()));
|
||||
frontier.size(), const_cast<uint8_t *>(frontier.data()),
|
||||
&conn->write_mutex);
|
||||
}
|
||||
|
||||
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
|
||||
|
@ -279,14 +289,17 @@ std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
|
|||
fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds,
|
||||
wait_local);
|
||||
fbb.Finish(message);
|
||||
write_message(conn->conn,
|
||||
static_cast<int64_t>(ray::protocol::MessageType::WaitRequest),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
// Read result.
|
||||
int64_t type;
|
||||
int64_t reply_size;
|
||||
uint8_t *reply;
|
||||
read_message(conn->conn, &type, &reply_size, &reply);
|
||||
{
|
||||
std::unique_lock<std::mutex> guard(conn->mutex);
|
||||
write_message(conn->conn,
|
||||
static_cast<int64_t>(ray::protocol::MessageType::WaitRequest),
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
|
||||
// Read result.
|
||||
read_message(conn->conn, &type, &reply_size, &reply);
|
||||
}
|
||||
RAY_CHECK(static_cast<ray::protocol::MessageType>(type) ==
|
||||
ray::protocol::MessageType::WaitReply);
|
||||
auto reply_message = flatbuffers::GetRoot<ray::protocol::WaitReply>(reply);
|
||||
|
@ -320,7 +333,7 @@ void local_scheduler_push_error(LocalSchedulerConnection *conn,
|
|||
|
||||
write_message(conn->conn, static_cast<int64_t>(
|
||||
ray::protocol::MessageType::PushErrorRequest),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
|
||||
}
|
||||
|
||||
void local_scheduler_push_profile_events(
|
||||
|
@ -334,5 +347,5 @@ void local_scheduler_push_profile_events(
|
|||
write_message(conn->conn,
|
||||
static_cast<int64_t>(
|
||||
ray::protocol::MessageType::PushProfileEventsRequest),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
#ifndef LOCAL_SCHEDULER_CLIENT_H
|
||||
#define LOCAL_SCHEDULER_CLIENT_H
|
||||
|
||||
#include <mutex>
|
||||
|
||||
#include "common/task.h"
|
||||
#include "local_scheduler_shared.h"
|
||||
#include "ray/raylet/task_spec.h"
|
||||
|
@ -19,6 +21,10 @@ struct LocalSchedulerConnection {
|
|||
/// of that resource allocated for this worker.
|
||||
std::unordered_map<std::string, std::vector<std::pair<int64_t, double>>>
|
||||
resource_ids_;
|
||||
/// A mutex to protect stateful operations of the local scheduler client.
|
||||
std::mutex mutex;
|
||||
/// A mutext to protect write operations of the local scheduler client.
|
||||
std::mutex write_mutex;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
Loading…
Add table
Reference in a new issue