mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
Define protobuf for RequestMetadata and HTTPRequestWrapper (#18203)
This commit is contained in:
parent
7df3441ae9
commit
ed04ab7140
8 changed files with 193 additions and 167 deletions
|
@ -1,22 +1,20 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import io.ray.serve.generated.RequestMetadata;
|
||||
|
||||
/** Wrap request arguments and meta data. */
|
||||
public class Query {
|
||||
|
||||
private Object[] args;
|
||||
|
||||
private RequestMetadata metadata;
|
||||
|
||||
public Query(Object[] args, RequestMetadata requestMetadata) {
|
||||
this.args = args;
|
||||
/**
|
||||
* If this query is cross-language, the args is serialized {@link
|
||||
* io.ray.serve.generated.RequestWrapper}. Otherwise, it is Object[].
|
||||
*/
|
||||
private Object args;
|
||||
|
||||
public Query(RequestMetadata requestMetadata, Object args) {
|
||||
this.metadata = requestMetadata;
|
||||
}
|
||||
|
||||
public Object[] getArgs() {
|
||||
return args;
|
||||
}
|
||||
|
||||
public void setArgs(Object[] args) {
|
||||
this.args = args;
|
||||
}
|
||||
|
||||
|
@ -24,7 +22,7 @@ public class Query {
|
|||
return metadata;
|
||||
}
|
||||
|
||||
public void setMetadata(RequestMetadata metadata) {
|
||||
this.metadata = metadata;
|
||||
public Object getArgs() {
|
||||
return args;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,15 +8,17 @@ import io.ray.runtime.metric.Gauge;
|
|||
import io.ray.runtime.metric.Histogram;
|
||||
import io.ray.runtime.metric.MetricConfig;
|
||||
import io.ray.runtime.metric.Metrics;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.generated.BackendConfig;
|
||||
import io.ray.serve.generated.RequestWrapper;
|
||||
import io.ray.serve.poll.KeyListener;
|
||||
import io.ray.serve.poll.KeyType;
|
||||
import io.ray.serve.poll.LongPollClient;
|
||||
import io.ray.serve.poll.LongPollNamespace;
|
||||
import io.ray.serve.util.BackendConfigUtil;
|
||||
import io.ray.serve.util.LogUtil;
|
||||
import io.ray.serve.util.ReflectUtil;
|
||||
import io.ray.serve.util.ServeProtoUtil;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -59,7 +61,7 @@ public class RayServeReplica {
|
|||
this.replicaTag = Serve.getReplicaContext().getReplicaTag();
|
||||
this.callable = callable;
|
||||
this.config = backendConfig;
|
||||
this.reconfigure(BackendConfigUtil.getUserConfig(backendConfig));
|
||||
this.reconfigure(ServeProtoUtil.parseUserConfig(backendConfig));
|
||||
|
||||
Map<KeyType, KeyListener> keyListeners = new HashMap<>();
|
||||
keyListeners.put(
|
||||
|
@ -152,8 +154,9 @@ public class RayServeReplica {
|
|||
replicaTag,
|
||||
requestItem.getMetadata().getRequestId());
|
||||
|
||||
methodToCall = getRunnerMethod(requestItem);
|
||||
Object result = methodToCall.invoke(callable, requestItem.getArgs());
|
||||
Object[] args = parseRequestItem(requestItem);
|
||||
methodToCall = getRunnerMethod(requestItem.getMetadata().getCallMethod(), args);
|
||||
Object result = methodToCall.invoke(callable, args);
|
||||
reportMetrics(() -> requestCounter.inc(1.0));
|
||||
return result;
|
||||
} catch (Throwable e) {
|
||||
|
@ -169,12 +172,29 @@ public class RayServeReplica {
|
|||
}
|
||||
}
|
||||
|
||||
private Method getRunnerMethod(Query query) {
|
||||
String methodName = query.getMetadata().getCallMethod();
|
||||
private Object[] parseRequestItem(Query requestItem) {
|
||||
if (requestItem.getArgs() == null) {
|
||||
return new Object[0];
|
||||
}
|
||||
|
||||
// From Java Proxy or Handle.
|
||||
if (requestItem.getArgs() instanceof Object[]) {
|
||||
return (Object[]) requestItem.getArgs();
|
||||
}
|
||||
|
||||
// From other language Proxy or Handle.
|
||||
RequestWrapper requestWrapper = (RequestWrapper) requestItem.getArgs();
|
||||
if (requestWrapper.getBody() == null || requestWrapper.getBody().isEmpty()) {
|
||||
return new Object[0];
|
||||
}
|
||||
|
||||
return MessagePackSerializer.decode(requestWrapper.getBody().toByteArray(), Object[].class);
|
||||
}
|
||||
|
||||
private Method getRunnerMethod(String methodName, Object[] args) {
|
||||
|
||||
try {
|
||||
return ReflectUtil.getMethod(
|
||||
callable.getClass(), methodName, query.getArgs() == null ? null : query.getArgs());
|
||||
return ReflectUtil.getMethod(callable.getClass(), methodName, args);
|
||||
} catch (NoSuchMethodException e) {
|
||||
throw new RayServeException(
|
||||
LogUtil.format(
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.ray.api.BaseActorHandle;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.generated.BackendConfig;
|
||||
import io.ray.serve.util.BackendConfigUtil;
|
||||
import io.ray.serve.generated.RequestMetadata;
|
||||
import io.ray.serve.util.ReflectUtil;
|
||||
import io.ray.serve.util.ServeProtoUtil;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.Optional;
|
||||
|
@ -30,7 +32,7 @@ public class RayServeWrappedReplica {
|
|||
IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException {
|
||||
|
||||
// Parse BackendConfig.
|
||||
BackendConfig backendConfig = BackendConfigUtil.parseFrom(backendConfigBytes);
|
||||
BackendConfig backendConfig = ServeProtoUtil.parseBackendConfig(backendConfigBytes);
|
||||
|
||||
// Parse init args.
|
||||
Object[] initArgs = parseInitArgs(initArgsbytes, backendConfig);
|
||||
|
@ -73,13 +75,25 @@ public class RayServeWrappedReplica {
|
|||
/**
|
||||
* The entry method to process the request.
|
||||
*
|
||||
* @param requestMetadata request metadata
|
||||
* @param requestArgs the input parameters of the specified method of the object defined by
|
||||
* backendDef.
|
||||
* @param requestMetadata the real type is byte[] if this invocation is cross-language. Otherwise,
|
||||
* the real type is {@link io.ray.serve.generated.RequestMetadata}.
|
||||
* @param requestArgs The input parameters of the specified method of the object defined by
|
||||
* backendDef. The real type is serialized {@link io.ray.serve.generated.RequestWrapper} if
|
||||
* this invocation is cross-language. Otherwise, the real type is Object[].
|
||||
* @return the result of request being processed
|
||||
* @throws InvalidProtocolBufferException if the protobuf deserialization fails.
|
||||
*/
|
||||
public Object handleRequest(RequestMetadata requestMetadata, Object[] requestArgs) {
|
||||
return backend.handleRequest(new Query(requestArgs, requestMetadata));
|
||||
public Object handleRequest(Object requestMetadata, Object requestArgs)
|
||||
throws InvalidProtocolBufferException {
|
||||
boolean isCrossLanguage = requestMetadata instanceof byte[];
|
||||
return backend.handleRequest(
|
||||
new Query(
|
||||
isCrossLanguage
|
||||
? ServeProtoUtil.parseRequestMetadata((byte[]) requestMetadata)
|
||||
: (RequestMetadata) requestMetadata,
|
||||
isCrossLanguage
|
||||
? ServeProtoUtil.parseRequestWrapper((byte[]) requestArgs)
|
||||
: requestArgs));
|
||||
}
|
||||
|
||||
/** Check whether this replica is ready or not. */
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
/** The meta data of request. */
|
||||
public class RequestMetadata implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = -8925036926565326811L;
|
||||
|
||||
private String requestId;
|
||||
|
||||
private String endpoint;
|
||||
|
||||
private String callMethod = "call";
|
||||
|
||||
private String httpMethod;
|
||||
|
||||
private Map<String, String> httpHeaders;
|
||||
|
||||
public String getRequestId() {
|
||||
return requestId;
|
||||
}
|
||||
|
||||
public void setRequestId(String requestId) {
|
||||
this.requestId = requestId;
|
||||
}
|
||||
|
||||
public String getEndpoint() {
|
||||
return endpoint;
|
||||
}
|
||||
|
||||
public void setEndpoint(String endpoint) {
|
||||
this.endpoint = endpoint;
|
||||
}
|
||||
|
||||
public String getCallMethod() {
|
||||
return callMethod;
|
||||
}
|
||||
|
||||
public void setCallMethod(String callMethod) {
|
||||
this.callMethod = callMethod;
|
||||
}
|
||||
|
||||
public String getHttpMethod() {
|
||||
return httpMethod;
|
||||
}
|
||||
|
||||
public void setHttpMethod(String httpMethod) {
|
||||
this.httpMethod = httpMethod;
|
||||
}
|
||||
|
||||
public Map<String, String> getHttpHeaders() {
|
||||
return httpHeaders;
|
||||
}
|
||||
|
||||
public void setHttpHeaders(Map<String, String> httpHeaders) {
|
||||
this.httpHeaders = httpHeaders;
|
||||
}
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
package io.ray.serve.util;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.RayServeException;
|
||||
import io.ray.serve.generated.BackendConfig;
|
||||
import io.ray.serve.generated.BackendLanguage;
|
||||
|
||||
public class BackendConfigUtil {
|
||||
|
||||
public static BackendConfig parseFrom(byte[] backendConfigBytes)
|
||||
throws InvalidProtocolBufferException {
|
||||
|
||||
// Parse BackendConfig from byte[].
|
||||
BackendConfig inputBackendConfig = BackendConfig.parseFrom(backendConfigBytes);
|
||||
if (inputBackendConfig == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Set default values.
|
||||
BackendConfig.Builder builder = BackendConfig.newBuilder();
|
||||
|
||||
if (inputBackendConfig.getNumReplicas() == 0) {
|
||||
builder.setNumReplicas(1);
|
||||
} else {
|
||||
builder.setNumReplicas(inputBackendConfig.getNumReplicas());
|
||||
}
|
||||
|
||||
Preconditions.checkArgument(
|
||||
inputBackendConfig.getMaxConcurrentQueries() >= 0, "max_concurrent_queries must be >= 0");
|
||||
if (inputBackendConfig.getMaxConcurrentQueries() == 0) {
|
||||
builder.setMaxConcurrentQueries(100);
|
||||
} else {
|
||||
builder.setMaxConcurrentQueries(inputBackendConfig.getMaxConcurrentQueries());
|
||||
}
|
||||
|
||||
builder.setUserConfig(inputBackendConfig.getUserConfig());
|
||||
|
||||
if (inputBackendConfig.getExperimentalGracefulShutdownWaitLoopS() == 0) {
|
||||
builder.setExperimentalGracefulShutdownWaitLoopS(2);
|
||||
} else {
|
||||
builder.setExperimentalGracefulShutdownWaitLoopS(
|
||||
inputBackendConfig.getExperimentalGracefulShutdownWaitLoopS());
|
||||
}
|
||||
|
||||
if (inputBackendConfig.getExperimentalGracefulShutdownTimeoutS() == 0) {
|
||||
builder.setExperimentalGracefulShutdownTimeoutS(20);
|
||||
} else {
|
||||
builder.setExperimentalGracefulShutdownTimeoutS(
|
||||
inputBackendConfig.getExperimentalGracefulShutdownTimeoutS());
|
||||
}
|
||||
|
||||
builder.setIsCrossLanguage(inputBackendConfig.getIsCrossLanguage());
|
||||
|
||||
if (inputBackendConfig.getBackendLanguage() == BackendLanguage.UNRECOGNIZED) {
|
||||
throw new RayServeException(
|
||||
LogUtil.format(
|
||||
"Unrecognized backend language {}. Backend language must be in {}.",
|
||||
inputBackendConfig.getBackendLanguageValue(),
|
||||
Lists.newArrayList(BackendLanguage.values())));
|
||||
} else {
|
||||
builder.setBackendLanguage(inputBackendConfig.getBackendLanguage());
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
public static Object getUserConfig(BackendConfig backendConfig) {
|
||||
if (backendConfig.getUserConfig() == null || backendConfig.getUserConfig().size() == 0) {
|
||||
return null;
|
||||
}
|
||||
return MessagePackSerializer.decode(backendConfig.getUserConfig().toByteArray(), Object.class);
|
||||
}
|
||||
}
|
111
java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java
Normal file
111
java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java
Normal file
|
@ -0,0 +1,111 @@
|
|||
package io.ray.serve.util;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.RayServeException;
|
||||
import io.ray.serve.generated.BackendConfig;
|
||||
import io.ray.serve.generated.BackendLanguage;
|
||||
import io.ray.serve.generated.RequestMetadata;
|
||||
import io.ray.serve.generated.RequestWrapper;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
public class ServeProtoUtil {
|
||||
|
||||
public static BackendConfig parseBackendConfig(byte[] backendConfigBytes)
|
||||
throws InvalidProtocolBufferException {
|
||||
|
||||
// Get a builder from BackendConfig(bytes) or create a new one.
|
||||
BackendConfig.Builder builder = null;
|
||||
if (backendConfigBytes == null) {
|
||||
builder = BackendConfig.newBuilder();
|
||||
} else {
|
||||
BackendConfig backendConfig = BackendConfig.parseFrom(backendConfigBytes);
|
||||
if (backendConfig == null) {
|
||||
builder = BackendConfig.newBuilder();
|
||||
} else {
|
||||
builder = BackendConfig.newBuilder(backendConfig);
|
||||
}
|
||||
}
|
||||
|
||||
// Set default values.
|
||||
if (builder.getNumReplicas() == 0) {
|
||||
builder.setNumReplicas(1);
|
||||
}
|
||||
|
||||
Preconditions.checkArgument(
|
||||
builder.getMaxConcurrentQueries() >= 0, "max_concurrent_queries must be >= 0");
|
||||
if (builder.getMaxConcurrentQueries() == 0) {
|
||||
builder.setMaxConcurrentQueries(100);
|
||||
}
|
||||
|
||||
if (builder.getExperimentalGracefulShutdownWaitLoopS() == 0) {
|
||||
builder.setExperimentalGracefulShutdownWaitLoopS(2);
|
||||
}
|
||||
|
||||
if (builder.getExperimentalGracefulShutdownTimeoutS() == 0) {
|
||||
builder.setExperimentalGracefulShutdownTimeoutS(20);
|
||||
}
|
||||
|
||||
if (builder.getBackendLanguage() == BackendLanguage.UNRECOGNIZED) {
|
||||
throw new RayServeException(
|
||||
LogUtil.format(
|
||||
"Unrecognized backend language {}. Backend language must be in {}.",
|
||||
builder.getBackendLanguageValue(),
|
||||
Lists.newArrayList(BackendLanguage.values())));
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
public static Object parseUserConfig(BackendConfig backendConfig) {
|
||||
if (backendConfig.getUserConfig() == null || backendConfig.getUserConfig().size() == 0) {
|
||||
return null;
|
||||
}
|
||||
return MessagePackSerializer.decode(backendConfig.getUserConfig().toByteArray(), Object.class);
|
||||
}
|
||||
|
||||
public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes)
|
||||
throws InvalidProtocolBufferException {
|
||||
|
||||
// Get a builder from RequestMetadata(bytes) or create a new one.
|
||||
RequestMetadata.Builder builder = null;
|
||||
if (requestMetadataBytes == null) {
|
||||
builder = RequestMetadata.newBuilder();
|
||||
} else {
|
||||
RequestMetadata requestMetadata = RequestMetadata.parseFrom(requestMetadataBytes);
|
||||
if (requestMetadata == null) {
|
||||
builder = RequestMetadata.newBuilder();
|
||||
} else {
|
||||
builder = RequestMetadata.newBuilder(requestMetadata);
|
||||
}
|
||||
}
|
||||
|
||||
// Set default values.
|
||||
if (StringUtils.isBlank(builder.getCallMethod())) {
|
||||
builder.setCallMethod("call");
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
public static RequestWrapper parseRequestWrapper(byte[] httpRequestWrapperBytes)
|
||||
throws InvalidProtocolBufferException {
|
||||
|
||||
// Get a builder from HTTPRequestWrapper(bytes) or create a new one.
|
||||
RequestWrapper.Builder builder = null;
|
||||
if (httpRequestWrapperBytes == null) {
|
||||
builder = RequestWrapper.newBuilder();
|
||||
} else {
|
||||
RequestWrapper requestWrapper = RequestWrapper.parseFrom(httpRequestWrapperBytes);
|
||||
if (requestWrapper == null) {
|
||||
builder = RequestWrapper.newBuilder();
|
||||
} else {
|
||||
builder = RequestWrapper.newBuilder(requestWrapper);
|
||||
}
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
}
|
|
@ -6,6 +6,8 @@ import io.ray.api.Ray;
|
|||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.generated.BackendConfig;
|
||||
import io.ray.serve.generated.BackendLanguage;
|
||||
import io.ray.serve.generated.RequestMetadata;
|
||||
import io.ray.serve.generated.RequestWrapper;
|
||||
import java.io.IOException;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
@ -17,7 +19,6 @@ public class RayServeReplicaTest {
|
|||
public void test() throws IOException {
|
||||
|
||||
boolean inited = Ray.isInitialized();
|
||||
|
||||
Ray.init();
|
||||
|
||||
try {
|
||||
|
@ -52,12 +53,18 @@ public class RayServeReplicaTest {
|
|||
|
||||
backendHandle.task(RayServeWrappedReplica::ready).remote();
|
||||
|
||||
RequestMetadata requestMetadata = new RequestMetadata();
|
||||
RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder();
|
||||
requestMetadata.setRequestId("RayServeReplicaTest");
|
||||
requestMetadata.setCallMethod("getBackendTag");
|
||||
|
||||
RequestWrapper.Builder requestWrapper = RequestWrapper.newBuilder();
|
||||
|
||||
ObjectRef<Object> resultRef =
|
||||
backendHandle
|
||||
.task(RayServeWrappedReplica::handleRequest, requestMetadata, (Object[]) null)
|
||||
.task(
|
||||
RayServeWrappedReplica::handleRequest,
|
||||
requestMetadata.build().toByteArray(),
|
||||
requestWrapper.build().toByteArray())
|
||||
.remote();
|
||||
|
||||
Assert.assertEquals((String) resultRef.get(), backendTag);
|
||||
|
|
|
@ -56,3 +56,15 @@ enum BackendLanguage {
|
|||
PYTHON = 0;
|
||||
JAVA = 1;
|
||||
}
|
||||
|
||||
message RequestMetadata {
|
||||
string request_id = 1;
|
||||
|
||||
string endpoint = 2;
|
||||
|
||||
string call_method = 3;
|
||||
}
|
||||
|
||||
message RequestWrapper {
|
||||
bytes body = 1;
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue