Define protobuf for RequestMetadata and HTTPRequestWrapper (#18203)

This commit is contained in:
liuyang-my 2021-09-16 05:39:27 +08:00 committed by GitHub
parent 7df3441ae9
commit ed04ab7140
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 193 additions and 167 deletions

View file

@ -1,22 +1,20 @@
package io.ray.serve; package io.ray.serve;
import io.ray.serve.generated.RequestMetadata;
/** Wrap request arguments and meta data. */ /** Wrap request arguments and meta data. */
public class Query { public class Query {
private Object[] args;
private RequestMetadata metadata; private RequestMetadata metadata;
public Query(Object[] args, RequestMetadata requestMetadata) { /**
this.args = args; * If this query is cross-language, the args is serialized {@link
* io.ray.serve.generated.RequestWrapper}. Otherwise, it is Object[].
*/
private Object args;
public Query(RequestMetadata requestMetadata, Object args) {
this.metadata = requestMetadata; this.metadata = requestMetadata;
}
public Object[] getArgs() {
return args;
}
public void setArgs(Object[] args) {
this.args = args; this.args = args;
} }
@ -24,7 +22,7 @@ public class Query {
return metadata; return metadata;
} }
public void setMetadata(RequestMetadata metadata) { public Object getArgs() {
this.metadata = metadata; return args;
} }
} }

View file

@ -8,15 +8,17 @@ import io.ray.runtime.metric.Gauge;
import io.ray.runtime.metric.Histogram; import io.ray.runtime.metric.Histogram;
import io.ray.runtime.metric.MetricConfig; import io.ray.runtime.metric.MetricConfig;
import io.ray.runtime.metric.Metrics; import io.ray.runtime.metric.Metrics;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.api.Serve; import io.ray.serve.api.Serve;
import io.ray.serve.generated.BackendConfig; import io.ray.serve.generated.BackendConfig;
import io.ray.serve.generated.RequestWrapper;
import io.ray.serve.poll.KeyListener; import io.ray.serve.poll.KeyListener;
import io.ray.serve.poll.KeyType; import io.ray.serve.poll.KeyType;
import io.ray.serve.poll.LongPollClient; import io.ray.serve.poll.LongPollClient;
import io.ray.serve.poll.LongPollNamespace; import io.ray.serve.poll.LongPollNamespace;
import io.ray.serve.util.BackendConfigUtil;
import io.ray.serve.util.LogUtil; import io.ray.serve.util.LogUtil;
import io.ray.serve.util.ReflectUtil; import io.ray.serve.util.ReflectUtil;
import io.ray.serve.util.ServeProtoUtil;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -59,7 +61,7 @@ public class RayServeReplica {
this.replicaTag = Serve.getReplicaContext().getReplicaTag(); this.replicaTag = Serve.getReplicaContext().getReplicaTag();
this.callable = callable; this.callable = callable;
this.config = backendConfig; this.config = backendConfig;
this.reconfigure(BackendConfigUtil.getUserConfig(backendConfig)); this.reconfigure(ServeProtoUtil.parseUserConfig(backendConfig));
Map<KeyType, KeyListener> keyListeners = new HashMap<>(); Map<KeyType, KeyListener> keyListeners = new HashMap<>();
keyListeners.put( keyListeners.put(
@ -152,8 +154,9 @@ public class RayServeReplica {
replicaTag, replicaTag,
requestItem.getMetadata().getRequestId()); requestItem.getMetadata().getRequestId());
methodToCall = getRunnerMethod(requestItem); Object[] args = parseRequestItem(requestItem);
Object result = methodToCall.invoke(callable, requestItem.getArgs()); methodToCall = getRunnerMethod(requestItem.getMetadata().getCallMethod(), args);
Object result = methodToCall.invoke(callable, args);
reportMetrics(() -> requestCounter.inc(1.0)); reportMetrics(() -> requestCounter.inc(1.0));
return result; return result;
} catch (Throwable e) { } catch (Throwable e) {
@ -169,12 +172,29 @@ public class RayServeReplica {
} }
} }
private Method getRunnerMethod(Query query) { private Object[] parseRequestItem(Query requestItem) {
String methodName = query.getMetadata().getCallMethod(); if (requestItem.getArgs() == null) {
return new Object[0];
}
// From Java Proxy or Handle.
if (requestItem.getArgs() instanceof Object[]) {
return (Object[]) requestItem.getArgs();
}
// From other language Proxy or Handle.
RequestWrapper requestWrapper = (RequestWrapper) requestItem.getArgs();
if (requestWrapper.getBody() == null || requestWrapper.getBody().isEmpty()) {
return new Object[0];
}
return MessagePackSerializer.decode(requestWrapper.getBody().toByteArray(), Object[].class);
}
private Method getRunnerMethod(String methodName, Object[] args) {
try { try {
return ReflectUtil.getMethod( return ReflectUtil.getMethod(callable.getClass(), methodName, args);
callable.getClass(), methodName, query.getArgs() == null ? null : query.getArgs());
} catch (NoSuchMethodException e) { } catch (NoSuchMethodException e) {
throw new RayServeException( throw new RayServeException(
LogUtil.format( LogUtil.format(

View file

@ -1,13 +1,15 @@
package io.ray.serve; package io.ray.serve;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.api.BaseActorHandle; import io.ray.api.BaseActorHandle;
import io.ray.api.Ray; import io.ray.api.Ray;
import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.api.Serve; import io.ray.serve.api.Serve;
import io.ray.serve.generated.BackendConfig; import io.ray.serve.generated.BackendConfig;
import io.ray.serve.util.BackendConfigUtil; import io.ray.serve.generated.RequestMetadata;
import io.ray.serve.util.ReflectUtil; import io.ray.serve.util.ReflectUtil;
import io.ray.serve.util.ServeProtoUtil;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.util.Optional; import java.util.Optional;
@ -30,7 +32,7 @@ public class RayServeWrappedReplica {
IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException { IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException {
// Parse BackendConfig. // Parse BackendConfig.
BackendConfig backendConfig = BackendConfigUtil.parseFrom(backendConfigBytes); BackendConfig backendConfig = ServeProtoUtil.parseBackendConfig(backendConfigBytes);
// Parse init args. // Parse init args.
Object[] initArgs = parseInitArgs(initArgsbytes, backendConfig); Object[] initArgs = parseInitArgs(initArgsbytes, backendConfig);
@ -73,13 +75,25 @@ public class RayServeWrappedReplica {
/** /**
* The entry method to process the request. * The entry method to process the request.
* *
* @param requestMetadata request metadata * @param requestMetadata the real type is byte[] if this invocation is cross-language. Otherwise,
* @param requestArgs the input parameters of the specified method of the object defined by * the real type is {@link io.ray.serve.generated.RequestMetadata}.
* backendDef. * @param requestArgs The input parameters of the specified method of the object defined by
* backendDef. The real type is serialized {@link io.ray.serve.generated.RequestWrapper} if
* this invocation is cross-language. Otherwise, the real type is Object[].
* @return the result of request being processed * @return the result of request being processed
* @throws InvalidProtocolBufferException if the protobuf deserialization fails.
*/ */
public Object handleRequest(RequestMetadata requestMetadata, Object[] requestArgs) { public Object handleRequest(Object requestMetadata, Object requestArgs)
return backend.handleRequest(new Query(requestArgs, requestMetadata)); throws InvalidProtocolBufferException {
boolean isCrossLanguage = requestMetadata instanceof byte[];
return backend.handleRequest(
new Query(
isCrossLanguage
? ServeProtoUtil.parseRequestMetadata((byte[]) requestMetadata)
: (RequestMetadata) requestMetadata,
isCrossLanguage
? ServeProtoUtil.parseRequestWrapper((byte[]) requestArgs)
: requestArgs));
} }
/** Check whether this replica is ready or not. */ /** Check whether this replica is ready or not. */

View file

@ -1,60 +0,0 @@
package io.ray.serve;
import java.io.Serializable;
import java.util.Map;
/** The meta data of request. */
public class RequestMetadata implements Serializable {
private static final long serialVersionUID = -8925036926565326811L;
private String requestId;
private String endpoint;
private String callMethod = "call";
private String httpMethod;
private Map<String, String> httpHeaders;
public String getRequestId() {
return requestId;
}
public void setRequestId(String requestId) {
this.requestId = requestId;
}
public String getEndpoint() {
return endpoint;
}
public void setEndpoint(String endpoint) {
this.endpoint = endpoint;
}
public String getCallMethod() {
return callMethod;
}
public void setCallMethod(String callMethod) {
this.callMethod = callMethod;
}
public String getHttpMethod() {
return httpMethod;
}
public void setHttpMethod(String httpMethod) {
this.httpMethod = httpMethod;
}
public Map<String, String> getHttpHeaders() {
return httpHeaders;
}
public void setHttpHeaders(Map<String, String> httpHeaders) {
this.httpHeaders = httpHeaders;
}
}

View file

@ -1,76 +0,0 @@
package io.ray.serve.util;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.RayServeException;
import io.ray.serve.generated.BackendConfig;
import io.ray.serve.generated.BackendLanguage;
public class BackendConfigUtil {
public static BackendConfig parseFrom(byte[] backendConfigBytes)
throws InvalidProtocolBufferException {
// Parse BackendConfig from byte[].
BackendConfig inputBackendConfig = BackendConfig.parseFrom(backendConfigBytes);
if (inputBackendConfig == null) {
return null;
}
// Set default values.
BackendConfig.Builder builder = BackendConfig.newBuilder();
if (inputBackendConfig.getNumReplicas() == 0) {
builder.setNumReplicas(1);
} else {
builder.setNumReplicas(inputBackendConfig.getNumReplicas());
}
Preconditions.checkArgument(
inputBackendConfig.getMaxConcurrentQueries() >= 0, "max_concurrent_queries must be >= 0");
if (inputBackendConfig.getMaxConcurrentQueries() == 0) {
builder.setMaxConcurrentQueries(100);
} else {
builder.setMaxConcurrentQueries(inputBackendConfig.getMaxConcurrentQueries());
}
builder.setUserConfig(inputBackendConfig.getUserConfig());
if (inputBackendConfig.getExperimentalGracefulShutdownWaitLoopS() == 0) {
builder.setExperimentalGracefulShutdownWaitLoopS(2);
} else {
builder.setExperimentalGracefulShutdownWaitLoopS(
inputBackendConfig.getExperimentalGracefulShutdownWaitLoopS());
}
if (inputBackendConfig.getExperimentalGracefulShutdownTimeoutS() == 0) {
builder.setExperimentalGracefulShutdownTimeoutS(20);
} else {
builder.setExperimentalGracefulShutdownTimeoutS(
inputBackendConfig.getExperimentalGracefulShutdownTimeoutS());
}
builder.setIsCrossLanguage(inputBackendConfig.getIsCrossLanguage());
if (inputBackendConfig.getBackendLanguage() == BackendLanguage.UNRECOGNIZED) {
throw new RayServeException(
LogUtil.format(
"Unrecognized backend language {}. Backend language must be in {}.",
inputBackendConfig.getBackendLanguageValue(),
Lists.newArrayList(BackendLanguage.values())));
} else {
builder.setBackendLanguage(inputBackendConfig.getBackendLanguage());
}
return builder.build();
}
public static Object getUserConfig(BackendConfig backendConfig) {
if (backendConfig.getUserConfig() == null || backendConfig.getUserConfig().size() == 0) {
return null;
}
return MessagePackSerializer.decode(backendConfig.getUserConfig().toByteArray(), Object.class);
}
}

View file

@ -0,0 +1,111 @@
package io.ray.serve.util;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.RayServeException;
import io.ray.serve.generated.BackendConfig;
import io.ray.serve.generated.BackendLanguage;
import io.ray.serve.generated.RequestMetadata;
import io.ray.serve.generated.RequestWrapper;
import org.apache.commons.lang3.StringUtils;
public class ServeProtoUtil {
public static BackendConfig parseBackendConfig(byte[] backendConfigBytes)
throws InvalidProtocolBufferException {
// Get a builder from BackendConfig(bytes) or create a new one.
BackendConfig.Builder builder = null;
if (backendConfigBytes == null) {
builder = BackendConfig.newBuilder();
} else {
BackendConfig backendConfig = BackendConfig.parseFrom(backendConfigBytes);
if (backendConfig == null) {
builder = BackendConfig.newBuilder();
} else {
builder = BackendConfig.newBuilder(backendConfig);
}
}
// Set default values.
if (builder.getNumReplicas() == 0) {
builder.setNumReplicas(1);
}
Preconditions.checkArgument(
builder.getMaxConcurrentQueries() >= 0, "max_concurrent_queries must be >= 0");
if (builder.getMaxConcurrentQueries() == 0) {
builder.setMaxConcurrentQueries(100);
}
if (builder.getExperimentalGracefulShutdownWaitLoopS() == 0) {
builder.setExperimentalGracefulShutdownWaitLoopS(2);
}
if (builder.getExperimentalGracefulShutdownTimeoutS() == 0) {
builder.setExperimentalGracefulShutdownTimeoutS(20);
}
if (builder.getBackendLanguage() == BackendLanguage.UNRECOGNIZED) {
throw new RayServeException(
LogUtil.format(
"Unrecognized backend language {}. Backend language must be in {}.",
builder.getBackendLanguageValue(),
Lists.newArrayList(BackendLanguage.values())));
}
return builder.build();
}
public static Object parseUserConfig(BackendConfig backendConfig) {
if (backendConfig.getUserConfig() == null || backendConfig.getUserConfig().size() == 0) {
return null;
}
return MessagePackSerializer.decode(backendConfig.getUserConfig().toByteArray(), Object.class);
}
public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes)
throws InvalidProtocolBufferException {
// Get a builder from RequestMetadata(bytes) or create a new one.
RequestMetadata.Builder builder = null;
if (requestMetadataBytes == null) {
builder = RequestMetadata.newBuilder();
} else {
RequestMetadata requestMetadata = RequestMetadata.parseFrom(requestMetadataBytes);
if (requestMetadata == null) {
builder = RequestMetadata.newBuilder();
} else {
builder = RequestMetadata.newBuilder(requestMetadata);
}
}
// Set default values.
if (StringUtils.isBlank(builder.getCallMethod())) {
builder.setCallMethod("call");
}
return builder.build();
}
public static RequestWrapper parseRequestWrapper(byte[] httpRequestWrapperBytes)
throws InvalidProtocolBufferException {
// Get a builder from HTTPRequestWrapper(bytes) or create a new one.
RequestWrapper.Builder builder = null;
if (httpRequestWrapperBytes == null) {
builder = RequestWrapper.newBuilder();
} else {
RequestWrapper requestWrapper = RequestWrapper.parseFrom(httpRequestWrapperBytes);
if (requestWrapper == null) {
builder = RequestWrapper.newBuilder();
} else {
builder = RequestWrapper.newBuilder(requestWrapper);
}
}
return builder.build();
}
}

View file

@ -6,6 +6,8 @@ import io.ray.api.Ray;
import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.generated.BackendConfig; import io.ray.serve.generated.BackendConfig;
import io.ray.serve.generated.BackendLanguage; import io.ray.serve.generated.BackendLanguage;
import io.ray.serve.generated.RequestMetadata;
import io.ray.serve.generated.RequestWrapper;
import java.io.IOException; import java.io.IOException;
import org.testng.Assert; import org.testng.Assert;
import org.testng.annotations.Test; import org.testng.annotations.Test;
@ -17,7 +19,6 @@ public class RayServeReplicaTest {
public void test() throws IOException { public void test() throws IOException {
boolean inited = Ray.isInitialized(); boolean inited = Ray.isInitialized();
Ray.init(); Ray.init();
try { try {
@ -52,12 +53,18 @@ public class RayServeReplicaTest {
backendHandle.task(RayServeWrappedReplica::ready).remote(); backendHandle.task(RayServeWrappedReplica::ready).remote();
RequestMetadata requestMetadata = new RequestMetadata(); RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder();
requestMetadata.setRequestId("RayServeReplicaTest"); requestMetadata.setRequestId("RayServeReplicaTest");
requestMetadata.setCallMethod("getBackendTag"); requestMetadata.setCallMethod("getBackendTag");
RequestWrapper.Builder requestWrapper = RequestWrapper.newBuilder();
ObjectRef<Object> resultRef = ObjectRef<Object> resultRef =
backendHandle backendHandle
.task(RayServeWrappedReplica::handleRequest, requestMetadata, (Object[]) null) .task(
RayServeWrappedReplica::handleRequest,
requestMetadata.build().toByteArray(),
requestWrapper.build().toByteArray())
.remote(); .remote();
Assert.assertEquals((String) resultRef.get(), backendTag); Assert.assertEquals((String) resultRef.get(), backendTag);

View file

@ -56,3 +56,15 @@ enum BackendLanguage {
PYTHON = 0; PYTHON = 0;
JAVA = 1; JAVA = 1;
} }
message RequestMetadata {
string request_id = 1;
string endpoint = 2;
string call_method = 3;
}
message RequestWrapper {
bytes body = 1;
}