[serve][xlang]Support deploying Python deployment from Java. (#26877)

In the previously merged pr(https://github.com/ray-project/ray/pull/22726/commits), java serve's support for python deployment was not implemented. This PR is used to implement this feature.

Co-authored-by: nanqi.yxf <nanqi.yxf@antgroup.com>
This commit is contained in:
xiaofeng 2022-08-06 14:35:49 +08:00 committed by GitHub
parent 50e278f58b
commit 9f8b596aaa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 345 additions and 90 deletions

View file

@ -160,6 +160,7 @@ define_java_module(
"@maven//:com_google_code_gson_gson",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:commons_io_commons_io",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_apache_httpcomponents_client5_httpclient5",
"@maven//:org_apache_httpcomponents_core5_httpcore5",

View file

@ -193,15 +193,15 @@ public class DeploymentConfig implements Serializable {
io.ray.serve.generated.DeploymentConfig.newBuilder()
.setNumReplicas(numReplicas)
.setMaxConcurrentQueries(maxConcurrentQueries)
.setUserConfig(
ByteString.copyFrom(
MessagePackSerializer.encode(userConfig).getKey())) // TODO-xlang
.setGracefulShutdownWaitLoopS(gracefulShutdownWaitLoopS)
.setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS)
.setHealthCheckPeriodS(healthCheckPeriodS)
.setHealthCheckTimeoutS(healthCheckTimeoutS)
.setIsCrossLanguage(isCrossLanguage)
.setDeploymentLanguage(deploymentLanguage);
if (null != userConfig) {
builder.setUserConfig(ByteString.copyFrom(MessagePackSerializer.encode(userConfig).getKey()));
}
if (null != autoscalingConfig) {
builder.setAutoscalingConfig(autoscalingConfig.toProto());
}

View file

@ -142,10 +142,14 @@ public class ReplicaConfig {
if (proto == null) {
return null;
}
Object[] initArgs = null;
if (0 != proto.getInitArgs().toByteArray().length) {
initArgs = MessagePackSerializer.decode(proto.getInitArgs().toByteArray(), null);
}
ReplicaConfig replicaConfig =
new ReplicaConfig(
proto.getDeploymentDefName(),
MessagePackSerializer.decode(proto.getInitArgs().toByteArray(), null), // TODO-xlang
initArgs,
gson.fromJson(proto.getRayActorOptions(), Map.class));
return replicaConfig;
}

View file

@ -126,7 +126,8 @@ public class Deployment {
.setGracefulShutdownWaitLoopS(this.config.getGracefulShutdownWaitLoopS())
.setGracefulShutdownTimeoutS(this.config.getGracefulShutdownTimeoutS())
.setHealthCheckPeriodS(this.config.getHealthCheckPeriodS())
.setHealthCheckTimeoutS(this.config.getHealthCheckTimeoutS());
.setHealthCheckTimeoutS(this.config.getHealthCheckTimeoutS())
.setLanguage(this.config.getDeploymentLanguage());
}
public String getDeploymentDef() {

View file

@ -79,7 +79,7 @@ public class DeploymentCreator {
private boolean routed;
private DeploymentLanguage deploymentLanguage;
private DeploymentLanguage language;
public Deployment create() {
@ -97,7 +97,7 @@ public class DeploymentCreator {
.setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS)
.setHealthCheckPeriodS(healthCheckPeriodS)
.setHealthCheckTimeoutS(healthCheckTimeoutS)
.setDeploymentLanguage(deploymentLanguage);
.setDeploymentLanguage(language);
return new Deployment(
deploymentDef,
@ -246,11 +246,12 @@ public class DeploymentCreator {
return this;
}
public DeploymentLanguage getDeploymentLanguage() {
return deploymentLanguage;
public DeploymentLanguage getLanguage() {
return language;
}
public void setDeploymentLanguage(DeploymentLanguage deploymentLanguage) {
this.deploymentLanguage = deploymentLanguage;
public DeploymentCreator setLanguage(DeploymentLanguage language) {
this.language = language;
return this;
}
}

View file

@ -3,14 +3,20 @@ package io.ray.serve.router;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.Ray;
import io.ray.api.function.PyActorMethod;
import io.ray.runtime.metric.Gauge;
import io.ray.runtime.metric.Metrics;
import io.ray.runtime.metric.TagKey;
import io.ray.serve.api.Serve;
import io.ray.serve.common.Constants;
import io.ray.serve.deployment.Deployment;
import io.ray.serve.exception.RayServeException;
import io.ray.serve.generated.ActorNameList;
import io.ray.serve.generated.DeploymentLanguage;
import io.ray.serve.metrics.RayServeMetrics;
import io.ray.serve.replica.RayServeWrappedReplica;
import io.ray.serve.util.CollectionUtil;
@ -18,6 +24,7 @@ import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
@ -31,7 +38,14 @@ public class ReplicaSet {
private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaSet.class);
private final Map<ActorHandle<RayServeWrappedReplica>, Set<ObjectRef<Object>>> inFlightQueries;
// The key is the name of the actor, and the value is a set of all flight queries objectrefs of
// the actor.
private final Map<String, Set<ObjectRef<Object>>> inFlightQueries;
// Map the actor name to the handle of the actor.
private final Map<String, BaseActorHandle> allActorHandles;
private DeploymentLanguage language;
private AtomicInteger numQueuedQueries = new AtomicInteger();
@ -41,6 +55,16 @@ public class ReplicaSet {
public ReplicaSet(String deploymentName) {
this.inFlightQueries = new ConcurrentHashMap<>();
this.allActorHandles = new ConcurrentHashMap<>();
try {
Deployment deployment = Serve.getDeployment(deploymentName);
this.language = deployment.getConfig().getDeploymentLanguage();
} catch (Exception e) {
LOGGER.warn(
"Failed to get language from controller. Set it to Java as default value. The exception is ",
e);
this.language = DeploymentLanguage.JAVA;
}
RayServeMetrics.execute(
() ->
this.numQueuedQueriesGauge =
@ -54,26 +78,26 @@ public class ReplicaSet {
@SuppressWarnings("unchecked")
public synchronized void updateWorkerReplicas(Object actorSet) {
List<String> actorNames = ((ActorNameList) actorSet).getNamesList();
Set<ActorHandle<RayServeWrappedReplica>> workerReplicas = new HashSet<>();
if (!CollectionUtil.isEmpty(actorNames)) {
actorNames.forEach(
name ->
workerReplicas.add(
(ActorHandle<RayServeWrappedReplica>)
Ray.getActor(name, Constants.SERVE_NAMESPACE).get()));
}
Set<ActorHandle<RayServeWrappedReplica>> added =
new HashSet<>(Sets.difference(workerReplicas, inFlightQueries.keySet()));
Set<ActorHandle<RayServeWrappedReplica>> removed =
new HashSet<>(Sets.difference(inFlightQueries.keySet(), workerReplicas));
added.forEach(actorHandle -> inFlightQueries.put(actorHandle, Sets.newConcurrentHashSet()));
removed.forEach(inFlightQueries::remove);
if (added.size() > 0 || removed.size() > 0) {
LOGGER.info("ReplicaSet: +{}, -{} replicas.", added.size(), removed.size());
if (null != actorSet) {
Set<String> actorNameSet = new HashSet<>(((ActorNameList) actorSet).getNamesList());
Set<String> added = new HashSet<>(Sets.difference(actorNameSet, inFlightQueries.keySet()));
Set<String> removed = new HashSet<>(Sets.difference(inFlightQueries.keySet(), actorNameSet));
added.forEach(
name -> {
Optional<BaseActorHandle> handleOptional =
Ray.getActor(name, Constants.SERVE_NAMESPACE);
if (handleOptional.isPresent()) {
allActorHandles.put(name, handleOptional.get());
inFlightQueries.put(name, Sets.newConcurrentHashSet());
} else {
LOGGER.warn("Can not get actor handle. actor name is {}", name);
}
});
removed.forEach(inFlightQueries::remove);
removed.forEach(allActorHandles::remove);
if (added.size() > 0 || removed.size() > 0) {
LOGGER.info("ReplicaSet: +{}, -{} replicas.", added.size(), removed.size());
}
}
hasPullReplica = true;
}
@ -121,20 +145,29 @@ public class ReplicaSet {
}
loopCount++;
}
List<ActorHandle<RayServeWrappedReplica>> handles = new ArrayList<>(inFlightQueries.keySet());
List<BaseActorHandle> handles = new ArrayList<>(allActorHandles.values());
if (CollectionUtil.isEmpty(handles)) {
throw new RayServeException("ReplicaSet found no replica.");
}
int randomIndex = RandomUtils.nextInt(0, handles.size());
ActorHandle<RayServeWrappedReplica> replica =
BaseActorHandle replica =
handles.get(randomIndex); // TODO controll concurrency using maxConcurrentQueries
LOGGER.debug("Assigned query {} to replica {}.", query.getMetadata().getRequestId(), replica);
return replica
.task(RayServeWrappedReplica::handleRequest, query.getMetadata(), query.getArgs())
.remote();
if (language == DeploymentLanguage.PYTHON) {
return ((PyActorHandle) replica)
.task(
PyActorMethod.of("handle_request_from_java"),
query.getMetadata().toByteArray(),
query.getArgs())
.remote();
} else {
return ((ActorHandle<RayServeWrappedReplica>) replica)
.task(RayServeWrappedReplica::handleRequest, query.getMetadata(), query.getArgs())
.remote();
}
}
public Map<ActorHandle<RayServeWrappedReplica>, Set<ObjectRef<Object>>> getInFlightQueries() {
public Map<String, Set<ObjectRef<Object>>> getInFlightQueries() {
return inFlightQueries;
}
}

View file

@ -0,0 +1,20 @@
# This file is used by CrossLanguageDeploymentTest.java to test cross-language
# invocation.
from ray import serve
def echo_server(v):
return v
@serve.deployment
class Counter(object):
def __init__(self, value):
self.value = int(value)
def increase(self, delta):
self.value += int(delta)
return str(self.value)
def reconfigure(self, value_str):
self.value = int(value_str)

View file

@ -0,0 +1,95 @@
package io.ray.serve;
import io.ray.api.Ray;
import io.ray.serve.api.Serve;
import io.ray.serve.deployment.Deployment;
import io.ray.serve.generated.DeploymentLanguage;
import io.ray.serve.handle.RayServeHandle;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.TimeUnit;
import org.apache.commons.io.FileUtils;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
@Test(groups = {"cluster"})
public class CrossLanguageDeploymentTest extends BaseServeTest {
private static final String PYTHON_MODULE = "test_python_deployment";
@BeforeClass
public void beforeClass() {
// Delete and re-create the temp dir.
File tempDir =
new File(
System.getProperty("java.io.tmpdir")
+ File.separator
+ "ray_serve_cross_language_test");
FileUtils.deleteQuietly(tempDir);
tempDir.mkdirs();
tempDir.deleteOnExit();
// Write the test Python file to the temp dir.
InputStream in =
CrossLanguageDeploymentTest.class.getResourceAsStream(
File.separator + PYTHON_MODULE + ".py");
File pythonFile = new File(tempDir.getAbsolutePath() + File.separator + PYTHON_MODULE + ".py");
try {
FileUtils.copyInputStreamToFile(in, pythonFile);
} catch (IOException e) {
throw new RuntimeException(e);
}
System.setProperty(
"ray.job.code-search-path",
System.getProperty("java.class.path") + File.pathSeparator + tempDir.getAbsolutePath());
}
@Test
public void createPyClassTest() {
Deployment deployment =
Serve.deployment()
.setLanguage(DeploymentLanguage.PYTHON)
.setName("createPyClassTest")
.setDeploymentDef(PYTHON_MODULE + ".Counter")
.setNumReplicas(1)
.setInitArgs(new Object[] {"28"})
.create();
deployment.deploy(true);
Assert.assertEquals(Ray.get(deployment.getHandle().method("increase").remote("6")), "34");
}
@Test
public void createPyMethodTest() {
Deployment deployment =
Serve.deployment()
.setLanguage(DeploymentLanguage.PYTHON)
.setName("createPyMethodTest")
.setDeploymentDef(PYTHON_MODULE + ".echo_server")
.setNumReplicas(1)
.create();
deployment.deploy(true);
RayServeHandle handle = deployment.getHandle();
Assert.assertEquals(Ray.get(handle.method("__call__").remote("6")), "6");
}
@Test
public void userConfigTest() throws InterruptedException {
Deployment deployment =
Serve.deployment()
.setLanguage(DeploymentLanguage.PYTHON)
.setName("userConfigTest")
.setDeploymentDef(PYTHON_MODULE + ".Counter")
.setNumReplicas(1)
.setUserConfig("1")
.setInitArgs(new Object[] {"28"})
.create();
deployment.deploy(true);
Assert.assertEquals(Ray.get(deployment.getHandle().method("increase").remote("6")), "7");
deployment.options().setUserConfig("3").create().deploy(true);
TimeUnit.SECONDS.sleep(20L);
Assert.assertEquals(Ray.get(deployment.getHandle().method("increase").remote("6")), "9");
}
}

View file

@ -29,8 +29,7 @@ public class ReplicaSetTest extends BaseTest {
ActorNameList.Builder builder = ActorNameList.newBuilder();
replicaSet.updateWorkerReplicas(builder.build());
Map<ActorHandle<RayServeWrappedReplica>, Set<ObjectRef<Object>>> inFlightQueries =
replicaSet.getInFlightQueries();
Map<String, Set<ObjectRef<Object>>> inFlightQueries = replicaSet.getInFlightQueries();
Assert.assertTrue(inFlightQueries.isEmpty());
}

View file

@ -1,24 +0,0 @@
# This file is used by CrossLanguageInvocationTest.java to test cross-language
# invocation.
import ray
@ray.remote
def py_return_input(v):
return v
@ray.remote
def py_func_python_raise_exception():
1 / 0
@ray.remote
class Counter(object):
def __init__(self, value):
self.value = int(value)
def increase(self, delta):
self.value += int(delta)
return str(self.value).encode("utf-8")

View file

@ -46,6 +46,7 @@ from ray.serve._private.utils import (
format_actor_name,
get_random_letters,
msgpack_serialize,
msgpack_deserialize,
)
from ray.serve._private.version import DeploymentVersion, VersionedReplica
@ -206,7 +207,9 @@ class ActorReplicaWrapper:
# Populated in self.stop().
self._graceful_shutdown_ref: ObjectRef = None
# todo: will be confused with deployment_config.is_cross_language
self._is_cross_language = False
self._deployment_is_cross_language = False
@property
def replica_tag(self) -> str:
@ -267,25 +270,51 @@ class ActorReplicaWrapper:
)
self._actor_resources = deployment_info.replica_config.resource_dict
# it is currently not possible to create a placement group
# with no resources (https://github.com/ray-project/ray/issues/20401)
self._deployment_is_cross_language = (
deployment_info.deployment_config.is_cross_language
)
logger.debug(
f"Starting replica {self.replica_tag} for deployment "
f"{self.deployment_name}."
)
actor_def = deployment_info.actor_def
init_args = (
self.deployment_name,
self.replica_tag,
deployment_info.replica_config.serialized_deployment_def,
deployment_info.replica_config.serialized_init_args,
deployment_info.replica_config.serialized_init_kwargs,
deployment_info.deployment_config.to_proto_bytes(),
version,
self._controller_name,
self._detached,
)
# TODO(simon): unify the constructor arguments across language
if (
deployment_info.deployment_config.deployment_language
== DeploymentLanguage.PYTHON
):
if deployment_info.replica_config.serialized_init_args is None:
serialized_init_args = cloudpickle.dumps(())
else:
serialized_init_args = (
cloudpickle.dumps(
msgpack_deserialize(
deployment_info.replica_config.serialized_init_args
)
)
if self._deployment_is_cross_language
else deployment_info.replica_config.serialized_init_args
)
init_args = (
self.deployment_name,
self.replica_tag,
cloudpickle.dumps(deployment_info.replica_config.deployment_def)
if self._deployment_is_cross_language
else deployment_info.replica_config.serialized_deployment_def,
serialized_init_args,
deployment_info.replica_config.serialized_init_kwargs
if deployment_info.replica_config.serialized_init_kwargs
else cloudpickle.dumps({}),
deployment_info.deployment_config.to_proto_bytes(),
version,
self._controller_name,
self._detached,
)
# TODO(simon): unify the constructor arguments across language
elif (
deployment_info.deployment_config.deployment_language
== DeploymentLanguage.JAVA
):
@ -306,7 +335,7 @@ class ActorReplicaWrapper:
deployment_info.replica_config.serialized_init_args
)
)
if deployment_info.deployment_config.is_cross_language
if self._deployment_is_cross_language
else deployment_info.replica_config.serialized_init_args,
# byte[] deploymentConfigBytes,
deployment_info.deployment_config.to_proto_bytes(),
@ -329,16 +358,17 @@ class ActorReplicaWrapper:
# Perform auto method name translation for java handles.
# See https://github.com/ray-project/ray/issues/21474
user_config = self._format_user_config(
deployment_info.deployment_config.user_config
)
if self._is_cross_language:
self._actor_handle = JavaActorHandleProxy(self._actor_handle)
self._allocated_obj_ref = self._actor_handle.is_allocated.remote()
self._ready_obj_ref = self._actor_handle.reconfigure.remote(
deployment_info.deployment_config.user_config
)
self._ready_obj_ref = self._actor_handle.reconfigure.remote(user_config)
else:
self._allocated_obj_ref = self._actor_handle.is_allocated.remote()
self._ready_obj_ref = self._actor_handle.reconfigure.remote(
deployment_info.deployment_config.user_config,
user_config,
# Ensure that `is_allocated` will execute before `reconfigure`,
# because `reconfigure` runs user code that could block the replica
# asyncio loop. If that happens before `is_allocated` is executed,
@ -346,12 +376,23 @@ class ActorReplicaWrapper:
self._allocated_obj_ref,
)
def _format_user_config(self, user_config: Any):
temp = copy(user_config)
if user_config is not None and self._deployment_is_cross_language:
if self._is_cross_language:
temp = msgpack_serialize(temp)
else:
temp = msgpack_deserialize(temp)
return temp
def update_user_config(self, user_config: Any):
"""
Update user config of existing actor behind current
DeploymentReplica instance.
"""
self._ready_obj_ref = self._actor_handle.reconfigure.remote(user_config)
self._ready_obj_ref = self._actor_handle.reconfigure.remote(
self._format_user_config(user_config)
)
def recover(self):
"""
@ -424,8 +465,12 @@ class ActorReplicaWrapper:
except Exception:
logger.exception(f"Exception in deployment '{self._deployment_name}'")
return ReplicaStartupStatus.FAILED, None
return ReplicaStartupStatus.SUCCEEDED, version
if self._deployment_is_cross_language:
# todo: The replica's userconfig whitch java client created
# is different from the controller's userconfig
return ReplicaStartupStatus.SUCCEEDED, None
else:
return ReplicaStartupStatus.SUCCEEDED, version
@property
def actor_resources(self) -> Optional[Dict[str, float]]:

View file

@ -36,6 +36,7 @@ from ray.serve._private.utils import (
parse_import_path,
parse_request_item,
wrap_to_ray_error,
merge_dict,
)
from ray.serve._private.version import DeploymentVersion
@ -183,6 +184,24 @@ def create_replica_wrapper(name: str):
query = Query(request_args, request_kwargs, request_metadata)
return await self.replica.handle_request(query)
async def handle_request_from_java(
self,
proto_request_metadata: bytes,
*request_args,
**request_kwargs,
):
from ray.serve.generated.serve_pb2 import (
RequestMetadata as RequestMetadataProto,
)
proto = RequestMetadataProto.FromString(proto_request_metadata)
request_metadata: RequestMetadata = RequestMetadata(
proto.request_id, proto.endpoint, call_method=proto.call_method
)
request_args = request_args[0]
query = Query(request_args, request_kwargs, request_metadata, return_num=1)
return await self.replica.handle_request(query)
async def is_allocated(self) -> str:
"""poke the replica to check whether it's alive.
@ -349,7 +368,11 @@ class RayServeReplica:
method_stat = actor_stats.get(
f"{_format_replica_actor_name(self.deployment_name)}.handle_request"
)
return method_stat
method_stat_java = actor_stats.get(
f"{_format_replica_actor_name(self.deployment_name)}"
f".handle_request_from_java"
)
return merge_dict(method_stat, method_stat_java)
def _collect_autoscaling_metrics(self):
method_stat = self._get_handle_request_stats()
@ -487,9 +510,11 @@ class RayServeReplica:
latency_ms=latency_ms,
)
)
# Returns a small object for router to track request status.
return b"", result
if request.return_num == 1:
return result
else:
# Returns a small object for router to track request status.
return b"", result
async def prepare_for_shutdown(self):
"""Perform graceful shutdown.

View file

@ -44,6 +44,7 @@ class Query:
args: List[Any]
kwargs: Dict[Any, Any]
metadata: RequestMetadata
return_num: int = 2
class ReplicaSet:

View file

@ -24,6 +24,7 @@ from ray.exceptions import RayTaskError
from ray.serve._private.constants import HTTP_PROXY_TIMEOUT
from ray.serve._private.http_util import HTTPRequestWrapper, build_starlette_request
from ray.util.serialization import StandaloneSerializationContext
from ray._raylet import MessagePackSerializer
import __main__
@ -33,6 +34,7 @@ except ImportError:
pd = None
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
MESSAGE_PACK_OFFSET = 9
# Use a global singleton enum to emulate default options. We cannot use None
@ -239,6 +241,28 @@ def msgpack_serialize(obj):
return serialized
def msgpack_deserialize(data):
# todo: Ray does not provide a msgpack deserialization api.
try:
obj = MessagePackSerializer.loads(data[MESSAGE_PACK_OFFSET:], None)
except Exception:
raise
return obj
def merge_dict(dict1, dict2):
if dict1 is None and dict2 is None:
return None
if dict1 is None:
dict1 = dict()
if dict2 is None:
dict2 = dict()
result = dict()
for key in dict1.keys() | dict2.keys():
result[key] = sum([e.get(key, 0) for e in (dict1, dict2)])
return result
def get_deployment_import_path(
deployment, replace_main=False, enforce_importable=False
):

View file

@ -470,8 +470,8 @@ class ReplicaConfig:
return ReplicaConfig(
proto.deployment_def_name,
proto.deployment_def,
proto.init_args,
proto.init_kwargs,
proto.init_args if proto.init_args != b"" else None,
proto.init_kwargs if proto.init_kwargs != b"" else None,
json.loads(proto.ray_actor_options),
needs_pickle,
)

View file

@ -14,9 +14,39 @@ from ray.serve._private.utils import (
get_deployment_import_path,
override_runtime_envs_except_env_vars,
serve_encoders,
merge_dict,
msgpack_serialize,
msgpack_deserialize,
)
def test_serialize():
data = msgpack_serialize(5)
obj = msgpack_deserialize(data)
assert 5 == obj
def test_merge_dict():
dict1 = {"pending": 1, "running": 1, "finished": 1}
dict2 = {"pending": 4, "finished": 1}
merge = merge_dict(dict1, dict2)
assert merge["pending"] == 5
assert merge["running"] == 1
assert merge["finished"] == 2
dict1 = None
merge = merge_dict(dict1, dict2)
assert merge["pending"] == 4
assert merge["finished"] == 1
try:
assert merge["running"] == 1
assert False
except KeyError:
assert True
dict2 = None
merge = merge_dict(dict1, dict2)
assert merge is None
def test_bytes_encoder():
data_before = {"inp": {"nest": b"bytes"}}
data_after = {"inp": {"nest": "bytes"}}