[C++ worker] Add ray register part1 (#14436)

This commit is contained in:
qicosmos 2021-03-09 13:57:17 +08:00 committed by GitHub
parent a06dc39d9f
commit f2348a5456
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 332 additions and 7 deletions

View file

@ -6,6 +6,7 @@
#include <ray/api/actor_task_caller.h>
#include <ray/api/exec_funcs.h>
#include <ray/api/object_ref.h>
#include <ray/api/ray_remote.h>
#include <ray/api/ray_runtime.h>
#include <ray/api/ray_runtime_holder.h>
#include <ray/api/task_caller.h>
@ -139,7 +140,7 @@ inline static std::vector<ObjectID> ObjectRefsToObjectIDs(
template <typename T>
inline ObjectRef<T> Ray::Put(const T &obj) {
auto buffer = std::make_shared<msgpack::sbuffer>(Serializer::Serialize(obj));
auto id = internal::RayRuntime()->Put(buffer);
auto id = ray::internal::RayRuntime()->Put(buffer);
return ObjectRef<T>(id);
}
@ -150,7 +151,7 @@ inline std::shared_ptr<T> Ray::Get(const ObjectRef<T> &object) {
template <typename T>
inline std::vector<std::shared_ptr<T>> Ray::Get(const std::vector<ObjectID> &ids) {
auto result = internal::RayRuntime()->Get(ids);
auto result = ray::internal::RayRuntime()->Get(ids);
std::vector<std::shared_ptr<T>> return_objects;
return_objects.reserve(result.size());
for (auto it = result.begin(); it != result.end(); it++) {
@ -168,7 +169,7 @@ inline std::vector<std::shared_ptr<T>> Ray::Get(const std::vector<ObjectRef<T>>
inline WaitResult Ray::Wait(const std::vector<ObjectID> &ids, int num_objects,
int timeout_ms) {
return internal::RayRuntime()->Wait(ids, num_objects, timeout_ms);
return ray::internal::RayRuntime()->Wait(ids, num_objects, timeout_ms);
}
template <typename ReturnType, typename FuncType, typename ExecFuncType,
@ -180,7 +181,8 @@ inline TaskCaller<ReturnType> Ray::TaskInternal(FuncType &func, ExecFuncType &ex
RemoteFunctionPtrHolder ptr;
ptr.function_pointer = reinterpret_cast<uintptr_t>(func);
ptr.exec_function_pointer = reinterpret_cast<uintptr_t>(exec_func);
return TaskCaller<ReturnType>(internal::RayRuntime().get(), ptr, std::move(task_args));
return TaskCaller<ReturnType>(ray::internal::RayRuntime().get(), ptr,
std::move(task_args));
}
template <typename ActorType, typename FuncType, typename ExecFuncType,
@ -193,7 +195,8 @@ inline ActorCreator<ActorType> Ray::CreateActorInternal(FuncType &create_func,
RemoteFunctionPtrHolder ptr;
ptr.function_pointer = reinterpret_cast<uintptr_t>(create_func);
ptr.exec_function_pointer = reinterpret_cast<uintptr_t>(exec_func);
return ActorCreator<ActorType>(internal::RayRuntime().get(), ptr, std::move(task_args));
return ActorCreator<ActorType>(ray::internal::RayRuntime().get(), ptr,
std::move(task_args));
}
/// Normal task.

View file

@ -0,0 +1,177 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ray/api/serializer.h>
#include "absl/utility/utility.h"
#include <boost/callable_traits.hpp>
#include <functional>
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
namespace ray {
namespace internal {
template <class, class>
struct AddType;
template <class First, class... Second>
struct AddType<First, std::tuple<Second...>> {
using type = std::tuple<First, Second...>;
};
/// Add a type to a tuple: AddType_t<int, std::tuple<double>> equal std::tuple<int,
/// double>.
template <class First, class Second>
using AddType_t = typename AddType<First, Second>::type;
enum ErrorCode {
OK = 0,
FAIL = 1,
};
struct VoidResponse {
int error_code;
std::string error_msg;
MSGPACK_DEFINE(error_code, error_msg);
};
template <typename T>
struct Response {
int error_code;
std::string error_msg;
T data;
MSGPACK_DEFINE(error_code, error_msg, data);
};
template <typename T>
inline static msgpack::sbuffer PackReturnValue(int error_code, std::string error_msg,
T result) {
return ray::api::Serializer::Serialize(
Response<T>{error_code, std::move(error_msg), std::move(result)});
}
inline static msgpack::sbuffer PackReturnValue(int error_code,
std::string error_msg = "ok") {
return ray::api::Serializer::Serialize(VoidResponse{error_code, std::move(error_msg)});
}
/// It's help to invoke functions and member functions, the class Invoker<Function> help
/// do type erase.
template <typename Function>
struct Invoker {
/// Invoke functions by networking stream, at first deserialize the binary data to a
/// tuple, then call function with tuple.
static inline msgpack::sbuffer Apply(const Function &func, const char *data,
size_t size) {
using args_tuple = AddType_t<std::string, boost::callable_traits::args_t<Function>>;
msgpack::sbuffer result;
try {
auto tp = ray::api::Serializer::Deserialize<args_tuple>(data, size);
result = Invoker<Function>::Call(func, std::move(tp));
} catch (msgpack::type_error &e) {
result =
PackReturnValue(ErrorCode::FAIL, std::string("invalid arguments: ") + e.what());
} catch (const std::exception &e) {
result = PackReturnValue(ErrorCode::FAIL,
std::string("function execute exception: ") + e.what());
} catch (...) {
result = PackReturnValue(ErrorCode::FAIL, "unknown exception");
}
return result;
}
template <typename F, size_t... I, typename Arg, typename... Args>
static absl::result_of_t<F(Args...)> CallInternal(const F &f,
const absl::index_sequence<I...> &,
std::tuple<Arg, Args...> tup) {
return f(std::move(std::get<I + 1>(tup))...);
}
template <typename F, typename Arg, typename... Args>
static absl::enable_if_t<std::is_void<absl::result_of_t<F(Args...)>>::value,
msgpack::sbuffer>
Call(const F &f, std::tuple<Arg, Args...> tp) {
CallInternal(f, absl::make_index_sequence<sizeof...(Args)>{}, std::move(tp));
return PackReturnValue(ErrorCode::OK);
}
template <typename F, typename Arg, typename... Args>
static absl::enable_if_t<!std::is_void<absl::result_of_t<F(Args...)>>::value,
msgpack::sbuffer>
Call(const F &f, std::tuple<Arg, Args...> tp) {
auto r = CallInternal(f, absl::make_index_sequence<sizeof...(Args)>{}, std::move(tp));
return PackReturnValue(ErrorCode::OK, "ok", r);
}
};
/// Manage all ray remote functions, add remote functions by RAY_REMOTE, get functions by
/// TaskExecutionHandler.
class FunctionManager {
public:
static FunctionManager &Instance() {
static FunctionManager instance;
return instance;
}
std::function<msgpack::sbuffer(const char *, size_t)> *GetFunction(
const std::string &func_name) {
auto it = map_invokers_.find(func_name);
if (it == map_invokers_.end()) {
return nullptr;
}
return &it->second;
}
template <typename Function>
bool RegisterRemoteFunction(std::string const &name, const Function &f) {
/// Now it is just support free function, it will be
/// improved to support member function later.
auto pair = func_ptr_to_key_map_.emplace((uint64_t)f, name);
if (!pair.second) {
return false;
}
return RegisterNonMemberFunc(name, f);
}
private:
FunctionManager() = default;
~FunctionManager() = default;
FunctionManager(const FunctionManager &) = delete;
FunctionManager(FunctionManager &&) = delete;
template <typename Function>
bool RegisterNonMemberFunc(std::string const &name, Function f) {
return map_invokers_
.emplace(name, std::bind(&Invoker<Function>::Apply, std::move(f),
std::placeholders::_1, std::placeholders::_2))
.second;
}
std::unordered_map<std::string, std::function<msgpack::sbuffer(const char *, size_t)>>
map_invokers_;
std::unordered_map<uintptr_t, std::string> func_ptr_to_key_map_;
};
} // namespace internal
} // namespace ray

View file

@ -0,0 +1,38 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ray/api/function_manager.h>
namespace ray {
namespace {
#define CONCATENATE_DIRECT(s1, s2) s1##s2
#define CONCATENATE(s1, s2) CONCATENATE_DIRECT(s1, s2)
#ifdef _MSC_VER
#define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __COUNTER__)
#else
#define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __LINE__)
#endif
} // namespace
namespace api {
#define RAY_REMOTE(f) \
static auto ANONYMOUS_VARIABLE(var) = \
ray::internal::FunctionManager::Instance().RegisterRemoteFunction(#f, f);
} // namespace api
} // namespace ray

View file

@ -11,7 +11,7 @@ std::once_flag Ray::is_inited_;
void Ray::Init() {
std::call_once(is_inited_, [] {
auto runtime = AbstractRayRuntime::DoInit(RayConfig::GetInstance());
internal::RayRuntimeHolder::Instance().Init(runtime);
ray::internal::RayRuntimeHolder::Instance().Init(runtime);
});
}

View file

@ -1,12 +1,39 @@
#pragma once
#include <ray/api/function_manager.h>
#include <ray/api/serializer.h>
#include <memory>
#include "absl/synchronization/mutex.h"
#include "invocation_spec.h"
#include "ray/core.h"
namespace ray {
namespace internal {
/// Execute remote functions by networking stream.
inline static msgpack::sbuffer TaskExecutionHandler(const char *data, std::size_t size) {
msgpack::sbuffer result;
do {
try {
auto p = ray::api::Serializer::Deserialize<std::tuple<std::string>>(data, size);
auto &func_name = std::get<0>(p);
auto func_ptr = FunctionManager::Instance().GetFunction(func_name);
if (func_ptr == nullptr) {
result = PackReturnValue(internal::ErrorCode::FAIL,
"unknown function: " + func_name, 0);
break;
}
result = (*func_ptr)(data, size);
} catch (const std::exception &ex) {
result = PackReturnValue(internal::ErrorCode::FAIL, ex.what());
}
} while (0);
return result;
}
} // namespace internal
namespace api {
class AbstractRayRuntime;

View file

@ -0,0 +1,80 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <ray/api.h>
#include <ray/api/serializer.h>
#include "cpp/src/ray/runtime/task/task_executor.h"
using namespace ray::api;
using namespace ray::internal;
int Return() { return 1; }
int PlusOne(int x) { return x + 1; }
RAY_REMOTE(PlusOne);
TEST(RayApiTest, DuplicateRegister) {
bool r = FunctionManager::Instance().RegisterRemoteFunction("Return", Return);
EXPECT_TRUE(r);
/// Duplicate register
bool r1 = FunctionManager::Instance().RegisterRemoteFunction("Return", Return);
EXPECT_FALSE(r1);
bool r2 = FunctionManager::Instance().RegisterRemoteFunction("PlusOne", PlusOne);
EXPECT_FALSE(r2);
}
TEST(RayApiTest, FindAndExecuteFunction) {
/// Find and call the registered function.
auto args = std::make_tuple("PlusOne", 1);
auto buf = Serializer::Serialize(args);
auto result_buf = TaskExecutionHandler(buf.data(), buf.size());
/// Deserialize result.
auto response =
Serializer::Deserialize<Response<int>>(result_buf.data(), result_buf.size());
EXPECT_EQ(response.error_code, ErrorCode::OK);
EXPECT_EQ(response.data, 2);
}
TEST(RayApiTest, VoidFunction) {
auto buf1 = Serializer::Serialize(std::make_tuple("Return"));
auto result_buf = TaskExecutionHandler(buf1.data(), buf1.size());
auto response =
Serializer::Deserialize<VoidResponse>(result_buf.data(), result_buf.size());
EXPECT_EQ(response.error_code, ErrorCode::OK);
}
/// We should consider the driver so is not same with the worker so, and find the error
/// reason.
TEST(RayApiTest, NotExistFunction) {
auto buf2 = Serializer::Serialize(std::make_tuple("Return11"));
auto result_buf = TaskExecutionHandler(buf2.data(), buf2.size());
auto response =
Serializer::Deserialize<VoidResponse>(result_buf.data(), result_buf.size());
EXPECT_EQ(response.error_code, ErrorCode::FAIL);
EXPECT_FALSE(response.error_msg.empty());
}
TEST(RayApiTest, ArgumentsNotMatch) {
auto buf = Serializer::Serialize(std::make_tuple("PlusOne", "invalid arguments"));
auto result_buf = TaskExecutionHandler(buf.data(), buf.size());
auto response =
Serializer::Deserialize<Response<int>>(result_buf.data(), result_buf.size());
EXPECT_EQ(response.error_code, ErrorCode::FAIL);
EXPECT_FALSE(response.error_msg.empty());
}