mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[core] fix bug that actor tasks from reconstructed actor is ignored by scheduling queue (#7637)
This commit is contained in:
parent
1b90196bef
commit
a7a5d172b1
10 changed files with 384 additions and 35 deletions
|
@ -207,6 +207,76 @@ def test_actor_reconstruction_without_task(ray_start_regular):
|
|||
assert wait_for_condition(check_reconstructed)
|
||||
|
||||
|
||||
def test_caller_actor_reconstruction(ray_start_regular):
|
||||
"""Test tasks from a reconstructed actor can be correctly processed
|
||||
by the receiving actor."""
|
||||
|
||||
@ray.remote(max_reconstructions=1)
|
||||
class ReconstructableActor:
|
||||
"""An actor that will be reconstructed at most once."""
|
||||
|
||||
def __init__(self, actor):
|
||||
self.actor = actor
|
||||
|
||||
def increase(self):
|
||||
return ray.get(self.actor.increase.remote())
|
||||
|
||||
def get_pid(self):
|
||||
return os.getpid()
|
||||
|
||||
@ray.remote(max_reconstructions=1)
|
||||
class Actor:
|
||||
"""An actor that will be reconstructed at most once."""
|
||||
|
||||
def __init__(self):
|
||||
self.value = 0
|
||||
|
||||
def increase(self):
|
||||
self.value += 1
|
||||
return self.value
|
||||
|
||||
remote_actor = Actor.remote()
|
||||
actor = ReconstructableActor.remote(remote_actor)
|
||||
# Call increase 3 times
|
||||
for _ in range(3):
|
||||
ray.get(actor.increase.remote())
|
||||
|
||||
# kill the actor.
|
||||
# TODO(zhijunfu): use ray.kill instead.
|
||||
kill_actor(actor)
|
||||
|
||||
# Check that we can still call the actor.
|
||||
assert ray.get(actor.increase.remote()) == 4
|
||||
|
||||
|
||||
def test_caller_task_reconstruction(ray_start_regular):
|
||||
"""Test a retried task from a dead worker can be correctly processed
|
||||
by the receiving actor."""
|
||||
|
||||
@ray.remote(max_retries=5)
|
||||
def RetryableTask(actor):
|
||||
value = ray.get(actor.increase.remote())
|
||||
if value > 2:
|
||||
return value
|
||||
else:
|
||||
os._exit(0)
|
||||
|
||||
@ray.remote(max_reconstructions=1)
|
||||
class Actor:
|
||||
"""An actor that will be reconstructed at most once."""
|
||||
|
||||
def __init__(self):
|
||||
self.value = 0
|
||||
|
||||
def increase(self):
|
||||
self.value += 1
|
||||
return self.value
|
||||
|
||||
remote_actor = Actor.remote()
|
||||
|
||||
assert ray.get(RetryableTask.remote(remote_actor)) == 3
|
||||
|
||||
|
||||
def test_actor_reconstruction_on_node_failure(ray_start_cluster_head):
|
||||
"""Test actor reconstruction when node dies unexpectedly."""
|
||||
cluster = ray_start_cluster_head
|
||||
|
|
|
@ -72,13 +72,16 @@ class WorkerContext {
|
|||
|
||||
int GetNextPutIndex();
|
||||
|
||||
protected:
|
||||
// allow unit test to set.
|
||||
bool current_actor_is_direct_call_ = false;
|
||||
bool current_task_is_direct_call_ = false;
|
||||
|
||||
private:
|
||||
const WorkerType worker_type_;
|
||||
const WorkerID worker_id_;
|
||||
JobID current_job_id_;
|
||||
ActorID current_actor_id_;
|
||||
bool current_actor_is_direct_call_ = false;
|
||||
bool current_task_is_direct_call_ = false;
|
||||
int current_actor_max_concurrency_ = 1;
|
||||
bool current_actor_is_asyncio_ = false;
|
||||
|
||||
|
|
|
@ -140,9 +140,10 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
raylet_task_receiver_ =
|
||||
std::unique_ptr<CoreWorkerRayletTaskReceiver>(new CoreWorkerRayletTaskReceiver(
|
||||
worker_context_.GetWorkerID(), local_raylet_client_, execute_task));
|
||||
direct_task_receiver_ = std::unique_ptr<CoreWorkerDirectTaskReceiver>(
|
||||
new CoreWorkerDirectTaskReceiver(worker_context_, local_raylet_client_,
|
||||
task_execution_service_, execute_task));
|
||||
direct_task_receiver_ =
|
||||
std::unique_ptr<CoreWorkerDirectTaskReceiver>(new CoreWorkerDirectTaskReceiver(
|
||||
worker_context_, task_execution_service_, execute_task,
|
||||
[this] { return local_raylet_client_->TaskDone(); }));
|
||||
}
|
||||
|
||||
// Start RPC server after all the task receivers are properly initialized.
|
||||
|
@ -255,7 +256,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
future_resolver_.reset(new FutureResolver(memory_store_, client_factory));
|
||||
// Unfortunately the raylet client has to be constructed after the receivers.
|
||||
if (direct_task_receiver_ != nullptr) {
|
||||
direct_task_receiver_->Init(client_factory, rpc_address_);
|
||||
direct_task_receiver_->Init(client_factory, rpc_address_, local_raylet_client_);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1144,11 +1145,16 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
|
|||
return_ids.pop_back();
|
||||
task_type = TaskType::ACTOR_CREATION_TASK;
|
||||
SetActorId(task_spec.ActorCreationId());
|
||||
// For an actor, set the timestamp as the time its creation task starts execution.
|
||||
SetCallerCreationTimestamp();
|
||||
RAY_LOG(INFO) << "Creating actor: " << task_spec.ActorCreationId();
|
||||
} else if (task_spec.IsActorTask()) {
|
||||
RAY_CHECK(return_ids.size() > 0);
|
||||
return_ids.pop_back();
|
||||
task_type = TaskType::ACTOR_TASK;
|
||||
} else {
|
||||
// For a non-actor task, set the timestamp as the time it starts execution.
|
||||
SetCallerCreationTimestamp();
|
||||
}
|
||||
|
||||
status = task_execution_callback_(
|
||||
|
@ -1563,4 +1569,9 @@ void CoreWorker::SetActorTitle(const std::string &title) {
|
|||
actor_title_ = title;
|
||||
}
|
||||
|
||||
void CoreWorker::SetCallerCreationTimestamp() {
|
||||
absl::MutexLock lock(&mutex_);
|
||||
direct_actor_submitter_->SetCallerCreationTimestamp(current_sys_time_ms());
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
|
|
@ -120,6 +120,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
|||
|
||||
void SetActorTitle(const std::string &title);
|
||||
|
||||
void SetCallerCreationTimestamp();
|
||||
|
||||
/// Increase the reference count for this object ID.
|
||||
/// Increase the local reference count for this object ID. Should be called
|
||||
/// by the language frontend when a new reference is created.
|
||||
|
|
|
@ -75,9 +75,9 @@ TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
|
|||
return task;
|
||||
}
|
||||
|
||||
class DirectActorTransportTest : public ::testing::Test {
|
||||
class DirectActorSubmitterTest : public ::testing::Test {
|
||||
public:
|
||||
DirectActorTransportTest()
|
||||
DirectActorSubmitterTest()
|
||||
: worker_client_(std::shared_ptr<MockWorkerClient>(new MockWorkerClient())),
|
||||
store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
|
||||
task_finisher_(std::make_shared<MockTaskFinisher>()),
|
||||
|
@ -91,7 +91,7 @@ class DirectActorTransportTest : public ::testing::Test {
|
|||
CoreWorkerDirectActorTaskSubmitter submitter_;
|
||||
};
|
||||
|
||||
TEST_F(DirectActorTransportTest, TestSubmitTask) {
|
||||
TEST_F(DirectActorSubmitterTest, TestSubmitTask) {
|
||||
rpc::Address addr;
|
||||
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
|
||||
|
||||
|
@ -114,7 +114,7 @@ TEST_F(DirectActorTransportTest, TestSubmitTask) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(DirectActorTransportTest, TestDependencies) {
|
||||
TEST_F(DirectActorSubmitterTest, TestDependencies) {
|
||||
rpc::Address addr;
|
||||
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
|
||||
submitter_.ConnectActor(actor_id, addr);
|
||||
|
@ -142,7 +142,7 @@ TEST_F(DirectActorTransportTest, TestDependencies) {
|
|||
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(DirectActorTransportTest, TestOutOfOrderDependencies) {
|
||||
TEST_F(DirectActorSubmitterTest, TestOutOfOrderDependencies) {
|
||||
rpc::Address addr;
|
||||
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
|
||||
submitter_.ConnectActor(actor_id, addr);
|
||||
|
@ -171,7 +171,7 @@ TEST_F(DirectActorTransportTest, TestOutOfOrderDependencies) {
|
|||
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(DirectActorTransportTest, TestActorFailure) {
|
||||
TEST_F(DirectActorSubmitterTest, TestActorFailure) {
|
||||
rpc::Address addr;
|
||||
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
|
||||
gcs::ActorTableData actor_data;
|
||||
|
@ -193,9 +193,182 @@ TEST_F(DirectActorTransportTest, TestActorFailure) {
|
|||
}
|
||||
}
|
||||
|
||||
class MockDependencyWaiterInterface : public DependencyWaiterInterface {
|
||||
public:
|
||||
virtual Status WaitForDirectActorCallArgs(const std::vector<ObjectID> &object_ids,
|
||||
int64_t tag) override {
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter,
|
||||
TaskID caller_id) {
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
task.GetMutableMessage().set_caller_id(caller_id.Binary());
|
||||
task.GetMutableMessage().set_type(TaskType::ACTOR_TASK);
|
||||
task.GetMutableMessage().mutable_actor_task_spec()->set_actor_id(actor_id.Binary());
|
||||
task.GetMutableMessage().mutable_actor_task_spec()->set_actor_counter(counter);
|
||||
task.GetMutableMessage().set_num_returns(1);
|
||||
return task;
|
||||
}
|
||||
|
||||
rpc::PushTaskRequest CreatePushTaskRequestHelper(ActorID actor_id, int64_t counter,
|
||||
WorkerID caller_worker_id,
|
||||
TaskID caller_id,
|
||||
int64_t caller_timestamp) {
|
||||
auto task_spec = CreateActorTaskHelper(actor_id, counter, caller_id);
|
||||
rpc::Address rpc_address;
|
||||
rpc_address.set_worker_id(caller_worker_id.Binary());
|
||||
|
||||
rpc::PushTaskRequest request;
|
||||
request.mutable_caller_address()->CopyFrom(rpc_address);
|
||||
request.mutable_task_spec()->CopyFrom(task_spec.GetMessage());
|
||||
request.set_caller_version(caller_timestamp);
|
||||
request.set_sequence_number(request.task_spec().actor_task_spec().actor_counter());
|
||||
request.set_client_processed_up_to(-1);
|
||||
return request;
|
||||
}
|
||||
|
||||
class MockWorkerContext : public WorkerContext {
|
||||
public:
|
||||
MockWorkerContext(WorkerType worker_type, const JobID &job_id)
|
||||
: WorkerContext(worker_type, job_id) {
|
||||
current_actor_is_direct_call_ = true;
|
||||
}
|
||||
};
|
||||
|
||||
class DirectActorReceiverTest : public ::testing::Test {
|
||||
public:
|
||||
DirectActorReceiverTest()
|
||||
: worker_context_(WorkerType::WORKER, JobID::FromInt(0)),
|
||||
worker_client_(std::shared_ptr<MockWorkerClient>(new MockWorkerClient())),
|
||||
dependency_client_(std::make_shared<MockDependencyWaiterInterface>()) {
|
||||
auto execute_task =
|
||||
std::bind(&DirectActorReceiverTest::MockExecuteTask, this, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
|
||||
receiver_ = std::unique_ptr<CoreWorkerDirectTaskReceiver>(
|
||||
new CoreWorkerDirectTaskReceiver(worker_context_, main_io_service_, execute_task,
|
||||
[] { return Status::OK(); }));
|
||||
receiver_->Init([&](const rpc::Address &addr) { return worker_client_; },
|
||||
rpc_address_, dependency_client_);
|
||||
}
|
||||
|
||||
Status MockExecuteTask(const TaskSpecification &task_spec,
|
||||
const std::shared_ptr<ResourceMappingType> &resource_ids,
|
||||
std::vector<std::shared_ptr<RayObject>> *return_objects,
|
||||
ReferenceCounter::ReferenceTableProto *borrowed_refs) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void StartIOService() { main_io_service_.run(); }
|
||||
|
||||
void StopIOService() { main_io_service_.stop(); }
|
||||
|
||||
std::unique_ptr<CoreWorkerDirectTaskReceiver> receiver_;
|
||||
|
||||
private:
|
||||
rpc::Address rpc_address_;
|
||||
MockWorkerContext worker_context_;
|
||||
boost::asio::io_service main_io_service_;
|
||||
std::shared_ptr<MockWorkerClient> worker_client_;
|
||||
std::shared_ptr<DependencyWaiterInterface> dependency_client_;
|
||||
};
|
||||
|
||||
TEST_F(DirectActorReceiverTest, TestNewTaskFromDifferentWorker) {
|
||||
rpc::Address addr;
|
||||
TaskID current_task_id = TaskID::Nil();
|
||||
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
|
||||
WorkerID worker_id = WorkerID::FromRandom();
|
||||
TaskID caller_id =
|
||||
TaskID::ForActorTask(JobID::FromInt(0), current_task_id, 0, actor_id);
|
||||
|
||||
int64_t curr_timestamp = current_sys_time_ms();
|
||||
int64_t old_timestamp = curr_timestamp - 1000;
|
||||
int64_t new_timestamp = curr_timestamp + 1000;
|
||||
|
||||
int callback_count = 0;
|
||||
|
||||
// Push a task request with actor counter 0. This should scucceed
|
||||
// on the receiver.
|
||||
{
|
||||
auto request =
|
||||
CreatePushTaskRequestHelper(actor_id, 0, worker_id, caller_id, curr_timestamp);
|
||||
rpc::PushTaskReply reply;
|
||||
auto reply_callback = [&callback_count](Status status, std::function<void()> success,
|
||||
std::function<void()> failure) {
|
||||
++callback_count;
|
||||
ASSERT_TRUE(status.ok());
|
||||
};
|
||||
receiver_->HandlePushTask(request, &reply, reply_callback);
|
||||
}
|
||||
|
||||
// Push a task request with actor counter 1. This should scucceed
|
||||
// on the receiver.
|
||||
{
|
||||
auto request =
|
||||
CreatePushTaskRequestHelper(actor_id, 1, worker_id, caller_id, curr_timestamp);
|
||||
rpc::PushTaskReply reply;
|
||||
auto reply_callback = [&callback_count](Status status, std::function<void()> success,
|
||||
std::function<void()> failure) {
|
||||
++callback_count;
|
||||
ASSERT_TRUE(status.ok());
|
||||
};
|
||||
receiver_->HandlePushTask(request, &reply, reply_callback);
|
||||
}
|
||||
|
||||
// Create another request with the same caller id, but a differnt worker id,
|
||||
// and a newer timestamp. This simulates caller reconstruction.
|
||||
// Note that here the task request still has counter 0, which should be
|
||||
// ignored normally, but here it's from a different worker and with a newer
|
||||
// timestamp, in this case it should succeed.
|
||||
{
|
||||
auto worker_id = WorkerID::FromRandom();
|
||||
auto request =
|
||||
CreatePushTaskRequestHelper(actor_id, 0, worker_id, caller_id, new_timestamp);
|
||||
rpc::PushTaskReply reply;
|
||||
auto reply_callback = [&callback_count](Status status, std::function<void()> success,
|
||||
std::function<void()> failure) {
|
||||
++callback_count;
|
||||
ASSERT_TRUE(status.ok());
|
||||
};
|
||||
receiver_->HandlePushTask(request, &reply, reply_callback);
|
||||
}
|
||||
|
||||
// Push a task request with actor counter 1, but with a different worker id,
|
||||
// and a older timstamp. In this case the request should fail.
|
||||
{
|
||||
auto worker_id = WorkerID::FromRandom();
|
||||
auto request =
|
||||
CreatePushTaskRequestHelper(actor_id, 1, worker_id, caller_id, old_timestamp);
|
||||
rpc::PushTaskReply reply;
|
||||
auto reply_callback = [&callback_count](Status status, std::function<void()> success,
|
||||
std::function<void()> failure) {
|
||||
++callback_count;
|
||||
ASSERT_TRUE(!status.ok());
|
||||
};
|
||||
receiver_->HandlePushTask(request, &reply, reply_callback);
|
||||
}
|
||||
|
||||
StartIOService();
|
||||
|
||||
// Wait for all the callbacks to be invoked.
|
||||
auto condition_func = [&callback_count]() -> bool { return callback_count == 4; };
|
||||
|
||||
ASSERT_TRUE(WaitForCondition(condition_func, 10 * 1000));
|
||||
|
||||
StopIOService();
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
|
||||
InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog,
|
||||
ray::RayLog::ShutDownRayLog, argv[0],
|
||||
ray::RayLogLevel::INFO,
|
||||
/*log_dir=*/"");
|
||||
ray::RayLog::InstallFailureSignalHandler();
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
|
|
@ -76,6 +76,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
|||
// fails, then the task data will be gone when the TaskManager attempts to
|
||||
// access the task.
|
||||
request->mutable_task_spec()->CopyFrom(task_spec.GetMessage());
|
||||
request->set_caller_version(caller_creation_timestamp_ms_);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
|
@ -201,9 +202,14 @@ bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) c
|
|||
return (iter != rpc_clients_.end());
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskReceiver::Init(rpc::ClientFactoryFn client_factory,
|
||||
rpc::Address rpc_address) {
|
||||
waiter_.reset(new DependencyWaiterImpl(*local_raylet_client_));
|
||||
void CoreWorkerDirectActorTaskSubmitter::SetCallerCreationTimestamp(int64_t timestamp) {
|
||||
caller_creation_timestamp_ms_ = timestamp;
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskReceiver::Init(
|
||||
rpc::ClientFactoryFn client_factory, rpc::Address rpc_address,
|
||||
std::shared_ptr<DependencyWaiterInterface> dependency_client) {
|
||||
waiter_.reset(new DependencyWaiterImpl(*dependency_client));
|
||||
rpc_address_ = rpc_address;
|
||||
client_factory_ = client_factory;
|
||||
}
|
||||
|
@ -286,7 +292,7 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask(
|
|||
// Tell raylet that an actor creation task has finished execution, so that
|
||||
// raylet can publish actor creation event to GCS, and mark this worker as
|
||||
// actor, thus if this worker dies later raylet will reconstruct the actor.
|
||||
RAY_CHECK_OK(local_raylet_client_->TaskDone());
|
||||
RAY_CHECK_OK(task_done_());
|
||||
}
|
||||
}
|
||||
if (status.IsSystemExit()) {
|
||||
|
@ -316,15 +322,50 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask(
|
|||
send_reply_callback(Status::Invalid("client cancelled stale rpc"), nullptr, nullptr);
|
||||
};
|
||||
|
||||
auto caller_worker_id = WorkerID::FromBinary(request.caller_address().worker_id());
|
||||
auto caller_version = request.caller_version();
|
||||
auto it = scheduling_queue_.find(task_spec.CallerId());
|
||||
if (it != scheduling_queue_.end()) {
|
||||
if (it->second.first.caller_worker_id != caller_worker_id) {
|
||||
// We received a request with the same caller ID, but from a different worker,
|
||||
// this indicates the caller (actor) is reconstructed.
|
||||
if (it->second.first.caller_creation_timestamp_ms < caller_version) {
|
||||
// The new request has a newer caller version, then remove the old entry
|
||||
// from scheduling queue since it's invalid now.
|
||||
RAY_LOG(INFO) << "Remove existing scheduling queue for caller "
|
||||
<< task_spec.CallerId() << " after receiving a "
|
||||
<< "request from a different worker ID with a newer "
|
||||
<< "version, old worker ID: " << it->second.first.caller_worker_id
|
||||
<< ", new worker ID" << caller_worker_id;
|
||||
scheduling_queue_.erase(task_spec.CallerId());
|
||||
it = scheduling_queue_.end();
|
||||
} else {
|
||||
// The existing caller has the newer version, this indicates the request
|
||||
// is from an old caller, which might be possible when network has problems.
|
||||
// In this case fail this request.
|
||||
RAY_LOG(WARNING) << "Ignoring request from an old caller because "
|
||||
<< "it has a smaller timestamp, old worker ID: "
|
||||
<< caller_worker_id << ", current worker ID"
|
||||
<< it->second.first.caller_worker_id;
|
||||
// Fail request with an old caller version.
|
||||
reject_callback();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (it == scheduling_queue_.end()) {
|
||||
SchedulingQueueTag tag;
|
||||
tag.caller_worker_id = caller_worker_id;
|
||||
tag.caller_creation_timestamp_ms = caller_version;
|
||||
auto result = scheduling_queue_.emplace(
|
||||
task_spec.CallerId(), std::unique_ptr<SchedulingQueue>(new SchedulingQueue(
|
||||
task_main_io_service_, *waiter_, worker_context_)));
|
||||
task_spec.CallerId(),
|
||||
std::make_pair(tag, std::unique_ptr<SchedulingQueue>(new SchedulingQueue(
|
||||
task_main_io_service_, *waiter_, worker_context_))));
|
||||
it = result.first;
|
||||
}
|
||||
it->second->Add(request.sequence_number(), request.client_processed_up_to(),
|
||||
accept_callback, reject_callback, dependencies);
|
||||
it->second.second->Add(request.sequence_number(), request.client_processed_up_to(),
|
||||
accept_callback, reject_callback, dependencies);
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskReceiver::HandleDirectActorCallArgWaitComplete(
|
||||
|
|
|
@ -85,6 +85,9 @@ class CoreWorkerDirectActorTaskSubmitter {
|
|||
/// \param[in] actor_id Actor ID.
|
||||
void DisconnectActor(const ActorID &actor_id, bool dead = false);
|
||||
|
||||
/// Set the timerstamp for the caller.
|
||||
void SetCallerCreationTimestamp(int64_t timestamp);
|
||||
|
||||
private:
|
||||
/// Push a task to a remote actor via the given client.
|
||||
/// Note, this function doesn't return any error status code. If an error occurs while
|
||||
|
@ -161,6 +164,12 @@ class CoreWorkerDirectActorTaskSubmitter {
|
|||
/// Used to complete tasks.
|
||||
std::shared_ptr<TaskFinisherInterface> task_finisher_;
|
||||
|
||||
/// Timestamp when the caller is created.
|
||||
/// - if this worker is an actor, this is set to the time that the actor creation
|
||||
/// task starts execution;
|
||||
/// - otherwise, it's set to the time that the current task starts execution.
|
||||
int64_t caller_creation_timestamp_ms_ = 0;
|
||||
|
||||
friend class CoreWorkerTest;
|
||||
};
|
||||
|
||||
|
@ -195,14 +204,14 @@ class DependencyWaiter {
|
|||
|
||||
class DependencyWaiterImpl : public DependencyWaiter {
|
||||
public:
|
||||
DependencyWaiterImpl(raylet::RayletClient &local_raylet_client)
|
||||
: local_raylet_client_(local_raylet_client) {}
|
||||
DependencyWaiterImpl(DependencyWaiterInterface &dependency_client)
|
||||
: dependency_client_(dependency_client) {}
|
||||
|
||||
void Wait(const std::vector<ObjectID> &dependencies,
|
||||
std::function<void()> on_dependencies_available) override {
|
||||
auto tag = next_request_id_++;
|
||||
requests_[tag] = on_dependencies_available;
|
||||
local_raylet_client_.WaitForDirectActorCallArgs(dependencies, tag);
|
||||
dependency_client_.WaitForDirectActorCallArgs(dependencies, tag);
|
||||
}
|
||||
|
||||
/// Fulfills the callback stored by Wait().
|
||||
|
@ -216,7 +225,7 @@ class DependencyWaiterImpl : public DependencyWaiter {
|
|||
private:
|
||||
int64_t next_request_id_ = 0;
|
||||
std::unordered_map<int64_t, std::function<void()>> requests_;
|
||||
raylet::RayletClient &local_raylet_client_;
|
||||
DependencyWaiterInterface &dependency_client_;
|
||||
};
|
||||
|
||||
/// Wraps a thread-pool to block posts until the pool has free slots. This is used
|
||||
|
@ -253,6 +262,13 @@ class BoundedExecutor {
|
|||
boost::asio::thread_pool pool_;
|
||||
};
|
||||
|
||||
struct SchedulingQueueTag {
|
||||
/// Worker ID for the caller.
|
||||
WorkerID caller_worker_id;
|
||||
/// Timestamp for the caller, which is used as a version.
|
||||
int64_t caller_creation_timestamp_ms = 0;
|
||||
};
|
||||
|
||||
/// Used to ensure serial order of task execution per actor handle.
|
||||
/// See direct_actor.proto for a description of the ordering protocol.
|
||||
class SchedulingQueue {
|
||||
|
@ -409,17 +425,20 @@ class CoreWorkerDirectTaskReceiver {
|
|||
std::vector<std::shared_ptr<RayObject>> *return_objects,
|
||||
ReferenceCounter::ReferenceTableProto *borrower_refs)>;
|
||||
|
||||
using OnTaskDone = std::function<ray::Status()>;
|
||||
|
||||
CoreWorkerDirectTaskReceiver(WorkerContext &worker_context,
|
||||
std::shared_ptr<raylet::RayletClient> &local_raylet_client,
|
||||
boost::asio::io_service &main_io_service,
|
||||
const TaskHandler &task_handler)
|
||||
const TaskHandler &task_handler,
|
||||
const OnTaskDone &task_done)
|
||||
: worker_context_(worker_context),
|
||||
local_raylet_client_(local_raylet_client),
|
||||
task_handler_(task_handler),
|
||||
task_main_io_service_(main_io_service) {}
|
||||
task_main_io_service_(main_io_service),
|
||||
task_done_(task_done) {}
|
||||
|
||||
/// Initialize this receiver. This must be called prior to use.
|
||||
void Init(rpc::ClientFactoryFn client_factory, rpc::Address rpc_address);
|
||||
void Init(rpc::ClientFactoryFn client_factory, rpc::Address rpc_address,
|
||||
std::shared_ptr<DependencyWaiterInterface> dependency_client);
|
||||
|
||||
/// Handle a `PushTask` request.
|
||||
///
|
||||
|
@ -446,18 +465,19 @@ class CoreWorkerDirectTaskReceiver {
|
|||
TaskHandler task_handler_;
|
||||
/// The IO event loop for running tasks on.
|
||||
boost::asio::io_service &task_main_io_service_;
|
||||
/// The callback function to be invoked when finishing a task.
|
||||
OnTaskDone task_done_;
|
||||
/// Factory for producing new core worker clients.
|
||||
rpc::ClientFactoryFn client_factory_;
|
||||
/// Address of our RPC server.
|
||||
rpc::Address rpc_address_;
|
||||
/// Reference to the core worker's raylet client. This is a pointer ref so that it
|
||||
/// can be initialized by core worker after this class is constructed.
|
||||
std::shared_ptr<raylet::RayletClient> &local_raylet_client_;
|
||||
/// Shared waiter for dependencies required by incoming tasks.
|
||||
std::unique_ptr<DependencyWaiterImpl> waiter_;
|
||||
/// Queue of pending requests per actor handle.
|
||||
/// TODO(ekl) GC these queues once the handle is no longer active.
|
||||
std::unordered_map<TaskID, std::unique_ptr<SchedulingQueue>> scheduling_queue_;
|
||||
std::unordered_map<TaskID,
|
||||
std::pair<SchedulingQueueTag, std::unique_ptr<SchedulingQueue>>>
|
||||
scheduling_queue_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
|
|
@ -110,6 +110,13 @@ message PushTaskRequest {
|
|||
int64 client_processed_up_to = 5;
|
||||
// Resource mapping ids assigned to the worker executing the task.
|
||||
repeated ResourceMapEntry resource_mapping = 6;
|
||||
// The version of the caller. This is used to distinguish on-the-fly
|
||||
// requests from a caller before it die, and requests from the reconstructed
|
||||
// caller, which might happen theoretically when network has issues.
|
||||
// - For an actor, this is set to the timestamp when the actor is created,
|
||||
// so it can be used to differentiate which is the new reconstructed actor.
|
||||
// - For a non-actor task, it's set to the timestamp the task starts execution.
|
||||
int64 caller_version = 7;
|
||||
}
|
||||
|
||||
message PushTaskReply {
|
||||
|
|
|
@ -74,6 +74,21 @@ class WorkerLeaseInterface {
|
|||
virtual ~WorkerLeaseInterface(){};
|
||||
};
|
||||
|
||||
/// Interface for waiting dependencies. Abstract for testing.
|
||||
class DependencyWaiterInterface {
|
||||
public:
|
||||
/// Wait for the given objects, asynchronously. The core worker is notified when
|
||||
/// the wait completes.
|
||||
///
|
||||
/// \param object_ids The objects to wait for.
|
||||
/// \param tag Value that will be sent to the core worker via gRPC on completion.
|
||||
/// \return ray::Status.
|
||||
virtual ray::Status WaitForDirectActorCallArgs(const std::vector<ObjectID> &object_ids,
|
||||
int64_t tag) = 0;
|
||||
|
||||
virtual ~DependencyWaiterInterface(){};
|
||||
};
|
||||
|
||||
namespace raylet {
|
||||
|
||||
class RayletConnection {
|
||||
|
@ -115,7 +130,7 @@ class RayletConnection {
|
|||
std::mutex write_mutex_;
|
||||
};
|
||||
|
||||
class RayletClient : public WorkerLeaseInterface {
|
||||
class RayletClient : public WorkerLeaseInterface, public DependencyWaiterInterface {
|
||||
public:
|
||||
/// Connect to the raylet.
|
||||
///
|
||||
|
@ -205,7 +220,7 @@ class RayletClient : public WorkerLeaseInterface {
|
|||
/// \param tag Value that will be sent to the core worker via gRPC on completion.
|
||||
/// \return ray::Status.
|
||||
ray::Status WaitForDirectActorCallArgs(const std::vector<ObjectID> &object_ids,
|
||||
int64_t tag);
|
||||
int64_t tag) override;
|
||||
|
||||
/// Push an error to the relevant driver.
|
||||
///
|
||||
|
|
|
@ -40,6 +40,13 @@ inline int64_t current_time_ms() {
|
|||
return ms_since_epoch.count();
|
||||
}
|
||||
|
||||
inline int64_t current_sys_time_ms() {
|
||||
std::chrono::milliseconds ms_since_epoch =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch());
|
||||
return ms_since_epoch.count();
|
||||
}
|
||||
|
||||
/// A helper function to split a string by whitespaces.
|
||||
///
|
||||
/// \param str The string with whitespaces.
|
||||
|
|
Loading…
Add table
Reference in a new issue