diff --git a/.travis.yml b/.travis.yml index 8416fa138..bd0fd929b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -54,7 +54,7 @@ matrix: - cd .. # Run Python linting, ignore dict vs {} (C408), others are defaults - flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504 - - .travis/yapf.sh --all + - .travis/format.sh --all - os: linux dist: trusty @@ -185,6 +185,7 @@ install: - ./src/ray/raylet/lineage_cache_test - ./src/ray/raylet/task_dependency_manager_test - ./src/ray/raylet/reconstruction_policy_test + - ./src/ray/raylet/client_connection_test - ./src/ray/util/logging_test --gtest_filter=PrintLogTest* - ./src/ray/util/signal_test diff --git a/.travis/yapf.sh b/.travis/format.sh similarity index 83% rename from .travis/yapf.sh rename to .travis/format.sh index d90aec895..ca92d5196 100755 --- a/.travis/yapf.sh +++ b/.travis/format.sh @@ -1,4 +1,6 @@ #!/usr/bin/env bash +# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# You are encouraged to run this locally before pushing changes for review. # Cause the script to exit if a single command fails set -eo pipefail @@ -51,6 +53,13 @@ format_changed() { git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \ yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" fi + + if which clang-format >/dev/null; then + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.cc' '*.h' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.cc' '*.h' | xargs -P 5 \ + clang-format -i + fi + fi } # Format all files, and print the diff to stdout for travis. diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index eaa479429..3347d1cfd 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -1,5 +1,6 @@ #include "client_connection.h" +#include #include #include "common.h" @@ -18,9 +19,19 @@ ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket, return boost_to_ray_status(error); } +template +std::shared_ptr> ServerConnection::Create( + boost::asio::basic_stream_socket &&socket) { + std::shared_ptr> self(new ServerConnection(std::move(socket))); + return self; +} + template ServerConnection::ServerConnection(boost::asio::basic_stream_socket &&socket) - : socket_(std::move(socket)) {} + : socket_(std::move(socket)), + async_write_max_messages_(1), + async_write_queue_(), + async_write_in_flight_(false) {} template Status ServerConnection::WriteBuffer( @@ -78,11 +89,80 @@ ray::Status ServerConnection::WriteMessage(int64_t type, int64_t length, message_buffers.push_back(boost::asio::buffer(&type, sizeof(type))); message_buffers.push_back(boost::asio::buffer(&length, sizeof(length))); message_buffers.push_back(boost::asio::buffer(message, length)); - // Write the message and then wait for more messages. - // TODO(swang): Does this need to be an async write? return WriteBuffer(message_buffers); } +template +void ServerConnection::WriteMessageAsync( + int64_t type, int64_t length, const uint8_t *message, + const std::function &handler) { + auto write_buffer = std::unique_ptr(new AsyncWriteBuffer()); + write_buffer->write_version = RayConfig::instance().ray_protocol_version(); + write_buffer->write_type = type; + write_buffer->write_length = length; + write_buffer->write_message.resize(length); + write_buffer->write_message.assign(message, message + length); + write_buffer->handler = handler; + + auto size = async_write_queue_.size(); + auto size_is_power_of_two = (size & (size - 1)) == 0; + if (size > 100 && size_is_power_of_two) { + RAY_LOG(WARNING) << "ServerConnection has " << size << " buffered async writes"; + } + + async_write_queue_.push_back(std::move(write_buffer)); + + if (!async_write_in_flight_) { + DoAsyncWrites(); + } +} + +template +void ServerConnection::DoAsyncWrites() { + // Make sure we were not writing to the socket. + RAY_CHECK(!async_write_in_flight_); + async_write_in_flight_ = true; + + // Do an async write of everything currently in the queue to the socket. + std::vector message_buffers; + int num_messages = 0; + for (const auto &write_buffer : async_write_queue_) { + message_buffers.push_back(boost::asio::buffer(&write_buffer->write_version, + sizeof(write_buffer->write_version))); + message_buffers.push_back( + boost::asio::buffer(&write_buffer->write_type, sizeof(write_buffer->write_type))); + message_buffers.push_back(boost::asio::buffer(&write_buffer->write_length, + sizeof(write_buffer->write_length))); + message_buffers.push_back(boost::asio::buffer(write_buffer->write_message)); + num_messages++; + if (num_messages >= async_write_max_messages_) { + break; + } + } + auto this_ptr = this->shared_from_this(); + boost::asio::async_write( + ServerConnection::socket_, message_buffers, + [this, this_ptr, num_messages](const boost::system::error_code &error, + size_t bytes_transferred) { + ray::Status status = ray::Status::OK(); + if (error.value() != boost::system::errc::errc_t::success) { + status = boost_to_ray_status(error); + } + // Call the handlers for the written messages. + for (int i = 0; i < num_messages; i++) { + auto write_buffer = std::move(async_write_queue_.front()); + write_buffer->handler(status); + async_write_queue_.pop_front(); + } + // We finished writing, so mark that we're no longer doing an async write. + async_write_in_flight_ = false; + // If there is more to write, try to write the rest. + if (!async_write_queue_.empty()) { + DoAsyncWrites(); + } + }); +} + template std::shared_ptr> ClientConnection::Create( ClientHandler &client_handler, MessageHandler &message_handler, @@ -122,8 +202,8 @@ void ClientConnection::ProcessMessages() { header.push_back(boost::asio::buffer(&read_length_, sizeof(read_length_))); boost::asio::async_read( ServerConnection::socket_, header, - boost::bind(&ClientConnection::ProcessMessageHeader, this->shared_from_this(), - boost::asio::placeholders::error)); + boost::bind(&ClientConnection::ProcessMessageHeader, + shared_ClientConnection_from_this(), boost::asio::placeholders::error)); } template @@ -143,8 +223,8 @@ void ClientConnection::ProcessMessageHeader(const boost::system::error_code & // Wait for the message to be read. boost::asio::async_read( ServerConnection::socket_, boost::asio::buffer(read_message_), - boost::bind(&ClientConnection::ProcessMessage, this->shared_from_this(), - boost::asio::placeholders::error)); + boost::bind(&ClientConnection::ProcessMessage, + shared_ClientConnection_from_this(), boost::asio::placeholders::error)); } template @@ -154,7 +234,7 @@ void ClientConnection::ProcessMessage(const boost::system::error_code &error) } uint64_t start_ms = current_time_ms(); - message_handler_(this->shared_from_this(), read_type_, read_message_.data()); + message_handler_(shared_ClientConnection_from_this(), read_type_, read_message_.data()); uint64_t interval = current_time_ms() - start_ms; if (interval > RayConfig::instance().handler_warning_timeout_ms()) { RAY_LOG(WARNING) << "[" << debug_label_ << "]ProcessMessage with type " << read_type_ diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index 20b232c33..83c9849d9 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -1,6 +1,7 @@ #ifndef RAY_COMMON_CLIENT_CONNECTION_H #define RAY_COMMON_CLIENT_CONNECTION_H +#include #include #include @@ -26,10 +27,14 @@ ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket, /// A generic type representing a client connection to a server. This typename /// can be used to write messages synchronously to the server. template -class ServerConnection { +class ServerConnection : public std::enable_shared_from_this> { public: - /// Create a connection to the server. - ServerConnection(boost::asio::basic_stream_socket &&socket); + /// Allocate a new server connection. + /// + /// \param socket A reference to the server socket. + /// \return std::shared_ptr. + static std::shared_ptr> Create( + boost::asio::basic_stream_socket &&socket); /// Write a message to the client. /// @@ -39,6 +44,15 @@ class ServerConnection { /// \return Status. ray::Status WriteMessage(int64_t type, int64_t length, const uint8_t *message); + /// Write a message to the client asynchronously. + /// + /// \param type The message type (e.g., a flatbuffer enum). + /// \param length The size in bytes of the message. + /// \param message A pointer to the message buffer. + /// \param handler A callback to run on write completion. + void WriteMessageAsync(int64_t type, int64_t length, const uint8_t *message, + const std::function &handler); + /// Write a buffer to this connection. /// /// \param buffer The buffer. @@ -52,9 +66,42 @@ class ServerConnection { void ReadBuffer(const std::vector &buffer, boost::system::error_code &ec); + /// Shuts down socket for this connection. + void Close() { + boost::system::error_code ec; + socket_.close(ec); + } + protected: + /// A private constructor for a server connection. + ServerConnection(boost::asio::basic_stream_socket &&socket); + + /// A message that is queued for writing asynchronously. + struct AsyncWriteBuffer { + int64_t write_version; + int64_t write_type; + uint64_t write_length; + std::vector write_message; + std::function handler; + }; + /// The socket connection to the server. boost::asio::basic_stream_socket socket_; + + /// Max number of messages to write out at once. + const int async_write_max_messages_; + + /// List of pending messages to write. + std::list> async_write_queue_; + + /// Whether we are in the middle of an async write. + bool async_write_in_flight_; + + private: + /// Asynchronously flushes the write queue. While async writes are running, the flag + /// async_write_in_flight_ will be set. This should only be called when no async writes + /// are currently in flight. + void DoAsyncWrites(); }; template @@ -72,9 +119,10 @@ using MessageHandler = /// writing messages to the client, like in ServerConnection, this typename can /// also be used to process messages asynchronously from client. template -class ClientConnection : public ServerConnection, - public std::enable_shared_from_this> { +class ClientConnection : public ServerConnection { public: + using std::enable_shared_from_this>::shared_from_this; + /// Allocate a new node client connection. /// /// \param new_client_handler A reference to the client handler. @@ -85,6 +133,10 @@ class ClientConnection : public ServerConnection, ClientHandler &new_client_handler, MessageHandler &message_handler, boost::asio::basic_stream_socket &&socket, const std::string &debug_label); + std::shared_ptr> shared_ClientConnection_from_this() { + return std::static_pointer_cast>(shared_from_this()); + } + /// \return The ClientID of the remote client. const ClientID &GetClientID(); diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index e6674fbf1..84be11066 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -266,35 +266,31 @@ void ObjectManager::PullEstablishConnection(const ObjectID &object_id, } connection_pool_.RegisterSender(ConnectionPool::ConnectionType::MESSAGE, client_id, async_conn); - Status pull_send_status = PullSendRequest(object_id, async_conn); - if (!pull_send_status.ok()) { - CheckIOError(pull_send_status, "Pull"); - } + PullSendRequest(object_id, async_conn); }, []() { RAY_LOG(ERROR) << "Failed to establish connection with remote object manager."; }); } else { - status = PullSendRequest(object_id, conn); - if (!status.ok()) { - CheckIOError(status, "Pull"); - } + PullSendRequest(object_id, conn); } } -ray::Status ObjectManager::PullSendRequest(const ObjectID &object_id, - std::shared_ptr &conn) { +void ObjectManager::PullSendRequest(const ObjectID &object_id, + std::shared_ptr &conn) { flatbuffers::FlatBufferBuilder fbb; auto message = object_manager_protocol::CreatePullRequestMessage( fbb, fbb.CreateString(client_id_.binary()), fbb.CreateString(object_id.binary())); fbb.Finish(message); - Status status = conn->WriteMessage( + conn->WriteMessageAsync( static_cast(object_manager_protocol::MessageType::PullRequest), - fbb.GetSize(), fbb.GetBufferPointer()); - if (status.ok()) { - connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::MESSAGE, conn); - } - return status; + fbb.GetSize(), fbb.GetBufferPointer(), [this, conn](ray::Status status) mutable { + if (status.ok()) { + connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::MESSAGE, conn); + } else { + CheckIOError(status, "Pull"); + } + }); } void ObjectManager::HandlePushTaskTimeout(const ObjectID &object_id, diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 11b5d7a6c..ef5b98a03 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -245,10 +245,10 @@ class ObjectManager : public ObjectManagerInterface { /// Executes on main_service_ thread. void PullEstablishConnection(const ObjectID &object_id, const ClientID &client_id); - /// Synchronously send a pull request via remote object manager connection. + /// Asynchronously send a pull request via remote object manager connection. /// Executes on main_service_ thread. - ray::Status PullSendRequest(const ObjectID &object_id, - std::shared_ptr &conn); + void PullSendRequest(const ObjectID &object_id, + std::shared_ptr &conn); std::shared_ptr CreateSenderConnection( ConnectionPool::ConnectionType type, RemoteConnectionInfo info); diff --git a/src/ray/object_manager/object_manager_client_connection.cc b/src/ray/object_manager/object_manager_client_connection.cc index c612e1703..dadfd72ce 100644 --- a/src/ray/object_manager/object_manager_client_connection.cc +++ b/src/ray/object_manager/object_manager_client_connection.cc @@ -11,7 +11,7 @@ std::shared_ptr SenderConnection::Create( Status status = TcpConnect(socket, ip, port); if (status.ok()) { std::shared_ptr conn = - std::make_shared(std::move(socket)); + TcpServerConnection::Create(std::move(socket)); return std::make_shared(std::move(conn), client_id); } else { return nullptr; diff --git a/src/ray/object_manager/object_manager_client_connection.h b/src/ray/object_manager/object_manager_client_connection.h index 1c8661b0d..b3a03102a 100644 --- a/src/ray/object_manager/object_manager_client_connection.h +++ b/src/ray/object_manager/object_manager_client_connection.h @@ -16,6 +16,7 @@ namespace ray { +// TODO(ekl) this class can be replaced with a plain ClientConnection class SenderConnection : public boost::enable_shared_from_this { public: /// Create a connection for sending data to other object managers. @@ -44,6 +45,17 @@ class SenderConnection : public boost::enable_shared_from_this return conn_->WriteMessage(type, length, message); } + /// Write a message to the client asynchronously. + /// + /// \param type The message type (e.g., a flatbuffer enum). + /// \param length The size in bytes of the message. + /// \param message A pointer to the message buffer. + /// \param handler A callback to run on write completion. + void WriteMessageAsync(int64_t type, int64_t length, const uint8_t *message, + const std::function &handler) { + conn_->WriteMessageAsync(type, length, message, handler); + } + /// Write a buffer to this connection. /// /// \param buffer The buffer. diff --git a/src/ray/raylet/CMakeLists.txt b/src/ray/raylet/CMakeLists.txt index 79233965a..5b580e4a2 100644 --- a/src/ray/raylet/CMakeLists.txt +++ b/src/ray/raylet/CMakeLists.txt @@ -32,6 +32,7 @@ ADD_RAY_TEST(object_manager_integration_test STATIC_LINK_LIBS ray_static ${PLASM ADD_RAY_TEST(worker_pool_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) +ADD_RAY_TEST(client_connection_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) ADD_RAY_TEST(task_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) ADD_RAY_TEST(lineage_cache_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) ADD_RAY_TEST(task_dependency_manager_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) diff --git a/src/ray/raylet/client_connection_test.cc b/src/ray/raylet/client_connection_test.cc new file mode 100644 index 000000000..a68a6535c --- /dev/null +++ b/src/ray/raylet/client_connection_test.cc @@ -0,0 +1,155 @@ +#include +#include + +#include +#include +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "ray/common/client_connection.h" + +namespace ray { +namespace raylet { + +class ClientConnectionTest : public ::testing::Test { + public: + ClientConnectionTest() : io_service_(), in_(io_service_), out_(io_service_) { + boost::asio::local::connect_pair(in_, out_); + } + + protected: + boost::asio::io_service io_service_; + boost::asio::local::stream_protocol::socket in_; + boost::asio::local::stream_protocol::socket out_; +}; + +TEST_F(ClientConnectionTest, SimpleSyncWrite) { + const uint8_t arr[5] = {1, 2, 3, 4, 5}; + int num_messages = 0; + + ClientHandler client_handler = + [](LocalClientConnection &client) {}; + + MessageHandler message_handler = + [&arr, &num_messages](std::shared_ptr client, + int64_t message_type, const uint8_t *message) { + ASSERT_TRUE(!std::memcmp(arr, message, 5)); + num_messages += 1; + }; + + auto conn1 = LocalClientConnection::Create(client_handler, message_handler, + std::move(in_), "conn1"); + + auto conn2 = LocalClientConnection::Create(client_handler, message_handler, + std::move(out_), "conn2"); + + RAY_CHECK_OK(conn1->WriteMessage(0, 5, arr)); + RAY_CHECK_OK(conn2->WriteMessage(0, 5, arr)); + conn1->ProcessMessages(); + conn2->ProcessMessages(); + io_service_.run(); + ASSERT_EQ(num_messages, 2); +} + +TEST_F(ClientConnectionTest, SimpleAsyncWrite) { + const uint8_t msg1[5] = {1, 2, 3, 4, 5}; + const uint8_t msg2[5] = {4, 4, 4, 4, 4}; + const uint8_t msg3[5] = {8, 8, 8, 8, 8}; + int num_messages = 0; + + ClientHandler client_handler = + [](LocalClientConnection &client) {}; + + MessageHandler noop_handler = []( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; + + std::shared_ptr reader = NULL; + + MessageHandler message_handler = + [&msg1, &msg2, &msg3, &num_messages, &reader]( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + if (num_messages == 0) { + ASSERT_TRUE(!std::memcmp(msg1, message, 5)); + } else if (num_messages == 1) { + ASSERT_TRUE(!std::memcmp(msg2, message, 5)); + } else { + ASSERT_TRUE(!std::memcmp(msg3, message, 5)); + } + num_messages += 1; + if (num_messages < 3) { + reader->ProcessMessages(); + } + }; + + auto writer = LocalClientConnection::Create(client_handler, noop_handler, + std::move(in_), "writer"); + + reader = LocalClientConnection::Create(client_handler, message_handler, std::move(out_), + "reader"); + + std::function callback = [](const ray::Status &status) { + RAY_CHECK_OK(status); + }; + + writer->WriteMessageAsync(0, 5, msg1, callback); + writer->WriteMessageAsync(0, 5, msg2, callback); + writer->WriteMessageAsync(0, 5, msg3, callback); + reader->ProcessMessages(); + io_service_.run(); + ASSERT_EQ(num_messages, 3); +} + +TEST_F(ClientConnectionTest, SimpleAsyncError) { + const uint8_t msg1[5] = {1, 2, 3, 4, 5}; + + ClientHandler client_handler = + [](LocalClientConnection &client) {}; + + MessageHandler noop_handler = []( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; + + auto writer = LocalClientConnection::Create(client_handler, noop_handler, + std::move(in_), "writer"); + + std::function callback = [](const ray::Status &status) { + ASSERT_TRUE(!status.ok()); + }; + + writer->Close(); + writer->WriteMessageAsync(0, 5, msg1, callback); + io_service_.run(); +} + +TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) { + const uint8_t msg1[5] = {1, 2, 3, 4, 5}; + + ClientHandler client_handler = + [](LocalClientConnection &client) {}; + + MessageHandler noop_handler = []( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; + + auto writer = LocalClientConnection::Create(client_handler, noop_handler, + std::move(in_), "writer"); + + std::function callback = + [writer](const ray::Status &status) { + static_cast(writer); + ASSERT_TRUE(status.ok()); + }; + writer->WriteMessageAsync(0, 5, msg1, callback); + io_service_.run(); +} + +} // namespace raylet + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index c4e75dbb8..be1974468 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -359,7 +359,7 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { } // The client is connected. - auto server_conn = TcpServerConnection(std::move(socket)); + auto server_conn = TcpServerConnection::Create(std::move(socket)); remote_server_connections_.emplace(client_id, std::move(server_conn)); ResourceSet resources_total(client_data.resources_total_label, @@ -1304,56 +1304,59 @@ void NodeManager::AssignTask(Task &task) { auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb), fbb.CreateVector(resource_id_set_flatbuf)); fbb.Finish(message); - auto status = worker->Connection()->WriteMessage( + worker->Connection()->WriteMessageAsync( static_cast(protocol::MessageType::ExecuteTask), fbb.GetSize(), - fbb.GetBufferPointer()); - if (status.ok()) { - // We successfully assigned the task to the worker. - worker->AssignTaskId(spec.TaskId()); - worker->AssignDriverId(spec.DriverId()); - // If the task was an actor task, then record this execution to guarantee - // consistency in the case of reconstruction. - if (spec.IsActorTask()) { - auto actor_entry = actor_registry_.find(spec.ActorId()); - RAY_CHECK(actor_entry != actor_registry_.end()); - auto execution_dependency = actor_entry->second.GetExecutionDependency(); - // The execution dependency is initialized to the actor creation task's - // return value, and is subsequently updated to the assigned tasks' - // return values, so it should never be nil. - RAY_CHECK(!execution_dependency.is_nil()); - // Update the task's execution dependencies to reflect the actual - // execution order, to support deterministic reconstruction. - // NOTE(swang): The update of an actor task's execution dependencies is - // performed asynchronously. This means that if this node manager dies, - // we may lose updates that are in flight to the task table. We only - // guarantee deterministic reconstruction ordering for tasks whose - // updates are reflected in the task table. - task.SetExecutionDependencies({execution_dependency}); - // Extend the frontier to include the executing task. - actor_entry->second.ExtendFrontier(spec.ActorHandleId(), spec.ActorDummyObject()); - } - // We started running the task, so the task is ready to write to GCS. - if (!lineage_cache_.AddReadyTask(task)) { - RAY_LOG(WARNING) - << "Task " << spec.TaskId() - << " already in lineage cache. This is most likely due to reconstruction."; - } - // Mark the task as running. - // (See design_docs/task_states.rst for the state transition diagram.) - local_queues_.QueueRunningTasks(std::vector({task})); - // Notify the task dependency manager that we no longer need this task's - // object dependencies. - task_dependency_manager_.UnsubscribeDependencies(spec.TaskId()); - } else { - RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; - // We failed to send the task to the worker, so disconnect the worker. - ProcessDisconnectClientMessage(worker->Connection()); - // Queue this task for future assignment. The task will be assigned to a - // worker once one becomes available. - // (See design_docs/task_states.rst for the state transition diagram.) - local_queues_.QueueReadyTasks(std::vector({task})); - DispatchTasks(); - } + fbb.GetBufferPointer(), [this, worker, task](ray::Status status) mutable { + if (status.ok()) { + auto spec = task.GetTaskSpecification(); + // We successfully assigned the task to the worker. + worker->AssignTaskId(spec.TaskId()); + worker->AssignDriverId(spec.DriverId()); + // If the task was an actor task, then record this execution to guarantee + // consistency in the case of reconstruction. + if (spec.IsActorTask()) { + auto actor_entry = actor_registry_.find(spec.ActorId()); + RAY_CHECK(actor_entry != actor_registry_.end()); + auto execution_dependency = actor_entry->second.GetExecutionDependency(); + // The execution dependency is initialized to the actor creation task's + // return value, and is subsequently updated to the assigned tasks' + // return values, so it should never be nil. + RAY_CHECK(!execution_dependency.is_nil()); + // Update the task's execution dependencies to reflect the actual + // execution order, to support deterministic reconstruction. + // NOTE(swang): The update of an actor task's execution dependencies is + // performed asynchronously. This means that if this node manager dies, + // we may lose updates that are in flight to the task table. We only + // guarantee deterministic reconstruction ordering for tasks whose + // updates are reflected in the task table. + task.SetExecutionDependencies({execution_dependency}); + // Extend the frontier to include the executing task. + actor_entry->second.ExtendFrontier(spec.ActorHandleId(), + spec.ActorDummyObject()); + } + // We started running the task, so the task is ready to write to GCS. + if (!lineage_cache_.AddReadyTask(task)) { + RAY_LOG(WARNING) << "Task " << spec.TaskId() << " already in lineage cache. " + "This is most likely due to " + "reconstruction."; + } + // Mark the task as running. + // (See design_docs/task_states.rst for the state transition diagram.) + local_queues_.QueueRunningTasks(std::vector({task})); + // Notify the task dependency manager that we no longer need this task's + // object dependencies. + task_dependency_manager_.UnsubscribeDependencies(spec.TaskId()); + } else { + RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; + // We failed to send the task to the worker, so disconnect the worker. + ProcessDisconnectClientMessage(worker->Connection()); + // Queue this task for future assignment. The task will be assigned to a + // worker once one becomes available. + // (See design_docs/task_states.rst for the state transition diagram.) + local_queues_.QueueReadyTasks(std::vector({task})); + DispatchTasks(); + } + }); } void NodeManager::FinishAssignedTask(Worker &worker) { @@ -1522,10 +1525,10 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, const ClientID &node_manager_id) { /// TODO(rkn): Should we check that the node manager is remote and not local? /// TODO(rkn): Should we check if the remote node manager is known to be dead? - const TaskID task_id = task.GetTaskSpecification().TaskId(); - // Attempt to forward the task. - if (!ForwardTask(task, node_manager_id).ok()) { + ForwardTask(task, node_manager_id, [this, task, node_manager_id](ray::Status error) { + const TaskID task_id = task.GetTaskSpecification().TaskId(); + RAY_LOG(INFO) << "Failed to forward task " << task_id << " to node manager " << node_manager_id; // Mark the failed task as pending to let other raylets know that we still @@ -1564,10 +1567,11 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, ScheduleTasks(cluster_resource_map_); DispatchTasks(); } - } + }); } -ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) { +void NodeManager::ForwardTask(const Task &task, const ClientID &node_id, + const std::function &on_error) { const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); @@ -1593,49 +1597,53 @@ ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) if (it == remote_server_connections_.end()) { // TODO(atumanov): caller must handle failure to ensure tasks are not lost. RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id; - return ray::Status::IOError("NodeManager connection not found"); + on_error(ray::Status::IOError("NodeManager connection not found")); + return; } auto &server_conn = it->second; - auto status = server_conn.WriteMessage( + server_conn->WriteMessageAsync( static_cast(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(), - fbb.GetBufferPointer()); - if (status.ok()) { - // If we were able to forward the task, remove the forwarded task from the - // lineage cache since the receiving node is now responsible for writing - // the task to the GCS. - if (!lineage_cache_.RemoveWaitingTask(task_id)) { - RAY_LOG(WARNING) << "Task " << task_id << " already removed from the lineage " - "cache. This is most likely due to " - "reconstruction."; - } - // Mark as forwarded so that the task and its lineage is not re-forwarded - // in the future to the receiving node. - lineage_cache_.MarkTaskAsForwarded(task_id, node_id); - - // Notify the task dependency manager that we are no longer responsible - // for executing this task. - task_dependency_manager_.TaskCanceled(task_id); - // Preemptively push any local arguments to the receiving node. For now, we - // only do this with actor tasks, since actor tasks must be executed by a - // specific process and therefore have affinity to the receiving node. - if (spec.IsActorTask()) { - // Iterate through the object's arguments. NOTE(swang): We do not include - // the execution dependencies here since those cannot be transferred - // between nodes. - for (int i = 0; i < spec.NumArgs(); ++i) { - int count = spec.ArgIdCount(i); - for (int j = 0; j < count; j++) { - ObjectID argument_id = spec.ArgId(i, j); - // If the argument is local, then push it to the receiving node. - if (task_dependency_manager_.CheckObjectLocal(argument_id)) { - object_manager_.Push(argument_id, node_id); + fbb.GetBufferPointer(), + [this, on_error, task_id, node_id, spec](ray::Status status) { + if (status.ok()) { + // If we were able to forward the task, remove the forwarded task from the + // lineage cache since the receiving node is now responsible for writing + // the task to the GCS. + if (!lineage_cache_.RemoveWaitingTask(task_id)) { + RAY_LOG(WARNING) << "Task " << task_id << " already removed from the lineage " + "cache. This is most likely due to " + "reconstruction."; } + // Mark as forwarded so that the task and its lineage is not re-forwarded + // in the future to the receiving node. + lineage_cache_.MarkTaskAsForwarded(task_id, node_id); + + // Notify the task dependency manager that we are no longer responsible + // for executing this task. + task_dependency_manager_.TaskCanceled(task_id); + // Preemptively push any local arguments to the receiving node. For now, we + // only do this with actor tasks, since actor tasks must be executed by a + // specific process and therefore have affinity to the receiving node. + if (spec.IsActorTask()) { + // Iterate through the object's arguments. NOTE(swang): We do not include + // the execution dependencies here since those cannot be transferred + // between nodes. + for (int i = 0; i < spec.NumArgs(); ++i) { + int count = spec.ArgIdCount(i); + for (int j = 0; j < count; j++) { + ObjectID argument_id = spec.ArgId(i, j); + // If the argument is local, then push it to the receiving node. + if (task_dependency_manager_.CheckObjectLocal(argument_id)) { + object_manager_.Push(argument_id, node_id); + } + } + } + } + } else { + on_error(status); } - } - } - } - return status; + }); } } // namespace raylet diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index e3d2ca141..2e5d7605f 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -174,10 +174,10 @@ class NodeManager { /// /// \param task The task to forward. /// \param node_id The ID of the node to forward the task to. - /// \return A status indicating whether the forward succeeded or not. Note - /// that a status of OK is not a reliable indicator that the forward succeeded - /// or even that the remote node is still alive. - ray::Status ForwardTask(const Task &task, const ClientID &node_id); + /// \param on_error Callback on run on non-ok status. + void ForwardTask(const Task &task, const ClientID &node_id, + const std::function &on_error); + /// Dispatch locally scheduled tasks. This attempts the transition from "scheduled" to /// "running" task state. void DispatchTasks(); @@ -352,7 +352,8 @@ class NodeManager { /// The lineage cache for the GCS object and task tables. LineageCache lineage_cache_; std::vector remote_clients_; - std::unordered_map remote_server_connections_; + std::unordered_map> + remote_server_connections_; /// A mapping from actor ID to registration information about that actor /// (including which node manager owns it). std::unordered_map actor_registry_;