diff --git a/cpp/include/ray/api/ray_remote.h b/cpp/include/ray/api/ray_remote.h index 38080d011..4ab42c42e 100644 --- a/cpp/include/ray/api/ray_remote.h +++ b/cpp/include/ray/api/ray_remote.h @@ -75,7 +75,7 @@ inline static int RegisterRemoteFunctions(const T &t, U... u) { } // namespace internal #define RAY_REMOTE(...) \ - static auto ANONYMOUS_VARIABLE(var) = \ + inline auto ANONYMOUS_VARIABLE(var) = \ ray::internal::RegisterRemoteFunctions(#__VA_ARGS__, __VA_ARGS__); #define RAY_FUNC(f, ...) ray::internal::underload<__VA_ARGS__>(f) diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index e731cc929..f54800109 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -164,6 +164,13 @@ Status TaskExecutor::ExecuteTask( std::tie(status, data) = GetExecuteResult(func_name, ray_args_buffer, nullptr); current_actor_ = data; } else if (task_type == ray::TaskType::ACTOR_TASK) { + if (cross_lang) { + RAY_CHECK(!typed_descriptor->ClassName().empty()); + func_name = std::string("&") + .append(typed_descriptor->ClassName()) + .append("::") + .append(typed_descriptor->FunctionName()); + } RAY_CHECK(current_actor_ != nullptr); std::tie(status, data) = GetExecuteResult(func_name, ray_args_buffer, current_actor_.get()); diff --git a/cpp/src/ray/test/cluster/counter.h b/cpp/src/ray/test/cluster/counter.h index 436d015ea..b554d5726 100644 --- a/cpp/src/ray/test/cluster/counter.h +++ b/cpp/src/ray/test/cluster/counter.h @@ -50,6 +50,9 @@ class Counter { bool is_restared = false; }; +inline Counter *CreateCounter() { return new Counter(0); } +RAY_REMOTE(CreateCounter); + class CountDownLatch { public: explicit CountDownLatch(size_t count) : m_count(count) {} diff --git a/cpp/test_python_call_cpp.py b/cpp/test_python_call_cpp.py index 2b2bb4c97..34fa6f784 100644 --- a/cpp/test_python_call_cpp.py +++ b/cpp/test_python_call_cpp.py @@ -1,11 +1,16 @@ import ray import ray.cluster_utils from ray.exceptions import CrossLanguageError +from ray.exceptions import RayActorError import pytest def test_cross_language_cpp(): - ray.init(job_config=ray.job_config.JobConfig(code_search_path=["../../plus.so"])) + ray.init( + job_config=ray.job_config.JobConfig( + code_search_path=["../../plus.so:../../counter.so"] + ) + ) obj = ray.cross_language.cpp_function("Plus1").remote(1) assert 2 == ray.get(obj) @@ -60,6 +65,37 @@ def test_cross_language_cpp(): assert students == ray.get(obj9) +def test_cross_language_cpp_actor(): + actor = ray.cross_language.cpp_actor_class("CreateCounter", "Counter").remote() + obj = actor.Plus1.remote() + assert 1 == ray.get(obj) + + actor1 = ray.cross_language.cpp_actor_class( + "RAY_FUNC(Counter::FactoryCreate)", "Counter" + ).remote("invalid arg") + obj = actor1.Plus1.remote() + with pytest.raises(RayActorError): + ray.get(obj) + + actor1 = ray.cross_language.cpp_actor_class( + "RAY_FUNC(Counter::FactoryCreate)", "Counter" + ).remote() + + obj = actor1.Plus1.remote() + assert 1 == ray.get(obj) + + obj = actor1.Add.remote(2) + assert 3 == ray.get(obj) + + obj2 = actor1.ExceptionFunc.remote() + with pytest.raises(CrossLanguageError): + ray.get(obj2) + + obj3 = actor1.NotExistFunc.remote() + with pytest.raises(CrossLanguageError): + ray.get(obj3) + + if __name__ == "__main__": import sys diff --git a/python/ray/actor.py b/python/ray/actor.py index 416cc0e8c..8431a4566 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -968,11 +968,14 @@ class ActorClass: # Update the creation descriptor based on number of arguments if meta.is_cross_language: + func_name = "" + if meta.language == Language.CPP: + func_name = meta.actor_creation_function_descriptor.function_name meta.actor_creation_function_descriptor = ( cross_language.get_function_descriptor_for_actor_method( meta.language, meta.actor_creation_function_descriptor, - "", + func_name, str(len(args) + len(kwargs)), ) ) diff --git a/python/ray/cross_language.py b/python/ray/cross_language.py index b1fd5005e..9f4c242a3 100644 --- a/python/ray/cross_language.py +++ b/python/ray/cross_language.py @@ -56,6 +56,12 @@ def get_function_descriptor_for_actor_method( method_name, signature, ) + elif language == Language.CPP: + return CppFunctionDescriptor( + method_name, + "PYTHON", + actor_creation_function_descriptor.class_name, + ) else: raise NotImplementedError( "Cross language remote actor method " f"not support language {language}" @@ -143,3 +149,29 @@ def java_actor_class(class_name): accelerator_type=None, runtime_env=None, ) + + +@PublicAPI(stability="beta") +def cpp_actor_class(create_function_name, class_name): + """Define a Cpp actor class. + + Args: + create_function_name (str): Create cpp class function name. + class_name (str): Cpp class name. + """ + from ray.actor import ActorClass + + print("create func=", create_function_name, "class_name=", class_name) + return ActorClass._ray_from_function_descriptor( + Language.CPP, + CppFunctionDescriptor(create_function_name, "PYTHON", class_name), + max_restarts=0, + max_task_retries=0, + num_cpus=None, + num_gpus=None, + memory=None, + object_store_memory=None, + resources=None, + accelerator_type=None, + runtime_env=None, + ) diff --git a/python/ray/includes/function_descriptor.pxd b/python/ray/includes/function_descriptor.pxd index 1bf627c59..86090f68c 100644 --- a/python/ray/includes/function_descriptor.pxd +++ b/python/ray/includes/function_descriptor.pxd @@ -57,7 +57,8 @@ cdef extern from "ray/common/function_descriptor.h" nogil: @staticmethod CFunctionDescriptor BuildCpp(const c_string &function_name, - const c_string &caller) + const c_string &caller, + const c_string &class_name) @staticmethod CFunctionDescriptor Deserialize(const c_string &serialized_binary) @@ -76,3 +77,4 @@ cdef extern from "ray/common/function_descriptor.h" nogil: cdef cppclass CCppFunctionDescriptor "ray::CppFunctionDescriptor": c_string FunctionName() c_string Caller() + c_string ClassName() diff --git a/python/ray/includes/function_descriptor.pxi b/python/ray/includes/function_descriptor.pxi index 007aa0352..86991a343 100644 --- a/python/ray/includes/function_descriptor.pxi +++ b/python/ray/includes/function_descriptor.pxi @@ -332,21 +332,24 @@ cdef class CppFunctionDescriptor(FunctionDescriptor): CCppFunctionDescriptor *typed_descriptor def __cinit__(self, - function_name, caller): - self.descriptor = CFunctionDescriptorBuilder.BuildCpp(function_name, caller) + function_name, caller, class_name=""): + self.descriptor = CFunctionDescriptorBuilder.BuildCpp( + function_name, caller, class_name) self.typed_descriptor = ( self.descriptor.get()) def __reduce__(self): return CppFunctionDescriptor, (self.typed_descriptor.FunctionName(), - self.typed_descriptor.Caller()) + self.typed_descriptor.Caller(), + self.typed_descriptor.ClassName()) @staticmethod cdef from_cpp(const CFunctionDescriptor &c_function_descriptor): cdef CCppFunctionDescriptor *typed_descriptor = \ (c_function_descriptor.get()) return CppFunctionDescriptor(typed_descriptor.FunctionName(), - typed_descriptor.Caller()) + typed_descriptor.Caller(), + typed_descriptor.ClassName()) @property def function_name(self): @@ -365,3 +368,13 @@ cdef class CppFunctionDescriptor(FunctionDescriptor): The caller of the function descriptor. """ return self.typed_descriptor.Caller() + + @property + def class_name(self): + """Get the class name of current function descriptor, + when it is empty, it is a non-member function. + + Returns: + The class name of the function descriptor. + """ + return self.typed_descriptor.ClassName() diff --git a/src/ray/common/function_descriptor.cc b/src/ray/common/function_descriptor.cc index 332660886..c9f1a71de 100644 --- a/src/ray/common/function_descriptor.cc +++ b/src/ray/common/function_descriptor.cc @@ -47,11 +47,13 @@ FunctionDescriptor FunctionDescriptorBuilder::BuildPython( } FunctionDescriptor FunctionDescriptorBuilder::BuildCpp(const std::string &function_name, - const std::string &caller) { + const std::string &caller, + const std::string &class_name) { rpc::FunctionDescriptor descriptor; auto typed_descriptor = descriptor.mutable_cpp_function_descriptor(); typed_descriptor->set_function_name(function_name); typed_descriptor->set_caller(caller); + typed_descriptor->set_class_name(class_name); return ray::FunctionDescriptor(new CppFunctionDescriptor(std::move(descriptor))); } @@ -89,11 +91,11 @@ FunctionDescriptor FunctionDescriptorBuilder::FromVector( function_descriptor_list[3] // function hash ); } else if (language == rpc::Language::CPP) { - RAY_CHECK(function_descriptor_list.size() == 2); + RAY_CHECK(function_descriptor_list.size() == 3); return FunctionDescriptorBuilder::BuildCpp( - function_descriptor_list[0], // function name - function_descriptor_list[1] // caller - ); + function_descriptor_list[0], // function name + function_descriptor_list[1], // caller + function_descriptor_list[2]); // class name } else { RAY_LOG(FATAL) << "Unspported language " << language; return FunctionDescriptorBuilder::Empty(); diff --git a/src/ray/common/function_descriptor.h b/src/ray/common/function_descriptor.h index be36e63d9..0663a7c34 100644 --- a/src/ray/common/function_descriptor.h +++ b/src/ray/common/function_descriptor.h @@ -230,14 +230,16 @@ class CppFunctionDescriptor : public FunctionDescriptorInterface { virtual size_t Hash() const { return std::hash()(ray::FunctionDescriptorType::kCppFunctionDescriptor) ^ - std::hash()(typed_message_->function_name()); + std::hash()(typed_message_->function_name()) ^ + std::hash()(typed_message_->class_name()); } inline bool operator==(const CppFunctionDescriptor &other) const { if (this == &other) { return true; } - return this->FunctionName() == other.FunctionName(); + return this->FunctionName() == other.FunctionName() && + this->ClassName() == other.ClassName(); } inline bool operator!=(const CppFunctionDescriptor &other) const { @@ -245,8 +247,9 @@ class CppFunctionDescriptor : public FunctionDescriptorInterface { } virtual std::string ToString() const { + std::string class_name = ClassName().empty() ? "" : ", class_name=" + ClassName(); return "{type=CppFunctionDescriptor, function_name=" + - typed_message_->function_name() + "}"; + typed_message_->function_name() + class_name + "}"; } virtual std::string CallString() const { return typed_message_->function_name(); } @@ -257,6 +260,8 @@ class CppFunctionDescriptor : public FunctionDescriptorInterface { const std::string &Caller() const { return typed_message_->caller(); } + const std::string &ClassName() const { return typed_message_->class_name(); } + private: const rpc::CppFunctionDescriptor *typed_message_; }; @@ -323,7 +328,8 @@ class FunctionDescriptorBuilder { /// /// \return a ray::CppFunctionDescriptor static FunctionDescriptor BuildCpp(const std::string &function_name, - const std::string &caller = ""); + const std::string &caller = "", + const std::string &class_name = ""); /// Build a ray::FunctionDescriptor according to input message. /// diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 481372a65..7dd009b4d 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -102,6 +102,7 @@ message CppFunctionDescriptor { /// Remote function name. string function_name = 1; string caller = 2; + string class_name = 3; } // A union wrapper for various function descriptor types.