mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[C++ worker] Add ray register part1 (#14436)
This commit is contained in:
parent
a06dc39d9f
commit
f2348a5456
6 changed files with 332 additions and 7 deletions
|
@ -6,6 +6,7 @@
|
|||
#include <ray/api/actor_task_caller.h>
|
||||
#include <ray/api/exec_funcs.h>
|
||||
#include <ray/api/object_ref.h>
|
||||
#include <ray/api/ray_remote.h>
|
||||
#include <ray/api/ray_runtime.h>
|
||||
#include <ray/api/ray_runtime_holder.h>
|
||||
#include <ray/api/task_caller.h>
|
||||
|
@ -139,7 +140,7 @@ inline static std::vector<ObjectID> ObjectRefsToObjectIDs(
|
|||
template <typename T>
|
||||
inline ObjectRef<T> Ray::Put(const T &obj) {
|
||||
auto buffer = std::make_shared<msgpack::sbuffer>(Serializer::Serialize(obj));
|
||||
auto id = internal::RayRuntime()->Put(buffer);
|
||||
auto id = ray::internal::RayRuntime()->Put(buffer);
|
||||
return ObjectRef<T>(id);
|
||||
}
|
||||
|
||||
|
@ -150,7 +151,7 @@ inline std::shared_ptr<T> Ray::Get(const ObjectRef<T> &object) {
|
|||
|
||||
template <typename T>
|
||||
inline std::vector<std::shared_ptr<T>> Ray::Get(const std::vector<ObjectID> &ids) {
|
||||
auto result = internal::RayRuntime()->Get(ids);
|
||||
auto result = ray::internal::RayRuntime()->Get(ids);
|
||||
std::vector<std::shared_ptr<T>> return_objects;
|
||||
return_objects.reserve(result.size());
|
||||
for (auto it = result.begin(); it != result.end(); it++) {
|
||||
|
@ -168,7 +169,7 @@ inline std::vector<std::shared_ptr<T>> Ray::Get(const std::vector<ObjectRef<T>>
|
|||
|
||||
inline WaitResult Ray::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
||||
int timeout_ms) {
|
||||
return internal::RayRuntime()->Wait(ids, num_objects, timeout_ms);
|
||||
return ray::internal::RayRuntime()->Wait(ids, num_objects, timeout_ms);
|
||||
}
|
||||
|
||||
template <typename ReturnType, typename FuncType, typename ExecFuncType,
|
||||
|
@ -180,7 +181,8 @@ inline TaskCaller<ReturnType> Ray::TaskInternal(FuncType &func, ExecFuncType &ex
|
|||
RemoteFunctionPtrHolder ptr;
|
||||
ptr.function_pointer = reinterpret_cast<uintptr_t>(func);
|
||||
ptr.exec_function_pointer = reinterpret_cast<uintptr_t>(exec_func);
|
||||
return TaskCaller<ReturnType>(internal::RayRuntime().get(), ptr, std::move(task_args));
|
||||
return TaskCaller<ReturnType>(ray::internal::RayRuntime().get(), ptr,
|
||||
std::move(task_args));
|
||||
}
|
||||
|
||||
template <typename ActorType, typename FuncType, typename ExecFuncType,
|
||||
|
@ -193,7 +195,8 @@ inline ActorCreator<ActorType> Ray::CreateActorInternal(FuncType &create_func,
|
|||
RemoteFunctionPtrHolder ptr;
|
||||
ptr.function_pointer = reinterpret_cast<uintptr_t>(create_func);
|
||||
ptr.exec_function_pointer = reinterpret_cast<uintptr_t>(exec_func);
|
||||
return ActorCreator<ActorType>(internal::RayRuntime().get(), ptr, std::move(task_args));
|
||||
return ActorCreator<ActorType>(ray::internal::RayRuntime().get(), ptr,
|
||||
std::move(task_args));
|
||||
}
|
||||
|
||||
/// Normal task.
|
||||
|
|
177
cpp/include/ray/api/function_manager.h
Normal file
177
cpp/include/ray/api/function_manager.h
Normal file
|
@ -0,0 +1,177 @@
|
|||
// Copyright 2017 The Ray Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ray/api/serializer.h>
|
||||
#include "absl/utility/utility.h"
|
||||
|
||||
#include <boost/callable_traits.hpp>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace ray {
|
||||
namespace internal {
|
||||
|
||||
template <class, class>
|
||||
struct AddType;
|
||||
|
||||
template <class First, class... Second>
|
||||
struct AddType<First, std::tuple<Second...>> {
|
||||
using type = std::tuple<First, Second...>;
|
||||
};
|
||||
|
||||
/// Add a type to a tuple: AddType_t<int, std::tuple<double>> equal std::tuple<int,
|
||||
/// double>.
|
||||
template <class First, class Second>
|
||||
using AddType_t = typename AddType<First, Second>::type;
|
||||
|
||||
enum ErrorCode {
|
||||
OK = 0,
|
||||
FAIL = 1,
|
||||
};
|
||||
|
||||
struct VoidResponse {
|
||||
int error_code;
|
||||
std::string error_msg;
|
||||
|
||||
MSGPACK_DEFINE(error_code, error_msg);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Response {
|
||||
int error_code;
|
||||
std::string error_msg;
|
||||
T data;
|
||||
|
||||
MSGPACK_DEFINE(error_code, error_msg, data);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline static msgpack::sbuffer PackReturnValue(int error_code, std::string error_msg,
|
||||
T result) {
|
||||
return ray::api::Serializer::Serialize(
|
||||
Response<T>{error_code, std::move(error_msg), std::move(result)});
|
||||
}
|
||||
|
||||
inline static msgpack::sbuffer PackReturnValue(int error_code,
|
||||
std::string error_msg = "ok") {
|
||||
return ray::api::Serializer::Serialize(VoidResponse{error_code, std::move(error_msg)});
|
||||
}
|
||||
|
||||
/// It's help to invoke functions and member functions, the class Invoker<Function> help
|
||||
/// do type erase.
|
||||
template <typename Function>
|
||||
struct Invoker {
|
||||
/// Invoke functions by networking stream, at first deserialize the binary data to a
|
||||
/// tuple, then call function with tuple.
|
||||
static inline msgpack::sbuffer Apply(const Function &func, const char *data,
|
||||
size_t size) {
|
||||
using args_tuple = AddType_t<std::string, boost::callable_traits::args_t<Function>>;
|
||||
|
||||
msgpack::sbuffer result;
|
||||
try {
|
||||
auto tp = ray::api::Serializer::Deserialize<args_tuple>(data, size);
|
||||
result = Invoker<Function>::Call(func, std::move(tp));
|
||||
} catch (msgpack::type_error &e) {
|
||||
result =
|
||||
PackReturnValue(ErrorCode::FAIL, std::string("invalid arguments: ") + e.what());
|
||||
} catch (const std::exception &e) {
|
||||
result = PackReturnValue(ErrorCode::FAIL,
|
||||
std::string("function execute exception: ") + e.what());
|
||||
} catch (...) {
|
||||
result = PackReturnValue(ErrorCode::FAIL, "unknown exception");
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename F, size_t... I, typename Arg, typename... Args>
|
||||
static absl::result_of_t<F(Args...)> CallInternal(const F &f,
|
||||
const absl::index_sequence<I...> &,
|
||||
std::tuple<Arg, Args...> tup) {
|
||||
return f(std::move(std::get<I + 1>(tup))...);
|
||||
}
|
||||
|
||||
template <typename F, typename Arg, typename... Args>
|
||||
static absl::enable_if_t<std::is_void<absl::result_of_t<F(Args...)>>::value,
|
||||
msgpack::sbuffer>
|
||||
Call(const F &f, std::tuple<Arg, Args...> tp) {
|
||||
CallInternal(f, absl::make_index_sequence<sizeof...(Args)>{}, std::move(tp));
|
||||
return PackReturnValue(ErrorCode::OK);
|
||||
}
|
||||
|
||||
template <typename F, typename Arg, typename... Args>
|
||||
static absl::enable_if_t<!std::is_void<absl::result_of_t<F(Args...)>>::value,
|
||||
msgpack::sbuffer>
|
||||
Call(const F &f, std::tuple<Arg, Args...> tp) {
|
||||
auto r = CallInternal(f, absl::make_index_sequence<sizeof...(Args)>{}, std::move(tp));
|
||||
return PackReturnValue(ErrorCode::OK, "ok", r);
|
||||
}
|
||||
};
|
||||
|
||||
/// Manage all ray remote functions, add remote functions by RAY_REMOTE, get functions by
|
||||
/// TaskExecutionHandler.
|
||||
class FunctionManager {
|
||||
public:
|
||||
static FunctionManager &Instance() {
|
||||
static FunctionManager instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::function<msgpack::sbuffer(const char *, size_t)> *GetFunction(
|
||||
const std::string &func_name) {
|
||||
auto it = map_invokers_.find(func_name);
|
||||
if (it == map_invokers_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return &it->second;
|
||||
}
|
||||
|
||||
template <typename Function>
|
||||
bool RegisterRemoteFunction(std::string const &name, const Function &f) {
|
||||
/// Now it is just support free function, it will be
|
||||
/// improved to support member function later.
|
||||
auto pair = func_ptr_to_key_map_.emplace((uint64_t)f, name);
|
||||
if (!pair.second) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return RegisterNonMemberFunc(name, f);
|
||||
}
|
||||
|
||||
private:
|
||||
FunctionManager() = default;
|
||||
~FunctionManager() = default;
|
||||
FunctionManager(const FunctionManager &) = delete;
|
||||
FunctionManager(FunctionManager &&) = delete;
|
||||
|
||||
template <typename Function>
|
||||
bool RegisterNonMemberFunc(std::string const &name, Function f) {
|
||||
return map_invokers_
|
||||
.emplace(name, std::bind(&Invoker<Function>::Apply, std::move(f),
|
||||
std::placeholders::_1, std::placeholders::_2))
|
||||
.second;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::function<msgpack::sbuffer(const char *, size_t)>>
|
||||
map_invokers_;
|
||||
std::unordered_map<uintptr_t, std::string> func_ptr_to_key_map_;
|
||||
};
|
||||
} // namespace internal
|
||||
} // namespace ray
|
38
cpp/include/ray/api/ray_remote.h
Normal file
38
cpp/include/ray/api/ray_remote.h
Normal file
|
@ -0,0 +1,38 @@
|
|||
// Copyright 2017 The Ray Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ray/api/function_manager.h>
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace {
|
||||
#define CONCATENATE_DIRECT(s1, s2) s1##s2
|
||||
#define CONCATENATE(s1, s2) CONCATENATE_DIRECT(s1, s2)
|
||||
#ifdef _MSC_VER
|
||||
#define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __COUNTER__)
|
||||
#else
|
||||
#define ANONYMOUS_VARIABLE(str) CONCATENATE(str, __LINE__)
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
namespace api {
|
||||
|
||||
#define RAY_REMOTE(f) \
|
||||
static auto ANONYMOUS_VARIABLE(var) = \
|
||||
ray::internal::FunctionManager::Instance().RegisterRemoteFunction(#f, f);
|
||||
|
||||
} // namespace api
|
||||
} // namespace ray
|
|
@ -11,7 +11,7 @@ std::once_flag Ray::is_inited_;
|
|||
void Ray::Init() {
|
||||
std::call_once(is_inited_, [] {
|
||||
auto runtime = AbstractRayRuntime::DoInit(RayConfig::GetInstance());
|
||||
internal::RayRuntimeHolder::Instance().Init(runtime);
|
||||
ray::internal::RayRuntimeHolder::Instance().Init(runtime);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +1,39 @@
|
|||
#pragma once
|
||||
|
||||
#include <ray/api/function_manager.h>
|
||||
#include <ray/api/serializer.h>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "invocation_spec.h"
|
||||
#include "ray/core.h"
|
||||
|
||||
namespace ray {
|
||||
namespace internal {
|
||||
|
||||
/// Execute remote functions by networking stream.
|
||||
inline static msgpack::sbuffer TaskExecutionHandler(const char *data, std::size_t size) {
|
||||
msgpack::sbuffer result;
|
||||
do {
|
||||
try {
|
||||
auto p = ray::api::Serializer::Deserialize<std::tuple<std::string>>(data, size);
|
||||
auto &func_name = std::get<0>(p);
|
||||
auto func_ptr = FunctionManager::Instance().GetFunction(func_name);
|
||||
if (func_ptr == nullptr) {
|
||||
result = PackReturnValue(internal::ErrorCode::FAIL,
|
||||
"unknown function: " + func_name, 0);
|
||||
break;
|
||||
}
|
||||
|
||||
result = (*func_ptr)(data, size);
|
||||
} catch (const std::exception &ex) {
|
||||
result = PackReturnValue(internal::ErrorCode::FAIL, ex.what());
|
||||
}
|
||||
} while (0);
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
namespace api {
|
||||
|
||||
class AbstractRayRuntime;
|
||||
|
|
80
cpp/src/ray/test/ray_remote_test.cc
Normal file
80
cpp/src/ray/test/ray_remote_test.cc
Normal file
|
@ -0,0 +1,80 @@
|
|||
// Copyright 2017 The Ray Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ray/api.h>
|
||||
#include <ray/api/serializer.h>
|
||||
#include "cpp/src/ray/runtime/task/task_executor.h"
|
||||
|
||||
using namespace ray::api;
|
||||
using namespace ray::internal;
|
||||
|
||||
int Return() { return 1; }
|
||||
int PlusOne(int x) { return x + 1; }
|
||||
|
||||
RAY_REMOTE(PlusOne);
|
||||
|
||||
TEST(RayApiTest, DuplicateRegister) {
|
||||
bool r = FunctionManager::Instance().RegisterRemoteFunction("Return", Return);
|
||||
EXPECT_TRUE(r);
|
||||
|
||||
/// Duplicate register
|
||||
bool r1 = FunctionManager::Instance().RegisterRemoteFunction("Return", Return);
|
||||
EXPECT_FALSE(r1);
|
||||
|
||||
bool r2 = FunctionManager::Instance().RegisterRemoteFunction("PlusOne", PlusOne);
|
||||
EXPECT_FALSE(r2);
|
||||
}
|
||||
|
||||
TEST(RayApiTest, FindAndExecuteFunction) {
|
||||
/// Find and call the registered function.
|
||||
auto args = std::make_tuple("PlusOne", 1);
|
||||
auto buf = Serializer::Serialize(args);
|
||||
auto result_buf = TaskExecutionHandler(buf.data(), buf.size());
|
||||
|
||||
/// Deserialize result.
|
||||
auto response =
|
||||
Serializer::Deserialize<Response<int>>(result_buf.data(), result_buf.size());
|
||||
|
||||
EXPECT_EQ(response.error_code, ErrorCode::OK);
|
||||
EXPECT_EQ(response.data, 2);
|
||||
}
|
||||
|
||||
TEST(RayApiTest, VoidFunction) {
|
||||
auto buf1 = Serializer::Serialize(std::make_tuple("Return"));
|
||||
auto result_buf = TaskExecutionHandler(buf1.data(), buf1.size());
|
||||
auto response =
|
||||
Serializer::Deserialize<VoidResponse>(result_buf.data(), result_buf.size());
|
||||
EXPECT_EQ(response.error_code, ErrorCode::OK);
|
||||
}
|
||||
|
||||
/// We should consider the driver so is not same with the worker so, and find the error
|
||||
/// reason.
|
||||
TEST(RayApiTest, NotExistFunction) {
|
||||
auto buf2 = Serializer::Serialize(std::make_tuple("Return11"));
|
||||
auto result_buf = TaskExecutionHandler(buf2.data(), buf2.size());
|
||||
auto response =
|
||||
Serializer::Deserialize<VoidResponse>(result_buf.data(), result_buf.size());
|
||||
EXPECT_EQ(response.error_code, ErrorCode::FAIL);
|
||||
EXPECT_FALSE(response.error_msg.empty());
|
||||
}
|
||||
|
||||
TEST(RayApiTest, ArgumentsNotMatch) {
|
||||
auto buf = Serializer::Serialize(std::make_tuple("PlusOne", "invalid arguments"));
|
||||
auto result_buf = TaskExecutionHandler(buf.data(), buf.size());
|
||||
auto response =
|
||||
Serializer::Deserialize<Response<int>>(result_buf.data(), result_buf.size());
|
||||
EXPECT_EQ(response.error_code, ErrorCode::FAIL);
|
||||
EXPECT_FALSE(response.error_msg.empty());
|
||||
}
|
Loading…
Add table
Reference in a new issue