diff --git a/java/serve/src/main/java/io/ray/serve/Query.java b/java/serve/src/main/java/io/ray/serve/Query.java index c611e0669..e88480506 100644 --- a/java/serve/src/main/java/io/ray/serve/Query.java +++ b/java/serve/src/main/java/io/ray/serve/Query.java @@ -1,22 +1,20 @@ package io.ray.serve; +import io.ray.serve.generated.RequestMetadata; + /** Wrap request arguments and meta data. */ public class Query { - private Object[] args; - private RequestMetadata metadata; - public Query(Object[] args, RequestMetadata requestMetadata) { - this.args = args; + /** + * If this query is cross-language, the args is serialized {@link + * io.ray.serve.generated.RequestWrapper}. Otherwise, it is Object[]. + */ + private Object args; + + public Query(RequestMetadata requestMetadata, Object args) { this.metadata = requestMetadata; - } - - public Object[] getArgs() { - return args; - } - - public void setArgs(Object[] args) { this.args = args; } @@ -24,7 +22,7 @@ public class Query { return metadata; } - public void setMetadata(RequestMetadata metadata) { - this.metadata = metadata; + public Object getArgs() { + return args; } } diff --git a/java/serve/src/main/java/io/ray/serve/RayServeReplica.java b/java/serve/src/main/java/io/ray/serve/RayServeReplica.java index 17fdda014..9949115fb 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeReplica.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeReplica.java @@ -8,15 +8,17 @@ import io.ray.runtime.metric.Gauge; import io.ray.runtime.metric.Histogram; import io.ray.runtime.metric.MetricConfig; import io.ray.runtime.metric.Metrics; +import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.RequestWrapper; import io.ray.serve.poll.KeyListener; import io.ray.serve.poll.KeyType; import io.ray.serve.poll.LongPollClient; import io.ray.serve.poll.LongPollNamespace; -import io.ray.serve.util.BackendConfigUtil; import io.ray.serve.util.LogUtil; import io.ray.serve.util.ReflectUtil; +import io.ray.serve.util.ServeProtoUtil; import java.lang.reflect.Method; import java.util.HashMap; import java.util.Map; @@ -59,7 +61,7 @@ public class RayServeReplica { this.replicaTag = Serve.getReplicaContext().getReplicaTag(); this.callable = callable; this.config = backendConfig; - this.reconfigure(BackendConfigUtil.getUserConfig(backendConfig)); + this.reconfigure(ServeProtoUtil.parseUserConfig(backendConfig)); Map keyListeners = new HashMap<>(); keyListeners.put( @@ -152,8 +154,9 @@ public class RayServeReplica { replicaTag, requestItem.getMetadata().getRequestId()); - methodToCall = getRunnerMethod(requestItem); - Object result = methodToCall.invoke(callable, requestItem.getArgs()); + Object[] args = parseRequestItem(requestItem); + methodToCall = getRunnerMethod(requestItem.getMetadata().getCallMethod(), args); + Object result = methodToCall.invoke(callable, args); reportMetrics(() -> requestCounter.inc(1.0)); return result; } catch (Throwable e) { @@ -169,12 +172,29 @@ public class RayServeReplica { } } - private Method getRunnerMethod(Query query) { - String methodName = query.getMetadata().getCallMethod(); + private Object[] parseRequestItem(Query requestItem) { + if (requestItem.getArgs() == null) { + return new Object[0]; + } + + // From Java Proxy or Handle. + if (requestItem.getArgs() instanceof Object[]) { + return (Object[]) requestItem.getArgs(); + } + + // From other language Proxy or Handle. + RequestWrapper requestWrapper = (RequestWrapper) requestItem.getArgs(); + if (requestWrapper.getBody() == null || requestWrapper.getBody().isEmpty()) { + return new Object[0]; + } + + return MessagePackSerializer.decode(requestWrapper.getBody().toByteArray(), Object[].class); + } + + private Method getRunnerMethod(String methodName, Object[] args) { try { - return ReflectUtil.getMethod( - callable.getClass(), methodName, query.getArgs() == null ? null : query.getArgs()); + return ReflectUtil.getMethod(callable.getClass(), methodName, args); } catch (NoSuchMethodException e) { throw new RayServeException( LogUtil.format( diff --git a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java index 2c49814b0..9ccc6c6f7 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java @@ -1,13 +1,15 @@ package io.ray.serve; import com.google.common.base.Preconditions; +import com.google.protobuf.InvalidProtocolBufferException; import io.ray.api.BaseActorHandle; import io.ray.api.Ray; import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; import io.ray.serve.generated.BackendConfig; -import io.ray.serve.util.BackendConfigUtil; +import io.ray.serve.generated.RequestMetadata; import io.ray.serve.util.ReflectUtil; +import io.ray.serve.util.ServeProtoUtil; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.util.Optional; @@ -30,7 +32,7 @@ public class RayServeWrappedReplica { IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException { // Parse BackendConfig. - BackendConfig backendConfig = BackendConfigUtil.parseFrom(backendConfigBytes); + BackendConfig backendConfig = ServeProtoUtil.parseBackendConfig(backendConfigBytes); // Parse init args. Object[] initArgs = parseInitArgs(initArgsbytes, backendConfig); @@ -73,13 +75,25 @@ public class RayServeWrappedReplica { /** * The entry method to process the request. * - * @param requestMetadata request metadata - * @param requestArgs the input parameters of the specified method of the object defined by - * backendDef. + * @param requestMetadata the real type is byte[] if this invocation is cross-language. Otherwise, + * the real type is {@link io.ray.serve.generated.RequestMetadata}. + * @param requestArgs The input parameters of the specified method of the object defined by + * backendDef. The real type is serialized {@link io.ray.serve.generated.RequestWrapper} if + * this invocation is cross-language. Otherwise, the real type is Object[]. * @return the result of request being processed + * @throws InvalidProtocolBufferException if the protobuf deserialization fails. */ - public Object handleRequest(RequestMetadata requestMetadata, Object[] requestArgs) { - return backend.handleRequest(new Query(requestArgs, requestMetadata)); + public Object handleRequest(Object requestMetadata, Object requestArgs) + throws InvalidProtocolBufferException { + boolean isCrossLanguage = requestMetadata instanceof byte[]; + return backend.handleRequest( + new Query( + isCrossLanguage + ? ServeProtoUtil.parseRequestMetadata((byte[]) requestMetadata) + : (RequestMetadata) requestMetadata, + isCrossLanguage + ? ServeProtoUtil.parseRequestWrapper((byte[]) requestArgs) + : requestArgs)); } /** Check whether this replica is ready or not. */ diff --git a/java/serve/src/main/java/io/ray/serve/RequestMetadata.java b/java/serve/src/main/java/io/ray/serve/RequestMetadata.java deleted file mode 100644 index a903ef650..000000000 --- a/java/serve/src/main/java/io/ray/serve/RequestMetadata.java +++ /dev/null @@ -1,60 +0,0 @@ -package io.ray.serve; - -import java.io.Serializable; -import java.util.Map; - -/** The meta data of request. */ -public class RequestMetadata implements Serializable { - - private static final long serialVersionUID = -8925036926565326811L; - - private String requestId; - - private String endpoint; - - private String callMethod = "call"; - - private String httpMethod; - - private Map httpHeaders; - - public String getRequestId() { - return requestId; - } - - public void setRequestId(String requestId) { - this.requestId = requestId; - } - - public String getEndpoint() { - return endpoint; - } - - public void setEndpoint(String endpoint) { - this.endpoint = endpoint; - } - - public String getCallMethod() { - return callMethod; - } - - public void setCallMethod(String callMethod) { - this.callMethod = callMethod; - } - - public String getHttpMethod() { - return httpMethod; - } - - public void setHttpMethod(String httpMethod) { - this.httpMethod = httpMethod; - } - - public Map getHttpHeaders() { - return httpHeaders; - } - - public void setHttpHeaders(Map httpHeaders) { - this.httpHeaders = httpHeaders; - } -} diff --git a/java/serve/src/main/java/io/ray/serve/util/BackendConfigUtil.java b/java/serve/src/main/java/io/ray/serve/util/BackendConfigUtil.java deleted file mode 100644 index 6de202892..000000000 --- a/java/serve/src/main/java/io/ray/serve/util/BackendConfigUtil.java +++ /dev/null @@ -1,76 +0,0 @@ -package io.ray.serve.util; - -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; -import com.google.protobuf.InvalidProtocolBufferException; -import io.ray.runtime.serializer.MessagePackSerializer; -import io.ray.serve.RayServeException; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendLanguage; - -public class BackendConfigUtil { - - public static BackendConfig parseFrom(byte[] backendConfigBytes) - throws InvalidProtocolBufferException { - - // Parse BackendConfig from byte[]. - BackendConfig inputBackendConfig = BackendConfig.parseFrom(backendConfigBytes); - if (inputBackendConfig == null) { - return null; - } - - // Set default values. - BackendConfig.Builder builder = BackendConfig.newBuilder(); - - if (inputBackendConfig.getNumReplicas() == 0) { - builder.setNumReplicas(1); - } else { - builder.setNumReplicas(inputBackendConfig.getNumReplicas()); - } - - Preconditions.checkArgument( - inputBackendConfig.getMaxConcurrentQueries() >= 0, "max_concurrent_queries must be >= 0"); - if (inputBackendConfig.getMaxConcurrentQueries() == 0) { - builder.setMaxConcurrentQueries(100); - } else { - builder.setMaxConcurrentQueries(inputBackendConfig.getMaxConcurrentQueries()); - } - - builder.setUserConfig(inputBackendConfig.getUserConfig()); - - if (inputBackendConfig.getExperimentalGracefulShutdownWaitLoopS() == 0) { - builder.setExperimentalGracefulShutdownWaitLoopS(2); - } else { - builder.setExperimentalGracefulShutdownWaitLoopS( - inputBackendConfig.getExperimentalGracefulShutdownWaitLoopS()); - } - - if (inputBackendConfig.getExperimentalGracefulShutdownTimeoutS() == 0) { - builder.setExperimentalGracefulShutdownTimeoutS(20); - } else { - builder.setExperimentalGracefulShutdownTimeoutS( - inputBackendConfig.getExperimentalGracefulShutdownTimeoutS()); - } - - builder.setIsCrossLanguage(inputBackendConfig.getIsCrossLanguage()); - - if (inputBackendConfig.getBackendLanguage() == BackendLanguage.UNRECOGNIZED) { - throw new RayServeException( - LogUtil.format( - "Unrecognized backend language {}. Backend language must be in {}.", - inputBackendConfig.getBackendLanguageValue(), - Lists.newArrayList(BackendLanguage.values()))); - } else { - builder.setBackendLanguage(inputBackendConfig.getBackendLanguage()); - } - - return builder.build(); - } - - public static Object getUserConfig(BackendConfig backendConfig) { - if (backendConfig.getUserConfig() == null || backendConfig.getUserConfig().size() == 0) { - return null; - } - return MessagePackSerializer.decode(backendConfig.getUserConfig().toByteArray(), Object.class); - } -} diff --git a/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java new file mode 100644 index 000000000..b1d02a046 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java @@ -0,0 +1,111 @@ +package io.ray.serve.util; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.protobuf.InvalidProtocolBufferException; +import io.ray.runtime.serializer.MessagePackSerializer; +import io.ray.serve.RayServeException; +import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.RequestMetadata; +import io.ray.serve.generated.RequestWrapper; +import org.apache.commons.lang3.StringUtils; + +public class ServeProtoUtil { + + public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) + throws InvalidProtocolBufferException { + + // Get a builder from BackendConfig(bytes) or create a new one. + BackendConfig.Builder builder = null; + if (backendConfigBytes == null) { + builder = BackendConfig.newBuilder(); + } else { + BackendConfig backendConfig = BackendConfig.parseFrom(backendConfigBytes); + if (backendConfig == null) { + builder = BackendConfig.newBuilder(); + } else { + builder = BackendConfig.newBuilder(backendConfig); + } + } + + // Set default values. + if (builder.getNumReplicas() == 0) { + builder.setNumReplicas(1); + } + + Preconditions.checkArgument( + builder.getMaxConcurrentQueries() >= 0, "max_concurrent_queries must be >= 0"); + if (builder.getMaxConcurrentQueries() == 0) { + builder.setMaxConcurrentQueries(100); + } + + if (builder.getExperimentalGracefulShutdownWaitLoopS() == 0) { + builder.setExperimentalGracefulShutdownWaitLoopS(2); + } + + if (builder.getExperimentalGracefulShutdownTimeoutS() == 0) { + builder.setExperimentalGracefulShutdownTimeoutS(20); + } + + if (builder.getBackendLanguage() == BackendLanguage.UNRECOGNIZED) { + throw new RayServeException( + LogUtil.format( + "Unrecognized backend language {}. Backend language must be in {}.", + builder.getBackendLanguageValue(), + Lists.newArrayList(BackendLanguage.values()))); + } + + return builder.build(); + } + + public static Object parseUserConfig(BackendConfig backendConfig) { + if (backendConfig.getUserConfig() == null || backendConfig.getUserConfig().size() == 0) { + return null; + } + return MessagePackSerializer.decode(backendConfig.getUserConfig().toByteArray(), Object.class); + } + + public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes) + throws InvalidProtocolBufferException { + + // Get a builder from RequestMetadata(bytes) or create a new one. + RequestMetadata.Builder builder = null; + if (requestMetadataBytes == null) { + builder = RequestMetadata.newBuilder(); + } else { + RequestMetadata requestMetadata = RequestMetadata.parseFrom(requestMetadataBytes); + if (requestMetadata == null) { + builder = RequestMetadata.newBuilder(); + } else { + builder = RequestMetadata.newBuilder(requestMetadata); + } + } + + // Set default values. + if (StringUtils.isBlank(builder.getCallMethod())) { + builder.setCallMethod("call"); + } + + return builder.build(); + } + + public static RequestWrapper parseRequestWrapper(byte[] httpRequestWrapperBytes) + throws InvalidProtocolBufferException { + + // Get a builder from HTTPRequestWrapper(bytes) or create a new one. + RequestWrapper.Builder builder = null; + if (httpRequestWrapperBytes == null) { + builder = RequestWrapper.newBuilder(); + } else { + RequestWrapper requestWrapper = RequestWrapper.parseFrom(httpRequestWrapperBytes); + if (requestWrapper == null) { + builder = RequestWrapper.newBuilder(); + } else { + builder = RequestWrapper.newBuilder(requestWrapper); + } + } + + return builder.build(); + } +} diff --git a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java index e25947a07..7cc7746ff 100644 --- a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java +++ b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java @@ -6,6 +6,8 @@ import io.ray.api.Ray; import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.generated.BackendConfig; import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.RequestMetadata; +import io.ray.serve.generated.RequestWrapper; import java.io.IOException; import org.testng.Assert; import org.testng.annotations.Test; @@ -17,7 +19,6 @@ public class RayServeReplicaTest { public void test() throws IOException { boolean inited = Ray.isInitialized(); - Ray.init(); try { @@ -52,12 +53,18 @@ public class RayServeReplicaTest { backendHandle.task(RayServeWrappedReplica::ready).remote(); - RequestMetadata requestMetadata = new RequestMetadata(); + RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder(); requestMetadata.setRequestId("RayServeReplicaTest"); requestMetadata.setCallMethod("getBackendTag"); + + RequestWrapper.Builder requestWrapper = RequestWrapper.newBuilder(); + ObjectRef resultRef = backendHandle - .task(RayServeWrappedReplica::handleRequest, requestMetadata, (Object[]) null) + .task( + RayServeWrappedReplica::handleRequest, + requestMetadata.build().toByteArray(), + requestWrapper.build().toByteArray()) .remote(); Assert.assertEquals((String) resultRef.get(), backendTag); diff --git a/src/ray/protobuf/serve.proto b/src/ray/protobuf/serve.proto index fd48f28df..f6de71593 100644 --- a/src/ray/protobuf/serve.proto +++ b/src/ray/protobuf/serve.proto @@ -56,3 +56,15 @@ enum BackendLanguage { PYTHON = 0; JAVA = 1; } + +message RequestMetadata { + string request_id = 1; + + string endpoint = 2; + + string call_method = 3; +} + +message RequestWrapper { + bytes body = 1; +}