[Lint] One parameter/argument per line for C++ code (#22725)

It's really annoying to deal with parameter/argument conflicts. This is even frustrating when we merge code from the community to Ant's internal code base with hundreds of conflicts caused by parameters/arguments.

In this PR, I updated the clang-format style to make parameters/arguments stay on different lines if they can't fit into a single line.

There are several benefits:

* Conflict resolving is easier.
* Less potential human mistakes when resolving conflicts.
* Git history and Git blame are more straightforward.
* Better readability.
* Align with the new Python format style.
This commit is contained in:
Kai Yang 2022-03-13 17:05:44 +08:00 committed by GitHub
parent f15bcb21dc
commit e9755d87a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
396 changed files with 10812 additions and 5275 deletions

View file

@ -3,3 +3,5 @@ ColumnLimit: 90
DerivePointerAlignment: false DerivePointerAlignment: false
IndentCaseLabels: false IndentCaseLabels: false
PointerAlignment: Right PointerAlignment: Right
BinPackArguments: false
BinPackParameters: false

View file

@ -37,7 +37,8 @@ namespace {
using ::google::protobuf::io::CodedInputStream; using ::google::protobuf::io::CodedInputStream;
using ::google::protobuf::io::FileInputStream; using ::google::protobuf::io::FileInputStream;
bool ReadExtraAction(const std::string &path, blaze::ExtraActionInfo *info, bool ReadExtraAction(const std::string &path,
blaze::ExtraActionInfo *info,
blaze::CppCompileInfo *cpp_info) { blaze::CppCompileInfo *cpp_info) {
int fd = ::open(path.c_str(), O_RDONLY, S_IREAD | S_IWRITE); int fd = ::open(path.c_str(), O_RDONLY, S_IREAD | S_IWRITE);
if (fd < 0) { if (fd < 0) {
@ -97,8 +98,8 @@ int main(int argc, char **argv) {
std::vector<std::string> args; std::vector<std::string> args;
args.push_back(cpp_info.tool()); args.push_back(cpp_info.tool());
args.insert(args.end(), cpp_info.compiler_option().begin(), args.insert(
cpp_info.compiler_option().end()); args.end(), cpp_info.compiler_option().begin(), cpp_info.compiler_option().end());
if (std::find(args.begin(), args.end(), "-c") == args.end()) { if (std::find(args.begin(), args.end(), "-c") == args.end()) {
args.push_back("-c"); args.push_back("-c");
args.push_back(cpp_info.source_file()); args.push_back(cpp_info.source_file());

View file

@ -82,7 +82,8 @@ std::vector<std::shared_ptr<T>> Get(const std::vector<ray::ObjectRef<T>> &object
/// \return Two arrays, one containing locally available objects, one containing the /// \return Two arrays, one containing locally available objects, one containing the
/// rest. /// rest.
template <typename T> template <typename T>
WaitResult<T> Wait(const std::vector<ray::ObjectRef<T>> &objects, int num_objects, WaitResult<T> Wait(const std::vector<ray::ObjectRef<T>> &objects,
int num_objects,
int timeout_ms); int timeout_ms);
/// Create a `TaskCaller` for calling remote function. /// Create a `TaskCaller` for calling remote function.
@ -196,7 +197,8 @@ inline std::vector<std::shared_ptr<T>> Get(const std::vector<ray::ObjectRef<T>>
} }
template <typename T> template <typename T>
inline WaitResult<T> Wait(const std::vector<ray::ObjectRef<T>> &objects, int num_objects, inline WaitResult<T> Wait(const std::vector<ray::ObjectRef<T>> &objects,
int num_objects,
int timeout_ms) { int timeout_ms) {
auto object_ids = ObjectRefsToObjectIDs<T>(objects); auto object_ids = ObjectRefsToObjectIDs<T>(objects);
auto results = auto results =
@ -214,9 +216,10 @@ inline WaitResult<T> Wait(const std::vector<ray::ObjectRef<T>> &objects, int num
} }
inline ray::internal::ActorCreator<PyActorClass> Actor(PyActorClass func) { inline ray::internal::ActorCreator<PyActorClass> Actor(PyActorClass func) {
ray::internal::RemoteFunctionHolder remote_func_holder( ray::internal::RemoteFunctionHolder remote_func_holder(func.module_name,
func.module_name, func.function_name, func.class_name, func.function_name,
ray::internal::LangType::PYTHON); func.class_name,
ray::internal::LangType::PYTHON);
return {ray::internal::GetRayRuntime().get(), std::move(remote_func_holder)}; return {ray::internal::GetRayRuntime().get(), std::move(remote_func_holder)};
} }

View file

@ -80,13 +80,15 @@ ActorHandle<GetActorType<F>, is_python_v<F>> ActorCreator<F>::Remote(Args &&...a
if constexpr (is_python_v<F>) { if constexpr (is_python_v<F>) {
using ArgsTuple = std::tuple<Args...>; using ArgsTuple = std::tuple<Args...>;
Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/true, &args_, Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/true,
&args_,
std::make_index_sequence<sizeof...(Args)>{}, std::make_index_sequence<sizeof...(Args)>{},
std::forward<Args>(args)...); std::forward<Args>(args)...);
} else { } else {
StaticCheck<F, Args...>(); StaticCheck<F, Args...>();
using ArgsTuple = RemoveReference_t<boost::callable_traits::args_t<F>>; using ArgsTuple = RemoveReference_t<boost::callable_traits::args_t<F>>;
Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/false, &args_, Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/false,
&args_,
std::make_index_sequence<sizeof...(Args)>{}, std::make_index_sequence<sizeof...(Args)>{},
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }

View file

@ -45,8 +45,8 @@ class ActorHandle {
std::is_same<ActorType, Self>::value || std::is_base_of<Self, ActorType>::value, std::is_same<ActorType, Self>::value || std::is_base_of<Self, ActorType>::value,
"Class types must be same."); "Class types must be same.");
ray::internal::RemoteFunctionHolder remote_func_holder(actor_func); ray::internal::RemoteFunctionHolder remote_func_holder(actor_func);
return ray::internal::ActorTaskCaller<F>(internal::GetRayRuntime().get(), id_, return ray::internal::ActorTaskCaller<F>(
std::move(remote_func_holder)); internal::GetRayRuntime().get(), id_, std::move(remote_func_holder));
} }
template <typename R> template <typename R>

View file

@ -26,7 +26,8 @@ class ActorTaskCaller {
public: public:
ActorTaskCaller() = default; ActorTaskCaller() = default;
ActorTaskCaller(RayRuntime *runtime, const std::string &id, ActorTaskCaller(RayRuntime *runtime,
const std::string &id,
RemoteFunctionHolder remote_function_holder) RemoteFunctionHolder remote_function_holder)
: runtime_(runtime), : runtime_(runtime),
id_(id), id_(id),
@ -68,13 +69,15 @@ ObjectRef<boost::callable_traits::return_type_t<F>> ActorTaskCaller<F>::Remote(
if constexpr (is_python_v<F>) { if constexpr (is_python_v<F>) {
using ArgsTuple = std::tuple<Args...>; using ArgsTuple = std::tuple<Args...>;
Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/true, &args_, Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/true,
&args_,
std::make_index_sequence<sizeof...(Args)>{}, std::make_index_sequence<sizeof...(Args)>{},
std::forward<Args>(args)...); std::forward<Args>(args)...);
} else { } else {
StaticCheck<F, Args...>(); StaticCheck<F, Args...>();
using ArgsTuple = RemoveReference_t<RemoveFirst_t<boost::callable_traits::args_t<F>>>; using ArgsTuple = RemoveReference_t<RemoveFirst_t<boost::callable_traits::args_t<F>>>;
Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/false, &args_, Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/false,
&args_,
std::make_index_sequence<sizeof...(Args)>{}, std::make_index_sequence<sizeof...(Args)>{},
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }

View file

@ -26,7 +26,8 @@ namespace internal {
class Arguments { class Arguments {
public: public:
template <typename OriginArgType, typename InputArgTypes> template <typename OriginArgType, typename InputArgTypes>
static void WrapArgsImpl(bool cross_lang, std::vector<TaskArg> *task_args, static void WrapArgsImpl(bool cross_lang,
std::vector<TaskArg> *task_args,
InputArgTypes &&arg) { InputArgTypes &&arg) {
if constexpr (is_object_ref_v<OriginArgType>) { if constexpr (is_object_ref_v<OriginArgType>) {
PushReferenceArg(task_args, std::forward<InputArgTypes>(arg)); PushReferenceArg(task_args, std::forward<InputArgTypes>(arg));
@ -65,8 +66,10 @@ class Arguments {
} }
template <typename OriginArgsTuple, size_t... I, typename... InputArgTypes> template <typename OriginArgsTuple, size_t... I, typename... InputArgTypes>
static void WrapArgs(bool cross_lang, std::vector<TaskArg> *task_args, static void WrapArgs(bool cross_lang,
std::index_sequence<I...>, InputArgTypes &&...args) { std::vector<TaskArg> *task_args,
std::index_sequence<I...>,
InputArgTypes &&...args) {
(void)std::initializer_list<int>{ (void)std::initializer_list<int>{
(WrapArgsImpl<std::tuple_element_t<I, OriginArgsTuple>>( (WrapArgsImpl<std::tuple_element_t<I, OriginArgsTuple>>(
cross_lang, task_args, std::forward<InputArgTypes>(args)), cross_lang, task_args, std::forward<InputArgTypes>(args)),
@ -77,7 +80,8 @@ class Arguments {
} }
private: private:
static void PushValueArg(std::vector<TaskArg> *task_args, msgpack::sbuffer &&buffer, static void PushValueArg(std::vector<TaskArg> *task_args,
msgpack::sbuffer &&buffer,
std::string_view meta_str = "") { std::string_view meta_str = "") {
/// Pass by value. /// Pass by value.
TaskArg task_arg; TaskArg task_arg;

View file

@ -84,7 +84,8 @@ struct Invoker {
return result; return result;
} }
static inline msgpack::sbuffer ApplyMember(const Function &func, msgpack::sbuffer *ptr, static inline msgpack::sbuffer ApplyMember(const Function &func,
msgpack::sbuffer *ptr,
const ArgsBufferList &args_buffer) { const ArgsBufferList &args_buffer) {
using RetrunType = boost::callable_traits::return_type_t<Function>; using RetrunType = boost::callable_traits::return_type_t<Function>;
using ArgsTuple = using ArgsTuple =
@ -124,7 +125,8 @@ struct Invoker {
} }
} }
static inline bool GetArgsTuple(std::tuple<> &tup, const ArgsBufferList &args_buffer, static inline bool GetArgsTuple(std::tuple<> &tup,
const ArgsBufferList &args_buffer,
std::index_sequence<>) { std::index_sequence<>) {
return true; return true;
} }
@ -155,7 +157,8 @@ struct Invoker {
} }
template <typename R, typename F, size_t... I, typename... Args> template <typename R, typename F, size_t... I, typename... Args>
static R CallInternal(const F &f, const std::index_sequence<I...> &, static R CallInternal(const F &f,
const std::index_sequence<I...> &,
std::tuple<Args...> args) { std::tuple<Args...> args) {
(void)args; (void)args;
using ArgsTuple = boost::callable_traits::args_t<F>; using ArgsTuple = boost::callable_traits::args_t<F>;
@ -165,21 +168,23 @@ struct Invoker {
template <typename R, typename F, typename Self, typename... Args> template <typename R, typename F, typename Self, typename... Args>
static std::enable_if_t<std::is_void<R>::value, msgpack::sbuffer> CallMember( static std::enable_if_t<std::is_void<R>::value, msgpack::sbuffer> CallMember(
const F &f, Self *self, std::tuple<Args...> args) { const F &f, Self *self, std::tuple<Args...> args) {
CallMemberInternal<R>(f, self, std::make_index_sequence<sizeof...(Args)>{}, CallMemberInternal<R>(
std::move(args)); f, self, std::make_index_sequence<sizeof...(Args)>{}, std::move(args));
return PackVoid(); return PackVoid();
} }
template <typename R, typename F, typename Self, typename... Args> template <typename R, typename F, typename Self, typename... Args>
static std::enable_if_t<!std::is_void<R>::value, msgpack::sbuffer> CallMember( static std::enable_if_t<!std::is_void<R>::value, msgpack::sbuffer> CallMember(
const F &f, Self *self, std::tuple<Args...> args) { const F &f, Self *self, std::tuple<Args...> args) {
auto r = CallMemberInternal<R>(f, self, std::make_index_sequence<sizeof...(Args)>{}, auto r = CallMemberInternal<R>(
std::move(args)); f, self, std::make_index_sequence<sizeof...(Args)>{}, std::move(args));
return PackReturnValue(r); return PackReturnValue(r);
} }
template <typename R, typename F, typename Self, size_t... I, typename... Args> template <typename R, typename F, typename Self, size_t... I, typename... Args>
static R CallMemberInternal(const F &f, Self *self, const std::index_sequence<I...> &, static R CallMemberInternal(const F &f,
Self *self,
const std::index_sequence<I...> &,
std::tuple<Args...> args) { std::tuple<Args...> args) {
(void)args; (void)args;
using ArgsTuple = boost::callable_traits::args_t<F>; using ArgsTuple = boost::callable_traits::args_t<F>;
@ -288,16 +293,20 @@ class FunctionManager {
template <typename Function> template <typename Function>
bool RegisterNonMemberFunc(std::string const &name, Function f) { bool RegisterNonMemberFunc(std::string const &name, Function f) {
return map_invokers_ return map_invokers_
.emplace(name, std::bind(&Invoker<Function>::Apply, std::move(f), .emplace(
std::placeholders::_1)) name,
std::bind(&Invoker<Function>::Apply, std::move(f), std::placeholders::_1))
.second; .second;
} }
template <typename Function> template <typename Function>
bool RegisterMemberFunc(std::string const &name, Function f) { bool RegisterMemberFunc(std::string const &name, Function f) {
return map_mem_func_invokers_ return map_mem_func_invokers_
.emplace(name, std::bind(&Invoker<Function>::ApplyMember, std::move(f), .emplace(name,
std::placeholders::_1, std::placeholders::_2)) std::bind(&Invoker<Function>::ApplyMember,
std::move(f),
std::placeholders::_1,
std::placeholders::_2))
.second; .second;
} }

View file

@ -75,7 +75,8 @@ class RayLogger {
virtual std::ostream &Stream() = 0; virtual std::ostream &Stream() = 0;
}; };
std::unique_ptr<RayLogger> CreateRayLogger(const char *file_name, int line_number, std::unique_ptr<RayLogger> CreateRayLogger(const char *file_name,
int line_number,
RayLoggerLevel severity); RayLoggerLevel severity);
bool IsLevelEnabled(RayLoggerLevel log_level); bool IsLevelEnabled(RayLoggerLevel log_level);

View file

@ -14,10 +14,12 @@
#pragma once #pragma once
#include <ray/api/ray_exception.h> #include <ray/api/ray_exception.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "boost/optional.hpp" #include "boost/optional.hpp"
namespace ray { namespace ray {

View file

@ -60,7 +60,8 @@ inline static int RegisterRemoteFunctions(const T &t, U... u) {
(void)std::initializer_list<int>{ (void)std::initializer_list<int>{
(FunctionManager::Instance().RegisterRemoteFunction( (FunctionManager::Instance().RegisterRemoteFunction(
std::string(func_names[index].data(), func_names[index].length()), u), std::string(func_names[index].data(), func_names[index].length()), u),
index++, 0)...}; index++,
0)...};
return 0; return 0;
} }

View file

@ -29,7 +29,8 @@ namespace internal {
struct RemoteFunctionHolder { struct RemoteFunctionHolder {
RemoteFunctionHolder() = default; RemoteFunctionHolder() = default;
RemoteFunctionHolder(const std::string &module_name, const std::string &function_name, RemoteFunctionHolder(const std::string &module_name,
const std::string &function_name,
const std::string &class_name = "", const std::string &class_name = "",
LangType lang_type = LangType::CPP) { LangType lang_type = LangType::CPP) {
this->module_name = module_name; this->module_name = module_name;
@ -61,7 +62,8 @@ class RayRuntime {
virtual std::vector<std::shared_ptr<msgpack::sbuffer>> Get( virtual std::vector<std::shared_ptr<msgpack::sbuffer>> Get(
const std::vector<std::string> &ids) = 0; const std::vector<std::string> &ids) = 0;
virtual std::vector<bool> Wait(const std::vector<std::string> &ids, int num_objects, virtual std::vector<bool> Wait(const std::vector<std::string> &ids,
int num_objects,
int timeout_ms) = 0; int timeout_ms) = 0;
virtual std::string Call(const RemoteFunctionHolder &remote_function_holder, virtual std::string Call(const RemoteFunctionHolder &remote_function_holder,
@ -71,7 +73,8 @@ class RayRuntime {
std::vector<TaskArg> &args, std::vector<TaskArg> &args,
const ActorCreationOptions &create_options) = 0; const ActorCreationOptions &create_options) = 0;
virtual std::string CallActor(const RemoteFunctionHolder &remote_function_holder, virtual std::string CallActor(const RemoteFunctionHolder &remote_function_holder,
const std::string &actor, std::vector<TaskArg> &args, const std::string &actor,
std::vector<TaskArg> &args,
const CallOptions &call_options) = 0; const CallOptions &call_options) = 0;
virtual void AddLocalReference(const std::string &id) = 0; virtual void AddLocalReference(const std::string &id) = 0;
virtual void RemoveLocalReference(const std::string &id) = 0; virtual void RemoveLocalReference(const std::string &id) = 0;

View file

@ -76,13 +76,15 @@ ObjectRef<boost::callable_traits::return_type_t<F>> TaskCaller<F>::Remote(
if constexpr (is_python_v<F>) { if constexpr (is_python_v<F>) {
using ArgsTuple = std::tuple<Args...>; using ArgsTuple = std::tuple<Args...>;
Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/true, &args_, Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/true,
&args_,
std::make_index_sequence<sizeof...(Args)>{}, std::make_index_sequence<sizeof...(Args)>{},
std::forward<Args>(args)...); std::forward<Args>(args)...);
} else { } else {
StaticCheck<F, Args...>(); StaticCheck<F, Args...>();
using ArgsTuple = RemoveReference_t<boost::callable_traits::args_t<F>>; using ArgsTuple = RemoveReference_t<boost::callable_traits::args_t<F>>;
Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/false, &args_, Arguments::WrapArgs<ArgsTuple>(/*cross_lang=*/false,
&args_,
std::make_index_sequence<sizeof...(Args)>{}, std::make_index_sequence<sizeof...(Args)>{},
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }

View file

@ -64,7 +64,8 @@ struct PlacementGroupCreationOptions {
class PlacementGroup { class PlacementGroup {
public: public:
PlacementGroup() = default; PlacementGroup() = default;
PlacementGroup(std::string id, PlacementGroupCreationOptions options, PlacementGroup(std::string id,
PlacementGroupCreationOptions options,
PlacementGroupState state = PlacementGroupState::UNRECOGNIZED) PlacementGroupState state = PlacementGroupState::UNRECOGNIZED)
: id_(std::move(id)), options_(std::move(options)), state_(state) {} : id_(std::move(id)), options_(std::move(options)), state_(state) {}
std::string GetID() const { return id_; } std::string GetID() const { return id_; }

View file

@ -14,10 +14,10 @@
#pragma once #pragma once
#include <list>
#include <ray/api/object_ref.h> #include <ray/api/object_ref.h>
#include <list>
namespace ray { namespace ray {
/// \param T The type of object. /// \param T The type of object.

View file

@ -25,11 +25,15 @@ ABSL_FLAG(std::string, ray_address, "", "The address of the Ray cluster to conne
/// absl::flags does not provide a IsDefaultValue method, so use a non-empty dummy default /// absl::flags does not provide a IsDefaultValue method, so use a non-empty dummy default
/// value to support empty redis password. /// value to support empty redis password.
ABSL_FLAG(std::string, ray_redis_password, "absl::flags dummy default value", ABSL_FLAG(std::string,
ray_redis_password,
"absl::flags dummy default value",
"Prevents external clients without the password from connecting to Redis " "Prevents external clients without the password from connecting to Redis "
"if provided."); "if provided.");
ABSL_FLAG(std::string, ray_code_search_path, "", ABSL_FLAG(std::string,
ray_code_search_path,
"",
"A list of directories or files of dynamic libraries that specify the " "A list of directories or files of dynamic libraries that specify the "
"search path for user code. Only searching the top level under a directory. " "search path for user code. Only searching the top level under a directory. "
"':' is used as the separator."); "':' is used as the separator.");
@ -38,10 +42,14 @@ ABSL_FLAG(std::string, ray_job_id, "", "Assigned job id.");
ABSL_FLAG(int32_t, ray_node_manager_port, 0, "The port to use for the node manager."); ABSL_FLAG(int32_t, ray_node_manager_port, 0, "The port to use for the node manager.");
ABSL_FLAG(std::string, ray_raylet_socket_name, "", ABSL_FLAG(std::string,
ray_raylet_socket_name,
"",
"It will specify the socket name used by the raylet if provided."); "It will specify the socket name used by the raylet if provided.");
ABSL_FLAG(std::string, ray_plasma_store_socket_name, "", ABSL_FLAG(std::string,
ray_plasma_store_socket_name,
"",
"It will specify the socket name used by the plasma store if provided."); "It will specify the socket name used by the plasma store if provided.");
ABSL_FLAG(std::string, ray_session_dir, "", "The path of this session."); ABSL_FLAG(std::string, ray_session_dir, "", "The path of this session.");
@ -50,12 +58,16 @@ ABSL_FLAG(std::string, ray_logs_dir, "", "Logs dir for workers.");
ABSL_FLAG(std::string, ray_node_ip_address, "", "The ip address for this node."); ABSL_FLAG(std::string, ray_node_ip_address, "", "The ip address for this node.");
ABSL_FLAG(std::string, ray_head_args, "", ABSL_FLAG(std::string,
ray_head_args,
"",
"The command line args to be appended as parameters of the `ray start` " "The command line args to be appended as parameters of the `ray start` "
"command. It takes effect only if Ray head is started by a driver. Run `ray " "command. It takes effect only if Ray head is started by a driver. Run `ray "
"start --help` for details."); "start --help` for details.");
ABSL_FLAG(int64_t, startup_token, -1, ABSL_FLAG(int64_t,
startup_token,
-1,
"The startup token assigned to this worker process by the raylet."); "The startup token assigned to this worker process by the raylet.");
namespace ray { namespace ray {
@ -81,8 +93,8 @@ void ConfigInternal::Init(RayConfig &config, int argc, char **argv) {
if (!FLAGS_ray_code_search_path.CurrentValue().empty()) { if (!FLAGS_ray_code_search_path.CurrentValue().empty()) {
// Code search path like this "/path1/xxx.so:/path2". // Code search path like this "/path1/xxx.so:/path2".
code_search_path = absl::StrSplit(FLAGS_ray_code_search_path.CurrentValue(), ':', code_search_path = absl::StrSplit(
absl::SkipEmpty()); FLAGS_ray_code_search_path.CurrentValue(), ':', absl::SkipEmpty());
} }
if (!FLAGS_ray_address.CurrentValue().empty()) { if (!FLAGS_ray_address.CurrentValue().empty()) {
SetBootstrapAddress(FLAGS_ray_address.CurrentValue()); SetBootstrapAddress(FLAGS_ray_address.CurrentValue());
@ -150,8 +162,8 @@ void ConfigInternal::SetBootstrapAddress(std::string_view address) {
auto pos = address.find(':'); auto pos = address.find(':');
RAY_CHECK(pos != std::string::npos); RAY_CHECK(pos != std::string::npos);
bootstrap_ip = address.substr(0, pos); bootstrap_ip = address.substr(0, pos);
auto ret = std::from_chars(address.data() + pos + 1, address.data() + address.size(), auto ret = std::from_chars(
bootstrap_port); address.data() + pos + 1, address.data() + address.size(), bootstrap_port);
RAY_CHECK(ret.ec == std::errc()); RAY_CHECK(ret.ec == std::errc());
} }
} // namespace internal } // namespace internal

View file

@ -111,7 +111,8 @@ std::vector<std::shared_ptr<msgpack::sbuffer>> AbstractRayRuntime::Get(
} }
std::vector<bool> AbstractRayRuntime::Wait(const std::vector<std::string> &ids, std::vector<bool> AbstractRayRuntime::Wait(const std::vector<std::string> &ids,
int num_objects, int timeout_ms) { int num_objects,
int timeout_ms) {
return object_store_->Wait(StringIDsToObjectIDs(ids), num_objects, timeout_ms); return object_store_->Wait(StringIDsToObjectIDs(ids), num_objects, timeout_ms);
} }
@ -129,7 +130,8 @@ std::vector<std::unique_ptr<::ray::TaskArg>> TransformArgs(
auto meta_str = arg.meta_str; auto meta_str = arg.meta_str;
metadata = std::make_shared<ray::LocalMemoryBuffer>( metadata = std::make_shared<ray::LocalMemoryBuffer>(
reinterpret_cast<uint8_t *>(const_cast<char *>(meta_str.data())), reinterpret_cast<uint8_t *>(const_cast<char *>(meta_str.data())),
meta_str.size(), true); meta_str.size(),
true);
} }
ray_arg = absl::make_unique<ray::TaskArgByValue>(std::make_shared<ray::RayObject>( ray_arg = absl::make_unique<ray::TaskArgByValue>(std::make_shared<ray::RayObject>(
memory_buffer, metadata, std::vector<rpc::ObjectReference>())); memory_buffer, metadata, std::vector<rpc::ObjectReference>()));
@ -141,7 +143,8 @@ std::vector<std::unique_ptr<::ray::TaskArg>> TransformArgs(
auto &core_worker = CoreWorkerProcess::GetCoreWorker(); auto &core_worker = CoreWorkerProcess::GetCoreWorker();
owner_address = core_worker.GetOwnerAddress(id); owner_address = core_worker.GetOwnerAddress(id);
} }
ray_arg = absl::make_unique<ray::TaskArgByReference>(id, owner_address, ray_arg = absl::make_unique<ray::TaskArgByReference>(id,
owner_address,
/*call_site=*/""); /*call_site=*/"");
} }
ray_args.push_back(std::move(ray_arg)); ray_args.push_back(std::move(ray_arg));
@ -181,8 +184,10 @@ std::string AbstractRayRuntime::CreateActor(
} }
std::string AbstractRayRuntime::CallActor( std::string AbstractRayRuntime::CallActor(
const RemoteFunctionHolder &remote_function_holder, const std::string &actor, const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::internal::TaskArg> &args, const CallOptions &call_options) { const std::string &actor,
std::vector<ray::internal::TaskArg> &args,
const CallOptions &call_options) {
InvocationSpec invocation_spec{}; InvocationSpec invocation_spec{};
if (remote_function_holder.lang_type == LangType::PYTHON) { if (remote_function_holder.lang_type == LangType::PYTHON) {
const auto native_actor_handle = CoreWorkerProcess::GetCoreWorker().GetActorHandle( const auto native_actor_handle = CoreWorkerProcess::GetCoreWorker().GetActorHandle(
@ -192,11 +197,11 @@ std::string AbstractRayRuntime::CallActor(
RemoteFunctionHolder func_holder = remote_function_holder; RemoteFunctionHolder func_holder = remote_function_holder;
func_holder.module_name = typed_descriptor->ModuleName(); func_holder.module_name = typed_descriptor->ModuleName();
func_holder.class_name = typed_descriptor->ClassName(); func_holder.class_name = typed_descriptor->ClassName();
invocation_spec = BuildInvocationSpec1(TaskType::ACTOR_TASK, func_holder, args, invocation_spec = BuildInvocationSpec1(
ActorID::FromBinary(actor)); TaskType::ACTOR_TASK, func_holder, args, ActorID::FromBinary(actor));
} else { } else {
invocation_spec = BuildInvocationSpec1(TaskType::ACTOR_TASK, remote_function_holder, invocation_spec = BuildInvocationSpec1(
args, ActorID::FromBinary(actor)); TaskType::ACTOR_TASK, remote_function_holder, args, ActorID::FromBinary(actor));
} }
return task_submitter_->SubmitActorTask(invocation_spec, call_options).Binary(); return task_submitter_->SubmitActorTask(invocation_spec, call_options).Binary();
@ -308,7 +313,8 @@ PlacementGroup AbstractRayRuntime::GeneratePlacementGroup(const std::string &str
options.bundles.emplace_back(bundle); options.bundles.emplace_back(bundle);
} }
options.strategy = PlacementStrategy(pg_table_data.strategy()); options.strategy = PlacementStrategy(pg_table_data.strategy());
PlacementGroup group(pg_table_data.placement_group_id(), std::move(options), PlacementGroup group(pg_table_data.placement_group_id(),
std::move(options),
PlacementGroupState(pg_table_data.state())); PlacementGroupState(pg_table_data.state()));
return group; return group;
} }

View file

@ -54,7 +54,8 @@ class AbstractRayRuntime : public RayRuntime {
std::vector<std::shared_ptr<msgpack::sbuffer>> Get(const std::vector<std::string> &ids); std::vector<std::shared_ptr<msgpack::sbuffer>> Get(const std::vector<std::string> &ids);
std::vector<bool> Wait(const std::vector<std::string> &ids, int num_objects, std::vector<bool> Wait(const std::vector<std::string> &ids,
int num_objects,
int timeout_ms); int timeout_ms);
std::string Call(const RemoteFunctionHolder &remote_function_holder, std::string Call(const RemoteFunctionHolder &remote_function_holder,

View file

@ -24,7 +24,8 @@ namespace ray {
namespace internal { namespace internal {
LocalModeRayRuntime::LocalModeRayRuntime() LocalModeRayRuntime::LocalModeRayRuntime()
: worker_(ray::core::WorkerType::DRIVER, ComputeDriverIdFromJob(JobID::Nil()), : worker_(ray::core::WorkerType::DRIVER,
ComputeDriverIdFromJob(JobID::Nil()),
JobID::Nil()) { JobID::Nil()) {
object_store_ = std::unique_ptr<ObjectStore>(new LocalModeObjectStore(*this)); object_store_ = std::unique_ptr<ObjectStore>(new LocalModeObjectStore(*this));
task_submitter_ = std::unique_ptr<TaskSubmitter>(new LocalModeTaskSubmitter(*this)); task_submitter_ = std::unique_ptr<TaskSubmitter>(new LocalModeTaskSubmitter(*this));

View file

@ -26,7 +26,8 @@ class RayLoggerImpl : public RayLogger, public ray::RayLog {
std::ostream &Stream() override { return ray::RayLog::Stream(); } std::ostream &Stream() override { return ray::RayLog::Stream(); }
}; };
std::unique_ptr<RayLogger> CreateRayLogger(const char *file_name, int line_number, std::unique_ptr<RayLogger> CreateRayLogger(const char *file_name,
int line_number,
RayLoggerLevel severity) { RayLoggerLevel severity) {
return std::make_unique<RayLoggerImpl>(file_name, line_number, severity); return std::make_unique<RayLoggerImpl>(file_name, line_number, severity);
} }

View file

@ -58,9 +58,12 @@ std::shared_ptr<msgpack::sbuffer> LocalModeObjectStore::GetRaw(const ObjectID &o
std::vector<std::shared_ptr<msgpack::sbuffer>> LocalModeObjectStore::GetRaw( std::vector<std::shared_ptr<msgpack::sbuffer>> LocalModeObjectStore::GetRaw(
const std::vector<ObjectID> &ids, int timeout_ms) { const std::vector<ObjectID> &ids, int timeout_ms) {
std::vector<std::shared_ptr<::ray::RayObject>> results; std::vector<std::shared_ptr<::ray::RayObject>> results;
::ray::Status status = ::ray::Status status = memory_store_->Get(ids,
memory_store_->Get(ids, (int)ids.size(), timeout_ms, (int)ids.size(),
local_mode_ray_tuntime_.GetWorkerContext(), false, &results); timeout_ms,
local_mode_ray_tuntime_.GetWorkerContext(),
false,
&results);
if (!status.ok()) { if (!status.ok()) {
throw RayException("Get object error: " + status.ToString()); throw RayException("Get object error: " + status.ToString());
} }
@ -78,15 +81,18 @@ std::vector<std::shared_ptr<msgpack::sbuffer>> LocalModeObjectStore::GetRaw(
} }
std::vector<bool> LocalModeObjectStore::Wait(const std::vector<ObjectID> &ids, std::vector<bool> LocalModeObjectStore::Wait(const std::vector<ObjectID> &ids,
int num_objects, int timeout_ms) { int num_objects,
int timeout_ms) {
absl::flat_hash_set<ObjectID> memory_object_ids; absl::flat_hash_set<ObjectID> memory_object_ids;
for (const auto &object_id : ids) { for (const auto &object_id : ids) {
memory_object_ids.insert(object_id); memory_object_ids.insert(object_id);
} }
absl::flat_hash_set<ObjectID> ready; absl::flat_hash_set<ObjectID> ready;
::ray::Status status = ::ray::Status status = memory_store_->Wait(memory_object_ids,
memory_store_->Wait(memory_object_ids, num_objects, timeout_ms, num_objects,
local_mode_ray_tuntime_.GetWorkerContext(), &ready); timeout_ms,
local_mode_ray_tuntime_.GetWorkerContext(),
&ready);
if (!status.ok()) { if (!status.ok()) {
throw RayException("Wait object error: " + status.ToString()); throw RayException("Wait object error: " + status.ToString());
} }

View file

@ -29,7 +29,8 @@ class LocalModeObjectStore : public ObjectStore {
public: public:
LocalModeObjectStore(LocalModeRayRuntime &local_mode_ray_tuntime); LocalModeObjectStore(LocalModeRayRuntime &local_mode_ray_tuntime);
std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects, std::vector<bool> Wait(const std::vector<ObjectID> &ids,
int num_objects,
int timeout_ms); int timeout_ms);
void AddLocalReference(const std::string &id); void AddLocalReference(const std::string &id);

View file

@ -34,7 +34,8 @@ void NativeObjectStore::PutRaw(std::shared_ptr<msgpack::sbuffer> data,
auto buffer = std::make_shared<::ray::LocalMemoryBuffer>( auto buffer = std::make_shared<::ray::LocalMemoryBuffer>(
reinterpret_cast<uint8_t *>(data->data()), data->size(), true); reinterpret_cast<uint8_t *>(data->data()), data->size(), true);
auto status = core_worker.Put( auto status = core_worker.Put(
::ray::RayObject(buffer, nullptr, std::vector<rpc::ObjectReference>()), {}, ::ray::RayObject(buffer, nullptr, std::vector<rpc::ObjectReference>()),
{},
object_id); object_id);
if (!status.ok()) { if (!status.ok()) {
throw RayException("Put object error"); throw RayException("Put object error");
@ -48,7 +49,8 @@ void NativeObjectStore::PutRaw(std::shared_ptr<msgpack::sbuffer> data,
auto buffer = std::make_shared<::ray::LocalMemoryBuffer>( auto buffer = std::make_shared<::ray::LocalMemoryBuffer>(
reinterpret_cast<uint8_t *>(data->data()), data->size(), true); reinterpret_cast<uint8_t *>(data->data()), data->size(), true);
auto status = core_worker.Put( auto status = core_worker.Put(
::ray::RayObject(buffer, nullptr, std::vector<rpc::ObjectReference>()), {}, ::ray::RayObject(buffer, nullptr, std::vector<rpc::ObjectReference>()),
{},
object_id); object_id);
if (!status.ok()) { if (!status.ok()) {
throw RayException("Put object error"); throw RayException("Put object error");
@ -113,7 +115,8 @@ std::vector<std::shared_ptr<msgpack::sbuffer>> NativeObjectStore::GetRaw(
} }
std::vector<bool> NativeObjectStore::Wait(const std::vector<ObjectID> &ids, std::vector<bool> NativeObjectStore::Wait(const std::vector<ObjectID> &ids,
int num_objects, int timeout_ms) { int num_objects,
int timeout_ms) {
std::vector<bool> results; std::vector<bool> results;
auto &core_worker = CoreWorkerProcess::GetCoreWorker(); auto &core_worker = CoreWorkerProcess::GetCoreWorker();
// TODO(SongGuyang): Support `fetch_local` option in API. // TODO(SongGuyang): Support `fetch_local` option in API.

View file

@ -24,7 +24,8 @@ namespace internal {
class NativeObjectStore : public ObjectStore { class NativeObjectStore : public ObjectStore {
public: public:
std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects, std::vector<bool> Wait(const std::vector<ObjectID> &ids,
int num_objects,
int timeout_ms); int timeout_ms);
void AddLocalReference(const std::string &id); void AddLocalReference(const std::string &id);

View file

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <msgpack.hpp> #include <msgpack.hpp>
#include "ray/common/id.h" #include "ray/common/id.h"
namespace ray { namespace ray {
@ -67,7 +68,8 @@ class ObjectStore {
/// \param[in] num_objects The minimum number of objects to wait. /// \param[in] num_objects The minimum number of objects to wait.
/// \param[in] timeout_ms The maximum wait time in milliseconds. /// \param[in] timeout_ms The maximum wait time in milliseconds.
/// \return A vector that indicates each object has appeared or not. /// \return A vector that indicates each object has appeared or not.
virtual std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects, virtual std::vector<bool> Wait(const std::vector<ObjectID> &ids,
int num_objects,
int timeout_ms) = 0; int timeout_ms) = 0;
/// Increase the reference count for this object ID. /// Increase the reference count for this object ID.

View file

@ -49,27 +49,41 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation,
invocation.name.empty() ? functionDescriptor->DefaultTaskName() : invocation.name; invocation.name.empty() ? functionDescriptor->DefaultTaskName() : invocation.name;
// TODO (Alex): Properly set the depth here? // TODO (Alex): Properly set the depth here?
builder.SetCommonTaskSpec(task_id, task_name, rpc::Language::CPP, functionDescriptor, builder.SetCommonTaskSpec(task_id,
task_name,
rpc::Language::CPP,
functionDescriptor,
local_mode_ray_tuntime_.GetCurrentJobID(), local_mode_ray_tuntime_.GetCurrentJobID(),
local_mode_ray_tuntime_.GetCurrentTaskId(), 0, local_mode_ray_tuntime_.GetCurrentTaskId(),
local_mode_ray_tuntime_.GetCurrentTaskId(), address, 1, 0,
required_resources, required_placement_resources, "", local_mode_ray_tuntime_.GetCurrentTaskId(),
address,
1,
required_resources,
required_placement_resources,
"",
/*depth=*/0); /*depth=*/0);
if (invocation.task_type == TaskType::NORMAL_TASK) { if (invocation.task_type == TaskType::NORMAL_TASK) {
} else if (invocation.task_type == TaskType::ACTOR_CREATION_TASK) { } else if (invocation.task_type == TaskType::ACTOR_CREATION_TASK) {
invocation.actor_id = local_mode_ray_tuntime_.GetNextActorID(); invocation.actor_id = local_mode_ray_tuntime_.GetNextActorID();
rpc::SchedulingStrategy scheduling_strategy; rpc::SchedulingStrategy scheduling_strategy;
scheduling_strategy.mutable_default_scheduling_strategy(); scheduling_strategy.mutable_default_scheduling_strategy();
builder.SetActorCreationTaskSpec(invocation.actor_id, /*serialized_actor_handle=*/"", builder.SetActorCreationTaskSpec(invocation.actor_id,
scheduling_strategy, options.max_restarts, /*serialized_actor_handle=*/"",
/*max_task_retries=*/0, {}, options.max_concurrency); scheduling_strategy,
options.max_restarts,
/*max_task_retries=*/0,
{},
options.max_concurrency);
} else if (invocation.task_type == TaskType::ACTOR_TASK) { } else if (invocation.task_type == TaskType::ACTOR_TASK) {
const TaskID actor_creation_task_id = const TaskID actor_creation_task_id =
TaskID::ForActorCreationTask(invocation.actor_id); TaskID::ForActorCreationTask(invocation.actor_id);
const ObjectID actor_creation_dummy_object_id = const ObjectID actor_creation_dummy_object_id =
ObjectID::FromIndex(actor_creation_task_id, 1); ObjectID::FromIndex(actor_creation_task_id, 1);
builder.SetActorTaskSpec(invocation.actor_id, actor_creation_dummy_object_id, builder.SetActorTaskSpec(invocation.actor_id,
ObjectID(), invocation.actor_counter); actor_creation_dummy_object_id,
ObjectID(),
invocation.actor_counter);
} else { } else {
throw RayException("unknown task type"); throw RayException("unknown task type");
} }
@ -92,20 +106,20 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation,
/// TODO(SongGuyang): Handle task dependencies. /// TODO(SongGuyang): Handle task dependencies.
/// Execute actor task directly in the main thread because we must guarantee the actor /// Execute actor task directly in the main thread because we must guarantee the actor
/// task executed by calling order. /// task executed by calling order.
TaskExecutor::Invoke(task_specification, actor, runtime, actor_contexts_, TaskExecutor::Invoke(
actor_contexts_mutex_); task_specification, actor, runtime, actor_contexts_, actor_contexts_mutex_);
} else { } else {
boost::asio::post(*thread_pool_.get(), boost::asio::post(
std::bind( *thread_pool_.get(),
[actor, mutex, runtime, this](TaskSpecification &ts) { std::bind(
if (mutex) { [actor, mutex, runtime, this](TaskSpecification &ts) {
absl::MutexLock lock(mutex.get()); if (mutex) {
} absl::MutexLock lock(mutex.get());
TaskExecutor::Invoke(ts, actor, runtime, }
this->actor_contexts_, TaskExecutor::Invoke(
this->actor_contexts_mutex_); ts, actor, runtime, this->actor_contexts_, this->actor_contexts_mutex_);
}, },
std::move(task_specification))); std::move(task_specification)));
} }
return return_object_id; return return_object_id;
} }

View file

@ -33,7 +33,8 @@ RayFunction BuildRayFunction(InvocationSpec &invocation) {
auto function_descriptor = FunctionDescriptorBuilder::BuildPython( auto function_descriptor = FunctionDescriptorBuilder::BuildPython(
invocation.remote_function_holder.module_name, invocation.remote_function_holder.module_name,
invocation.remote_function_holder.class_name, invocation.remote_function_holder.class_name,
invocation.remote_function_holder.function_name, ""); invocation.remote_function_holder.function_name,
"");
return RayFunction(ray::Language::PYTHON, function_descriptor); return RayFunction(ray::Language::PYTHON, function_descriptor);
} else { } else {
throw RayException("not supported yet"); throw RayException("not supported yet");
@ -76,8 +77,13 @@ ObjectID NativeTaskSubmitter::Submit(InvocationSpec &invocation,
bundle_id.second); bundle_id.second);
placement_group_scheduling_strategy->set_placement_group_capture_child_tasks(false); placement_group_scheduling_strategy->set_placement_group_capture_child_tasks(false);
} }
return_refs = core_worker.SubmitTask(BuildRayFunction(invocation), invocation.args, return_refs = core_worker.SubmitTask(BuildRayFunction(invocation),
options, 1, false, scheduling_strategy, ""); invocation.args,
options,
1,
false,
scheduling_strategy,
"");
} }
std::vector<ObjectID> return_ids; std::vector<ObjectID> return_ids;
for (const auto &ref : return_refs.value()) { for (const auto &ref : return_refs.value()) {
@ -121,8 +127,8 @@ ActorID NativeTaskSubmitter::CreateActor(InvocationSpec &invocation,
/*is_asyncio=*/false, /*is_asyncio=*/false,
scheduling_strategy}; scheduling_strategy};
ActorID actor_id; ActorID actor_id;
auto status = core_worker.CreateActor(BuildRayFunction(invocation), invocation.args, auto status = core_worker.CreateActor(
actor_options, "", &actor_id); BuildRayFunction(invocation), invocation.args, actor_options, "", &actor_id);
if (!status.ok()) { if (!status.ok()) {
throw RayException("Create actor error"); throw RayException("Create actor error");
} }
@ -150,8 +156,10 @@ ActorID NativeTaskSubmitter::GetActor(const std::string &actor_name) const {
ray::PlacementGroup NativeTaskSubmitter::CreatePlacementGroup( ray::PlacementGroup NativeTaskSubmitter::CreatePlacementGroup(
const ray::PlacementGroupCreationOptions &create_options) { const ray::PlacementGroupCreationOptions &create_options) {
auto options = ray::core::PlacementGroupCreationOptions( auto options = ray::core::PlacementGroupCreationOptions(
create_options.name, (ray::core::PlacementStrategy)create_options.strategy, create_options.name,
create_options.bundles, false); (ray::core::PlacementStrategy)create_options.strategy,
create_options.bundles,
false);
ray::PlacementGroupID placement_group_id; ray::PlacementGroupID placement_group_id;
auto status = CoreWorkerProcess::GetCoreWorker().CreatePlacementGroup( auto status = CoreWorkerProcess::GetCoreWorker().CreatePlacementGroup(
options, &placement_group_id); options, &placement_group_id);

View file

@ -85,7 +85,8 @@ std::unique_ptr<ObjectID> TaskExecutor::Execute(InvocationSpec &invocation) {
/// TODO(qicosmos): Need to add more details of the error messages, such as object id, /// TODO(qicosmos): Need to add more details of the error messages, such as object id,
/// task id etc. /// task id etc.
std::pair<Status, std::shared_ptr<msgpack::sbuffer>> GetExecuteResult( std::pair<Status, std::shared_ptr<msgpack::sbuffer>> GetExecuteResult(
const std::string &func_name, const ArgsBufferList &args_buffer, const std::string &func_name,
const ArgsBufferList &args_buffer,
msgpack::sbuffer *actor_ptr) { msgpack::sbuffer *actor_ptr) {
try { try {
EntryFuntion entry_function; EntryFuntion entry_function;
@ -122,11 +123,14 @@ std::pair<Status, std::shared_ptr<msgpack::sbuffer>> GetExecuteResult(
} }
Status TaskExecutor::ExecuteTask( Status TaskExecutor::ExecuteTask(
ray::TaskType task_type, const std::string task_name, const RayFunction &ray_function, ray::TaskType task_type,
const std::string task_name,
const RayFunction &ray_function,
const std::unordered_map<std::string, double> &required_resources, const std::unordered_map<std::string, double> &required_resources,
const std::vector<std::shared_ptr<ray::RayObject>> &args_buffer, const std::vector<std::shared_ptr<ray::RayObject>> &args_buffer,
const std::vector<rpc::ObjectReference> &arg_refs, const std::vector<rpc::ObjectReference> &arg_refs,
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint, const std::vector<ObjectID> &return_ids,
const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<ray::RayObject>> *results, std::vector<std::shared_ptr<ray::RayObject>> *results,
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes, std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
bool *is_application_level_error, bool *is_application_level_error,
@ -210,8 +214,12 @@ Status TaskExecutor::ExecuteTask(
size_t total = cross_lang ? (XLANG_HEADER_LEN + data_size) : data_size; size_t total = cross_lang ? (XLANG_HEADER_LEN + data_size) : data_size;
RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().AllocateReturnObject( RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().AllocateReturnObject(
result_id, total, meta_buffer, std::vector<ray::ObjectID>(), result_id,
&task_output_inlined_bytes, result_ptr)); total,
meta_buffer,
std::vector<ray::ObjectID>(),
&task_output_inlined_bytes,
result_ptr));
auto result = *result_ptr; auto result = *result_ptr;
if (result != nullptr) { if (result != nullptr) {
@ -243,7 +251,8 @@ Status TaskExecutor::ExecuteTask(
} }
void TaskExecutor::Invoke( void TaskExecutor::Invoke(
const TaskSpecification &task_spec, std::shared_ptr<msgpack::sbuffer> actor, const TaskSpecification &task_spec,
std::shared_ptr<msgpack::sbuffer> actor,
AbstractRayRuntime *runtime, AbstractRayRuntime *runtime,
std::unordered_map<ActorID, std::unique_ptr<ActorContext>> &actor_contexts, std::unordered_map<ActorID, std::unique_ptr<ActorContext>> &actor_contexts,
absl::Mutex &actor_contexts_mutex) { absl::Mutex &actor_contexts_mutex) {
@ -267,8 +276,8 @@ void TaskExecutor::Invoke(
std::shared_ptr<msgpack::sbuffer> data; std::shared_ptr<msgpack::sbuffer> data;
try { try {
if (actor) { if (actor) {
auto result = TaskExecutionHandler(typed_descriptor->FunctionName(), args_buffer, auto result = TaskExecutionHandler(
actor.get()); typed_descriptor->FunctionName(), args_buffer, actor.get());
data = std::make_shared<msgpack::sbuffer>(std::move(result)); data = std::make_shared<msgpack::sbuffer>(std::move(result));
runtime->Put(std::move(data), task_spec.ReturnId(0)); runtime->Put(std::move(data), task_spec.ReturnId(0));
} else { } else {

View file

@ -71,18 +71,21 @@ class TaskExecutor {
std::unique_ptr<ObjectID> Execute(InvocationSpec &invocation); std::unique_ptr<ObjectID> Execute(InvocationSpec &invocation);
static void Invoke( static void Invoke(
const TaskSpecification &task_spec, std::shared_ptr<msgpack::sbuffer> actor, const TaskSpecification &task_spec,
std::shared_ptr<msgpack::sbuffer> actor,
AbstractRayRuntime *runtime, AbstractRayRuntime *runtime,
std::unordered_map<ActorID, std::unique_ptr<ActorContext>> &actor_contexts, std::unordered_map<ActorID, std::unique_ptr<ActorContext>> &actor_contexts,
absl::Mutex &actor_contexts_mutex); absl::Mutex &actor_contexts_mutex);
static Status ExecuteTask( static Status ExecuteTask(
ray::TaskType task_type, const std::string task_name, ray::TaskType task_type,
const std::string task_name,
const RayFunction &ray_function, const RayFunction &ray_function,
const std::unordered_map<std::string, double> &required_resources, const std::unordered_map<std::string, double> &required_resources,
const std::vector<std::shared_ptr<ray::RayObject>> &args, const std::vector<std::shared_ptr<ray::RayObject>> &args,
const std::vector<rpc::ObjectReference> &arg_refs, const std::vector<rpc::ObjectReference> &arg_refs,
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint, const std::vector<ObjectID> &return_ids,
const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<ray::RayObject>> *results, std::vector<std::shared_ptr<ray::RayObject>> *results,
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes, std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
bool *is_application_level_error, bool *is_application_level_error,

View file

@ -98,8 +98,14 @@ class Counter {
} }
}; };
RAY_REMOTE(Counter::FactoryCreate, &Counter::Plus1, &Counter::Plus, &Counter::Triple, RAY_REMOTE(Counter::FactoryCreate,
&Counter::Add, &Counter::GetVal, &Counter::GetIntVal, &Counter::GetList); &Counter::Plus1,
&Counter::Plus,
&Counter::Triple,
&Counter::Add,
&Counter::GetVal,
&Counter::GetIntVal,
&Counter::GetList);
TEST(RayApiTest, LogTest) { TEST(RayApiTest, LogTest) {
auto log_path = boost::filesystem::current_path().string() + "/tmp/"; auto log_path = boost::filesystem::current_path().string() + "/tmp/";
@ -325,8 +331,8 @@ TEST(RayApiTest, CompareWithFuture) {
TEST(RayApiTest, CreateAndRemovePlacementGroup) { TEST(RayApiTest, CreateAndRemovePlacementGroup) {
std::vector<std::unordered_map<std::string, double>> bundles{{{"CPU", 1}}}; std::vector<std::unordered_map<std::string, double>> bundles{{{"CPU", 1}}};
ray::PlacementGroupCreationOptions options1{"first_placement_group", bundles, ray::PlacementGroupCreationOptions options1{
ray::PlacementStrategy::PACK}; "first_placement_group", bundles, ray::PlacementStrategy::PACK};
auto first_placement_group = ray::CreatePlacementGroup(options1); auto first_placement_group = ray::CreatePlacementGroup(options1);
EXPECT_TRUE(first_placement_group.Wait(10)); EXPECT_TRUE(first_placement_group.Wait(10));

View file

@ -45,8 +45,8 @@ struct Person {
TEST(RayClusterModeTest, FullTest) { TEST(RayClusterModeTest, FullTest) {
ray::RayConfig config; ray::RayConfig config;
config.head_args = {"--num-cpus", "2", "--resources", config.head_args = {
"{\"resource1\":1,\"resource2\":2}"}; "--num-cpus", "2", "--resources", "{\"resource1\":1,\"resource2\":2}"};
if (absl::GetFlag<bool>(FLAGS_external_cluster)) { if (absl::GetFlag<bool>(FLAGS_external_cluster)) {
auto port = absl::GetFlag<int32_t>(FLAGS_redis_port); auto port = absl::GetFlag<int32_t>(FLAGS_redis_port);
std::string password = absl::GetFlag<std::string>(FLAGS_redis_password); std::string password = absl::GetFlag<std::string>(FLAGS_redis_password);
@ -401,8 +401,8 @@ TEST(RayClusterModeTest, CreateAndRemovePlacementGroup) {
TEST(RayClusterModeTest, CreatePlacementGroupExceedsClusterResource) { TEST(RayClusterModeTest, CreatePlacementGroupExceedsClusterResource) {
std::vector<std::unordered_map<std::string, double>> bundles{{{"CPU", 10000}}}; std::vector<std::unordered_map<std::string, double>> bundles{{{"CPU", 10000}}};
ray::PlacementGroupCreationOptions options{"first_placement_group", bundles, ray::PlacementGroupCreationOptions options{
ray::PlacementStrategy::PACK}; "first_placement_group", bundles, ray::PlacementStrategy::PACK};
auto first_placement_group = ray::CreatePlacementGroup(options); auto first_placement_group = ray::CreatePlacementGroup(options);
EXPECT_FALSE(first_placement_group.Wait(3)); EXPECT_FALSE(first_placement_group.Wait(3));
ray::RemovePlacementGroup(first_placement_group.GetID()); ray::RemovePlacementGroup(first_placement_group.GetID());

View file

@ -83,11 +83,18 @@ bool Counter::CheckRestartInActorCreationTask() { return is_restared; }
bool Counter::CheckRestartInActorTask() { return ray::WasCurrentActorRestarted(); } bool Counter::CheckRestartInActorTask() { return ray::WasCurrentActorRestarted(); }
RAY_REMOTE(RAY_FUNC(Counter::FactoryCreate), Counter::FactoryCreateException, RAY_REMOTE(RAY_FUNC(Counter::FactoryCreate),
Counter::FactoryCreateException,
RAY_FUNC(Counter::FactoryCreate, int), RAY_FUNC(Counter::FactoryCreate, int),
RAY_FUNC(Counter::FactoryCreate, int, int), &Counter::Plus1, &Counter::Add, RAY_FUNC(Counter::FactoryCreate, int, int),
&Counter::Exit, &Counter::GetPid, &Counter::ExceptionFunc, &Counter::Plus1,
&Counter::CheckRestartInActorCreationTask, &Counter::CheckRestartInActorTask, &Counter::Add,
&Counter::GetVal, &Counter::GetIntVal); &Counter::Exit,
&Counter::GetPid,
&Counter::ExceptionFunc,
&Counter::CheckRestartInActorCreationTask,
&Counter::CheckRestartInActorTask,
&Counter::GetVal,
&Counter::GetIntVal);
RAY_REMOTE(ActorConcurrentCall::FactoryCreate, &ActorConcurrentCall::CountDown); RAY_REMOTE(ActorConcurrentCall::FactoryCreate, &ActorConcurrentCall::CountDown);

View file

@ -38,5 +38,15 @@ Student GetStudent(Student student) { return student; }
std::map<int, Student> GetStudents(std::map<int, Student> students) { return students; } std::map<int, Student> GetStudents(std::map<int, Student> students) { return students; }
RAY_REMOTE(Return1, Plus1, Plus, ThrowTask, ReturnLargeArray, Echo, GetMap, GetArray, RAY_REMOTE(Return1,
GetList, GetTuple, GetStudent, GetStudents); Plus1,
Plus,
ThrowTask,
ReturnLargeArray,
Echo,
GetMap,
GetArray,
GetList,
GetTuple,
GetStudent,
GetStudents);

View file

@ -151,8 +151,8 @@ void StartServer() {
// different nodes if possible. // different nodes if possible.
std::vector<std::unordered_map<std::string, double>> bundles{RESOUECES, RESOUECES}; std::vector<std::unordered_map<std::string, double>> bundles{RESOUECES, RESOUECES};
ray::PlacementGroupCreationOptions options{"kv_server_pg", bundles, ray::PlacementGroupCreationOptions options{
ray::PlacementStrategy::SPREAD}; "kv_server_pg", bundles, ray::PlacementStrategy::SPREAD};
auto placement_group = ray::CreatePlacementGroup(options); auto placement_group = ray::CreatePlacementGroup(options);
// Wait until the placement group is created. // Wait until the placement group is created.
assert(placement_group.Wait(10)); assert(placement_group.Wait(10));

View file

@ -27,11 +27,18 @@ namespace internal {
using ray::core::CoreWorkerProcess; using ray::core::CoreWorkerProcess;
using ray::core::WorkerType; using ray::core::WorkerType;
void ProcessHelper::StartRayNode(const int redis_port, const std::string redis_password, void ProcessHelper::StartRayNode(const int redis_port,
const std::string redis_password,
const std::vector<std::string> &head_args) { const std::vector<std::string> &head_args) {
std::vector<std::string> cmdargs( std::vector<std::string> cmdargs({"ray",
{"ray", "start", "--head", "--port", std::to_string(redis_port), "--redis-password", "start",
redis_password, "--node-ip-address", GetNodeIpAddress()}); "--head",
"--port",
std::to_string(redis_port),
"--redis-password",
redis_password,
"--node-ip-address",
GetNodeIpAddress()});
if (!head_args.empty()) { if (!head_args.empty()) {
cmdargs.insert(cmdargs.end(), head_args.begin(), head_args.end()); cmdargs.insert(cmdargs.end(), head_args.begin(), head_args.end());
} }
@ -63,8 +70,8 @@ std::unique_ptr<ray::gcs::GlobalStateAccessor> ProcessHelper::CreateGlobalStateA
std::vector<std::string> address; std::vector<std::string> address;
boost::split(address, redis_address, boost::is_any_of(":")); boost::split(address, redis_address, boost::is_any_of(":"));
RAY_CHECK(address.size() == 2); RAY_CHECK(address.size() == 2);
ray::gcs::GcsClientOptions client_options(address[0], std::stoi(address[1]), ray::gcs::GcsClientOptions client_options(
redis_password); address[0], std::stoi(address[1]), redis_password);
auto global_state_accessor = auto global_state_accessor =
std::make_unique<ray::gcs::GlobalStateAccessor>(client_options); std::make_unique<ray::gcs::GlobalStateAccessor>(client_options);
@ -79,7 +86,8 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback)
if (ConfigInternal::Instance().worker_type == WorkerType::DRIVER && if (ConfigInternal::Instance().worker_type == WorkerType::DRIVER &&
bootstrap_ip.empty()) { bootstrap_ip.empty()) {
bootstrap_ip = "127.0.0.1"; bootstrap_ip = "127.0.0.1";
StartRayNode(bootstrap_port, ConfigInternal::Instance().redis_password, StartRayNode(bootstrap_port,
ConfigInternal::Instance().redis_password,
ConfigInternal::Instance().head_args); ConfigInternal::Instance().head_args);
} }
if (bootstrap_ip == "127.0.0.1") { if (bootstrap_ip == "127.0.0.1") {
@ -129,7 +137,8 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback)
gcs::GcsClientOptions gcs_options = gcs::GcsClientOptions gcs_options =
::RayConfig::instance().bootstrap_with_gcs() ::RayConfig::instance().bootstrap_with_gcs()
? gcs::GcsClientOptions(bootstrap_address) ? gcs::GcsClientOptions(bootstrap_address)
: gcs::GcsClientOptions(bootstrap_ip, ConfigInternal::Instance().bootstrap_port, : gcs::GcsClientOptions(bootstrap_ip,
ConfigInternal::Instance().bootstrap_port,
ConfigInternal::Instance().redis_password); ConfigInternal::Instance().redis_password);
CoreWorkerOptions options; CoreWorkerOptions options;

View file

@ -29,7 +29,8 @@ class ProcessHelper {
public: public:
void RayStart(CoreWorkerOptions::TaskExecutionCallback callback); void RayStart(CoreWorkerOptions::TaskExecutionCallback callback);
void RayStop(); void RayStop();
void StartRayNode(const int redis_port, const std::string redis_password, void StartRayNode(const int redis_port,
const std::string redis_password,
const std::vector<std::string> &head_args = {}); const std::vector<std::string> &head_args = {});
void StopRayNode(); void StopRayNode();

View file

@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include "util.h" #include "util.h"
#include <boost/algorithm/string.hpp> #include <boost/algorithm/string.hpp>
#include <boost/asio.hpp> #include <boost/asio.hpp>
#include "ray/util/logging.h" #include "ray/util/logging.h"
namespace ray { namespace ray {
@ -27,8 +29,8 @@ std::string GetNodeIpAddress(const std::string &address) {
try { try {
boost::asio::io_service netService; boost::asio::io_service netService;
boost::asio::ip::udp::resolver resolver(netService); boost::asio::ip::udp::resolver resolver(netService);
boost::asio::ip::udp::resolver::query query(boost::asio::ip::udp::v4(), parts[0], boost::asio::ip::udp::resolver::query query(
parts[1]); boost::asio::ip::udp::v4(), parts[0], parts[1]);
boost::asio::ip::udp::resolver::iterator endpoints = resolver.resolve(query); boost::asio::ip::udp::resolver::iterator endpoints = resolver.resolve(query);
boost::asio::ip::udp::endpoint ep = *endpoints; boost::asio::ip::udp::endpoint ep = *endpoints;
boost::asio::ip::udp::socket socket(netService); boost::asio::ip::udp::socket socket(netService);

View file

@ -14,10 +14,10 @@
#include <ray/api.h> #include <ray/api.h>
#include <ray/util/logging.h> #include <ray/util/logging.h>
#include "ray/core_worker/common.h"
#include "ray/core_worker/core_worker.h"
#include "../config_internal.h" #include "../config_internal.h"
#include "ray/core_worker/common.h"
#include "ray/core_worker/core_worker.h"
int main(int argc, char **argv) { int main(int argc, char **argv) {
RAY_LOG(INFO) << "CPP default worker started."; RAY_LOG(INFO) << "CPP default worker started.";

View file

@ -19,17 +19,23 @@ namespace core {
class MockActorCreatorInterface : public ActorCreatorInterface { class MockActorCreatorInterface : public ActorCreatorInterface {
public: public:
MOCK_METHOD(Status, RegisterActor, (const TaskSpecification &task_spec), MOCK_METHOD(Status,
RegisterActor,
(const TaskSpecification &task_spec),
(const, override)); (const, override));
MOCK_METHOD(Status, AsyncRegisterActor, MOCK_METHOD(Status,
AsyncRegisterActor,
(const TaskSpecification &task_spec, gcs::StatusCallback callback), (const TaskSpecification &task_spec, gcs::StatusCallback callback),
(override)); (override));
MOCK_METHOD(Status, AsyncCreateActor, MOCK_METHOD(Status,
AsyncCreateActor,
(const TaskSpecification &task_spec, (const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback), const rpc::ClientCallback<rpc::CreateActorReply> &callback),
(override)); (override));
MOCK_METHOD(void, AsyncWaitForActorRegisterFinish, MOCK_METHOD(void,
(const ActorID &actor_id, gcs::StatusCallback callback), (override)); AsyncWaitForActorRegisterFinish,
(const ActorID &actor_id, gcs::StatusCallback callback),
(override));
MOCK_METHOD(bool, IsActorInRegistering, (const ActorID &actor_id), (const, override)); MOCK_METHOD(bool, IsActorInRegistering, (const ActorID &actor_id), (const, override));
}; };
@ -41,13 +47,17 @@ namespace core {
class MockDefaultActorCreator : public DefaultActorCreator { class MockDefaultActorCreator : public DefaultActorCreator {
public: public:
MOCK_METHOD(Status, RegisterActor, (const TaskSpecification &task_spec), MOCK_METHOD(Status,
RegisterActor,
(const TaskSpecification &task_spec),
(const, override)); (const, override));
MOCK_METHOD(Status, AsyncRegisterActor, MOCK_METHOD(Status,
AsyncRegisterActor,
(const TaskSpecification &task_spec, gcs::StatusCallback callback), (const TaskSpecification &task_spec, gcs::StatusCallback callback),
(override)); (override));
MOCK_METHOD(bool, IsActorInRegistering, (const ActorID &actor_id), (const, override)); MOCK_METHOD(bool, IsActorInRegistering, (const ActorID &actor_id), (const, override));
MOCK_METHOD(Status, AsyncCreateActor, MOCK_METHOD(Status,
AsyncCreateActor,
(const TaskSpecification &task_spec, (const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback), const rpc::ClientCallback<rpc::CreateActorReply> &callback),
(override)); (override));

View file

@ -39,105 +39,134 @@ namespace core {
class MockCoreWorker : public CoreWorker { class MockCoreWorker : public CoreWorker {
public: public:
MOCK_METHOD(void, HandlePushTask, MOCK_METHOD(void,
(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, HandlePushTask,
(const rpc::PushTaskRequest &request,
rpc::PushTaskReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleDirectActorCallArgWaitComplete, MOCK_METHOD(void,
HandleDirectActorCallArgWaitComplete,
(const rpc::DirectActorCallArgWaitCompleteRequest &request, (const rpc::DirectActorCallArgWaitCompleteRequest &request,
rpc::DirectActorCallArgWaitCompleteReply *reply, rpc::DirectActorCallArgWaitCompleteReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetObjectStatus, MOCK_METHOD(void,
HandleGetObjectStatus,
(const rpc::GetObjectStatusRequest &request, (const rpc::GetObjectStatusRequest &request,
rpc::GetObjectStatusReply *reply, rpc::GetObjectStatusReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleWaitForActorOutOfScope, MOCK_METHOD(void,
HandleWaitForActorOutOfScope,
(const rpc::WaitForActorOutOfScopeRequest &request, (const rpc::WaitForActorOutOfScopeRequest &request,
rpc::WaitForActorOutOfScopeReply *reply, rpc::WaitForActorOutOfScopeReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandlePubsubLongPolling, MOCK_METHOD(void,
HandlePubsubLongPolling,
(const rpc::PubsubLongPollingRequest &request, (const rpc::PubsubLongPollingRequest &request,
rpc::PubsubLongPollingReply *reply, rpc::PubsubLongPollingReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandlePubsubCommandBatch, MOCK_METHOD(void,
HandlePubsubCommandBatch,
(const rpc::PubsubCommandBatchRequest &request, (const rpc::PubsubCommandBatchRequest &request,
rpc::PubsubCommandBatchReply *reply, rpc::PubsubCommandBatchReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleAddObjectLocationOwner, MOCK_METHOD(void,
HandleAddObjectLocationOwner,
(const rpc::AddObjectLocationOwnerRequest &request, (const rpc::AddObjectLocationOwnerRequest &request,
rpc::AddObjectLocationOwnerReply *reply, rpc::AddObjectLocationOwnerReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleRemoveObjectLocationOwner, MOCK_METHOD(void,
HandleRemoveObjectLocationOwner,
(const rpc::RemoveObjectLocationOwnerRequest &request, (const rpc::RemoveObjectLocationOwnerRequest &request,
rpc::RemoveObjectLocationOwnerReply *reply, rpc::RemoveObjectLocationOwnerReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetObjectLocationsOwner, MOCK_METHOD(void,
HandleGetObjectLocationsOwner,
(const rpc::GetObjectLocationsOwnerRequest &request, (const rpc::GetObjectLocationsOwnerRequest &request,
rpc::GetObjectLocationsOwnerReply *reply, rpc::GetObjectLocationsOwnerReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleKillActor, MOCK_METHOD(void,
(const rpc::KillActorRequest &request, rpc::KillActorReply *reply, HandleKillActor,
(const rpc::KillActorRequest &request,
rpc::KillActorReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleCancelTask, MOCK_METHOD(void,
(const rpc::CancelTaskRequest &request, rpc::CancelTaskReply *reply, HandleCancelTask,
(const rpc::CancelTaskRequest &request,
rpc::CancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleRemoteCancelTask, MOCK_METHOD(void,
HandleRemoteCancelTask,
(const rpc::RemoteCancelTaskRequest &request, (const rpc::RemoteCancelTaskRequest &request,
rpc::RemoteCancelTaskReply *reply, rpc::RemoteCancelTaskReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandlePlasmaObjectReady, MOCK_METHOD(void,
HandlePlasmaObjectReady,
(const rpc::PlasmaObjectReadyRequest &request, (const rpc::PlasmaObjectReadyRequest &request,
rpc::PlasmaObjectReadyReply *reply, rpc::PlasmaObjectReadyReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetCoreWorkerStats, MOCK_METHOD(void,
HandleGetCoreWorkerStats,
(const rpc::GetCoreWorkerStatsRequest &request, (const rpc::GetCoreWorkerStatsRequest &request,
rpc::GetCoreWorkerStatsReply *reply, rpc::GetCoreWorkerStatsReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleLocalGC, MOCK_METHOD(void,
(const rpc::LocalGCRequest &request, rpc::LocalGCReply *reply, HandleLocalGC,
(const rpc::LocalGCRequest &request,
rpc::LocalGCReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleRunOnUtilWorker, MOCK_METHOD(void,
HandleRunOnUtilWorker,
(const rpc::RunOnUtilWorkerRequest &request, (const rpc::RunOnUtilWorkerRequest &request,
rpc::RunOnUtilWorkerReply *reply, rpc::RunOnUtilWorkerReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleSpillObjects, MOCK_METHOD(void,
(const rpc::SpillObjectsRequest &request, rpc::SpillObjectsReply *reply, HandleSpillObjects,
(const rpc::SpillObjectsRequest &request,
rpc::SpillObjectsReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleAddSpilledUrl, MOCK_METHOD(void,
(const rpc::AddSpilledUrlRequest &request, rpc::AddSpilledUrlReply *reply, HandleAddSpilledUrl,
(const rpc::AddSpilledUrlRequest &request,
rpc::AddSpilledUrlReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleRestoreSpilledObjects, MOCK_METHOD(void,
HandleRestoreSpilledObjects,
(const rpc::RestoreSpilledObjectsRequest &request, (const rpc::RestoreSpilledObjectsRequest &request,
rpc::RestoreSpilledObjectsReply *reply, rpc::RestoreSpilledObjectsReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleDeleteSpilledObjects, MOCK_METHOD(void,
HandleDeleteSpilledObjects,
(const rpc::DeleteSpilledObjectsRequest &request, (const rpc::DeleteSpilledObjectsRequest &request,
rpc::DeleteSpilledObjectsReply *reply, rpc::DeleteSpilledObjectsReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleExit, MOCK_METHOD(void,
(const rpc::ExitRequest &request, rpc::ExitReply *reply, HandleExit,
(const rpc::ExitRequest &request,
rpc::ExitReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleAssignObjectOwner, MOCK_METHOD(void,
HandleAssignObjectOwner,
(const rpc::AssignObjectOwnerRequest &request, (const rpc::AssignObjectOwnerRequest &request,
rpc::AssignObjectOwnerReply *reply, rpc::AssignObjectOwnerReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),

View file

@ -27,7 +27,9 @@ namespace core {
class MockLocalityDataProviderInterface : public LocalityDataProviderInterface { class MockLocalityDataProviderInterface : public LocalityDataProviderInterface {
public: public:
MOCK_METHOD(absl::optional<LocalityData>, GetLocalityData, (const ObjectID &object_id), MOCK_METHOD(absl::optional<LocalityData>,
GetLocalityData,
(const ObjectID &object_id),
(override)); (override));
}; };
@ -39,8 +41,10 @@ namespace core {
class MockLeasePolicyInterface : public LeasePolicyInterface { class MockLeasePolicyInterface : public LeasePolicyInterface {
public: public:
MOCK_METHOD((std::pair<rpc::Address, bool>), GetBestNodeForTask, MOCK_METHOD((std::pair<rpc::Address, bool>),
(const TaskSpecification &spec), (override)); GetBestNodeForTask,
(const TaskSpecification &spec),
(override));
}; };
} // namespace core } // namespace core

View file

@ -19,28 +19,43 @@ namespace core {
class MockTaskFinisherInterface : public TaskFinisherInterface { class MockTaskFinisherInterface : public TaskFinisherInterface {
public: public:
MOCK_METHOD(void, CompletePendingTask, MOCK_METHOD(void,
(const TaskID &task_id, const rpc::PushTaskReply &reply, CompletePendingTask,
(const TaskID &task_id,
const rpc::PushTaskReply &reply,
const rpc::Address &actor_addr), const rpc::Address &actor_addr),
(override)); (override));
MOCK_METHOD(void, FailPendingTask, MOCK_METHOD(void,
(const TaskID &task_id, rpc::ErrorType error_type, const Status *status, FailPendingTask,
const rpc::RayErrorInfo *ray_error_info, bool mark_task_object_failed), (const TaskID &task_id,
rpc::ErrorType error_type,
const Status *status,
const rpc::RayErrorInfo *ray_error_info,
bool mark_task_object_failed),
(override)); (override));
MOCK_METHOD(bool, FailOrRetryPendingTask, MOCK_METHOD(bool,
(const TaskID &task_id, rpc::ErrorType error_type, const Status *status, FailOrRetryPendingTask,
const rpc::RayErrorInfo *ray_error_info, bool mark_task_object_failed), (const TaskID &task_id,
rpc::ErrorType error_type,
const Status *status,
const rpc::RayErrorInfo *ray_error_info,
bool mark_task_object_failed),
(override)); (override));
MOCK_METHOD(void, OnTaskDependenciesInlined, MOCK_METHOD(void,
OnTaskDependenciesInlined,
(const std::vector<ObjectID> &inlined_dependency_ids, (const std::vector<ObjectID> &inlined_dependency_ids,
const std::vector<ObjectID> &contained_ids), const std::vector<ObjectID> &contained_ids),
(override)); (override));
MOCK_METHOD(bool, MarkTaskCanceled, (const TaskID &task_id), (override)); MOCK_METHOD(bool, MarkTaskCanceled, (const TaskID &task_id), (override));
MOCK_METHOD(void, MarkTaskReturnObjectsFailed, MOCK_METHOD(void,
(const TaskSpecification &spec, rpc::ErrorType error_type, MarkTaskReturnObjectsFailed,
(const TaskSpecification &spec,
rpc::ErrorType error_type,
const rpc::RayErrorInfo *ray_error_info), const rpc::RayErrorInfo *ray_error_info),
(override)); (override));
MOCK_METHOD(absl::optional<TaskSpecification>, GetTaskSpec, (const TaskID &task_id), MOCK_METHOD(absl::optional<TaskSpecification>,
GetTaskSpec,
(const TaskID &task_id),
(const, override)); (const, override));
MOCK_METHOD(bool, RetryTaskIfPossible, (const TaskID &task_id), (override)); MOCK_METHOD(bool, RetryTaskIfPossible, (const TaskID &task_id), (override));
MOCK_METHOD(void, MarkDependenciesResolved, (const TaskID &task_id), (override)); MOCK_METHOD(void, MarkDependenciesResolved, (const TaskID &task_id), (override));
@ -54,8 +69,10 @@ namespace core {
class MockTaskResubmissionInterface : public TaskResubmissionInterface { class MockTaskResubmissionInterface : public TaskResubmissionInterface {
public: public:
MOCK_METHOD(bool, ResubmitTask, MOCK_METHOD(bool,
(const TaskID &task_id, std::vector<ObjectID> *task_deps), (override)); ResubmitTask,
(const TaskID &task_id, std::vector<ObjectID> *task_deps),
(override));
}; };
} // namespace core } // namespace core

View file

@ -17,24 +17,34 @@ namespace core {
class MockCoreWorkerDirectActorTaskSubmitterInterface class MockCoreWorkerDirectActorTaskSubmitterInterface
: public CoreWorkerDirectActorTaskSubmitterInterface { : public CoreWorkerDirectActorTaskSubmitterInterface {
public: public:
MOCK_METHOD(void, AddActorQueueIfNotExists, MOCK_METHOD(void,
(const ActorID &actor_id, int32_t max_pending_calls), (override)); AddActorQueueIfNotExists,
MOCK_METHOD(void, ConnectActor, (const ActorID &actor_id, int32_t max_pending_calls),
(const ActorID &actor_id, const rpc::Address &address, (override));
MOCK_METHOD(void,
ConnectActor,
(const ActorID &actor_id,
const rpc::Address &address,
int64_t num_restarts), int64_t num_restarts),
(override)); (override));
MOCK_METHOD(void, DisconnectActor, MOCK_METHOD(void,
(const ActorID &actor_id, int64_t num_restarts, bool dead, DisconnectActor,
(const ActorID &actor_id,
int64_t num_restarts,
bool dead,
const rpc::RayException *creation_task_exception), const rpc::RayException *creation_task_exception),
(override)); (override));
MOCK_METHOD(void, KillActor, MOCK_METHOD(void,
(const ActorID &actor_id, bool force_kill, bool no_restart), (override)); KillActor,
(const ActorID &actor_id, bool force_kill, bool no_restart),
(override));
MOCK_METHOD(void, CheckTimeoutTasks, (), (override)); MOCK_METHOD(void, CheckTimeoutTasks, (), (override));
}; };
class MockDependencyWaiter : public DependencyWaiter { class MockDependencyWaiter : public DependencyWaiter {
public: public:
MOCK_METHOD(void, Wait, MOCK_METHOD(void,
Wait,
(const std::vector<rpc::ObjectReference> &dependencies, (const std::vector<rpc::ObjectReference> &dependencies,
std::function<void()> on_dependencies_available), std::function<void()> on_dependencies_available),
(override)); (override));
@ -42,13 +52,16 @@ class MockDependencyWaiter : public DependencyWaiter {
class MockSchedulingQueue : public SchedulingQueue { class MockSchedulingQueue : public SchedulingQueue {
public: public:
MOCK_METHOD(void, Add, MOCK_METHOD(void,
(int64_t seq_no, int64_t client_processed_up_to, Add,
(int64_t seq_no,
int64_t client_processed_up_to,
std::function<void(rpc::SendReplyCallback)> accept_request, std::function<void(rpc::SendReplyCallback)> accept_request,
std::function<void(rpc::SendReplyCallback)> reject_request, std::function<void(rpc::SendReplyCallback)> reject_request,
rpc::SendReplyCallback send_reply_callback, rpc::SendReplyCallback send_reply_callback,
const std::string &concurrency_group_name, const std::string &concurrency_group_name,
const ray::FunctionDescriptor &function_descriptor, TaskID task_id, const ray::FunctionDescriptor &function_descriptor,
TaskID task_id,
const std::vector<rpc::ObjectReference> &dependencies), const std::vector<rpc::ObjectReference> &dependencies),
(override)); (override));
MOCK_METHOD(void, ScheduleRequests, (), (override)); MOCK_METHOD(void, ScheduleRequests, (), (override));

View file

@ -20,37 +20,53 @@ namespace gcs {
class MockActorInfoAccessor : public ActorInfoAccessor { class MockActorInfoAccessor : public ActorInfoAccessor {
public: public:
MOCK_METHOD(Status, AsyncGet, MOCK_METHOD(Status,
AsyncGet,
(const ActorID &actor_id, (const ActorID &actor_id,
const OptionalItemCallback<rpc::ActorTableData> &callback), const OptionalItemCallback<rpc::ActorTableData> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGetAll, MOCK_METHOD(Status,
(const MultiItemCallback<rpc::ActorTableData> &callback), (override)); AsyncGetAll,
MOCK_METHOD(Status, AsyncGetByName, (const MultiItemCallback<rpc::ActorTableData> &callback),
(const std::string &name, const std::string &ray_namespace, (override));
MOCK_METHOD(Status,
AsyncGetByName,
(const std::string &name,
const std::string &ray_namespace,
const OptionalItemCallback<rpc::ActorTableData> &callback, const OptionalItemCallback<rpc::ActorTableData> &callback,
int64_t timeout_ms), int64_t timeout_ms),
(override)); (override));
MOCK_METHOD(Status, AsyncListNamedActors, MOCK_METHOD(Status,
(bool all_namespaces, const std::string &ray_namespace, AsyncListNamedActors,
(bool all_namespaces,
const std::string &ray_namespace,
const OptionalItemCallback<std::vector<rpc::NamedActorInfo>> &callback, const OptionalItemCallback<std::vector<rpc::NamedActorInfo>> &callback,
int64_t timeout_ms), int64_t timeout_ms),
(override)); (override));
MOCK_METHOD(Status, AsyncRegisterActor, MOCK_METHOD(Status,
(const TaskSpecification &task_spec, const StatusCallback &callback, AsyncRegisterActor,
(const TaskSpecification &task_spec,
const StatusCallback &callback,
int64_t timeout_ms), int64_t timeout_ms),
(override)); (override));
MOCK_METHOD(Status, SyncRegisterActor, (const TaskSpecification &task_spec), MOCK_METHOD(Status,
SyncRegisterActor,
(const TaskSpecification &task_spec),
(override)); (override));
MOCK_METHOD(Status, AsyncKillActor, MOCK_METHOD(Status,
(const ActorID &actor_id, bool force_kill, bool no_restart, AsyncKillActor,
(const ActorID &actor_id,
bool force_kill,
bool no_restart,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncCreateActor, MOCK_METHOD(Status,
AsyncCreateActor,
(const TaskSpecification &task_spec, (const TaskSpecification &task_spec,
const rpc::ClientCallback<rpc::CreateActorReply> &callback), const rpc::ClientCallback<rpc::CreateActorReply> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncSubscribe, MOCK_METHOD(Status,
AsyncSubscribe,
(const ActorID &actor_id, (const ActorID &actor_id,
(const SubscribeCallback<ActorID, rpc::ActorTableData> &subscribe), (const SubscribeCallback<ActorID, rpc::ActorTableData> &subscribe),
const StatusCallback &done), const StatusCallback &done),
@ -68,20 +84,28 @@ namespace gcs {
class MockJobInfoAccessor : public JobInfoAccessor { class MockJobInfoAccessor : public JobInfoAccessor {
public: public:
MOCK_METHOD(Status, AsyncAdd, MOCK_METHOD(Status,
AsyncAdd,
(const std::shared_ptr<rpc::JobTableData> &data_ptr, (const std::shared_ptr<rpc::JobTableData> &data_ptr,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncMarkFinished, MOCK_METHOD(Status,
(const JobID &job_id, const StatusCallback &callback), (override)); AsyncMarkFinished,
MOCK_METHOD(Status, AsyncSubscribeAll, (const JobID &job_id, const StatusCallback &callback),
(override));
MOCK_METHOD(Status,
AsyncSubscribeAll,
((const SubscribeCallback<JobID, rpc::JobTableData> &subscribe), ((const SubscribeCallback<JobID, rpc::JobTableData> &subscribe),
const StatusCallback &done), const StatusCallback &done),
(override)); (override));
MOCK_METHOD(Status, AsyncGetAll, (const MultiItemCallback<rpc::JobTableData> &callback), MOCK_METHOD(Status,
AsyncGetAll,
(const MultiItemCallback<rpc::JobTableData> &callback),
(override)); (override));
MOCK_METHOD(void, AsyncResubscribe, (bool is_pubsub_server_restarted), (override)); MOCK_METHOD(void, AsyncResubscribe, (bool is_pubsub_server_restarted), (override));
MOCK_METHOD(Status, AsyncGetNextJobID, (const ItemCallback<JobID> &callback), MOCK_METHOD(Status,
AsyncGetNextJobID,
(const ItemCallback<JobID> &callback),
(override)); (override));
}; };
@ -93,35 +117,49 @@ namespace gcs {
class MockNodeInfoAccessor : public NodeInfoAccessor { class MockNodeInfoAccessor : public NodeInfoAccessor {
public: public:
MOCK_METHOD(Status, RegisterSelf, MOCK_METHOD(Status,
RegisterSelf,
(const rpc::GcsNodeInfo &local_node_info, const StatusCallback &callback), (const rpc::GcsNodeInfo &local_node_info, const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, DrainSelf, (), (override)); MOCK_METHOD(Status, DrainSelf, (), (override));
MOCK_METHOD(const NodeID &, GetSelfId, (), (const, override)); MOCK_METHOD(const NodeID &, GetSelfId, (), (const, override));
MOCK_METHOD(const rpc::GcsNodeInfo &, GetSelfInfo, (), (const, override)); MOCK_METHOD(const rpc::GcsNodeInfo &, GetSelfInfo, (), (const, override));
MOCK_METHOD(Status, AsyncRegister, MOCK_METHOD(Status,
AsyncRegister,
(const rpc::GcsNodeInfo &node_info, const StatusCallback &callback), (const rpc::GcsNodeInfo &node_info, const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncDrainNode, MOCK_METHOD(Status,
(const NodeID &node_id, const StatusCallback &callback), (override)); AsyncDrainNode,
MOCK_METHOD(Status, AsyncGetAll, (const MultiItemCallback<rpc::GcsNodeInfo> &callback), (const NodeID &node_id, const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncSubscribeToNodeChange, MOCK_METHOD(Status,
AsyncGetAll,
(const MultiItemCallback<rpc::GcsNodeInfo> &callback),
(override));
MOCK_METHOD(Status,
AsyncSubscribeToNodeChange,
((const SubscribeCallback<NodeID, rpc::GcsNodeInfo> &subscribe), ((const SubscribeCallback<NodeID, rpc::GcsNodeInfo> &subscribe),
const StatusCallback &done), const StatusCallback &done),
(override)); (override));
MOCK_METHOD(const rpc::GcsNodeInfo *, Get, MOCK_METHOD(const rpc::GcsNodeInfo *,
(const NodeID &node_id, bool filter_dead_nodes), (const, override)); Get,
MOCK_METHOD((const absl::flat_hash_map<NodeID, rpc::GcsNodeInfo> &), GetAll, (), (const NodeID &node_id, bool filter_dead_nodes),
(const, override));
MOCK_METHOD((const absl::flat_hash_map<NodeID, rpc::GcsNodeInfo> &),
GetAll,
(),
(const, override)); (const, override));
MOCK_METHOD(bool, IsRemoved, (const NodeID &node_id), (const, override)); MOCK_METHOD(bool, IsRemoved, (const NodeID &node_id), (const, override));
MOCK_METHOD(Status, AsyncReportHeartbeat, MOCK_METHOD(Status,
AsyncReportHeartbeat,
(const std::shared_ptr<rpc::HeartbeatTableData> &data_ptr, (const std::shared_ptr<rpc::HeartbeatTableData> &data_ptr,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(void, AsyncResubscribe, (bool is_pubsub_server_restarted), (override)); MOCK_METHOD(void, AsyncResubscribe, (bool is_pubsub_server_restarted), (override));
MOCK_METHOD(Status, AsyncGetInternalConfig, MOCK_METHOD(Status,
(const OptionalItemCallback<std::string> &callback), (override)); AsyncGetInternalConfig,
(const OptionalItemCallback<std::string> &callback),
(override));
}; };
} // namespace gcs } // namespace gcs
@ -132,24 +170,32 @@ namespace gcs {
class MockNodeResourceInfoAccessor : public NodeResourceInfoAccessor { class MockNodeResourceInfoAccessor : public NodeResourceInfoAccessor {
public: public:
MOCK_METHOD(Status, AsyncGetResources, MOCK_METHOD(Status,
AsyncGetResources,
(const NodeID &node_id, const OptionalItemCallback<ResourceMap> &callback), (const NodeID &node_id, const OptionalItemCallback<ResourceMap> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGetAllAvailableResources, MOCK_METHOD(Status,
(const MultiItemCallback<rpc::AvailableResources> &callback), (override)); AsyncGetAllAvailableResources,
MOCK_METHOD(Status, AsyncSubscribeToResources, (const MultiItemCallback<rpc::AvailableResources> &callback),
(override));
MOCK_METHOD(Status,
AsyncSubscribeToResources,
(const ItemCallback<rpc::NodeResourceChange> &subscribe, (const ItemCallback<rpc::NodeResourceChange> &subscribe,
const StatusCallback &done), const StatusCallback &done),
(override)); (override));
MOCK_METHOD(void, AsyncResubscribe, (bool is_pubsub_server_restarted), (override)); MOCK_METHOD(void, AsyncResubscribe, (bool is_pubsub_server_restarted), (override));
MOCK_METHOD(Status, AsyncReportResourceUsage, MOCK_METHOD(Status,
AsyncReportResourceUsage,
(const std::shared_ptr<rpc::ResourcesData> &data_ptr, (const std::shared_ptr<rpc::ResourcesData> &data_ptr,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(void, AsyncReReportResourceUsage, (), (override)); MOCK_METHOD(void, AsyncReReportResourceUsage, (), (override));
MOCK_METHOD(Status, AsyncGetAllResourceUsage, MOCK_METHOD(Status,
(const ItemCallback<rpc::ResourceUsageBatchData> &callback), (override)); AsyncGetAllResourceUsage,
MOCK_METHOD(Status, AsyncSubscribeBatchedResourceUsage, (const ItemCallback<rpc::ResourceUsageBatchData> &callback),
(override));
MOCK_METHOD(Status,
AsyncSubscribeBatchedResourceUsage,
(const ItemCallback<rpc::ResourceUsageBatchData> &subscribe, (const ItemCallback<rpc::ResourceUsageBatchData> &subscribe,
const StatusCallback &done), const StatusCallback &done),
(override)); (override));
@ -163,7 +209,8 @@ namespace gcs {
class MockErrorInfoAccessor : public ErrorInfoAccessor { class MockErrorInfoAccessor : public ErrorInfoAccessor {
public: public:
MOCK_METHOD(Status, AsyncReportJobError, MOCK_METHOD(Status,
AsyncReportJobError,
(const std::shared_ptr<rpc::ErrorTableData> &data_ptr, (const std::shared_ptr<rpc::ErrorTableData> &data_ptr,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
@ -177,12 +224,15 @@ namespace gcs {
class MockStatsInfoAccessor : public StatsInfoAccessor { class MockStatsInfoAccessor : public StatsInfoAccessor {
public: public:
MOCK_METHOD(Status, AsyncAddProfileData, MOCK_METHOD(Status,
AsyncAddProfileData,
(const std::shared_ptr<rpc::ProfileTableData> &data_ptr, (const std::shared_ptr<rpc::ProfileTableData> &data_ptr,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGetAll, MOCK_METHOD(Status,
(const MultiItemCallback<rpc::ProfileTableData> &callback), (override)); AsyncGetAll,
(const MultiItemCallback<rpc::ProfileTableData> &callback),
(override));
}; };
} // namespace gcs } // namespace gcs
@ -193,21 +243,27 @@ namespace gcs {
class MockWorkerInfoAccessor : public WorkerInfoAccessor { class MockWorkerInfoAccessor : public WorkerInfoAccessor {
public: public:
MOCK_METHOD(Status, AsyncSubscribeToWorkerFailures, MOCK_METHOD(Status,
AsyncSubscribeToWorkerFailures,
(const ItemCallback<rpc::WorkerDeltaData> &subscribe, (const ItemCallback<rpc::WorkerDeltaData> &subscribe,
const StatusCallback &done), const StatusCallback &done),
(override)); (override));
MOCK_METHOD(Status, AsyncReportWorkerFailure, MOCK_METHOD(Status,
AsyncReportWorkerFailure,
(const std::shared_ptr<rpc::WorkerTableData> &data_ptr, (const std::shared_ptr<rpc::WorkerTableData> &data_ptr,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGet, MOCK_METHOD(Status,
AsyncGet,
(const WorkerID &worker_id, (const WorkerID &worker_id,
const OptionalItemCallback<rpc::WorkerTableData> &callback), const OptionalItemCallback<rpc::WorkerTableData> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGetAll, MOCK_METHOD(Status,
(const MultiItemCallback<rpc::WorkerTableData> &callback), (override)); AsyncGetAll,
MOCK_METHOD(Status, AsyncAdd, (const MultiItemCallback<rpc::WorkerTableData> &callback),
(override));
MOCK_METHOD(Status,
AsyncAdd,
(const std::shared_ptr<rpc::WorkerTableData> &data_ptr, (const std::shared_ptr<rpc::WorkerTableData> &data_ptr,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
@ -222,23 +278,33 @@ namespace gcs {
class MockPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor { class MockPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor {
public: public:
MOCK_METHOD(Status, SyncCreatePlacementGroup, MOCK_METHOD(Status,
(const PlacementGroupSpecification &placement_group_spec), (override)); SyncCreatePlacementGroup,
MOCK_METHOD(Status, AsyncGet, (const PlacementGroupSpecification &placement_group_spec),
(override));
MOCK_METHOD(Status,
AsyncGet,
(const PlacementGroupID &placement_group_id, (const PlacementGroupID &placement_group_id,
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback), const OptionalItemCallback<rpc::PlacementGroupTableData> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGetByName, MOCK_METHOD(Status,
(const std::string &placement_group_name, const std::string &ray_namespace, AsyncGetByName,
(const std::string &placement_group_name,
const std::string &ray_namespace,
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback, const OptionalItemCallback<rpc::PlacementGroupTableData> &callback,
int64_t timeout_ms), int64_t timeout_ms),
(override)); (override));
MOCK_METHOD(Status, AsyncGetAll, MOCK_METHOD(Status,
AsyncGetAll,
(const MultiItemCallback<rpc::PlacementGroupTableData> &callback), (const MultiItemCallback<rpc::PlacementGroupTableData> &callback),
(override)); (override));
MOCK_METHOD(Status, SyncRemovePlacementGroup, MOCK_METHOD(Status,
(const PlacementGroupID &placement_group_id), (override)); SyncRemovePlacementGroup,
MOCK_METHOD(Status, SyncWaitUntilReady, (const PlacementGroupID &placement_group_id), (const PlacementGroupID &placement_group_id),
(override));
MOCK_METHOD(Status,
SyncWaitUntilReady,
(const PlacementGroupID &placement_group_id),
(override)); (override));
}; };
@ -250,24 +316,37 @@ namespace gcs {
class MockInternalKVAccessor : public InternalKVAccessor { class MockInternalKVAccessor : public InternalKVAccessor {
public: public:
MOCK_METHOD(Status, AsyncInternalKVKeys, MOCK_METHOD(Status,
(const std::string &ns, const std::string &prefix, AsyncInternalKVKeys,
(const std::string &ns,
const std::string &prefix,
const OptionalItemCallback<std::vector<std::string>> &callback), const OptionalItemCallback<std::vector<std::string>> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncInternalKVGet, MOCK_METHOD(Status,
(const std::string &ns, const std::string &key, AsyncInternalKVGet,
(const std::string &ns,
const std::string &key,
const OptionalItemCallback<std::string> &callback), const OptionalItemCallback<std::string> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncInternalKVPut, MOCK_METHOD(Status,
(const std::string &ns, const std::string &key, const std::string &value, AsyncInternalKVPut,
bool overwrite, const OptionalItemCallback<int> &callback), (const std::string &ns,
const std::string &key,
const std::string &value,
bool overwrite,
const OptionalItemCallback<int> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncInternalKVExists, MOCK_METHOD(Status,
(const std::string &ns, const std::string &key, AsyncInternalKVExists,
(const std::string &ns,
const std::string &key,
const OptionalItemCallback<bool> &callback), const OptionalItemCallback<bool> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncInternalKVDel, MOCK_METHOD(Status,
(const std::string &ns, const std::string &key, bool del_by_prefix, AsyncInternalKVDel,
(const std::string &ns,
const std::string &key,
bool del_by_prefix,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
}; };

View file

@ -27,34 +27,44 @@ namespace gcs {
class MockGcsActorManager : public GcsActorManager { class MockGcsActorManager : public GcsActorManager {
public: public:
MOCK_METHOD(void, HandleRegisterActor, MOCK_METHOD(void,
(const rpc::RegisterActorRequest &request, rpc::RegisterActorReply *reply, HandleRegisterActor,
(const rpc::RegisterActorRequest &request,
rpc::RegisterActorReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleCreateActor, MOCK_METHOD(void,
(const rpc::CreateActorRequest &request, rpc::CreateActorReply *reply, HandleCreateActor,
(const rpc::CreateActorRequest &request,
rpc::CreateActorReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetActorInfo, MOCK_METHOD(void,
(const rpc::GetActorInfoRequest &request, rpc::GetActorInfoReply *reply, HandleGetActorInfo,
(const rpc::GetActorInfoRequest &request,
rpc::GetActorInfoReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetNamedActorInfo, MOCK_METHOD(void,
HandleGetNamedActorInfo,
(const rpc::GetNamedActorInfoRequest &request, (const rpc::GetNamedActorInfoRequest &request,
rpc::GetNamedActorInfoReply *reply, rpc::GetNamedActorInfoReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleListNamedActors, MOCK_METHOD(void,
HandleListNamedActors,
(const rpc::ListNamedActorsRequest &request, (const rpc::ListNamedActorsRequest &request,
rpc::ListNamedActorsReply *reply, rpc::ListNamedActorsReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetAllActorInfo, MOCK_METHOD(void,
HandleGetAllActorInfo,
(const rpc::GetAllActorInfoRequest &request, (const rpc::GetAllActorInfoRequest &request,
rpc::GetAllActorInfoReply *reply, rpc::GetAllActorInfoReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleKillActorViaGcs, MOCK_METHOD(void,
HandleKillActorViaGcs,
(const rpc::KillActorViaGcsRequest &request, (const rpc::KillActorViaGcsRequest &request,
rpc::KillActorViaGcsReply *reply, rpc::KillActorViaGcsReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),

View file

@ -20,13 +20,17 @@ class MockGcsActorSchedulerInterface : public GcsActorSchedulerInterface {
MOCK_METHOD(void, Schedule, (std::shared_ptr<GcsActor> actor), (override)); MOCK_METHOD(void, Schedule, (std::shared_ptr<GcsActor> actor), (override));
MOCK_METHOD(void, Reschedule, (std::shared_ptr<GcsActor> actor), (override)); MOCK_METHOD(void, Reschedule, (std::shared_ptr<GcsActor> actor), (override));
MOCK_METHOD(std::vector<ActorID>, CancelOnNode, (const NodeID &node_id), (override)); MOCK_METHOD(std::vector<ActorID>, CancelOnNode, (const NodeID &node_id), (override));
MOCK_METHOD(void, CancelOnLeasing, MOCK_METHOD(void,
CancelOnLeasing,
(const NodeID &node_id, const ActorID &actor_id, const TaskID &task_id), (const NodeID &node_id, const ActorID &actor_id, const TaskID &task_id),
(override)); (override));
MOCK_METHOD(ActorID, CancelOnWorker, (const NodeID &node_id, const WorkerID &worker_id), MOCK_METHOD(ActorID,
CancelOnWorker,
(const NodeID &node_id, const WorkerID &worker_id),
(override)); (override));
MOCK_METHOD( MOCK_METHOD(
void, ReleaseUnusedWorkers, void,
ReleaseUnusedWorkers,
((const std::unordered_map<NodeID, std::vector<WorkerID>> &node_to_workers)), ((const std::unordered_map<NodeID, std::vector<WorkerID>> &node_to_workers)),
(override)); (override));
}; };
@ -42,25 +46,36 @@ class MockGcsActorScheduler : public GcsActorScheduler {
MOCK_METHOD(void, Schedule, (std::shared_ptr<GcsActor> actor), (override)); MOCK_METHOD(void, Schedule, (std::shared_ptr<GcsActor> actor), (override));
MOCK_METHOD(void, Reschedule, (std::shared_ptr<GcsActor> actor), (override)); MOCK_METHOD(void, Reschedule, (std::shared_ptr<GcsActor> actor), (override));
MOCK_METHOD(std::vector<ActorID>, CancelOnNode, (const NodeID &node_id), (override)); MOCK_METHOD(std::vector<ActorID>, CancelOnNode, (const NodeID &node_id), (override));
MOCK_METHOD(void, CancelOnLeasing, MOCK_METHOD(void,
CancelOnLeasing,
(const NodeID &node_id, const ActorID &actor_id, const TaskID &task_id), (const NodeID &node_id, const ActorID &actor_id, const TaskID &task_id),
(override)); (override));
MOCK_METHOD(ActorID, CancelOnWorker, (const NodeID &node_id, const WorkerID &worker_id), MOCK_METHOD(ActorID,
CancelOnWorker,
(const NodeID &node_id, const WorkerID &worker_id),
(override)); (override));
MOCK_METHOD( MOCK_METHOD(
void, ReleaseUnusedWorkers, void,
ReleaseUnusedWorkers,
((const std::unordered_map<NodeID, std::vector<WorkerID>> &node_to_workers)), ((const std::unordered_map<NodeID, std::vector<WorkerID>> &node_to_workers)),
(override)); (override));
MOCK_METHOD(std::shared_ptr<rpc::GcsNodeInfo>, SelectNode, MOCK_METHOD(std::shared_ptr<rpc::GcsNodeInfo>,
(std::shared_ptr<GcsActor> actor), (override)); SelectNode,
MOCK_METHOD(void, HandleWorkerLeaseReply, (std::shared_ptr<GcsActor> actor),
(std::shared_ptr<GcsActor> actor, std::shared_ptr<rpc::GcsNodeInfo> node,
const Status &status, const rpc::RequestWorkerLeaseReply &reply),
(override)); (override));
MOCK_METHOD(void, RetryLeasingWorkerFromNode, MOCK_METHOD(void,
HandleWorkerLeaseReply,
(std::shared_ptr<GcsActor> actor,
std::shared_ptr<rpc::GcsNodeInfo> node,
const Status &status,
const rpc::RequestWorkerLeaseReply &reply),
(override));
MOCK_METHOD(void,
RetryLeasingWorkerFromNode,
(std::shared_ptr<GcsActor> actor, std::shared_ptr<rpc::GcsNodeInfo> node), (std::shared_ptr<GcsActor> actor, std::shared_ptr<rpc::GcsNodeInfo> node),
(override)); (override));
MOCK_METHOD(void, RetryCreatingActorOnWorker, MOCK_METHOD(void,
RetryCreatingActorOnWorker,
(std::shared_ptr<GcsActor> actor, std::shared_ptr<GcsLeasedWorker> worker), (std::shared_ptr<GcsActor> actor, std::shared_ptr<GcsLeasedWorker> worker),
(override)); (override));
}; };
@ -73,11 +88,16 @@ namespace gcs {
class MockRayletBasedActorScheduler : public RayletBasedActorScheduler { class MockRayletBasedActorScheduler : public RayletBasedActorScheduler {
public: public:
MOCK_METHOD(std::shared_ptr<rpc::GcsNodeInfo>, SelectNode, MOCK_METHOD(std::shared_ptr<rpc::GcsNodeInfo>,
(std::shared_ptr<GcsActor> actor), (override)); SelectNode,
MOCK_METHOD(void, HandleWorkerLeaseReply, (std::shared_ptr<GcsActor> actor),
(std::shared_ptr<GcsActor> actor, std::shared_ptr<rpc::GcsNodeInfo> node, (override));
const Status &status, const rpc::RequestWorkerLeaseReply &reply), MOCK_METHOD(void,
HandleWorkerLeaseReply,
(std::shared_ptr<GcsActor> actor,
std::shared_ptr<rpc::GcsNodeInfo> node,
const Status &status,
const rpc::RequestWorkerLeaseReply &reply),
(override)); (override));
}; };

View file

@ -17,13 +17,16 @@ namespace gcs {
class MockGcsHeartbeatManager : public GcsHeartbeatManager { class MockGcsHeartbeatManager : public GcsHeartbeatManager {
public: public:
MOCK_METHOD(void, HandleReportHeartbeat, MOCK_METHOD(void,
HandleReportHeartbeat,
(const rpc::ReportHeartbeatRequest &request, (const rpc::ReportHeartbeatRequest &request,
rpc::ReportHeartbeatReply *reply, rpc::ReportHeartbeatReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleCheckAlive, MOCK_METHOD(void,
(const rpc::CheckAliveRequest &request, rpc::CheckAliveReply *reply, HandleCheckAlive,
(const rpc::CheckAliveRequest &request,
rpc::CheckAliveReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
}; };

View file

@ -17,29 +17,40 @@ namespace gcs {
class MockGcsJobManager : public GcsJobManager { class MockGcsJobManager : public GcsJobManager {
public: public:
MOCK_METHOD(void, HandleAddJob, MOCK_METHOD(void,
(const rpc::AddJobRequest &request, rpc::AddJobReply *reply, HandleAddJob,
(const rpc::AddJobRequest &request,
rpc::AddJobReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleMarkJobFinished, MOCK_METHOD(void,
HandleMarkJobFinished,
(const rpc::MarkJobFinishedRequest &request, (const rpc::MarkJobFinishedRequest &request,
rpc::MarkJobFinishedReply *reply, rpc::MarkJobFinishedReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetAllJobInfo, MOCK_METHOD(void,
(const rpc::GetAllJobInfoRequest &request, rpc::GetAllJobInfoReply *reply, HandleGetAllJobInfo,
(const rpc::GetAllJobInfoRequest &request,
rpc::GetAllJobInfoReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleReportJobError, MOCK_METHOD(void,
(const rpc::ReportJobErrorRequest &request, rpc::ReportJobErrorReply *reply, HandleReportJobError,
(const rpc::ReportJobErrorRequest &request,
rpc::ReportJobErrorReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetNextJobID, MOCK_METHOD(void,
(const rpc::GetNextJobIDRequest &request, rpc::GetNextJobIDReply *reply, HandleGetNextJobID,
(const rpc::GetNextJobIDRequest &request,
rpc::GetNextJobIDReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, AddJobFinishedListener, MOCK_METHOD(void,
(std::function<void(std::shared_ptr<JobID>)> listener), (override)); AddJobFinishedListener,
(std::function<void(std::shared_ptr<JobID>)> listener),
(override));
}; };
} // namespace gcs } // namespace gcs

View file

@ -19,24 +19,37 @@ class MockInternalKVInterface : public ray::gcs::InternalKVInterface {
public: public:
MockInternalKVInterface() {} MockInternalKVInterface() {}
MOCK_METHOD(void, Get, MOCK_METHOD(void,
(const std::string &ns, const std::string &key, Get,
(const std::string &ns,
const std::string &key,
std::function<void(std::optional<std::string>)> callback), std::function<void(std::optional<std::string>)> callback),
(override)); (override));
MOCK_METHOD(void, Put, MOCK_METHOD(void,
(const std::string &ns, const std::string &key, const std::string &value, Put,
bool overwrite, std::function<void(bool)> callback), (const std::string &ns,
(override)); const std::string &key,
MOCK_METHOD(void, Del, const std::string &value,
(const std::string &ns, const std::string &key, bool del_by_prefix, bool overwrite,
std::function<void(int64_t)> callback),
(override));
MOCK_METHOD(void, Exists,
(const std::string &ns, const std::string &key,
std::function<void(bool)> callback), std::function<void(bool)> callback),
(override)); (override));
MOCK_METHOD(void, Keys, MOCK_METHOD(void,
(const std::string &ns, const std::string &prefix, Del,
(const std::string &ns,
const std::string &key,
bool del_by_prefix,
std::function<void(int64_t)> callback),
(override));
MOCK_METHOD(void,
Exists,
(const std::string &ns,
const std::string &key,
std::function<void(bool)> callback),
(override));
MOCK_METHOD(void,
Keys,
(const std::string &ns,
const std::string &prefix,
std::function<void(std::vector<std::string>)> callback), std::function<void(std::vector<std::string>)> callback),
(override)); (override));
MOCK_METHOD(instrumented_io_context &, GetEventLoop, (), (override)); MOCK_METHOD(instrumented_io_context &, GetEventLoop, (), (override));

View file

@ -18,19 +18,26 @@ namespace gcs {
class MockGcsNodeManager : public GcsNodeManager { class MockGcsNodeManager : public GcsNodeManager {
public: public:
MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr, nullptr) {} MockGcsNodeManager() : GcsNodeManager(nullptr, nullptr, nullptr) {}
MOCK_METHOD(void, HandleRegisterNode, MOCK_METHOD(void,
(const rpc::RegisterNodeRequest &request, rpc::RegisterNodeReply *reply, HandleRegisterNode,
(const rpc::RegisterNodeRequest &request,
rpc::RegisterNodeReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleDrainNode, MOCK_METHOD(void,
(const rpc::DrainNodeRequest &request, rpc::DrainNodeReply *reply, HandleDrainNode,
(const rpc::DrainNodeRequest &request,
rpc::DrainNodeReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetAllNodeInfo, MOCK_METHOD(void,
(const rpc::GetAllNodeInfoRequest &request, rpc::GetAllNodeInfoReply *reply, HandleGetAllNodeInfo,
(const rpc::GetAllNodeInfoRequest &request,
rpc::GetAllNodeInfoReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetInternalConfig, MOCK_METHOD(void,
HandleGetInternalConfig,
(const rpc::GetInternalConfigRequest &request, (const rpc::GetInternalConfigRequest &request,
rpc::GetInternalConfigReply *reply, rpc::GetInternalConfigReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),

View file

@ -27,32 +27,38 @@ namespace gcs {
class MockGcsPlacementGroupManager : public GcsPlacementGroupManager { class MockGcsPlacementGroupManager : public GcsPlacementGroupManager {
public: public:
MOCK_METHOD(void, HandleCreatePlacementGroup, MOCK_METHOD(void,
HandleCreatePlacementGroup,
(const rpc::CreatePlacementGroupRequest &request, (const rpc::CreatePlacementGroupRequest &request,
rpc::CreatePlacementGroupReply *reply, rpc::CreatePlacementGroupReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleRemovePlacementGroup, MOCK_METHOD(void,
HandleRemovePlacementGroup,
(const rpc::RemovePlacementGroupRequest &request, (const rpc::RemovePlacementGroupRequest &request,
rpc::RemovePlacementGroupReply *reply, rpc::RemovePlacementGroupReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetPlacementGroup, MOCK_METHOD(void,
HandleGetPlacementGroup,
(const rpc::GetPlacementGroupRequest &request, (const rpc::GetPlacementGroupRequest &request,
rpc::GetPlacementGroupReply *reply, rpc::GetPlacementGroupReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetNamedPlacementGroup, MOCK_METHOD(void,
HandleGetNamedPlacementGroup,
(const rpc::GetNamedPlacementGroupRequest &request, (const rpc::GetNamedPlacementGroupRequest &request,
rpc::GetNamedPlacementGroupReply *reply, rpc::GetNamedPlacementGroupReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetAllPlacementGroup, MOCK_METHOD(void,
HandleGetAllPlacementGroup,
(const rpc::GetAllPlacementGroupRequest &request, (const rpc::GetAllPlacementGroupRequest &request,
rpc::GetAllPlacementGroupReply *reply, rpc::GetAllPlacementGroupReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleWaitPlacementGroupUntilReady, MOCK_METHOD(void,
HandleWaitPlacementGroupUntilReady,
(const rpc::WaitPlacementGroupUntilReadyRequest &request, (const rpc::WaitPlacementGroupUntilReadyRequest &request,
rpc::WaitPlacementGroupUntilReadyReply *reply, rpc::WaitPlacementGroupUntilReadyReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),

View file

@ -28,25 +28,34 @@ namespace gcs {
class MockGcsPlacementGroupSchedulerInterface class MockGcsPlacementGroupSchedulerInterface
: public GcsPlacementGroupSchedulerInterface { : public GcsPlacementGroupSchedulerInterface {
public: public:
MOCK_METHOD(void, ScheduleUnplacedBundles, MOCK_METHOD(void,
ScheduleUnplacedBundles,
(std::shared_ptr<GcsPlacementGroup> placement_group, (std::shared_ptr<GcsPlacementGroup> placement_group,
PGSchedulingFailureCallback failure_callback, PGSchedulingFailureCallback failure_callback,
PGSchedulingSuccessfulCallback success_callback), PGSchedulingSuccessfulCallback success_callback),
(override)); (override));
MOCK_METHOD((absl::flat_hash_map<PlacementGroupID, std::vector<int64_t>>), MOCK_METHOD((absl::flat_hash_map<PlacementGroupID, std::vector<int64_t>>),
GetBundlesOnNode, (const NodeID &node_id), (override)); GetBundlesOnNode,
MOCK_METHOD(void, DestroyPlacementGroupBundleResourcesIfExists, (const NodeID &node_id),
(const PlacementGroupID &placement_group_id), (override)); (override));
MOCK_METHOD(void, MarkScheduleCancelled, (const PlacementGroupID &placement_group_id), MOCK_METHOD(void,
DestroyPlacementGroupBundleResourcesIfExists,
(const PlacementGroupID &placement_group_id),
(override));
MOCK_METHOD(void,
MarkScheduleCancelled,
(const PlacementGroupID &placement_group_id),
(override)); (override));
MOCK_METHOD( MOCK_METHOD(
void, ReleaseUnusedBundles, void,
ReleaseUnusedBundles,
((const absl::flat_hash_map<NodeID, std::vector<rpc::Bundle>> &node_to_bundles)), ((const absl::flat_hash_map<NodeID, std::vector<rpc::Bundle>> &node_to_bundles)),
(override)); (override));
MOCK_METHOD(void, Initialize, MOCK_METHOD(void,
Initialize,
((const absl::flat_hash_map< ((const absl::flat_hash_map<
PlacementGroupID, std::vector<std::shared_ptr<BundleSpecification>>> PlacementGroupID,
&group_to_bundles)), std::vector<std::shared_ptr<BundleSpecification>>> &group_to_bundles)),
(override)); (override));
}; };
@ -69,7 +78,8 @@ namespace gcs {
class MockGcsScheduleStrategy : public GcsScheduleStrategy { class MockGcsScheduleStrategy : public GcsScheduleStrategy {
public: public:
MOCK_METHOD( MOCK_METHOD(
ScheduleResult, Schedule, ScheduleResult,
Schedule,
(const std::vector<std::shared_ptr<const ray::BundleSpecification>> &bundles, (const std::vector<std::shared_ptr<const ray::BundleSpecification>> &bundles,
const std::unique_ptr<ScheduleContext> &context, const std::unique_ptr<ScheduleContext> &context,
GcsResourceScheduler &gcs_resource_scheduler), GcsResourceScheduler &gcs_resource_scheduler),
@ -85,7 +95,8 @@ namespace gcs {
class MockGcsPackStrategy : public GcsPackStrategy { class MockGcsPackStrategy : public GcsPackStrategy {
public: public:
MOCK_METHOD( MOCK_METHOD(
ScheduleResult, Schedule, ScheduleResult,
Schedule,
(const std::vector<std::shared_ptr<const ray::BundleSpecification>> &bundles, (const std::vector<std::shared_ptr<const ray::BundleSpecification>> &bundles,
const std::unique_ptr<ScheduleContext> &context, const std::unique_ptr<ScheduleContext> &context,
GcsResourceScheduler &gcs_resource_scheduler), GcsResourceScheduler &gcs_resource_scheduler),
@ -101,7 +112,8 @@ namespace gcs {
class MockGcsSpreadStrategy : public GcsSpreadStrategy { class MockGcsSpreadStrategy : public GcsSpreadStrategy {
public: public:
MOCK_METHOD( MOCK_METHOD(
ScheduleResult, Schedule, ScheduleResult,
Schedule,
(const std::vector<std::shared_ptr<const ray::BundleSpecification>> &bundles, (const std::vector<std::shared_ptr<const ray::BundleSpecification>> &bundles,
const std::unique_ptr<ScheduleContext> &context, const std::unique_ptr<ScheduleContext> &context,
GcsResourceScheduler &gcs_resource_scheduler), GcsResourceScheduler &gcs_resource_scheduler),
@ -117,7 +129,8 @@ namespace gcs {
class MockGcsStrictPackStrategy : public GcsStrictPackStrategy { class MockGcsStrictPackStrategy : public GcsStrictPackStrategy {
public: public:
MOCK_METHOD( MOCK_METHOD(
ScheduleResult, Schedule, ScheduleResult,
Schedule,
(const std::vector<std::shared_ptr<const ray::BundleSpecification>> &bundles, (const std::vector<std::shared_ptr<const ray::BundleSpecification>> &bundles,
const std::unique_ptr<ScheduleContext> &context, const std::unique_ptr<ScheduleContext> &context,
GcsResourceScheduler &gcs_resource_scheduler), GcsResourceScheduler &gcs_resource_scheduler),
@ -133,7 +146,8 @@ namespace gcs {
class MockGcsStrictSpreadStrategy : public GcsStrictSpreadStrategy { class MockGcsStrictSpreadStrategy : public GcsStrictSpreadStrategy {
public: public:
MOCK_METHOD( MOCK_METHOD(
ScheduleResult, Schedule, ScheduleResult,
Schedule,
(const std::vector<std::shared_ptr<const ray::BundleSpecification>> &bundles, (const std::vector<std::shared_ptr<const ray::BundleSpecification>> &bundles,
const std::unique_ptr<ScheduleContext> &context, const std::unique_ptr<ScheduleContext> &context,
GcsResourceScheduler &gcs_resource_scheduler), GcsResourceScheduler &gcs_resource_scheduler),
@ -168,19 +182,27 @@ namespace gcs {
class MockGcsPlacementGroupScheduler : public GcsPlacementGroupScheduler { class MockGcsPlacementGroupScheduler : public GcsPlacementGroupScheduler {
public: public:
MOCK_METHOD(void, ScheduleUnplacedBundles, MOCK_METHOD(void,
ScheduleUnplacedBundles,
(std::shared_ptr<GcsPlacementGroup> placement_group, (std::shared_ptr<GcsPlacementGroup> placement_group,
PGSchedulingFailureCallback failure_handler, PGSchedulingFailureCallback failure_handler,
PGSchedulingSuccessfulCallback success_handler), PGSchedulingSuccessfulCallback success_handler),
(override)); (override));
MOCK_METHOD(void, DestroyPlacementGroupBundleResourcesIfExists, MOCK_METHOD(void,
(const PlacementGroupID &placement_group_id), (override)); DestroyPlacementGroupBundleResourcesIfExists,
MOCK_METHOD(void, MarkScheduleCancelled, (const PlacementGroupID &placement_group_id), (const PlacementGroupID &placement_group_id),
(override));
MOCK_METHOD(void,
MarkScheduleCancelled,
(const PlacementGroupID &placement_group_id),
(override)); (override));
MOCK_METHOD((absl::flat_hash_map<PlacementGroupID, std::vector<int64_t>>), MOCK_METHOD((absl::flat_hash_map<PlacementGroupID, std::vector<int64_t>>),
GetBundlesOnNode, (const NodeID &node_id), (override)); GetBundlesOnNode,
(const NodeID &node_id),
(override));
MOCK_METHOD( MOCK_METHOD(
void, ReleaseUnusedBundles, void,
ReleaseUnusedBundles,
((const absl::flat_hash_map<NodeID, std::vector<rpc::Bundle>> &node_to_bundles)), ((const absl::flat_hash_map<NodeID, std::vector<rpc::Bundle>> &node_to_bundles)),
(override)); (override));
}; };

View file

@ -18,21 +18,26 @@ namespace gcs {
class MockGcsResourceManager : public GcsResourceManager { class MockGcsResourceManager : public GcsResourceManager {
public: public:
using GcsResourceManager::GcsResourceManager; using GcsResourceManager::GcsResourceManager;
MOCK_METHOD(void, HandleGetResources, MOCK_METHOD(void,
(const rpc::GetResourcesRequest &request, rpc::GetResourcesReply *reply, HandleGetResources,
(const rpc::GetResourcesRequest &request,
rpc::GetResourcesReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetAllAvailableResources, MOCK_METHOD(void,
HandleGetAllAvailableResources,
(const rpc::GetAllAvailableResourcesRequest &request, (const rpc::GetAllAvailableResourcesRequest &request,
rpc::GetAllAvailableResourcesReply *reply, rpc::GetAllAvailableResourcesReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleReportResourceUsage, MOCK_METHOD(void,
HandleReportResourceUsage,
(const rpc::ReportResourceUsageRequest &request, (const rpc::ReportResourceUsageRequest &request,
rpc::ReportResourceUsageReply *reply, rpc::ReportResourceUsageReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetAllResourceUsage, MOCK_METHOD(void,
HandleGetAllResourceUsage,
(const rpc::GetAllResourceUsageRequest &request, (const rpc::GetAllResourceUsageRequest &request,
rpc::GetAllResourceUsageReply *reply, rpc::GetAllResourceUsageReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),

View file

@ -17,7 +17,8 @@ namespace gcs {
class MockNodeScorer : public NodeScorer { class MockNodeScorer : public NodeScorer {
public: public:
MOCK_METHOD(double, Score, MOCK_METHOD(double,
Score,
(const ResourceSet &required_resources, (const ResourceSet &required_resources,
const SchedulingResources &node_resources), const SchedulingResources &node_resources),
(override)); (override));
@ -31,7 +32,8 @@ namespace gcs {
class MockLeastResourceScorer : public LeastResourceScorer { class MockLeastResourceScorer : public LeastResourceScorer {
public: public:
MOCK_METHOD(double, Score, MOCK_METHOD(double,
Score,
(const ResourceSet &required_resources, (const ResourceSet &required_resources,
const SchedulingResources &node_resources), const SchedulingResources &node_resources),
(override)); (override));

View file

@ -18,13 +18,18 @@ namespace gcs {
template <typename Key, typename Data> template <typename Key, typename Data>
class MockGcsTable : public GcsTable<Key, Data> { class MockGcsTable : public GcsTable<Key, Data> {
public: public:
MOCK_METHOD(Status, Put, MOCK_METHOD(Status,
Put,
(const Key &key, const Data &value, const StatusCallback &callback), (const Key &key, const Data &value, const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, Delete, (const Key &key, const StatusCallback &callback), MOCK_METHOD(Status,
Delete,
(const Key &key, const StatusCallback &callback),
(override));
MOCK_METHOD(Status,
BatchDelete,
(const std::vector<Key> &keys, const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, BatchDelete,
(const std::vector<Key> &keys, const StatusCallback &callback), (override));
}; };
} // namespace gcs } // namespace gcs
@ -36,13 +41,18 @@ namespace gcs {
template <typename Key, typename Data> template <typename Key, typename Data>
class MockGcsTableWithJobId : public GcsTableWithJobId<Key, Data> { class MockGcsTableWithJobId : public GcsTableWithJobId<Key, Data> {
public: public:
MOCK_METHOD(Status, Put, MOCK_METHOD(Status,
Put,
(const Key &key, const Data &value, const StatusCallback &callback), (const Key &key, const Data &value, const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, Delete, (const Key &key, const StatusCallback &callback), MOCK_METHOD(Status,
Delete,
(const Key &key, const StatusCallback &callback),
(override));
MOCK_METHOD(Status,
BatchDelete,
(const std::vector<Key> &keys, const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, BatchDelete,
(const std::vector<Key> &keys, const StatusCallback &callback), (override));
MOCK_METHOD(JobID, GetJobIdFromKey, (const Key &key), (override)); MOCK_METHOD(JobID, GetJobIdFromKey, (const Key &key), (override));
}; };

View file

@ -17,22 +17,28 @@ namespace gcs {
class MockGcsWorkerManager : public GcsWorkerManager { class MockGcsWorkerManager : public GcsWorkerManager {
public: public:
MOCK_METHOD(void, HandleReportWorkerFailure, MOCK_METHOD(void,
HandleReportWorkerFailure,
(const rpc::ReportWorkerFailureRequest &request, (const rpc::ReportWorkerFailureRequest &request,
rpc::ReportWorkerFailureReply *reply, rpc::ReportWorkerFailureReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetWorkerInfo, MOCK_METHOD(void,
(const rpc::GetWorkerInfoRequest &request, rpc::GetWorkerInfoReply *reply, HandleGetWorkerInfo,
(const rpc::GetWorkerInfoRequest &request,
rpc::GetWorkerInfoReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetAllWorkerInfo, MOCK_METHOD(void,
HandleGetAllWorkerInfo,
(const rpc::GetAllWorkerInfoRequest &request, (const rpc::GetAllWorkerInfoRequest &request,
rpc::GetAllWorkerInfoReply *reply, rpc::GetAllWorkerInfoReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleAddWorkerInfo, MOCK_METHOD(void,
(const rpc::AddWorkerInfoRequest &request, rpc::AddWorkerInfoReply *reply, HandleAddWorkerInfo,
(const rpc::AddWorkerInfoRequest &request,
rpc::AddWorkerInfoReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
}; };

View file

@ -17,11 +17,14 @@ namespace rpc {
class MockDefaultStatsHandler : public DefaultStatsHandler { class MockDefaultStatsHandler : public DefaultStatsHandler {
public: public:
MOCK_METHOD(void, HandleAddProfileData, MOCK_METHOD(void,
(const AddProfileDataRequest &request, AddProfileDataReply *reply, HandleAddProfileData,
(const AddProfileDataRequest &request,
AddProfileDataReply *reply,
SendReplyCallback send_reply_callback), SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetAllProfileInfo, MOCK_METHOD(void,
HandleGetAllProfileInfo,
(const rpc::GetAllProfileInfoRequest &request, (const rpc::GetAllProfileInfoRequest &request,
rpc::GetAllProfileInfoReply *reply, rpc::GetAllProfileInfoReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),

View file

@ -17,8 +17,11 @@ namespace gcs {
class MockGcsPubSub : public GcsPubSub { class MockGcsPubSub : public GcsPubSub {
public: public:
MOCK_METHOD(Status, Publish, MOCK_METHOD(Status,
(const std::string &channel, const std::string &id, const std::string &data, Publish,
(const std::string &channel,
const std::string &id,
const std::string &data,
const StatusCallback &done), const StatusCallback &done),
(override)); (override));
}; };

View file

@ -17,46 +17,68 @@ namespace gcs {
class MockInMemoryStoreClient : public InMemoryStoreClient { class MockInMemoryStoreClient : public InMemoryStoreClient {
public: public:
MOCK_METHOD(Status, AsyncPut, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncPut,
const std::string &data, const StatusCallback &callback), (const std::string &table_name,
(override)); const std::string &key,
MOCK_METHOD(Status, AsyncPutWithIndex, const std::string &data,
(const std::string &table_name, const std::string &key,
const std::string &index_key, const std::string &data,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGet, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncPutWithIndex,
(const std::string &table_name,
const std::string &key,
const std::string &index_key,
const std::string &data,
const StatusCallback &callback),
(override));
MOCK_METHOD(Status,
AsyncGet,
(const std::string &table_name,
const std::string &key,
const OptionalItemCallback<std::string> &callback), const OptionalItemCallback<std::string> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGetByIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &index_key, AsyncGetByIndex,
(const std::string &table_name,
const std::string &index_key,
(const MapCallback<std::string, std::string> &callback)), (const MapCallback<std::string, std::string> &callback)),
(override)); (override));
MOCK_METHOD(Status, AsyncGetAll, MOCK_METHOD(Status,
AsyncGetAll,
(const std::string &table_name, (const std::string &table_name,
(const MapCallback<std::string, std::string> &callback)), (const MapCallback<std::string, std::string> &callback)),
(override)); (override));
MOCK_METHOD(Status, AsyncDelete, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncDelete,
(const std::string &table_name,
const std::string &key,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncDeleteWithIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncDeleteWithIndex,
const std::string &index_key, const StatusCallback &callback), (const std::string &table_name,
(override)); const std::string &key,
MOCK_METHOD(Status, AsyncBatchDelete, const std::string &index_key,
(const std::string &table_name, const std::vector<std::string> &keys,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::vector<std::string> &keys, AsyncBatchDelete,
(const std::string &table_name,
const std::vector<std::string> &keys,
const StatusCallback &callback),
(override));
MOCK_METHOD(Status,
AsyncBatchDeleteWithIndex,
(const std::string &table_name,
const std::vector<std::string> &keys,
const std::vector<std::string> &index_keys, const std::vector<std::string> &index_keys,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncDeleteByIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &index_key, AsyncDeleteByIndex,
(const std::string &table_name,
const std::string &index_key,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(int, GetNextJobID, (), (override)); MOCK_METHOD(int, GetNextJobID, (), (override));

View file

@ -18,46 +18,68 @@ namespace gcs {
class MockRedisStoreClient : public RedisStoreClient { class MockRedisStoreClient : public RedisStoreClient {
public: public:
MockRedisStoreClient() : RedisStoreClient(nullptr) {} MockRedisStoreClient() : RedisStoreClient(nullptr) {}
MOCK_METHOD(Status, AsyncPut, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncPut,
const std::string &data, const StatusCallback &callback), (const std::string &table_name,
(override)); const std::string &key,
MOCK_METHOD(Status, AsyncPutWithIndex, const std::string &data,
(const std::string &table_name, const std::string &key,
const std::string &index_key, const std::string &data,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGet, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncPutWithIndex,
(const std::string &table_name,
const std::string &key,
const std::string &index_key,
const std::string &data,
const StatusCallback &callback),
(override));
MOCK_METHOD(Status,
AsyncGet,
(const std::string &table_name,
const std::string &key,
const OptionalItemCallback<std::string> &callback), const OptionalItemCallback<std::string> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGetByIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &index_key, AsyncGetByIndex,
(const std::string &table_name,
const std::string &index_key,
(const MapCallback<std::string, std::string> &callback)), (const MapCallback<std::string, std::string> &callback)),
(override)); (override));
MOCK_METHOD(Status, AsyncGetAll, MOCK_METHOD(Status,
AsyncGetAll,
(const std::string &table_name, (const std::string &table_name,
(const MapCallback<std::string, std::string> &callback)), (const MapCallback<std::string, std::string> &callback)),
(override)); (override));
MOCK_METHOD(Status, AsyncDelete, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncDelete,
(const std::string &table_name,
const std::string &key,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncDeleteWithIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncDeleteWithIndex,
const std::string &index_key, const StatusCallback &callback), (const std::string &table_name,
(override)); const std::string &key,
MOCK_METHOD(Status, AsyncBatchDelete, const std::string &index_key,
(const std::string &table_name, const std::vector<std::string> &keys,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::vector<std::string> &keys, AsyncBatchDelete,
(const std::string &table_name,
const std::vector<std::string> &keys,
const StatusCallback &callback),
(override));
MOCK_METHOD(Status,
AsyncBatchDeleteWithIndex,
(const std::string &table_name,
const std::vector<std::string> &keys,
const std::vector<std::string> &index_keys, const std::vector<std::string> &index_keys,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncDeleteByIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &index_key, AsyncDeleteByIndex,
(const std::string &table_name,
const std::string &index_key,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(int, GetNextJobID, (), (override)); MOCK_METHOD(int, GetNextJobID, (), (override));

View file

@ -17,46 +17,68 @@ namespace gcs {
class MockStoreClient : public StoreClient { class MockStoreClient : public StoreClient {
public: public:
MOCK_METHOD(Status, AsyncPut, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncPut,
const std::string &data, const StatusCallback &callback), (const std::string &table_name,
(override)); const std::string &key,
MOCK_METHOD(Status, AsyncPutWithIndex, const std::string &data,
(const std::string &table_name, const std::string &key,
const std::string &index_key, const std::string &data,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGet, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncPutWithIndex,
(const std::string &table_name,
const std::string &key,
const std::string &index_key,
const std::string &data,
const StatusCallback &callback),
(override));
MOCK_METHOD(Status,
AsyncGet,
(const std::string &table_name,
const std::string &key,
const OptionalItemCallback<std::string> &callback), const OptionalItemCallback<std::string> &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncGetByIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &index_key, AsyncGetByIndex,
(const std::string &table_name,
const std::string &index_key,
(const MapCallback<std::string, std::string> &callback)), (const MapCallback<std::string, std::string> &callback)),
(override)); (override));
MOCK_METHOD(Status, AsyncGetAll, MOCK_METHOD(Status,
AsyncGetAll,
(const std::string &table_name, (const std::string &table_name,
(const MapCallback<std::string, std::string> &callback)), (const MapCallback<std::string, std::string> &callback)),
(override)); (override));
MOCK_METHOD(Status, AsyncDelete, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncDelete,
(const std::string &table_name,
const std::string &key,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncDeleteWithIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &key, AsyncDeleteWithIndex,
const std::string &index_key, const StatusCallback &callback), (const std::string &table_name,
(override)); const std::string &key,
MOCK_METHOD(Status, AsyncBatchDelete, const std::string &index_key,
(const std::string &table_name, const std::vector<std::string> &keys,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncBatchDeleteWithIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::vector<std::string> &keys, AsyncBatchDelete,
(const std::string &table_name,
const std::vector<std::string> &keys,
const StatusCallback &callback),
(override));
MOCK_METHOD(Status,
AsyncBatchDeleteWithIndex,
(const std::string &table_name,
const std::vector<std::string> &keys,
const std::vector<std::string> &index_keys, const std::vector<std::string> &index_keys,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(Status, AsyncDeleteByIndex, MOCK_METHOD(Status,
(const std::string &table_name, const std::string &index_key, AsyncDeleteByIndex,
(const std::string &table_name,
const std::string &index_key,
const StatusCallback &callback), const StatusCallback &callback),
(override)); (override));
MOCK_METHOD(int, GetNextJobID, (), (override)); MOCK_METHOD(int, GetNextJobID, (), (override));

View file

@ -54,19 +54,26 @@ namespace pubsub {
class MockPublisherInterface : public PublisherInterface { class MockPublisherInterface : public PublisherInterface {
public: public:
MOCK_METHOD(bool, RegisterSubscription, MOCK_METHOD(bool,
(const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, RegisterSubscription,
(const rpc::ChannelType channel_type,
const SubscriberID &subscriber_id,
const std::optional<std::string> &key_id_binary), const std::optional<std::string> &key_id_binary),
(override)); (override));
MOCK_METHOD(void, Publish, MOCK_METHOD(void,
(const rpc::ChannelType channel_type, const rpc::PubMessage &pub_message, Publish,
(const rpc::ChannelType channel_type,
const rpc::PubMessage &pub_message,
const std::string &key_id_binary), const std::string &key_id_binary),
(override)); (override));
MOCK_METHOD(void, PublishFailure, MOCK_METHOD(void,
PublishFailure,
(const rpc::ChannelType channel_type, const std::string &key_id_binary), (const rpc::ChannelType channel_type, const std::string &key_id_binary),
(override)); (override));
MOCK_METHOD(bool, UnregisterSubscription, MOCK_METHOD(bool,
(const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, UnregisterSubscription,
(const rpc::ChannelType channel_type,
const SubscriberID &subscriber_id,
const std::optional<std::string> &key_id_binary), const std::optional<std::string> &key_id_binary),
(override)); (override));
}; };
@ -79,19 +86,26 @@ namespace pubsub {
class MockPublisher : public Publisher { class MockPublisher : public Publisher {
public: public:
MOCK_METHOD(bool, RegisterSubscription, MOCK_METHOD(bool,
(const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, RegisterSubscription,
(const rpc::ChannelType channel_type,
const SubscriberID &subscriber_id,
const std::optional<std::string> &key_id_binary), const std::optional<std::string> &key_id_binary),
(override)); (override));
MOCK_METHOD(void, Publish, MOCK_METHOD(void,
(const rpc::ChannelType channel_type, const rpc::PubMessage &pub_message, Publish,
(const rpc::ChannelType channel_type,
const rpc::PubMessage &pub_message,
const std::string &key_id_binary), const std::string &key_id_binary),
(override)); (override));
MOCK_METHOD(void, PublishFailure, MOCK_METHOD(void,
PublishFailure,
(const rpc::ChannelType channel_type, const std::string &key_id_binary), (const rpc::ChannelType channel_type, const std::string &key_id_binary),
(override)); (override));
MOCK_METHOD(bool, UnregisterSubscription, MOCK_METHOD(bool,
(const rpc::ChannelType channel_type, const SubscriberID &subscriber_id, UnregisterSubscription,
(const rpc::ChannelType channel_type,
const SubscriberID &subscriber_id,
const std::optional<std::string> &key_id_binary), const std::optional<std::string> &key_id_binary),
(override)); (override));
}; };

View file

@ -17,25 +17,33 @@ namespace pubsub {
class MockSubscriberInterface : public SubscriberInterface { class MockSubscriberInterface : public SubscriberInterface {
public: public:
MOCK_METHOD(bool, Subscribe, MOCK_METHOD(bool,
Subscribe,
(std::unique_ptr<rpc::SubMessage> sub_message, (std::unique_ptr<rpc::SubMessage> sub_message,
const rpc::ChannelType channel_type, const rpc::Address &publisher_address, const rpc::ChannelType channel_type,
const std::string &key_id, SubscribeDoneCallback subscribe_done_callback, const rpc::Address &publisher_address,
SubscriptionItemCallback subscription_callback, const std::string &key_id,
SubscriptionFailureCallback subscription_failure_callback),
(override));
MOCK_METHOD(bool, SubscribeChannel,
(std::unique_ptr<rpc::SubMessage> sub_message,
const rpc::ChannelType channel_type, const rpc::Address &publisher_address,
SubscribeDoneCallback subscribe_done_callback, SubscribeDoneCallback subscribe_done_callback,
SubscriptionItemCallback subscription_callback, SubscriptionItemCallback subscription_callback,
SubscriptionFailureCallback subscription_failure_callback), SubscriptionFailureCallback subscription_failure_callback),
(override)); (override));
MOCK_METHOD(bool, Unsubscribe, MOCK_METHOD(bool,
(const rpc::ChannelType channel_type, const rpc::Address &publisher_address, SubscribeChannel,
(std::unique_ptr<rpc::SubMessage> sub_message,
const rpc::ChannelType channel_type,
const rpc::Address &publisher_address,
SubscribeDoneCallback subscribe_done_callback,
SubscriptionItemCallback subscription_callback,
SubscriptionFailureCallback subscription_failure_callback),
(override));
MOCK_METHOD(bool,
Unsubscribe,
(const rpc::ChannelType channel_type,
const rpc::Address &publisher_address,
const std::string &key_id), const std::string &key_id),
(override)); (override));
MOCK_METHOD(bool, UnsubscribeChannel, MOCK_METHOD(bool,
UnsubscribeChannel,
(const rpc::ChannelType channel_type, (const rpc::ChannelType channel_type,
const rpc::Address &publisher_address), const rpc::Address &publisher_address),
(override)); (override));
@ -50,11 +58,13 @@ namespace pubsub {
class MockSubscriberClientInterface : public SubscriberClientInterface { class MockSubscriberClientInterface : public SubscriberClientInterface {
public: public:
MOCK_METHOD(void, PubsubLongPolling, MOCK_METHOD(void,
PubsubLongPolling,
(const rpc::PubsubLongPollingRequest &request, (const rpc::PubsubLongPollingRequest &request,
const rpc::ClientCallback<rpc::PubsubLongPollingReply> &callback), const rpc::ClientCallback<rpc::PubsubLongPollingReply> &callback),
(override)); (override));
MOCK_METHOD(void, PubsubCommandBatch, MOCK_METHOD(void,
PubsubCommandBatch,
(const rpc::PubsubCommandBatchRequest &request, (const rpc::PubsubCommandBatchRequest &request,
const rpc::ClientCallback<rpc::PubsubCommandBatchReply> &callback), const rpc::ClientCallback<rpc::PubsubCommandBatchReply> &callback),
(override)); (override));

View file

@ -17,15 +17,20 @@ namespace raylet {
class MockAgentManager : public AgentManager { class MockAgentManager : public AgentManager {
public: public:
MOCK_METHOD(void, HandleRegisterAgent, MOCK_METHOD(void,
(const rpc::RegisterAgentRequest &request, rpc::RegisterAgentReply *reply, HandleRegisterAgent,
(const rpc::RegisterAgentRequest &request,
rpc::RegisterAgentReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, CreateRuntimeEnv, MOCK_METHOD(void,
(const JobID &job_id, const std::string &serialized_runtime_env, CreateRuntimeEnv,
(const JobID &job_id,
const std::string &serialized_runtime_env,
CreateRuntimeEnvCallback callback), CreateRuntimeEnvCallback callback),
(override)); (override));
MOCK_METHOD(void, DeleteRuntimeEnv, MOCK_METHOD(void,
DeleteRuntimeEnv,
(const std::string &serialized_runtime_env, (const std::string &serialized_runtime_env,
DeleteRuntimeEnvCallback callback), DeleteRuntimeEnvCallback callback),
(override)); (override));
@ -39,8 +44,10 @@ namespace raylet {
class MockDefaultAgentManagerServiceHandler : public DefaultAgentManagerServiceHandler { class MockDefaultAgentManagerServiceHandler : public DefaultAgentManagerServiceHandler {
public: public:
MOCK_METHOD(void, HandleRegisterAgent, MOCK_METHOD(void,
(const rpc::RegisterAgentRequest &request, rpc::RegisterAgentReply *reply, HandleRegisterAgent,
(const rpc::RegisterAgentRequest &request,
rpc::RegisterAgentReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
}; };

View file

@ -17,7 +17,8 @@ namespace raylet {
class MockTaskDependencyManagerInterface : public TaskDependencyManagerInterface { class MockTaskDependencyManagerInterface : public TaskDependencyManagerInterface {
public: public:
MOCK_METHOD(bool, RequestTaskDependencies, MOCK_METHOD(bool,
RequestTaskDependencies,
(const TaskID &task_id, (const TaskID &task_id,
const std::vector<rpc::ObjectReference> &required_objects), const std::vector<rpc::ObjectReference> &required_objects),
(override)); (override));

View file

@ -37,88 +37,110 @@ namespace raylet {
class MockNodeManager : public NodeManager { class MockNodeManager : public NodeManager {
public: public:
MOCK_METHOD(void, HandleUpdateResourceUsage, MOCK_METHOD(void,
HandleUpdateResourceUsage,
(const rpc::UpdateResourceUsageRequest &request, (const rpc::UpdateResourceUsageRequest &request,
rpc::UpdateResourceUsageReply *reply, rpc::UpdateResourceUsageReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleRequestResourceReport, MOCK_METHOD(void,
HandleRequestResourceReport,
(const rpc::RequestResourceReportRequest &request, (const rpc::RequestResourceReportRequest &request,
rpc::RequestResourceReportReply *reply, rpc::RequestResourceReportReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandlePrepareBundleResources, MOCK_METHOD(void,
HandlePrepareBundleResources,
(const rpc::PrepareBundleResourcesRequest &request, (const rpc::PrepareBundleResourcesRequest &request,
rpc::PrepareBundleResourcesReply *reply, rpc::PrepareBundleResourcesReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleCommitBundleResources, MOCK_METHOD(void,
HandleCommitBundleResources,
(const rpc::CommitBundleResourcesRequest &request, (const rpc::CommitBundleResourcesRequest &request,
rpc::CommitBundleResourcesReply *reply, rpc::CommitBundleResourcesReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleCancelResourceReserve, MOCK_METHOD(void,
HandleCancelResourceReserve,
(const rpc::CancelResourceReserveRequest &request, (const rpc::CancelResourceReserveRequest &request,
rpc::CancelResourceReserveReply *reply, rpc::CancelResourceReserveReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleRequestWorkerLease, MOCK_METHOD(void,
HandleRequestWorkerLease,
(const rpc::RequestWorkerLeaseRequest &request, (const rpc::RequestWorkerLeaseRequest &request,
rpc::RequestWorkerLeaseReply *reply, rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleReportWorkerBacklog, MOCK_METHOD(void,
HandleReportWorkerBacklog,
(const rpc::ReportWorkerBacklogRequest &request, (const rpc::ReportWorkerBacklogRequest &request,
rpc::ReportWorkerBacklogReply *reply, rpc::ReportWorkerBacklogReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleReturnWorker, MOCK_METHOD(void,
(const rpc::ReturnWorkerRequest &request, rpc::ReturnWorkerReply *reply, HandleReturnWorker,
(const rpc::ReturnWorkerRequest &request,
rpc::ReturnWorkerReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleReleaseUnusedWorkers, MOCK_METHOD(void,
HandleReleaseUnusedWorkers,
(const rpc::ReleaseUnusedWorkersRequest &request, (const rpc::ReleaseUnusedWorkersRequest &request,
rpc::ReleaseUnusedWorkersReply *reply, rpc::ReleaseUnusedWorkersReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleCancelWorkerLease, MOCK_METHOD(void,
HandleCancelWorkerLease,
(const rpc::CancelWorkerLeaseRequest &request, (const rpc::CancelWorkerLeaseRequest &request,
rpc::CancelWorkerLeaseReply *reply, rpc::CancelWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandlePinObjectIDs, MOCK_METHOD(void,
(const rpc::PinObjectIDsRequest &request, rpc::PinObjectIDsReply *reply, HandlePinObjectIDs,
(const rpc::PinObjectIDsRequest &request,
rpc::PinObjectIDsReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetNodeStats, MOCK_METHOD(void,
(const rpc::GetNodeStatsRequest &request, rpc::GetNodeStatsReply *reply, HandleGetNodeStats,
(const rpc::GetNodeStatsRequest &request,
rpc::GetNodeStatsReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGlobalGC, MOCK_METHOD(void,
(const rpc::GlobalGCRequest &request, rpc::GlobalGCReply *reply, HandleGlobalGC,
(const rpc::GlobalGCRequest &request,
rpc::GlobalGCReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleFormatGlobalMemoryInfo, MOCK_METHOD(void,
HandleFormatGlobalMemoryInfo,
(const rpc::FormatGlobalMemoryInfoRequest &request, (const rpc::FormatGlobalMemoryInfoRequest &request,
rpc::FormatGlobalMemoryInfoReply *reply, rpc::FormatGlobalMemoryInfoReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleRequestObjectSpillage, MOCK_METHOD(void,
HandleRequestObjectSpillage,
(const rpc::RequestObjectSpillageRequest &request, (const rpc::RequestObjectSpillageRequest &request,
rpc::RequestObjectSpillageReply *reply, rpc::RequestObjectSpillageReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleReleaseUnusedBundles, MOCK_METHOD(void,
HandleReleaseUnusedBundles,
(const rpc::ReleaseUnusedBundlesRequest &request, (const rpc::ReleaseUnusedBundlesRequest &request,
rpc::ReleaseUnusedBundlesReply *reply, rpc::ReleaseUnusedBundlesReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetSystemConfig, MOCK_METHOD(void,
HandleGetSystemConfig,
(const rpc::GetSystemConfigRequest &request, (const rpc::GetSystemConfigRequest &request,
rpc::GetSystemConfigReply *reply, rpc::GetSystemConfigReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, HandleGetGcsServerAddress, MOCK_METHOD(void,
HandleGetGcsServerAddress,
(const rpc::GetGcsServerAddressRequest &request, (const rpc::GetGcsServerAddressRequest &request,
rpc::GetGcsServerAddressReply *reply, rpc::GetGcsServerAddressReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),

View file

@ -27,31 +27,49 @@ namespace raylet {
class MockClusterTaskManager : public ClusterTaskManager { class MockClusterTaskManager : public ClusterTaskManager {
public: public:
MOCK_METHOD(void, QueueAndScheduleTask, MOCK_METHOD(void,
(const RayTask &task, rpc::RequestWorkerLeaseReply *reply, QueueAndScheduleTask,
(const RayTask &task,
rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(void, TasksUnblocked, (const std::vector<TaskID> &ready_ids), (override)); MOCK_METHOD(void, TasksUnblocked, (const std::vector<TaskID> &ready_ids), (override));
MOCK_METHOD(void, TaskFinished, MOCK_METHOD(void,
(std::shared_ptr<WorkerInterface> worker, RayTask *task), (override)); TaskFinished,
MOCK_METHOD(bool, CancelTask, (const TaskID &task_id, bool runtime_env_setup_failed), (std::shared_ptr<WorkerInterface> worker, RayTask *task),
(override)); (override));
MOCK_METHOD(void, FillPendingActorInfo, (rpc::GetNodeStatsReply * reply), MOCK_METHOD(bool,
CancelTask,
(const TaskID &task_id, bool runtime_env_setup_failed),
(override));
MOCK_METHOD(void,
FillPendingActorInfo,
(rpc::GetNodeStatsReply * reply),
(const, override)); (const, override));
MOCK_METHOD(void, FillResourceUsage, MOCK_METHOD(void,
FillResourceUsage,
(rpc::ResourcesData & data, (rpc::ResourcesData & data,
const std::shared_ptr<SchedulingResources> &last_reported_resources), const std::shared_ptr<SchedulingResources> &last_reported_resources),
(override)); (override));
MOCK_METHOD(bool, AnyPendingTasksForResourceAcquisition, MOCK_METHOD(bool,
(RayTask * exemplar, bool *any_pending, int *num_pending_actor_creation, AnyPendingTasksForResourceAcquisition,
(RayTask * exemplar,
bool *any_pending,
int *num_pending_actor_creation,
int *num_pending_tasks), int *num_pending_tasks),
(const, override)); (const, override));
MOCK_METHOD(void, ReleaseWorkerResources, (std::shared_ptr<WorkerInterface> worker), MOCK_METHOD(void,
ReleaseWorkerResources,
(std::shared_ptr<WorkerInterface> worker),
(override));
MOCK_METHOD(bool,
ReleaseCpuResourcesFromUnblockedWorker,
(std::shared_ptr<WorkerInterface> worker),
(override));
MOCK_METHOD(bool,
ReturnCpuResourcesToBlockedWorker,
(std::shared_ptr<WorkerInterface> worker),
(override)); (override));
MOCK_METHOD(bool, ReleaseCpuResourcesFromUnblockedWorker,
(std::shared_ptr<WorkerInterface> worker), (override));
MOCK_METHOD(bool, ReturnCpuResourcesToBlockedWorker,
(std::shared_ptr<WorkerInterface> worker), (override));
MOCK_METHOD(void, ScheduleAndDispatchTasks, (), (override)); MOCK_METHOD(void, ScheduleAndDispatchTasks, (), (override));
MOCK_METHOD(void, RecordMetrics, (), (override)); MOCK_METHOD(void, RecordMetrics, (), (override));
MOCK_METHOD(std::string, DebugStr, (), (const, override)); MOCK_METHOD(std::string, DebugStr, (), (const, override));

View file

@ -17,33 +17,50 @@ namespace raylet {
class MockClusterTaskManagerInterface : public ClusterTaskManagerInterface { class MockClusterTaskManagerInterface : public ClusterTaskManagerInterface {
public: public:
MOCK_METHOD(void, ReleaseWorkerResources, (std::shared_ptr<WorkerInterface> worker), MOCK_METHOD(void,
ReleaseWorkerResources,
(std::shared_ptr<WorkerInterface> worker),
(override));
MOCK_METHOD(bool,
ReleaseCpuResourcesFromUnblockedWorker,
(std::shared_ptr<WorkerInterface> worker),
(override));
MOCK_METHOD(bool,
ReturnCpuResourcesToBlockedWorker,
(std::shared_ptr<WorkerInterface> worker),
(override)); (override));
MOCK_METHOD(bool, ReleaseCpuResourcesFromUnblockedWorker,
(std::shared_ptr<WorkerInterface> worker), (override));
MOCK_METHOD(bool, ReturnCpuResourcesToBlockedWorker,
(std::shared_ptr<WorkerInterface> worker), (override));
MOCK_METHOD(void, ScheduleAndDispatchTasks, (), (override)); MOCK_METHOD(void, ScheduleAndDispatchTasks, (), (override));
MOCK_METHOD(void, TasksUnblocked, (const std::vector<TaskID> &ready_ids), (override)); MOCK_METHOD(void, TasksUnblocked, (const std::vector<TaskID> &ready_ids), (override));
MOCK_METHOD(void, FillResourceUsage, MOCK_METHOD(void,
FillResourceUsage,
(rpc::ResourcesData & data, (rpc::ResourcesData & data,
const std::shared_ptr<SchedulingResources> &last_reported_resources), const std::shared_ptr<SchedulingResources> &last_reported_resources),
(override)); (override));
MOCK_METHOD(void, FillPendingActorInfo, (rpc::GetNodeStatsReply * reply), MOCK_METHOD(void,
FillPendingActorInfo,
(rpc::GetNodeStatsReply * reply),
(const, override)); (const, override));
MOCK_METHOD(void, TaskFinished, MOCK_METHOD(void,
(std::shared_ptr<WorkerInterface> worker, RayTask *task), (override)); TaskFinished,
MOCK_METHOD(bool, CancelTask, (std::shared_ptr<WorkerInterface> worker, RayTask *task),
(override));
MOCK_METHOD(bool,
CancelTask,
(const TaskID &task_id, (const TaskID &task_id,
rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type,
const std::string &scheduling_failure_message), const std::string &scheduling_failure_message),
(override)); (override));
MOCK_METHOD(void, QueueAndScheduleTask, MOCK_METHOD(void,
(const RayTask &task, rpc::RequestWorkerLeaseReply *reply, QueueAndScheduleTask,
(const RayTask &task,
rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback), rpc::SendReplyCallback send_reply_callback),
(override)); (override));
MOCK_METHOD(bool, AnyPendingTasksForResourceAcquisition, MOCK_METHOD(bool,
(RayTask * exemplar, bool *any_pending, int *num_pending_actor_creation, AnyPendingTasksForResourceAcquisition,
(RayTask * exemplar,
bool *any_pending,
int *num_pending_actor_creation,
int *num_pending_tasks), int *num_pending_tasks),
(const, override)); (const, override));
MOCK_METHOD(std::string, DebugStr, (), (const, override)); MOCK_METHOD(std::string, DebugStr, (), (const, override));

View file

@ -29,7 +29,9 @@ class MockWorkerInterface : public WorkerInterface {
MOCK_METHOD(Language, GetLanguage, (), (const, override)); MOCK_METHOD(Language, GetLanguage, (), (const, override));
MOCK_METHOD(const std::string, IpAddress, (), (const, override)); MOCK_METHOD(const std::string, IpAddress, (), (const, override));
MOCK_METHOD(void, Connect, (int port), (override)); MOCK_METHOD(void, Connect, (int port), (override));
MOCK_METHOD(void, Connect, (std::shared_ptr<rpc::CoreWorkerClientInterface> rpc_client), MOCK_METHOD(void,
Connect,
(std::shared_ptr<rpc::CoreWorkerClientInterface> rpc_client),
(override)); (override));
MOCK_METHOD(int, Port, (), (const, override)); MOCK_METHOD(int, Port, (), (const, override));
MOCK_METHOD(int, AssignedPort, (), (const, override)); MOCK_METHOD(int, AssignedPort, (), (const, override));
@ -38,7 +40,9 @@ class MockWorkerInterface : public WorkerInterface {
MOCK_METHOD(const TaskID &, GetAssignedTaskId, (), (const, override)); MOCK_METHOD(const TaskID &, GetAssignedTaskId, (), (const, override));
MOCK_METHOD(bool, AddBlockedTaskId, (const TaskID &task_id), (override)); MOCK_METHOD(bool, AddBlockedTaskId, (const TaskID &task_id), (override));
MOCK_METHOD(bool, RemoveBlockedTaskId, (const TaskID &task_id), (override)); MOCK_METHOD(bool, RemoveBlockedTaskId, (const TaskID &task_id), (override));
MOCK_METHOD(const std::unordered_set<TaskID> &, GetBlockedTaskIds, (), MOCK_METHOD(const std::unordered_set<TaskID> &,
GetBlockedTaskIds,
(),
(const, override)); (const, override));
MOCK_METHOD(const JobID &, GetAssignedJobId, (), (const, override)); MOCK_METHOD(const JobID &, GetAssignedJobId, (), (const, override));
MOCK_METHOD(int, GetRuntimeEnvHash, (), (const, override)); MOCK_METHOD(int, GetRuntimeEnvHash, (), (const, override));
@ -52,16 +56,22 @@ class MockWorkerInterface : public WorkerInterface {
MOCK_METHOD(void, DirectActorCallArgWaitComplete, (int64_t tag), (override)); MOCK_METHOD(void, DirectActorCallArgWaitComplete, (int64_t tag), (override));
MOCK_METHOD(const BundleID &, GetBundleId, (), (const, override)); MOCK_METHOD(const BundleID &, GetBundleId, (), (const, override));
MOCK_METHOD(void, SetBundleId, (const BundleID &bundle_id), (override)); MOCK_METHOD(void, SetBundleId, (const BundleID &bundle_id), (override));
MOCK_METHOD(void, SetAllocatedInstances, MOCK_METHOD(void,
SetAllocatedInstances,
(const std::shared_ptr<TaskResourceInstances> &allocated_instances), (const std::shared_ptr<TaskResourceInstances> &allocated_instances),
(override)); (override));
MOCK_METHOD(std::shared_ptr<TaskResourceInstances>, GetAllocatedInstances, (), MOCK_METHOD(std::shared_ptr<TaskResourceInstances>,
GetAllocatedInstances,
(),
(override)); (override));
MOCK_METHOD(void, ClearAllocatedInstances, (), (override)); MOCK_METHOD(void, ClearAllocatedInstances, (), (override));
MOCK_METHOD(void, SetLifetimeAllocatedInstances, MOCK_METHOD(void,
SetLifetimeAllocatedInstances,
(const std::shared_ptr<TaskResourceInstances> &allocated_instances), (const std::shared_ptr<TaskResourceInstances> &allocated_instances),
(override)); (override));
MOCK_METHOD(std::shared_ptr<TaskResourceInstances>, GetLifetimeAllocatedInstances, (), MOCK_METHOD(std::shared_ptr<TaskResourceInstances>,
GetLifetimeAllocatedInstances,
(),
(override)); (override));
MOCK_METHOD(void, ClearLifetimeAllocatedInstances, (), (override)); MOCK_METHOD(void, ClearLifetimeAllocatedInstances, (), (override));
MOCK_METHOD(RayTask &, GetAssignedTask, (), (override)); MOCK_METHOD(RayTask &, GetAssignedTask, (), (override));

View file

@ -17,14 +17,20 @@ namespace raylet {
class MockWorkerPoolInterface : public WorkerPoolInterface { class MockWorkerPoolInterface : public WorkerPoolInterface {
public: public:
MOCK_METHOD(void, PopWorker, MOCK_METHOD(void,
(const TaskSpecification &task_spec, const PopWorkerCallback &callback, PopWorker,
(const TaskSpecification &task_spec,
const PopWorkerCallback &callback,
const std::string &allocated_instances_serialized_json), const std::string &allocated_instances_serialized_json),
(override)); (override));
MOCK_METHOD(void, PushWorker, (const std::shared_ptr<WorkerInterface> &worker), MOCK_METHOD(void,
PushWorker,
(const std::shared_ptr<WorkerInterface> &worker),
(override)); (override));
MOCK_METHOD(const std::vector<std::shared_ptr<WorkerInterface>>, MOCK_METHOD(const std::vector<std::shared_ptr<WorkerInterface>>,
GetAllRegisteredWorkers, (bool filter_dead_workers), (override)); GetAllRegisteredWorkers,
(bool filter_dead_workers),
(override));
}; };
} // namespace raylet } // namespace raylet
@ -35,24 +41,36 @@ namespace raylet {
class MockIOWorkerPoolInterface : public IOWorkerPoolInterface { class MockIOWorkerPoolInterface : public IOWorkerPoolInterface {
public: public:
MOCK_METHOD(void, PushSpillWorker, (const std::shared_ptr<WorkerInterface> &worker), MOCK_METHOD(void,
PushSpillWorker,
(const std::shared_ptr<WorkerInterface> &worker),
(override)); (override));
MOCK_METHOD(void, PopSpillWorker, MOCK_METHOD(void,
PopSpillWorker,
(std::function<void(std::shared_ptr<WorkerInterface>)> callback), (std::function<void(std::shared_ptr<WorkerInterface>)> callback),
(override)); (override));
MOCK_METHOD(void, PushRestoreWorker, (const std::shared_ptr<WorkerInterface> &worker), MOCK_METHOD(void,
PushRestoreWorker,
(const std::shared_ptr<WorkerInterface> &worker),
(override)); (override));
MOCK_METHOD(void, PopRestoreWorker, MOCK_METHOD(void,
PopRestoreWorker,
(std::function<void(std::shared_ptr<WorkerInterface>)> callback), (std::function<void(std::shared_ptr<WorkerInterface>)> callback),
(override)); (override));
MOCK_METHOD(void, PushDeleteWorker, (const std::shared_ptr<WorkerInterface> &worker), MOCK_METHOD(void,
PushDeleteWorker,
(const std::shared_ptr<WorkerInterface> &worker),
(override)); (override));
MOCK_METHOD(void, PopDeleteWorker, MOCK_METHOD(void,
PopDeleteWorker,
(std::function<void(std::shared_ptr<WorkerInterface>)> callback), (std::function<void(std::shared_ptr<WorkerInterface>)> callback),
(override)); (override));
MOCK_METHOD(void, PushUtilWorker, (const std::shared_ptr<WorkerInterface> &worker), MOCK_METHOD(void,
PushUtilWorker,
(const std::shared_ptr<WorkerInterface> &worker),
(override)); (override));
MOCK_METHOD(void, PopUtilWorker, MOCK_METHOD(void,
PopUtilWorker,
(std::function<void(std::shared_ptr<WorkerInterface>)> callback), (std::function<void(std::shared_ptr<WorkerInterface>)> callback),
(override)); (override));
}; };
@ -65,13 +83,16 @@ namespace raylet {
class MockWorkerPool : public WorkerPool { class MockWorkerPool : public WorkerPool {
public: public:
MOCK_METHOD(Process, StartProcess, MOCK_METHOD(Process,
StartProcess,
(const std::vector<std::string> &worker_command_args, (const std::vector<std::string> &worker_command_args,
const ProcessEnvironment &env), const ProcessEnvironment &env),
(override)); (override));
MOCK_METHOD(void, WarnAboutSize, (), (override)); MOCK_METHOD(void, WarnAboutSize, (), (override));
MOCK_METHOD(void, PopWorkerCallbackAsync, MOCK_METHOD(void,
(const PopWorkerCallback &callback, std::shared_ptr<WorkerInterface> worker, PopWorkerCallbackAsync,
(const PopWorkerCallback &callback,
std::shared_ptr<WorkerInterface> worker,
PopWorkerStatus status), PopWorkerStatus status),
(override)); (override));
}; };

View file

@ -16,7 +16,8 @@ namespace ray {
class MockPinObjectsInterface : public PinObjectsInterface { class MockPinObjectsInterface : public PinObjectsInterface {
public: public:
MOCK_METHOD(void, PinObjectIDs, MOCK_METHOD(void,
PinObjectIDs,
(const rpc::Address &caller_address, (const rpc::Address &caller_address,
const std::vector<ObjectID> &object_ids, const std::vector<ObjectID> &object_ids,
const ray::rpc::ClientCallback<ray::rpc::PinObjectIDsReply> &callback), const ray::rpc::ClientCallback<ray::rpc::PinObjectIDsReply> &callback),
@ -30,20 +31,28 @@ namespace ray {
class MockWorkerLeaseInterface : public WorkerLeaseInterface { class MockWorkerLeaseInterface : public WorkerLeaseInterface {
public: public:
MOCK_METHOD( MOCK_METHOD(
void, RequestWorkerLease, void,
(const rpc::TaskSpec &task_spec, bool grant_or_reject, RequestWorkerLease,
(const rpc::TaskSpec &task_spec,
bool grant_or_reject,
const ray::rpc::ClientCallback<ray::rpc::RequestWorkerLeaseReply> &callback, const ray::rpc::ClientCallback<ray::rpc::RequestWorkerLeaseReply> &callback,
const int64_t backlog_size, const bool is_selected_based_on_locality), const int64_t backlog_size,
const bool is_selected_based_on_locality),
(override)); (override));
MOCK_METHOD(ray::Status, ReturnWorker, MOCK_METHOD(ray::Status,
(int worker_port, const WorkerID &worker_id, bool disconnect_worker, ReturnWorker,
(int worker_port,
const WorkerID &worker_id,
bool disconnect_worker,
bool worker_exiting), bool worker_exiting),
(override)); (override));
MOCK_METHOD(void, ReleaseUnusedWorkers, MOCK_METHOD(void,
ReleaseUnusedWorkers,
(const std::vector<WorkerID> &workers_in_use, (const std::vector<WorkerID> &workers_in_use,
const rpc::ClientCallback<rpc::ReleaseUnusedWorkersReply> &callback), const rpc::ClientCallback<rpc::ReleaseUnusedWorkersReply> &callback),
(override)); (override));
MOCK_METHOD(void, CancelWorkerLease, MOCK_METHOD(void,
CancelWorkerLease,
(const TaskID &task_id, (const TaskID &task_id,
const rpc::ClientCallback<rpc::CancelWorkerLeaseReply> &callback), const rpc::ClientCallback<rpc::CancelWorkerLeaseReply> &callback),
(override)); (override));
@ -56,21 +65,25 @@ namespace ray {
class MockResourceReserveInterface : public ResourceReserveInterface { class MockResourceReserveInterface : public ResourceReserveInterface {
public: public:
MOCK_METHOD( MOCK_METHOD(
void, PrepareBundleResources, void,
PrepareBundleResources,
(const std::vector<std::shared_ptr<const BundleSpecification>> &bundle_specs, (const std::vector<std::shared_ptr<const BundleSpecification>> &bundle_specs,
const ray::rpc::ClientCallback<ray::rpc::PrepareBundleResourcesReply> &callback), const ray::rpc::ClientCallback<ray::rpc::PrepareBundleResourcesReply> &callback),
(override)); (override));
MOCK_METHOD( MOCK_METHOD(
void, CommitBundleResources, void,
CommitBundleResources,
(const std::vector<std::shared_ptr<const BundleSpecification>> &bundle_specs, (const std::vector<std::shared_ptr<const BundleSpecification>> &bundle_specs,
const ray::rpc::ClientCallback<ray::rpc::CommitBundleResourcesReply> &callback), const ray::rpc::ClientCallback<ray::rpc::CommitBundleResourcesReply> &callback),
(override)); (override));
MOCK_METHOD( MOCK_METHOD(
void, CancelResourceReserve, void,
CancelResourceReserve,
(const BundleSpecification &bundle_spec, (const BundleSpecification &bundle_spec,
const ray::rpc::ClientCallback<ray::rpc::CancelResourceReserveReply> &callback), const ray::rpc::ClientCallback<ray::rpc::CancelResourceReserveReply> &callback),
(override)); (override));
MOCK_METHOD(void, ReleaseUnusedBundles, MOCK_METHOD(void,
ReleaseUnusedBundles,
(const std::vector<rpc::Bundle> &bundles_in_use, (const std::vector<rpc::Bundle> &bundles_in_use,
const rpc::ClientCallback<rpc::ReleaseUnusedBundlesReply> &callback), const rpc::ClientCallback<rpc::ReleaseUnusedBundlesReply> &callback),
(override)); (override));
@ -82,7 +95,8 @@ namespace ray {
class MockDependencyWaiterInterface : public DependencyWaiterInterface { class MockDependencyWaiterInterface : public DependencyWaiterInterface {
public: public:
MOCK_METHOD(ray::Status, WaitForDirectActorCallArgs, MOCK_METHOD(ray::Status,
WaitForDirectActorCallArgs,
(const std::vector<rpc::ObjectReference> &references, int64_t tag), (const std::vector<rpc::ObjectReference> &references, int64_t tag),
(override)); (override));
}; };
@ -93,11 +107,13 @@ namespace ray {
class MockResourceTrackingInterface : public ResourceTrackingInterface { class MockResourceTrackingInterface : public ResourceTrackingInterface {
public: public:
MOCK_METHOD(void, UpdateResourceUsage, MOCK_METHOD(void,
UpdateResourceUsage,
(std::string & serialized_resource_usage_batch, (std::string & serialized_resource_usage_batch,
const rpc::ClientCallback<rpc::UpdateResourceUsageReply> &callback), const rpc::ClientCallback<rpc::UpdateResourceUsageReply> &callback),
(override)); (override));
MOCK_METHOD(void, RequestResourceReport, MOCK_METHOD(void,
RequestResourceReport,
(const rpc::ClientCallback<rpc::RequestResourceReportReply> &callback), (const rpc::ClientCallback<rpc::RequestResourceReportReply> &callback),
(override)); (override));
}; };
@ -108,71 +124,92 @@ namespace ray {
class MockRayletClientInterface : public RayletClientInterface { class MockRayletClientInterface : public RayletClientInterface {
public: public:
MOCK_METHOD(ray::Status, WaitForDirectActorCallArgs, MOCK_METHOD(ray::Status,
WaitForDirectActorCallArgs,
(const std::vector<rpc::ObjectReference> &references, int64_t tag), (const std::vector<rpc::ObjectReference> &references, int64_t tag),
(override)); (override));
MOCK_METHOD(void, ReportWorkerBacklog, MOCK_METHOD(void,
ReportWorkerBacklog,
(const WorkerID &worker_id, (const WorkerID &worker_id,
const std::vector<rpc::WorkerBacklogReport> &backlog_reports), const std::vector<rpc::WorkerBacklogReport> &backlog_reports),
(override)); (override));
MOCK_METHOD( MOCK_METHOD(
void, RequestWorkerLease, void,
(const rpc::TaskSpec &resource_spec, bool grant_or_reject, RequestWorkerLease,
(const rpc::TaskSpec &resource_spec,
bool grant_or_reject,
const ray::rpc::ClientCallback<ray::rpc::RequestWorkerLeaseReply> &callback, const ray::rpc::ClientCallback<ray::rpc::RequestWorkerLeaseReply> &callback,
const int64_t backlog_size, const bool is_selected_based_on_locality), const int64_t backlog_size,
const bool is_selected_based_on_locality),
(override)); (override));
MOCK_METHOD(ray::Status, ReturnWorker, MOCK_METHOD(ray::Status,
(int worker_port, const WorkerID &worker_id, bool disconnect_worker, ReturnWorker,
(int worker_port,
const WorkerID &worker_id,
bool disconnect_worker,
bool worker_exiting), bool worker_exiting),
(override)); (override));
MOCK_METHOD(void, ReleaseUnusedWorkers, MOCK_METHOD(void,
ReleaseUnusedWorkers,
(const std::vector<WorkerID> &workers_in_use, (const std::vector<WorkerID> &workers_in_use,
const rpc::ClientCallback<rpc::ReleaseUnusedWorkersReply> &callback), const rpc::ClientCallback<rpc::ReleaseUnusedWorkersReply> &callback),
(override)); (override));
MOCK_METHOD(void, CancelWorkerLease, MOCK_METHOD(void,
CancelWorkerLease,
(const TaskID &task_id, (const TaskID &task_id,
const rpc::ClientCallback<rpc::CancelWorkerLeaseReply> &callback), const rpc::ClientCallback<rpc::CancelWorkerLeaseReply> &callback),
(override)); (override));
MOCK_METHOD( MOCK_METHOD(
void, PrepareBundleResources, void,
PrepareBundleResources,
(const std::vector<std::shared_ptr<const BundleSpecification>> &bundle_specs, (const std::vector<std::shared_ptr<const BundleSpecification>> &bundle_specs,
const ray::rpc::ClientCallback<ray::rpc::PrepareBundleResourcesReply> &callback), const ray::rpc::ClientCallback<ray::rpc::PrepareBundleResourcesReply> &callback),
(override)); (override));
MOCK_METHOD( MOCK_METHOD(
void, CommitBundleResources, void,
CommitBundleResources,
(const std::vector<std::shared_ptr<const BundleSpecification>> &bundle_specs, (const std::vector<std::shared_ptr<const BundleSpecification>> &bundle_specs,
const ray::rpc::ClientCallback<ray::rpc::CommitBundleResourcesReply> &callback), const ray::rpc::ClientCallback<ray::rpc::CommitBundleResourcesReply> &callback),
(override)); (override));
MOCK_METHOD( MOCK_METHOD(
void, CancelResourceReserve, void,
CancelResourceReserve,
(const BundleSpecification &bundle_spec, (const BundleSpecification &bundle_spec,
const ray::rpc::ClientCallback<ray::rpc::CancelResourceReserveReply> &callback), const ray::rpc::ClientCallback<ray::rpc::CancelResourceReserveReply> &callback),
(override)); (override));
MOCK_METHOD(void, ReleaseUnusedBundles, MOCK_METHOD(void,
ReleaseUnusedBundles,
(const std::vector<rpc::Bundle> &bundles_in_use, (const std::vector<rpc::Bundle> &bundles_in_use,
const rpc::ClientCallback<rpc::ReleaseUnusedBundlesReply> &callback), const rpc::ClientCallback<rpc::ReleaseUnusedBundlesReply> &callback),
(override)); (override));
MOCK_METHOD(void, PinObjectIDs, MOCK_METHOD(void,
PinObjectIDs,
(const rpc::Address &caller_address, (const rpc::Address &caller_address,
const std::vector<ObjectID> &object_ids, const std::vector<ObjectID> &object_ids,
const ray::rpc::ClientCallback<ray::rpc::PinObjectIDsReply> &callback), const ray::rpc::ClientCallback<ray::rpc::PinObjectIDsReply> &callback),
(override)); (override));
MOCK_METHOD(void, GetSystemConfig, MOCK_METHOD(void,
GetSystemConfig,
(const rpc::ClientCallback<rpc::GetSystemConfigReply> &callback), (const rpc::ClientCallback<rpc::GetSystemConfigReply> &callback),
(override)); (override));
MOCK_METHOD(void, GetGcsServerAddress, MOCK_METHOD(void,
GetGcsServerAddress,
(const rpc::ClientCallback<rpc::GetGcsServerAddressReply> &callback), (const rpc::ClientCallback<rpc::GetGcsServerAddressReply> &callback),
(override)); (override));
MOCK_METHOD(void, UpdateResourceUsage, MOCK_METHOD(void,
UpdateResourceUsage,
(std::string & serialized_resource_usage_batch, (std::string & serialized_resource_usage_batch,
const rpc::ClientCallback<rpc::UpdateResourceUsageReply> &callback), const rpc::ClientCallback<rpc::UpdateResourceUsageReply> &callback),
(override)); (override));
MOCK_METHOD(void, RequestResourceReport, MOCK_METHOD(void,
RequestResourceReport,
(const rpc::ClientCallback<rpc::RequestResourceReportReply> &callback), (const rpc::ClientCallback<rpc::RequestResourceReportReply> &callback),
(override)); (override));
MOCK_METHOD(void, ShutdownRaylet, MOCK_METHOD(void,
(const NodeID &node_id, bool graceful, ShutdownRaylet,
(const NodeID &node_id,
bool graceful,
const rpc::ClientCallback<rpc::ShutdownRayletReply> &callback), const rpc::ClientCallback<rpc::ShutdownRayletReply> &callback),
(override)); (override));
}; };

View file

@ -29,86 +29,108 @@ class MockCoreWorkerClientInterface : public ray::pubsub::MockSubscriberClientIn
public CoreWorkerClientInterface { public CoreWorkerClientInterface {
public: public:
MOCK_METHOD(const rpc::Address &, Addr, (), (const, override)); MOCK_METHOD(const rpc::Address &, Addr, (), (const, override));
MOCK_METHOD(void, PushActorTask, MOCK_METHOD(void,
(std::unique_ptr<PushTaskRequest> request, bool skip_queue, PushActorTask,
(std::unique_ptr<PushTaskRequest> request,
bool skip_queue,
const ClientCallback<PushTaskReply> &callback), const ClientCallback<PushTaskReply> &callback),
(override)); (override));
MOCK_METHOD(void, PushNormalTask, MOCK_METHOD(void,
PushNormalTask,
(std::unique_ptr<PushTaskRequest> request, (std::unique_ptr<PushTaskRequest> request,
const ClientCallback<PushTaskReply> &callback), const ClientCallback<PushTaskReply> &callback),
(override)); (override));
MOCK_METHOD(void, DirectActorCallArgWaitComplete, MOCK_METHOD(void,
DirectActorCallArgWaitComplete,
(const DirectActorCallArgWaitCompleteRequest &request, (const DirectActorCallArgWaitCompleteRequest &request,
const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback), const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback),
(override)); (override));
MOCK_METHOD(void, GetObjectStatus, MOCK_METHOD(void,
GetObjectStatus,
(const GetObjectStatusRequest &request, (const GetObjectStatusRequest &request,
const ClientCallback<GetObjectStatusReply> &callback), const ClientCallback<GetObjectStatusReply> &callback),
(override)); (override));
MOCK_METHOD(void, WaitForActorOutOfScope, MOCK_METHOD(void,
WaitForActorOutOfScope,
(const WaitForActorOutOfScopeRequest &request, (const WaitForActorOutOfScopeRequest &request,
const ClientCallback<WaitForActorOutOfScopeReply> &callback), const ClientCallback<WaitForActorOutOfScopeReply> &callback),
(override)); (override));
MOCK_METHOD(void, PubsubLongPolling, MOCK_METHOD(void,
PubsubLongPolling,
(const PubsubLongPollingRequest &request, (const PubsubLongPollingRequest &request,
const ClientCallback<PubsubLongPollingReply> &callback), const ClientCallback<PubsubLongPollingReply> &callback),
(override)); (override));
MOCK_METHOD(void, PubsubCommandBatch, MOCK_METHOD(void,
PubsubCommandBatch,
(const PubsubCommandBatchRequest &request, (const PubsubCommandBatchRequest &request,
const ClientCallback<PubsubCommandBatchReply> &callback), const ClientCallback<PubsubCommandBatchReply> &callback),
(override)); (override));
MOCK_METHOD(void, UpdateObjectLocationBatch, MOCK_METHOD(void,
UpdateObjectLocationBatch,
(const UpdateObjectLocationBatchRequest &request, (const UpdateObjectLocationBatchRequest &request,
const ClientCallback<UpdateObjectLocationBatchReply> &callback), const ClientCallback<UpdateObjectLocationBatchReply> &callback),
(override)); (override));
MOCK_METHOD(void, GetObjectLocationsOwner, MOCK_METHOD(void,
GetObjectLocationsOwner,
(const GetObjectLocationsOwnerRequest &request, (const GetObjectLocationsOwnerRequest &request,
const ClientCallback<GetObjectLocationsOwnerReply> &callback), const ClientCallback<GetObjectLocationsOwnerReply> &callback),
(override)); (override));
MOCK_METHOD(void, KillActor, MOCK_METHOD(void,
KillActor,
(const KillActorRequest &request, (const KillActorRequest &request,
const ClientCallback<KillActorReply> &callback), const ClientCallback<KillActorReply> &callback),
(override)); (override));
MOCK_METHOD(void, CancelTask, MOCK_METHOD(void,
CancelTask,
(const CancelTaskRequest &request, (const CancelTaskRequest &request,
const ClientCallback<CancelTaskReply> &callback), const ClientCallback<CancelTaskReply> &callback),
(override)); (override));
MOCK_METHOD(void, RemoteCancelTask, MOCK_METHOD(void,
RemoteCancelTask,
(const RemoteCancelTaskRequest &request, (const RemoteCancelTaskRequest &request,
const ClientCallback<RemoteCancelTaskReply> &callback), const ClientCallback<RemoteCancelTaskReply> &callback),
(override)); (override));
MOCK_METHOD(void, GetCoreWorkerStats, MOCK_METHOD(void,
GetCoreWorkerStats,
(const GetCoreWorkerStatsRequest &request, (const GetCoreWorkerStatsRequest &request,
const ClientCallback<GetCoreWorkerStatsReply> &callback), const ClientCallback<GetCoreWorkerStatsReply> &callback),
(override)); (override));
MOCK_METHOD(void, LocalGC, MOCK_METHOD(void,
LocalGC,
(const LocalGCRequest &request, (const LocalGCRequest &request,
const ClientCallback<LocalGCReply> &callback), const ClientCallback<LocalGCReply> &callback),
(override)); (override));
MOCK_METHOD(void, SpillObjects, MOCK_METHOD(void,
SpillObjects,
(const SpillObjectsRequest &request, (const SpillObjectsRequest &request,
const ClientCallback<SpillObjectsReply> &callback), const ClientCallback<SpillObjectsReply> &callback),
(override)); (override));
MOCK_METHOD(void, RestoreSpilledObjects, MOCK_METHOD(void,
RestoreSpilledObjects,
(const RestoreSpilledObjectsRequest &request, (const RestoreSpilledObjectsRequest &request,
const ClientCallback<RestoreSpilledObjectsReply> &callback), const ClientCallback<RestoreSpilledObjectsReply> &callback),
(override)); (override));
MOCK_METHOD(void, DeleteSpilledObjects, MOCK_METHOD(void,
DeleteSpilledObjects,
(const DeleteSpilledObjectsRequest &request, (const DeleteSpilledObjectsRequest &request,
const ClientCallback<DeleteSpilledObjectsReply> &callback), const ClientCallback<DeleteSpilledObjectsReply> &callback),
(override)); (override));
MOCK_METHOD(void, AddSpilledUrl, MOCK_METHOD(void,
AddSpilledUrl,
(const AddSpilledUrlRequest &request, (const AddSpilledUrlRequest &request,
const ClientCallback<AddSpilledUrlReply> &callback), const ClientCallback<AddSpilledUrlReply> &callback),
(override)); (override));
MOCK_METHOD(void, PlasmaObjectReady, MOCK_METHOD(void,
PlasmaObjectReady,
(const PlasmaObjectReadyRequest &request, (const PlasmaObjectReadyRequest &request,
const ClientCallback<PlasmaObjectReadyReply> &callback), const ClientCallback<PlasmaObjectReadyReply> &callback),
(override)); (override));
MOCK_METHOD(void, Exit, MOCK_METHOD(void,
Exit,
(const ExitRequest &request, const ClientCallback<ExitReply> &callback), (const ExitRequest &request, const ClientCallback<ExitReply> &callback),
(override)); (override));
MOCK_METHOD(void, AssignObjectOwner, MOCK_METHOD(void,
AssignObjectOwner,
(const AssignObjectOwnerRequest &request, (const AssignObjectOwnerRequest &request,
const ClientCallback<AssignObjectOwnerReply> &callback), const ClientCallback<AssignObjectOwnerReply> &callback),
(override)); (override));

View file

@ -17,7 +17,8 @@
#include <boost/asio.hpp> #include <boost/asio.hpp>
inline std::shared_ptr<boost::asio::deadline_timer> execute_after_us( inline std::shared_ptr<boost::asio::deadline_timer> execute_after_us(
instrumented_io_context &io_context, std::function<void()> fn, instrumented_io_context &io_context,
std::function<void()> fn,
int64_t delay_microseconds) { int64_t delay_microseconds) {
auto timer = std::make_shared<boost::asio::deadline_timer>(io_context); auto timer = std::make_shared<boost::asio::deadline_timer>(io_context);
timer->expires_from_now(boost::posix_time::microseconds(delay_microseconds)); timer->expires_from_now(boost::posix_time::microseconds(delay_microseconds));

View file

@ -17,6 +17,7 @@
#include <atomic> #include <atomic>
#include <boost/asio.hpp> #include <boost/asio.hpp>
#include <thread> #include <thread>
#include "ray/common/asio/instrumented_io_context.h" #include "ray/common/asio/instrumented_io_context.h"
namespace ray { namespace ray {

View file

@ -31,7 +31,8 @@ PeriodicalRunner::~PeriodicalRunner() {
RAY_LOG(DEBUG) << "PeriodicalRunner is destructed"; RAY_LOG(DEBUG) << "PeriodicalRunner is destructed";
} }
void PeriodicalRunner::RunFnPeriodically(std::function<void()> fn, uint64_t period_ms, void PeriodicalRunner::RunFnPeriodically(std::function<void()> fn,
uint64_t period_ms,
const std::string name) { const std::string name) {
if (period_ms > 0) { if (period_ms > 0) {
auto timer = std::make_shared<boost::asio::deadline_timer>(io_service_); auto timer = std::make_shared<boost::asio::deadline_timer>(io_service_);
@ -53,13 +54,14 @@ void PeriodicalRunner::RunFnPeriodically(std::function<void()> fn, uint64_t peri
} }
void PeriodicalRunner::DoRunFnPeriodically( void PeriodicalRunner::DoRunFnPeriodically(
const std::function<void()> &fn, boost::posix_time::milliseconds period, const std::function<void()> &fn,
boost::posix_time::milliseconds period,
std::shared_ptr<boost::asio::deadline_timer> timer) { std::shared_ptr<boost::asio::deadline_timer> timer) {
fn(); fn();
absl::MutexLock lock(&mutex_); absl::MutexLock lock(&mutex_);
timer->expires_from_now(period); timer->expires_from_now(period);
timer->async_wait([this, fn = std::move(fn), period, timer->async_wait([this, fn = std::move(fn), period, timer = std::move(timer)](
timer = std::move(timer)](const boost::system::error_code &error) { const boost::system::error_code &error) {
if (error == boost::asio::error::operation_aborted) { if (error == boost::asio::error::operation_aborted) {
// `operation_aborted` is set when `timer` is canceled or destroyed. // `operation_aborted` is set when `timer` is canceled or destroyed.
// The Monitor lifetime may be short than the object who use it. (e.g. // The Monitor lifetime may be short than the object who use it. (e.g.
@ -72,8 +74,10 @@ void PeriodicalRunner::DoRunFnPeriodically(
} }
void PeriodicalRunner::DoRunFnPeriodicallyInstrumented( void PeriodicalRunner::DoRunFnPeriodicallyInstrumented(
const std::function<void()> &fn, boost::posix_time::milliseconds period, const std::function<void()> &fn,
std::shared_ptr<boost::asio::deadline_timer> timer, const std::string name) { boost::posix_time::milliseconds period,
std::shared_ptr<boost::asio::deadline_timer> timer,
const std::string name) {
fn(); fn();
absl::MutexLock lock(&mutex_); absl::MutexLock lock(&mutex_);
timer->expires_from_now(period); timer->expires_from_now(period);
@ -81,7 +85,10 @@ void PeriodicalRunner::DoRunFnPeriodicallyInstrumented(
// which the handler was elgible to execute on the event loop but was queued by the // which the handler was elgible to execute on the event loop but was queued by the
// event loop. // event loop.
auto stats_handle = io_service_.stats().RecordStart(name, period.total_nanoseconds()); auto stats_handle = io_service_.stats().RecordStart(name, period.total_nanoseconds());
timer->async_wait([this, fn = std::move(fn), period, timer = std::move(timer), timer->async_wait([this,
fn = std::move(fn),
period,
timer = std::move(timer),
stats_handle = std::move(stats_handle), stats_handle = std::move(stats_handle),
name](const boost::system::error_code &error) { name](const boost::system::error_code &error) {
io_service_.stats().RecordExecution( io_service_.stats().RecordExecution(

View file

@ -34,7 +34,8 @@ class PeriodicalRunner {
~PeriodicalRunner(); ~PeriodicalRunner();
void RunFnPeriodically(std::function<void()> fn, uint64_t period_ms, void RunFnPeriodically(std::function<void()> fn,
uint64_t period_ms,
const std::string name = "UNKNOWN") LOCKS_EXCLUDED(mutex_); const std::string name = "UNKNOWN") LOCKS_EXCLUDED(mutex_);
private: private:

View file

@ -143,7 +143,8 @@ class SharedMemoryBuffer : public Buffer {
} }
static std::shared_ptr<SharedMemoryBuffer> Slice(const std::shared_ptr<Buffer> &buffer, static std::shared_ptr<SharedMemoryBuffer> Slice(const std::shared_ptr<Buffer> &buffer,
int64_t offset, int64_t size) { int64_t offset,
int64_t size) {
return std::make_shared<SharedMemoryBuffer>(buffer, offset, size); return std::make_shared<SharedMemoryBuffer>(buffer, offset, size);
} }

View file

@ -139,7 +139,8 @@ std::string FormatPlacementGroupResource(const std::string &original_resource_na
original_resource_name, bundle_spec.PlacementGroupId(), bundle_spec.Index()); original_resource_name, bundle_spec.PlacementGroupId(), bundle_spec.Index());
} }
bool IsBundleIndex(const std::string &resource, const PlacementGroupID &group_id, bool IsBundleIndex(const std::string &resource,
const PlacementGroupID &group_id,
const int bundle_index) { const int bundle_index) {
return resource.find(kGroupKeyword + std::to_string(bundle_index) + "_" + return resource.find(kGroupKeyword + std::to_string(bundle_index) + "_" +
group_id.Hex()) != std::string::npos; group_id.Hex()) != std::string::npos;

View file

@ -102,7 +102,8 @@ std::string FormatPlacementGroupResource(const std::string &original_resource_na
const BundleSpecification &bundle_spec); const BundleSpecification &bundle_spec);
/// Return whether a formatted resource is a bundle of the given index. /// Return whether a formatted resource is a bundle of the given index.
bool IsBundleIndex(const std::string &resource, const PlacementGroupID &group_id, bool IsBundleIndex(const std::string &resource,
const PlacementGroupID &group_id,
const int bundle_index); const int bundle_index);
/// Return the original resource name of the placement group resource. /// Return the original resource name of the placement group resource.

View file

@ -30,8 +30,10 @@
namespace ray { namespace ray {
Status ConnectSocketRetry(local_stream_socket &socket, const std::string &endpoint, Status ConnectSocketRetry(local_stream_socket &socket,
int num_retries, int64_t timeout_in_ms) { const std::string &endpoint,
int num_retries,
int64_t timeout_in_ms) {
RAY_CHECK(num_retries != 0); RAY_CHECK(num_retries != 0);
// Pick the default values if the user did not specify. // Pick the default values if the user did not specify.
if (num_retries < 0) { if (num_retries < 0) {
@ -114,7 +116,8 @@ void ServerConnection::WriteBufferAsync(
const auto stats_handle = const auto stats_handle =
io_context.stats().RecordStart("ClientConnection.async_write.WriteBufferAsync"); io_context.stats().RecordStart("ClientConnection.async_write.WriteBufferAsync");
boost::asio::async_write( boost::asio::async_write(
socket_, buffer, socket_,
buffer,
[handler, stats_handle = std::move(stats_handle)]( [handler, stats_handle = std::move(stats_handle)](
const boost::system::error_code &ec, size_t bytes_transferred) { const boost::system::error_code &ec, size_t bytes_transferred) {
EventTracker::RecordExecution( EventTracker::RecordExecution(
@ -123,7 +126,8 @@ void ServerConnection::WriteBufferAsync(
}); });
} else { } else {
boost::asio::async_write( boost::asio::async_write(
socket_, buffer, socket_,
buffer,
[handler](const boost::system::error_code &ec, size_t bytes_transferred) { [handler](const boost::system::error_code &ec, size_t bytes_transferred) {
handler(boost_to_ray_status(ec)); handler(boost_to_ray_status(ec));
}); });
@ -162,7 +166,8 @@ void ServerConnection::ReadBufferAsync(
const auto stats_handle = const auto stats_handle =
io_context.stats().RecordStart("ClientConnection.async_read.ReadBufferAsync"); io_context.stats().RecordStart("ClientConnection.async_read.ReadBufferAsync");
boost::asio::async_read( boost::asio::async_read(
socket_, buffer, socket_,
buffer,
[handler, stats_handle = std::move(stats_handle)]( [handler, stats_handle = std::move(stats_handle)](
const boost::system::error_code &ec, size_t bytes_transferred) { const boost::system::error_code &ec, size_t bytes_transferred) {
EventTracker::RecordExecution( EventTracker::RecordExecution(
@ -171,14 +176,16 @@ void ServerConnection::ReadBufferAsync(
}); });
} else { } else {
boost::asio::async_read( boost::asio::async_read(
socket_, buffer, socket_,
buffer,
[handler](const boost::system::error_code &ec, size_t bytes_transferred) { [handler](const boost::system::error_code &ec, size_t bytes_transferred) {
handler(boost_to_ray_status(ec)); handler(boost_to_ray_status(ec));
}); });
} }
} }
ray::Status ServerConnection::WriteMessage(int64_t type, int64_t length, ray::Status ServerConnection::WriteMessage(int64_t type,
int64_t length,
const uint8_t *message) { const uint8_t *message) {
sync_writes_ += 1; sync_writes_ += 1;
bytes_written_ += length; bytes_written_ += length;
@ -218,7 +225,9 @@ Status ServerConnection::ReadMessage(int64_t type, std::vector<uint8_t> *message
} }
void ServerConnection::WriteMessageAsync( void ServerConnection::WriteMessageAsync(
int64_t type, int64_t length, const uint8_t *message, int64_t type,
int64_t length,
const uint8_t *message,
const std::function<void(const ray::Status &)> &handler) { const std::function<void(const ray::Status &)> &handler) {
async_writes_ += 1; async_writes_ += 1;
bytes_written_ += length; bytes_written_ += length;
@ -294,8 +303,12 @@ void ServerConnection::DoAsyncWrites() {
const auto stats_handle = const auto stats_handle =
io_context.stats().RecordStart("ClientConnection.async_write.DoAsyncWrites"); io_context.stats().RecordStart("ClientConnection.async_write.DoAsyncWrites");
boost::asio::async_write( boost::asio::async_write(
socket_, message_buffers, socket_,
[this, this_ptr, num_messages, call_handlers, message_buffers,
[this,
this_ptr,
num_messages,
call_handlers,
stats_handle = std::move(stats_handle)](const boost::system::error_code &error, stats_handle = std::move(stats_handle)](const boost::system::error_code &error,
size_t bytes_transferred) { size_t bytes_transferred) {
EventTracker::RecordExecution( EventTracker::RecordExecution(
@ -319,7 +332,8 @@ void ServerConnection::DoAsyncWrites() {
}); });
} else { } else {
boost::asio::async_write( boost::asio::async_write(
ServerConnection::socket_, message_buffers, ServerConnection::socket_,
message_buffers,
[this, this_ptr, num_messages, call_handlers]( [this, this_ptr, num_messages, call_handlers](
const boost::system::error_code &error, size_t bytes_transferred) { const boost::system::error_code &error, size_t bytes_transferred) {
ray::Status status = boost_to_ray_status(error); ray::Status status = boost_to_ray_status(error);
@ -341,22 +355,30 @@ void ServerConnection::DoAsyncWrites() {
} }
std::shared_ptr<ClientConnection> ClientConnection::Create( std::shared_ptr<ClientConnection> ClientConnection::Create(
ClientHandler &client_handler, MessageHandler &message_handler, ClientHandler &client_handler,
local_stream_socket &&socket, const std::string &debug_label, MessageHandler &message_handler,
const std::vector<std::string> &message_type_enum_names, int64_t error_message_type, local_stream_socket &&socket,
const std::string &debug_label,
const std::vector<std::string> &message_type_enum_names,
int64_t error_message_type,
const std::vector<uint8_t> &error_message_data) { const std::vector<uint8_t> &error_message_data) {
std::shared_ptr<ClientConnection> self(new ClientConnection( std::shared_ptr<ClientConnection> self(new ClientConnection(message_handler,
message_handler, std::move(socket), debug_label, message_type_enum_names, std::move(socket),
error_message_type, error_message_data)); debug_label,
message_type_enum_names,
error_message_type,
error_message_data));
// Let our manager process our new connection. // Let our manager process our new connection.
client_handler(*self); client_handler(*self);
return self; return self;
} }
ClientConnection::ClientConnection( ClientConnection::ClientConnection(
MessageHandler &message_handler, local_stream_socket &&socket, MessageHandler &message_handler,
local_stream_socket &&socket,
const std::string &debug_label, const std::string &debug_label,
const std::vector<std::string> &message_type_enum_names, int64_t error_message_type, const std::vector<std::string> &message_type_enum_names,
int64_t error_message_type,
const std::vector<uint8_t> &error_message_data) const std::vector<uint8_t> &error_message_data)
: ServerConnection(std::move(socket)), : ServerConnection(std::move(socket)),
registered_(false), registered_(false),
@ -386,7 +408,8 @@ void ClientConnection::ProcessMessages() {
const auto stats_handle = const auto stats_handle =
io_context.stats().RecordStart("ClientConnection.async_read.ReadBufferAsync"); io_context.stats().RecordStart("ClientConnection.async_read.ReadBufferAsync");
boost::asio::async_read( boost::asio::async_read(
ServerConnection::socket_, header, ServerConnection::socket_,
header,
[this, this_ptr, stats_handle = std::move(stats_handle)]( [this, this_ptr, stats_handle = std::move(stats_handle)](
const boost::system::error_code &ec, size_t bytes_transferred) { const boost::system::error_code &ec, size_t bytes_transferred) {
EventTracker::RecordExecution( EventTracker::RecordExecution(
@ -394,7 +417,8 @@ void ClientConnection::ProcessMessages() {
std::move(stats_handle)); std::move(stats_handle));
}); });
} else { } else {
boost::asio::async_read(ServerConnection::socket_, header, boost::asio::async_read(ServerConnection::socket_,
header,
boost::bind(&ClientConnection::ProcessMessageHeader, boost::bind(&ClientConnection::ProcessMessageHeader,
shared_ClientConnection_from_this(), shared_ClientConnection_from_this(),
boost::asio::placeholders::error)); boost::asio::placeholders::error));
@ -428,14 +452,16 @@ void ClientConnection::ProcessMessageHeader(const boost::system::error_code &err
const auto stats_handle = const auto stats_handle =
io_context.stats().RecordStart("ClientConnection.async_read.ReadBufferAsync"); io_context.stats().RecordStart("ClientConnection.async_read.ReadBufferAsync");
boost::asio::async_read( boost::asio::async_read(
ServerConnection::socket_, boost::asio::buffer(read_message_), ServerConnection::socket_,
boost::asio::buffer(read_message_),
[this, this_ptr, stats_handle = std::move(stats_handle)]( [this, this_ptr, stats_handle = std::move(stats_handle)](
const boost::system::error_code &ec, size_t bytes_transferred) { const boost::system::error_code &ec, size_t bytes_transferred) {
EventTracker::RecordExecution([this, this_ptr, ec]() { ProcessMessage(ec); }, EventTracker::RecordExecution([this, this_ptr, ec]() { ProcessMessage(ec); },
std::move(stats_handle)); std::move(stats_handle));
}); });
} else { } else {
boost::asio::async_read(ServerConnection::socket_, boost::asio::buffer(read_message_), boost::asio::async_read(ServerConnection::socket_,
boost::asio::buffer(read_message_),
boost::bind(&ClientConnection::ProcessMessage, boost::bind(&ClientConnection::ProcessMessage,
shared_ClientConnection_from_this(), shared_ClientConnection_from_this(),
boost::asio::placeholders::error)); boost::asio::placeholders::error));

View file

@ -31,8 +31,10 @@ typedef boost::asio::generic::stream_protocol local_stream_protocol;
typedef boost::asio::basic_stream_socket<local_stream_protocol> local_stream_socket; typedef boost::asio::basic_stream_socket<local_stream_protocol> local_stream_socket;
/// Connect to a socket with retry times. /// Connect to a socket with retry times.
Status ConnectSocketRetry(local_stream_socket &socket, const std::string &endpoint, Status ConnectSocketRetry(local_stream_socket &socket,
int num_retries = -1, int64_t timeout_in_ms = -1); const std::string &endpoint,
int num_retries = -1,
int64_t timeout_in_ms = -1);
/// \typename ServerConnection /// \typename ServerConnection
/// ///
@ -63,7 +65,9 @@ class ServerConnection : public std::enable_shared_from_this<ServerConnection> {
/// \param length The size in bytes of the message. /// \param length The size in bytes of the message.
/// \param message A pointer to the message buffer. /// \param message A pointer to the message buffer.
/// \param handler A callback to run on write completion. /// \param handler A callback to run on write completion.
void WriteMessageAsync(int64_t type, int64_t length, const uint8_t *message, void WriteMessageAsync(int64_t type,
int64_t length,
const uint8_t *message,
const std::function<void(const ray::Status &)> &handler); const std::function<void(const ray::Status &)> &handler);
/// Read a message from the client. /// Read a message from the client.
@ -169,8 +173,8 @@ class ServerConnection : public std::enable_shared_from_this<ServerConnection> {
class ClientConnection; class ClientConnection;
using ClientHandler = std::function<void(ClientConnection &)>; using ClientHandler = std::function<void(ClientConnection &)>;
using MessageHandler = std::function<void(std::shared_ptr<ClientConnection>, int64_t, using MessageHandler = std::function<void(
const std::vector<uint8_t> &)>; std::shared_ptr<ClientConnection>, int64_t, const std::vector<uint8_t> &)>;
static std::vector<uint8_t> _dummy_error_message_data; static std::vector<uint8_t> _dummy_error_message_data;
/// \typename ClientConnection /// \typename ClientConnection
@ -195,9 +199,12 @@ class ClientConnection : public ServerConnection {
/// \param error_message_data the companion data to the error message type. /// \param error_message_data the companion data to the error message type.
/// \return std::shared_ptr<ClientConnection>. /// \return std::shared_ptr<ClientConnection>.
static std::shared_ptr<ClientConnection> Create( static std::shared_ptr<ClientConnection> Create(
ClientHandler &new_client_handler, MessageHandler &message_handler, ClientHandler &new_client_handler,
local_stream_socket &&socket, const std::string &debug_label, MessageHandler &message_handler,
const std::vector<std::string> &message_type_enum_names, int64_t error_message_type, local_stream_socket &&socket,
const std::string &debug_label,
const std::vector<std::string> &message_type_enum_names,
int64_t error_message_type,
const std::vector<uint8_t> &error_message_data = _dummy_error_message_data); const std::vector<uint8_t> &error_message_data = _dummy_error_message_data);
std::shared_ptr<ClientConnection> shared_ClientConnection_from_this() { std::shared_ptr<ClientConnection> shared_ClientConnection_from_this() {
@ -215,9 +222,11 @@ class ClientConnection : public ServerConnection {
protected: protected:
/// A protected constructor for a node client connection. /// A protected constructor for a node client connection.
ClientConnection( ClientConnection(
MessageHandler &message_handler, local_stream_socket &&socket, MessageHandler &message_handler,
local_stream_socket &&socket,
const std::string &debug_label, const std::string &debug_label,
const std::vector<std::string> &message_type_enum_names, int64_t error_message_type, const std::vector<std::string> &message_type_enum_names,
int64_t error_message_type,
const std::vector<uint8_t> &error_message_data = _dummy_error_message_data); const std::vector<uint8_t> &error_message_data = _dummy_error_message_data);
/// Process an error from the last operation, then process the message /// Process an error from the last operation, then process the message
/// header from the client. /// header from the client.

View file

@ -71,7 +71,9 @@ std::shared_ptr<StatsHandle> EventTracker::RecordStart(
ray::stats::STATS_operation_count.Record(curr_count, name); ray::stats::STATS_operation_count.Record(curr_count, name);
ray::stats::STATS_operation_active_count.Record(curr_count, name); ray::stats::STATS_operation_active_count.Record(curr_count, name);
return std::make_shared<StatsHandle>( return std::make_shared<StatsHandle>(
name, absl::GetCurrentTimeNanos() + expected_queueing_delay_ns, std::move(stats), name,
absl::GetCurrentTimeNanos() + expected_queueing_delay_ns,
std::move(stats),
global_stats_); global_stats_);
} }
@ -165,7 +167,8 @@ std::vector<std::pair<std::string, EventStats>> EventTracker::get_event_stats()
absl::ReaderMutexLock lock(&mutex_); absl::ReaderMutexLock lock(&mutex_);
std::vector<std::pair<std::string, EventStats>> stats; std::vector<std::pair<std::string, EventStats>> stats;
stats.reserve(post_handler_stats_.size()); stats.reserve(post_handler_stats_.size());
std::transform(post_handler_stats_.begin(), post_handler_stats_.end(), std::transform(post_handler_stats_.begin(),
post_handler_stats_.end(),
std::back_inserter(stats), std::back_inserter(stats),
[](const std::pair<std::string, std::shared_ptr<GuardedEventStats>> &p) { [](const std::pair<std::string, std::shared_ptr<GuardedEventStats>> &p) {
return std::make_pair(p.first, to_event_stats_view(p.second)); return std::make_pair(p.first, to_event_stats_view(p.second));
@ -181,7 +184,8 @@ std::string EventTracker::StatsString() const {
} }
auto stats = get_event_stats(); auto stats = get_event_stats();
// Sort stats by cumulative count, outside of the table lock. // Sort stats by cumulative count, outside of the table lock.
sort(stats.begin(), stats.end(), sort(stats.begin(),
stats.end(),
[](const std::pair<std::string, EventStats> &a, [](const std::pair<std::string, EventStats> &a,
const std::pair<std::string, EventStats> &b) { const std::pair<std::string, EventStats> &b) {
return a.second.cum_count > b.second.cum_count; return a.second.cum_count > b.second.cum_count;

View file

@ -71,7 +71,8 @@ struct StatsHandle {
std::shared_ptr<GuardedGlobalStats> global_stats; std::shared_ptr<GuardedGlobalStats> global_stats;
std::atomic<bool> execution_recorded; std::atomic<bool> execution_recorded;
StatsHandle(std::string event_name_, int64_t start_time_, StatsHandle(std::string event_name_,
int64_t start_time_,
std::shared_ptr<GuardedEventStats> handler_stats_, std::shared_ptr<GuardedEventStats> handler_stats_,
std::shared_ptr<GuardedGlobalStats> global_stats_) std::shared_ptr<GuardedGlobalStats> global_stats_)
: event_name(std::move(event_name_)), : event_name(std::move(event_name_)),

View file

@ -33,8 +33,10 @@ FunctionDescriptor FunctionDescriptorBuilder::BuildJava(const std::string &class
} }
FunctionDescriptor FunctionDescriptorBuilder::BuildPython( FunctionDescriptor FunctionDescriptorBuilder::BuildPython(
const std::string &module_name, const std::string &class_name, const std::string &module_name,
const std::string &function_name, const std::string &function_hash) { const std::string &class_name,
const std::string &function_name,
const std::string &function_hash) {
rpc::FunctionDescriptor descriptor; rpc::FunctionDescriptor descriptor;
auto typed_descriptor = descriptor.mutable_python_function_descriptor(); auto typed_descriptor = descriptor.mutable_python_function_descriptor();
typed_descriptor->set_module_name(module_name); typed_descriptor->set_module_name(module_name);

View file

@ -82,14 +82,16 @@ inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) {
} else { } else {
// Unlike `UNKNOWN`, `ABORTED` is never generated by the library, so using it means // Unlike `UNKNOWN`, `ABORTED` is never generated by the library, so using it means
// more robust. // more robust.
return grpc::Status(grpc::StatusCode::ABORTED, ray_status.CodeAsString(), return grpc::Status(
ray_status.message()); grpc::StatusCode::ABORTED, ray_status.CodeAsString(), ray_status.message());
} }
} }
inline std::string GrpcStatusToRayStatusMessage(const grpc::Status &grpc_status) { inline std::string GrpcStatusToRayStatusMessage(const grpc::Status &grpc_status) {
return absl::StrCat("RPC Error message: ", grpc_status.error_message(), return absl::StrCat("RPC Error message: ",
"; RPC Error details: ", grpc_status.error_details()); grpc_status.error_message(),
"; RPC Error details: ",
grpc_status.error_details());
} }
/// Helper function that converts a gRPC status to ray status. /// Helper function that converts a gRPC status to ray status.
@ -135,8 +137,8 @@ inline std::vector<ID> IdVectorFromProtobuf(
const ::google::protobuf::RepeatedPtrField<::std::string> &pb_repeated) { const ::google::protobuf::RepeatedPtrField<::std::string> &pb_repeated) {
auto str_vec = VectorFromProtobuf(pb_repeated); auto str_vec = VectorFromProtobuf(pb_repeated);
std::vector<ID> ret; std::vector<ID> ret;
std::transform(str_vec.begin(), str_vec.end(), std::back_inserter(ret), std::transform(
&ID::FromBinary); str_vec.begin(), str_vec.end(), std::back_inserter(ret), &ID::FromBinary);
return ret; return ret;
} }

View file

@ -40,14 +40,17 @@ uint64_t MurmurHash64A(const void *key, int len, unsigned int seed);
/// A helper function to generate the unique bytes by hash. /// A helper function to generate the unique bytes by hash.
__suppress_ubsan__("undefined") std::string __suppress_ubsan__("undefined") std::string
GenerateUniqueBytes(const JobID &job_id, const TaskID &parent_task_id, GenerateUniqueBytes(const JobID &job_id,
size_t parent_task_counter, size_t extra_bytes, size_t length) { const TaskID &parent_task_id,
size_t parent_task_counter,
size_t extra_bytes,
size_t length) {
RAY_CHECK(length <= DIGEST_SIZE); RAY_CHECK(length <= DIGEST_SIZE);
SHA256_CTX ctx; SHA256_CTX ctx;
sha256_init(&ctx); sha256_init(&ctx);
sha256_update(&ctx, reinterpret_cast<const BYTE *>(job_id.Data()), job_id.Size()); sha256_update(&ctx, reinterpret_cast<const BYTE *>(job_id.Data()), job_id.Size());
sha256_update(&ctx, reinterpret_cast<const BYTE *>(parent_task_id.Data()), sha256_update(
parent_task_id.Size()); &ctx, reinterpret_cast<const BYTE *>(parent_task_id.Data()), parent_task_id.Size());
sha256_update(&ctx, (const BYTE *)&parent_task_counter, sizeof(parent_task_counter)); sha256_update(&ctx, (const BYTE *)&parent_task_counter, sizeof(parent_task_counter));
if (extra_bytes > 0) { if (extra_bytes > 0) {
sha256_update(&ctx, (const BYTE *)&extra_bytes, sizeof(extra_bytes)); sha256_update(&ctx, (const BYTE *)&extra_bytes, sizeof(extra_bytes));
@ -124,14 +127,17 @@ __suppress_ubsan__("undefined") uint64_t
return h; return h;
} }
ActorID ActorID::Of(const JobID &job_id, const TaskID &parent_task_id, ActorID ActorID::Of(const JobID &job_id,
const TaskID &parent_task_id,
const size_t parent_task_counter) { const size_t parent_task_counter) {
// NOTE(swang): Include the current time in the hash for the actor ID so that // NOTE(swang): Include the current time in the hash for the actor ID so that
// we avoid duplicating a previous actor ID, which is not allowed by the GCS. // we avoid duplicating a previous actor ID, which is not allowed by the GCS.
// See https://github.com/ray-project/ray/issues/10481. // See https://github.com/ray-project/ray/issues/10481.
auto data = auto data = GenerateUniqueBytes(job_id,
GenerateUniqueBytes(job_id, parent_task_id, parent_task_counter, parent_task_id,
absl::GetCurrentTimeNanos(), ActorID::kUniqueBytesLength); parent_task_counter,
absl::GetCurrentTimeNanos(),
ActorID::kUniqueBytesLength);
std::copy_n(job_id.Data(), JobID::kLength, std::back_inserter(data)); std::copy_n(job_id.Data(), JobID::kLength, std::back_inserter(data));
RAY_CHECK(data.size() == kLength); RAY_CHECK(data.size() == kLength);
return ActorID::FromBinary(data); return ActorID::FromBinary(data);
@ -175,19 +181,22 @@ TaskID TaskID::ForActorCreationTask(const ActorID &actor_id) {
return TaskID::FromBinary(data); return TaskID::FromBinary(data);
} }
TaskID TaskID::ForActorTask(const JobID &job_id, const TaskID &parent_task_id, TaskID TaskID::ForActorTask(const JobID &job_id,
size_t parent_task_counter, const ActorID &actor_id) { const TaskID &parent_task_id,
std::string data = GenerateUniqueBytes(job_id, parent_task_id, parent_task_counter, 0, size_t parent_task_counter,
TaskID::kUniqueBytesLength); const ActorID &actor_id) {
std::string data = GenerateUniqueBytes(
job_id, parent_task_id, parent_task_counter, 0, TaskID::kUniqueBytesLength);
std::copy_n(actor_id.Data(), ActorID::kLength, std::back_inserter(data)); std::copy_n(actor_id.Data(), ActorID::kLength, std::back_inserter(data));
RAY_CHECK(data.size() == TaskID::kLength); RAY_CHECK(data.size() == TaskID::kLength);
return TaskID::FromBinary(data); return TaskID::FromBinary(data);
} }
TaskID TaskID::ForNormalTask(const JobID &job_id, const TaskID &parent_task_id, TaskID TaskID::ForNormalTask(const JobID &job_id,
const TaskID &parent_task_id,
size_t parent_task_counter) { size_t parent_task_counter) {
std::string data = GenerateUniqueBytes(job_id, parent_task_id, parent_task_counter, 0, std::string data = GenerateUniqueBytes(
TaskID::kUniqueBytesLength); job_id, parent_task_id, parent_task_counter, 0, TaskID::kUniqueBytesLength);
const auto dummy_actor_id = ActorID::NilFromJob(job_id); const auto dummy_actor_id = ActorID::NilFromJob(job_id);
std::copy_n(dummy_actor_id.Data(), ActorID::kLength, std::back_inserter(data)); std::copy_n(dummy_actor_id.Data(), ActorID::kLength, std::back_inserter(data));
RAY_CHECK(data.size() == TaskID::kLength); RAY_CHECK(data.size() == TaskID::kLength);

View file

@ -144,7 +144,8 @@ class ActorID : public BaseID<ActorID> {
/// \param parent_task_counter The counter of the parent task. /// \param parent_task_counter The counter of the parent task.
/// ///
/// \return The random `ActorID`. /// \return The random `ActorID`.
static ActorID Of(const JobID &job_id, const TaskID &parent_task_id, static ActorID Of(const JobID &job_id,
const TaskID &parent_task_id,
const size_t parent_task_counter); const size_t parent_task_counter);
/// Creates a nil ActorID with the given job. /// Creates a nil ActorID with the given job.
@ -210,8 +211,10 @@ class TaskID : public BaseID<TaskID> {
/// \param actor_id The ID of the actor to which this task belongs. /// \param actor_id The ID of the actor to which this task belongs.
/// ///
/// \return The ID of the actor task. /// \return The ID of the actor task.
static TaskID ForActorTask(const JobID &job_id, const TaskID &parent_task_id, static TaskID ForActorTask(const JobID &job_id,
size_t parent_task_counter, const ActorID &actor_id); const TaskID &parent_task_id,
size_t parent_task_counter,
const ActorID &actor_id);
/// Creates a TaskID for normal task. /// Creates a TaskID for normal task.
/// ///
@ -221,7 +224,8 @@ class TaskID : public BaseID<TaskID> {
/// parent task before this one. /// parent task before this one.
/// ///
/// \return The ID of the normal task. /// \return The ID of the normal task.
static TaskID ForNormalTask(const JobID &job_id, const TaskID &parent_task_id, static TaskID ForNormalTask(const JobID &job_id,
const TaskID &parent_task_id,
size_t parent_task_counter); size_t parent_task_counter);
/// Given a base task ID, create a task ID that represents the n-th execution /// Given a base task ID, create a task ID that represents the n-th execution

View file

@ -23,8 +23,11 @@ std::string GetValidLocalIp(int port, int64_t timeout_ms) {
const std::string localhost_ip = "127.0.0.1"; const std::string localhost_ip = "127.0.0.1";
bool is_timeout; bool is_timeout;
if (async_client.Connect(kPublicDNSServerIp, kPublicDNSServerPort, timeout_ms, if (async_client.Connect(kPublicDNSServerIp,
&is_timeout, &error_code)) { kPublicDNSServerPort,
timeout_ms,
&is_timeout,
&error_code)) {
address = async_client.GetLocalIPAddress(); address = async_client.GetLocalIPAddress();
} else { } else {
address = localhost_ip; address = localhost_ip;
@ -41,7 +44,9 @@ std::string GetValidLocalIp(int port, int64_t timeout_ms) {
primary_endpoint.address(ip_candidate); primary_endpoint.address(ip_candidate);
AsyncClient client; AsyncClient client;
if (client.Connect(primary_endpoint.address().to_string(), port, timeout_ms, if (client.Connect(primary_endpoint.address().to_string(),
port,
timeout_ms,
&is_timeout)) { &is_timeout)) {
success = true; success = true;
break; break;
@ -110,8 +115,8 @@ std::vector<boost::asio::ip::address> GetValidLocalIpCandidates() {
freeifaddrs(ifs_info); freeifaddrs(ifs_info);
// Bigger prefixes must be tested first in CompNameAndIps // Bigger prefixes must be tested first in CompNameAndIps
std::sort(prefixes_and_priorities.begin(), prefixes_and_priorities.end(), std::sort(
CompPrefixLen); prefixes_and_priorities.begin(), prefixes_and_priorities.end(), CompPrefixLen);
// Filter out interfaces with small possibility of being desired to be used to serve // Filter out interfaces with small possibility of being desired to be used to serve
std::sort(ifnames_and_ips.begin(), ifnames_and_ips.end(), CompNamesAndIps); std::sort(ifnames_and_ips.begin(), ifnames_and_ips.end(), CompNamesAndIps);
@ -140,7 +145,8 @@ std::vector<boost::asio::ip::address> GetValidLocalIpCandidates() {
instrumented_io_context io_context; instrumented_io_context io_context;
boost::asio::ip::tcp::resolver resolver(io_context); boost::asio::ip::tcp::resolver resolver(io_context);
boost::asio::ip::tcp::resolver::query query( boost::asio::ip::tcp::resolver::query query(
boost::asio::ip::host_name(), "", boost::asio::ip::host_name(),
"",
boost::asio::ip::resolver_query_base::flags::v4_mapped); boost::asio::ip::resolver_query_base::flags::v4_mapped);
boost::asio::ip::tcp::resolver::iterator iter = resolver.resolve(query); boost::asio::ip::tcp::resolver::iterator iter = resolver.resolve(query);

View file

@ -57,7 +57,10 @@ class AsyncClient {
/// \param is_timeout Whether connection timeout. /// \param is_timeout Whether connection timeout.
/// \param error_code Set to indicate what error occurred, if any. /// \param error_code Set to indicate what error occurred, if any.
/// \return Whether the connection is successful. /// \return Whether the connection is successful.
bool Connect(const std::string &ip, int port, int64_t timeout_ms, bool *is_timeout, bool Connect(const std::string &ip,
int port,
int64_t timeout_ms,
bool *is_timeout,
boost::system::error_code *error_code = nullptr) { boost::system::error_code *error_code = nullptr) {
try { try {
auto endpoint = auto endpoint =
@ -65,9 +68,11 @@ class AsyncClient {
bool is_connected = false; bool is_connected = false;
*is_timeout = false; *is_timeout = false;
socket_.async_connect(endpoint, boost::bind(&AsyncClient::ConnectHandle, this, socket_.async_connect(endpoint,
boost::asio::placeholders::error, boost::bind(&AsyncClient::ConnectHandle,
boost::ref(is_connected))); this,
boost::asio::placeholders::error,
boost::ref(is_connected)));
// Set a deadline for the asynchronous operation. // Set a deadline for the asynchronous operation.
timer_.expires_from_now(boost::posix_time::milliseconds(timeout_ms)); timer_.expires_from_now(boost::posix_time::milliseconds(timeout_ms));

View file

@ -65,10 +65,13 @@ class PlacementGroupSpecBuilder {
/// ///
/// \return Reference to the builder object itself. /// \return Reference to the builder object itself.
PlacementGroupSpecBuilder &SetPlacementGroupSpec( PlacementGroupSpecBuilder &SetPlacementGroupSpec(
const PlacementGroupID &placement_group_id, std::string name, const PlacementGroupID &placement_group_id,
std::string name,
const std::vector<std::unordered_map<std::string, double>> &bundles, const std::vector<std::unordered_map<std::string, double>> &bundles,
const rpc::PlacementStrategy strategy, const bool is_detached, const rpc::PlacementStrategy strategy,
const JobID &creator_job_id, const ActorID &creator_actor_id, const bool is_detached,
const JobID &creator_job_id,
const ActorID &creator_actor_id,
bool is_creator_detached_actor) { bool is_creator_detached_actor) {
message_->set_placement_group_id(placement_group_id.Binary()); message_->set_placement_group_id(placement_group_id.Binary());
message_->set_name(name); message_->set_name(name);

View file

@ -245,7 +245,8 @@ RAY_CONFIG(uint64_t, object_manager_default_chunk_size, 5 * 1024 * 1024)
/// The maximum number of outbound bytes to allow to be outstanding. This avoids /// The maximum number of outbound bytes to allow to be outstanding. This avoids
/// excessive memory usage during object broadcast to many receivers. /// excessive memory usage during object broadcast to many receivers.
RAY_CONFIG(uint64_t, object_manager_max_bytes_in_flight, RAY_CONFIG(uint64_t,
object_manager_max_bytes_in_flight,
((uint64_t)2) * 1024 * 1024 * 1024) ((uint64_t)2) * 1024 * 1024 * 1024)
/// Maximum number of ids in one batch to send to GCS to delete keys. /// Maximum number of ids in one batch to send to GCS to delete keys.
@ -478,7 +479,8 @@ RAY_CONFIG(bool, enable_light_weight_resource_report, true)
// The number of seconds to wait for the Raylet to start. This is normally // The number of seconds to wait for the Raylet to start. This is normally
// fast, but when RAY_preallocate_plasma_memory=1 is set, it may take some time // fast, but when RAY_preallocate_plasma_memory=1 is set, it may take some time
// (a few GB/s) to populate all the pages on Raylet startup. // (a few GB/s) to populate all the pages on Raylet startup.
RAY_CONFIG(uint32_t, raylet_start_wait_time_s, RAY_CONFIG(uint32_t,
raylet_start_wait_time_s,
std::getenv("RAY_preallocate_plasma_memory") != nullptr && std::getenv("RAY_preallocate_plasma_memory") != nullptr &&
std::getenv("RAY_preallocate_plasma_memory") == std::string("1") std::getenv("RAY_preallocate_plasma_memory") == std::string("1")
? 120 ? 120

View file

@ -74,7 +74,8 @@ std::shared_ptr<ray::LocalMemoryBuffer> MakeSerializedErrorBuffer(
kMessagePackOffset); kMessagePackOffset);
// copy msgpack-serialized bytes // copy msgpack-serialized bytes
std::memcpy(final_buffer->Data() + kMessagePackOffset, std::memcpy(final_buffer->Data() + kMessagePackOffset,
msgpack_serialized_exception.data(), msgpack_serialized_exception.size()); msgpack_serialized_exception.data(),
msgpack_serialized_exception.size());
// copy offset // copy offset
msgpack::sbuffer msgpack_int; msgpack::sbuffer msgpack_int;
msgpack::pack(msgpack_int, msgpack_serialized_exception.size()); msgpack::pack(msgpack_int, msgpack_serialized_exception.size());

View file

@ -39,7 +39,8 @@ class RayObject {
/// \param[in] metadata Metadata of the ray object. /// \param[in] metadata Metadata of the ray object.
/// \param[in] nested_rfs ObjectRefs that were serialized in data. /// \param[in] nested_rfs ObjectRefs that were serialized in data.
/// \param[in] copy_data Whether this class should hold a copy of data. /// \param[in] copy_data Whether this class should hold a copy of data.
RayObject(const std::shared_ptr<Buffer> &data, const std::shared_ptr<Buffer> &metadata, RayObject(const std::shared_ptr<Buffer> &data,
const std::shared_ptr<Buffer> &metadata,
const std::vector<rpc::ObjectReference> &nested_refs, const std::vector<rpc::ObjectReference> &nested_refs,
bool copy_data = false) { bool copy_data = false) {
Init(data, metadata, nested_refs, copy_data); Init(data, metadata, nested_refs, copy_data);
@ -125,7 +126,8 @@ class RayObject {
int64_t CreationTimeNanos() const { return creation_time_nanos_; } int64_t CreationTimeNanos() const { return creation_time_nanos_; }
private: private:
void Init(const std::shared_ptr<Buffer> &data, const std::shared_ptr<Buffer> &metadata, void Init(const std::shared_ptr<Buffer> &data,
const std::shared_ptr<Buffer> &metadata,
const std::vector<rpc::ObjectReference> &nested_refs, const std::vector<rpc::ObjectReference> &nested_refs,
bool copy_data = false) { bool copy_data = false) {
data_ = data; data_ = data;
@ -138,7 +140,8 @@ class RayObject {
// If this object is required to hold a copy of the data, // If this object is required to hold a copy of the data,
// make a copy if the passed in buffers don't already have a copy. // make a copy if the passed in buffers don't already have a copy.
if (data_ && !data_->OwnsData()) { if (data_ && !data_->OwnsData()) {
data_ = std::make_shared<LocalMemoryBuffer>(data_->Data(), data_->Size(), data_ = std::make_shared<LocalMemoryBuffer>(data_->Data(),
data_->Size(),
/*copy_data=*/true); /*copy_data=*/true);
} }

View file

@ -176,7 +176,8 @@ int TaskSpecification::GetRuntimeEnvHash() const {
required_resource = GetRequiredResources().GetResourceMap(); required_resource = GetRequiredResources().GetResourceMap();
} }
WorkerCacheKey env = { WorkerCacheKey env = {
SerializedRuntimeEnv(), required_resource, SerializedRuntimeEnv(),
required_resource,
IsActorCreationTask() && RayConfig::instance().isolate_workers_across_task_types(), IsActorCreationTask() && RayConfig::instance().isolate_workers_across_task_types(),
GetRequiredResources().GetResource("GPU") > 0 && GetRequiredResources().GetResource("GPU") > 0 &&
RayConfig::instance().isolate_workers_across_resource_types()}; RayConfig::instance().isolate_workers_across_resource_types()};
@ -451,7 +452,8 @@ std::string TaskSpecification::CallSiteString() const {
WorkerCacheKey::WorkerCacheKey( WorkerCacheKey::WorkerCacheKey(
const std::string serialized_runtime_env, const std::string serialized_runtime_env,
const absl::flat_hash_map<std::string, double> &required_resources, bool is_actor, const absl::flat_hash_map<std::string, double> &required_resources,
bool is_actor,
bool is_gpu) bool is_gpu)
: serialized_runtime_env(serialized_runtime_env), : serialized_runtime_env(serialized_runtime_env),
required_resources(std::move(required_resources)), required_resources(std::move(required_resources)),

View file

@ -87,7 +87,8 @@ struct ConcurrencyGroup {
ConcurrencyGroup() = default; ConcurrencyGroup() = default;
ConcurrencyGroup(const std::string &name, uint32_t max_concurrency, ConcurrencyGroup(const std::string &name,
uint32_t max_concurrency,
const std::vector<ray::FunctionDescriptor> &fds) const std::vector<ray::FunctionDescriptor> &fds)
: name(name), max_concurrency(max_concurrency), function_descriptors(fds) {} : name(name), max_concurrency(max_concurrency), function_descriptors(fds) {}
@ -347,7 +348,8 @@ class WorkerCacheKey {
/// resource type isolation between workers is enabled. /// resource type isolation between workers is enabled.
WorkerCacheKey(const std::string serialized_runtime_env, WorkerCacheKey(const std::string serialized_runtime_env,
const absl::flat_hash_map<std::string, double> &required_resources, const absl::flat_hash_map<std::string, double> &required_resources,
bool is_actor, bool is_gpu); bool is_actor,
bool is_gpu);
bool operator==(const WorkerCacheKey &k) const; bool operator==(const WorkerCacheKey &k) const;

View file

@ -34,7 +34,8 @@ class TaskArgByReference : public TaskArg {
/// ///
/// \param[in] object_id Id of the argument. /// \param[in] object_id Id of the argument.
/// \return The task argument. /// \return The task argument.
TaskArgByReference(const ObjectID &object_id, const rpc::Address &owner_address, TaskArgByReference(const ObjectID &object_id,
const rpc::Address &owner_address,
const std::string &call_site) const std::string &call_site)
: id_(object_id), owner_address_(owner_address), call_site_(call_site) {} : id_(object_id), owner_address_(owner_address), call_site_(call_site) {}
@ -97,13 +98,20 @@ class TaskSpecBuilder {
/// ///
/// \return Reference to the builder object itself. /// \return Reference to the builder object itself.
TaskSpecBuilder &SetCommonTaskSpec( TaskSpecBuilder &SetCommonTaskSpec(
const TaskID &task_id, const std::string name, const Language &language, const TaskID &task_id,
const ray::FunctionDescriptor &function_descriptor, const JobID &job_id, const std::string name,
const TaskID &parent_task_id, uint64_t parent_counter, const TaskID &caller_id, const Language &language,
const rpc::Address &caller_address, uint64_t num_returns, const ray::FunctionDescriptor &function_descriptor,
const JobID &job_id,
const TaskID &parent_task_id,
uint64_t parent_counter,
const TaskID &caller_id,
const rpc::Address &caller_address,
uint64_t num_returns,
const std::unordered_map<std::string, double> &required_resources, const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources, const std::unordered_map<std::string, double> &required_placement_resources,
const std::string &debugger_breakpoint, int64_t depth, const std::string &debugger_breakpoint,
int64_t depth,
const std::shared_ptr<rpc::RuntimeEnvInfo> runtime_env_info = nullptr, const std::shared_ptr<rpc::RuntimeEnvInfo> runtime_env_info = nullptr,
const std::string &concurrency_group_name = "") { const std::string &concurrency_group_name = "") {
message_->set_type(TaskType::NORMAL_TASK); message_->set_type(TaskType::NORMAL_TASK);
@ -130,7 +138,8 @@ class TaskSpecBuilder {
return *this; return *this;
} }
TaskSpecBuilder &SetNormalTaskSpec(int max_retries, bool retry_exceptions, TaskSpecBuilder &SetNormalTaskSpec(int max_retries,
bool retry_exceptions,
const rpc::SchedulingStrategy &scheduling_strategy) { const rpc::SchedulingStrategy &scheduling_strategy) {
message_->set_max_retries(max_retries); message_->set_max_retries(max_retries);
message_->set_retry_exceptions(retry_exceptions); message_->set_retry_exceptions(retry_exceptions);
@ -142,8 +151,10 @@ class TaskSpecBuilder {
/// See `common.proto` for meaning of the arguments. /// See `common.proto` for meaning of the arguments.
/// ///
/// \return Reference to the builder object itself. /// \return Reference to the builder object itself.
TaskSpecBuilder &SetDriverTaskSpec(const TaskID &task_id, const Language &language, TaskSpecBuilder &SetDriverTaskSpec(const TaskID &task_id,
const JobID &job_id, const TaskID &parent_task_id, const Language &language,
const JobID &job_id,
const TaskID &parent_task_id,
const TaskID &caller_id, const TaskID &caller_id,
const rpc::Address &caller_address) { const rpc::Address &caller_address) {
message_->set_type(TaskType::DRIVER_TASK); message_->set_type(TaskType::DRIVER_TASK);
@ -170,14 +181,20 @@ class TaskSpecBuilder {
/// ///
/// \return Reference to the builder object itself. /// \return Reference to the builder object itself.
TaskSpecBuilder &SetActorCreationTaskSpec( TaskSpecBuilder &SetActorCreationTaskSpec(
const ActorID &actor_id, const std::string &serialized_actor_handle, const ActorID &actor_id,
const rpc::SchedulingStrategy &scheduling_strategy, int64_t max_restarts = 0, const std::string &serialized_actor_handle,
const rpc::SchedulingStrategy &scheduling_strategy,
int64_t max_restarts = 0,
int64_t max_task_retries = 0, int64_t max_task_retries = 0,
const std::vector<std::string> &dynamic_worker_options = {}, const std::vector<std::string> &dynamic_worker_options = {},
int max_concurrency = 1, bool is_detached = false, std::string name = "", int max_concurrency = 1,
std::string ray_namespace = "", bool is_asyncio = false, bool is_detached = false,
std::string name = "",
std::string ray_namespace = "",
bool is_asyncio = false,
const std::vector<ConcurrencyGroup> &concurrency_groups = {}, const std::vector<ConcurrencyGroup> &concurrency_groups = {},
const std::string &extension_data = "", bool execute_out_of_order = false) { const std::string &extension_data = "",
bool execute_out_of_order = false) {
message_->set_type(TaskType::ACTOR_CREATION_TASK); message_->set_type(TaskType::ACTOR_CREATION_TASK);
auto actor_creation_spec = message_->mutable_actor_creation_task_spec(); auto actor_creation_spec = message_->mutable_actor_creation_task_spec();
actor_creation_spec->set_actor_id(actor_id.Binary()); actor_creation_spec->set_actor_id(actor_id.Binary());

View file

@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "ray/common/client_connection.h" #include "ray/common/client_connection.h"
#include "ray/common/asio/instrumented_io_context.h"
#include <boost/asio.hpp> #include <boost/asio.hpp>
#include <boost/asio/error.hpp> #include <boost/asio/error.hpp>
@ -22,6 +21,7 @@
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ray/common/asio/instrumented_io_context.h"
namespace ray { namespace ray {
namespace raylet { namespace raylet {
@ -45,8 +45,10 @@ class ClientConnectionTest : public ::testing::Test {
#endif #endif
} }
ray::Status WriteBadMessage(std::shared_ptr<ray::ClientConnection> conn, int64_t type, ray::Status WriteBadMessage(std::shared_ptr<ray::ClientConnection> conn,
int64_t length, const uint8_t *message) { int64_t type,
int64_t length,
const uint8_t *message) {
std::vector<boost::asio::const_buffer> message_buffers; std::vector<boost::asio::const_buffer> message_buffers;
auto write_cookie = 123456; // incorrect version. auto write_cookie = 123456; // incorrect version.
message_buffers.push_back(boost::asio::buffer(&write_cookie, sizeof(write_cookie))); message_buffers.push_back(boost::asio::buffer(&write_cookie, sizeof(write_cookie)));
@ -69,18 +71,19 @@ TEST_F(ClientConnectionTest, SimpleSyncWrite) {
ClientHandler client_handler = [](ClientConnection &client) {}; ClientHandler client_handler = [](ClientConnection &client) {};
MessageHandler message_handler = MessageHandler message_handler = [&arr, &num_messages](
[&arr, &num_messages](std::shared_ptr<ClientConnection> client, std::shared_ptr<ClientConnection> client,
int64_t message_type, const std::vector<uint8_t> &message) { int64_t message_type,
ASSERT_TRUE(!std::memcmp(arr, message.data(), 5)); const std::vector<uint8_t> &message) {
num_messages += 1; ASSERT_TRUE(!std::memcmp(arr, message.data(), 5));
}; num_messages += 1;
};
auto conn1 = ClientConnection::Create(client_handler, message_handler, std::move(in_), auto conn1 = ClientConnection::Create(
"conn1", {}, error_message_type_); client_handler, message_handler, std::move(in_), "conn1", {}, error_message_type_);
auto conn2 = ClientConnection::Create(client_handler, message_handler, std::move(out_), auto conn2 = ClientConnection::Create(
"conn2", {}, error_message_type_); client_handler, message_handler, std::move(out_), "conn2", {}, error_message_type_);
RAY_CHECK_OK(conn1->WriteMessage(0, 5, arr)); RAY_CHECK_OK(conn1->WriteMessage(0, 5, arr));
RAY_CHECK_OK(conn2->WriteMessage(0, 5, arr)); RAY_CHECK_OK(conn2->WriteMessage(0, 5, arr));
@ -121,11 +124,15 @@ TEST_F(ClientConnectionTest, SimpleAsyncWrite) {
} }
}; };
auto writer = ClientConnection::Create(client_handler, noop_handler, std::move(in_), auto writer = ClientConnection::Create(
"writer", {}, error_message_type_); client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_);
reader = ClientConnection::Create(client_handler, message_handler, std::move(out_), reader = ClientConnection::Create(client_handler,
"reader", {}, error_message_type_); message_handler,
std::move(out_),
"reader",
{},
error_message_type_);
std::function<void(const ray::Status &)> callback = [](const ray::Status &status) { std::function<void(const ray::Status &)> callback = [](const ray::Status &status) {
RAY_CHECK_OK(status); RAY_CHECK_OK(status);
@ -178,8 +185,8 @@ TEST_F(ClientConnectionTest, SimpleAsyncError) {
int64_t message_type, int64_t message_type,
const std::vector<uint8_t> &message) {}; const std::vector<uint8_t> &message) {};
auto writer = ClientConnection::Create(client_handler, noop_handler, std::move(in_), auto writer = ClientConnection::Create(
"writer", {}, error_message_type_); client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_);
std::function<void(const ray::Status &)> callback = [](const ray::Status &status) { std::function<void(const ray::Status &)> callback = [](const ray::Status &status) {
ASSERT_TRUE(!status.ok()); ASSERT_TRUE(!status.ok());
@ -199,8 +206,8 @@ TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) {
int64_t message_type, int64_t message_type,
const std::vector<uint8_t> &message) {}; const std::vector<uint8_t> &message) {};
auto writer = ClientConnection::Create(client_handler, noop_handler, std::move(in_), auto writer = ClientConnection::Create(
"writer", {}, error_message_type_); client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_);
std::function<void(const ray::Status &)> callback = std::function<void(const ray::Status &)> callback =
[writer](const ray::Status &status) { [writer](const ray::Status &status) {
@ -217,18 +224,23 @@ TEST_F(ClientConnectionTest, ProcessBadMessage) {
ClientHandler client_handler = [](ClientConnection &client) {}; ClientHandler client_handler = [](ClientConnection &client) {};
MessageHandler message_handler = MessageHandler message_handler = [&arr, &num_messages](
[&arr, &num_messages](std::shared_ptr<ClientConnection> client, std::shared_ptr<ClientConnection> client,
int64_t message_type, const std::vector<uint8_t> &message) { int64_t message_type,
ASSERT_TRUE(!std::memcmp(arr, message.data(), 5)); const std::vector<uint8_t> &message) {
num_messages += 1; ASSERT_TRUE(!std::memcmp(arr, message.data(), 5));
}; num_messages += 1;
};
auto writer = ClientConnection::Create(client_handler, message_handler, std::move(in_), auto writer = ClientConnection::Create(
"writer", {}, error_message_type_); client_handler, message_handler, std::move(in_), "writer", {}, error_message_type_);
auto reader = ClientConnection::Create(client_handler, message_handler, std::move(out_), auto reader = ClientConnection::Create(client_handler,
"reader", {}, error_message_type_); message_handler,
std::move(out_),
"reader",
{},
error_message_type_);
// If client ID is set, bad message would crash the test. // If client ID is set, bad message would crash the test.
// reader->SetClientID(UniqueID::FromRandom()); // reader->SetClientID(UniqueID::FromRandom());

View file

@ -100,7 +100,8 @@ std::string TestSetupUtil::StartGcsServer(const std::string &redis_address) {
std::string gcs_server_socket_name = std::string gcs_server_socket_name =
ray::JoinPaths(ray::GetUserTempDir(), "gcs_server" + ObjectID::FromRandom().Hex()); ray::JoinPaths(ray::GetUserTempDir(), "gcs_server" + ObjectID::FromRandom().Hex());
std::vector<std::string> cmdargs( std::vector<std::string> cmdargs(
{TEST_GCS_SERVER_EXEC_PATH, "--redis_address=" + redis_address, {TEST_GCS_SERVER_EXEC_PATH,
"--redis_address=" + redis_address,
"--config_list=" + "--config_list=" +
absl::Base64Escape(R"({"object_timeout_milliseconds": 2000})")}); absl::Base64Escape(R"({"object_timeout_milliseconds": 2000})")});
if (RayConfig::instance().bootstrap_with_gcs()) { if (RayConfig::instance().bootstrap_with_gcs()) {
@ -129,15 +130,21 @@ std::string TestSetupUtil::StartRaylet(const std::string &node_ip_address,
std::string plasma_store_socket_name = std::string plasma_store_socket_name =
ray::JoinPaths(ray::GetUserTempDir(), "store" + ObjectID::FromRandom().Hex()); ray::JoinPaths(ray::GetUserTempDir(), "store" + ObjectID::FromRandom().Hex());
std::vector<std::string> cmdargs( std::vector<std::string> cmdargs(
{TEST_RAYLET_EXEC_PATH, "--raylet_socket_name=" + raylet_socket_name, {TEST_RAYLET_EXEC_PATH,
"--store_socket_name=" + plasma_store_socket_name, "--object_manager_port=0", "--raylet_socket_name=" + raylet_socket_name,
"--store_socket_name=" + plasma_store_socket_name,
"--object_manager_port=0",
"--node_manager_port=" + std::to_string(port), "--node_manager_port=" + std::to_string(port),
"--node_ip_address=" + node_ip_address, "--redis_port=6379", "--min-worker-port=0", "--node_ip_address=" + node_ip_address,
"--max-worker-port=0", "--maximum_startup_concurrency=10", "--redis_port=6379",
"--min-worker-port=0",
"--max-worker-port=0",
"--maximum_startup_concurrency=10",
"--static_resource_list=" + resource, "--static_resource_list=" + resource,
"--python_worker_command=" + "--python_worker_command=" + CreateCommandLine({TEST_MOCK_WORKER_EXEC_PATH,
CreateCommandLine({TEST_MOCK_WORKER_EXEC_PATH, plasma_store_socket_name, plasma_store_socket_name,
raylet_socket_name, std::to_string(port)}), raylet_socket_name,
std::to_string(port)}),
"--object_store_memory=10000000"}); "--object_store_memory=10000000"});
if (RayConfig::instance().bootstrap_with_gcs()) { if (RayConfig::instance().bootstrap_with_gcs()) {
cmdargs.push_back("--gcs-address=" + bootstrap_address); cmdargs.push_back("--gcs-address=" + bootstrap_address);
@ -178,7 +185,8 @@ bool WaitForCondition(std::function<bool()> condition, int timeout_ms) {
return false; return false;
} }
void WaitForExpectedCount(std::atomic<int> &current_count, int expected_count, void WaitForExpectedCount(std::atomic<int> &current_count,
int expected_count,
int timeout_ms) { int timeout_ms) {
auto condition = [&current_count, expected_count]() { auto condition = [&current_count, expected_count]() {
return current_count == expected_count; return current_count == expected_count;

View file

@ -63,7 +63,8 @@ bool WaitForCondition(std::function<bool()> condition, int timeout_ms);
/// \param[in] expected_count The expected count. /// \param[in] expected_count The expected count.
/// \param[in] timeout_ms Timeout in milliseconds to wait for for. /// \param[in] timeout_ms Timeout in milliseconds to wait for for.
/// \return Whether the expected count is met. /// \return Whether the expected count is met.
void WaitForExpectedCount(std::atomic<int> &current_count, int expected_count, void WaitForExpectedCount(std::atomic<int> &current_count,
int expected_count,
int timeout_ms = 60000); int timeout_ms = 60000);
/// Used to kill process whose pid is stored in `socket_name.id` file. /// Used to kill process whose pid is stored in `socket_name.id` file.
@ -117,7 +118,8 @@ class TestSetupUtil {
static std::string StartGcsServer(const std::string &redis_address); static std::string StartGcsServer(const std::string &redis_address);
static void StopGcsServer(const std::string &gcs_server_socket_name); static void StopGcsServer(const std::string &gcs_server_socket_name);
static std::string StartRaylet(const std::string &node_ip_address, const int &port, static std::string StartRaylet(const std::string &node_ip_address,
const int &port,
const std::string &bootstrap_address, const std::string &bootstrap_address,
const std::string &resource, const std::string &resource,
std::string *store_socket_name); std::string *store_socket_name);

View file

@ -20,12 +20,18 @@ namespace ray {
namespace core { namespace core {
namespace { namespace {
rpc::ActorHandle CreateInnerActorHandle( rpc::ActorHandle CreateInnerActorHandle(
const class ActorID &actor_id, const TaskID &owner_id, const class ActorID &actor_id,
const rpc::Address &owner_address, const class JobID &job_id, const TaskID &owner_id,
const ObjectID &initial_cursor, const Language actor_language, const rpc::Address &owner_address,
const class JobID &job_id,
const ObjectID &initial_cursor,
const Language actor_language,
const FunctionDescriptor &actor_creation_task_function_descriptor, const FunctionDescriptor &actor_creation_task_function_descriptor,
const std::string &extension_data, int64_t max_task_retries, const std::string &name, const std::string &extension_data,
const std::string &ray_namespace, int32_t max_pending_calls, int64_t max_task_retries,
const std::string &name,
const std::string &ray_namespace,
int32_t max_pending_calls,
bool execute_out_of_order) { bool execute_out_of_order) {
rpc::ActorHandle inner; rpc::ActorHandle inner;
inner.set_actor_id(actor_id.Data(), actor_id.Size()); inner.set_actor_id(actor_id.Data(), actor_id.Size());
@ -79,17 +85,32 @@ rpc::ActorHandle CreateInnerActorHandleFromActorTableData(
} // namespace } // namespace
ActorHandle::ActorHandle( ActorHandle::ActorHandle(
const class ActorID &actor_id, const TaskID &owner_id, const class ActorID &actor_id,
const rpc::Address &owner_address, const class JobID &job_id, const TaskID &owner_id,
const ObjectID &initial_cursor, const Language actor_language, const rpc::Address &owner_address,
const class JobID &job_id,
const ObjectID &initial_cursor,
const Language actor_language,
const FunctionDescriptor &actor_creation_task_function_descriptor, const FunctionDescriptor &actor_creation_task_function_descriptor,
const std::string &extension_data, int64_t max_task_retries, const std::string &name, const std::string &extension_data,
const std::string &ray_namespace, int32_t max_pending_calls, int64_t max_task_retries,
const std::string &name,
const std::string &ray_namespace,
int32_t max_pending_calls,
bool execute_out_of_order) bool execute_out_of_order)
: ActorHandle(CreateInnerActorHandle( : ActorHandle(CreateInnerActorHandle(actor_id,
actor_id, owner_id, owner_address, job_id, initial_cursor, actor_language, owner_id,
actor_creation_task_function_descriptor, extension_data, max_task_retries, name, owner_address,
ray_namespace, max_pending_calls, execute_out_of_order)) {} job_id,
initial_cursor,
actor_language,
actor_creation_task_function_descriptor,
extension_data,
max_task_retries,
name,
ray_namespace,
max_pending_calls,
execute_out_of_order)) {}
ActorHandle::ActorHandle(const std::string &serialized) ActorHandle::ActorHandle(const std::string &serialized)
: ActorHandle(CreateInnerActorHandleFromString(serialized)) {} : ActorHandle(CreateInnerActorHandleFromString(serialized)) {}
@ -103,7 +124,8 @@ void ActorHandle::SetActorTaskSpec(TaskSpecBuilder &builder, const ObjectID new_
const TaskID actor_creation_task_id = TaskID::ForActorCreationTask(GetActorID()); const TaskID actor_creation_task_id = TaskID::ForActorCreationTask(GetActorID());
const ObjectID actor_creation_dummy_object_id = const ObjectID actor_creation_dummy_object_id =
ObjectID::FromIndex(actor_creation_task_id, /*index=*/1); ObjectID::FromIndex(actor_creation_task_id, /*index=*/1);
builder.SetActorTaskSpec(GetActorID(), actor_creation_dummy_object_id, builder.SetActorTaskSpec(GetActorID(),
actor_creation_dummy_object_id,
/*previous_actor_task_dummy_object_id=*/actor_cursor_, /*previous_actor_task_dummy_object_id=*/actor_cursor_,
task_counter_++); task_counter_++);
actor_cursor_ = new_cursor; actor_cursor_ = new_cursor;

Some files were not shown because too many files have changed in this diff Show more