[xray] Implementing Gcs sharding (#2409)

Basically a re-implementation of #2281, with modifications of #2298 (A fix of #2334, for rebasing issues.).
[+] Implement sharding for gcs tables.
[+] Keep ClientTable and ErrorTable managed by the primary_shard. TaskTable is managed by the primary_shard for now, until a good hashing for tasks is implemented.
[+] Move AsyncGcsClient's initialization into Connect function.
[-] Move GetRedisShard and bool sharding from RedisContext's connect into AsyncGcsClient. This may make the interface cleaner.
This commit is contained in:
Yucong He 2018-08-31 15:54:30 -07:00 committed by Robert Nishihara
parent eda6ebb87d
commit 5b45f0bdff
19 changed files with 269 additions and 235 deletions

View file

@ -132,11 +132,6 @@ GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop,
"global_scheduler", node_ip_address,
std::vector<std::string>());
db_attach(state->db, loop, false);
RAY_CHECK_OK(state->gcs_client.Connect(
std::string(redis_primary_addr), redis_primary_port, /*sharding=*/true));
RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(loop));
RAY_CHECK_OK(state->gcs_client.primary_context()->AttachToEventLoop(loop));
state->policy_state = GlobalSchedulerPolicyState_init();
return state;
}

View file

@ -51,8 +51,6 @@ typedef struct {
event_loop *loop;
/** The global state store database. */
DBHandle *db;
/** The handle to the GCS (modern version of the above). */
ray::gcs::AsyncGcsClient gcs_client;
/** A hash table mapping local scheduler ID to the local schedulers that are
* connected to Redis. */
std::unordered_map<DBClientID, LocalScheduler> local_schedulers;

View file

@ -351,11 +351,6 @@ LocalSchedulerState *LocalSchedulerState_init(
state->db = db_connect(std::string(redis_primary_addr), redis_primary_port,
"local_scheduler", node_ip_address, db_connect_args);
db_attach(state->db, loop, false);
RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr),
redis_primary_port, true));
RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(loop));
RAY_CHECK_OK(state->gcs_client.primary_context()->AttachToEventLoop(loop));
} else {
state->db = NULL;
}

View file

@ -60,8 +60,6 @@ struct LocalSchedulerState {
std::unordered_map<ActorID, ActorMapEntry> actor_mapping;
/** The handle to the database. */
DBHandle *db;
/** The handle to the GCS (modern version of the above). */
ray::gcs::AsyncGcsClient gcs_client;
/** The Plasma client. */
plasma::PlasmaClient *plasma_conn;
/** State for the scheduling algorithm. */

View file

@ -215,8 +215,6 @@ struct PlasmaManagerState {
* other plasma stores. */
std::unordered_map<std::string, ClientConnection *> manager_connections;
DBHandle *db;
/** The handle to the GCS (modern version of the above). */
ray::gcs::AsyncGcsClient gcs_client;
/** Our address. */
const char *addr;
/** Our port. */
@ -490,13 +488,6 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name,
state->db = db_connect(std::string(redis_primary_addr), redis_primary_port,
"plasma_manager", manager_addr, db_connect_args);
db_attach(state->db, state->loop, false);
RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr),
redis_primary_port,
/*sharding=*/true));
RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(state->loop));
RAY_CHECK_OK(
state->gcs_client.primary_context()->AttachToEventLoop(state->loop));
} else {
state->db = NULL;
RAY_LOG(DEBUG) << "No db connection specified";

View file

@ -2,51 +2,152 @@
#include "ray/gcs/redis_context.h"
static void GetRedisShards(redisContext *context, std::vector<std::string> &addresses,
std::vector<int> &ports) {
// Get the total number of Redis shards in the system.
int num_attempts = 0;
redisReply *reply = nullptr;
while (num_attempts < RayConfig::instance().redis_db_connect_retries()) {
// Try to read the number of Redis shards from the primary shard. If the
// entry is present, exit.
reply = reinterpret_cast<redisReply *>(redisCommand(context, "GET NumRedisShards"));
if (reply->type != REDIS_REPLY_NIL) {
break;
}
// Sleep for a little, and try again if the entry isn't there yet. */
freeReplyObject(reply);
usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000);
num_attempts++;
}
RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries())
<< "No entry found for NumRedisShards";
RAY_CHECK(reply->type == REDIS_REPLY_STRING) << "Expected string, found Redis type "
<< reply->type << " for NumRedisShards";
int num_redis_shards = atoi(reply->str);
RAY_CHECK(num_redis_shards >= 1) << "Expected at least one Redis shard, "
<< "found " << num_redis_shards;
freeReplyObject(reply);
// Get the addresses of all of the Redis shards.
num_attempts = 0;
while (num_attempts < RayConfig::instance().redis_db_connect_retries()) {
// Try to read the Redis shard locations from the primary shard. If we find
// that all of them are present, exit.
reply =
reinterpret_cast<redisReply *>(redisCommand(context, "LRANGE RedisShards 0 -1"));
if (static_cast<int>(reply->elements) == num_redis_shards) {
break;
}
// Sleep for a little, and try again if not all Redis shard addresses have
// been added yet.
freeReplyObject(reply);
usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000);
num_attempts++;
}
RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries())
<< "Expected " << num_redis_shards << " Redis shard addresses, found "
<< reply->elements;
// Parse the Redis shard addresses.
for (size_t i = 0; i < reply->elements; ++i) {
// Parse the shard addresses and ports.
RAY_CHECK(reply->element[i]->type == REDIS_REPLY_STRING);
std::string addr;
std::stringstream ss(reply->element[i]->str);
getline(ss, addr, ':');
addresses.push_back(addr);
int port;
ss >> port;
ports.push_back(port);
}
freeReplyObject(reply);
}
namespace ray {
namespace gcs {
AsyncGcsClient::AsyncGcsClient(const ClientID &client_id, CommandType command_type) {
context_ = std::make_shared<RedisContext>();
AsyncGcsClient::AsyncGcsClient(const std::string &address, int port,
const ClientID &client_id, CommandType command_type,
bool is_test_client = false) {
primary_context_ = std::make_shared<RedisContext>();
client_table_.reset(new ClientTable(primary_context_, this, client_id));
object_table_.reset(new ObjectTable(context_, this, command_type));
actor_table_.reset(new ActorTable(context_, this));
task_table_.reset(new TaskTable(context_, this, command_type));
raylet_task_table_.reset(new raylet::TaskTable(context_, this, command_type));
task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this));
task_lease_table_.reset(new TaskLeaseTable(context_, this));
heartbeat_table_.reset(new HeartbeatTable(context_, this));
driver_table_.reset(new DriverTable(primary_context_, this));
error_table_.reset(new ErrorTable(primary_context_, this));
profile_table_.reset(new ProfileTable(context_, this));
RAY_CHECK_OK(primary_context_->Connect(address, port, /*sharding=*/true));
if (!is_test_client) {
// Moving sharding into constructor defaultly means that sharding = true.
// This design decision may worth a look.
std::vector<std::string> addresses;
std::vector<int> ports;
GetRedisShards(primary_context_->sync_context(), addresses, ports);
if (addresses.size() == 0 || ports.size() == 0) {
addresses.push_back(address);
ports.push_back(port);
}
// Populate shard_contexts.
for (size_t i = 0; i < addresses.size(); ++i) {
shard_contexts_.push_back(std::make_shared<RedisContext>());
}
RAY_CHECK(shard_contexts_.size() == addresses.size());
for (size_t i = 0; i < addresses.size(); ++i) {
RAY_CHECK_OK(
shard_contexts_[i]->Connect(addresses[i], ports[i], /*sharding=*/true));
}
} else {
shard_contexts_.push_back(std::make_shared<RedisContext>());
RAY_CHECK_OK(shard_contexts_[0]->Connect(address, port, /*sharding=*/true));
}
client_table_.reset(new ClientTable({primary_context_}, this, client_id));
error_table_.reset(new ErrorTable({primary_context_}, this));
driver_table_.reset(new DriverTable({primary_context_}, this));
// Tables below would be sharded.
object_table_.reset(new ObjectTable(shard_contexts_, this, command_type));
actor_table_.reset(new ActorTable(shard_contexts_, this));
task_table_.reset(new TaskTable(shard_contexts_, this, command_type));
raylet_task_table_.reset(new raylet::TaskTable(shard_contexts_, this, command_type));
task_reconstruction_log_.reset(new TaskReconstructionLog(shard_contexts_, this));
task_lease_table_.reset(new TaskLeaseTable(shard_contexts_, this));
heartbeat_table_.reset(new HeartbeatTable(shard_contexts_, this));
profile_table_.reset(new ProfileTable(shard_contexts_, this));
command_type_ = command_type;
// TODO(swang): Call the client table's Connect() method here. To do this,
// we need to make sure that we are attached to an event loop first. This
// currently isn't possible because the aeEventLoop, which we use for
// testing, requires us to connect to Redis first.
}
#if RAY_USE_NEW_GCS
// Use of kChain currently only applies to Table::Add which affects only the
// task table, and when RAY_USE_NEW_GCS is set at compile time.
AsyncGcsClient::AsyncGcsClient(const ClientID &client_id)
: AsyncGcsClient(client_id, CommandType::kChain) {}
AsyncGcsClient::AsyncGcsClient(const std::string &address, int port,
const ClientID &client_id, bool is_test_client = false)
: AsyncGcsClient(address, port, client_id, CommandType::kChain, is_test_client) {}
#else
AsyncGcsClient::AsyncGcsClient(const ClientID &client_id)
: AsyncGcsClient(client_id, CommandType::kRegular) {}
AsyncGcsClient::AsyncGcsClient(const std::string &address, int port,
const ClientID &client_id, bool is_test_client = false)
: AsyncGcsClient(address, port, client_id, CommandType::kRegular, is_test_client) {}
#endif // RAY_USE_NEW_GCS
AsyncGcsClient::AsyncGcsClient(CommandType command_type)
: AsyncGcsClient(ClientID::from_random(), command_type) {}
AsyncGcsClient::AsyncGcsClient(const std::string &address, int port,
CommandType command_type)
: AsyncGcsClient(address, port, ClientID::from_random(), command_type) {}
AsyncGcsClient::AsyncGcsClient() : AsyncGcsClient(ClientID::from_random()) {}
AsyncGcsClient::AsyncGcsClient(const std::string &address, int port,
CommandType command_type, bool is_test_client)
: AsyncGcsClient(address, port, ClientID::from_random(), command_type,
is_test_client) {}
Status AsyncGcsClient::Connect(const std::string &address, int port, bool sharding) {
RAY_RETURN_NOT_OK(context_->Connect(address, port, sharding));
RAY_RETURN_NOT_OK(primary_context_->Connect(address, port, /*sharding=*/false));
// TODO(swang): Call the client table's Connect() method here. To do this,
// we need to make sure that we are attached to an event loop first. This
// currently isn't possible because the aeEventLoop, which we use for
// testing, requires us to connect to Redis first.
return Status::OK();
}
AsyncGcsClient::AsyncGcsClient(const std::string &address, int port)
: AsyncGcsClient(address, port, ClientID::from_random()) {}
AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, bool is_test_client)
: AsyncGcsClient(address, port, ClientID::from_random(), is_test_client) {}
Status Attach(plasma::EventLoop &event_loop) {
// TODO(pcm): Implement this via
@ -55,9 +156,14 @@ Status Attach(plasma::EventLoop &event_loop) {
}
Status AsyncGcsClient::Attach(boost::asio::io_service &io_service) {
asio_async_client_.reset(new RedisAsioClient(io_service, context_->async_context()));
asio_subscribe_client_.reset(
new RedisAsioClient(io_service, context_->subscribe_context()));
// Take care of sharding contexts.
RAY_CHECK(shard_asio_async_clients_.empty()) << "Attach shall be called only once";
for (std::shared_ptr<RedisContext> context : shard_contexts_) {
shard_asio_async_clients_.emplace_back(
new RedisAsioClient(io_service, context->async_context()));
shard_asio_subscribe_clients_.emplace_back(
new RedisAsioClient(io_service, context->subscribe_context()));
}
asio_async_auxiliary_client_.reset(
new RedisAsioClient(io_service, primary_context_->async_context()));
asio_subscribe_auxiliary_client_.reset(

View file

@ -24,21 +24,22 @@ class RAY_EXPORT AsyncGcsClient {
/// Attach() must be called. To read and write from the GCS tables requires a
/// further call to Connect() to the client table.
///
/// \param client_id The ID to assign to the client.
/// \param command_type GCS command type. If CommandType::kChain, chain-replicated
/// versions of the tables might be used, if available.
AsyncGcsClient(const ClientID &client_id, CommandType command_type);
AsyncGcsClient(const ClientID &client_id);
AsyncGcsClient(CommandType command_type);
AsyncGcsClient();
/// Connect to the GCS.
///
/// \param address The GCS IP address.
/// \param port The GCS port.
/// \param sharding If true, use sharded redis for the GCS.
/// \return Status.
Status Connect(const std::string &address, int port, bool sharding);
/// \param client_id The ID to assign to the client.
/// \param command_type GCS command type. If CommandType::kChain, chain-replicated
/// versions of the tables might be used, if available.
AsyncGcsClient(const std::string &address, int port, const ClientID &client_id,
CommandType command_type, bool is_test_client);
AsyncGcsClient(const std::string &address, int port, const ClientID &client_id,
bool is_test_client);
AsyncGcsClient(const std::string &address, int port, CommandType command_type);
AsyncGcsClient(const std::string &address, int port, CommandType command_type,
bool is_test_client);
AsyncGcsClient(const std::string &address, int port);
AsyncGcsClient(const std::string &address, int port, bool is_test_client);
/// Attach this client to a plasma event loop. Note that only
/// one event loop should be attached at a time.
Status Attach(plasma::EventLoop &event_loop);
@ -71,7 +72,7 @@ class RAY_EXPORT AsyncGcsClient {
Status GetExport(const std::string &driver_id, int64_t export_index,
const GetExportCallback &done_callback);
std::shared_ptr<RedisContext> context() { return context_; }
std::vector<std::shared_ptr<RedisContext>> shard_contexts() { return shard_contexts_; }
std::shared_ptr<RedisContext> primary_context() { return primary_context_; }
private:
@ -88,9 +89,9 @@ class RAY_EXPORT AsyncGcsClient {
std::unique_ptr<ProfileTable> profile_table_;
std::unique_ptr<ClientTable> client_table_;
// The following contexts write to the data shard
std::shared_ptr<RedisContext> context_;
std::unique_ptr<RedisAsioClient> asio_async_client_;
std::unique_ptr<RedisAsioClient> asio_subscribe_client_;
std::vector<std::shared_ptr<RedisContext>> shard_contexts_;
std::vector<std::unique_ptr<RedisAsioClient>> shard_asio_async_clients_;
std::vector<std::unique_ptr<RedisAsioClient>> shard_asio_subscribe_clients_;
// The following context writes everything to the primary shard
std::shared_ptr<RedisContext> primary_context_;
std::unique_ptr<DriverTable> driver_table_;

View file

@ -28,9 +28,8 @@ static inline void flushall_redis(void) {
class TestGcs : public ::testing::Test {
public:
TestGcs(CommandType command_type) : num_callbacks_(0), command_type_(command_type) {
client_ = std::make_shared<gcs::AsyncGcsClient>(command_type_);
RAY_CHECK_OK(client_->Connect("127.0.0.1", 6379, /*sharding=*/false));
client_ = std::make_shared<gcs::AsyncGcsClient>("127.0.0.1", 6379, command_type_,
/*is_test_client=*/true);
job_id_ = JobID::from_random();
}
@ -60,7 +59,10 @@ class TestGcsWithAe : public TestGcs {
public:
TestGcsWithAe(CommandType command_type) : TestGcs(command_type) {
loop_ = aeCreateEventLoop(1024);
RAY_CHECK_OK(client_->context()->AttachToEventLoop(loop_));
RAY_CHECK_OK(client_->primary_context()->AttachToEventLoop(loop_));
for (auto &context : client_->shard_contexts()) {
RAY_CHECK_OK(context->AttachToEventLoop(loop_));
}
}
TestGcsWithAe() : TestGcsWithAe(CommandType::kRegular) {}

View file

@ -135,69 +135,6 @@ RedisContext::~RedisContext() {
}
}
static void GetRedisShards(redisContext *context, std::vector<std::string> *addresses,
std::vector<int> *ports) {
// Get the total number of Redis shards in the system.
int num_attempts = 0;
redisReply *reply = nullptr;
while (num_attempts < RayConfig::instance().redis_db_connect_retries()) {
// Try to read the number of Redis shards from the primary shard. If the
// entry is present, exit.
reply = reinterpret_cast<redisReply *>(redisCommand(context, "GET NumRedisShards"));
if (reply->type != REDIS_REPLY_NIL) {
break;
}
// Sleep for a little, and try again if the entry isn't there yet. */
freeReplyObject(reply);
usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000);
num_attempts++;
}
RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries())
<< "No entry found for NumRedisShards";
RAY_CHECK(reply->type == REDIS_REPLY_STRING) << "Expected string, found Redis type "
<< reply->type << " for NumRedisShards";
int num_redis_shards = atoi(reply->str);
RAY_CHECK(num_redis_shards >= 1) << "Expected at least one Redis shard, "
<< "found " << num_redis_shards;
freeReplyObject(reply);
// Get the addresses of all of the Redis shards.
num_attempts = 0;
while (num_attempts < RayConfig::instance().redis_db_connect_retries()) {
// Try to read the Redis shard locations from the primary shard. If we find
// that all of them are present, exit.
reply =
reinterpret_cast<redisReply *>(redisCommand(context, "LRANGE RedisShards 0 -1"));
if (static_cast<int>(reply->elements) == num_redis_shards) {
break;
}
// Sleep for a little, and try again if not all Redis shard addresses have
// been added yet.
freeReplyObject(reply);
usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000);
num_attempts++;
}
RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries())
<< "Expected " << num_redis_shards << " Redis shard addresses, found "
<< reply->elements;
// Parse the Redis shard addresses.
for (size_t i = 0; i < reply->elements; ++i) {
// Parse the shard addresses and ports.
RAY_CHECK(reply->element[i]->type == REDIS_REPLY_STRING);
std::string addr;
std::stringstream ss(reply->element[i]->str);
getline(ss, addr, ':');
addresses->push_back(addr);
int port;
ss >> port;
ports->push_back(port);
}
freeReplyObject(reply);
}
Status RedisContext::Connect(const std::string &address, int port, bool sharding) {
int connection_attempts = 0;
context_ = redisConnect(address.c_str(), port);
@ -223,31 +160,17 @@ Status RedisContext::Connect(const std::string &address, int port, bool sharding
REDIS_CHECK_ERROR(context_, reply);
freeReplyObject(reply);
std::string redis_address;
int redis_port;
if (sharding) {
// Get the redis data shard
std::vector<std::string> addresses;
std::vector<int> ports;
GetRedisShards(context_, &addresses, &ports);
redis_address = addresses[0];
redis_port = ports[0];
} else {
redis_address = address;
redis_port = port;
}
// Connect to async context
async_context_ = redisAsyncConnect(redis_address.c_str(), redis_port);
async_context_ = redisAsyncConnect(address.c_str(), port);
if (async_context_ == nullptr || async_context_->err) {
RAY_LOG(FATAL) << "Could not establish connection to redis " << redis_address << ":"
<< redis_port;
RAY_LOG(FATAL) << "Could not establish connection to redis " << address << ":"
<< port;
}
// Connect to subscribe context
subscribe_context_ = redisAsyncConnect(redis_address.c_str(), redis_port);
subscribe_context_ = redisAsyncConnect(address.c_str(), port);
if (subscribe_context_ == nullptr || subscribe_context_->err) {
RAY_LOG(FATAL) << "Could not establish subscribe connection to redis "
<< redis_address << ":" << redis_port;
RAY_LOG(FATAL) << "Could not establish subscribe connection to redis " << address
<< ":" << port;
}
return Status::OK();
}

View file

@ -88,6 +88,7 @@ class RedisContext {
/// \return Status.
Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel,
const RedisCallback &redisCallback, int64_t *out_callback_index);
redisContext *sync_context() { return context_; }
redisAsyncContext *async_context() { return async_context_; }
redisAsyncContext *subscribe_context() { return subscribe_context_; };

View file

@ -47,9 +47,9 @@ Status Log<ID, Data>::Append(const JobID &job_id, const ID &id,
flatbuffers::FlatBufferBuilder fbb;
fbb.ForceDefaults(true);
fbb.Finish(Data::Pack(fbb, dataT.get()));
return context_->RunAsync(GetLogAppendCommand(command_type_), id,
fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
pubsub_channel_, std::move(callback));
return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id,
fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
pubsub_channel_, std::move(callback));
}
template <typename ID, typename Data>
@ -71,9 +71,9 @@ Status Log<ID, Data>::AppendAt(const JobID &job_id, const ID &id,
flatbuffers::FlatBufferBuilder fbb;
fbb.ForceDefaults(true);
fbb.Finish(Data::Pack(fbb, dataT.get()));
return context_->RunAsync(GetLogAppendCommand(command_type_), id,
fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
pubsub_channel_, std::move(callback), log_length);
return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id,
fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
pubsub_channel_, std::move(callback), log_length);
}
template <typename ID, typename Data>
@ -96,8 +96,8 @@ Status Log<ID, Data>::Lookup(const JobID &job_id, const ID &id, const Callback &
return true;
};
std::vector<uint8_t> nil;
return context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), prefix_,
pubsub_channel_, std::move(callback));
return GetRedisContext(id)->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(),
prefix_, pubsub_channel_, std::move(callback));
}
template <typename ID, typename Data>
@ -136,8 +136,12 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
// more subscription messages.
return false;
};
return context_->SubscribeAsync(client_id, pubsub_channel_, std::move(callback),
&subscribe_callback_index_);
subscribe_callback_index_ = 1;
for (auto &context : shard_contexts_) {
RAY_RETURN_NOT_OK(context->SubscribeAsync(client_id, pubsub_channel_, callback,
&subscribe_callback_index_));
}
return Status::OK();
}
template <typename ID, typename Data>
@ -145,8 +149,9 @@ Status Log<ID, Data>::RequestNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) {
RAY_CHECK(subscribe_callback_index_ >= 0)
<< "Client requested notifications on a key before Subscribe completed";
return context_->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, client_id.data(),
client_id.size(), prefix_, pubsub_channel_, nullptr);
return GetRedisContext(id)->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id,
client_id.data(), client_id.size(), prefix_,
pubsub_channel_, nullptr);
}
template <typename ID, typename Data>
@ -154,8 +159,9 @@ Status Log<ID, Data>::CancelNotifications(const JobID &job_id, const ID &id,
const ClientID &client_id) {
RAY_CHECK(subscribe_callback_index_ >= 0)
<< "Client canceled notifications on a key before Subscribe completed";
return context_->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, client_id.data(),
client_id.size(), prefix_, pubsub_channel_, nullptr);
return GetRedisContext(id)->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id,
client_id.data(), client_id.size(), prefix_,
pubsub_channel_, nullptr);
}
template <typename ID, typename Data>
@ -170,8 +176,9 @@ Status Table<ID, Data>::Add(const JobID &job_id, const ID &id,
flatbuffers::FlatBufferBuilder fbb;
fbb.ForceDefaults(true);
fbb.Finish(Data::Pack(fbb, dataT.get()));
return context_->RunAsync(GetTableAddCommand(command_type_), id, fbb.GetBufferPointer(),
fbb.GetSize(), prefix_, pubsub_channel_, std::move(callback));
return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id,
fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
pubsub_channel_, std::move(callback));
}
template <typename ID, typename Data>

View file

@ -99,8 +99,8 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
AsyncGcsClient *client;
};
Log(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: context_(context),
Log(const std::vector<std::shared_ptr<RedisContext>> &contexts, AsyncGcsClient *client)
: shard_contexts_(contexts),
client_(client),
pubsub_channel_(TablePubsub::NO_PUBLISH),
prefix_(TablePrefix::UNUSED),
@ -190,8 +190,12 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
const ClientID &client_id);
protected:
std::shared_ptr<RedisContext> GetRedisContext(const ID &id) {
static std::hash<ray::UniqueID> index;
return shard_contexts_[index(id) % shard_contexts_.size()];
}
/// The connection to the GCS.
std::shared_ptr<RedisContext> context_;
std::vector<std::shared_ptr<RedisContext>> shard_contexts_;
/// The GCS client.
AsyncGcsClient *client_;
/// The pubsub channel to subscribe to for notifications about keys in this
@ -245,8 +249,9 @@ class Table : private Log<ID, Data>,
/// request and receive notifications.
using SubscriptionCallback = typename Log<ID, Data>::SubscriptionCallback;
Table(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Log<ID, Data>(context, client) {}
Table(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Log<ID, Data>(contexts, client) {}
using Log<ID, Data>::RequestNotifications;
using Log<ID, Data>::CancelNotifications;
@ -296,24 +301,26 @@ class Table : private Log<ID, Data>,
const SubscriptionCallback &done);
protected:
using Log<ID, Data>::context_;
using Log<ID, Data>::shard_contexts_;
using Log<ID, Data>::client_;
using Log<ID, Data>::pubsub_channel_;
using Log<ID, Data>::prefix_;
using Log<ID, Data>::command_type_;
using Log<ID, Data>::GetRedisContext;
};
class ObjectTable : public Log<ObjectID, ObjectTableData> {
public:
ObjectTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Log(context, client) {
ObjectTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Log(contexts, client) {
pubsub_channel_ = TablePubsub::OBJECT;
prefix_ = TablePrefix::OBJECT;
};
ObjectTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client,
gcs::CommandType command_type)
: ObjectTable(context, client) {
ObjectTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client, gcs::CommandType command_type)
: ObjectTable(contexts, client) {
command_type_ = command_type;
};
@ -322,8 +329,9 @@ class ObjectTable : public Log<ObjectID, ObjectTableData> {
class HeartbeatTable : public Table<ClientID, HeartbeatTableData> {
public:
HeartbeatTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Table(context, client) {
HeartbeatTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Table(contexts, client) {
pubsub_channel_ = TablePubsub::HEARTBEAT;
prefix_ = TablePrefix::HEARTBEAT;
}
@ -332,8 +340,9 @@ class HeartbeatTable : public Table<ClientID, HeartbeatTableData> {
class DriverTable : public Log<JobID, DriverTableData> {
public:
DriverTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Log(context, client) {
DriverTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Log(contexts, client) {
pubsub_channel_ = TablePubsub::DRIVER;
prefix_ = TablePrefix::DRIVER;
};
@ -349,8 +358,9 @@ class DriverTable : public Log<JobID, DriverTableData> {
class FunctionTable : public Table<ObjectID, FunctionTableData> {
public:
FunctionTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Table(context, client) {
FunctionTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Table(contexts, client) {
pubsub_channel_ = TablePubsub::NO_PUBLISH;
prefix_ = TablePrefix::FUNCTION;
};
@ -361,8 +371,9 @@ using ClassTable = Table<ClassID, ClassTableData>;
// TODO(swang): Set the pubsub channel for the actor table.
class ActorTable : public Log<ActorID, ActorTableData> {
public:
ActorTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Log(context, client) {
ActorTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Log(contexts, client) {
pubsub_channel_ = TablePubsub::ACTOR;
prefix_ = TablePrefix::ACTOR;
}
@ -370,17 +381,18 @@ class ActorTable : public Log<ActorID, ActorTableData> {
class TaskReconstructionLog : public Log<TaskID, TaskReconstructionData> {
public:
TaskReconstructionLog(const std::shared_ptr<RedisContext> &context,
TaskReconstructionLog(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Log(context, client) {
: Log(contexts, client) {
prefix_ = TablePrefix::TASK_RECONSTRUCTION;
}
};
class TaskLeaseTable : public Table<TaskID, TaskLeaseData> {
public:
TaskLeaseTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Table(context, client) {
TaskLeaseTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Table(contexts, client) {
pubsub_channel_ = TablePubsub::TASK_LEASE;
prefix_ = TablePrefix::TASK_LEASE;
}
@ -397,7 +409,8 @@ class TaskLeaseTable : public Table<TaskID, TaskLeaseData> {
std::vector<std::string> args = {"PEXPIRE",
EnumNameTablePrefix(prefix_) + id.binary(),
std::to_string(data->timeout)};
return context_->RunArgvAsync(args);
return GetRedisContext(id)->RunArgvAsync(args);
}
};
@ -405,15 +418,16 @@ namespace raylet {
class TaskTable : public Table<TaskID, ray::protocol::Task> {
public:
TaskTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Table(context, client) {
TaskTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Table(contexts, client) {
pubsub_channel_ = TablePubsub::RAYLET_TASK;
prefix_ = TablePrefix::RAYLET_TASK;
}
TaskTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client,
gcs::CommandType command_type)
: TaskTable(context, client) {
TaskTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client, gcs::CommandType command_type)
: TaskTable(contexts, client) {
command_type_ = command_type;
};
};
@ -422,15 +436,16 @@ class TaskTable : public Table<TaskID, ray::protocol::Task> {
class TaskTable : public Table<TaskID, TaskTableData> {
public:
TaskTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Table(context, client) {
TaskTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Table(contexts, client) {
pubsub_channel_ = TablePubsub::TASK;
prefix_ = TablePrefix::TASK;
};
TaskTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client,
gcs::CommandType command_type)
: TaskTable(context, client) {
TaskTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client, gcs::CommandType command_type)
: TaskTable(contexts, client) {
command_type_ = command_type;
}
@ -466,9 +481,11 @@ class TaskTable : public Table<TaskID, TaskTableData> {
};
flatbuffers::FlatBufferBuilder fbb;
fbb.Finish(TaskTableTestAndUpdate::Pack(fbb, data.get()));
RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_TEST_AND_UPDATE", id,
fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
pubsub_channel_, redisCallback));
for (auto context : shard_contexts_) {
RAY_RETURN_NOT_OK(context->RunAsync("RAY.TABLE_TEST_AND_UPDATE", id,
fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
pubsub_channel_, redisCallback));
}
return Status::OK();
}
@ -504,8 +521,9 @@ Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id,
class ErrorTable : private Log<JobID, ErrorTableData> {
public:
ErrorTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Log(context, client) {
ErrorTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Log(contexts, client) {
pubsub_channel_ = TablePubsub::ERROR_INFO;
prefix_ = TablePrefix::ERROR_INFO;
};
@ -528,8 +546,9 @@ class ErrorTable : private Log<JobID, ErrorTableData> {
class ProfileTable : private Log<UniqueID, ProfileTableData> {
public:
ProfileTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Log(context, client) {
ProfileTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Log(contexts, client) {
prefix_ = TablePrefix::PROFILE;
};
@ -574,9 +593,9 @@ class ClientTable : private Log<UniqueID, ClientTableData> {
public:
using ClientTableCallback = std::function<void(
AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data)>;
ClientTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client,
const ClientID &client_id)
: Log(context, client),
ClientTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client, const ClientID &client_id)
: Log(contexts, client),
// We set the client log's key equal to nil so that all instances of
// ClientTable have the same key.
client_log_key_(UniqueID::nil()),

View file

@ -43,7 +43,6 @@ class MockServer {
private:
ray::Status RegisterGcs(boost::asio::io_service &io_service) {
RAY_RETURN_NOT_OK(gcs_client_->Connect("127.0.0.1", 6379, /*sharding=*/false));
RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service));
boost::asio::ip::tcp::endpoint endpoint = object_manager_acceptor_.local_endpoint();
@ -130,7 +129,8 @@ class TestObjectManagerBase : public ::testing::Test {
int push_timeout_ms = 10000;
// start first server
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(
new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true));
ObjectManagerConfig om_config_1;
om_config_1.store_socket_name = store_id_1;
om_config_1.pull_timeout_ms = pull_timeout_ms;
@ -141,7 +141,8 @@ class TestObjectManagerBase : public ::testing::Test {
server1.reset(new MockServer(main_service, om_config_1, gcs_client_1));
// start second server
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(
new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true));
ObjectManagerConfig om_config_2;
om_config_2.store_socket_name = store_id_2;
om_config_2.pull_timeout_ms = pull_timeout_ms;

View file

@ -34,7 +34,6 @@ class MockServer {
private:
ray::Status RegisterGcs(boost::asio::io_service &io_service) {
RAY_RETURN_NOT_OK(gcs_client_->Connect("127.0.0.1", 6379, /*sharding=*/false));
RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service));
boost::asio::ip::tcp::endpoint endpoint = object_manager_acceptor_.local_endpoint();
@ -115,7 +114,8 @@ class TestObjectManagerBase : public ::testing::Test {
push_timeout_ms = 1000;
// start first server
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(
new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true));
ObjectManagerConfig om_config_1;
om_config_1.store_socket_name = store_id_1;
om_config_1.pull_timeout_ms = pull_timeout_ms;
@ -126,7 +126,8 @@ class TestObjectManagerBase : public ::testing::Test {
server1.reset(new MockServer(main_service, om_config_1, gcs_client_1));
// start second server
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(
new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true));
ObjectManagerConfig om_config_2;
om_config_2.store_socket_name = store_id_2;
om_config_2.pull_timeout_ms = pull_timeout_ms;

View file

@ -89,7 +89,7 @@ int main(int argc, char *argv[]) {
<< "object_chunk_size = " << object_manager_config.object_chunk_size;
// initialize mock gcs & object directory
auto gcs_client = std::make_shared<ray::gcs::AsyncGcsClient>();
auto gcs_client = std::make_shared<ray::gcs::AsyncGcsClient>(redis_address, redis_port);
RAY_LOG(DEBUG) << "Initializing GCS client "
<< gcs_client->client_table().GetLocalClientId();

View file

@ -15,10 +15,9 @@ namespace raylet {
/// the client table, which broadcasts the event to all other Raylets.
Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_address,
int redis_port)
: gcs_client_(),
: gcs_client_(redis_address, redis_port),
num_heartbeats_timeout_(RayConfig::instance().num_heartbeats_timeout()),
heartbeat_timer_(io_service) {
RAY_CHECK_OK(gcs_client_.Connect(redis_address, redis_port, /*sharding=*/true));
RAY_CHECK_OK(gcs_client_.Attach(io_service));
}

View file

@ -54,7 +54,8 @@ class TestObjectManagerBase : public ::testing::Test {
std::string store_sock_2 = StartStore("2");
// start first server
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(
new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true));
ObjectManagerConfig om_config_1;
om_config_1.store_socket_name = store_sock_1;
om_config_1.push_timeout_ms = 10000;
@ -63,7 +64,8 @@ class TestObjectManagerBase : public ::testing::Test {
GetNodeManagerConfig("raylet_1", store_sock_1), om_config_1, gcs_client_1));
// start second server
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(
new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true));
ObjectManagerConfig om_config_2;
om_config_2.store_socket_name = store_sock_2;
om_config_2.push_timeout_ms = 10000;

View file

@ -54,7 +54,6 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address,
const std::string &redis_address, int redis_port,
boost::asio::io_service &io_service,
const NodeManagerConfig &node_manager_config) {
RAY_RETURN_NOT_OK(gcs_client_->Connect(redis_address, redis_port, /*sharding=*/true));
RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service));
ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient();

View file

@ -1295,11 +1295,7 @@ class APITestSharded(APITest):
if kwargs is None:
kwargs = {}
kwargs["start_ray_local"] = True
if os.environ.get("RAY_USE_XRAY") == "1":
print("XRAY currently supports only a single Redis shard.")
kwargs["num_redis_shards"] = 1
else:
kwargs["num_redis_shards"] = 20
kwargs["num_redis_shards"] = 20
kwargs["redirect_output"] = True
ray.worker._init(**kwargs)