diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index a17057524..f5fded5be 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1101,12 +1101,21 @@ void NodeManager::ProcessRegisterClientRequestMessage( worker->AssignTaskId(driver_task_id); rpc::JobConfig job_config; job_config.ParseFromString(message->serialized_job_config()->str()); - Status status = worker_pool_.RegisterDriver(worker, job_config, send_reply_callback); - if (status.ok()) { - auto job_data_ptr = gcs::CreateJobTableData(job_id, /*is_dead*/ false, - worker_ip_address, pid, job_config); - RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd(job_data_ptr, nullptr)); - } + + // Send the reply callback only after registration fully completes at the GCS. + auto cb = [this, worker_ip_address, pid, job_id, job_config, + send_reply_callback = std::move(send_reply_callback)](const Status &status, + int assigned_port) { + if (status.ok()) { + auto job_data_ptr = gcs::CreateJobTableData(job_id, /*is_dead*/ false, + worker_ip_address, pid, job_config); + RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd( + job_data_ptr, + [send_reply_callback = std::move(send_reply_callback), assigned_port]( + Status status) { send_reply_callback(status, assigned_port); })); + } + }; + RAY_UNUSED(worker_pool_.RegisterDriver(worker, job_config, std::move(cb))); } }