Revert "Revert "[serve][xlang]Support deploying Python deployment from Java. …" (#27945)

This commit is contained in:
xiaofeng 2022-08-19 08:57:37 +08:00 committed by GitHub
parent a6b7189ab3
commit af488e1cc2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 354 additions and 95 deletions

View file

@ -160,6 +160,7 @@ define_java_module(
"@maven//:com_google_code_gson_gson", "@maven//:com_google_code_gson_gson",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java", "@maven//:com_google_protobuf_protobuf_java",
"@maven//:commons_io_commons_io",
"@maven//:org_apache_commons_commons_lang3", "@maven//:org_apache_commons_commons_lang3",
"@maven//:org_apache_httpcomponents_client5_httpclient5", "@maven//:org_apache_httpcomponents_client5_httpclient5",
"@maven//:org_apache_httpcomponents_client5_httpclient5_fluent", "@maven//:org_apache_httpcomponents_client5_httpclient5_fluent",

View file

@ -16,6 +16,7 @@ import io.ray.serve.deployment.DeploymentCreator;
import io.ray.serve.deployment.DeploymentRoute; import io.ray.serve.deployment.DeploymentRoute;
import io.ray.serve.exception.RayServeException; import io.ray.serve.exception.RayServeException;
import io.ray.serve.generated.ActorNameList; import io.ray.serve.generated.ActorNameList;
import io.ray.serve.poll.LongPollClientFactory;
import io.ray.serve.replica.ReplicaContext; import io.ray.serve.replica.ReplicaContext;
import io.ray.serve.util.CollectionUtil; import io.ray.serve.util.CollectionUtil;
import io.ray.serve.util.CommonUtil; import io.ray.serve.util.CommonUtil;
@ -143,7 +144,10 @@ public class Serve {
} }
client.shutdown(); client.shutdown();
LongPollClientFactory.stop();
LongPollClientFactory.clearAllCache();
setGlobalClient(null); setGlobalClient(null);
setInternalReplicaContext(null);
} }
/** /**

View file

@ -193,15 +193,15 @@ public class DeploymentConfig implements Serializable {
io.ray.serve.generated.DeploymentConfig.newBuilder() io.ray.serve.generated.DeploymentConfig.newBuilder()
.setNumReplicas(numReplicas) .setNumReplicas(numReplicas)
.setMaxConcurrentQueries(maxConcurrentQueries) .setMaxConcurrentQueries(maxConcurrentQueries)
.setUserConfig(
ByteString.copyFrom(
MessagePackSerializer.encode(userConfig).getKey())) // TODO-xlang
.setGracefulShutdownWaitLoopS(gracefulShutdownWaitLoopS) .setGracefulShutdownWaitLoopS(gracefulShutdownWaitLoopS)
.setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS) .setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS)
.setHealthCheckPeriodS(healthCheckPeriodS) .setHealthCheckPeriodS(healthCheckPeriodS)
.setHealthCheckTimeoutS(healthCheckTimeoutS) .setHealthCheckTimeoutS(healthCheckTimeoutS)
.setIsCrossLanguage(isCrossLanguage) .setIsCrossLanguage(isCrossLanguage)
.setDeploymentLanguage(deploymentLanguage); .setDeploymentLanguage(deploymentLanguage);
if (null != userConfig) {
builder.setUserConfig(ByteString.copyFrom(MessagePackSerializer.encode(userConfig).getKey()));
}
if (null != autoscalingConfig) { if (null != autoscalingConfig) {
builder.setAutoscalingConfig(autoscalingConfig.toProto()); builder.setAutoscalingConfig(autoscalingConfig.toProto());
} }

View file

@ -142,10 +142,14 @@ public class ReplicaConfig {
if (proto == null) { if (proto == null) {
return null; return null;
} }
Object[] initArgs = null;
if (0 != proto.getInitArgs().toByteArray().length) {
initArgs = MessagePackSerializer.decode(proto.getInitArgs().toByteArray(), null);
}
ReplicaConfig replicaConfig = ReplicaConfig replicaConfig =
new ReplicaConfig( new ReplicaConfig(
proto.getDeploymentDefName(), proto.getDeploymentDefName(),
MessagePackSerializer.decode(proto.getInitArgs().toByteArray(), null), // TODO-xlang initArgs,
gson.fromJson(proto.getRayActorOptions(), Map.class)); gson.fromJson(proto.getRayActorOptions(), Map.class));
return replicaConfig; return replicaConfig;
} }

View file

@ -126,7 +126,8 @@ public class Deployment {
.setGracefulShutdownWaitLoopS(this.config.getGracefulShutdownWaitLoopS()) .setGracefulShutdownWaitLoopS(this.config.getGracefulShutdownWaitLoopS())
.setGracefulShutdownTimeoutS(this.config.getGracefulShutdownTimeoutS()) .setGracefulShutdownTimeoutS(this.config.getGracefulShutdownTimeoutS())
.setHealthCheckPeriodS(this.config.getHealthCheckPeriodS()) .setHealthCheckPeriodS(this.config.getHealthCheckPeriodS())
.setHealthCheckTimeoutS(this.config.getHealthCheckTimeoutS()); .setHealthCheckTimeoutS(this.config.getHealthCheckTimeoutS())
.setLanguage(this.config.getDeploymentLanguage());
} }
public String getDeploymentDef() { public String getDeploymentDef() {

View file

@ -79,7 +79,7 @@ public class DeploymentCreator {
private boolean routed; private boolean routed;
private DeploymentLanguage deploymentLanguage; private DeploymentLanguage language;
public Deployment create() { public Deployment create() {
@ -97,7 +97,7 @@ public class DeploymentCreator {
.setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS) .setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS)
.setHealthCheckPeriodS(healthCheckPeriodS) .setHealthCheckPeriodS(healthCheckPeriodS)
.setHealthCheckTimeoutS(healthCheckTimeoutS) .setHealthCheckTimeoutS(healthCheckTimeoutS)
.setDeploymentLanguage(deploymentLanguage); .setDeploymentLanguage(language);
return new Deployment( return new Deployment(
deploymentDef, deploymentDef,
@ -246,11 +246,12 @@ public class DeploymentCreator {
return this; return this;
} }
public DeploymentLanguage getDeploymentLanguage() { public DeploymentLanguage getLanguage() {
return deploymentLanguage; return language;
} }
public void setDeploymentLanguage(DeploymentLanguage deploymentLanguage) { public DeploymentCreator setLanguage(DeploymentLanguage language) {
this.deploymentLanguage = deploymentLanguage; this.language = language;
return this;
} }
} }

View file

@ -220,7 +220,7 @@ public class LongPollClientFactory {
scheduledExecutorService.shutdown(); scheduledExecutorService.shutdown();
} }
inited = false; inited = false;
LOGGER.info("LongPollClient was shopped."); LOGGER.info("LongPollClient was stopped.");
} }
public static boolean isInitialized() { public static boolean isInitialized() {

View file

@ -3,14 +3,20 @@ package io.ray.serve.router;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import io.ray.api.ActorHandle; import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef; import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.Ray; import io.ray.api.Ray;
import io.ray.api.function.PyActorMethod;
import io.ray.runtime.metric.Gauge; import io.ray.runtime.metric.Gauge;
import io.ray.runtime.metric.Metrics; import io.ray.runtime.metric.Metrics;
import io.ray.runtime.metric.TagKey; import io.ray.runtime.metric.TagKey;
import io.ray.serve.api.Serve;
import io.ray.serve.common.Constants; import io.ray.serve.common.Constants;
import io.ray.serve.deployment.Deployment;
import io.ray.serve.exception.RayServeException; import io.ray.serve.exception.RayServeException;
import io.ray.serve.generated.ActorNameList; import io.ray.serve.generated.ActorNameList;
import io.ray.serve.generated.DeploymentLanguage;
import io.ray.serve.metrics.RayServeMetrics; import io.ray.serve.metrics.RayServeMetrics;
import io.ray.serve.replica.RayServeWrappedReplica; import io.ray.serve.replica.RayServeWrappedReplica;
import io.ray.serve.util.CollectionUtil; import io.ray.serve.util.CollectionUtil;
@ -18,6 +24,7 @@ import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -31,7 +38,14 @@ public class ReplicaSet {
private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaSet.class); private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaSet.class);
private final Map<ActorHandle<RayServeWrappedReplica>, Set<ObjectRef<Object>>> inFlightQueries; // The key is the name of the actor, and the value is a set of all flight queries objectrefs of
// the actor.
private final Map<String, Set<ObjectRef<Object>>> inFlightQueries;
// Map the actor name to the handle of the actor.
private final Map<String, BaseActorHandle> allActorHandles;
private DeploymentLanguage language;
private AtomicInteger numQueuedQueries = new AtomicInteger(); private AtomicInteger numQueuedQueries = new AtomicInteger();
@ -41,6 +55,16 @@ public class ReplicaSet {
public ReplicaSet(String deploymentName) { public ReplicaSet(String deploymentName) {
this.inFlightQueries = new ConcurrentHashMap<>(); this.inFlightQueries = new ConcurrentHashMap<>();
this.allActorHandles = new ConcurrentHashMap<>();
try {
Deployment deployment = Serve.getDeployment(deploymentName);
this.language = deployment.getConfig().getDeploymentLanguage();
} catch (Exception e) {
LOGGER.warn(
"Failed to get language from controller. Set it to Java as default value. The exception is ",
e);
this.language = DeploymentLanguage.JAVA;
}
RayServeMetrics.execute( RayServeMetrics.execute(
() -> () ->
this.numQueuedQueriesGauge = this.numQueuedQueriesGauge =
@ -54,27 +78,27 @@ public class ReplicaSet {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public synchronized void updateWorkerReplicas(Object actorSet) { public synchronized void updateWorkerReplicas(Object actorSet) {
List<String> actorNames = ((ActorNameList) actorSet).getNamesList(); if (null != actorSet) {
Set<ActorHandle<RayServeWrappedReplica>> workerReplicas = new HashSet<>(); Set<String> actorNameSet = new HashSet<>(((ActorNameList) actorSet).getNamesList());
if (!CollectionUtil.isEmpty(actorNames)) { Set<String> added = new HashSet<>(Sets.difference(actorNameSet, inFlightQueries.keySet()));
actorNames.forEach( Set<String> removed = new HashSet<>(Sets.difference(inFlightQueries.keySet(), actorNameSet));
name -> added.forEach(
workerReplicas.add( name -> {
(ActorHandle<RayServeWrappedReplica>) Optional<BaseActorHandle> handleOptional =
Ray.getActor(name, Constants.SERVE_NAMESPACE).get())); Ray.getActor(name, Constants.SERVE_NAMESPACE);
if (handleOptional.isPresent()) {
allActorHandles.put(name, handleOptional.get());
inFlightQueries.put(name, Sets.newConcurrentHashSet());
} else {
LOGGER.warn("Can not get actor handle. actor name is {}", name);
} }
});
Set<ActorHandle<RayServeWrappedReplica>> added =
new HashSet<>(Sets.difference(workerReplicas, inFlightQueries.keySet()));
Set<ActorHandle<RayServeWrappedReplica>> removed =
new HashSet<>(Sets.difference(inFlightQueries.keySet(), workerReplicas));
added.forEach(actorHandle -> inFlightQueries.put(actorHandle, Sets.newConcurrentHashSet()));
removed.forEach(inFlightQueries::remove); removed.forEach(inFlightQueries::remove);
removed.forEach(allActorHandles::remove);
if (added.size() > 0 || removed.size() > 0) { if (added.size() > 0 || removed.size() > 0) {
LOGGER.info("ReplicaSet: +{}, -{} replicas.", added.size(), removed.size()); LOGGER.info("ReplicaSet: +{}, -{} replicas.", added.size(), removed.size());
} }
}
hasPullReplica = true; hasPullReplica = true;
} }
@ -121,20 +145,29 @@ public class ReplicaSet {
} }
loopCount++; loopCount++;
} }
List<ActorHandle<RayServeWrappedReplica>> handles = new ArrayList<>(inFlightQueries.keySet()); List<BaseActorHandle> handles = new ArrayList<>(allActorHandles.values());
if (CollectionUtil.isEmpty(handles)) { if (CollectionUtil.isEmpty(handles)) {
throw new RayServeException("ReplicaSet found no replica."); throw new RayServeException("ReplicaSet found no replica.");
} }
int randomIndex = RandomUtils.nextInt(0, handles.size()); int randomIndex = RandomUtils.nextInt(0, handles.size());
ActorHandle<RayServeWrappedReplica> replica = BaseActorHandle replica =
handles.get(randomIndex); // TODO controll concurrency using maxConcurrentQueries handles.get(randomIndex); // TODO controll concurrency using maxConcurrentQueries
LOGGER.debug("Assigned query {} to replica {}.", query.getMetadata().getRequestId(), replica); LOGGER.debug("Assigned query {} to replica {}.", query.getMetadata().getRequestId(), replica);
return replica if (language == DeploymentLanguage.PYTHON) {
return ((PyActorHandle) replica)
.task(
PyActorMethod.of("handle_request_from_java"),
query.getMetadata().toByteArray(),
query.getArgs())
.remote();
} else {
return ((ActorHandle<RayServeWrappedReplica>) replica)
.task(RayServeWrappedReplica::handleRequest, query.getMetadata(), query.getArgs()) .task(RayServeWrappedReplica::handleRequest, query.getMetadata(), query.getArgs())
.remote(); .remote();
} }
}
public Map<ActorHandle<RayServeWrappedReplica>, Set<ObjectRef<Object>>> getInFlightQueries() { public Map<String, Set<ObjectRef<Object>>> getInFlightQueries() {
return inFlightQueries; return inFlightQueries;
} }
} }

View file

@ -0,0 +1,20 @@
# This file is used by CrossLanguageDeploymentTest.java to test cross-language
# invocation.
from ray import serve
def echo_server(v):
return v
@serve.deployment
class Counter(object):
def __init__(self, value):
self.value = int(value)
def increase(self, delta):
self.value += int(delta)
return str(self.value)
def reconfigure(self, value_str):
self.value = int(value_str)

View file

@ -4,7 +4,6 @@ import io.ray.api.Ray;
import io.ray.serve.api.Serve; import io.ray.serve.api.Serve;
import io.ray.serve.api.ServeControllerClient; import io.ray.serve.api.ServeControllerClient;
import io.ray.serve.config.RayServeConfig; import io.ray.serve.config.RayServeConfig;
import io.ray.serve.poll.LongPollClientFactory;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Map; import java.util.Map;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -38,8 +37,5 @@ public abstract class BaseServeTest {
} catch (Exception e) { } catch (Exception e) {
LOGGER.error("ray shutdown error", e); LOGGER.error("ray shutdown error", e);
} }
LongPollClientFactory.stop();
LongPollClientFactory.clearAllCache();
Serve.setInternalReplicaContext(null);
} }
} }

View file

@ -3,6 +3,7 @@ package io.ray.serve;
import io.ray.api.Ray; import io.ray.api.Ray;
import io.ray.serve.api.Serve; import io.ray.serve.api.Serve;
import io.ray.serve.common.Constants; import io.ray.serve.common.Constants;
import io.ray.serve.poll.LongPollClientFactory;
public class BaseTest { public class BaseTest {
@ -18,6 +19,9 @@ public class BaseTest {
} }
protected void shutdown() { protected void shutdown() {
LongPollClientFactory.stop();
LongPollClientFactory.clearAllCache();
Serve.setInternalReplicaContext(null);
if (!previousInited) { if (!previousInited) {
Ray.shutdown(); Ray.shutdown();
} }

View file

@ -0,0 +1,95 @@
package io.ray.serve;
import io.ray.api.Ray;
import io.ray.serve.api.Serve;
import io.ray.serve.deployment.Deployment;
import io.ray.serve.generated.DeploymentLanguage;
import io.ray.serve.handle.RayServeHandle;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.TimeUnit;
import org.apache.commons.io.FileUtils;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
@Test(groups = {"cluster"})
public class CrossLanguageDeploymentTest extends BaseServeTest {
private static final String PYTHON_MODULE = "test_python_deployment";
@BeforeClass
public void beforeClass() {
// Delete and re-create the temp dir.
File tempDir =
new File(
System.getProperty("java.io.tmpdir")
+ File.separator
+ "ray_serve_cross_language_test");
FileUtils.deleteQuietly(tempDir);
tempDir.mkdirs();
tempDir.deleteOnExit();
// Write the test Python file to the temp dir.
InputStream in =
CrossLanguageDeploymentTest.class.getResourceAsStream(
File.separator + PYTHON_MODULE + ".py");
File pythonFile = new File(tempDir.getAbsolutePath() + File.separator + PYTHON_MODULE + ".py");
try {
FileUtils.copyInputStreamToFile(in, pythonFile);
} catch (IOException e) {
throw new RuntimeException(e);
}
System.setProperty(
"ray.job.code-search-path",
System.getProperty("java.class.path") + File.pathSeparator + tempDir.getAbsolutePath());
}
@Test
public void createPyClassTest() {
Deployment deployment =
Serve.deployment()
.setLanguage(DeploymentLanguage.PYTHON)
.setName("createPyClassTest")
.setDeploymentDef(PYTHON_MODULE + ".Counter")
.setNumReplicas(1)
.setInitArgs(new Object[] {"28"})
.create();
deployment.deploy(true);
Assert.assertEquals(Ray.get(deployment.getHandle().method("increase").remote("6")), "34");
}
@Test
public void createPyMethodTest() {
Deployment deployment =
Serve.deployment()
.setLanguage(DeploymentLanguage.PYTHON)
.setName("createPyMethodTest")
.setDeploymentDef(PYTHON_MODULE + ".echo_server")
.setNumReplicas(1)
.create();
deployment.deploy(true);
RayServeHandle handle = deployment.getHandle();
Assert.assertEquals(Ray.get(handle.method("__call__").remote("6")), "6");
}
@Test
public void userConfigTest() throws InterruptedException {
Deployment deployment =
Serve.deployment()
.setLanguage(DeploymentLanguage.PYTHON)
.setName("userConfigTest")
.setDeploymentDef(PYTHON_MODULE + ".Counter")
.setNumReplicas(1)
.setUserConfig("1")
.setInitArgs(new Object[] {"28"})
.create();
deployment.deploy(true);
Assert.assertEquals(Ray.get(deployment.getHandle().method("increase").remote("6")), "7");
deployment.options().setUserConfig("3").create().deploy(true);
TimeUnit.SECONDS.sleep(20L);
Assert.assertEquals(Ray.get(deployment.getHandle().method("increase").remote("6")), "9");
}
}

View file

@ -29,8 +29,7 @@ public class ReplicaSetTest extends BaseTest {
ActorNameList.Builder builder = ActorNameList.newBuilder(); ActorNameList.Builder builder = ActorNameList.newBuilder();
replicaSet.updateWorkerReplicas(builder.build()); replicaSet.updateWorkerReplicas(builder.build());
Map<ActorHandle<RayServeWrappedReplica>, Set<ObjectRef<Object>>> inFlightQueries = Map<String, Set<ObjectRef<Object>>> inFlightQueries = replicaSet.getInFlightQueries();
replicaSet.getInFlightQueries();
Assert.assertTrue(inFlightQueries.isEmpty()); Assert.assertTrue(inFlightQueries.isEmpty());
} }

View file

@ -1,24 +0,0 @@
# This file is used by CrossLanguageInvocationTest.java to test cross-language
# invocation.
import ray
@ray.remote
def py_return_input(v):
return v
@ray.remote
def py_func_python_raise_exception():
1 / 0
@ray.remote
class Counter(object):
def __init__(self, value):
self.value = int(value)
def increase(self, delta):
self.value += int(delta)
return str(self.value).encode("utf-8")

View file

@ -46,6 +46,7 @@ from ray.serve._private.utils import (
format_actor_name, format_actor_name,
get_random_letters, get_random_letters,
msgpack_serialize, msgpack_serialize,
msgpack_deserialize,
) )
from ray.serve._private.version import DeploymentVersion, VersionedReplica from ray.serve._private.version import DeploymentVersion, VersionedReplica
@ -206,7 +207,9 @@ class ActorReplicaWrapper:
# Populated in self.stop(). # Populated in self.stop().
self._graceful_shutdown_ref: ObjectRef = None self._graceful_shutdown_ref: ObjectRef = None
# todo: will be confused with deployment_config.is_cross_language
self._is_cross_language = False self._is_cross_language = False
self._deployment_is_cross_language = False
@property @property
def replica_tag(self) -> str: def replica_tag(self) -> str:
@ -267,25 +270,51 @@ class ActorReplicaWrapper:
) )
self._actor_resources = deployment_info.replica_config.resource_dict self._actor_resources = deployment_info.replica_config.resource_dict
# it is currently not possible to create a placement group
# with no resources (https://github.com/ray-project/ray/issues/20401)
self._deployment_is_cross_language = (
deployment_info.deployment_config.is_cross_language
)
logger.debug( logger.debug(
f"Starting replica {self.replica_tag} for deployment " f"Starting replica {self.replica_tag} for deployment "
f"{self.deployment_name}." f"{self.deployment_name}."
) )
actor_def = deployment_info.actor_def actor_def = deployment_info.actor_def
if (
deployment_info.deployment_config.deployment_language
== DeploymentLanguage.PYTHON
):
if deployment_info.replica_config.serialized_init_args is None:
serialized_init_args = cloudpickle.dumps(())
else:
serialized_init_args = (
cloudpickle.dumps(
msgpack_deserialize(
deployment_info.replica_config.serialized_init_args
)
)
if self._deployment_is_cross_language
else deployment_info.replica_config.serialized_init_args
)
init_args = ( init_args = (
self.deployment_name, self.deployment_name,
self.replica_tag, self.replica_tag,
deployment_info.replica_config.serialized_deployment_def, cloudpickle.dumps(deployment_info.replica_config.deployment_def)
deployment_info.replica_config.serialized_init_args, if self._deployment_is_cross_language
deployment_info.replica_config.serialized_init_kwargs, else deployment_info.replica_config.serialized_deployment_def,
serialized_init_args,
deployment_info.replica_config.serialized_init_kwargs
if deployment_info.replica_config.serialized_init_kwargs
else cloudpickle.dumps({}),
deployment_info.deployment_config.to_proto_bytes(), deployment_info.deployment_config.to_proto_bytes(),
version, version,
self._controller_name, self._controller_name,
self._detached, self._detached,
) )
# TODO(simon): unify the constructor arguments across language # TODO(simon): unify the constructor arguments across language
if ( elif (
deployment_info.deployment_config.deployment_language deployment_info.deployment_config.deployment_language
== DeploymentLanguage.JAVA == DeploymentLanguage.JAVA
): ):
@ -306,7 +335,7 @@ class ActorReplicaWrapper:
deployment_info.replica_config.serialized_init_args deployment_info.replica_config.serialized_init_args
) )
) )
if deployment_info.deployment_config.is_cross_language if self._deployment_is_cross_language
else deployment_info.replica_config.serialized_init_args, else deployment_info.replica_config.serialized_init_args,
# byte[] deploymentConfigBytes, # byte[] deploymentConfigBytes,
deployment_info.deployment_config.to_proto_bytes(), deployment_info.deployment_config.to_proto_bytes(),
@ -329,16 +358,17 @@ class ActorReplicaWrapper:
# Perform auto method name translation for java handles. # Perform auto method name translation for java handles.
# See https://github.com/ray-project/ray/issues/21474 # See https://github.com/ray-project/ray/issues/21474
user_config = self._format_user_config(
deployment_info.deployment_config.user_config
)
if self._is_cross_language: if self._is_cross_language:
self._actor_handle = JavaActorHandleProxy(self._actor_handle) self._actor_handle = JavaActorHandleProxy(self._actor_handle)
self._allocated_obj_ref = self._actor_handle.is_allocated.remote() self._allocated_obj_ref = self._actor_handle.is_allocated.remote()
self._ready_obj_ref = self._actor_handle.reconfigure.remote( self._ready_obj_ref = self._actor_handle.reconfigure.remote(user_config)
deployment_info.deployment_config.user_config
)
else: else:
self._allocated_obj_ref = self._actor_handle.is_allocated.remote() self._allocated_obj_ref = self._actor_handle.is_allocated.remote()
self._ready_obj_ref = self._actor_handle.reconfigure.remote( self._ready_obj_ref = self._actor_handle.reconfigure.remote(
deployment_info.deployment_config.user_config, user_config,
# Ensure that `is_allocated` will execute before `reconfigure`, # Ensure that `is_allocated` will execute before `reconfigure`,
# because `reconfigure` runs user code that could block the replica # because `reconfigure` runs user code that could block the replica
# asyncio loop. If that happens before `is_allocated` is executed, # asyncio loop. If that happens before `is_allocated` is executed,
@ -346,12 +376,23 @@ class ActorReplicaWrapper:
self._allocated_obj_ref, self._allocated_obj_ref,
) )
def _format_user_config(self, user_config: Any):
temp = copy(user_config)
if user_config is not None and self._deployment_is_cross_language:
if self._is_cross_language:
temp = msgpack_serialize(temp)
else:
temp = msgpack_deserialize(temp)
return temp
def update_user_config(self, user_config: Any): def update_user_config(self, user_config: Any):
""" """
Update user config of existing actor behind current Update user config of existing actor behind current
DeploymentReplica instance. DeploymentReplica instance.
""" """
self._ready_obj_ref = self._actor_handle.reconfigure.remote(user_config) self._ready_obj_ref = self._actor_handle.reconfigure.remote(
self._format_user_config(user_config)
)
def recover(self): def recover(self):
""" """
@ -424,7 +465,11 @@ class ActorReplicaWrapper:
except Exception: except Exception:
logger.exception(f"Exception in deployment '{self._deployment_name}'") logger.exception(f"Exception in deployment '{self._deployment_name}'")
return ReplicaStartupStatus.FAILED, None return ReplicaStartupStatus.FAILED, None
if self._deployment_is_cross_language:
# todo: The replica's userconfig whitch java client created
# is different from the controller's userconfig
return ReplicaStartupStatus.SUCCEEDED, None
else:
return ReplicaStartupStatus.SUCCEEDED, version return ReplicaStartupStatus.SUCCEEDED, version
@property @property

View file

@ -36,6 +36,7 @@ from ray.serve._private.utils import (
parse_import_path, parse_import_path,
parse_request_item, parse_request_item,
wrap_to_ray_error, wrap_to_ray_error,
merge_dict,
) )
from ray.serve._private.version import DeploymentVersion from ray.serve._private.version import DeploymentVersion
@ -183,6 +184,24 @@ def create_replica_wrapper(name: str):
query = Query(request_args, request_kwargs, request_metadata) query = Query(request_args, request_kwargs, request_metadata)
return await self.replica.handle_request(query) return await self.replica.handle_request(query)
async def handle_request_from_java(
self,
proto_request_metadata: bytes,
*request_args,
**request_kwargs,
):
from ray.serve.generated.serve_pb2 import (
RequestMetadata as RequestMetadataProto,
)
proto = RequestMetadataProto.FromString(proto_request_metadata)
request_metadata: RequestMetadata = RequestMetadata(
proto.request_id, proto.endpoint, call_method=proto.call_method
)
request_args = request_args[0]
query = Query(request_args, request_kwargs, request_metadata, return_num=1)
return await self.replica.handle_request(query)
async def is_allocated(self) -> str: async def is_allocated(self) -> str:
"""poke the replica to check whether it's alive. """poke the replica to check whether it's alive.
@ -349,7 +368,11 @@ class RayServeReplica:
method_stat = actor_stats.get( method_stat = actor_stats.get(
f"{_format_replica_actor_name(self.deployment_name)}.handle_request" f"{_format_replica_actor_name(self.deployment_name)}.handle_request"
) )
return method_stat method_stat_java = actor_stats.get(
f"{_format_replica_actor_name(self.deployment_name)}"
f".handle_request_from_java"
)
return merge_dict(method_stat, method_stat_java)
def _collect_autoscaling_metrics(self): def _collect_autoscaling_metrics(self):
method_stat = self._get_handle_request_stats() method_stat = self._get_handle_request_stats()
@ -487,7 +510,9 @@ class RayServeReplica:
latency_ms=latency_ms, latency_ms=latency_ms,
) )
) )
if request.return_num == 1:
return result
else:
# Returns a small object for router to track request status. # Returns a small object for router to track request status.
return b"", result return b"", result

View file

@ -45,6 +45,7 @@ class Query:
args: List[Any] args: List[Any]
kwargs: Dict[Any, Any] kwargs: Dict[Any, Any]
metadata: RequestMetadata metadata: RequestMetadata
return_num: int = 2
async def resolve_async_tasks(self): async def resolve_async_tasks(self):
"""Find all unresolved asyncio.Task and gather them all at once.""" """Find all unresolved asyncio.Task and gather them all at once."""

View file

@ -24,6 +24,7 @@ from ray.exceptions import RayTaskError
from ray.serve._private.constants import HTTP_PROXY_TIMEOUT, RAY_GCS_RPC_TIMEOUT_S from ray.serve._private.constants import HTTP_PROXY_TIMEOUT, RAY_GCS_RPC_TIMEOUT_S
from ray.serve._private.http_util import HTTPRequestWrapper, build_starlette_request from ray.serve._private.http_util import HTTPRequestWrapper, build_starlette_request
from ray.util.serialization import StandaloneSerializationContext from ray.util.serialization import StandaloneSerializationContext
from ray._raylet import MessagePackSerializer
import __main__ import __main__
@ -33,6 +34,7 @@ except ImportError:
pd = None pd = None
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60 ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
MESSAGE_PACK_OFFSET = 9
# Use a global singleton enum to emulate default options. We cannot use None # Use a global singleton enum to emulate default options. We cannot use None
@ -241,6 +243,28 @@ def msgpack_serialize(obj):
return serialized return serialized
def msgpack_deserialize(data):
# todo: Ray does not provide a msgpack deserialization api.
try:
obj = MessagePackSerializer.loads(data[MESSAGE_PACK_OFFSET:], None)
except Exception:
raise
return obj
def merge_dict(dict1, dict2):
if dict1 is None and dict2 is None:
return None
if dict1 is None:
dict1 = dict()
if dict2 is None:
dict2 = dict()
result = dict()
for key in dict1.keys() | dict2.keys():
result[key] = sum([e.get(key, 0) for e in (dict1, dict2)])
return result
def get_deployment_import_path( def get_deployment_import_path(
deployment, replace_main=False, enforce_importable=False deployment, replace_main=False, enforce_importable=False
): ):

View file

@ -470,8 +470,8 @@ class ReplicaConfig:
return ReplicaConfig( return ReplicaConfig(
proto.deployment_def_name, proto.deployment_def_name,
proto.deployment_def, proto.deployment_def,
proto.init_args, proto.init_args if proto.init_args != b"" else None,
proto.init_kwargs, proto.init_kwargs if proto.init_kwargs != b"" else None,
json.loads(proto.ray_actor_options), json.loads(proto.ray_actor_options),
needs_pickle, needs_pickle,
) )

View file

@ -14,9 +14,39 @@ from ray.serve._private.utils import (
get_deployment_import_path, get_deployment_import_path,
override_runtime_envs_except_env_vars, override_runtime_envs_except_env_vars,
serve_encoders, serve_encoders,
merge_dict,
msgpack_serialize,
msgpack_deserialize,
) )
def test_serialize():
data = msgpack_serialize(5)
obj = msgpack_deserialize(data)
assert 5 == obj
def test_merge_dict():
dict1 = {"pending": 1, "running": 1, "finished": 1}
dict2 = {"pending": 4, "finished": 1}
merge = merge_dict(dict1, dict2)
assert merge["pending"] == 5
assert merge["running"] == 1
assert merge["finished"] == 2
dict1 = None
merge = merge_dict(dict1, dict2)
assert merge["pending"] == 4
assert merge["finished"] == 1
try:
assert merge["running"] == 1
assert False
except KeyError:
assert True
dict2 = None
merge = merge_dict(dict1, dict2)
assert merge is None
def test_bytes_encoder(): def test_bytes_encoder():
data_before = {"inp": {"nest": b"bytes"}} data_before = {"inp": {"nest": b"bytes"}}
data_after = {"inp": {"nest": "bytes"}} data_after = {"inp": {"nest": "bytes"}}