[Core] Fix maximum_startup_concurrency caused by AnnounceWorkerPort (#10853)

* Fix maximum_startup_concurrency caused by AnnounceWorkerPort

* Address comment

* Update src/ray/raylet/worker_pool.h

Co-authored-by: Eric Liang <ekhliang@gmail.com>

Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
Kai Yang 2020-09-24 11:27:44 +08:00 committed by GitHub
parent 52e1495e30
commit b251a445dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 86 additions and 20 deletions

View file

@ -1330,6 +1330,7 @@ void NodeManager::ProcessAnnounceWorkerPortMessage(
int port = message->port();
worker->Connect(port);
if (is_worker) {
worker_pool_.OnWorkerStarted(worker);
HandleWorkerAvailable(worker->Connection());
}
}

View file

@ -437,14 +437,16 @@ Status WorkerPool::RegisterWorker(const std::shared_ptr<WorkerInterface> &worker
RAY_CHECK(worker);
auto &state = GetStateForLanguage(worker->GetLanguage());
auto process = Process::FromPid(pid);
auto it = state.starting_worker_processes.find(Process::FromPid(pid));
if (it == state.starting_worker_processes.end()) {
RAY_LOG(WARNING) << "Received a register request from an unknown worker " << pid;
if (state.starting_worker_processes.count(process) == 0) {
RAY_LOG(WARNING) << "Received a register request from an unknown worker "
<< process.GetId();
Status status = Status::Invalid("Unknown worker");
send_reply_callback(status, /*port=*/0);
return status;
}
worker->SetProcess(process);
// The port that this worker's gRPC server should listen on. 0 if the worker
// should bind on a random port.
@ -457,20 +459,8 @@ Status WorkerPool::RegisterWorker(const std::shared_ptr<WorkerInterface> &worker
RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << port
<< ", worker_type: " << rpc::WorkerType_Name(worker->GetWorkerType());
worker->SetAssignedPort(port);
worker->SetProcess(it->first);
it->second--;
if (it->second == 0) {
state.starting_worker_processes.erase(it);
// We may have slots to start more workers now.
TryStartIOWorkers(worker->GetLanguage(), state);
}
RAY_CHECK(worker->GetProcess().GetId() == pid);
state.registered_workers.insert(worker);
if (worker->GetWorkerType() == rpc::WorkerType::IO_WORKER) {
state.registered_io_workers.insert(worker);
state.num_starting_io_workers--;
}
if (RayConfig::instance().enable_multi_tenancy()) {
auto dedicated_workers_it = state.worker_pids_to_assigned_jobs.find(pid);
@ -490,11 +480,38 @@ Status WorkerPool::RegisterWorker(const std::shared_ptr<WorkerInterface> &worker
worker->AssignJobId(job_id);
// We don't call state.worker_pids_to_assigned_jobs.erase(job_id) here
// because we allow multi-workers per worker process.
}
// Send the reply immediately for worker registrations.
send_reply_callback(Status::OK(), port);
return Status::OK();
}
void WorkerPool::OnWorkerStarted(const std::shared_ptr<WorkerInterface> &worker) {
auto &state = GetStateForLanguage(worker->GetLanguage());
const auto &process = worker->GetProcess();
RAY_CHECK(process.IsValid());
auto it = state.starting_worker_processes.find(process);
if (it != state.starting_worker_processes.end()) {
it->second--;
if (it->second == 0) {
state.starting_worker_processes.erase(it);
// We may have slots to start more workers now.
TryStartIOWorkers(worker->GetLanguage(), state);
}
}
if (worker->GetWorkerType() == rpc::WorkerType::IO_WORKER) {
state.registered_io_workers.insert(worker);
state.num_starting_io_workers--;
}
if (RayConfig::instance().enable_multi_tenancy()) {
// This is a workaround to finish driver registration after all initial workers are
// registered to Raylet if and only if Raylet is started by a Python driver and the
// job config is not set in `ray.init(...)`.
if (first_job_ == job_id && worker->GetLanguage() == Language::PYTHON) {
if (first_job_ == worker->GetAssignedJobId() &&
worker->GetLanguage() == Language::PYTHON) {
if (++first_job_registered_python_worker_count_ ==
first_job_driver_wait_num_python_workers_) {
if (first_job_send_register_client_reply_to_driver_) {
@ -504,10 +521,6 @@ Status WorkerPool::RegisterWorker(const std::shared_ptr<WorkerInterface> &worker
}
}
}
// Send the reply immediately for worker registrations.
send_reply_callback(Status::OK(), port);
return Status::OK();
}
Status WorkerPool::RegisterDriver(const std::shared_ptr<WorkerInterface> &driver,

View file

@ -124,6 +124,12 @@ class WorkerPool : public WorkerPoolInterface {
Status RegisterWorker(const std::shared_ptr<WorkerInterface> &worker, pid_t pid,
std::function<void(Status, int)> send_reply_callback);
/// To be invoked when a worker is started. This method should be called when the worker
/// announces its port.
///
/// \param[in] worker The worker which is started.
void OnWorkerStarted(const std::shared_ptr<WorkerInterface> &worker);
/// Register a new driver.
///
/// \param[in] worker The driver to be registered.

View file

@ -248,6 +248,7 @@ TEST_P(WorkerPoolTest, HandleWorkerRegistration) {
// Check that we cannot lookup the worker before it's registered.
ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), nullptr);
RAY_CHECK_OK(worker_pool_->RegisterWorker(worker, proc.GetId(), [](Status, int) {}));
worker_pool_->OnWorkerStarted(worker);
// Check that we can lookup the worker after it's registered.
ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), worker);
}
@ -441,6 +442,51 @@ TEST_P(WorkerPoolTest, PopWorkerMultiTenancy) {
}
}
TEST_P(WorkerPoolTest, MaximumStartupConcurrency) {
auto task_spec = ExampleTaskSpec();
std::vector<Process> started_processes;
// Try to pop some workers. Some worker processes will be started.
for (int i = 0; i < MAXIMUM_STARTUP_CONCURRENCY; i++) {
auto worker = worker_pool_->PopWorker(task_spec);
RAY_CHECK(!worker);
auto last_process = worker_pool_->LastStartedWorkerProcess();
RAY_CHECK(last_process.IsValid());
started_processes.push_back(last_process);
}
// Can't start a new worker process at this point.
ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkerProcessesStarting());
RAY_CHECK(!worker_pool_->PopWorker(task_spec));
ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkerProcessesStarting());
std::vector<std::shared_ptr<WorkerInterface>> workers;
// Call `RegisterWorker` to emulate worker registration.
for (const auto &process : started_processes) {
auto worker = CreateWorker(Process());
RAY_CHECK_OK(
worker_pool_->RegisterWorker(worker, process.GetId(), [](Status, int) {}));
// Calling `RegisterWorker` won't affect the counter of starting worker processes.
ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkerProcessesStarting());
workers.push_back(worker);
}
// Can't start a new worker process at this point.
ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkerProcessesStarting());
RAY_CHECK(!worker_pool_->PopWorker(task_spec));
ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkerProcessesStarting());
// Call `OnWorkerStarted` to emulate worker port announcement.
for (size_t i = 0; i < workers.size(); i++) {
worker_pool_->OnWorkerStarted(workers[i]);
// Calling `OnWorkerStarted` will affect the counter of starting worker processes.
ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY - i - 1,
worker_pool_->NumWorkerProcessesStarting());
}
ASSERT_EQ(0, worker_pool_->NumWorkerProcessesStarting());
}
INSTANTIATE_TEST_CASE_P(WorkerPoolMultiTenancyTest, WorkerPoolTest,
::testing::Values(true, false));