[xray] All messages on main asio event loop should be written asynchronously (#3023)

* copy over ref code

* wip async writes

* compiles

* fix error handling

* add test

* amend

* fix test

* clang fmgt

* clang format

* wip

* yapf

* rename format script

* test error

* clangfmt

* add test to list

* warn

* ref test

* fix test

* comment

* add capture

* Update client_connection.cc

* wip

* fix compile
This commit is contained in:
Eric Liang 2018-10-18 21:56:22 -07:00 committed by Stephanie Wang
parent fa469783d8
commit 9d23fa03c9
12 changed files with 446 additions and 131 deletions

View file

@ -54,7 +54,7 @@ matrix:
- cd ..
# Run Python linting, ignore dict vs {} (C408), others are defaults
- flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504
- .travis/yapf.sh --all
- .travis/format.sh --all
- os: linux
dist: trusty
@ -185,6 +185,7 @@ install:
- ./src/ray/raylet/lineage_cache_test
- ./src/ray/raylet/task_dependency_manager_test
- ./src/ray/raylet/reconstruction_policy_test
- ./src/ray/raylet/client_connection_test
- ./src/ray/util/logging_test --gtest_filter=PrintLogTest*
- ./src/ray/util/signal_test

View file

@ -1,4 +1,6 @@
#!/usr/bin/env bash
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
@ -51,6 +53,13 @@ format_changed() {
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
fi
if which clang-format >/dev/null; then
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.cc' '*.h' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.cc' '*.h' | xargs -P 5 \
clang-format -i
fi
fi
}
# Format all files, and print the diff to stdout for travis.

View file

@ -1,5 +1,6 @@
#include "client_connection.h"
#include <stdio.h>
#include <boost/bind.hpp>
#include "common.h"
@ -18,9 +19,19 @@ ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket,
return boost_to_ray_status(error);
}
template <class T>
std::shared_ptr<ServerConnection<T>> ServerConnection<T>::Create(
boost::asio::basic_stream_socket<T> &&socket) {
std::shared_ptr<ServerConnection<T>> self(new ServerConnection(std::move(socket)));
return self;
}
template <class T>
ServerConnection<T>::ServerConnection(boost::asio::basic_stream_socket<T> &&socket)
: socket_(std::move(socket)) {}
: socket_(std::move(socket)),
async_write_max_messages_(1),
async_write_queue_(),
async_write_in_flight_(false) {}
template <class T>
Status ServerConnection<T>::WriteBuffer(
@ -78,11 +89,80 @@ ray::Status ServerConnection<T>::WriteMessage(int64_t type, int64_t length,
message_buffers.push_back(boost::asio::buffer(&type, sizeof(type)));
message_buffers.push_back(boost::asio::buffer(&length, sizeof(length)));
message_buffers.push_back(boost::asio::buffer(message, length));
// Write the message and then wait for more messages.
// TODO(swang): Does this need to be an async write?
return WriteBuffer(message_buffers);
}
template <class T>
void ServerConnection<T>::WriteMessageAsync(
int64_t type, int64_t length, const uint8_t *message,
const std::function<void(const ray::Status &)> &handler) {
auto write_buffer = std::unique_ptr<AsyncWriteBuffer>(new AsyncWriteBuffer());
write_buffer->write_version = RayConfig::instance().ray_protocol_version();
write_buffer->write_type = type;
write_buffer->write_length = length;
write_buffer->write_message.resize(length);
write_buffer->write_message.assign(message, message + length);
write_buffer->handler = handler;
auto size = async_write_queue_.size();
auto size_is_power_of_two = (size & (size - 1)) == 0;
if (size > 100 && size_is_power_of_two) {
RAY_LOG(WARNING) << "ServerConnection has " << size << " buffered async writes";
}
async_write_queue_.push_back(std::move(write_buffer));
if (!async_write_in_flight_) {
DoAsyncWrites();
}
}
template <class T>
void ServerConnection<T>::DoAsyncWrites() {
// Make sure we were not writing to the socket.
RAY_CHECK(!async_write_in_flight_);
async_write_in_flight_ = true;
// Do an async write of everything currently in the queue to the socket.
std::vector<boost::asio::const_buffer> message_buffers;
int num_messages = 0;
for (const auto &write_buffer : async_write_queue_) {
message_buffers.push_back(boost::asio::buffer(&write_buffer->write_version,
sizeof(write_buffer->write_version)));
message_buffers.push_back(
boost::asio::buffer(&write_buffer->write_type, sizeof(write_buffer->write_type)));
message_buffers.push_back(boost::asio::buffer(&write_buffer->write_length,
sizeof(write_buffer->write_length)));
message_buffers.push_back(boost::asio::buffer(write_buffer->write_message));
num_messages++;
if (num_messages >= async_write_max_messages_) {
break;
}
}
auto this_ptr = this->shared_from_this();
boost::asio::async_write(
ServerConnection<T>::socket_, message_buffers,
[this, this_ptr, num_messages](const boost::system::error_code &error,
size_t bytes_transferred) {
ray::Status status = ray::Status::OK();
if (error.value() != boost::system::errc::errc_t::success) {
status = boost_to_ray_status(error);
}
// Call the handlers for the written messages.
for (int i = 0; i < num_messages; i++) {
auto write_buffer = std::move(async_write_queue_.front());
write_buffer->handler(status);
async_write_queue_.pop_front();
}
// We finished writing, so mark that we're no longer doing an async write.
async_write_in_flight_ = false;
// If there is more to write, try to write the rest.
if (!async_write_queue_.empty()) {
DoAsyncWrites();
}
});
}
template <class T>
std::shared_ptr<ClientConnection<T>> ClientConnection<T>::Create(
ClientHandler<T> &client_handler, MessageHandler<T> &message_handler,
@ -122,8 +202,8 @@ void ClientConnection<T>::ProcessMessages() {
header.push_back(boost::asio::buffer(&read_length_, sizeof(read_length_)));
boost::asio::async_read(
ServerConnection<T>::socket_, header,
boost::bind(&ClientConnection<T>::ProcessMessageHeader, this->shared_from_this(),
boost::asio::placeholders::error));
boost::bind(&ClientConnection<T>::ProcessMessageHeader,
shared_ClientConnection_from_this(), boost::asio::placeholders::error));
}
template <class T>
@ -143,8 +223,8 @@ void ClientConnection<T>::ProcessMessageHeader(const boost::system::error_code &
// Wait for the message to be read.
boost::asio::async_read(
ServerConnection<T>::socket_, boost::asio::buffer(read_message_),
boost::bind(&ClientConnection<T>::ProcessMessage, this->shared_from_this(),
boost::asio::placeholders::error));
boost::bind(&ClientConnection<T>::ProcessMessage,
shared_ClientConnection_from_this(), boost::asio::placeholders::error));
}
template <class T>
@ -154,7 +234,7 @@ void ClientConnection<T>::ProcessMessage(const boost::system::error_code &error)
}
uint64_t start_ms = current_time_ms();
message_handler_(this->shared_from_this(), read_type_, read_message_.data());
message_handler_(shared_ClientConnection_from_this(), read_type_, read_message_.data());
uint64_t interval = current_time_ms() - start_ms;
if (interval > RayConfig::instance().handler_warning_timeout_ms()) {
RAY_LOG(WARNING) << "[" << debug_label_ << "]ProcessMessage with type " << read_type_

View file

@ -1,6 +1,7 @@
#ifndef RAY_COMMON_CLIENT_CONNECTION_H
#define RAY_COMMON_CLIENT_CONNECTION_H
#include <list>
#include <memory>
#include <boost/asio.hpp>
@ -26,10 +27,14 @@ ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket,
/// A generic type representing a client connection to a server. This typename
/// can be used to write messages synchronously to the server.
template <typename T>
class ServerConnection {
class ServerConnection : public std::enable_shared_from_this<ServerConnection<T>> {
public:
/// Create a connection to the server.
ServerConnection(boost::asio::basic_stream_socket<T> &&socket);
/// Allocate a new server connection.
///
/// \param socket A reference to the server socket.
/// \return std::shared_ptr<ServerConnection>.
static std::shared_ptr<ServerConnection<T>> Create(
boost::asio::basic_stream_socket<T> &&socket);
/// Write a message to the client.
///
@ -39,6 +44,15 @@ class ServerConnection {
/// \return Status.
ray::Status WriteMessage(int64_t type, int64_t length, const uint8_t *message);
/// Write a message to the client asynchronously.
///
/// \param type The message type (e.g., a flatbuffer enum).
/// \param length The size in bytes of the message.
/// \param message A pointer to the message buffer.
/// \param handler A callback to run on write completion.
void WriteMessageAsync(int64_t type, int64_t length, const uint8_t *message,
const std::function<void(const ray::Status &)> &handler);
/// Write a buffer to this connection.
///
/// \param buffer The buffer.
@ -52,9 +66,42 @@ class ServerConnection {
void ReadBuffer(const std::vector<boost::asio::mutable_buffer> &buffer,
boost::system::error_code &ec);
/// Shuts down socket for this connection.
void Close() {
boost::system::error_code ec;
socket_.close(ec);
}
protected:
/// A private constructor for a server connection.
ServerConnection(boost::asio::basic_stream_socket<T> &&socket);
/// A message that is queued for writing asynchronously.
struct AsyncWriteBuffer {
int64_t write_version;
int64_t write_type;
uint64_t write_length;
std::vector<uint8_t> write_message;
std::function<void(const ray::Status &)> handler;
};
/// The socket connection to the server.
boost::asio::basic_stream_socket<T> socket_;
/// Max number of messages to write out at once.
const int async_write_max_messages_;
/// List of pending messages to write.
std::list<std::unique_ptr<AsyncWriteBuffer>> async_write_queue_;
/// Whether we are in the middle of an async write.
bool async_write_in_flight_;
private:
/// Asynchronously flushes the write queue. While async writes are running, the flag
/// async_write_in_flight_ will be set. This should only be called when no async writes
/// are currently in flight.
void DoAsyncWrites();
};
template <typename T>
@ -72,9 +119,10 @@ using MessageHandler =
/// writing messages to the client, like in ServerConnection, this typename can
/// also be used to process messages asynchronously from client.
template <typename T>
class ClientConnection : public ServerConnection<T>,
public std::enable_shared_from_this<ClientConnection<T>> {
class ClientConnection : public ServerConnection<T> {
public:
using std::enable_shared_from_this<ServerConnection<T>>::shared_from_this;
/// Allocate a new node client connection.
///
/// \param new_client_handler A reference to the client handler.
@ -85,6 +133,10 @@ class ClientConnection : public ServerConnection<T>,
ClientHandler<T> &new_client_handler, MessageHandler<T> &message_handler,
boost::asio::basic_stream_socket<T> &&socket, const std::string &debug_label);
std::shared_ptr<ClientConnection<T>> shared_ClientConnection_from_this() {
return std::static_pointer_cast<ClientConnection<T>>(shared_from_this());
}
/// \return The ClientID of the remote client.
const ClientID &GetClientID();

View file

@ -266,35 +266,31 @@ void ObjectManager::PullEstablishConnection(const ObjectID &object_id,
}
connection_pool_.RegisterSender(ConnectionPool::ConnectionType::MESSAGE,
client_id, async_conn);
Status pull_send_status = PullSendRequest(object_id, async_conn);
if (!pull_send_status.ok()) {
CheckIOError(pull_send_status, "Pull");
}
PullSendRequest(object_id, async_conn);
},
[]() {
RAY_LOG(ERROR) << "Failed to establish connection with remote object manager.";
});
} else {
status = PullSendRequest(object_id, conn);
if (!status.ok()) {
CheckIOError(status, "Pull");
}
PullSendRequest(object_id, conn);
}
}
ray::Status ObjectManager::PullSendRequest(const ObjectID &object_id,
std::shared_ptr<SenderConnection> &conn) {
void ObjectManager::PullSendRequest(const ObjectID &object_id,
std::shared_ptr<SenderConnection> &conn) {
flatbuffers::FlatBufferBuilder fbb;
auto message = object_manager_protocol::CreatePullRequestMessage(
fbb, fbb.CreateString(client_id_.binary()), fbb.CreateString(object_id.binary()));
fbb.Finish(message);
Status status = conn->WriteMessage(
conn->WriteMessageAsync(
static_cast<int64_t>(object_manager_protocol::MessageType::PullRequest),
fbb.GetSize(), fbb.GetBufferPointer());
if (status.ok()) {
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::MESSAGE, conn);
}
return status;
fbb.GetSize(), fbb.GetBufferPointer(), [this, conn](ray::Status status) mutable {
if (status.ok()) {
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::MESSAGE, conn);
} else {
CheckIOError(status, "Pull");
}
});
}
void ObjectManager::HandlePushTaskTimeout(const ObjectID &object_id,

View file

@ -245,10 +245,10 @@ class ObjectManager : public ObjectManagerInterface {
/// Executes on main_service_ thread.
void PullEstablishConnection(const ObjectID &object_id, const ClientID &client_id);
/// Synchronously send a pull request via remote object manager connection.
/// Asynchronously send a pull request via remote object manager connection.
/// Executes on main_service_ thread.
ray::Status PullSendRequest(const ObjectID &object_id,
std::shared_ptr<SenderConnection> &conn);
void PullSendRequest(const ObjectID &object_id,
std::shared_ptr<SenderConnection> &conn);
std::shared_ptr<SenderConnection> CreateSenderConnection(
ConnectionPool::ConnectionType type, RemoteConnectionInfo info);

View file

@ -11,7 +11,7 @@ std::shared_ptr<SenderConnection> SenderConnection::Create(
Status status = TcpConnect(socket, ip, port);
if (status.ok()) {
std::shared_ptr<TcpServerConnection> conn =
std::make_shared<TcpServerConnection>(std::move(socket));
TcpServerConnection::Create(std::move(socket));
return std::make_shared<SenderConnection>(std::move(conn), client_id);
} else {
return nullptr;

View file

@ -16,6 +16,7 @@
namespace ray {
// TODO(ekl) this class can be replaced with a plain ClientConnection
class SenderConnection : public boost::enable_shared_from_this<SenderConnection> {
public:
/// Create a connection for sending data to other object managers.
@ -44,6 +45,17 @@ class SenderConnection : public boost::enable_shared_from_this<SenderConnection>
return conn_->WriteMessage(type, length, message);
}
/// Write a message to the client asynchronously.
///
/// \param type The message type (e.g., a flatbuffer enum).
/// \param length The size in bytes of the message.
/// \param message A pointer to the message buffer.
/// \param handler A callback to run on write completion.
void WriteMessageAsync(int64_t type, int64_t length, const uint8_t *message,
const std::function<void(const ray::Status &)> &handler) {
conn_->WriteMessageAsync(type, length, message, handler);
}
/// Write a buffer to this connection.
///
/// \param buffer The buffer.

View file

@ -32,6 +32,7 @@ ADD_RAY_TEST(object_manager_integration_test STATIC_LINK_LIBS ray_static ${PLASM
ADD_RAY_TEST(worker_pool_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
ADD_RAY_TEST(client_connection_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
ADD_RAY_TEST(task_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
ADD_RAY_TEST(lineage_cache_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
ADD_RAY_TEST(task_dependency_manager_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})

View file

@ -0,0 +1,155 @@
#include <list>
#include <memory>
#include <boost/asio.hpp>
#include <boost/asio/error.hpp>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ray/common/client_connection.h"
namespace ray {
namespace raylet {
class ClientConnectionTest : public ::testing::Test {
public:
ClientConnectionTest() : io_service_(), in_(io_service_), out_(io_service_) {
boost::asio::local::connect_pair(in_, out_);
}
protected:
boost::asio::io_service io_service_;
boost::asio::local::stream_protocol::socket in_;
boost::asio::local::stream_protocol::socket out_;
};
TEST_F(ClientConnectionTest, SimpleSyncWrite) {
const uint8_t arr[5] = {1, 2, 3, 4, 5};
int num_messages = 0;
ClientHandler<boost::asio::local::stream_protocol> client_handler =
[](LocalClientConnection &client) {};
MessageHandler<boost::asio::local::stream_protocol> message_handler =
[&arr, &num_messages](std::shared_ptr<LocalClientConnection> client,
int64_t message_type, const uint8_t *message) {
ASSERT_TRUE(!std::memcmp(arr, message, 5));
num_messages += 1;
};
auto conn1 = LocalClientConnection::Create(client_handler, message_handler,
std::move(in_), "conn1");
auto conn2 = LocalClientConnection::Create(client_handler, message_handler,
std::move(out_), "conn2");
RAY_CHECK_OK(conn1->WriteMessage(0, 5, arr));
RAY_CHECK_OK(conn2->WriteMessage(0, 5, arr));
conn1->ProcessMessages();
conn2->ProcessMessages();
io_service_.run();
ASSERT_EQ(num_messages, 2);
}
TEST_F(ClientConnectionTest, SimpleAsyncWrite) {
const uint8_t msg1[5] = {1, 2, 3, 4, 5};
const uint8_t msg2[5] = {4, 4, 4, 4, 4};
const uint8_t msg3[5] = {8, 8, 8, 8, 8};
int num_messages = 0;
ClientHandler<boost::asio::local::stream_protocol> client_handler =
[](LocalClientConnection &client) {};
MessageHandler<boost::asio::local::stream_protocol> noop_handler = [](
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
const uint8_t *message) {};
std::shared_ptr<LocalClientConnection> reader = NULL;
MessageHandler<boost::asio::local::stream_protocol> message_handler =
[&msg1, &msg2, &msg3, &num_messages, &reader](
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
const uint8_t *message) {
if (num_messages == 0) {
ASSERT_TRUE(!std::memcmp(msg1, message, 5));
} else if (num_messages == 1) {
ASSERT_TRUE(!std::memcmp(msg2, message, 5));
} else {
ASSERT_TRUE(!std::memcmp(msg3, message, 5));
}
num_messages += 1;
if (num_messages < 3) {
reader->ProcessMessages();
}
};
auto writer = LocalClientConnection::Create(client_handler, noop_handler,
std::move(in_), "writer");
reader = LocalClientConnection::Create(client_handler, message_handler, std::move(out_),
"reader");
std::function<void(const ray::Status &)> callback = [](const ray::Status &status) {
RAY_CHECK_OK(status);
};
writer->WriteMessageAsync(0, 5, msg1, callback);
writer->WriteMessageAsync(0, 5, msg2, callback);
writer->WriteMessageAsync(0, 5, msg3, callback);
reader->ProcessMessages();
io_service_.run();
ASSERT_EQ(num_messages, 3);
}
TEST_F(ClientConnectionTest, SimpleAsyncError) {
const uint8_t msg1[5] = {1, 2, 3, 4, 5};
ClientHandler<boost::asio::local::stream_protocol> client_handler =
[](LocalClientConnection &client) {};
MessageHandler<boost::asio::local::stream_protocol> noop_handler = [](
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
const uint8_t *message) {};
auto writer = LocalClientConnection::Create(client_handler, noop_handler,
std::move(in_), "writer");
std::function<void(const ray::Status &)> callback = [](const ray::Status &status) {
ASSERT_TRUE(!status.ok());
};
writer->Close();
writer->WriteMessageAsync(0, 5, msg1, callback);
io_service_.run();
}
TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) {
const uint8_t msg1[5] = {1, 2, 3, 4, 5};
ClientHandler<boost::asio::local::stream_protocol> client_handler =
[](LocalClientConnection &client) {};
MessageHandler<boost::asio::local::stream_protocol> noop_handler = [](
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
const uint8_t *message) {};
auto writer = LocalClientConnection::Create(client_handler, noop_handler,
std::move(in_), "writer");
std::function<void(const ray::Status &)> callback =
[writer](const ray::Status &status) {
static_cast<void>(writer);
ASSERT_TRUE(status.ok());
};
writer->WriteMessageAsync(0, 5, msg1, callback);
io_service_.run();
}
} // namespace raylet
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View file

@ -359,7 +359,7 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) {
}
// The client is connected.
auto server_conn = TcpServerConnection(std::move(socket));
auto server_conn = TcpServerConnection::Create(std::move(socket));
remote_server_connections_.emplace(client_id, std::move(server_conn));
ResourceSet resources_total(client_data.resources_total_label,
@ -1304,56 +1304,59 @@ void NodeManager::AssignTask(Task &task) {
auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb),
fbb.CreateVector(resource_id_set_flatbuf));
fbb.Finish(message);
auto status = worker->Connection()->WriteMessage(
worker->Connection()->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::ExecuteTask), fbb.GetSize(),
fbb.GetBufferPointer());
if (status.ok()) {
// We successfully assigned the task to the worker.
worker->AssignTaskId(spec.TaskId());
worker->AssignDriverId(spec.DriverId());
// If the task was an actor task, then record this execution to guarantee
// consistency in the case of reconstruction.
if (spec.IsActorTask()) {
auto actor_entry = actor_registry_.find(spec.ActorId());
RAY_CHECK(actor_entry != actor_registry_.end());
auto execution_dependency = actor_entry->second.GetExecutionDependency();
// The execution dependency is initialized to the actor creation task's
// return value, and is subsequently updated to the assigned tasks'
// return values, so it should never be nil.
RAY_CHECK(!execution_dependency.is_nil());
// Update the task's execution dependencies to reflect the actual
// execution order, to support deterministic reconstruction.
// NOTE(swang): The update of an actor task's execution dependencies is
// performed asynchronously. This means that if this node manager dies,
// we may lose updates that are in flight to the task table. We only
// guarantee deterministic reconstruction ordering for tasks whose
// updates are reflected in the task table.
task.SetExecutionDependencies({execution_dependency});
// Extend the frontier to include the executing task.
actor_entry->second.ExtendFrontier(spec.ActorHandleId(), spec.ActorDummyObject());
}
// We started running the task, so the task is ready to write to GCS.
if (!lineage_cache_.AddReadyTask(task)) {
RAY_LOG(WARNING)
<< "Task " << spec.TaskId()
<< " already in lineage cache. This is most likely due to reconstruction.";
}
// Mark the task as running.
// (See design_docs/task_states.rst for the state transition diagram.)
local_queues_.QueueRunningTasks(std::vector<Task>({task}));
// Notify the task dependency manager that we no longer need this task's
// object dependencies.
task_dependency_manager_.UnsubscribeDependencies(spec.TaskId());
} else {
RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client";
// We failed to send the task to the worker, so disconnect the worker.
ProcessDisconnectClientMessage(worker->Connection());
// Queue this task for future assignment. The task will be assigned to a
// worker once one becomes available.
// (See design_docs/task_states.rst for the state transition diagram.)
local_queues_.QueueReadyTasks(std::vector<Task>({task}));
DispatchTasks();
}
fbb.GetBufferPointer(), [this, worker, task](ray::Status status) mutable {
if (status.ok()) {
auto spec = task.GetTaskSpecification();
// We successfully assigned the task to the worker.
worker->AssignTaskId(spec.TaskId());
worker->AssignDriverId(spec.DriverId());
// If the task was an actor task, then record this execution to guarantee
// consistency in the case of reconstruction.
if (spec.IsActorTask()) {
auto actor_entry = actor_registry_.find(spec.ActorId());
RAY_CHECK(actor_entry != actor_registry_.end());
auto execution_dependency = actor_entry->second.GetExecutionDependency();
// The execution dependency is initialized to the actor creation task's
// return value, and is subsequently updated to the assigned tasks'
// return values, so it should never be nil.
RAY_CHECK(!execution_dependency.is_nil());
// Update the task's execution dependencies to reflect the actual
// execution order, to support deterministic reconstruction.
// NOTE(swang): The update of an actor task's execution dependencies is
// performed asynchronously. This means that if this node manager dies,
// we may lose updates that are in flight to the task table. We only
// guarantee deterministic reconstruction ordering for tasks whose
// updates are reflected in the task table.
task.SetExecutionDependencies({execution_dependency});
// Extend the frontier to include the executing task.
actor_entry->second.ExtendFrontier(spec.ActorHandleId(),
spec.ActorDummyObject());
}
// We started running the task, so the task is ready to write to GCS.
if (!lineage_cache_.AddReadyTask(task)) {
RAY_LOG(WARNING) << "Task " << spec.TaskId() << " already in lineage cache. "
"This is most likely due to "
"reconstruction.";
}
// Mark the task as running.
// (See design_docs/task_states.rst for the state transition diagram.)
local_queues_.QueueRunningTasks(std::vector<Task>({task}));
// Notify the task dependency manager that we no longer need this task's
// object dependencies.
task_dependency_manager_.UnsubscribeDependencies(spec.TaskId());
} else {
RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client";
// We failed to send the task to the worker, so disconnect the worker.
ProcessDisconnectClientMessage(worker->Connection());
// Queue this task for future assignment. The task will be assigned to a
// worker once one becomes available.
// (See design_docs/task_states.rst for the state transition diagram.)
local_queues_.QueueReadyTasks(std::vector<Task>({task}));
DispatchTasks();
}
});
}
void NodeManager::FinishAssignedTask(Worker &worker) {
@ -1522,10 +1525,10 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task,
const ClientID &node_manager_id) {
/// TODO(rkn): Should we check that the node manager is remote and not local?
/// TODO(rkn): Should we check if the remote node manager is known to be dead?
const TaskID task_id = task.GetTaskSpecification().TaskId();
// Attempt to forward the task.
if (!ForwardTask(task, node_manager_id).ok()) {
ForwardTask(task, node_manager_id, [this, task, node_manager_id](ray::Status error) {
const TaskID task_id = task.GetTaskSpecification().TaskId();
RAY_LOG(INFO) << "Failed to forward task " << task_id << " to node manager "
<< node_manager_id;
// Mark the failed task as pending to let other raylets know that we still
@ -1564,10 +1567,11 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task,
ScheduleTasks(cluster_resource_map_);
DispatchTasks();
}
}
});
}
ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) {
void NodeManager::ForwardTask(const Task &task, const ClientID &node_id,
const std::function<void(const ray::Status &)> &on_error) {
const auto &spec = task.GetTaskSpecification();
auto task_id = spec.TaskId();
@ -1593,49 +1597,53 @@ ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id)
if (it == remote_server_connections_.end()) {
// TODO(atumanov): caller must handle failure to ensure tasks are not lost.
RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id;
return ray::Status::IOError("NodeManager connection not found");
on_error(ray::Status::IOError("NodeManager connection not found"));
return;
}
auto &server_conn = it->second;
auto status = server_conn.WriteMessage(
server_conn->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(),
fbb.GetBufferPointer());
if (status.ok()) {
// If we were able to forward the task, remove the forwarded task from the
// lineage cache since the receiving node is now responsible for writing
// the task to the GCS.
if (!lineage_cache_.RemoveWaitingTask(task_id)) {
RAY_LOG(WARNING) << "Task " << task_id << " already removed from the lineage "
"cache. This is most likely due to "
"reconstruction.";
}
// Mark as forwarded so that the task and its lineage is not re-forwarded
// in the future to the receiving node.
lineage_cache_.MarkTaskAsForwarded(task_id, node_id);
// Notify the task dependency manager that we are no longer responsible
// for executing this task.
task_dependency_manager_.TaskCanceled(task_id);
// Preemptively push any local arguments to the receiving node. For now, we
// only do this with actor tasks, since actor tasks must be executed by a
// specific process and therefore have affinity to the receiving node.
if (spec.IsActorTask()) {
// Iterate through the object's arguments. NOTE(swang): We do not include
// the execution dependencies here since those cannot be transferred
// between nodes.
for (int i = 0; i < spec.NumArgs(); ++i) {
int count = spec.ArgIdCount(i);
for (int j = 0; j < count; j++) {
ObjectID argument_id = spec.ArgId(i, j);
// If the argument is local, then push it to the receiving node.
if (task_dependency_manager_.CheckObjectLocal(argument_id)) {
object_manager_.Push(argument_id, node_id);
fbb.GetBufferPointer(),
[this, on_error, task_id, node_id, spec](ray::Status status) {
if (status.ok()) {
// If we were able to forward the task, remove the forwarded task from the
// lineage cache since the receiving node is now responsible for writing
// the task to the GCS.
if (!lineage_cache_.RemoveWaitingTask(task_id)) {
RAY_LOG(WARNING) << "Task " << task_id << " already removed from the lineage "
"cache. This is most likely due to "
"reconstruction.";
}
// Mark as forwarded so that the task and its lineage is not re-forwarded
// in the future to the receiving node.
lineage_cache_.MarkTaskAsForwarded(task_id, node_id);
// Notify the task dependency manager that we are no longer responsible
// for executing this task.
task_dependency_manager_.TaskCanceled(task_id);
// Preemptively push any local arguments to the receiving node. For now, we
// only do this with actor tasks, since actor tasks must be executed by a
// specific process and therefore have affinity to the receiving node.
if (spec.IsActorTask()) {
// Iterate through the object's arguments. NOTE(swang): We do not include
// the execution dependencies here since those cannot be transferred
// between nodes.
for (int i = 0; i < spec.NumArgs(); ++i) {
int count = spec.ArgIdCount(i);
for (int j = 0; j < count; j++) {
ObjectID argument_id = spec.ArgId(i, j);
// If the argument is local, then push it to the receiving node.
if (task_dependency_manager_.CheckObjectLocal(argument_id)) {
object_manager_.Push(argument_id, node_id);
}
}
}
}
} else {
on_error(status);
}
}
}
}
return status;
});
}
} // namespace raylet

View file

@ -174,10 +174,10 @@ class NodeManager {
///
/// \param task The task to forward.
/// \param node_id The ID of the node to forward the task to.
/// \return A status indicating whether the forward succeeded or not. Note
/// that a status of OK is not a reliable indicator that the forward succeeded
/// or even that the remote node is still alive.
ray::Status ForwardTask(const Task &task, const ClientID &node_id);
/// \param on_error Callback on run on non-ok status.
void ForwardTask(const Task &task, const ClientID &node_id,
const std::function<void(const ray::Status &)> &on_error);
/// Dispatch locally scheduled tasks. This attempts the transition from "scheduled" to
/// "running" task state.
void DispatchTasks();
@ -352,7 +352,8 @@ class NodeManager {
/// The lineage cache for the GCS object and task tables.
LineageCache lineage_cache_;
std::vector<ClientID> remote_clients_;
std::unordered_map<ClientID, TcpServerConnection> remote_server_connections_;
std::unordered_map<ClientID, std::shared_ptr<TcpServerConnection>>
remote_server_connections_;
/// A mapping from actor ID to registration information about that actor
/// (including which node manager owns it).
std::unordered_map<ActorID, ActorRegistration> actor_registry_;