From 9f8b596aaae7d76fb4fb286766f7093ac9de30ff Mon Sep 17 00:00:00 2001 From: xiaofeng Date: Sat, 6 Aug 2022 14:35:49 +0800 Subject: [PATCH] [serve][xlang]Support deploying Python deployment from Java. (#26877) In the previously merged pr(https://github.com/ray-project/ray/pull/22726/commits), java serve's support for python deployment was not implemented. This PR is used to implement this feature. Co-authored-by: nanqi.yxf --- java/BUILD.bazel | 1 + .../io/ray/serve/config/DeploymentConfig.java | 6 +- .../io/ray/serve/config/ReplicaConfig.java | 6 +- .../io/ray/serve/deployment/Deployment.java | 3 +- .../serve/deployment/DeploymentCreator.java | 13 +-- .../java/io/ray/serve/router/ReplicaSet.java | 87 +++++++++++------ .../main/resources/test_python_deployment.py | 20 ++++ .../serve/CrossLanguageDeploymentTest.java | 95 +++++++++++++++++++ .../java/io/ray/serve/ReplicaSetTest.java | 3 +- ...st_deployment_cross_language_invocation.py | 24 ----- python/ray/serve/_private/deployment_state.py | 85 +++++++++++++---- python/ray/serve/_private/replica.py | 33 ++++++- python/ray/serve/_private/router.py | 1 + python/ray/serve/_private/utils.py | 24 +++++ python/ray/serve/config.py | 4 +- python/ray/serve/tests/test_util.py | 30 ++++++ 16 files changed, 345 insertions(+), 90 deletions(-) create mode 100644 java/serve/src/main/resources/test_python_deployment.py create mode 100644 java/serve/src/test/java/io/ray/serve/CrossLanguageDeploymentTest.java delete mode 100644 java/serve/src/test/resources/test_deployment_cross_language_invocation.py diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 7fe1bb8f6..c90f579f2 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -160,6 +160,7 @@ define_java_module( "@maven//:com_google_code_gson_gson", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", + "@maven//:commons_io_commons_io", "@maven//:org_apache_commons_commons_lang3", "@maven//:org_apache_httpcomponents_client5_httpclient5", "@maven//:org_apache_httpcomponents_core5_httpcore5", diff --git a/java/serve/src/main/java/io/ray/serve/config/DeploymentConfig.java b/java/serve/src/main/java/io/ray/serve/config/DeploymentConfig.java index 841d70e92..588356e7a 100644 --- a/java/serve/src/main/java/io/ray/serve/config/DeploymentConfig.java +++ b/java/serve/src/main/java/io/ray/serve/config/DeploymentConfig.java @@ -193,15 +193,15 @@ public class DeploymentConfig implements Serializable { io.ray.serve.generated.DeploymentConfig.newBuilder() .setNumReplicas(numReplicas) .setMaxConcurrentQueries(maxConcurrentQueries) - .setUserConfig( - ByteString.copyFrom( - MessagePackSerializer.encode(userConfig).getKey())) // TODO-xlang .setGracefulShutdownWaitLoopS(gracefulShutdownWaitLoopS) .setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS) .setHealthCheckPeriodS(healthCheckPeriodS) .setHealthCheckTimeoutS(healthCheckTimeoutS) .setIsCrossLanguage(isCrossLanguage) .setDeploymentLanguage(deploymentLanguage); + if (null != userConfig) { + builder.setUserConfig(ByteString.copyFrom(MessagePackSerializer.encode(userConfig).getKey())); + } if (null != autoscalingConfig) { builder.setAutoscalingConfig(autoscalingConfig.toProto()); } diff --git a/java/serve/src/main/java/io/ray/serve/config/ReplicaConfig.java b/java/serve/src/main/java/io/ray/serve/config/ReplicaConfig.java index 53defb501..7c1c58301 100644 --- a/java/serve/src/main/java/io/ray/serve/config/ReplicaConfig.java +++ b/java/serve/src/main/java/io/ray/serve/config/ReplicaConfig.java @@ -142,10 +142,14 @@ public class ReplicaConfig { if (proto == null) { return null; } + Object[] initArgs = null; + if (0 != proto.getInitArgs().toByteArray().length) { + initArgs = MessagePackSerializer.decode(proto.getInitArgs().toByteArray(), null); + } ReplicaConfig replicaConfig = new ReplicaConfig( proto.getDeploymentDefName(), - MessagePackSerializer.decode(proto.getInitArgs().toByteArray(), null), // TODO-xlang + initArgs, gson.fromJson(proto.getRayActorOptions(), Map.class)); return replicaConfig; } diff --git a/java/serve/src/main/java/io/ray/serve/deployment/Deployment.java b/java/serve/src/main/java/io/ray/serve/deployment/Deployment.java index ab0763f7e..cdb49d286 100644 --- a/java/serve/src/main/java/io/ray/serve/deployment/Deployment.java +++ b/java/serve/src/main/java/io/ray/serve/deployment/Deployment.java @@ -126,7 +126,8 @@ public class Deployment { .setGracefulShutdownWaitLoopS(this.config.getGracefulShutdownWaitLoopS()) .setGracefulShutdownTimeoutS(this.config.getGracefulShutdownTimeoutS()) .setHealthCheckPeriodS(this.config.getHealthCheckPeriodS()) - .setHealthCheckTimeoutS(this.config.getHealthCheckTimeoutS()); + .setHealthCheckTimeoutS(this.config.getHealthCheckTimeoutS()) + .setLanguage(this.config.getDeploymentLanguage()); } public String getDeploymentDef() { diff --git a/java/serve/src/main/java/io/ray/serve/deployment/DeploymentCreator.java b/java/serve/src/main/java/io/ray/serve/deployment/DeploymentCreator.java index 4c79ba230..307a562ef 100644 --- a/java/serve/src/main/java/io/ray/serve/deployment/DeploymentCreator.java +++ b/java/serve/src/main/java/io/ray/serve/deployment/DeploymentCreator.java @@ -79,7 +79,7 @@ public class DeploymentCreator { private boolean routed; - private DeploymentLanguage deploymentLanguage; + private DeploymentLanguage language; public Deployment create() { @@ -97,7 +97,7 @@ public class DeploymentCreator { .setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS) .setHealthCheckPeriodS(healthCheckPeriodS) .setHealthCheckTimeoutS(healthCheckTimeoutS) - .setDeploymentLanguage(deploymentLanguage); + .setDeploymentLanguage(language); return new Deployment( deploymentDef, @@ -246,11 +246,12 @@ public class DeploymentCreator { return this; } - public DeploymentLanguage getDeploymentLanguage() { - return deploymentLanguage; + public DeploymentLanguage getLanguage() { + return language; } - public void setDeploymentLanguage(DeploymentLanguage deploymentLanguage) { - this.deploymentLanguage = deploymentLanguage; + public DeploymentCreator setLanguage(DeploymentLanguage language) { + this.language = language; + return this; } } diff --git a/java/serve/src/main/java/io/ray/serve/router/ReplicaSet.java b/java/serve/src/main/java/io/ray/serve/router/ReplicaSet.java index 14da1db40..5533c21c0 100644 --- a/java/serve/src/main/java/io/ray/serve/router/ReplicaSet.java +++ b/java/serve/src/main/java/io/ray/serve/router/ReplicaSet.java @@ -3,14 +3,20 @@ package io.ray.serve.router; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import io.ray.api.ActorHandle; +import io.ray.api.BaseActorHandle; import io.ray.api.ObjectRef; +import io.ray.api.PyActorHandle; import io.ray.api.Ray; +import io.ray.api.function.PyActorMethod; import io.ray.runtime.metric.Gauge; import io.ray.runtime.metric.Metrics; import io.ray.runtime.metric.TagKey; +import io.ray.serve.api.Serve; import io.ray.serve.common.Constants; +import io.ray.serve.deployment.Deployment; import io.ray.serve.exception.RayServeException; import io.ray.serve.generated.ActorNameList; +import io.ray.serve.generated.DeploymentLanguage; import io.ray.serve.metrics.RayServeMetrics; import io.ray.serve.replica.RayServeWrappedReplica; import io.ray.serve.util.CollectionUtil; @@ -18,6 +24,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; @@ -31,7 +38,14 @@ public class ReplicaSet { private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaSet.class); - private final Map, Set>> inFlightQueries; + // The key is the name of the actor, and the value is a set of all flight queries objectrefs of + // the actor. + private final Map>> inFlightQueries; + + // Map the actor name to the handle of the actor. + private final Map allActorHandles; + + private DeploymentLanguage language; private AtomicInteger numQueuedQueries = new AtomicInteger(); @@ -41,6 +55,16 @@ public class ReplicaSet { public ReplicaSet(String deploymentName) { this.inFlightQueries = new ConcurrentHashMap<>(); + this.allActorHandles = new ConcurrentHashMap<>(); + try { + Deployment deployment = Serve.getDeployment(deploymentName); + this.language = deployment.getConfig().getDeploymentLanguage(); + } catch (Exception e) { + LOGGER.warn( + "Failed to get language from controller. Set it to Java as default value. The exception is ", + e); + this.language = DeploymentLanguage.JAVA; + } RayServeMetrics.execute( () -> this.numQueuedQueriesGauge = @@ -54,26 +78,26 @@ public class ReplicaSet { @SuppressWarnings("unchecked") public synchronized void updateWorkerReplicas(Object actorSet) { - List actorNames = ((ActorNameList) actorSet).getNamesList(); - Set> workerReplicas = new HashSet<>(); - if (!CollectionUtil.isEmpty(actorNames)) { - actorNames.forEach( - name -> - workerReplicas.add( - (ActorHandle) - Ray.getActor(name, Constants.SERVE_NAMESPACE).get())); - } - - Set> added = - new HashSet<>(Sets.difference(workerReplicas, inFlightQueries.keySet())); - Set> removed = - new HashSet<>(Sets.difference(inFlightQueries.keySet(), workerReplicas)); - - added.forEach(actorHandle -> inFlightQueries.put(actorHandle, Sets.newConcurrentHashSet())); - removed.forEach(inFlightQueries::remove); - - if (added.size() > 0 || removed.size() > 0) { - LOGGER.info("ReplicaSet: +{}, -{} replicas.", added.size(), removed.size()); + if (null != actorSet) { + Set actorNameSet = new HashSet<>(((ActorNameList) actorSet).getNamesList()); + Set added = new HashSet<>(Sets.difference(actorNameSet, inFlightQueries.keySet())); + Set removed = new HashSet<>(Sets.difference(inFlightQueries.keySet(), actorNameSet)); + added.forEach( + name -> { + Optional handleOptional = + Ray.getActor(name, Constants.SERVE_NAMESPACE); + if (handleOptional.isPresent()) { + allActorHandles.put(name, handleOptional.get()); + inFlightQueries.put(name, Sets.newConcurrentHashSet()); + } else { + LOGGER.warn("Can not get actor handle. actor name is {}", name); + } + }); + removed.forEach(inFlightQueries::remove); + removed.forEach(allActorHandles::remove); + if (added.size() > 0 || removed.size() > 0) { + LOGGER.info("ReplicaSet: +{}, -{} replicas.", added.size(), removed.size()); + } } hasPullReplica = true; } @@ -121,20 +145,29 @@ public class ReplicaSet { } loopCount++; } - List> handles = new ArrayList<>(inFlightQueries.keySet()); + List handles = new ArrayList<>(allActorHandles.values()); if (CollectionUtil.isEmpty(handles)) { throw new RayServeException("ReplicaSet found no replica."); } int randomIndex = RandomUtils.nextInt(0, handles.size()); - ActorHandle replica = + BaseActorHandle replica = handles.get(randomIndex); // TODO controll concurrency using maxConcurrentQueries LOGGER.debug("Assigned query {} to replica {}.", query.getMetadata().getRequestId(), replica); - return replica - .task(RayServeWrappedReplica::handleRequest, query.getMetadata(), query.getArgs()) - .remote(); + if (language == DeploymentLanguage.PYTHON) { + return ((PyActorHandle) replica) + .task( + PyActorMethod.of("handle_request_from_java"), + query.getMetadata().toByteArray(), + query.getArgs()) + .remote(); + } else { + return ((ActorHandle) replica) + .task(RayServeWrappedReplica::handleRequest, query.getMetadata(), query.getArgs()) + .remote(); + } } - public Map, Set>> getInFlightQueries() { + public Map>> getInFlightQueries() { return inFlightQueries; } } diff --git a/java/serve/src/main/resources/test_python_deployment.py b/java/serve/src/main/resources/test_python_deployment.py new file mode 100644 index 000000000..7df45ba6b --- /dev/null +++ b/java/serve/src/main/resources/test_python_deployment.py @@ -0,0 +1,20 @@ +# This file is used by CrossLanguageDeploymentTest.java to test cross-language +# invocation. +from ray import serve + + +def echo_server(v): + return v + + +@serve.deployment +class Counter(object): + def __init__(self, value): + self.value = int(value) + + def increase(self, delta): + self.value += int(delta) + return str(self.value) + + def reconfigure(self, value_str): + self.value = int(value_str) diff --git a/java/serve/src/test/java/io/ray/serve/CrossLanguageDeploymentTest.java b/java/serve/src/test/java/io/ray/serve/CrossLanguageDeploymentTest.java new file mode 100644 index 000000000..5ff3c4961 --- /dev/null +++ b/java/serve/src/test/java/io/ray/serve/CrossLanguageDeploymentTest.java @@ -0,0 +1,95 @@ +package io.ray.serve; + +import io.ray.api.Ray; +import io.ray.serve.api.Serve; +import io.ray.serve.deployment.Deployment; +import io.ray.serve.generated.DeploymentLanguage; +import io.ray.serve.handle.RayServeHandle; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.TimeUnit; +import org.apache.commons.io.FileUtils; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +@Test(groups = {"cluster"}) +public class CrossLanguageDeploymentTest extends BaseServeTest { + private static final String PYTHON_MODULE = "test_python_deployment"; + + @BeforeClass + public void beforeClass() { + // Delete and re-create the temp dir. + File tempDir = + new File( + System.getProperty("java.io.tmpdir") + + File.separator + + "ray_serve_cross_language_test"); + FileUtils.deleteQuietly(tempDir); + tempDir.mkdirs(); + tempDir.deleteOnExit(); + + // Write the test Python file to the temp dir. + InputStream in = + CrossLanguageDeploymentTest.class.getResourceAsStream( + File.separator + PYTHON_MODULE + ".py"); + File pythonFile = new File(tempDir.getAbsolutePath() + File.separator + PYTHON_MODULE + ".py"); + try { + FileUtils.copyInputStreamToFile(in, pythonFile); + } catch (IOException e) { + throw new RuntimeException(e); + } + + System.setProperty( + "ray.job.code-search-path", + System.getProperty("java.class.path") + File.pathSeparator + tempDir.getAbsolutePath()); + } + + @Test + public void createPyClassTest() { + Deployment deployment = + Serve.deployment() + .setLanguage(DeploymentLanguage.PYTHON) + .setName("createPyClassTest") + .setDeploymentDef(PYTHON_MODULE + ".Counter") + .setNumReplicas(1) + .setInitArgs(new Object[] {"28"}) + .create(); + + deployment.deploy(true); + Assert.assertEquals(Ray.get(deployment.getHandle().method("increase").remote("6")), "34"); + } + + @Test + public void createPyMethodTest() { + Deployment deployment = + Serve.deployment() + .setLanguage(DeploymentLanguage.PYTHON) + .setName("createPyMethodTest") + .setDeploymentDef(PYTHON_MODULE + ".echo_server") + .setNumReplicas(1) + .create(); + deployment.deploy(true); + RayServeHandle handle = deployment.getHandle(); + Assert.assertEquals(Ray.get(handle.method("__call__").remote("6")), "6"); + } + + @Test + public void userConfigTest() throws InterruptedException { + Deployment deployment = + Serve.deployment() + .setLanguage(DeploymentLanguage.PYTHON) + .setName("userConfigTest") + .setDeploymentDef(PYTHON_MODULE + ".Counter") + .setNumReplicas(1) + .setUserConfig("1") + .setInitArgs(new Object[] {"28"}) + .create(); + deployment.deploy(true); + Assert.assertEquals(Ray.get(deployment.getHandle().method("increase").remote("6")), "7"); + deployment.options().setUserConfig("3").create().deploy(true); + TimeUnit.SECONDS.sleep(20L); + Assert.assertEquals(Ray.get(deployment.getHandle().method("increase").remote("6")), "9"); + } +} diff --git a/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java b/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java index 76930f823..fa846affc 100644 --- a/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java +++ b/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java @@ -29,8 +29,7 @@ public class ReplicaSetTest extends BaseTest { ActorNameList.Builder builder = ActorNameList.newBuilder(); replicaSet.updateWorkerReplicas(builder.build()); - Map, Set>> inFlightQueries = - replicaSet.getInFlightQueries(); + Map>> inFlightQueries = replicaSet.getInFlightQueries(); Assert.assertTrue(inFlightQueries.isEmpty()); } diff --git a/java/serve/src/test/resources/test_deployment_cross_language_invocation.py b/java/serve/src/test/resources/test_deployment_cross_language_invocation.py deleted file mode 100644 index 989cb2057..000000000 --- a/java/serve/src/test/resources/test_deployment_cross_language_invocation.py +++ /dev/null @@ -1,24 +0,0 @@ -# This file is used by CrossLanguageInvocationTest.java to test cross-language -# invocation. - -import ray - - -@ray.remote -def py_return_input(v): - return v - - -@ray.remote -def py_func_python_raise_exception(): - 1 / 0 - - -@ray.remote -class Counter(object): - def __init__(self, value): - self.value = int(value) - - def increase(self, delta): - self.value += int(delta) - return str(self.value).encode("utf-8") diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index cd89a7204..512d9f4ec 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -46,6 +46,7 @@ from ray.serve._private.utils import ( format_actor_name, get_random_letters, msgpack_serialize, + msgpack_deserialize, ) from ray.serve._private.version import DeploymentVersion, VersionedReplica @@ -206,7 +207,9 @@ class ActorReplicaWrapper: # Populated in self.stop(). self._graceful_shutdown_ref: ObjectRef = None + # todo: will be confused with deployment_config.is_cross_language self._is_cross_language = False + self._deployment_is_cross_language = False @property def replica_tag(self) -> str: @@ -267,25 +270,51 @@ class ActorReplicaWrapper: ) self._actor_resources = deployment_info.replica_config.resource_dict + # it is currently not possible to create a placement group + # with no resources (https://github.com/ray-project/ray/issues/20401) + self._deployment_is_cross_language = ( + deployment_info.deployment_config.is_cross_language + ) + logger.debug( f"Starting replica {self.replica_tag} for deployment " f"{self.deployment_name}." ) actor_def = deployment_info.actor_def - init_args = ( - self.deployment_name, - self.replica_tag, - deployment_info.replica_config.serialized_deployment_def, - deployment_info.replica_config.serialized_init_args, - deployment_info.replica_config.serialized_init_kwargs, - deployment_info.deployment_config.to_proto_bytes(), - version, - self._controller_name, - self._detached, - ) - # TODO(simon): unify the constructor arguments across language if ( + deployment_info.deployment_config.deployment_language + == DeploymentLanguage.PYTHON + ): + if deployment_info.replica_config.serialized_init_args is None: + serialized_init_args = cloudpickle.dumps(()) + else: + serialized_init_args = ( + cloudpickle.dumps( + msgpack_deserialize( + deployment_info.replica_config.serialized_init_args + ) + ) + if self._deployment_is_cross_language + else deployment_info.replica_config.serialized_init_args + ) + init_args = ( + self.deployment_name, + self.replica_tag, + cloudpickle.dumps(deployment_info.replica_config.deployment_def) + if self._deployment_is_cross_language + else deployment_info.replica_config.serialized_deployment_def, + serialized_init_args, + deployment_info.replica_config.serialized_init_kwargs + if deployment_info.replica_config.serialized_init_kwargs + else cloudpickle.dumps({}), + deployment_info.deployment_config.to_proto_bytes(), + version, + self._controller_name, + self._detached, + ) + # TODO(simon): unify the constructor arguments across language + elif ( deployment_info.deployment_config.deployment_language == DeploymentLanguage.JAVA ): @@ -306,7 +335,7 @@ class ActorReplicaWrapper: deployment_info.replica_config.serialized_init_args ) ) - if deployment_info.deployment_config.is_cross_language + if self._deployment_is_cross_language else deployment_info.replica_config.serialized_init_args, # byte[] deploymentConfigBytes, deployment_info.deployment_config.to_proto_bytes(), @@ -329,16 +358,17 @@ class ActorReplicaWrapper: # Perform auto method name translation for java handles. # See https://github.com/ray-project/ray/issues/21474 + user_config = self._format_user_config( + deployment_info.deployment_config.user_config + ) if self._is_cross_language: self._actor_handle = JavaActorHandleProxy(self._actor_handle) self._allocated_obj_ref = self._actor_handle.is_allocated.remote() - self._ready_obj_ref = self._actor_handle.reconfigure.remote( - deployment_info.deployment_config.user_config - ) + self._ready_obj_ref = self._actor_handle.reconfigure.remote(user_config) else: self._allocated_obj_ref = self._actor_handle.is_allocated.remote() self._ready_obj_ref = self._actor_handle.reconfigure.remote( - deployment_info.deployment_config.user_config, + user_config, # Ensure that `is_allocated` will execute before `reconfigure`, # because `reconfigure` runs user code that could block the replica # asyncio loop. If that happens before `is_allocated` is executed, @@ -346,12 +376,23 @@ class ActorReplicaWrapper: self._allocated_obj_ref, ) + def _format_user_config(self, user_config: Any): + temp = copy(user_config) + if user_config is not None and self._deployment_is_cross_language: + if self._is_cross_language: + temp = msgpack_serialize(temp) + else: + temp = msgpack_deserialize(temp) + return temp + def update_user_config(self, user_config: Any): """ Update user config of existing actor behind current DeploymentReplica instance. """ - self._ready_obj_ref = self._actor_handle.reconfigure.remote(user_config) + self._ready_obj_ref = self._actor_handle.reconfigure.remote( + self._format_user_config(user_config) + ) def recover(self): """ @@ -424,8 +465,12 @@ class ActorReplicaWrapper: except Exception: logger.exception(f"Exception in deployment '{self._deployment_name}'") return ReplicaStartupStatus.FAILED, None - - return ReplicaStartupStatus.SUCCEEDED, version + if self._deployment_is_cross_language: + # todo: The replica's userconfig whitch java client created + # is different from the controller's userconfig + return ReplicaStartupStatus.SUCCEEDED, None + else: + return ReplicaStartupStatus.SUCCEEDED, version @property def actor_resources(self) -> Optional[Dict[str, float]]: diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 60ae85f5f..18129d6c4 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -36,6 +36,7 @@ from ray.serve._private.utils import ( parse_import_path, parse_request_item, wrap_to_ray_error, + merge_dict, ) from ray.serve._private.version import DeploymentVersion @@ -183,6 +184,24 @@ def create_replica_wrapper(name: str): query = Query(request_args, request_kwargs, request_metadata) return await self.replica.handle_request(query) + async def handle_request_from_java( + self, + proto_request_metadata: bytes, + *request_args, + **request_kwargs, + ): + from ray.serve.generated.serve_pb2 import ( + RequestMetadata as RequestMetadataProto, + ) + + proto = RequestMetadataProto.FromString(proto_request_metadata) + request_metadata: RequestMetadata = RequestMetadata( + proto.request_id, proto.endpoint, call_method=proto.call_method + ) + request_args = request_args[0] + query = Query(request_args, request_kwargs, request_metadata, return_num=1) + return await self.replica.handle_request(query) + async def is_allocated(self) -> str: """poke the replica to check whether it's alive. @@ -349,7 +368,11 @@ class RayServeReplica: method_stat = actor_stats.get( f"{_format_replica_actor_name(self.deployment_name)}.handle_request" ) - return method_stat + method_stat_java = actor_stats.get( + f"{_format_replica_actor_name(self.deployment_name)}" + f".handle_request_from_java" + ) + return merge_dict(method_stat, method_stat_java) def _collect_autoscaling_metrics(self): method_stat = self._get_handle_request_stats() @@ -487,9 +510,11 @@ class RayServeReplica: latency_ms=latency_ms, ) ) - - # Returns a small object for router to track request status. - return b"", result + if request.return_num == 1: + return result + else: + # Returns a small object for router to track request status. + return b"", result async def prepare_for_shutdown(self): """Perform graceful shutdown. diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 67d21707e..2f781ed8e 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -44,6 +44,7 @@ class Query: args: List[Any] kwargs: Dict[Any, Any] metadata: RequestMetadata + return_num: int = 2 class ReplicaSet: diff --git a/python/ray/serve/_private/utils.py b/python/ray/serve/_private/utils.py index 8082ac86f..efe79ea04 100644 --- a/python/ray/serve/_private/utils.py +++ b/python/ray/serve/_private/utils.py @@ -24,6 +24,7 @@ from ray.exceptions import RayTaskError from ray.serve._private.constants import HTTP_PROXY_TIMEOUT from ray.serve._private.http_util import HTTPRequestWrapper, build_starlette_request from ray.util.serialization import StandaloneSerializationContext +from ray._raylet import MessagePackSerializer import __main__ @@ -33,6 +34,7 @@ except ImportError: pd = None ACTOR_FAILURE_RETRY_TIMEOUT_S = 60 +MESSAGE_PACK_OFFSET = 9 # Use a global singleton enum to emulate default options. We cannot use None @@ -239,6 +241,28 @@ def msgpack_serialize(obj): return serialized +def msgpack_deserialize(data): + # todo: Ray does not provide a msgpack deserialization api. + try: + obj = MessagePackSerializer.loads(data[MESSAGE_PACK_OFFSET:], None) + except Exception: + raise + return obj + + +def merge_dict(dict1, dict2): + if dict1 is None and dict2 is None: + return None + if dict1 is None: + dict1 = dict() + if dict2 is None: + dict2 = dict() + result = dict() + for key in dict1.keys() | dict2.keys(): + result[key] = sum([e.get(key, 0) for e in (dict1, dict2)]) + return result + + def get_deployment_import_path( deployment, replace_main=False, enforce_importable=False ): diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 88f6c813e..c11d233d1 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -470,8 +470,8 @@ class ReplicaConfig: return ReplicaConfig( proto.deployment_def_name, proto.deployment_def, - proto.init_args, - proto.init_kwargs, + proto.init_args if proto.init_args != b"" else None, + proto.init_kwargs if proto.init_kwargs != b"" else None, json.loads(proto.ray_actor_options), needs_pickle, ) diff --git a/python/ray/serve/tests/test_util.py b/python/ray/serve/tests/test_util.py index 9bcba62b9..aec458778 100644 --- a/python/ray/serve/tests/test_util.py +++ b/python/ray/serve/tests/test_util.py @@ -14,9 +14,39 @@ from ray.serve._private.utils import ( get_deployment_import_path, override_runtime_envs_except_env_vars, serve_encoders, + merge_dict, + msgpack_serialize, + msgpack_deserialize, ) +def test_serialize(): + data = msgpack_serialize(5) + obj = msgpack_deserialize(data) + assert 5 == obj + + +def test_merge_dict(): + dict1 = {"pending": 1, "running": 1, "finished": 1} + dict2 = {"pending": 4, "finished": 1} + merge = merge_dict(dict1, dict2) + assert merge["pending"] == 5 + assert merge["running"] == 1 + assert merge["finished"] == 2 + dict1 = None + merge = merge_dict(dict1, dict2) + assert merge["pending"] == 4 + assert merge["finished"] == 1 + try: + assert merge["running"] == 1 + assert False + except KeyError: + assert True + dict2 = None + merge = merge_dict(dict1, dict2) + assert merge is None + + def test_bytes_encoder(): data_before = {"inp": {"nest": b"bytes"}} data_after = {"inp": {"nest": "bytes"}}