mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[xray] All messages on main asio event loop should be written asynchronously (#3023)
* copy over ref code * wip async writes * compiles * fix error handling * add test * amend * fix test * clang fmgt * clang format * wip * yapf * rename format script * test error * clangfmt * add test to list * warn * ref test * fix test * comment * add capture * Update client_connection.cc * wip * fix compile
This commit is contained in:
parent
fa469783d8
commit
9d23fa03c9
12 changed files with 446 additions and 131 deletions
|
@ -54,7 +54,7 @@ matrix:
|
|||
- cd ..
|
||||
# Run Python linting, ignore dict vs {} (C408), others are defaults
|
||||
- flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504
|
||||
- .travis/yapf.sh --all
|
||||
- .travis/format.sh --all
|
||||
|
||||
- os: linux
|
||||
dist: trusty
|
||||
|
@ -185,6 +185,7 @@ install:
|
|||
- ./src/ray/raylet/lineage_cache_test
|
||||
- ./src/ray/raylet/task_dependency_manager_test
|
||||
- ./src/ray/raylet/reconstruction_policy_test
|
||||
- ./src/ray/raylet/client_connection_test
|
||||
- ./src/ray/util/logging_test --gtest_filter=PrintLogTest*
|
||||
- ./src/ray/util/signal_test
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
#!/usr/bin/env bash
|
||||
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
|
||||
# You are encouraged to run this locally before pushing changes for review.
|
||||
|
||||
# Cause the script to exit if a single command fails
|
||||
set -eo pipefail
|
||||
|
@ -51,6 +53,13 @@ format_changed() {
|
|||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \
|
||||
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
|
||||
fi
|
||||
|
||||
if which clang-format >/dev/null; then
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.cc' '*.h' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.cc' '*.h' | xargs -P 5 \
|
||||
clang-format -i
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Format all files, and print the diff to stdout for travis.
|
|
@ -1,5 +1,6 @@
|
|||
#include "client_connection.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <boost/bind.hpp>
|
||||
|
||||
#include "common.h"
|
||||
|
@ -18,9 +19,19 @@ ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket,
|
|||
return boost_to_ray_status(error);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::shared_ptr<ServerConnection<T>> ServerConnection<T>::Create(
|
||||
boost::asio::basic_stream_socket<T> &&socket) {
|
||||
std::shared_ptr<ServerConnection<T>> self(new ServerConnection(std::move(socket)));
|
||||
return self;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
ServerConnection<T>::ServerConnection(boost::asio::basic_stream_socket<T> &&socket)
|
||||
: socket_(std::move(socket)) {}
|
||||
: socket_(std::move(socket)),
|
||||
async_write_max_messages_(1),
|
||||
async_write_queue_(),
|
||||
async_write_in_flight_(false) {}
|
||||
|
||||
template <class T>
|
||||
Status ServerConnection<T>::WriteBuffer(
|
||||
|
@ -78,11 +89,80 @@ ray::Status ServerConnection<T>::WriteMessage(int64_t type, int64_t length,
|
|||
message_buffers.push_back(boost::asio::buffer(&type, sizeof(type)));
|
||||
message_buffers.push_back(boost::asio::buffer(&length, sizeof(length)));
|
||||
message_buffers.push_back(boost::asio::buffer(message, length));
|
||||
// Write the message and then wait for more messages.
|
||||
// TODO(swang): Does this need to be an async write?
|
||||
return WriteBuffer(message_buffers);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void ServerConnection<T>::WriteMessageAsync(
|
||||
int64_t type, int64_t length, const uint8_t *message,
|
||||
const std::function<void(const ray::Status &)> &handler) {
|
||||
auto write_buffer = std::unique_ptr<AsyncWriteBuffer>(new AsyncWriteBuffer());
|
||||
write_buffer->write_version = RayConfig::instance().ray_protocol_version();
|
||||
write_buffer->write_type = type;
|
||||
write_buffer->write_length = length;
|
||||
write_buffer->write_message.resize(length);
|
||||
write_buffer->write_message.assign(message, message + length);
|
||||
write_buffer->handler = handler;
|
||||
|
||||
auto size = async_write_queue_.size();
|
||||
auto size_is_power_of_two = (size & (size - 1)) == 0;
|
||||
if (size > 100 && size_is_power_of_two) {
|
||||
RAY_LOG(WARNING) << "ServerConnection has " << size << " buffered async writes";
|
||||
}
|
||||
|
||||
async_write_queue_.push_back(std::move(write_buffer));
|
||||
|
||||
if (!async_write_in_flight_) {
|
||||
DoAsyncWrites();
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void ServerConnection<T>::DoAsyncWrites() {
|
||||
// Make sure we were not writing to the socket.
|
||||
RAY_CHECK(!async_write_in_flight_);
|
||||
async_write_in_flight_ = true;
|
||||
|
||||
// Do an async write of everything currently in the queue to the socket.
|
||||
std::vector<boost::asio::const_buffer> message_buffers;
|
||||
int num_messages = 0;
|
||||
for (const auto &write_buffer : async_write_queue_) {
|
||||
message_buffers.push_back(boost::asio::buffer(&write_buffer->write_version,
|
||||
sizeof(write_buffer->write_version)));
|
||||
message_buffers.push_back(
|
||||
boost::asio::buffer(&write_buffer->write_type, sizeof(write_buffer->write_type)));
|
||||
message_buffers.push_back(boost::asio::buffer(&write_buffer->write_length,
|
||||
sizeof(write_buffer->write_length)));
|
||||
message_buffers.push_back(boost::asio::buffer(write_buffer->write_message));
|
||||
num_messages++;
|
||||
if (num_messages >= async_write_max_messages_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto this_ptr = this->shared_from_this();
|
||||
boost::asio::async_write(
|
||||
ServerConnection<T>::socket_, message_buffers,
|
||||
[this, this_ptr, num_messages](const boost::system::error_code &error,
|
||||
size_t bytes_transferred) {
|
||||
ray::Status status = ray::Status::OK();
|
||||
if (error.value() != boost::system::errc::errc_t::success) {
|
||||
status = boost_to_ray_status(error);
|
||||
}
|
||||
// Call the handlers for the written messages.
|
||||
for (int i = 0; i < num_messages; i++) {
|
||||
auto write_buffer = std::move(async_write_queue_.front());
|
||||
write_buffer->handler(status);
|
||||
async_write_queue_.pop_front();
|
||||
}
|
||||
// We finished writing, so mark that we're no longer doing an async write.
|
||||
async_write_in_flight_ = false;
|
||||
// If there is more to write, try to write the rest.
|
||||
if (!async_write_queue_.empty()) {
|
||||
DoAsyncWrites();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::shared_ptr<ClientConnection<T>> ClientConnection<T>::Create(
|
||||
ClientHandler<T> &client_handler, MessageHandler<T> &message_handler,
|
||||
|
@ -122,8 +202,8 @@ void ClientConnection<T>::ProcessMessages() {
|
|||
header.push_back(boost::asio::buffer(&read_length_, sizeof(read_length_)));
|
||||
boost::asio::async_read(
|
||||
ServerConnection<T>::socket_, header,
|
||||
boost::bind(&ClientConnection<T>::ProcessMessageHeader, this->shared_from_this(),
|
||||
boost::asio::placeholders::error));
|
||||
boost::bind(&ClientConnection<T>::ProcessMessageHeader,
|
||||
shared_ClientConnection_from_this(), boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
|
@ -143,8 +223,8 @@ void ClientConnection<T>::ProcessMessageHeader(const boost::system::error_code &
|
|||
// Wait for the message to be read.
|
||||
boost::asio::async_read(
|
||||
ServerConnection<T>::socket_, boost::asio::buffer(read_message_),
|
||||
boost::bind(&ClientConnection<T>::ProcessMessage, this->shared_from_this(),
|
||||
boost::asio::placeholders::error));
|
||||
boost::bind(&ClientConnection<T>::ProcessMessage,
|
||||
shared_ClientConnection_from_this(), boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
|
@ -154,7 +234,7 @@ void ClientConnection<T>::ProcessMessage(const boost::system::error_code &error)
|
|||
}
|
||||
|
||||
uint64_t start_ms = current_time_ms();
|
||||
message_handler_(this->shared_from_this(), read_type_, read_message_.data());
|
||||
message_handler_(shared_ClientConnection_from_this(), read_type_, read_message_.data());
|
||||
uint64_t interval = current_time_ms() - start_ms;
|
||||
if (interval > RayConfig::instance().handler_warning_timeout_ms()) {
|
||||
RAY_LOG(WARNING) << "[" << debug_label_ << "]ProcessMessage with type " << read_type_
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#ifndef RAY_COMMON_CLIENT_CONNECTION_H
|
||||
#define RAY_COMMON_CLIENT_CONNECTION_H
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
|
@ -26,10 +27,14 @@ ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket,
|
|||
/// A generic type representing a client connection to a server. This typename
|
||||
/// can be used to write messages synchronously to the server.
|
||||
template <typename T>
|
||||
class ServerConnection {
|
||||
class ServerConnection : public std::enable_shared_from_this<ServerConnection<T>> {
|
||||
public:
|
||||
/// Create a connection to the server.
|
||||
ServerConnection(boost::asio::basic_stream_socket<T> &&socket);
|
||||
/// Allocate a new server connection.
|
||||
///
|
||||
/// \param socket A reference to the server socket.
|
||||
/// \return std::shared_ptr<ServerConnection>.
|
||||
static std::shared_ptr<ServerConnection<T>> Create(
|
||||
boost::asio::basic_stream_socket<T> &&socket);
|
||||
|
||||
/// Write a message to the client.
|
||||
///
|
||||
|
@ -39,6 +44,15 @@ class ServerConnection {
|
|||
/// \return Status.
|
||||
ray::Status WriteMessage(int64_t type, int64_t length, const uint8_t *message);
|
||||
|
||||
/// Write a message to the client asynchronously.
|
||||
///
|
||||
/// \param type The message type (e.g., a flatbuffer enum).
|
||||
/// \param length The size in bytes of the message.
|
||||
/// \param message A pointer to the message buffer.
|
||||
/// \param handler A callback to run on write completion.
|
||||
void WriteMessageAsync(int64_t type, int64_t length, const uint8_t *message,
|
||||
const std::function<void(const ray::Status &)> &handler);
|
||||
|
||||
/// Write a buffer to this connection.
|
||||
///
|
||||
/// \param buffer The buffer.
|
||||
|
@ -52,9 +66,42 @@ class ServerConnection {
|
|||
void ReadBuffer(const std::vector<boost::asio::mutable_buffer> &buffer,
|
||||
boost::system::error_code &ec);
|
||||
|
||||
/// Shuts down socket for this connection.
|
||||
void Close() {
|
||||
boost::system::error_code ec;
|
||||
socket_.close(ec);
|
||||
}
|
||||
|
||||
protected:
|
||||
/// A private constructor for a server connection.
|
||||
ServerConnection(boost::asio::basic_stream_socket<T> &&socket);
|
||||
|
||||
/// A message that is queued for writing asynchronously.
|
||||
struct AsyncWriteBuffer {
|
||||
int64_t write_version;
|
||||
int64_t write_type;
|
||||
uint64_t write_length;
|
||||
std::vector<uint8_t> write_message;
|
||||
std::function<void(const ray::Status &)> handler;
|
||||
};
|
||||
|
||||
/// The socket connection to the server.
|
||||
boost::asio::basic_stream_socket<T> socket_;
|
||||
|
||||
/// Max number of messages to write out at once.
|
||||
const int async_write_max_messages_;
|
||||
|
||||
/// List of pending messages to write.
|
||||
std::list<std::unique_ptr<AsyncWriteBuffer>> async_write_queue_;
|
||||
|
||||
/// Whether we are in the middle of an async write.
|
||||
bool async_write_in_flight_;
|
||||
|
||||
private:
|
||||
/// Asynchronously flushes the write queue. While async writes are running, the flag
|
||||
/// async_write_in_flight_ will be set. This should only be called when no async writes
|
||||
/// are currently in flight.
|
||||
void DoAsyncWrites();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -72,9 +119,10 @@ using MessageHandler =
|
|||
/// writing messages to the client, like in ServerConnection, this typename can
|
||||
/// also be used to process messages asynchronously from client.
|
||||
template <typename T>
|
||||
class ClientConnection : public ServerConnection<T>,
|
||||
public std::enable_shared_from_this<ClientConnection<T>> {
|
||||
class ClientConnection : public ServerConnection<T> {
|
||||
public:
|
||||
using std::enable_shared_from_this<ServerConnection<T>>::shared_from_this;
|
||||
|
||||
/// Allocate a new node client connection.
|
||||
///
|
||||
/// \param new_client_handler A reference to the client handler.
|
||||
|
@ -85,6 +133,10 @@ class ClientConnection : public ServerConnection<T>,
|
|||
ClientHandler<T> &new_client_handler, MessageHandler<T> &message_handler,
|
||||
boost::asio::basic_stream_socket<T> &&socket, const std::string &debug_label);
|
||||
|
||||
std::shared_ptr<ClientConnection<T>> shared_ClientConnection_from_this() {
|
||||
return std::static_pointer_cast<ClientConnection<T>>(shared_from_this());
|
||||
}
|
||||
|
||||
/// \return The ClientID of the remote client.
|
||||
const ClientID &GetClientID();
|
||||
|
||||
|
|
|
@ -266,35 +266,31 @@ void ObjectManager::PullEstablishConnection(const ObjectID &object_id,
|
|||
}
|
||||
connection_pool_.RegisterSender(ConnectionPool::ConnectionType::MESSAGE,
|
||||
client_id, async_conn);
|
||||
Status pull_send_status = PullSendRequest(object_id, async_conn);
|
||||
if (!pull_send_status.ok()) {
|
||||
CheckIOError(pull_send_status, "Pull");
|
||||
}
|
||||
PullSendRequest(object_id, async_conn);
|
||||
},
|
||||
[]() {
|
||||
RAY_LOG(ERROR) << "Failed to establish connection with remote object manager.";
|
||||
});
|
||||
} else {
|
||||
status = PullSendRequest(object_id, conn);
|
||||
if (!status.ok()) {
|
||||
CheckIOError(status, "Pull");
|
||||
}
|
||||
PullSendRequest(object_id, conn);
|
||||
}
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::PullSendRequest(const ObjectID &object_id,
|
||||
void ObjectManager::PullSendRequest(const ObjectID &object_id,
|
||||
std::shared_ptr<SenderConnection> &conn) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = object_manager_protocol::CreatePullRequestMessage(
|
||||
fbb, fbb.CreateString(client_id_.binary()), fbb.CreateString(object_id.binary()));
|
||||
fbb.Finish(message);
|
||||
Status status = conn->WriteMessage(
|
||||
conn->WriteMessageAsync(
|
||||
static_cast<int64_t>(object_manager_protocol::MessageType::PullRequest),
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
fbb.GetSize(), fbb.GetBufferPointer(), [this, conn](ray::Status status) mutable {
|
||||
if (status.ok()) {
|
||||
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::MESSAGE, conn);
|
||||
} else {
|
||||
CheckIOError(status, "Pull");
|
||||
}
|
||||
return status;
|
||||
});
|
||||
}
|
||||
|
||||
void ObjectManager::HandlePushTaskTimeout(const ObjectID &object_id,
|
||||
|
|
|
@ -245,9 +245,9 @@ class ObjectManager : public ObjectManagerInterface {
|
|||
/// Executes on main_service_ thread.
|
||||
void PullEstablishConnection(const ObjectID &object_id, const ClientID &client_id);
|
||||
|
||||
/// Synchronously send a pull request via remote object manager connection.
|
||||
/// Asynchronously send a pull request via remote object manager connection.
|
||||
/// Executes on main_service_ thread.
|
||||
ray::Status PullSendRequest(const ObjectID &object_id,
|
||||
void PullSendRequest(const ObjectID &object_id,
|
||||
std::shared_ptr<SenderConnection> &conn);
|
||||
|
||||
std::shared_ptr<SenderConnection> CreateSenderConnection(
|
||||
|
|
|
@ -11,7 +11,7 @@ std::shared_ptr<SenderConnection> SenderConnection::Create(
|
|||
Status status = TcpConnect(socket, ip, port);
|
||||
if (status.ok()) {
|
||||
std::shared_ptr<TcpServerConnection> conn =
|
||||
std::make_shared<TcpServerConnection>(std::move(socket));
|
||||
TcpServerConnection::Create(std::move(socket));
|
||||
return std::make_shared<SenderConnection>(std::move(conn), client_id);
|
||||
} else {
|
||||
return nullptr;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
namespace ray {
|
||||
|
||||
// TODO(ekl) this class can be replaced with a plain ClientConnection
|
||||
class SenderConnection : public boost::enable_shared_from_this<SenderConnection> {
|
||||
public:
|
||||
/// Create a connection for sending data to other object managers.
|
||||
|
@ -44,6 +45,17 @@ class SenderConnection : public boost::enable_shared_from_this<SenderConnection>
|
|||
return conn_->WriteMessage(type, length, message);
|
||||
}
|
||||
|
||||
/// Write a message to the client asynchronously.
|
||||
///
|
||||
/// \param type The message type (e.g., a flatbuffer enum).
|
||||
/// \param length The size in bytes of the message.
|
||||
/// \param message A pointer to the message buffer.
|
||||
/// \param handler A callback to run on write completion.
|
||||
void WriteMessageAsync(int64_t type, int64_t length, const uint8_t *message,
|
||||
const std::function<void(const ray::Status &)> &handler) {
|
||||
conn_->WriteMessageAsync(type, length, message, handler);
|
||||
}
|
||||
|
||||
/// Write a buffer to this connection.
|
||||
///
|
||||
/// \param buffer The buffer.
|
||||
|
|
|
@ -32,6 +32,7 @@ ADD_RAY_TEST(object_manager_integration_test STATIC_LINK_LIBS ray_static ${PLASM
|
|||
|
||||
ADD_RAY_TEST(worker_pool_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
|
||||
ADD_RAY_TEST(client_connection_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
ADD_RAY_TEST(task_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
ADD_RAY_TEST(lineage_cache_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
ADD_RAY_TEST(task_dependency_manager_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
|
|
155
src/ray/raylet/client_connection_test.cc
Normal file
155
src/ray/raylet/client_connection_test.cc
Normal file
|
@ -0,0 +1,155 @@
|
|||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/asio/error.hpp>
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/common/client_connection.h"
|
||||
|
||||
namespace ray {
|
||||
namespace raylet {
|
||||
|
||||
class ClientConnectionTest : public ::testing::Test {
|
||||
public:
|
||||
ClientConnectionTest() : io_service_(), in_(io_service_), out_(io_service_) {
|
||||
boost::asio::local::connect_pair(in_, out_);
|
||||
}
|
||||
|
||||
protected:
|
||||
boost::asio::io_service io_service_;
|
||||
boost::asio::local::stream_protocol::socket in_;
|
||||
boost::asio::local::stream_protocol::socket out_;
|
||||
};
|
||||
|
||||
TEST_F(ClientConnectionTest, SimpleSyncWrite) {
|
||||
const uint8_t arr[5] = {1, 2, 3, 4, 5};
|
||||
int num_messages = 0;
|
||||
|
||||
ClientHandler<boost::asio::local::stream_protocol> client_handler =
|
||||
[](LocalClientConnection &client) {};
|
||||
|
||||
MessageHandler<boost::asio::local::stream_protocol> message_handler =
|
||||
[&arr, &num_messages](std::shared_ptr<LocalClientConnection> client,
|
||||
int64_t message_type, const uint8_t *message) {
|
||||
ASSERT_TRUE(!std::memcmp(arr, message, 5));
|
||||
num_messages += 1;
|
||||
};
|
||||
|
||||
auto conn1 = LocalClientConnection::Create(client_handler, message_handler,
|
||||
std::move(in_), "conn1");
|
||||
|
||||
auto conn2 = LocalClientConnection::Create(client_handler, message_handler,
|
||||
std::move(out_), "conn2");
|
||||
|
||||
RAY_CHECK_OK(conn1->WriteMessage(0, 5, arr));
|
||||
RAY_CHECK_OK(conn2->WriteMessage(0, 5, arr));
|
||||
conn1->ProcessMessages();
|
||||
conn2->ProcessMessages();
|
||||
io_service_.run();
|
||||
ASSERT_EQ(num_messages, 2);
|
||||
}
|
||||
|
||||
TEST_F(ClientConnectionTest, SimpleAsyncWrite) {
|
||||
const uint8_t msg1[5] = {1, 2, 3, 4, 5};
|
||||
const uint8_t msg2[5] = {4, 4, 4, 4, 4};
|
||||
const uint8_t msg3[5] = {8, 8, 8, 8, 8};
|
||||
int num_messages = 0;
|
||||
|
||||
ClientHandler<boost::asio::local::stream_protocol> client_handler =
|
||||
[](LocalClientConnection &client) {};
|
||||
|
||||
MessageHandler<boost::asio::local::stream_protocol> noop_handler = [](
|
||||
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {};
|
||||
|
||||
std::shared_ptr<LocalClientConnection> reader = NULL;
|
||||
|
||||
MessageHandler<boost::asio::local::stream_protocol> message_handler =
|
||||
[&msg1, &msg2, &msg3, &num_messages, &reader](
|
||||
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
if (num_messages == 0) {
|
||||
ASSERT_TRUE(!std::memcmp(msg1, message, 5));
|
||||
} else if (num_messages == 1) {
|
||||
ASSERT_TRUE(!std::memcmp(msg2, message, 5));
|
||||
} else {
|
||||
ASSERT_TRUE(!std::memcmp(msg3, message, 5));
|
||||
}
|
||||
num_messages += 1;
|
||||
if (num_messages < 3) {
|
||||
reader->ProcessMessages();
|
||||
}
|
||||
};
|
||||
|
||||
auto writer = LocalClientConnection::Create(client_handler, noop_handler,
|
||||
std::move(in_), "writer");
|
||||
|
||||
reader = LocalClientConnection::Create(client_handler, message_handler, std::move(out_),
|
||||
"reader");
|
||||
|
||||
std::function<void(const ray::Status &)> callback = [](const ray::Status &status) {
|
||||
RAY_CHECK_OK(status);
|
||||
};
|
||||
|
||||
writer->WriteMessageAsync(0, 5, msg1, callback);
|
||||
writer->WriteMessageAsync(0, 5, msg2, callback);
|
||||
writer->WriteMessageAsync(0, 5, msg3, callback);
|
||||
reader->ProcessMessages();
|
||||
io_service_.run();
|
||||
ASSERT_EQ(num_messages, 3);
|
||||
}
|
||||
|
||||
TEST_F(ClientConnectionTest, SimpleAsyncError) {
|
||||
const uint8_t msg1[5] = {1, 2, 3, 4, 5};
|
||||
|
||||
ClientHandler<boost::asio::local::stream_protocol> client_handler =
|
||||
[](LocalClientConnection &client) {};
|
||||
|
||||
MessageHandler<boost::asio::local::stream_protocol> noop_handler = [](
|
||||
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {};
|
||||
|
||||
auto writer = LocalClientConnection::Create(client_handler, noop_handler,
|
||||
std::move(in_), "writer");
|
||||
|
||||
std::function<void(const ray::Status &)> callback = [](const ray::Status &status) {
|
||||
ASSERT_TRUE(!status.ok());
|
||||
};
|
||||
|
||||
writer->Close();
|
||||
writer->WriteMessageAsync(0, 5, msg1, callback);
|
||||
io_service_.run();
|
||||
}
|
||||
|
||||
TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) {
|
||||
const uint8_t msg1[5] = {1, 2, 3, 4, 5};
|
||||
|
||||
ClientHandler<boost::asio::local::stream_protocol> client_handler =
|
||||
[](LocalClientConnection &client) {};
|
||||
|
||||
MessageHandler<boost::asio::local::stream_protocol> noop_handler = [](
|
||||
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {};
|
||||
|
||||
auto writer = LocalClientConnection::Create(client_handler, noop_handler,
|
||||
std::move(in_), "writer");
|
||||
|
||||
std::function<void(const ray::Status &)> callback =
|
||||
[writer](const ray::Status &status) {
|
||||
static_cast<void>(writer);
|
||||
ASSERT_TRUE(status.ok());
|
||||
};
|
||||
writer->WriteMessageAsync(0, 5, msg1, callback);
|
||||
io_service_.run();
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
|
@ -359,7 +359,7 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) {
|
|||
}
|
||||
|
||||
// The client is connected.
|
||||
auto server_conn = TcpServerConnection(std::move(socket));
|
||||
auto server_conn = TcpServerConnection::Create(std::move(socket));
|
||||
remote_server_connections_.emplace(client_id, std::move(server_conn));
|
||||
|
||||
ResourceSet resources_total(client_data.resources_total_label,
|
||||
|
@ -1304,10 +1304,11 @@ void NodeManager::AssignTask(Task &task) {
|
|||
auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb),
|
||||
fbb.CreateVector(resource_id_set_flatbuf));
|
||||
fbb.Finish(message);
|
||||
auto status = worker->Connection()->WriteMessage(
|
||||
worker->Connection()->WriteMessageAsync(
|
||||
static_cast<int64_t>(protocol::MessageType::ExecuteTask), fbb.GetSize(),
|
||||
fbb.GetBufferPointer());
|
||||
fbb.GetBufferPointer(), [this, worker, task](ray::Status status) mutable {
|
||||
if (status.ok()) {
|
||||
auto spec = task.GetTaskSpecification();
|
||||
// We successfully assigned the task to the worker.
|
||||
worker->AssignTaskId(spec.TaskId());
|
||||
worker->AssignDriverId(spec.DriverId());
|
||||
|
@ -1330,13 +1331,14 @@ void NodeManager::AssignTask(Task &task) {
|
|||
// updates are reflected in the task table.
|
||||
task.SetExecutionDependencies({execution_dependency});
|
||||
// Extend the frontier to include the executing task.
|
||||
actor_entry->second.ExtendFrontier(spec.ActorHandleId(), spec.ActorDummyObject());
|
||||
actor_entry->second.ExtendFrontier(spec.ActorHandleId(),
|
||||
spec.ActorDummyObject());
|
||||
}
|
||||
// We started running the task, so the task is ready to write to GCS.
|
||||
if (!lineage_cache_.AddReadyTask(task)) {
|
||||
RAY_LOG(WARNING)
|
||||
<< "Task " << spec.TaskId()
|
||||
<< " already in lineage cache. This is most likely due to reconstruction.";
|
||||
RAY_LOG(WARNING) << "Task " << spec.TaskId() << " already in lineage cache. "
|
||||
"This is most likely due to "
|
||||
"reconstruction.";
|
||||
}
|
||||
// Mark the task as running.
|
||||
// (See design_docs/task_states.rst for the state transition diagram.)
|
||||
|
@ -1354,6 +1356,7 @@ void NodeManager::AssignTask(Task &task) {
|
|||
local_queues_.QueueReadyTasks(std::vector<Task>({task}));
|
||||
DispatchTasks();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void NodeManager::FinishAssignedTask(Worker &worker) {
|
||||
|
@ -1522,10 +1525,10 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task,
|
|||
const ClientID &node_manager_id) {
|
||||
/// TODO(rkn): Should we check that the node manager is remote and not local?
|
||||
/// TODO(rkn): Should we check if the remote node manager is known to be dead?
|
||||
// Attempt to forward the task.
|
||||
ForwardTask(task, node_manager_id, [this, task, node_manager_id](ray::Status error) {
|
||||
const TaskID task_id = task.GetTaskSpecification().TaskId();
|
||||
|
||||
// Attempt to forward the task.
|
||||
if (!ForwardTask(task, node_manager_id).ok()) {
|
||||
RAY_LOG(INFO) << "Failed to forward task " << task_id << " to node manager "
|
||||
<< node_manager_id;
|
||||
// Mark the failed task as pending to let other raylets know that we still
|
||||
|
@ -1564,10 +1567,11 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task,
|
|||
ScheduleTasks(cluster_resource_map_);
|
||||
DispatchTasks();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) {
|
||||
void NodeManager::ForwardTask(const Task &task, const ClientID &node_id,
|
||||
const std::function<void(const ray::Status &)> &on_error) {
|
||||
const auto &spec = task.GetTaskSpecification();
|
||||
auto task_id = spec.TaskId();
|
||||
|
||||
|
@ -1593,13 +1597,15 @@ ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id)
|
|||
if (it == remote_server_connections_.end()) {
|
||||
// TODO(atumanov): caller must handle failure to ensure tasks are not lost.
|
||||
RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id;
|
||||
return ray::Status::IOError("NodeManager connection not found");
|
||||
on_error(ray::Status::IOError("NodeManager connection not found"));
|
||||
return;
|
||||
}
|
||||
|
||||
auto &server_conn = it->second;
|
||||
auto status = server_conn.WriteMessage(
|
||||
server_conn->WriteMessageAsync(
|
||||
static_cast<int64_t>(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(),
|
||||
fbb.GetBufferPointer());
|
||||
fbb.GetBufferPointer(),
|
||||
[this, on_error, task_id, node_id, spec](ray::Status status) {
|
||||
if (status.ok()) {
|
||||
// If we were able to forward the task, remove the forwarded task from the
|
||||
// lineage cache since the receiving node is now responsible for writing
|
||||
|
@ -1634,8 +1640,10 @@ ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id)
|
|||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
on_error(status);
|
||||
}
|
||||
return status;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
|
|
@ -174,10 +174,10 @@ class NodeManager {
|
|||
///
|
||||
/// \param task The task to forward.
|
||||
/// \param node_id The ID of the node to forward the task to.
|
||||
/// \return A status indicating whether the forward succeeded or not. Note
|
||||
/// that a status of OK is not a reliable indicator that the forward succeeded
|
||||
/// or even that the remote node is still alive.
|
||||
ray::Status ForwardTask(const Task &task, const ClientID &node_id);
|
||||
/// \param on_error Callback on run on non-ok status.
|
||||
void ForwardTask(const Task &task, const ClientID &node_id,
|
||||
const std::function<void(const ray::Status &)> &on_error);
|
||||
|
||||
/// Dispatch locally scheduled tasks. This attempts the transition from "scheduled" to
|
||||
/// "running" task state.
|
||||
void DispatchTasks();
|
||||
|
@ -352,7 +352,8 @@ class NodeManager {
|
|||
/// The lineage cache for the GCS object and task tables.
|
||||
LineageCache lineage_cache_;
|
||||
std::vector<ClientID> remote_clients_;
|
||||
std::unordered_map<ClientID, TcpServerConnection> remote_server_connections_;
|
||||
std::unordered_map<ClientID, std::shared_ptr<TcpServerConnection>>
|
||||
remote_server_connections_;
|
||||
/// A mapping from actor ID to registration information about that actor
|
||||
/// (including which node manager owns it).
|
||||
std::unordered_map<ActorID, ActorRegistration> actor_registry_;
|
||||
|
|
Loading…
Add table
Reference in a new issue