Revert "[core] Async submitting actor registerring (#18009)" (#18719)

This reverts commit 8ce01ea2cc.
This commit is contained in:
Yi Cheng 2021-09-17 13:34:12 -07:00 committed by GitHub
parent 09e760a1fd
commit cf64ab5b90
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 149 additions and 833 deletions

View file

@ -843,18 +843,6 @@ cc_test(
],
)
cc_test(
name = "direct_actor_transport_mock_test",
srcs = ["src/ray/core_worker/test/direct_actor_transport_mock_test.cc"],
copts = COPTS,
tags = ["team:core"],
deps = [
":core_worker_lib",
":ray_mock",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "direct_task_transport_test",
size = "small",
@ -863,20 +851,6 @@ cc_test(
tags = ["team:core"],
deps = [
":core_worker_lib",
":ray_mock",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "direct_task_transport_mock_test",
size = "small",
srcs = ["src/ray/core_worker/test/direct_task_transport_mock_test.cc"],
copts = COPTS,
tags = ["team:core"],
deps = [
":core_worker_lib",
":ray_mock",
"@com_google_googletest//:gtest_main",
],
)
@ -928,19 +902,6 @@ cc_test(
],
)
cc_test(
name = "actor_creator_test",
size = "small",
srcs = ["src/ray/core_worker/test/actor_creator_test.cc"],
copts = COPTS,
tags = ["team:core"],
deps = [
":core_worker_lib",
":ray_mock",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "actor_manager_test",
size = "small",

View file

@ -626,11 +626,9 @@ def test_warning_task_waiting_on_actor(shutdown_only):
@ray.remote(num_cpus=1)
class Actor:
def hello(self):
pass
pass
a = Actor.remote() # noqa
ray.get(a.hello.remote())
@ray.remote(num_cpus=1)
def f():

View file

@ -24,8 +24,7 @@ class MockActorCreatorInterface : public ActorCreatorInterface {
(const TaskSpecification &task_spec, gcs::StatusCallback callback),
(override));
MOCK_METHOD(Status, AsyncCreateActor,
(const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback),
(const TaskSpecification &task_spec, const gcs::StatusCallback &callback),
(override));
MOCK_METHOD(void, AsyncWaitForActorRegisterFinish,
(const ActorID &actor_id, gcs::StatusCallback callback), (override));
@ -46,8 +45,7 @@ class MockDefaultActorCreator : public DefaultActorCreator {
(override));
MOCK_METHOD(bool, IsActorInRegistering, (const ActorID &actor_id), (const, override));
MOCK_METHOD(Status, AsyncCreateActor,
(const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback),
(const TaskSpecification &task_spec, const gcs::StatusCallback &callback),
(override));
};

View file

@ -16,8 +16,8 @@
#include <google/protobuf/map.h>
#include <google/protobuf/repeated_field.h>
#include <google/protobuf/util/message_differencer.h>
#include <grpcpp/grpcpp.h>
#include <sstream>
#include "ray/common/status.h"
@ -63,11 +63,6 @@ class MessageWrapper {
/// Serialize the message to a string.
const std::string Serialize() const { return message_->SerializeAsString(); }
bool operator==(const MessageWrapper<Message> &rhs) const {
return google::protobuf::util::MessageDifferencer::Equivalent(GetMessage(),
rhs.GetMessage());
}
protected:
/// The wrapped message.
std::shared_ptr<Message> message_;

View file

@ -236,22 +236,6 @@ ObjectID ObjectID::ForActorHandle(const ActorID &actor_id) {
/*return_index=*/1);
}
bool ObjectID::IsActorID(const ObjectID &object_id) {
for (size_t i = 0; i < (TaskID::kLength - ActorID::kLength); ++i) {
if (object_id.id_[i] != 0xff) {
return false;
}
}
return true;
}
ActorID ObjectID::ToActorID(const ObjectID &object_id) {
auto beg = reinterpret_cast<const char *>(object_id.id_) + ObjectID::kLength -
ActorID::kLength - ObjectID::kIndexBytesLength;
std::string actor_id(beg, beg + ActorID::kLength);
return ActorID::FromBinary(actor_id);
}
ObjectID ObjectID::GenerateObjectId(const std::string &task_id_binary,
ObjectIDIndexType object_index) {
RAY_CHECK(task_id_binary.size() == TaskID::Size());

View file

@ -298,9 +298,6 @@ class ObjectID : public BaseID<ObjectID> {
/// \return The computed object ID.
static ObjectID ForActorHandle(const ActorID &actor_id);
static bool IsActorID(const ObjectID &object_id);
static ActorID ToActorID(const ObjectID &object_id);
MSGPACK_DEFINE(id_);
private:

View file

@ -469,9 +469,6 @@ RAY_CONFIG(int64_t, grpc_keepalive_timeout_ms, 20000);
/// Whether to use log reporter in event framework
RAY_CONFIG(bool, event_log_reporter_enabled, false)
/// Whether to use log reporter in event framework
RAY_CONFIG(bool, actor_register_async, true)
/// Event severity threshold value
RAY_CONFIG(std::string, event_level, "warning")

View file

@ -171,7 +171,7 @@ ObjectID TaskSpecification::ArgId(size_t arg_index) const {
return ObjectID::FromBinary(message_->args(arg_index).object_ref().object_id());
}
const rpc::ObjectReference &TaskSpecification::ArgRef(size_t arg_index) const {
rpc::ObjectReference TaskSpecification::ArgRef(size_t arg_index) const {
RAY_CHECK(ArgByRef(arg_index));
return message_->args(arg_index).object_ref();
}

View file

@ -114,7 +114,7 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
ObjectID ArgId(size_t arg_index) const;
const rpc::ObjectReference &ArgRef(size_t arg_index) const;
rpc::ObjectReference ArgRef(size_t arg_index) const;
ObjectID ReturnId(size_t return_index) const;

View file

@ -1,133 +0,0 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "ray/common/ray_config.h"
#include "ray/gcs/gcs_client.h"
namespace ray {
namespace core {
class ActorCreatorInterface {
public:
virtual ~ActorCreatorInterface() = default;
/// Register actor to GCS synchronously.
///
/// \param task_spec The specification for the actor creation task.
/// \return Status
virtual Status RegisterActor(const TaskSpecification &task_spec) = 0;
/// Asynchronously request GCS to register the actor.
/// \param task_spec The specification for the actor creation task.
/// \param callback Callback that will be called after the actor info is registered to
/// GCS
/// \return Status
virtual Status AsyncRegisterActor(const TaskSpecification &task_spec,
gcs::StatusCallback callback) = 0;
/// Asynchronously request GCS to create the actor.
///
/// \param task_spec The specification for the actor creation task.
/// \param callback Callback that will be called after the actor info is written to GCS.
/// \return Status
virtual Status AsyncCreateActor(
const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback) = 0;
/// Asynchronously wait until actor is registered successfully
///
/// \param actor_id The actor id to wait
/// \param callback The callback that will be called after actor registered
/// \return void
virtual void AsyncWaitForActorRegisterFinish(const ActorID &actor_id,
gcs::StatusCallback callback) = 0;
/// Check whether actor is activately under registering
///
/// \param actor_id The actor id to check
/// \return bool Boolean to indicate whether the actor is under registering
virtual bool IsActorInRegistering(const ActorID &actor_id) const = 0;
};
class DefaultActorCreator : public ActorCreatorInterface {
public:
explicit DefaultActorCreator(std::shared_ptr<gcs::GcsClient> gcs_client)
: gcs_client_(std::move(gcs_client)) {}
Status RegisterActor(const TaskSpecification &task_spec) override {
auto promise = std::make_shared<std::promise<void>>();
auto status = gcs_client_->Actors().AsyncRegisterActor(
task_spec, [promise](const Status &status) { promise->set_value(); });
if (status.ok() &&
promise->get_future().wait_for(std::chrono::seconds(
::RayConfig::instance().gcs_server_request_timeout_seconds())) !=
std::future_status::ready) {
std::ostringstream stream;
stream << "There was timeout in registering an actor. It is probably "
"because GCS server is dead or there's a high load there.";
return Status::TimedOut(stream.str());
}
return status;
}
Status AsyncRegisterActor(const TaskSpecification &task_spec,
gcs::StatusCallback callback) override {
if (::RayConfig::instance().actor_register_async()) {
auto actor_id = task_spec.ActorCreationId();
(*registering_actors_)[actor_id] = {};
if (callback != nullptr) {
(*registering_actors_)[actor_id].emplace_back(std::move(callback));
}
return gcs_client_->Actors().AsyncRegisterActor(
task_spec, [actor_id, this](Status status) {
std::vector<ray::gcs::StatusCallback> cbs;
cbs = std::move((*registering_actors_)[actor_id]);
registering_actors_->erase(actor_id);
for (auto &cb : cbs) {
cb(status);
}
});
} else {
callback(RegisterActor(task_spec));
return Status::OK();
}
}
bool IsActorInRegistering(const ActorID &actor_id) const override {
return registering_actors_->find(actor_id) != registering_actors_->end();
}
void AsyncWaitForActorRegisterFinish(const ActorID &actor_id,
gcs::StatusCallback callback) override {
auto iter = registering_actors_->find(actor_id);
RAY_CHECK(iter != registering_actors_->end());
iter->second.emplace_back(std::move(callback));
}
Status AsyncCreateActor(
const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback) override {
return gcs_client_->Actors().AsyncCreateActor(task_spec, callback);
}
private:
std::shared_ptr<gcs::GcsClient> gcs_client_;
using RegisteringActorType =
absl::flat_hash_map<ActorID, std::vector<ray::gcs::StatusCallback>>;
ThreadPrivate<RegisteringActorType> registering_actors_;
};
} // namespace core
} // namespace ray

View file

@ -15,14 +15,64 @@
#pragma once
#include "absl/container/flat_hash_map.h"
#include "ray/core_worker/actor_creator.h"
#include "ray/core_worker/actor_handle.h"
#include "ray/core_worker/reference_count.h"
#include "ray/core_worker/transport/direct_actor_transport.h"
#include "ray/gcs/gcs_client.h"
namespace ray {
namespace core {
class ActorCreatorInterface {
public:
virtual ~ActorCreatorInterface() = default;
/// Register actor to GCS synchronously.
///
/// \param task_spec The specification for the actor creation task.
/// \return Status
virtual Status RegisterActor(const TaskSpecification &task_spec) = 0;
/// Asynchronously request GCS to create the actor.
///
/// \param task_spec The specification for the actor creation task.
/// \param callback Callback that will be called after the actor info is written to GCS.
/// \return Status
virtual Status AsyncCreateActor(
const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback) = 0;
};
class DefaultActorCreator : public ActorCreatorInterface {
public:
explicit DefaultActorCreator(std::shared_ptr<gcs::GcsClient> gcs_client)
: gcs_client_(std::move(gcs_client)) {}
Status RegisterActor(const TaskSpecification &task_spec) override {
auto promise = std::make_shared<std::promise<void>>();
auto status = gcs_client_->Actors().AsyncRegisterActor(
task_spec, [promise](const Status &status) { promise->set_value(); });
if (status.ok() &&
promise->get_future().wait_for(std::chrono::seconds(
::RayConfig::instance().gcs_server_request_timeout_seconds())) !=
std::future_status::ready) {
std::ostringstream stream;
stream << "There was timeout in registering an actor. It is probably "
"because GCS server is dead or there's a high load there.";
return Status::TimedOut(stream.str());
}
return status;
}
Status AsyncCreateActor(
const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback) override {
return gcs_client_->Actors().AsyncCreateActor(task_spec, callback);
}
private:
std::shared_ptr<gcs::GcsClient> gcs_client_;
};
/// Class to manage lifetimes of actors that we create (actor children).
/// Currently this class is only used to publish actor DEAD event
/// for actor creation task failures. All other cases are managed

View file

@ -627,12 +627,12 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
PushError(options_.job_id, "excess_queueing_warning", stream.str(), timestamp));
};
actor_creator_ = std::make_shared<DefaultActorCreator>(gcs_client_);
std::shared_ptr<ActorCreatorInterface> actor_creator =
std::make_shared<DefaultActorCreator>(gcs_client_);
direct_actor_submitter_ = std::shared_ptr<CoreWorkerDirectActorTaskSubmitter>(
new CoreWorkerDirectActorTaskSubmitter(*core_worker_client_pool_, *memory_store_,
*task_manager_, *actor_creator_,
on_excess_queueing));
new CoreWorkerDirectActorTaskSubmitter(core_worker_client_pool_, memory_store_,
task_manager_, on_excess_queueing));
auto node_addr_factory = [this](const NodeID &node_id) {
absl::optional<rpc::Address> addr;
@ -655,7 +655,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
direct_task_submitter_ = std::make_unique<CoreWorkerDirectTaskSubmitter>(
rpc_address_, local_raylet_client_, core_worker_client_pool_, raylet_client_factory,
std::move(lease_policy), memory_store_, task_manager_, local_raylet_id,
RayConfig::instance().worker_lease_timeout_milliseconds(), actor_creator_,
RayConfig::instance().worker_lease_timeout_milliseconds(), std::move(actor_creator),
RayConfig::instance().max_tasks_in_flight_per_worker(),
boost::asio::steady_timer(io_service_));
auto report_locality_data_callback =
@ -1115,7 +1115,7 @@ Status CoreWorker::Put(const RayObject &object,
Status CoreWorker::Put(const RayObject &object,
const std::vector<ObjectID> &contained_object_ids,
const ObjectID &object_id, bool pin_object) {
RAY_RETURN_NOT_OK(WaitForActorRegistered(contained_object_ids));
bool object_exists;
if (options_.is_local_mode ||
(RayConfig::instance().put_small_object_in_memory_store() &&
static_cast<int64_t>(object.GetSize()) < max_direct_call_object_size_)) {
@ -1123,7 +1123,6 @@ Status CoreWorker::Put(const RayObject &object,
RAY_CHECK(memory_store_->Put(object, object_id));
return Status::OK();
}
bool object_exists;
RAY_RETURN_NOT_OK(plasma_store_provider_->Put(
object, object_id, /* owner_address = */ rpc_address_, &object_exists));
if (!object_exists) {
@ -1155,15 +1154,12 @@ Status CoreWorker::CreateOwned(const std::shared_ptr<Buffer> &metadata,
bool created_by_worker,
const std::unique_ptr<rpc::Address> &owner_address,
bool inline_small_object) {
auto status = WaitForActorRegistered(contained_object_ids);
if (!status.ok()) {
return status;
}
*object_id = ObjectID::FromIndex(worker_context_.GetCurrentTaskID(),
worker_context_.GetNextPutIndex());
rpc::Address real_owner_address =
owner_address != nullptr ? *owner_address : rpc_address_;
bool owned_by_us = real_owner_address.worker_id() == rpc_address_.worker_id();
auto status = Status::OK();
if (owned_by_us) {
reference_counter_->AddOwnedObject(*object_id, contained_object_ids, rpc_address_,
CurrentCallSite(), data_size + metadata->Size(),
@ -1754,6 +1750,7 @@ Status CoreWorker::CreateActor(const RayFunction &function,
<< "Actor " << actor_id << " already exists";
*return_actor_id = actor_id;
TaskSpecification task_spec = builder.Build();
Status status;
if (options_.is_local_mode) {
// TODO(suquark): Should we consider namespace in local mode? Currently
// it looks like two actors with two different namespaces become the
@ -1775,35 +1772,9 @@ Status CoreWorker::CreateActor(const RayFunction &function,
}
task_manager_->AddPendingTask(rpc_address_, task_spec, CurrentCallSite(),
max_retries);
if (actor_name.empty()) {
io_service_.post(
[this, task_spec = std::move(task_spec)]() {
RAY_UNUSED(actor_creator_->AsyncRegisterActor(
task_spec, [this, task_spec](Status status) {
if (!status.ok()) {
RAY_LOG(ERROR)
<< "Failed to register actor: " << task_spec.ActorCreationId()
<< ". Error message: " << status.ToString();
} else {
RAY_UNUSED(direct_task_submitter_->SubmitTask(task_spec));
}
}));
},
"ActorCreator.AsyncRegisterActor");
} else {
auto status = actor_creator_->RegisterActor(task_spec);
if (!status.ok()) {
return status;
}
io_service_.post(
[this, task_spec = std::move(task_spec)]() {
RAY_UNUSED(direct_task_submitter_->SubmitTask(task_spec));
},
"CoreWorker.SubmitTask");
}
status = direct_task_submitter_->SubmitTask(task_spec);
}
return Status::OK();
return status;
}
Status CoreWorker::CreatePlacementGroup(
@ -2587,20 +2558,8 @@ void CoreWorker::HandleWaitForActorOutOfScope(
};
const auto actor_id = ActorID::FromBinary(request.actor_id());
if (actor_creator_->IsActorInRegistering(actor_id)) {
actor_creator_->AsyncWaitForActorRegisterFinish(
actor_id, [this, actor_id, respond = std::move(respond)](auto status) {
if (!status.ok()) {
respond(actor_id);
} else {
RAY_LOG(DEBUG) << "Received HandleWaitForActorOutOfScope for " << actor_id;
actor_manager_->WaitForActorOutOfScope(actor_id, std::move(respond));
}
});
} else {
RAY_LOG(DEBUG) << "Received HandleWaitForActorOutOfScope for " << actor_id;
actor_manager_->WaitForActorOutOfScope(actor_id, std::move(respond));
}
RAY_LOG(DEBUG) << "Received HandleWaitForActorOutOfScope for " << actor_id;
actor_manager_->WaitForActorOutOfScope(actor_id, std::move(respond));
}
void CoreWorker::ProcessSubscribeForObjectEviction(
@ -3181,47 +3140,5 @@ std::shared_ptr<gcs::GcsClient> CoreWorker::GetGcsClient() const { return gcs_cl
bool CoreWorker::IsExiting() const { return exiting_; }
Status CoreWorker::WaitForActorRegistered(const std::vector<ObjectID> &ids) {
std::vector<ActorID> actor_ids;
for (const auto &id : ids) {
if (ObjectID::IsActorID(id)) {
actor_ids.emplace_back(ObjectID::ToActorID(id));
}
}
if (actor_ids.empty()) {
return Status::OK();
}
std::promise<void> promise;
auto future = promise.get_future();
std::vector<Status> ret;
int counter = 0;
// Post to service pool to avoid mutex
io_service_.post([&, this]() {
for (const auto &id : actor_ids) {
if (actor_creator_->IsActorInRegistering(id)) {
++counter;
actor_creator_->AsyncWaitForActorRegisterFinish(
id, [&counter, &promise, &ret](Status status) {
ret.push_back(status);
--counter;
if (counter == 0) {
promise.set_value();
}
});
}
}
if (counter == 0) {
promise.set_value();
}
});
future.wait();
for (const auto &s : ret) {
if (!s.ok()) {
return s;
}
}
return Status::OK();
}
} // namespace core
} // namespace ray

View file

@ -1238,8 +1238,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
return call_site;
}
Status WaitForActorRegistered(const std::vector<ObjectID> &ids);
/// Shared state of the worker. Includes process-level and thread-level state.
/// TODO(edoakes): we should move process-level state into this class and make
/// this a ThreadContext.
@ -1314,9 +1312,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
// Tracks the currently pending tasks.
std::shared_ptr<TaskManager> task_manager_;
// A class for actor creation.
std::shared_ptr<ActorCreatorInterface> actor_creator_;
// Interface to submit tasks directly to other actors.
std::shared_ptr<CoreWorkerDirectActorTaskSubmitter> direct_actor_submitter_;

View file

@ -1,94 +0,0 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// clang-format off
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ray/core_worker/actor_creator.h"
#include "ray/common/test_util.h"
#include "mock/ray/gcs/gcs_client.h"
// clang-format on
namespace ray {
namespace core {
class ActorCreatorTest : public ::testing::Test {
public:
ActorCreatorTest() {}
void SetUp() override {
gcs_client = std::make_shared<ray::gcs::MockGcsClient>();
actor_creator = std::make_unique<DefaultActorCreator>(gcs_client);
}
TaskSpecification GetTaskSpec(const ActorID &actor_id) {
rpc::TaskSpec task_spec;
task_spec.set_type(rpc::TaskType::ACTOR_CREATION_TASK);
rpc::ActorCreationTaskSpec actor_creation_task_spec;
actor_creation_task_spec.set_actor_id(actor_id.Binary());
task_spec.mutable_actor_creation_task_spec()->CopyFrom(actor_creation_task_spec);
return TaskSpecification(task_spec);
}
std::shared_ptr<ray::gcs::MockGcsClient> gcs_client;
std::unique_ptr<DefaultActorCreator> actor_creator;
};
TEST_F(ActorCreatorTest, IsRegister) {
auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000");
ASSERT_FALSE(actor_creator->IsActorInRegistering(actor_id));
auto task_spec = GetTaskSpec(actor_id);
std::function<void(Status)> cb;
EXPECT_CALL(*gcs_client->mock_actor_accessor,
AsyncRegisterActor(task_spec, ::testing::_))
.WillOnce(
::testing::DoAll(::testing::SaveArg<1>(&cb), ::testing::Return(Status::OK())));
ASSERT_TRUE(actor_creator->AsyncRegisterActor(task_spec, nullptr).ok());
ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id));
cb(Status::OK());
ASSERT_FALSE(actor_creator->IsActorInRegistering(actor_id));
}
TEST_F(ActorCreatorTest, AsyncWaitForFinish) {
auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000");
auto task_spec = GetTaskSpec(actor_id);
std::function<void(Status)> cb;
EXPECT_CALL(*gcs_client->mock_actor_accessor,
AsyncRegisterActor(::testing::_, ::testing::_))
.WillRepeatedly(
::testing::DoAll(::testing::SaveArg<1>(&cb), ::testing::Return(Status::OK())));
int cnt = 0;
auto per_finish_cb = [&cnt](Status status) {
ASSERT_TRUE(status.ok());
cnt++;
};
ASSERT_TRUE(actor_creator->AsyncRegisterActor(task_spec, per_finish_cb).ok());
ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id));
for (int i = 0; i < 100; ++i) {
actor_creator->AsyncWaitForActorRegisterFinish(actor_id, per_finish_cb);
}
cb(Status::OK());
ASSERT_FALSE(actor_creator->IsActorInRegistering(actor_id));
ASSERT_EQ(101, cnt);
}
} // namespace core
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog,
ray::RayLog::ShutDownRayLog, argv[0],
ray::RayLogLevel::INFO,
/*log_dir=*/"");
ray::RayLog::InstallFailureSignalHandler();
return RUN_ALL_TESTS();
}

View file

@ -1,115 +0,0 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// clang-format off
#include "ray/core_worker/transport/direct_actor_transport.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ray/core_worker/actor_creator.h"
#include "mock/ray/core_worker/task_manager.h"
#include "mock/ray/gcs/gcs_client.h"
// clang-format on
namespace ray {
namespace core {
using namespace ::testing;
class DirectTaskTransportTest : public ::testing::Test {
public:
void SetUp() override {
gcs_client = std::make_shared<ray::gcs::MockGcsClient>();
actor_creator = std::make_unique<DefaultActorCreator>(gcs_client);
task_finisher = std::make_shared<MockTaskFinisherInterface>();
client_pool = std::make_shared<rpc::CoreWorkerClientPool>(
[&](const rpc::Address &) { return nullptr; });
memory_store = std::make_unique<CoreWorkerMemoryStore>();
actor_task_submitter = std::make_unique<CoreWorkerDirectActorTaskSubmitter>(
*client_pool, *memory_store, *task_finisher, *actor_creator, nullptr);
}
TaskSpecification GetActorTaskSpec(const ActorID &actor_id) {
rpc::TaskSpec task_spec;
task_spec.set_type(rpc::TaskType::ACTOR_TASK);
task_spec.mutable_actor_task_spec()->set_actor_id(actor_id.Binary());
task_spec.set_task_id(
TaskID::ForActorTask(JobID::FromInt(10), TaskID::Nil(), 0, actor_id).Binary());
return TaskSpecification(task_spec);
}
TaskSpecification GetCreatingTaskSpec(const ActorID &actor_id) {
rpc::TaskSpec task_spec;
task_spec.set_task_id(TaskID::ForActorCreationTask(actor_id).Binary());
task_spec.set_type(rpc::TaskType::ACTOR_CREATION_TASK);
rpc::ActorCreationTaskSpec actor_creation_task_spec;
actor_creation_task_spec.set_actor_id(actor_id.Binary());
task_spec.mutable_actor_creation_task_spec()->CopyFrom(actor_creation_task_spec);
return TaskSpecification(task_spec);
}
std::unique_ptr<CoreWorkerDirectActorTaskSubmitter> actor_task_submitter;
std::shared_ptr<rpc::CoreWorkerClientPool> client_pool;
std::unique_ptr<CoreWorkerMemoryStore> memory_store;
std::shared_ptr<MockTaskFinisherInterface> task_finisher;
std::unique_ptr<DefaultActorCreator> actor_creator;
std::shared_ptr<ray::gcs::MockGcsClient> gcs_client;
};
TEST_F(DirectTaskTransportTest, ActorRegisterFailure) {
auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000");
ASSERT_TRUE(ObjectID::IsActorID(ObjectID::ForActorHandle(actor_id)));
ASSERT_EQ(actor_id, ObjectID::ToActorID(ObjectID::ForActorHandle(actor_id)));
auto creation_task_spec = GetCreatingTaskSpec(actor_id);
auto task_spec = GetActorTaskSpec(actor_id);
auto task_arg = task_spec.GetMutableMessage().add_args();
auto inline_obj_ref = task_arg->add_nested_inlined_refs();
inline_obj_ref->set_object_id(ObjectID::ForActorHandle(actor_id).Binary());
std::function<void(Status)> register_cb;
EXPECT_CALL(*gcs_client->mock_actor_accessor,
AsyncRegisterActor(creation_task_spec, ::testing::_))
.WillOnce(::testing::DoAll(::testing::SaveArg<1>(&register_cb),
::testing::Return(Status::OK())));
ASSERT_TRUE(actor_creator->AsyncRegisterActor(creation_task_spec, nullptr).ok());
ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id));
actor_task_submitter->AddActorQueueIfNotExists(actor_id);
ASSERT_TRUE(actor_task_submitter->SubmitTask(task_spec).ok());
EXPECT_CALL(*task_finisher,
PendingTaskFailed(task_spec.TaskId(),
rpc::ErrorType::DEPENDENCY_RESOLUTION_FAILED, _, _, _));
register_cb(Status::IOError(""));
}
TEST_F(DirectTaskTransportTest, ActorRegisterOk) {
auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000");
ASSERT_TRUE(ObjectID::IsActorID(ObjectID::ForActorHandle(actor_id)));
ASSERT_EQ(actor_id, ObjectID::ToActorID(ObjectID::ForActorHandle(actor_id)));
auto creation_task_spec = GetCreatingTaskSpec(actor_id);
auto task_spec = GetActorTaskSpec(actor_id);
auto task_arg = task_spec.GetMutableMessage().add_args();
auto inline_obj_ref = task_arg->add_nested_inlined_refs();
inline_obj_ref->set_object_id(ObjectID::ForActorHandle(actor_id).Binary());
std::function<void(Status)> register_cb;
EXPECT_CALL(*gcs_client->mock_actor_accessor,
AsyncRegisterActor(creation_task_spec, ::testing::_))
.WillOnce(::testing::DoAll(::testing::SaveArg<1>(&register_cb),
::testing::Return(Status::OK())));
ASSERT_TRUE(actor_creator->AsyncRegisterActor(creation_task_spec, nullptr).ok());
ASSERT_TRUE(actor_creator->IsActorInRegistering(actor_id));
actor_task_submitter->AddActorQueueIfNotExists(actor_id);
ASSERT_TRUE(actor_task_submitter->SubmitTask(task_spec).ok());
EXPECT_CALL(*task_finisher, PendingTaskFailed(_, _, _, _, _)).Times(0);
register_cb(Status::OK());
}
} // namespace core
} // namespace ray

View file

@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// clang-format off
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ray/common/asio/instrumented_io_context.h"
@ -22,9 +21,6 @@
#include "ray/core_worker/transport/direct_task_transport.h"
#include "ray/raylet_client/raylet_client.h"
#include "ray/rpc/worker/core_worker_client.h"
#include "mock/ray/core_worker/actor_creator.h"
#include "mock/ray/core_worker/task_manager.h"
// clang-format on
// clang-format off
#include "mock/ray/core_worker/task_manager.h"
@ -96,23 +92,20 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
class DirectActorSubmitterTest : public ::testing::Test {
public:
DirectActorSubmitterTest()
: client_pool_(
: worker_client_(std::shared_ptr<MockWorkerClient>(new MockWorkerClient())),
store_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
task_finisher_(std::make_shared<MockTaskFinisherInterface>()),
submitter_(
std::make_shared<rpc::CoreWorkerClientPool>([&](const rpc::Address &addr) {
num_clients_connected_++;
return worker_client_;
})),
worker_client_(std::make_shared<MockWorkerClient>()),
store_(std::make_shared<CoreWorkerMemoryStore>()),
task_finisher_(std::make_shared<MockTaskFinisherInterface>()),
submitter_(*client_pool_, *store_, *task_finisher_, actor_creator_,
[this](const ActorID &actor_id, int64_t num_queued) {
last_queue_warning_ = num_queued;
}) {}
}),
store_, task_finisher_, [this](const ActorID &actor_id, int64_t num_queued) {
last_queue_warning_ = num_queued;
}) {}
int num_clients_connected_ = 0;
int64_t last_queue_warning_ = 0;
MockActorCreatorInterface actor_creator_;
std::shared_ptr<rpc::CoreWorkerClientPool> client_pool_;
std::shared_ptr<MockWorkerClient> worker_client_;
std::shared_ptr<CoreWorkerMemoryStore> store_;
std::shared_ptr<MockTaskFinisherInterface> task_finisher_;

View file

@ -1,92 +0,0 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// clang-format off
#include "ray/core_worker/transport/direct_task_transport.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "mock/ray/core_worker/actor_creator.h"
#include "mock/ray/core_worker/task_manager.h"
#include "mock/ray/core_worker/lease_policy.h"
#include "mock/ray/raylet_client/raylet_client.h"
// clang-format on
namespace ray {
namespace core {
using namespace ::testing;
class DirectTaskTransportTest : public ::testing::Test {
public:
void SetUp() override {
raylet_client = std::make_shared<raylet::MockRayletClient>();
task_finisher = std::make_shared<MockTaskFinisherInterface>();
actor_creator = std::make_shared<MockActorCreatorInterface>();
lease_policy = std::make_shared<MockLeasePolicyInterface>();
auto client_pool = std::make_shared<rpc::CoreWorkerClientPool>(
[&](const rpc::Address &) { return nullptr; });
task_submitter = std::make_unique<CoreWorkerDirectTaskSubmitter>(
rpc::Address(), /* rpc_address */
raylet_client, /* lease_client */
client_pool, /* core_worker_client_pool */
nullptr, /* lease_client_factory */
lease_policy, /* lease_policy */
std::make_shared<CoreWorkerMemoryStore>(), task_finisher,
NodeID::Nil(), /* local_raylet_id */
0, /* lease_timeout_ms */
actor_creator);
}
TaskSpecification GetCreatingTaskSpec(const ActorID &actor_id) {
rpc::TaskSpec task_spec;
task_spec.set_task_id(TaskID::ForActorCreationTask(actor_id).Binary());
task_spec.set_type(rpc::TaskType::ACTOR_CREATION_TASK);
rpc::ActorCreationTaskSpec actor_creation_task_spec;
actor_creation_task_spec.set_actor_id(actor_id.Binary());
task_spec.mutable_actor_creation_task_spec()->CopyFrom(actor_creation_task_spec);
return TaskSpecification(task_spec);
}
std::unique_ptr<CoreWorkerDirectTaskSubmitter> task_submitter;
std::shared_ptr<raylet::MockRayletClient> raylet_client;
std::shared_ptr<MockTaskFinisherInterface> task_finisher;
std::shared_ptr<MockActorCreatorInterface> actor_creator;
std::shared_ptr<MockLeasePolicyInterface> lease_policy;
};
TEST_F(DirectTaskTransportTest, ActorRegisterOk) {
auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000");
auto task_spec = GetCreatingTaskSpec(actor_id);
EXPECT_CALL(*task_finisher, CompletePendingTask(task_spec.TaskId(), _, _));
rpc::ClientCallback<rpc::CreateActorReply> create_cb;
EXPECT_CALL(*actor_creator, AsyncCreateActor(task_spec, _))
.WillOnce(DoAll(SaveArg<1>(&create_cb), Return(Status::OK())));
ASSERT_TRUE(task_submitter->SubmitTask(task_spec).ok());
create_cb(Status::OK(), rpc::CreateActorReply());
}
TEST_F(DirectTaskTransportTest, ActorCreationFail) {
auto actor_id = ActorID::FromHex("f4ce02420592ca68c1738a0d01000000");
auto task_spec = GetCreatingTaskSpec(actor_id);
EXPECT_CALL(*task_finisher, CompletePendingTask(_, _, _)).Times(0);
EXPECT_CALL(*task_finisher,
PendingTaskFailed(task_spec.TaskId(), rpc::ErrorType::ACTOR_CREATION_FAILED,
_, _, true));
rpc::ClientCallback<rpc::CreateActorReply> create_cb;
EXPECT_CALL(*actor_creator, AsyncCreateActor(task_spec, _))
.WillOnce(DoAll(SaveArg<1>(&create_cb), Return(Status::OK())));
ASSERT_TRUE(task_submitter->SubmitTask(task_spec).ok());
create_cb(Status::IOError(""), rpc::CreateActorReply());
}
} // namespace core
} // namespace ray

View file

@ -234,22 +234,12 @@ class MockActorCreator : public ActorCreatorInterface {
return Status::OK();
};
Status AsyncRegisterActor(const TaskSpecification &task_spec,
gcs::StatusCallback callback) override {
return Status::OK();
}
Status AsyncCreateActor(
const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback) override {
return Status::OK();
}
void AsyncWaitForActorRegisterFinish(const ActorID &,
gcs::StatusCallback callback) override {}
bool IsActorInRegistering(const ActorID &actor_id) const override { return false; }
~MockActorCreator() {}
};
@ -299,11 +289,10 @@ TEST(TestMemoryStore, TestPromoteToPlasma) {
TEST(LocalDependencyResolverTest, TestNoDependencies) {
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto task_finisher = std::make_shared<MockTaskFinisher>();
MockActorCreator actor_creator;
LocalDependencyResolver resolver(*store, *task_finisher, actor_creator);
LocalDependencyResolver resolver(store, task_finisher);
TaskSpecification task;
bool ok = false;
resolver.ResolveDependencies(task, [&ok](Status) { ok = true; });
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
ASSERT_TRUE(ok);
ASSERT_EQ(task_finisher->num_inlined_dependencies, 0);
}
@ -311,8 +300,7 @@ TEST(LocalDependencyResolverTest, TestNoDependencies) {
TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto task_finisher = std::make_shared<MockTaskFinisher>();
MockActorCreator actor_creator;
LocalDependencyResolver resolver(*store, *task_finisher, actor_creator);
LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom();
std::string meta = std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
@ -322,7 +310,7 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
TaskSpecification task;
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok](Status) { ok = true; });
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
ASSERT_TRUE(ok);
ASSERT_TRUE(task.ArgByRef(0));
// Checks that the object id is still a direct call id.
@ -333,8 +321,7 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto task_finisher = std::make_shared<MockTaskFinisher>();
MockActorCreator actor_creator;
LocalDependencyResolver resolver(*store, *task_finisher, actor_creator);
LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom();
ObjectID obj2 = ObjectID::FromRandom();
auto data = GenerateRandomObject();
@ -345,7 +332,7 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary());
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj2.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok](Status) { ok = true; });
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
// Tests that the task proto was rewritten to have inline argument values.
ASSERT_TRUE(ok);
ASSERT_FALSE(task.ArgByRef(0));
@ -359,8 +346,7 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto task_finisher = std::make_shared<MockTaskFinisher>();
MockActorCreator actor_creator;
LocalDependencyResolver resolver(*store, *task_finisher, actor_creator);
LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom();
ObjectID obj2 = ObjectID::FromRandom();
auto data = GenerateRandomObject();
@ -368,7 +354,7 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary());
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj2.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok](Status) { ok = true; });
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
ASSERT_EQ(resolver.NumPendingTasks(), 1);
ASSERT_TRUE(!ok);
ASSERT_TRUE(store->Put(*data, obj1));
@ -388,8 +374,7 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
TEST(LocalDependencyResolverTest, TestInlinedObjectIds) {
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto task_finisher = std::make_shared<MockTaskFinisher>();
MockActorCreator actor_creator;
LocalDependencyResolver resolver(*store, *task_finisher, actor_creator);
LocalDependencyResolver resolver(store, task_finisher);
ObjectID obj1 = ObjectID::FromRandom();
ObjectID obj2 = ObjectID::FromRandom();
ObjectID obj3 = ObjectID::FromRandom();
@ -398,7 +383,7 @@ TEST(LocalDependencyResolverTest, TestInlinedObjectIds) {
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj1.Binary());
task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj2.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok](Status) { ok = true; });
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
ASSERT_EQ(resolver.NumPendingTasks(), 1);
ASSERT_TRUE(!ok);
ASSERT_TRUE(store->Put(*data, obj1));

View file

@ -19,26 +19,16 @@ namespace core {
struct TaskState {
TaskState(TaskSpecification t,
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> deps,
std::vector<ActorID> actor_ids)
: task(t),
local_dependencies(std::move(deps)),
actor_dependencies(std::move(actor_ids)),
status(Status::OK()) {
obj_dependencies_remaining = local_dependencies.size();
actor_dependencies_remaining = actor_dependencies.size();
}
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> deps)
: task(t), local_dependencies(deps), dependencies_remaining(deps.size()) {}
/// The task to be run.
TaskSpecification task;
/// The local dependencies to resolve for this task. Objects are nullptr if not yet
/// resolved.
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> local_dependencies;
std::vector<ActorID> actor_dependencies;
/// Number of local dependencies that aren't yet resolved (have nullptrs in the above
/// map).
size_t actor_dependencies_remaining;
size_t obj_dependencies_remaining;
Status status;
size_t dependencies_remaining;
};
void InlineDependencies(
@ -80,38 +70,28 @@ void InlineDependencies(
RAY_CHECK(found >= dependencies.size());
}
void LocalDependencyResolver::ResolveDependencies(
TaskSpecification &task, std::function<void(Status)> on_complete) {
void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task,
std::function<void()> on_complete) {
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> local_dependencies;
std::vector<ActorID> actor_dependences;
for (size_t i = 0; i < task.NumArgs(); i++) {
if (task.ArgByRef(i)) {
local_dependencies.emplace(task.ArgId(i), nullptr);
}
for (const auto &in : task.ArgInlinedRefs(i)) {
auto object_id = ObjectID::FromBinary(in.object_id());
if (ObjectID::IsActorID(object_id)) {
auto actor_id = ObjectID::ToActorID(object_id);
if (actor_creator_.IsActorInRegistering(actor_id)) {
actor_dependences.emplace_back(ObjectID::ToActorID(object_id));
}
}
}
}
if (local_dependencies.empty() && actor_dependences.empty()) {
on_complete(Status::OK());
if (local_dependencies.empty()) {
on_complete();
return;
}
// This is deleted when the last dependency fetch callback finishes.
std::shared_ptr<TaskState> state = std::make_shared<TaskState>(
task, std::move(local_dependencies), std::move(actor_dependences));
std::shared_ptr<TaskState> state =
std::make_shared<TaskState>(task, std::move(local_dependencies));
num_pending_ += 1;
for (const auto &it : state->local_dependencies) {
const ObjectID &obj_id = it.first;
in_memory_store_.GetAsync(obj_id, [this, state, obj_id,
on_complete](std::shared_ptr<RayObject> obj) {
in_memory_store_->GetAsync(obj_id, [this, state, obj_id,
on_complete](std::shared_ptr<RayObject> obj) {
RAY_CHECK(obj != nullptr);
bool complete = false;
std::vector<ObjectID> inlined_dependency_ids;
@ -119,35 +99,21 @@ void LocalDependencyResolver::ResolveDependencies(
{
absl::MutexLock lock(&mu_);
state->local_dependencies[obj_id] = std::move(obj);
if (--state->obj_dependencies_remaining == 0) {
if (--state->dependencies_remaining == 0) {
InlineDependencies(state->local_dependencies, state->task,
&inlined_dependency_ids, &contained_ids);
if (state->actor_dependencies_remaining == 0) {
complete = true;
num_pending_ -= 1;
}
complete = true;
num_pending_ -= 1;
}
}
if (inlined_dependency_ids.size() > 0) {
task_finisher_.OnTaskDependenciesInlined(inlined_dependency_ids, contained_ids);
task_finisher_->OnTaskDependenciesInlined(inlined_dependency_ids, contained_ids);
}
if (complete) {
on_complete(state->status);
on_complete();
}
});
}
for (const auto &actor_id : state->actor_dependencies) {
actor_creator_.AsyncWaitForActorRegisterFinish(
actor_id, [state, on_complete](Status status) {
if (!status.ok()) {
state->status = status;
}
if (--state->actor_dependencies_remaining == 0) {
on_complete(state->status);
}
});
}
}
} // namespace core

View file

@ -18,7 +18,6 @@
#include "ray/common/id.h"
#include "ray/common/task/task_spec.h"
#include "ray/core_worker/actor_creator.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/task_manager.h"
@ -28,13 +27,9 @@ namespace core {
// This class is thread-safe.
class LocalDependencyResolver {
public:
LocalDependencyResolver(CoreWorkerMemoryStore &store,
TaskFinisherInterface &task_finisher,
ActorCreatorInterface &actor_creator)
: in_memory_store_(store),
task_finisher_(task_finisher),
actor_creator_(actor_creator),
num_pending_(0) {}
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStore> store,
std::shared_ptr<TaskFinisherInterface> task_finisher)
: in_memory_store_(store), task_finisher_(task_finisher), num_pending_(0) {}
/// Resolve all local and remote dependencies for the task, calling the specified
/// callback when done. Direct call ids in the task specification will be resolved
@ -44,8 +39,7 @@ class LocalDependencyResolver {
///
/// Postcondition: all direct call id arguments that haven't been spilled to plasma
/// are converted to values and all remaining arguments are arguments in the task spec.
void ResolveDependencies(TaskSpecification &task,
std::function<void(Status)> on_complete);
void ResolveDependencies(TaskSpecification &task, std::function<void()> on_complete);
/// Return the number of tasks pending dependency resolution.
/// TODO(ekl) this should be exposed in worker stats.
@ -53,12 +47,11 @@ class LocalDependencyResolver {
private:
/// The in-memory store.
CoreWorkerMemoryStore &in_memory_store_;
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
/// Used to complete tasks.
TaskFinisherInterface &task_finisher_;
std::shared_ptr<TaskFinisherInterface> task_finisher_;
ActorCreatorInterface &actor_creator_;
/// Number of tasks pending dependency resolution.
std::atomic<int> num_pending_;

View file

@ -87,7 +87,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
if (task_queued) {
// We must release the lock before resolving the task dependencies since
// the callback may get called in the same call stack.
resolver_.ResolveDependencies(task_spec, [this, send_pos, actor_id](Status status) {
resolver_.ResolveDependencies(task_spec, [this, send_pos, actor_id]() {
absl::MutexLock lock(&mu_);
auto queue = client_queues_.find(actor_id);
RAY_CHECK(queue != client_queues_.end());
@ -95,20 +95,13 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
// Only dispatch tasks if the submitted task is still queued. The task
// may have been dequeued if the actor has since failed.
if (it != queue->second.requests.end()) {
if (status.ok()) {
it->second.second = true;
SendPendingTasks(actor_id);
} else {
auto task_id = it->second.first.TaskId();
queue->second.requests.erase(it);
task_finisher_.PendingTaskFailed(
task_id, rpc::ErrorType::DEPENDENCY_RESOLUTION_FAILED, &status);
}
it->second.second = true;
SendPendingTasks(actor_id);
}
});
} else {
// Do not hold the lock while calling into task_finisher_.
task_finisher_.MarkTaskCanceled(task_id);
task_finisher_->MarkTaskCanceled(task_id);
std::shared_ptr<rpc::RayException> creation_task_exception = nullptr;
{
absl::MutexLock lock(&mu_);
@ -118,8 +111,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
auto status = Status::IOError("cancelling task of dead actor");
// No need to increment the number of completed tasks since the actor is
// dead.
RAY_UNUSED(!task_finisher_.PendingTaskFailed(task_id, rpc::ErrorType::ACTOR_DIED,
&status, creation_task_exception));
RAY_UNUSED(!task_finisher_->PendingTaskFailed(task_id, rpc::ErrorType::ACTOR_DIED,
&status, creation_task_exception));
}
// If the task submission subsequently fails, then the client will receive
@ -129,7 +122,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
void CoreWorkerDirectActorTaskSubmitter::DisconnectRpcClient(ClientQueue &queue) {
queue.rpc_client = nullptr;
core_worker_client_pool_.Disconnect(WorkerID::FromBinary(queue.worker_id));
core_worker_client_pool_->Disconnect(WorkerID::FromBinary(queue.worker_id));
queue.worker_id.clear();
queue.pending_force_kill.reset();
}
@ -174,7 +167,7 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectActor(const ActorID &actor_id,
// Update the mapping so new RPCs go out with the right intended worker id.
queue->second.worker_id = address.worker_id();
// Create a new connection to the actor.
queue->second.rpc_client = core_worker_client_pool_.GetOrConnect(address);
queue->second.rpc_client = core_worker_client_pool_->GetOrConnect(address);
// TODO(swang): This assumes that all replies from the previous incarnation
// of the actor have been received. Fix this by setting an epoch for each
// actor task, so we can ignore completed tasks from old epochs.
@ -221,12 +214,12 @@ void CoreWorkerDirectActorTaskSubmitter::DisconnectActor(
auto status = Status::IOError("cancelling all pending tasks of dead actor");
while (head != requests.end()) {
const auto &task_spec = head->second.first;
task_finisher_.MarkTaskCanceled(task_spec.TaskId());
task_finisher_->MarkTaskCanceled(task_spec.TaskId());
// No need to increment the number of completed tasks since the actor is
// dead.
RAY_UNUSED(!task_finisher_.PendingTaskFailed(task_spec.TaskId(),
rpc::ErrorType::ACTOR_DIED, &status,
creation_task_exception));
RAY_UNUSED(!task_finisher_->PendingTaskFailed(task_spec.TaskId(),
rpc::ErrorType::ACTOR_DIED, &status,
creation_task_exception));
head = requests.erase(head);
}
@ -235,7 +228,7 @@ void CoreWorkerDirectActorTaskSubmitter::DisconnectActor(
RAY_LOG(INFO) << "Failing tasks waiting for death info, size="
<< wait_for_death_info_tasks.size() << ", actor_id=" << actor_id;
for (auto &net_err_task : wait_for_death_info_tasks) {
RAY_UNUSED(task_finisher_.MarkPendingTaskFailed(
RAY_UNUSED(task_finisher_->MarkPendingTaskFailed(
net_err_task.second, rpc::ErrorType::ACTOR_DIED, creation_task_exception));
}
@ -259,7 +252,7 @@ void CoreWorkerDirectActorTaskSubmitter::CheckTimeoutTasks() {
while (deque_itr != queue.wait_for_death_info_tasks.end() &&
/*timeout timestamp*/ deque_itr->first < current_time_ms()) {
auto task_spec = deque_itr->second;
task_finisher_.MarkPendingTaskFailed(task_spec, rpc::ErrorType::ACTOR_DIED);
task_finisher_->MarkPendingTaskFailed(task_spec, rpc::ErrorType::ACTOR_DIED);
deque_itr = queue.wait_for_death_info_tasks.erase(deque_itr);
}
}
@ -361,7 +354,7 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(const ClientQueue &queue,
// because the tasks are pushed directly to the actor, not placed on any queues
// in task_finisher_.
} else if (status.ok()) {
task_finisher_.CompletePendingTask(task_id, reply, addr);
task_finisher_->CompletePendingTask(task_id, reply, addr);
} else {
// push task failed due to network error. For example, actor is dead
// and no process response for the push task.
@ -371,7 +364,7 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(const ClientQueue &queue,
auto &queue = queue_pair->second;
bool immediately_mark_object_fail = (queue.state == rpc::ActorTableData::DEAD);
bool will_retry = task_finisher_.PendingTaskFailed(
bool will_retry = task_finisher_->PendingTaskFailed(
task_id, rpc::ErrorType::ACTOR_DIED, &status, queue.creation_task_exception,
immediately_mark_object_fail);
if (will_retry) {

View file

@ -28,7 +28,6 @@
#include "ray/common/asio/instrumented_io_context.h"
#include "ray/common/id.h"
#include "ray/common/ray_object.h"
#include "ray/core_worker/actor_creator.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/fiber.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
@ -37,6 +36,8 @@
#include "ray/rpc/grpc_server.h"
#include "ray/rpc/worker/core_worker_client.h"
namespace {} // namespace
namespace ray {
namespace core {
@ -67,11 +68,12 @@ class CoreWorkerDirectActorTaskSubmitter
: public CoreWorkerDirectActorTaskSubmitterInterface {
public:
CoreWorkerDirectActorTaskSubmitter(
rpc::CoreWorkerClientPool &core_worker_client_pool, CoreWorkerMemoryStore &store,
TaskFinisherInterface &task_finisher, ActorCreatorInterface &actor_creator,
std::shared_ptr<rpc::CoreWorkerClientPool> core_worker_client_pool,
std::shared_ptr<CoreWorkerMemoryStore> store,
std::shared_ptr<TaskFinisherInterface> task_finisher,
std::function<void(const ActorID &, int64_t)> warn_excess_queueing)
: core_worker_client_pool_(core_worker_client_pool),
resolver_(store, task_finisher, actor_creator),
resolver_(store, task_finisher),
task_finisher_(task_finisher),
warn_excess_queueing_(warn_excess_queueing) {
next_queueing_warn_threshold_ =
@ -260,7 +262,7 @@ class CoreWorkerDirectActorTaskSubmitter
bool IsActorAlive(const ActorID &actor_id) const;
/// Pool for producing new core worker clients.
rpc::CoreWorkerClientPool &core_worker_client_pool_;
std::shared_ptr<rpc::CoreWorkerClientPool> core_worker_client_pool_;
/// Mutex to protect the various maps below.
mutable absl::Mutex mu_;
@ -271,7 +273,7 @@ class CoreWorkerDirectActorTaskSubmitter
LocalDependencyResolver resolver_;
/// Used to complete tasks.
TaskFinisherInterface &task_finisher_;
std::shared_ptr<TaskFinisherInterface> task_finisher_;
/// Used to warn of excessive queueing.
std::function<void(const ActorID &, int64_t num_queued)> warn_excess_queueing_;

View file

@ -23,13 +23,21 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
RAY_LOG(DEBUG) << "Submit task " << task_spec.TaskId();
num_tasks_submitted_++;
resolver_.ResolveDependencies(task_spec, [this, task_spec](Status status) {
if (task_spec.IsActorCreationTask()) {
// Synchronously register the actor to GCS server.
// Previously, we asynchronously registered the actor after all its dependencies were
// resolved. This caused a problem: if the owner of the actor dies before dependencies
// are resolved, the actor will never be created. But the actor handle may already be
// passed to other workers. In this case, the actor tasks will hang forever.
// So we fixed this issue by synchronously registering the actor. If the owner dies
// before dependencies are resolved, GCS will notice this and mark the actor as dead.
auto status = actor_creator_->RegisterActor(task_spec);
if (!status.ok()) {
RAY_LOG(ERROR) << "Resolving task dependencies failed " << status.ToString();
RAY_UNUSED(task_finisher_->PendingTaskFailed(
task_spec.TaskId(), rpc::ErrorType::DEPENDENCY_RESOLUTION_FAILED, &status));
return;
return status;
}
}
resolver_.ResolveDependencies(task_spec, [this, task_spec]() {
RAY_LOG(DEBUG) << "Task dependencies resolved " << task_spec.TaskId();
if (task_spec.IsActorCreationTask()) {
// If gcs actor management is enabled, the actor creation task will be sent to

View file

@ -71,11 +71,11 @@ class CoreWorkerDirectTaskSubmitter {
local_lease_client_(lease_client),
lease_client_factory_(lease_client_factory),
lease_policy_(std::move(lease_policy)),
resolver_(*store, *task_finisher, *actor_creator),
resolver_(store, task_finisher),
task_finisher_(task_finisher),
lease_timeout_ms_(lease_timeout_ms),
local_raylet_id_(local_raylet_id),
actor_creator_(actor_creator),
actor_creator_(std::move(actor_creator)),
client_cache_(core_worker_client_pool),
max_tasks_in_flight_per_worker_(max_tasks_in_flight_per_worker),
cancel_retry_timer_(std::move(cancel_timer)) {}

View file

@ -18,7 +18,6 @@
#include <memory>
#include <string>
#include <vector>
#include "gtest/gtest_prod.h"
#include "ray/common/asio/instrumented_io_context.h"
#include "ray/common/status.h"
@ -167,7 +166,7 @@ class GcsClient : public std::enable_shared_from_this<GcsClient> {
/// Constructor of GcsClient.
///
/// \param options Options for client.
GcsClient(const GcsClientOptions &options = GcsClientOptions()) : options_(options) {}
GcsClient(const GcsClientOptions &options) : options_(options) {}
GcsClientOptions options_;

View file

@ -137,8 +137,6 @@ enum ErrorType {
// been deleted from distributed memory. This can happen in distributed
// reference counting, due to a bug or corner case.
OBJECT_DELETED = 10;
// Indicates there is some error when resolving the dependence
DEPENDENCY_RESOLUTION_FAILED = 11;
}
/// The task exception encapsulates all information about task

View file

@ -444,9 +444,6 @@ class RayletClient : public RayletClientInterface {
/// The number of object ID pin RPCs currently in flight.
std::atomic<int64_t> pins_in_flight_{0};
protected:
RayletClient() {}
};
} // namespace raylet

View file

@ -21,10 +21,8 @@
#include <sstream>
#include <string>
#include <thread>
#include <unordered_map>
#include "ray/util/logging.h"
#include "ray/util/macros.h"
#ifdef _WIN32
@ -239,77 +237,3 @@ inline void SetThreadName(const std::string &thread_name) {
pthread_setname_np(pthread_self(), thread_name.substr(0, 15).c_str());
#endif
}
inline std::string GetThreadName() {
#if defined(__linux__)
char name[128];
auto rc = pthread_getname_np(pthread_self(), name, sizeof(name));
if (rc != 0) {
return "ERROR";
} else {
return name;
}
#else
return "UNKNOWN";
#endif
}
namespace ray {
template <typename T>
class ThreadPrivate {
public:
template <typename... Ts>
ThreadPrivate(Ts &&... ts) : t_(std::forward<Ts>(ts)...) {}
T &operator*() {
ThreadCheck();
return t_;
}
T *operator->() {
ThreadCheck();
return &t_;
}
const T &operator*() const {
ThreadCheck();
return t_;
}
const T *operator->() const {
ThreadCheck();
return &t_;
}
private:
void ThreadCheck() const {
// ThreadCheck is not a thread safe function and at the same time, multiple
// threads might be accessing id_ at the same time.
// Here we only introduce mutex to protect write instead of read for the
// following reasons:
// - read and write at the same time for `id_` is fine since this is a
// trivial object. And since we are using this to detect errors,
// it doesn't matter which value it is.
// - read and write of `thread_name_` is not good. But it will only be
// read when we crash the program.
//
if (id_ == std::thread::id()) {
// Protect thread_name_
std::lock_guard<std::mutex> _(mutex_);
thread_name_ = GetThreadName();
RAY_LOG(DEBUG) << "First accessed in thread " << thread_name_;
id_ = std::this_thread::get_id();
}
RAY_CHECK(id_ == std::this_thread::get_id())
<< "A variable private to thread " << thread_name_ << " was accessed in thread "
<< GetThreadName();
}
T t_;
mutable std::string thread_name_;
mutable std::thread::id id_;
mutable std::mutex mutex_;
};
} // namespace ray