From f2348a5456ab4eae28a46d18d1d3ab74ca8ea4e0 Mon Sep 17 00:00:00 2001 From: qicosmos <383121719@qq.com> Date: Tue, 9 Mar 2021 13:57:17 +0800 Subject: [PATCH] [C++ worker] Add ray register part1 (#14436) --- cpp/include/ray/api.h | 13 +- cpp/include/ray/api/function_manager.h | 177 +++++++++++++++++++++++ cpp/include/ray/api/ray_remote.h | 38 +++++ cpp/src/ray/api.cc | 2 +- cpp/src/ray/runtime/task/task_executor.h | 29 +++- cpp/src/ray/test/ray_remote_test.cc | 80 ++++++++++ 6 files changed, 332 insertions(+), 7 deletions(-) create mode 100644 cpp/include/ray/api/function_manager.h create mode 100644 cpp/include/ray/api/ray_remote.h create mode 100644 cpp/src/ray/test/ray_remote_test.cc diff --git a/cpp/include/ray/api.h b/cpp/include/ray/api.h index d23a05b58..92c253027 100644 --- a/cpp/include/ray/api.h +++ b/cpp/include/ray/api.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -139,7 +140,7 @@ inline static std::vector ObjectRefsToObjectIDs( template inline ObjectRef Ray::Put(const T &obj) { auto buffer = std::make_shared(Serializer::Serialize(obj)); - auto id = internal::RayRuntime()->Put(buffer); + auto id = ray::internal::RayRuntime()->Put(buffer); return ObjectRef(id); } @@ -150,7 +151,7 @@ inline std::shared_ptr Ray::Get(const ObjectRef &object) { template inline std::vector> Ray::Get(const std::vector &ids) { - auto result = internal::RayRuntime()->Get(ids); + auto result = ray::internal::RayRuntime()->Get(ids); std::vector> return_objects; return_objects.reserve(result.size()); for (auto it = result.begin(); it != result.end(); it++) { @@ -168,7 +169,7 @@ inline std::vector> Ray::Get(const std::vector> inline WaitResult Ray::Wait(const std::vector &ids, int num_objects, int timeout_ms) { - return internal::RayRuntime()->Wait(ids, num_objects, timeout_ms); + return ray::internal::RayRuntime()->Wait(ids, num_objects, timeout_ms); } template Ray::TaskInternal(FuncType &func, ExecFuncType &ex RemoteFunctionPtrHolder ptr; ptr.function_pointer = reinterpret_cast(func); ptr.exec_function_pointer = reinterpret_cast(exec_func); - return TaskCaller(internal::RayRuntime().get(), ptr, std::move(task_args)); + return TaskCaller(ray::internal::RayRuntime().get(), ptr, + std::move(task_args)); } template Ray::CreateActorInternal(FuncType &create_func, RemoteFunctionPtrHolder ptr; ptr.function_pointer = reinterpret_cast(create_func); ptr.exec_function_pointer = reinterpret_cast(exec_func); - return ActorCreator(internal::RayRuntime().get(), ptr, std::move(task_args)); + return ActorCreator(ray::internal::RayRuntime().get(), ptr, + std::move(task_args)); } /// Normal task. diff --git a/cpp/include/ray/api/function_manager.h b/cpp/include/ray/api/function_manager.h new file mode 100644 index 000000000..f3d12aa57 --- /dev/null +++ b/cpp/include/ray/api/function_manager.h @@ -0,0 +1,177 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "absl/utility/utility.h" + +#include +#include +#include +#include +#include +#include + +namespace ray { +namespace internal { + +template +struct AddType; + +template +struct AddType> { + using type = std::tuple; +}; + +/// Add a type to a tuple: AddType_t> equal std::tuple. +template +using AddType_t = typename AddType::type; + +enum ErrorCode { + OK = 0, + FAIL = 1, +}; + +struct VoidResponse { + int error_code; + std::string error_msg; + + MSGPACK_DEFINE(error_code, error_msg); +}; + +template +struct Response { + int error_code; + std::string error_msg; + T data; + + MSGPACK_DEFINE(error_code, error_msg, data); +}; + +template +inline static msgpack::sbuffer PackReturnValue(int error_code, std::string error_msg, + T result) { + return ray::api::Serializer::Serialize( + Response{error_code, std::move(error_msg), std::move(result)}); +} + +inline static msgpack::sbuffer PackReturnValue(int error_code, + std::string error_msg = "ok") { + return ray::api::Serializer::Serialize(VoidResponse{error_code, std::move(error_msg)}); +} + +/// It's help to invoke functions and member functions, the class Invoker help +/// do type erase. +template +struct Invoker { + /// Invoke functions by networking stream, at first deserialize the binary data to a + /// tuple, then call function with tuple. + static inline msgpack::sbuffer Apply(const Function &func, const char *data, + size_t size) { + using args_tuple = AddType_t>; + + msgpack::sbuffer result; + try { + auto tp = ray::api::Serializer::Deserialize(data, size); + result = Invoker::Call(func, std::move(tp)); + } catch (msgpack::type_error &e) { + result = + PackReturnValue(ErrorCode::FAIL, std::string("invalid arguments: ") + e.what()); + } catch (const std::exception &e) { + result = PackReturnValue(ErrorCode::FAIL, + std::string("function execute exception: ") + e.what()); + } catch (...) { + result = PackReturnValue(ErrorCode::FAIL, "unknown exception"); + } + + return result; + } + + template + static absl::result_of_t CallInternal(const F &f, + const absl::index_sequence &, + std::tuple tup) { + return f(std::move(std::get(tup))...); + } + + template + static absl::enable_if_t>::value, + msgpack::sbuffer> + Call(const F &f, std::tuple tp) { + CallInternal(f, absl::make_index_sequence{}, std::move(tp)); + return PackReturnValue(ErrorCode::OK); + } + + template + static absl::enable_if_t>::value, + msgpack::sbuffer> + Call(const F &f, std::tuple tp) { + auto r = CallInternal(f, absl::make_index_sequence{}, std::move(tp)); + return PackReturnValue(ErrorCode::OK, "ok", r); + } +}; + +/// Manage all ray remote functions, add remote functions by RAY_REMOTE, get functions by +/// TaskExecutionHandler. +class FunctionManager { + public: + static FunctionManager &Instance() { + static FunctionManager instance; + return instance; + } + + std::function *GetFunction( + const std::string &func_name) { + auto it = map_invokers_.find(func_name); + if (it == map_invokers_.end()) { + return nullptr; + } + + return &it->second; + } + + template + bool RegisterRemoteFunction(std::string const &name, const Function &f) { + /// Now it is just support free function, it will be + /// improved to support member function later. + auto pair = func_ptr_to_key_map_.emplace((uint64_t)f, name); + if (!pair.second) { + return false; + } + + return RegisterNonMemberFunc(name, f); + } + + private: + FunctionManager() = default; + ~FunctionManager() = default; + FunctionManager(const FunctionManager &) = delete; + FunctionManager(FunctionManager &&) = delete; + + template + bool RegisterNonMemberFunc(std::string const &name, Function f) { + return map_invokers_ + .emplace(name, std::bind(&Invoker::Apply, std::move(f), + std::placeholders::_1, std::placeholders::_2)) + .second; + } + + std::unordered_map> + map_invokers_; + std::unordered_map func_ptr_to_key_map_; +}; +} // namespace internal +} // namespace ray \ No newline at end of file diff --git a/cpp/include/ray/api/ray_remote.h b/cpp/include/ray/api/ray_remote.h new file mode 100644 index 000000000..10857375e --- /dev/null +++ b/cpp/include/ray/api/ray_remote.h @@ -0,0 +1,38 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ray { + +namespace { +#define CONCATENATE_DIRECT(s1, s2) s1##s2 +#define CONCATENATE(s1, s2) CONCATENATE_DIRECT(s1, s2) +#ifdef _MSC_VER +#define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __COUNTER__) +#else +#define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __LINE__) +#endif +} // namespace + +namespace api { + +#define RAY_REMOTE(f) \ + static auto ANONYMOUS_VARIABLE(var) = \ + ray::internal::FunctionManager::Instance().RegisterRemoteFunction(#f, f); + +} // namespace api +} // namespace ray \ No newline at end of file diff --git a/cpp/src/ray/api.cc b/cpp/src/ray/api.cc index f969abfc5..184aef581 100644 --- a/cpp/src/ray/api.cc +++ b/cpp/src/ray/api.cc @@ -11,7 +11,7 @@ std::once_flag Ray::is_inited_; void Ray::Init() { std::call_once(is_inited_, [] { auto runtime = AbstractRayRuntime::DoInit(RayConfig::GetInstance()); - internal::RayRuntimeHolder::Instance().Init(runtime); + ray::internal::RayRuntimeHolder::Instance().Init(runtime); }); } diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index cf2d24745..5c4012bd9 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -1,12 +1,39 @@ #pragma once +#include +#include #include - #include "absl/synchronization/mutex.h" #include "invocation_spec.h" #include "ray/core.h" namespace ray { +namespace internal { + +/// Execute remote functions by networking stream. +inline static msgpack::sbuffer TaskExecutionHandler(const char *data, std::size_t size) { + msgpack::sbuffer result; + do { + try { + auto p = ray::api::Serializer::Deserialize>(data, size); + auto &func_name = std::get<0>(p); + auto func_ptr = FunctionManager::Instance().GetFunction(func_name); + if (func_ptr == nullptr) { + result = PackReturnValue(internal::ErrorCode::FAIL, + "unknown function: " + func_name, 0); + break; + } + + result = (*func_ptr)(data, size); + } catch (const std::exception &ex) { + result = PackReturnValue(internal::ErrorCode::FAIL, ex.what()); + } + } while (0); + + return result; +} +} // namespace internal + namespace api { class AbstractRayRuntime; diff --git a/cpp/src/ray/test/ray_remote_test.cc b/cpp/src/ray/test/ray_remote_test.cc new file mode 100644 index 000000000..c71e19ef8 --- /dev/null +++ b/cpp/src/ray/test/ray_remote_test.cc @@ -0,0 +1,80 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "cpp/src/ray/runtime/task/task_executor.h" + +using namespace ray::api; +using namespace ray::internal; + +int Return() { return 1; } +int PlusOne(int x) { return x + 1; } + +RAY_REMOTE(PlusOne); + +TEST(RayApiTest, DuplicateRegister) { + bool r = FunctionManager::Instance().RegisterRemoteFunction("Return", Return); + EXPECT_TRUE(r); + + /// Duplicate register + bool r1 = FunctionManager::Instance().RegisterRemoteFunction("Return", Return); + EXPECT_FALSE(r1); + + bool r2 = FunctionManager::Instance().RegisterRemoteFunction("PlusOne", PlusOne); + EXPECT_FALSE(r2); +} + +TEST(RayApiTest, FindAndExecuteFunction) { + /// Find and call the registered function. + auto args = std::make_tuple("PlusOne", 1); + auto buf = Serializer::Serialize(args); + auto result_buf = TaskExecutionHandler(buf.data(), buf.size()); + + /// Deserialize result. + auto response = + Serializer::Deserialize>(result_buf.data(), result_buf.size()); + + EXPECT_EQ(response.error_code, ErrorCode::OK); + EXPECT_EQ(response.data, 2); +} + +TEST(RayApiTest, VoidFunction) { + auto buf1 = Serializer::Serialize(std::make_tuple("Return")); + auto result_buf = TaskExecutionHandler(buf1.data(), buf1.size()); + auto response = + Serializer::Deserialize(result_buf.data(), result_buf.size()); + EXPECT_EQ(response.error_code, ErrorCode::OK); +} + +/// We should consider the driver so is not same with the worker so, and find the error +/// reason. +TEST(RayApiTest, NotExistFunction) { + auto buf2 = Serializer::Serialize(std::make_tuple("Return11")); + auto result_buf = TaskExecutionHandler(buf2.data(), buf2.size()); + auto response = + Serializer::Deserialize(result_buf.data(), result_buf.size()); + EXPECT_EQ(response.error_code, ErrorCode::FAIL); + EXPECT_FALSE(response.error_msg.empty()); +} + +TEST(RayApiTest, ArgumentsNotMatch) { + auto buf = Serializer::Serialize(std::make_tuple("PlusOne", "invalid arguments")); + auto result_buf = TaskExecutionHandler(buf.data(), buf.size()); + auto response = + Serializer::Deserialize>(result_buf.data(), result_buf.size()); + EXPECT_EQ(response.error_code, ErrorCode::FAIL); + EXPECT_FALSE(response.error_msg.empty()); +}