From b6fe6156f5cc6d4490e4ed3cd347ebcd500c2f1e Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Thu, 25 Aug 2022 10:05:05 +0800 Subject: [PATCH] [C++ worker]Support ActorHandle type return value (#28077) Before we support `ActorHandle` type as parameter, this PR adds support for `ActorHandle` type as return type. --- cpp/include/ray/api.h | 7 +++++-- cpp/include/ray/api/actor_handle.h | 4 +++- cpp/include/ray/api/common_types.h | 10 ++++++++++ cpp/include/ray/api/function_manager.h | 16 ++++++---------- cpp/include/ray/api/object_ref.h | 7 +++++++ cpp/include/ray/api/ray_runtime.h | 6 ++---- cpp/include/ray/api/task_options.h | 2 ++ cpp/src/ray/test/cluster/cluster_mode_test.cc | 6 ++++++ 8 files changed, 41 insertions(+), 17 deletions(-) diff --git a/cpp/include/ray/api.h b/cpp/include/ray/api.h index 1ae484373..f868d570d 100644 --- a/cpp/include/ray/api.h +++ b/cpp/include/ray/api.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -270,7 +271,8 @@ inline ray::internal::TaskCaller Task(F func) { static_assert(!ray::internal::is_python_v, "Must be a cpp function."); static_assert(!std::is_member_function_pointer_v, "Incompatible type: member function cannot be called with ray::Task."); - ray::internal::RemoteFunctionHolder remote_func_holder(std::move(func)); + auto func_name = internal::FunctionManager::Instance().GetFunctionName(func); + ray::internal::RemoteFunctionHolder remote_func_holder(std::move(func_name)); return ray::internal::TaskCaller(ray::internal::GetRayRuntime().get(), std::move(remote_func_holder)); } @@ -278,7 +280,8 @@ inline ray::internal::TaskCaller Task(F func) { /// Creating an actor. template inline ray::internal::ActorCreator Actor(F create_func) { - ray::internal::RemoteFunctionHolder remote_func_holder(std::move(create_func)); + auto func_name = internal::FunctionManager::Instance().GetFunctionName(create_func); + ray::internal::RemoteFunctionHolder remote_func_holder(std::move(func_name)); return ray::internal::ActorCreator(ray::internal::GetRayRuntime().get(), std::move(remote_func_holder)); } diff --git a/cpp/include/ray/api/actor_handle.h b/cpp/include/ray/api/actor_handle.h index 66822cf3f..2b5aabc1e 100644 --- a/cpp/include/ray/api/actor_handle.h +++ b/cpp/include/ray/api/actor_handle.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include namespace ray { @@ -47,7 +48,8 @@ class ActorHandle { static_assert( std::is_same::value || std::is_base_of::value, "Class types must be same."); - ray::internal::RemoteFunctionHolder remote_func_holder(actor_func); + auto func_name = internal::FunctionManager::Instance().GetFunctionName(actor_func); + ray::internal::RemoteFunctionHolder remote_func_holder(func_name); return ray::internal::ActorTaskCaller( internal::GetRayRuntime().get(), id_, std::move(remote_func_holder)); } diff --git a/cpp/include/ray/api/common_types.h b/cpp/include/ray/api/common_types.h index 6e20cebd7..71defe2fd 100644 --- a/cpp/include/ray/api/common_types.h +++ b/cpp/include/ray/api/common_types.h @@ -42,5 +42,15 @@ struct TaskArg { std::string_view meta_str; }; +using ArgsBuffer = msgpack::sbuffer; +using ArgsBufferList = std::vector; + +using RemoteFunction = std::function; +using RemoteFunctionMap_t = std::unordered_map; + +using RemoteMemberFunction = + std::function; +using RemoteMemberFunctionMap_t = std::unordered_map; + } // namespace internal } // namespace ray \ No newline at end of file diff --git a/cpp/include/ray/api/function_manager.h b/cpp/include/ray/api/function_manager.h index 324b65392..b99ee3406 100644 --- a/cpp/include/ray/api/function_manager.h +++ b/cpp/include/ray/api/function_manager.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include @@ -33,6 +34,11 @@ namespace internal { template inline static std::enable_if_t::value, msgpack::sbuffer> PackReturnValue(T result) { + if constexpr (is_actor_handle_v) { + auto serialized_actor_handle = + RayRuntimeHolder::Instance().Runtime()->SerializeActorHandle(result.ID()); + return Serializer::Serialize(serialized_actor_handle); + } return Serializer::Serialize(std::move(result)); } @@ -48,16 +54,6 @@ inline static msgpack::sbuffer PackVoid() { msgpack::sbuffer PackError(std::string error_msg); -using ArgsBuffer = msgpack::sbuffer; -using ArgsBufferList = std::vector; - -using RemoteFunction = std::function; -using RemoteFunctionMap_t = std::unordered_map; - -using RemoteMemberFunction = - std::function; -using RemoteMemberFunctionMap_t = std::unordered_map; - /// It's help to invoke functions and member functions, the class Invoker help /// do type erase. template diff --git a/cpp/include/ray/api/object_ref.h b/cpp/include/ray/api/object_ref.h index 09acb3b3d..bcf6a8ad3 100644 --- a/cpp/include/ray/api/object_ref.h +++ b/cpp/include/ray/api/object_ref.h @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -118,6 +119,12 @@ inline static std::shared_ptr GetFromRuntime(const ObjectRef &object) { packed_object->data(), packed_object->size(), internal::XLANG_HEADER_LEN); } + if constexpr (ray::internal::is_actor_handle_v) { + auto actor_handle = ray::internal::Serializer::Deserialize( + packed_object->data(), packed_object->size()); + return std::make_shared(T::FromBytes(actor_handle)); + } + return ray::internal::Serializer::Deserialize>( packed_object->data(), packed_object->size()); } diff --git a/cpp/include/ray/api/ray_runtime.h b/cpp/include/ray/api/ray_runtime.h index 3638899ea..5066629f3 100644 --- a/cpp/include/ray/api/ray_runtime.h +++ b/cpp/include/ray/api/ray_runtime.h @@ -14,7 +14,7 @@ #pragma once -#include +#include #include #include @@ -38,9 +38,7 @@ struct RemoteFunctionHolder { this->class_name = class_name; this->lang_type = lang_type; } - template - RemoteFunctionHolder(F func) { - auto func_name = FunctionManager::Instance().GetFunctionName(func); + RemoteFunctionHolder(std::string func_name) { if (func_name.empty()) { throw RayException( "Function not found. Please use RAY_REMOTE to register this function."); diff --git a/cpp/include/ray/api/task_options.h b/cpp/include/ray/api/task_options.h index c5bfd3ec9..c2284249d 100644 --- a/cpp/include/ray/api/task_options.h +++ b/cpp/include/ray/api/task_options.h @@ -14,6 +14,8 @@ #pragma once +#include + #include namespace ray { diff --git a/cpp/src/ray/test/cluster/cluster_mode_test.cc b/cpp/src/ray/test/cluster/cluster_mode_test.cc index 630933dbd..ccc344a2b 100644 --- a/cpp/src/ray/test/cluster/cluster_mode_test.cc +++ b/cpp/src/ray/test/cluster/cluster_mode_test.cc @@ -242,9 +242,15 @@ TEST(RayClusterModeTest, ActorHandleTest) { auto actor1 = ray::Actor(RAY_FUNC(Counter::FactoryCreate)).Remote(); auto obj1 = actor1.Task(&Counter::Plus1).Remote(); EXPECT_EQ(1, *obj1.Get()); + // Test `ActorHandle` type object as parameter. auto actor2 = ray::Actor(RAY_FUNC(Counter::FactoryCreate)).Remote(); auto obj2 = actor2.Task(&Counter::Plus1ForActor).Remote(actor1); EXPECT_EQ(2, *obj2.Get()); + // Test `ActorHandle` type object as return value. + std::string child_actor_name = "child_actor_name"; + auto child_actor = + actor1.Task(&Counter::CreateChildActor).Remote(child_actor_name).Get(); + EXPECT_EQ(1, *child_actor->Task(&Counter::Plus1).Remote().Get()); } TEST(RayClusterModeTest, PythonInvocationTest) {