mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Serve] Make Java Replica Extendable (#19463)
This commit is contained in:
parent
81f036d078
commit
efca009258
34 changed files with 1200 additions and 666 deletions
|
@ -6,8 +6,8 @@ import java.util.List;
|
|||
/** Ray Serve common constants. */
|
||||
public class Constants {
|
||||
|
||||
/** Name of backend reconfiguration method implemented by user. */
|
||||
public static final String BACKEND_RECONFIGURE_METHOD = "reconfigure";
|
||||
/** Name of deployment reconfiguration method implemented by user. */
|
||||
public static final String RECONFIGURE_METHOD = "reconfigure";
|
||||
|
||||
/** Default histogram buckets for latency tracker. */
|
||||
public static final List<Double> DEFAULT_LATENCY_BUCKET_MS =
|
||||
|
@ -19,7 +19,9 @@ public class Constants {
|
|||
|
||||
public static final String SERVE_CONTROLLER_NAME = "SERVE_CONTROLLER_ACTOR";
|
||||
|
||||
public static final String DEFAULT_CALL_METHOD = "call";
|
||||
public static final String CALL_METHOD = "call";
|
||||
|
||||
public static final String UTF8 = "UTF-8";
|
||||
|
||||
public static final String CHECK_HEALTH_METHOD = "checkHealth";
|
||||
}
|
||||
|
|
87
java/serve/src/main/java/io/ray/serve/DeploymentConfig.java
Normal file
87
java/serve/src/main/java/io/ray/serve/DeploymentConfig.java
Normal file
|
@ -0,0 +1,87 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.io.Serializable;
|
||||
|
||||
public class DeploymentConfig implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = 4037621960087621036L;
|
||||
|
||||
private int numReplicas = 1;
|
||||
|
||||
private int maxConcurrentQueries = 100;
|
||||
|
||||
private Object userConfig;
|
||||
|
||||
private double gracefulShutdownWaitLoopS = 2;
|
||||
|
||||
private double gracefulShutdownTimeoutS = 20;
|
||||
|
||||
private boolean isCrossLanguage;
|
||||
|
||||
private int deploymentLanguage = 1;
|
||||
|
||||
public int getNumReplicas() {
|
||||
return numReplicas;
|
||||
}
|
||||
|
||||
public DeploymentConfig setNumReplicas(int numReplicas) {
|
||||
this.numReplicas = numReplicas;
|
||||
return this;
|
||||
}
|
||||
|
||||
public int getMaxConcurrentQueries() {
|
||||
return maxConcurrentQueries;
|
||||
}
|
||||
|
||||
public DeploymentConfig setMaxConcurrentQueries(int maxConcurrentQueries) {
|
||||
Preconditions.checkArgument(maxConcurrentQueries >= 0, "max_concurrent_queries must be >= 0");
|
||||
this.maxConcurrentQueries = maxConcurrentQueries;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Object getUserConfig() {
|
||||
return userConfig;
|
||||
}
|
||||
|
||||
public DeploymentConfig setUserConfig(Object userConfig) {
|
||||
this.userConfig = userConfig;
|
||||
return this;
|
||||
}
|
||||
|
||||
public double getGracefulShutdownWaitLoopS() {
|
||||
return gracefulShutdownWaitLoopS;
|
||||
}
|
||||
|
||||
public DeploymentConfig setGracefulShutdownWaitLoopS(double gracefulShutdownWaitLoopS) {
|
||||
this.gracefulShutdownWaitLoopS = gracefulShutdownWaitLoopS;
|
||||
return this;
|
||||
}
|
||||
|
||||
public double getGracefulShutdownTimeoutS() {
|
||||
return gracefulShutdownTimeoutS;
|
||||
}
|
||||
|
||||
public DeploymentConfig setGracefulShutdownTimeoutS(double gracefulShutdownTimeoutS) {
|
||||
this.gracefulShutdownTimeoutS = gracefulShutdownTimeoutS;
|
||||
return this;
|
||||
}
|
||||
|
||||
public boolean isCrossLanguage() {
|
||||
return isCrossLanguage;
|
||||
}
|
||||
|
||||
public DeploymentConfig setCrossLanguage(boolean isCrossLanguage) {
|
||||
this.isCrossLanguage = isCrossLanguage;
|
||||
return this;
|
||||
}
|
||||
|
||||
public int getDeploymentLanguage() {
|
||||
return deploymentLanguage;
|
||||
}
|
||||
|
||||
public DeploymentConfig setDeploymentLanguage(int deploymentLanguage) {
|
||||
this.deploymentLanguage = deploymentLanguage;
|
||||
return this;
|
||||
}
|
||||
}
|
|
@ -1,38 +1,75 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
public class DeploymentInfo implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = -4198364411759931955L;
|
||||
private static final long serialVersionUID = -7132135316463505391L;
|
||||
|
||||
private byte[] deploymentConfig;
|
||||
private String name;
|
||||
|
||||
private ReplicaConfig replicaConfig;
|
||||
private String deploymentDef;
|
||||
|
||||
private byte[] deploymentVersion;
|
||||
private Object[] initArgs;
|
||||
|
||||
public byte[] getDeploymentConfig() {
|
||||
private DeploymentConfig deploymentConfig;
|
||||
|
||||
private DeploymentVersion deploymentVersion;
|
||||
|
||||
private Map<String, String> config;
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public DeploymentInfo setName(String name) {
|
||||
this.name = name;
|
||||
return this;
|
||||
}
|
||||
|
||||
public String getDeploymentDef() {
|
||||
return deploymentDef;
|
||||
}
|
||||
|
||||
public DeploymentInfo setDeploymentDef(String deploymentDef) {
|
||||
this.deploymentDef = deploymentDef;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Object[] getInitArgs() {
|
||||
return initArgs;
|
||||
}
|
||||
|
||||
public DeploymentInfo setInitArgs(Object[] initArgs) {
|
||||
this.initArgs = initArgs;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DeploymentConfig getDeploymentConfig() {
|
||||
return deploymentConfig;
|
||||
}
|
||||
|
||||
public void setDeploymentConfig(byte[] deploymentConfig) {
|
||||
public DeploymentInfo setDeploymentConfig(DeploymentConfig deploymentConfig) {
|
||||
this.deploymentConfig = deploymentConfig;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ReplicaConfig getReplicaConfig() {
|
||||
return replicaConfig;
|
||||
}
|
||||
|
||||
public void setReplicaConfig(ReplicaConfig replicaConfig) {
|
||||
this.replicaConfig = replicaConfig;
|
||||
}
|
||||
|
||||
public byte[] getDeploymentVersion() {
|
||||
public DeploymentVersion getDeploymentVersion() {
|
||||
return deploymentVersion;
|
||||
}
|
||||
|
||||
public void setDeploymentVersion(byte[] deploymentVersion) {
|
||||
public DeploymentInfo setDeploymentVersion(DeploymentVersion deploymentVersion) {
|
||||
this.deploymentVersion = deploymentVersion;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Map<String, String> getConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
public DeploymentInfo setConfig(Map<String, String> config) {
|
||||
this.config = config;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
|
46
java/serve/src/main/java/io/ray/serve/DeploymentVersion.java
Normal file
46
java/serve/src/main/java/io/ray/serve/DeploymentVersion.java
Normal file
|
@ -0,0 +1,46 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import java.io.Serializable;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
public class DeploymentVersion implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = 3400261981775851058L;
|
||||
|
||||
private String codeVersion;
|
||||
|
||||
private Object userConfig;
|
||||
|
||||
private boolean unversioned;
|
||||
|
||||
public DeploymentVersion() {
|
||||
this(null, null);
|
||||
}
|
||||
|
||||
public DeploymentVersion(String codeVersion) {
|
||||
this(codeVersion, null);
|
||||
}
|
||||
|
||||
public DeploymentVersion(String codeVersion, Object userConfig) {
|
||||
if (StringUtils.isBlank(codeVersion)) {
|
||||
this.unversioned = true;
|
||||
this.codeVersion = RandomStringUtils.randomAlphabetic(6);
|
||||
} else {
|
||||
this.codeVersion = codeVersion;
|
||||
}
|
||||
this.userConfig = userConfig;
|
||||
}
|
||||
|
||||
public String getCodeVersion() {
|
||||
return codeVersion;
|
||||
}
|
||||
|
||||
public Object getUserConfig() {
|
||||
return userConfig;
|
||||
}
|
||||
|
||||
public boolean isUnversioned() {
|
||||
return unversioned;
|
||||
}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class DummyBackendReplica {
|
||||
|
||||
private AtomicInteger counter = new AtomicInteger();
|
||||
|
||||
public String call() {
|
||||
return String.valueOf(counter.incrementAndGet());
|
||||
}
|
||||
}
|
21
java/serve/src/main/java/io/ray/serve/DummyReplica.java
Normal file
21
java/serve/src/main/java/io/ray/serve/DummyReplica.java
Normal file
|
@ -0,0 +1,21 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class DummyReplica {
|
||||
|
||||
private AtomicInteger counter = new AtomicInteger();
|
||||
|
||||
public String call() {
|
||||
return String.valueOf(counter.incrementAndGet());
|
||||
}
|
||||
|
||||
public void reconfigure(Object userConfig) {
|
||||
counter.set(0);
|
||||
}
|
||||
|
||||
public void reconfigure(Map<String, String> userConfig) {
|
||||
counter.set(Integer.valueOf(userConfig.get("value")));
|
||||
}
|
||||
}
|
|
@ -137,8 +137,8 @@ public class ProxyActor {
|
|||
this.proxyRouter.updateRoutes(endpointInfos);
|
||||
}
|
||||
|
||||
public void ready() {
|
||||
return;
|
||||
public boolean ready() {
|
||||
return true;
|
||||
}
|
||||
|
||||
public void blockUntilEndpointExists(String endpoint, double timeoutS) {
|
||||
|
|
|
@ -1,6 +1,35 @@
|
|||
package io.ray.serve;
|
||||
|
||||
public class RayServeConfig {
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
public class RayServeConfig implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = 5367425336296141588L;
|
||||
|
||||
public static final String PROXY_CLASS = "ray.serve.proxy.class";
|
||||
|
||||
public static final String METRICS_ENABLED = "ray.serve.metrics.enabled";
|
||||
|
||||
private String name;
|
||||
|
||||
private Map<String, String> config;
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public RayServeConfig setName(String name) {
|
||||
this.name = name;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Map<String, String> getConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
public RayServeConfig setConfig(Map<String, String> config) {
|
||||
this.config = config;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ public class RayServeHandle {
|
|||
* Returns a Ray ObjectRef whose results can be waited for or retrieved using ray.wait or ray.get
|
||||
* (or ``await object_ref``), respectively.
|
||||
*
|
||||
* @param parameters The input parameters of the specified method to invoke on the backend.
|
||||
* @param parameters The input parameters of the specified method to be invoked in the deployment.
|
||||
* @return ray.ObjectRef
|
||||
*/
|
||||
public ObjectRef<Object> remote(Object[] parameters) {
|
||||
|
@ -58,7 +58,7 @@ public class RayServeHandle {
|
|||
requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10));
|
||||
requestMetadata.setEndpoint(endpointName);
|
||||
requestMetadata.setCallMethod(
|
||||
handleOptions != null ? handleOptions.getMethodName() : Constants.DEFAULT_CALL_METHOD);
|
||||
handleOptions != null ? handleOptions.getMethodName() : Constants.CALL_METHOD);
|
||||
return router.assignRequest(requestMetadata.build(), parameters);
|
||||
}
|
||||
|
||||
|
|
|
@ -14,20 +14,20 @@ public enum RayServeMetrics {
|
|||
"serve_deployment_queued_queries",
|
||||
"The current number of queries to this deployment waiting to be assigned to a replica."),
|
||||
|
||||
SERVE_BACKEND_REQUEST_COUNTER(
|
||||
"serve_backend_request_counter",
|
||||
SERVE_DEPLOYMENT_REQUEST_COUNTER(
|
||||
"serve_deployment_request_counter",
|
||||
"The number of queries that have been processed in this replica."),
|
||||
|
||||
SERVE_BACKEND_ERROR_COUNTER(
|
||||
"serve_backend_error_counter",
|
||||
SERVE_DEPLOYMENT_ERROR_COUNTER(
|
||||
"serve_deployment_error_counter",
|
||||
"The number of exceptions that have occurred in this replica."),
|
||||
|
||||
SERVE_BACKEND_REPLICA_STARTS(
|
||||
"serve_backend_replica_starts",
|
||||
SERVE_DEPLOYMENT_REPLICA_STARTS(
|
||||
"serve_deployment_replica_starts",
|
||||
"The number of times this replica has been restarted due to failure."),
|
||||
|
||||
SERVE_BACKEND_PROCESSING_LATENCY_MS(
|
||||
"serve_backend_processing_latency_ms", "The latency for queries to be processed."),
|
||||
SERVE_DEPLOYMENT_PROCESSING_LATENCY_MS(
|
||||
"serve_deployment_processing_latency_ms", "The latency for queries to be processed."),
|
||||
|
||||
SERVE_REPLICA_PROCESSING_QUERIES(
|
||||
"serve_replica_processing_queries", "The current number of queries being processed."),
|
||||
|
@ -41,13 +41,13 @@ public enum RayServeMetrics {
|
|||
|
||||
public static final String TAG_ROUTE = "route";
|
||||
|
||||
public static final String TAG_BACKEND = "backend";
|
||||
|
||||
public static final String TAG_REPLICA = "replica";
|
||||
|
||||
private static final boolean isMetricsEnabled =
|
||||
private static final boolean canBeUsed =
|
||||
Ray.isInitialized() && !Ray.getRuntimeContext().isSingleProcess();
|
||||
|
||||
private static volatile boolean enabled = true;
|
||||
|
||||
private String name;
|
||||
|
||||
private String description;
|
||||
|
@ -58,7 +58,7 @@ public enum RayServeMetrics {
|
|||
}
|
||||
|
||||
public static void execute(Runnable runnable) {
|
||||
if (!isMetricsEnabled) {
|
||||
if (!enabled || !canBeUsed) {
|
||||
return;
|
||||
}
|
||||
runnable.run();
|
||||
|
@ -71,4 +71,12 @@ public enum RayServeMetrics {
|
|||
public String getDescription() {
|
||||
return description;
|
||||
}
|
||||
|
||||
public static void enable() {
|
||||
enabled = true;
|
||||
}
|
||||
|
||||
public static void disable() {
|
||||
enabled = false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,330 +1,18 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.ray.api.BaseActorHandle;
|
||||
import io.ray.runtime.metric.Count;
|
||||
import io.ray.runtime.metric.Gauge;
|
||||
import io.ray.runtime.metric.Histogram;
|
||||
import io.ray.runtime.metric.Metrics;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.generated.DeploymentConfig;
|
||||
import io.ray.serve.generated.DeploymentVersion;
|
||||
import io.ray.serve.generated.RequestWrapper;
|
||||
import io.ray.serve.poll.KeyListener;
|
||||
import io.ray.serve.poll.KeyType;
|
||||
import io.ray.serve.poll.LongPollClient;
|
||||
import io.ray.serve.poll.LongPollNamespace;
|
||||
import io.ray.serve.util.LogUtil;
|
||||
import io.ray.serve.util.ReflectUtil;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
public interface RayServeReplica {
|
||||
|
||||
/** Handles requests with the provided callable. */
|
||||
public class RayServeReplica {
|
||||
Object handleRequest(Object requestMetadata, Object requestArgs);
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayServeReplica.class);
|
||||
|
||||
private String deploymentName;
|
||||
|
||||
private String replicaTag;
|
||||
|
||||
private DeploymentConfig config;
|
||||
|
||||
private AtomicInteger numOngoingRequests = new AtomicInteger();
|
||||
|
||||
private Object callable;
|
||||
|
||||
private Count requestCounter;
|
||||
|
||||
private Count errorCounter;
|
||||
|
||||
private Count restartCounter;
|
||||
|
||||
private Histogram processingLatencyTracker;
|
||||
|
||||
private Gauge numProcessingItems;
|
||||
|
||||
private LongPollClient longPollClient;
|
||||
|
||||
private DeploymentVersion version;
|
||||
|
||||
private boolean isDeleted = false;
|
||||
|
||||
public RayServeReplica(
|
||||
Object callable,
|
||||
DeploymentConfig deploymentConfig,
|
||||
DeploymentVersion version,
|
||||
BaseActorHandle actorHandle) {
|
||||
this.deploymentName = Serve.getReplicaContext().getDeploymentName();
|
||||
this.replicaTag = Serve.getReplicaContext().getReplicaTag();
|
||||
this.callable = callable;
|
||||
this.config = deploymentConfig;
|
||||
this.version = version;
|
||||
|
||||
Map<KeyType, KeyListener> keyListeners = new HashMap<>();
|
||||
keyListeners.put(
|
||||
new KeyType(LongPollNamespace.BACKEND_CONFIGS, deploymentName),
|
||||
newConfig -> updateDeploymentConfigs(newConfig));
|
||||
this.longPollClient = new LongPollClient(actorHandle, keyListeners);
|
||||
this.longPollClient.start();
|
||||
registerMetrics();
|
||||
default Object reconfigure(Object userConfig) {
|
||||
return new DeploymentVersion(null, userConfig);
|
||||
}
|
||||
|
||||
private void registerMetrics() {
|
||||
RayServeMetrics.execute(
|
||||
() ->
|
||||
requestCounter =
|
||||
Metrics.count()
|
||||
.name(RayServeMetrics.SERVE_BACKEND_REQUEST_COUNTER.getName())
|
||||
.description(RayServeMetrics.SERVE_BACKEND_REQUEST_COUNTER.getDescription())
|
||||
.unit("")
|
||||
.tags(
|
||||
ImmutableMap.of(
|
||||
RayServeMetrics.TAG_BACKEND,
|
||||
deploymentName,
|
||||
RayServeMetrics.TAG_REPLICA,
|
||||
replicaTag))
|
||||
.register());
|
||||
|
||||
RayServeMetrics.execute(
|
||||
() ->
|
||||
errorCounter =
|
||||
Metrics.count()
|
||||
.name(RayServeMetrics.SERVE_BACKEND_ERROR_COUNTER.getName())
|
||||
.description(RayServeMetrics.SERVE_BACKEND_ERROR_COUNTER.getDescription())
|
||||
.unit("")
|
||||
.tags(
|
||||
ImmutableMap.of(
|
||||
RayServeMetrics.TAG_BACKEND,
|
||||
deploymentName,
|
||||
RayServeMetrics.TAG_REPLICA,
|
||||
replicaTag))
|
||||
.register());
|
||||
|
||||
RayServeMetrics.execute(
|
||||
() ->
|
||||
restartCounter =
|
||||
Metrics.count()
|
||||
.name(RayServeMetrics.SERVE_BACKEND_REPLICA_STARTS.getName())
|
||||
.description(RayServeMetrics.SERVE_BACKEND_REPLICA_STARTS.getDescription())
|
||||
.unit("")
|
||||
.tags(
|
||||
ImmutableMap.of(
|
||||
RayServeMetrics.TAG_BACKEND,
|
||||
deploymentName,
|
||||
RayServeMetrics.TAG_REPLICA,
|
||||
replicaTag))
|
||||
.register());
|
||||
|
||||
RayServeMetrics.execute(
|
||||
() ->
|
||||
processingLatencyTracker =
|
||||
Metrics.histogram()
|
||||
.name(RayServeMetrics.SERVE_BACKEND_PROCESSING_LATENCY_MS.getName())
|
||||
.description(
|
||||
RayServeMetrics.SERVE_BACKEND_PROCESSING_LATENCY_MS.getDescription())
|
||||
.unit("")
|
||||
.boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS)
|
||||
.tags(
|
||||
ImmutableMap.of(
|
||||
RayServeMetrics.TAG_BACKEND,
|
||||
deploymentName,
|
||||
RayServeMetrics.TAG_REPLICA,
|
||||
replicaTag))
|
||||
.register());
|
||||
|
||||
RayServeMetrics.execute(
|
||||
() ->
|
||||
numProcessingItems =
|
||||
Metrics.gauge()
|
||||
.name(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getName())
|
||||
.description(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getDescription())
|
||||
.unit("")
|
||||
.tags(
|
||||
ImmutableMap.of(
|
||||
RayServeMetrics.TAG_BACKEND,
|
||||
deploymentName,
|
||||
RayServeMetrics.TAG_REPLICA,
|
||||
replicaTag))
|
||||
.register());
|
||||
|
||||
RayServeMetrics.execute(() -> restartCounter.inc(1.0));
|
||||
}
|
||||
|
||||
public Object handleRequest(Query request) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
LOGGER.debug(
|
||||
"Replica {} received request {}", replicaTag, request.getMetadata().getRequestId());
|
||||
|
||||
numOngoingRequests.incrementAndGet();
|
||||
RayServeMetrics.execute(() -> numProcessingItems.update(numOngoingRequests.get()));
|
||||
Object result = invokeSingle(request);
|
||||
numOngoingRequests.decrementAndGet();
|
||||
|
||||
long requestTimeMs = System.currentTimeMillis() - startTime;
|
||||
LOGGER.debug(
|
||||
"Replica {} finished request {} in {}ms",
|
||||
replicaTag,
|
||||
request.getMetadata().getRequestId(),
|
||||
requestTimeMs);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private Object invokeSingle(Query requestItem) {
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
Method methodToCall = null;
|
||||
try {
|
||||
LOGGER.debug(
|
||||
"Replica {} started executing request {}",
|
||||
replicaTag,
|
||||
requestItem.getMetadata().getRequestId());
|
||||
|
||||
Object[] args = parseRequestItem(requestItem);
|
||||
methodToCall = getRunnerMethod(requestItem.getMetadata().getCallMethod(), args);
|
||||
Object result = methodToCall.invoke(callable, args);
|
||||
RayServeMetrics.execute(() -> requestCounter.inc(1.0));
|
||||
return result;
|
||||
} catch (Throwable e) {
|
||||
RayServeMetrics.execute(() -> errorCounter.inc(1.0));
|
||||
throw new RayServeException(
|
||||
LogUtil.format(
|
||||
"Replica {} failed to invoke method {}",
|
||||
replicaTag,
|
||||
methodToCall == null ? "unknown" : methodToCall.getName()),
|
||||
e);
|
||||
} finally {
|
||||
RayServeMetrics.execute(
|
||||
() -> processingLatencyTracker.update(System.currentTimeMillis() - start));
|
||||
}
|
||||
}
|
||||
|
||||
private Object[] parseRequestItem(Query requestItem) {
|
||||
if (requestItem.getArgs() == null) {
|
||||
return new Object[0];
|
||||
}
|
||||
|
||||
// From Java Proxy or Handle.
|
||||
if (requestItem.getArgs() instanceof Object[]) {
|
||||
return (Object[]) requestItem.getArgs();
|
||||
}
|
||||
|
||||
// From other language Proxy or Handle.
|
||||
RequestWrapper requestWrapper = (RequestWrapper) requestItem.getArgs();
|
||||
if (requestWrapper.getBody() == null || requestWrapper.getBody().isEmpty()) {
|
||||
return new Object[0];
|
||||
}
|
||||
|
||||
return MessagePackSerializer.decode(requestWrapper.getBody().toByteArray(), Object[].class);
|
||||
}
|
||||
|
||||
private Method getRunnerMethod(String methodName, Object[] args) {
|
||||
|
||||
try {
|
||||
return ReflectUtil.getMethod(callable.getClass(), methodName, args);
|
||||
} catch (NoSuchMethodException e) {
|
||||
throw new RayServeException(
|
||||
LogUtil.format(
|
||||
"Backend doesn't have method {} which is specified in the request. "
|
||||
+ "The available methods are {}",
|
||||
methodName,
|
||||
ReflectUtil.getMethodStrings(callable.getClass())));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform graceful shutdown. Trigger a graceful shutdown protocol that will wait for all the
|
||||
* queued tasks to be completed and return to the controller.
|
||||
*/
|
||||
public synchronized boolean prepareForShutdown() {
|
||||
while (true) {
|
||||
// Sleep first because we want to make sure all the routers receive the notification to remove
|
||||
// this replica first.
|
||||
try {
|
||||
Thread.sleep((long) (config.getGracefulShutdownWaitLoopS() * 1000));
|
||||
} catch (InterruptedException e) {
|
||||
LOGGER.error(
|
||||
"Replica {} was interrupted in sheep when draining pending queries", replicaTag);
|
||||
}
|
||||
if (numOngoingRequests.get() == 0) {
|
||||
break;
|
||||
} else {
|
||||
LOGGER.info(
|
||||
"Waiting for an additional {}s to shut down because there are {} ongoing requests.",
|
||||
config.getGracefulShutdownWaitLoopS(),
|
||||
numOngoingRequests.get());
|
||||
}
|
||||
}
|
||||
|
||||
// Explicitly call the del method to trigger clean up. We set isDeleted = true after
|
||||
// succssifully calling it so the destructor is called only once.
|
||||
try {
|
||||
if (!isDeleted) {
|
||||
ReflectUtil.getMethod(callable.getClass(), "del").invoke(callable);
|
||||
}
|
||||
} catch (NoSuchMethodException e) {
|
||||
LOGGER.warn("Deployment {} has no del method.", deploymentName);
|
||||
} catch (Throwable e) {
|
||||
LOGGER.error("Exception during graceful shutdown of replica.");
|
||||
} finally {
|
||||
isDeleted = true;
|
||||
}
|
||||
default boolean checkHealth() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconfigure user's configuration in the callable object through its reconfigure method.
|
||||
*
|
||||
* @param userConfig new user's configuration
|
||||
*/
|
||||
public DeploymentVersion reconfigure(Object userConfig) {
|
||||
DeploymentVersion.Builder builder = DeploymentVersion.newBuilder();
|
||||
builder.setCodeVersion(version.getCodeVersion());
|
||||
if (userConfig != null) {
|
||||
builder.setUserConfig(ByteString.copyFrom((byte[]) userConfig));
|
||||
}
|
||||
version = builder.build();
|
||||
|
||||
try {
|
||||
Method reconfigureMethod =
|
||||
ReflectUtil.getMethod(
|
||||
callable.getClass(),
|
||||
Constants.BACKEND_RECONFIGURE_METHOD,
|
||||
userConfig != null
|
||||
? MessagePackSerializer.decode((byte[]) userConfig, Object[].class)
|
||||
: new Object[0]); // TODO cache reconfigure method
|
||||
reconfigureMethod.invoke(callable, userConfig);
|
||||
} catch (NoSuchMethodException e) {
|
||||
LOGGER.warn(
|
||||
"user_config specified but backend {} missing {} method",
|
||||
deploymentName,
|
||||
Constants.BACKEND_RECONFIGURE_METHOD);
|
||||
} catch (Throwable e) {
|
||||
throw new RayServeException(
|
||||
LogUtil.format(
|
||||
"Backend {} failed to reconfigure user_config {}", deploymentName, userConfig),
|
||||
e);
|
||||
}
|
||||
return version;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update backend configs.
|
||||
*
|
||||
* @param newConfig the new configuration of backend
|
||||
*/
|
||||
private void updateDeploymentConfigs(Object newConfig) {
|
||||
config = (DeploymentConfig) newConfig;
|
||||
}
|
||||
|
||||
public DeploymentVersion getVersion() {
|
||||
return version;
|
||||
default boolean prepareForShutdown() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
380
java/serve/src/main/java/io/ray/serve/RayServeReplicaImpl.java
Normal file
380
java/serve/src/main/java/io/ray/serve/RayServeReplicaImpl.java
Normal file
|
@ -0,0 +1,380 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import io.ray.api.BaseActorHandle;
|
||||
import io.ray.runtime.metric.Count;
|
||||
import io.ray.runtime.metric.Gauge;
|
||||
import io.ray.runtime.metric.Histogram;
|
||||
import io.ray.runtime.metric.Metrics;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.generated.RequestMetadata;
|
||||
import io.ray.serve.generated.RequestWrapper;
|
||||
import io.ray.serve.util.LogUtil;
|
||||
import io.ray.serve.util.ReflectUtil;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/** Handles requests with the provided callable. */
|
||||
public class RayServeReplicaImpl implements RayServeReplica {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayServeReplicaImpl.class);
|
||||
|
||||
private String deploymentName;
|
||||
|
||||
private String replicaTag;
|
||||
|
||||
private DeploymentConfig config;
|
||||
|
||||
private AtomicInteger numOngoingRequests = new AtomicInteger();
|
||||
|
||||
private Object callable;
|
||||
|
||||
private Count requestCounter;
|
||||
|
||||
private Count errorCounter;
|
||||
|
||||
private Count restartCounter;
|
||||
|
||||
private Histogram processingLatencyTracker;
|
||||
|
||||
private Gauge numProcessingItems;
|
||||
|
||||
private DeploymentVersion version;
|
||||
|
||||
private boolean isDeleted = false;
|
||||
|
||||
private final Method checkHealthMethod;
|
||||
|
||||
private final Method callMethod;
|
||||
|
||||
public RayServeReplicaImpl(
|
||||
Object callable,
|
||||
DeploymentConfig deploymentConfig,
|
||||
DeploymentVersion version,
|
||||
BaseActorHandle actorHandle) {
|
||||
this.deploymentName = Serve.getReplicaContext().getDeploymentName();
|
||||
this.replicaTag = Serve.getReplicaContext().getReplicaTag();
|
||||
this.callable = callable;
|
||||
this.config = deploymentConfig;
|
||||
this.version = version;
|
||||
this.checkHealthMethod = getRunnerMethod(Constants.CHECK_HEALTH_METHOD, null, true);
|
||||
this.callMethod = getRunnerMethod(Constants.CALL_METHOD, new Object[] {new Object()}, true);
|
||||
registerMetrics();
|
||||
}
|
||||
|
||||
private void registerMetrics() {
|
||||
RayServeMetrics.execute(
|
||||
() ->
|
||||
requestCounter =
|
||||
Metrics.count()
|
||||
.name(RayServeMetrics.SERVE_DEPLOYMENT_REQUEST_COUNTER.getName())
|
||||
.description(RayServeMetrics.SERVE_DEPLOYMENT_REQUEST_COUNTER.getDescription())
|
||||
.unit("")
|
||||
.tags(
|
||||
ImmutableMap.of(
|
||||
RayServeMetrics.TAG_DEPLOYMENT,
|
||||
deploymentName,
|
||||
RayServeMetrics.TAG_REPLICA,
|
||||
replicaTag))
|
||||
.register());
|
||||
|
||||
RayServeMetrics.execute(
|
||||
() ->
|
||||
errorCounter =
|
||||
Metrics.count()
|
||||
.name(RayServeMetrics.SERVE_DEPLOYMENT_ERROR_COUNTER.getName())
|
||||
.description(RayServeMetrics.SERVE_DEPLOYMENT_ERROR_COUNTER.getDescription())
|
||||
.unit("")
|
||||
.tags(
|
||||
ImmutableMap.of(
|
||||
RayServeMetrics.TAG_DEPLOYMENT,
|
||||
deploymentName,
|
||||
RayServeMetrics.TAG_REPLICA,
|
||||
replicaTag))
|
||||
.register());
|
||||
|
||||
RayServeMetrics.execute(
|
||||
() ->
|
||||
restartCounter =
|
||||
Metrics.count()
|
||||
.name(RayServeMetrics.SERVE_DEPLOYMENT_REPLICA_STARTS.getName())
|
||||
.description(RayServeMetrics.SERVE_DEPLOYMENT_REPLICA_STARTS.getDescription())
|
||||
.unit("")
|
||||
.tags(
|
||||
ImmutableMap.of(
|
||||
RayServeMetrics.TAG_DEPLOYMENT,
|
||||
deploymentName,
|
||||
RayServeMetrics.TAG_REPLICA,
|
||||
replicaTag))
|
||||
.register());
|
||||
|
||||
RayServeMetrics.execute(
|
||||
() ->
|
||||
processingLatencyTracker =
|
||||
Metrics.histogram()
|
||||
.name(RayServeMetrics.SERVE_DEPLOYMENT_PROCESSING_LATENCY_MS.getName())
|
||||
.description(
|
||||
RayServeMetrics.SERVE_DEPLOYMENT_PROCESSING_LATENCY_MS.getDescription())
|
||||
.unit("")
|
||||
.boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS)
|
||||
.tags(
|
||||
ImmutableMap.of(
|
||||
RayServeMetrics.TAG_DEPLOYMENT,
|
||||
deploymentName,
|
||||
RayServeMetrics.TAG_REPLICA,
|
||||
replicaTag))
|
||||
.register());
|
||||
|
||||
RayServeMetrics.execute(
|
||||
() ->
|
||||
numProcessingItems =
|
||||
Metrics.gauge()
|
||||
.name(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getName())
|
||||
.description(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getDescription())
|
||||
.unit("")
|
||||
.tags(
|
||||
ImmutableMap.of(
|
||||
RayServeMetrics.TAG_DEPLOYMENT,
|
||||
deploymentName,
|
||||
RayServeMetrics.TAG_REPLICA,
|
||||
replicaTag))
|
||||
.register());
|
||||
|
||||
RayServeMetrics.execute(() -> restartCounter.inc(1.0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object handleRequest(Object requestMetadata, Object requestArgs) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
Query request = new Query((RequestMetadata) requestMetadata, requestArgs);
|
||||
LOGGER.debug(
|
||||
"Replica {} received request {}", replicaTag, request.getMetadata().getRequestId());
|
||||
|
||||
numOngoingRequests.incrementAndGet();
|
||||
RayServeMetrics.execute(() -> numProcessingItems.update(numOngoingRequests.get()));
|
||||
Object result = invokeSingle(request);
|
||||
numOngoingRequests.decrementAndGet();
|
||||
|
||||
long requestTimeMs = System.currentTimeMillis() - startTime;
|
||||
LOGGER.debug(
|
||||
"Replica {} finished request {} in {}ms",
|
||||
replicaTag,
|
||||
request.getMetadata().getRequestId(),
|
||||
requestTimeMs);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private Object invokeSingle(Query requestItem) {
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
Method methodToCall = null;
|
||||
try {
|
||||
LOGGER.debug(
|
||||
"Replica {} started executing request {}",
|
||||
replicaTag,
|
||||
requestItem.getMetadata().getRequestId());
|
||||
|
||||
Object[] args = parseRequestItem(requestItem);
|
||||
methodToCall =
|
||||
args.length == 1 && callMethod != null
|
||||
? callMethod
|
||||
: getRunnerMethod(requestItem.getMetadata().getCallMethod(), args, false);
|
||||
Object result = methodToCall.invoke(callable, args);
|
||||
RayServeMetrics.execute(() -> requestCounter.inc(1.0));
|
||||
return result;
|
||||
} catch (Throwable e) {
|
||||
RayServeMetrics.execute(() -> errorCounter.inc(1.0));
|
||||
throw new RayServeException(
|
||||
LogUtil.format(
|
||||
"Replica {} failed to invoke method {}",
|
||||
replicaTag,
|
||||
methodToCall == null ? "unknown" : methodToCall.getName()),
|
||||
e);
|
||||
} finally {
|
||||
RayServeMetrics.execute(
|
||||
() -> processingLatencyTracker.update(System.currentTimeMillis() - start));
|
||||
}
|
||||
}
|
||||
|
||||
private Object[] parseRequestItem(Query requestItem) {
|
||||
if (requestItem.getArgs() == null) {
|
||||
return new Object[0];
|
||||
}
|
||||
|
||||
// From Java Proxy or Handle.
|
||||
if (requestItem.getArgs() instanceof Object[]) {
|
||||
return (Object[]) requestItem.getArgs();
|
||||
}
|
||||
|
||||
// From other language Proxy or Handle.
|
||||
RequestWrapper requestWrapper = (RequestWrapper) requestItem.getArgs();
|
||||
if (requestWrapper.getBody() == null || requestWrapper.getBody().isEmpty()) {
|
||||
return new Object[0];
|
||||
}
|
||||
|
||||
return new Object[] {
|
||||
MessagePackSerializer.decode(requestWrapper.getBody().toByteArray(), Object.class)
|
||||
};
|
||||
}
|
||||
|
||||
private Method getRunnerMethod(String methodName, Object[] args, boolean isNullable) {
|
||||
try {
|
||||
return ReflectUtil.getMethod(callable.getClass(), methodName, args);
|
||||
} catch (NoSuchMethodException e) {
|
||||
String errMsg =
|
||||
LogUtil.format(
|
||||
"Tried to call a method {} that does not exist. Available methods: {}",
|
||||
methodName,
|
||||
ReflectUtil.getMethodStrings(callable.getClass()));
|
||||
if (isNullable) {
|
||||
LOGGER.warn(errMsg);
|
||||
return null;
|
||||
} else {
|
||||
LOGGER.error(errMsg, e);
|
||||
throw new RayServeException(errMsg, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform graceful shutdown. Trigger a graceful shutdown protocol that will wait for all the
|
||||
* queued tasks to be completed and return to the controller.
|
||||
*
|
||||
* @return true if it is ready for shutdown.
|
||||
*/
|
||||
@Override
|
||||
public synchronized boolean prepareForShutdown() {
|
||||
while (true) {
|
||||
// Sleep first because we want to make sure all the routers receive the notification to remove
|
||||
// this replica first.
|
||||
try {
|
||||
Thread.sleep((long) (config.getGracefulShutdownWaitLoopS() * 1000));
|
||||
} catch (InterruptedException e) {
|
||||
LOGGER.error(
|
||||
"Replica {} was interrupted in sheep when draining pending queries", replicaTag);
|
||||
}
|
||||
if (numOngoingRequests.get() == 0) {
|
||||
break;
|
||||
} else {
|
||||
LOGGER.info(
|
||||
"Waiting for an additional {}s to shut down because there are {} ongoing requests.",
|
||||
config.getGracefulShutdownWaitLoopS(),
|
||||
numOngoingRequests.get());
|
||||
}
|
||||
}
|
||||
|
||||
// Explicitly call the del method to trigger clean up. We set isDeleted = true after
|
||||
// succssifully calling it so the destructor is called only once.
|
||||
try {
|
||||
if (!isDeleted) {
|
||||
ReflectUtil.getMethod(callable.getClass(), "del").invoke(callable);
|
||||
}
|
||||
} catch (NoSuchMethodException e) {
|
||||
LOGGER.warn("Deployment {} has no del method.", deploymentName);
|
||||
} catch (Throwable e) {
|
||||
LOGGER.error("Exception during graceful shutdown of replica.");
|
||||
} finally {
|
||||
isDeleted = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DeploymentVersion reconfigure(Object userConfig) {
|
||||
DeploymentVersion deploymentVersion =
|
||||
new DeploymentVersion(version.getCodeVersion(), userConfig);
|
||||
version = deploymentVersion;
|
||||
if (userConfig == null) {
|
||||
return deploymentVersion;
|
||||
}
|
||||
|
||||
LOGGER.info(
|
||||
"Replica {} of deployment {} reconfigure userConfig: {}",
|
||||
replicaTag,
|
||||
deploymentName,
|
||||
userConfig);
|
||||
try {
|
||||
ReflectUtil.getMethod(callable.getClass(), Constants.RECONFIGURE_METHOD, userConfig)
|
||||
.invoke(callable, userConfig);
|
||||
return version;
|
||||
} catch (NoSuchMethodException e) {
|
||||
String errMsg =
|
||||
LogUtil.format(
|
||||
"userConfig specified but deployment {} missing {} method",
|
||||
deploymentName,
|
||||
Constants.RECONFIGURE_METHOD);
|
||||
LOGGER.error(errMsg);
|
||||
throw new RayServeException(errMsg, e);
|
||||
} catch (Throwable e) {
|
||||
String errMsg =
|
||||
LogUtil.format(
|
||||
"Replica {} of deployment {} failed to reconfigure userConfig {}",
|
||||
replicaTag,
|
||||
deploymentName,
|
||||
userConfig);
|
||||
LOGGER.error(errMsg);
|
||||
throw new RayServeException(errMsg, e);
|
||||
} finally {
|
||||
LOGGER.info(
|
||||
"Replica {} of deployment {} finished reconfiguring userConfig: {}",
|
||||
replicaTag,
|
||||
deploymentName,
|
||||
userConfig);
|
||||
}
|
||||
}
|
||||
|
||||
public DeploymentVersion getVersion() {
|
||||
return version;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean checkHealth() {
|
||||
if (checkHealthMethod == null) {
|
||||
return true;
|
||||
}
|
||||
boolean result = true;
|
||||
try {
|
||||
LOGGER.info(
|
||||
"Replica {} of deployment {} check health of {}",
|
||||
replicaTag,
|
||||
deploymentName,
|
||||
callable.getClass().getName());
|
||||
Object isHealthy = checkHealthMethod.invoke(callable);
|
||||
if (!(isHealthy instanceof Boolean)) {
|
||||
LOGGER.error(
|
||||
"The health check result {} of {} in replica {} of deployment {} is illegal.",
|
||||
isHealthy == null ? "null" : isHealthy.getClass().getName() + ":" + isHealthy,
|
||||
callable.getClass().getName(),
|
||||
replicaTag,
|
||||
deploymentName);
|
||||
result = false;
|
||||
} else {
|
||||
result = (boolean) isHealthy;
|
||||
}
|
||||
} catch (Throwable e) {
|
||||
LOGGER.error(
|
||||
"Replica {} of deployment {} failed to check health of {}",
|
||||
replicaTag,
|
||||
deploymentName,
|
||||
callable.getClass().getName(),
|
||||
e);
|
||||
result = false;
|
||||
} finally {
|
||||
LOGGER.info(
|
||||
"The health check result of {} in replica {} of deployment {} is {}.",
|
||||
callable.getClass().getName(),
|
||||
replicaTag,
|
||||
deploymentName,
|
||||
result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public Object getCallable() {
|
||||
return callable;
|
||||
}
|
||||
}
|
|
@ -1,82 +1,135 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.ray.api.BaseActorHandle;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.generated.DeploymentConfig;
|
||||
import io.ray.serve.generated.DeploymentVersion;
|
||||
import io.ray.serve.generated.RequestMetadata;
|
||||
import io.ray.serve.util.LogUtil;
|
||||
import io.ray.serve.util.ReflectUtil;
|
||||
import io.ray.serve.util.ServeProtoUtil;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/** Replica class wrapping the provided class. Note that Java function is not supported now. */
|
||||
public class RayServeWrappedReplica {
|
||||
public class RayServeWrappedReplica implements RayServeReplica {
|
||||
|
||||
private RayServeReplica backend;
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayServeReplicaImpl.class);
|
||||
|
||||
private DeploymentInfo deploymentInfo;
|
||||
|
||||
private RayServeReplicaImpl replica;
|
||||
|
||||
@SuppressWarnings("rawtypes")
|
||||
public RayServeWrappedReplica(
|
||||
String deploymentName,
|
||||
String replicaTag,
|
||||
String backendDef,
|
||||
String deploymentDef,
|
||||
byte[] initArgsbytes,
|
||||
byte[] deploymentConfigBytes,
|
||||
byte[] deploymentVersionBytes,
|
||||
String controllerName)
|
||||
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
|
||||
IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException {
|
||||
String controllerName) {
|
||||
|
||||
// Parse DeploymentConfig.
|
||||
DeploymentConfig deploymentConfig = ServeProtoUtil.parseDeploymentConfig(deploymentConfigBytes);
|
||||
|
||||
// Parse init args.
|
||||
Object[] initArgs = parseInitArgs(initArgsbytes, deploymentConfig);
|
||||
Object[] initArgs = null;
|
||||
try {
|
||||
initArgs = parseInitArgs(initArgsbytes, deploymentConfig);
|
||||
} catch (IOException e) {
|
||||
String errMsg =
|
||||
LogUtil.format(
|
||||
"Failed to initialize replica {} of deployment {}",
|
||||
replicaTag,
|
||||
deploymentInfo.getName());
|
||||
LOGGER.error(errMsg, e);
|
||||
throw new RayServeException(errMsg, e);
|
||||
}
|
||||
|
||||
// Instantiate the object defined by backendDef.
|
||||
Class backendClass = Class.forName(backendDef);
|
||||
Object callable = ReflectUtil.getConstructor(backendClass, initArgs).newInstance(initArgs);
|
||||
|
||||
// Get the controller by controllerName.
|
||||
Preconditions.checkArgument(
|
||||
StringUtils.isNotBlank(controllerName), "Must provide a valid controllerName");
|
||||
Optional<BaseActorHandle> optional = Ray.getActor(controllerName);
|
||||
Preconditions.checkState(optional.isPresent(), "Controller does not exist");
|
||||
|
||||
// Set the controller name so that Serve.connect() in the user's backend code will connect to
|
||||
// the instance that this backend is running in.
|
||||
Serve.setInternalReplicaContext(deploymentName, replicaTag, controllerName, callable);
|
||||
|
||||
// Construct worker replica.
|
||||
backend =
|
||||
new RayServeReplica(
|
||||
callable,
|
||||
deploymentConfig,
|
||||
ServeProtoUtil.parseDeploymentVersion(deploymentVersionBytes),
|
||||
optional.get());
|
||||
// Init replica.
|
||||
init(
|
||||
new DeploymentInfo()
|
||||
.setName(deploymentName)
|
||||
.setDeploymentConfig(deploymentConfig)
|
||||
.setDeploymentVersion(ServeProtoUtil.parseDeploymentVersion(deploymentVersionBytes))
|
||||
.setDeploymentDef(deploymentDef)
|
||||
.setInitArgs(initArgs),
|
||||
replicaTag,
|
||||
controllerName,
|
||||
null);
|
||||
}
|
||||
|
||||
public RayServeWrappedReplica(
|
||||
String deploymentName,
|
||||
String replicaTag,
|
||||
DeploymentInfo deploymentInfo,
|
||||
String controllerName)
|
||||
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
|
||||
IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException {
|
||||
this(
|
||||
deploymentName,
|
||||
replicaTag,
|
||||
deploymentInfo.getReplicaConfig().getBackendDef(),
|
||||
deploymentInfo.getReplicaConfig().getInitArgs(),
|
||||
deploymentInfo.getDeploymentConfig(),
|
||||
deploymentInfo.getDeploymentVersion(),
|
||||
controllerName);
|
||||
String replicaTag,
|
||||
String controllerName,
|
||||
RayServeConfig rayServeConfig) {
|
||||
init(deploymentInfo, replicaTag, controllerName, rayServeConfig);
|
||||
}
|
||||
|
||||
@SuppressWarnings("rawtypes")
|
||||
private void init(
|
||||
DeploymentInfo deploymentInfo,
|
||||
String replicaTag,
|
||||
String controllerName,
|
||||
RayServeConfig rayServeConfig) {
|
||||
try {
|
||||
// Set the controller name so that Serve.connect() in the user's code will connect to the
|
||||
// instance that this deployment is running in.
|
||||
Serve.setInternalReplicaContext(deploymentInfo.getName(), replicaTag, controllerName, null);
|
||||
Serve.getReplicaContext().setRayServeConfig(rayServeConfig);
|
||||
|
||||
// Instantiate the object defined by deploymentDef.
|
||||
Class deploymentClass = Class.forName(deploymentInfo.getDeploymentDef());
|
||||
Object callable =
|
||||
ReflectUtil.getConstructor(deploymentClass, deploymentInfo.getInitArgs())
|
||||
.newInstance(deploymentInfo.getInitArgs());
|
||||
Serve.getReplicaContext().setServableObject(callable);
|
||||
|
||||
// Get the controller by controllerName.
|
||||
Preconditions.checkArgument(
|
||||
StringUtils.isNotBlank(controllerName), "Must provide a valid controllerName");
|
||||
Optional<BaseActorHandle> optional = Ray.getActor(controllerName);
|
||||
Preconditions.checkState(optional.isPresent(), "Controller does not exist");
|
||||
|
||||
// Enable metrics.
|
||||
enableMetrics(deploymentInfo.getConfig());
|
||||
|
||||
// Construct worker replica.
|
||||
this.replica =
|
||||
new RayServeReplicaImpl(
|
||||
callable,
|
||||
deploymentInfo.getDeploymentConfig(),
|
||||
deploymentInfo.getDeploymentVersion(),
|
||||
optional.get());
|
||||
this.deploymentInfo = deploymentInfo;
|
||||
} catch (Throwable e) {
|
||||
String errMsg =
|
||||
LogUtil.format(
|
||||
"Failed to initialize replica {} of deployment {}",
|
||||
replicaTag,
|
||||
deploymentInfo.getName());
|
||||
LOGGER.error(errMsg, e);
|
||||
throw new RayServeException(errMsg, e);
|
||||
}
|
||||
}
|
||||
|
||||
private void enableMetrics(Map<String, String> config) {
|
||||
Optional.ofNullable(config)
|
||||
.map(conf -> conf.get(RayServeConfig.METRICS_ENABLED))
|
||||
.ifPresent(
|
||||
enabled -> {
|
||||
if (Boolean.valueOf(enabled)) {
|
||||
RayServeMetrics.enable();
|
||||
} else {
|
||||
RayServeMetrics.disable();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private Object[] parseInitArgs(byte[] initArgsbytes, DeploymentConfig deploymentConfig)
|
||||
|
@ -86,13 +139,13 @@ public class RayServeWrappedReplica {
|
|||
return new Object[0];
|
||||
}
|
||||
|
||||
if (!deploymentConfig.getIsCrossLanguage()) {
|
||||
if (deploymentConfig.isCrossLanguage()) {
|
||||
// For other language like Python API, not support Array type.
|
||||
return new Object[] {MessagePackSerializer.decode(initArgsbytes, Object.class)};
|
||||
} else {
|
||||
// If the construction request is from Java API, deserialize initArgsbytes to Object[]
|
||||
// directly.
|
||||
return MessagePackSerializer.decode(initArgsbytes, Object[].class);
|
||||
} else {
|
||||
// For other language like Python API, not support Array type.
|
||||
return new Object[] {MessagePackSerializer.decode(initArgsbytes, Object.class)};
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,27 +155,28 @@ public class RayServeWrappedReplica {
|
|||
* @param requestMetadata the real type is byte[] if this invocation is cross-language. Otherwise,
|
||||
* the real type is {@link io.ray.serve.generated.RequestMetadata}.
|
||||
* @param requestArgs The input parameters of the specified method of the object defined by
|
||||
* backendDef. The real type is serialized {@link io.ray.serve.generated.RequestWrapper} if
|
||||
* deploymentDef. The real type is serialized {@link io.ray.serve.generated.RequestWrapper} if
|
||||
* this invocation is cross-language. Otherwise, the real type is Object[].
|
||||
* @return the result of request being processed
|
||||
* @throws InvalidProtocolBufferException if the protobuf deserialization fails.
|
||||
*/
|
||||
public Object handleRequest(Object requestMetadata, Object requestArgs)
|
||||
throws InvalidProtocolBufferException {
|
||||
@Override
|
||||
public Object handleRequest(Object requestMetadata, Object requestArgs) {
|
||||
boolean isCrossLanguage = requestMetadata instanceof byte[];
|
||||
return backend.handleRequest(
|
||||
new Query(
|
||||
isCrossLanguage
|
||||
? ServeProtoUtil.parseRequestMetadata((byte[]) requestMetadata)
|
||||
: (RequestMetadata) requestMetadata,
|
||||
isCrossLanguage
|
||||
? ServeProtoUtil.parseRequestWrapper((byte[]) requestArgs)
|
||||
: requestArgs));
|
||||
return replica.handleRequest(
|
||||
isCrossLanguage
|
||||
? ServeProtoUtil.parseRequestMetadata((byte[]) requestMetadata)
|
||||
: (RequestMetadata) requestMetadata,
|
||||
isCrossLanguage ? ServeProtoUtil.parseRequestWrapper((byte[]) requestArgs) : requestArgs);
|
||||
}
|
||||
|
||||
/** Check whether this replica is ready or not. */
|
||||
public void ready() {
|
||||
return;
|
||||
/**
|
||||
* Check if the actor is healthy.
|
||||
*
|
||||
* @return true if the actor is health, or return false.
|
||||
*/
|
||||
@Override
|
||||
public boolean checkHealth() {
|
||||
return replica.checkHealth();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -130,16 +184,44 @@ public class RayServeWrappedReplica {
|
|||
*
|
||||
* @return true if it is ready for shutdown.
|
||||
*/
|
||||
@Override
|
||||
public boolean prepareForShutdown() {
|
||||
return backend.prepareForShutdown();
|
||||
return replica.prepareForShutdown();
|
||||
}
|
||||
|
||||
public byte[] reconfigure(Object userConfig) {
|
||||
DeploymentVersion deploymentVersion = backend.reconfigure(userConfig);
|
||||
return deploymentVersion.toByteArray();
|
||||
/**
|
||||
* Reconfigure user's configuration in the callable object through its reconfigure method.
|
||||
*
|
||||
* @param userConfig new user's configuration
|
||||
* @return DeploymentVersion. If the current invocation is crossing language, the
|
||||
* DeploymentVersion is serialized to protobuf byte[].
|
||||
*/
|
||||
@Override
|
||||
public Object reconfigure(Object userConfig) {
|
||||
DeploymentVersion deploymentVersion =
|
||||
replica.reconfigure(
|
||||
deploymentInfo.getDeploymentConfig().isCrossLanguage() && userConfig != null
|
||||
? MessagePackSerializer.decode((byte[]) userConfig, Object.class)
|
||||
: userConfig);
|
||||
return deploymentInfo.getDeploymentConfig().isCrossLanguage()
|
||||
? ServeProtoUtil.toProtobuf(deploymentVersion).toByteArray()
|
||||
: deploymentVersion;
|
||||
}
|
||||
|
||||
public byte[] getVersion() {
|
||||
return backend.getVersion().toByteArray();
|
||||
/**
|
||||
* Get the deployment version of the current replica.
|
||||
*
|
||||
* @return DeploymentVersion. If the current invocation is crossing language, the
|
||||
* DeploymentVersion is serialized to protobuf byte[].
|
||||
*/
|
||||
public Object getVersion() {
|
||||
DeploymentVersion deploymentVersion = replica.getVersion();
|
||||
return deploymentInfo.getDeploymentConfig().isCrossLanguage()
|
||||
? ServeProtoUtil.toProtobuf(deploymentVersion).toByteArray()
|
||||
: deploymentVersion;
|
||||
}
|
||||
|
||||
public Object getCallable() {
|
||||
return replica.getCallable();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ public class ReplicaConfig implements Serializable {
|
|||
|
||||
private static final long serialVersionUID = -1442657824045704226L;
|
||||
|
||||
private String backendDef;
|
||||
private String deploymentDef;
|
||||
|
||||
private byte[] initArgs;
|
||||
|
||||
|
@ -18,8 +18,8 @@ public class ReplicaConfig implements Serializable {
|
|||
|
||||
private Map<String, Double> resource;
|
||||
|
||||
public ReplicaConfig(String backendDef, byte[] initArgs, Map<String, Object> rayActorOptions) {
|
||||
this.backendDef = backendDef;
|
||||
public ReplicaConfig(String deploymentDef, byte[] initArgs, Map<String, Object> rayActorOptions) {
|
||||
this.deploymentDef = deploymentDef;
|
||||
this.initArgs = initArgs;
|
||||
this.rayActorOptions = rayActorOptions;
|
||||
this.resource = new HashMap<>();
|
||||
|
@ -30,7 +30,7 @@ public class ReplicaConfig implements Serializable {
|
|||
private void validate() {
|
||||
Preconditions.checkArgument(
|
||||
!rayActorOptions.containsKey("placement_group"),
|
||||
"Providing placement_group for backend actors is not currently supported.");
|
||||
"Providing placement_group for deployment actors is not currently supported.");
|
||||
|
||||
Preconditions.checkArgument(
|
||||
!rayActorOptions.containsKey("lifetime"),
|
||||
|
@ -81,12 +81,12 @@ public class ReplicaConfig implements Serializable {
|
|||
resource.putAll((Map) customResources);
|
||||
}
|
||||
|
||||
public String getBackendDef() {
|
||||
return backendDef;
|
||||
public String getDeploymentDef() {
|
||||
return deploymentDef;
|
||||
}
|
||||
|
||||
public void setBackendDef(String backendDef) {
|
||||
this.backendDef = backendDef;
|
||||
public void setDeploymentDef(String deploymentDef) {
|
||||
this.deploymentDef = deploymentDef;
|
||||
}
|
||||
|
||||
public byte[] getInitArgs() {
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package io.ray.serve;
|
||||
|
||||
/** Stores data for Serve API calls from within the user's backend code. */
|
||||
/** Stores data for Serve API calls from within deployments. */
|
||||
public class ReplicaContext {
|
||||
|
||||
private String deploymentName; // TODO deployment
|
||||
private String deploymentName;
|
||||
|
||||
private String replicaTag;
|
||||
|
||||
|
@ -11,6 +11,8 @@ public class ReplicaContext {
|
|||
|
||||
private Object servableObject;
|
||||
|
||||
private RayServeConfig rayServeConfig;
|
||||
|
||||
public ReplicaContext(
|
||||
String deploymentName, String replicaTag, String controllerName, Object servableObject) {
|
||||
this.deploymentName = deploymentName;
|
||||
|
@ -50,4 +52,12 @@ public class ReplicaContext {
|
|||
public void setServableObject(Object servableObject) {
|
||||
this.servableObject = servableObject;
|
||||
}
|
||||
|
||||
public RayServeConfig getRayServeConfig() {
|
||||
return rayServeConfig;
|
||||
}
|
||||
|
||||
public void setRayServeConfig(RayServeConfig rayServeConfig) {
|
||||
this.rayServeConfig = rayServeConfig;
|
||||
}
|
||||
}
|
||||
|
|
30
java/serve/src/main/java/io/ray/serve/ReplicaName.java
Normal file
30
java/serve/src/main/java/io/ray/serve/ReplicaName.java
Normal file
|
@ -0,0 +1,30 @@
|
|||
package io.ray.serve;
|
||||
|
||||
public class ReplicaName {
|
||||
|
||||
private String deploymentTag;
|
||||
|
||||
private String replicaSuffix;
|
||||
|
||||
private String replicaTag = "";
|
||||
|
||||
private String delimiter = "#";
|
||||
|
||||
public ReplicaName(String deploymentTag, String replicaSuffix) {
|
||||
this.deploymentTag = deploymentTag;
|
||||
this.replicaSuffix = replicaSuffix;
|
||||
this.replicaTag = deploymentTag + this.delimiter + replicaSuffix;
|
||||
}
|
||||
|
||||
public String getDeploymentTag() {
|
||||
return deploymentTag;
|
||||
}
|
||||
|
||||
public String getReplicaSuffix() {
|
||||
return replicaSuffix;
|
||||
}
|
||||
|
||||
public String getReplicaTag() {
|
||||
return replicaTag;
|
||||
}
|
||||
}
|
|
@ -9,7 +9,6 @@ import io.ray.runtime.metric.Gauge;
|
|||
import io.ray.runtime.metric.Metrics;
|
||||
import io.ray.runtime.metric.TagKey;
|
||||
import io.ray.serve.generated.ActorSet;
|
||||
import io.ray.serve.generated.DeploymentConfig;
|
||||
import io.ray.serve.util.CollectionUtil;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
|
@ -27,8 +26,6 @@ public class ReplicaSet {
|
|||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaSet.class);
|
||||
|
||||
private volatile int maxConcurrentQueries = 8;
|
||||
|
||||
private final Map<ActorHandle<RayServeWrappedReplica>, Set<ObjectRef<Object>>> inFlightQueries;
|
||||
|
||||
private AtomicInteger numQueuedQueries = new AtomicInteger();
|
||||
|
@ -48,18 +45,6 @@ public class ReplicaSet {
|
|||
.register());
|
||||
}
|
||||
|
||||
public void setMaxConcurrentQueries(Object deploymentConfig) {
|
||||
int newValue = ((DeploymentConfig) deploymentConfig).getMaxConcurrentQueries();
|
||||
if (newValue != this.maxConcurrentQueries) {
|
||||
this.maxConcurrentQueries = newValue;
|
||||
LOGGER.info("ReplicaSet: changing max_concurrent_queries to {}", newValue);
|
||||
}
|
||||
}
|
||||
|
||||
public int getMaxConcurrentQueries() {
|
||||
return maxConcurrentQueries;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public synchronized void updateWorkerReplicas(Object actorSet) {
|
||||
List<String> actorNames = ((ActorSet) actorSet).getNamesList();
|
||||
|
@ -86,7 +71,7 @@ public class ReplicaSet {
|
|||
/**
|
||||
* Given a query, submit it to a replica and return the object ref. This method will keep track of
|
||||
* the in flight queries for each replicas and only send a query to available replicas (determined
|
||||
* by the backend max_concurrent_quries value.)
|
||||
* by the max_concurrent_quries value.)
|
||||
*
|
||||
* @param query the incoming query.
|
||||
* @return ray.ObjectRef
|
||||
|
@ -98,7 +83,7 @@ public class ReplicaSet {
|
|||
() ->
|
||||
numQueuedQueriesGauge.update(
|
||||
numQueuedQueries.get(),
|
||||
TagKey.tagsFromMap(ImmutableMap.of(RayServeMetrics.TAG_ENDPOINT, endpoint))));
|
||||
ImmutableMap.of(new TagKey(RayServeMetrics.TAG_ENDPOINT), endpoint)));
|
||||
ObjectRef<Object> assignedRef =
|
||||
tryAssignReplica(query); // TODO controll concurrency using maxConcurrentQueries
|
||||
numQueuedQueries.decrementAndGet();
|
||||
|
@ -106,7 +91,7 @@ public class ReplicaSet {
|
|||
() ->
|
||||
numQueuedQueriesGauge.update(
|
||||
numQueuedQueries.get(),
|
||||
TagKey.tagsFromMap(ImmutableMap.of(RayServeMetrics.TAG_ENDPOINT, endpoint))));
|
||||
ImmutableMap.of(new TagKey(RayServeMetrics.TAG_ENDPOINT), endpoint)));
|
||||
return assignedRef;
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import io.ray.serve.poll.LongPollNamespace;
|
|||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/** Router process incoming queries: choose backend, and assign replica. */
|
||||
/** Router process incoming queries: assign a replica. */
|
||||
public class Router {
|
||||
|
||||
private ReplicaSet replicaSet;
|
||||
|
@ -36,9 +36,6 @@ public class Router {
|
|||
.register());
|
||||
|
||||
Map<KeyType, KeyListener> keyListeners = new HashMap<>();
|
||||
keyListeners.put(
|
||||
new KeyType(LongPollNamespace.BACKEND_CONFIGS, deploymentName),
|
||||
deploymentConfig -> replicaSet.setMaxConcurrentQueries(deploymentConfig)); // cross language
|
||||
keyListeners.put(
|
||||
new KeyType(LongPollNamespace.REPLICA_HANDLES, deploymentName),
|
||||
workerReplicas -> replicaSet.updateWorkerReplicas(workerReplicas)); // cross language
|
||||
|
|
|
@ -19,7 +19,7 @@ public class Serve {
|
|||
/**
|
||||
* Set replica information to global context.
|
||||
*
|
||||
* @param deploymentName backend tag
|
||||
* @param deploymentName deployment name
|
||||
* @param replicaTag replica tag
|
||||
* @param controllerName the controller actor's name
|
||||
* @param servableObject the servable object of the specified replica.
|
||||
|
@ -42,7 +42,7 @@ public class Serve {
|
|||
public static ReplicaContext getReplicaContext() {
|
||||
if (INTERNAL_REPLICA_CONTEXT == null) {
|
||||
throw new RayServeException(
|
||||
"`Serve.getReplicaContext()` may only be called from within a Ray Serve backend.");
|
||||
"`Serve.getReplicaContext()` may only be called from within a Ray Serve deployment.");
|
||||
}
|
||||
return INTERNAL_REPLICA_CONTEXT;
|
||||
}
|
||||
|
|
|
@ -46,10 +46,7 @@ public class LongPollClient {
|
|||
new HashMap<>();
|
||||
|
||||
static {
|
||||
DESERIALIZERS.put(
|
||||
LongPollNamespace.BACKEND_CONFIGS, body -> ServeProtoUtil.parseDeploymentConfig(body));
|
||||
DESERIALIZERS.put(
|
||||
LongPollNamespace.REPLICA_HANDLES, body -> ServeProtoUtil.parseEndpointSet(body));
|
||||
DESERIALIZERS.put(LongPollNamespace.ROUTE_TABLE, body -> ServeProtoUtil.parseEndpointSet(body));
|
||||
DESERIALIZERS.put(
|
||||
LongPollNamespace.REPLICA_HANDLES,
|
||||
body -> {
|
||||
|
@ -89,7 +86,7 @@ public class LongPollClient {
|
|||
}
|
||||
}
|
||||
},
|
||||
"backend-poll-thread");
|
||||
"ray-serve-long-poll-thread");
|
||||
}
|
||||
|
||||
public void start() {
|
||||
|
|
|
@ -4,7 +4,5 @@ package io.ray.serve.poll;
|
|||
public enum LongPollNamespace {
|
||||
REPLICA_HANDLES,
|
||||
|
||||
BACKEND_CONFIGS,
|
||||
|
||||
ROUTE_TABLE;
|
||||
}
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
package io.ray.serve.util;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.gson.Gson;
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.Constants;
|
||||
import io.ray.serve.DeploymentConfig;
|
||||
import io.ray.serve.DeploymentVersion;
|
||||
import io.ray.serve.RayServeException;
|
||||
import io.ray.serve.generated.DeploymentConfig;
|
||||
import io.ray.serve.generated.DeploymentLanguage;
|
||||
import io.ray.serve.generated.DeploymentVersion;
|
||||
import io.ray.serve.generated.EndpointInfo;
|
||||
import io.ray.serve.generated.EndpointSet;
|
||||
import io.ray.serve.generated.LongPollResult;
|
||||
|
@ -27,71 +27,67 @@ public class ServeProtoUtil {
|
|||
|
||||
public static DeploymentConfig parseDeploymentConfig(byte[] deploymentConfigBytes) {
|
||||
|
||||
// Get a builder from DeploymentConfig(bytes) or create a new one.
|
||||
DeploymentConfig.Builder builder = null;
|
||||
DeploymentConfig deploymentConfig = new DeploymentConfig();
|
||||
if (deploymentConfigBytes == null) {
|
||||
builder = DeploymentConfig.newBuilder();
|
||||
} else {
|
||||
DeploymentConfig deploymentConfig = null;
|
||||
try {
|
||||
deploymentConfig = DeploymentConfig.parseFrom(deploymentConfigBytes);
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new RayServeException("Failed to parse DeploymentConfig from protobuf bytes.", e);
|
||||
}
|
||||
if (deploymentConfig == null) {
|
||||
builder = DeploymentConfig.newBuilder();
|
||||
} else {
|
||||
builder = DeploymentConfig.newBuilder(deploymentConfig);
|
||||
}
|
||||
return deploymentConfig;
|
||||
}
|
||||
|
||||
// Set default values.
|
||||
if (builder.getNumReplicas() == 0) {
|
||||
builder.setNumReplicas(1);
|
||||
io.ray.serve.generated.DeploymentConfig pbDeploymentConfig = null;
|
||||
try {
|
||||
pbDeploymentConfig = io.ray.serve.generated.DeploymentConfig.parseFrom(deploymentConfigBytes);
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new RayServeException("Failed to parse DeploymentConfig from protobuf bytes.", e);
|
||||
}
|
||||
|
||||
Preconditions.checkArgument(
|
||||
builder.getMaxConcurrentQueries() >= 0, "max_concurrent_queries must be >= 0");
|
||||
if (builder.getMaxConcurrentQueries() == 0) {
|
||||
builder.setMaxConcurrentQueries(100);
|
||||
if (pbDeploymentConfig == null) {
|
||||
return deploymentConfig;
|
||||
}
|
||||
|
||||
if (builder.getGracefulShutdownWaitLoopS() == 0) {
|
||||
builder.setGracefulShutdownWaitLoopS(2);
|
||||
if (pbDeploymentConfig.getNumReplicas() != 0) {
|
||||
deploymentConfig.setNumReplicas(pbDeploymentConfig.getNumReplicas());
|
||||
}
|
||||
|
||||
if (builder.getGracefulShutdownTimeoutS() == 0) {
|
||||
builder.setGracefulShutdownTimeoutS(20);
|
||||
if (pbDeploymentConfig.getMaxConcurrentQueries() != 0) {
|
||||
deploymentConfig.setMaxConcurrentQueries(pbDeploymentConfig.getMaxConcurrentQueries());
|
||||
}
|
||||
|
||||
if (builder.getDeploymentLanguage() == DeploymentLanguage.UNRECOGNIZED) {
|
||||
if (pbDeploymentConfig.getGracefulShutdownWaitLoopS() != 0) {
|
||||
deploymentConfig.setGracefulShutdownWaitLoopS(
|
||||
pbDeploymentConfig.getGracefulShutdownWaitLoopS());
|
||||
}
|
||||
if (pbDeploymentConfig.getGracefulShutdownTimeoutS() != 0) {
|
||||
deploymentConfig.setGracefulShutdownTimeoutS(
|
||||
pbDeploymentConfig.getGracefulShutdownTimeoutS());
|
||||
}
|
||||
deploymentConfig.setCrossLanguage(pbDeploymentConfig.getIsCrossLanguage());
|
||||
if (pbDeploymentConfig.getDeploymentLanguage() == DeploymentLanguage.UNRECOGNIZED) {
|
||||
throw new RayServeException(
|
||||
LogUtil.format(
|
||||
"Unrecognized backend language {}. Backend language must be in {}.",
|
||||
builder.getDeploymentLanguageValue(),
|
||||
"Unrecognized deployment language {}. Deployment language must be in {}.",
|
||||
pbDeploymentConfig.getDeploymentLanguage(),
|
||||
Lists.newArrayList(DeploymentLanguage.values())));
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
public static Object parseUserConfig(DeploymentConfig deploymentConfig) {
|
||||
if (deploymentConfig.getUserConfig() == null || deploymentConfig.getUserConfig().size() == 0) {
|
||||
return null;
|
||||
deploymentConfig.setDeploymentLanguage(pbDeploymentConfig.getDeploymentLanguageValue());
|
||||
if (pbDeploymentConfig.getUserConfig() != null
|
||||
&& pbDeploymentConfig.getUserConfig().size() != 0) {
|
||||
deploymentConfig.setUserConfig(
|
||||
MessagePackSerializer.decode(
|
||||
pbDeploymentConfig.getUserConfig().toByteArray(), Object.class));
|
||||
}
|
||||
return MessagePackSerializer.decode(
|
||||
deploymentConfig.getUserConfig().toByteArray(), Object.class);
|
||||
return deploymentConfig;
|
||||
}
|
||||
|
||||
public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes)
|
||||
throws InvalidProtocolBufferException {
|
||||
public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes) {
|
||||
|
||||
// Get a builder from RequestMetadata(bytes) or create a new one.
|
||||
RequestMetadata.Builder builder = null;
|
||||
if (requestMetadataBytes == null) {
|
||||
builder = RequestMetadata.newBuilder();
|
||||
} else {
|
||||
RequestMetadata requestMetadata = RequestMetadata.parseFrom(requestMetadataBytes);
|
||||
RequestMetadata requestMetadata = null;
|
||||
try {
|
||||
requestMetadata = RequestMetadata.parseFrom(requestMetadataBytes);
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new RayServeException("Failed to parse RequestMetadata from protobuf bytes.", e);
|
||||
}
|
||||
if (requestMetadata == null) {
|
||||
builder = RequestMetadata.newBuilder();
|
||||
} else {
|
||||
|
@ -101,21 +97,25 @@ public class ServeProtoUtil {
|
|||
|
||||
// Set default values.
|
||||
if (StringUtils.isBlank(builder.getCallMethod())) {
|
||||
builder.setCallMethod(Constants.DEFAULT_CALL_METHOD);
|
||||
builder.setCallMethod(Constants.CALL_METHOD);
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
public static RequestWrapper parseRequestWrapper(byte[] httpRequestWrapperBytes)
|
||||
throws InvalidProtocolBufferException {
|
||||
public static RequestWrapper parseRequestWrapper(byte[] httpRequestWrapperBytes) {
|
||||
|
||||
// Get a builder from HTTPRequestWrapper(bytes) or create a new one.
|
||||
RequestWrapper.Builder builder = null;
|
||||
if (httpRequestWrapperBytes == null) {
|
||||
builder = RequestWrapper.newBuilder();
|
||||
} else {
|
||||
RequestWrapper requestWrapper = RequestWrapper.parseFrom(httpRequestWrapperBytes);
|
||||
RequestWrapper requestWrapper = null;
|
||||
try {
|
||||
requestWrapper = RequestWrapper.parseFrom(httpRequestWrapperBytes);
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new RayServeException("Failed to parse RequestWrapper from protobuf bytes.", e);
|
||||
}
|
||||
if (requestWrapper == null) {
|
||||
builder = RequestWrapper.newBuilder();
|
||||
} else {
|
||||
|
@ -160,12 +160,46 @@ public class ServeProtoUtil {
|
|||
|
||||
public static DeploymentVersion parseDeploymentVersion(byte[] deploymentVersionBytes) {
|
||||
if (deploymentVersionBytes == null) {
|
||||
return null;
|
||||
return new DeploymentVersion();
|
||||
}
|
||||
|
||||
io.ray.serve.generated.DeploymentVersion pbDeploymentVersion = null;
|
||||
try {
|
||||
return DeploymentVersion.parseFrom(deploymentVersionBytes);
|
||||
pbDeploymentVersion =
|
||||
io.ray.serve.generated.DeploymentVersion.parseFrom(deploymentVersionBytes);
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new RayServeException("Failed to parse DeploymentVersion from protobuf bytes.", e);
|
||||
}
|
||||
if (pbDeploymentVersion == null) {
|
||||
return new DeploymentVersion();
|
||||
}
|
||||
return new DeploymentVersion(
|
||||
pbDeploymentVersion.getCodeVersion(),
|
||||
pbDeploymentVersion.getUserConfig() != null
|
||||
&& pbDeploymentVersion.getUserConfig().size() != 0
|
||||
? new Object[] {
|
||||
MessagePackSerializer.decode(
|
||||
pbDeploymentVersion.getUserConfig().toByteArray(), Object.class)
|
||||
}
|
||||
: null);
|
||||
}
|
||||
|
||||
public static io.ray.serve.generated.DeploymentVersion toProtobuf(
|
||||
DeploymentVersion deploymentVersion) {
|
||||
io.ray.serve.generated.DeploymentVersion.Builder pbDeploymentVersion =
|
||||
io.ray.serve.generated.DeploymentVersion.newBuilder();
|
||||
if (deploymentVersion == null) {
|
||||
return pbDeploymentVersion.build();
|
||||
}
|
||||
|
||||
if (StringUtils.isNotBlank(deploymentVersion.getCodeVersion())) {
|
||||
pbDeploymentVersion.setCodeVersion(deploymentVersion.getCodeVersion());
|
||||
}
|
||||
if (deploymentVersion.getUserConfig() != null) {
|
||||
pbDeploymentVersion.setUserConfig(
|
||||
ByteString.copyFrom(
|
||||
MessagePackSerializer.encode(deploymentVersion.getUserConfig()).getLeft()));
|
||||
}
|
||||
return pbDeploymentVersion.build();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ public class HttpProxyTest {
|
|||
try (CloseableHttpResponse httpResponse =
|
||||
(CloseableHttpResponse) httpClient.execute(httpPost)) {
|
||||
|
||||
// No Backend replica, so error is expected.
|
||||
// No replica, so error is expected.
|
||||
int status = httpResponse.getCode();
|
||||
Assert.assertEquals(status, HttpURLConnection.HTTP_INTERNAL_ERROR);
|
||||
}
|
||||
|
|
|
@ -5,8 +5,7 @@ import io.ray.api.Ray;
|
|||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.generated.ActorSet;
|
||||
import io.ray.serve.generated.DeploymentConfig;
|
||||
import io.ray.serve.generated.DeploymentVersion;
|
||||
import io.ray.serve.generated.DeploymentLanguage;
|
||||
import io.ray.serve.generated.EndpointInfo;
|
||||
import io.ray.serve.util.CommonUtil;
|
||||
import java.io.IOException;
|
||||
|
@ -50,26 +49,29 @@ public class ProxyActorTest {
|
|||
controller.task(DummyServeController::setEndpoints, endpointInfos).remote();
|
||||
|
||||
// Replica
|
||||
DeploymentInfo deploymentInfo = new DeploymentInfo();
|
||||
deploymentInfo.setDeploymentConfig(DeploymentConfig.newBuilder().build().toByteArray());
|
||||
deploymentInfo.setDeploymentVersion(
|
||||
DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray());
|
||||
deploymentInfo.setReplicaConfig(
|
||||
new ReplicaConfig(DummyBackendReplica.class.getName(), null, new HashMap<>()));
|
||||
DeploymentInfo deploymentInfo =
|
||||
new DeploymentInfo()
|
||||
.setName(deploymentName)
|
||||
.setDeploymentConfig(
|
||||
new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber()))
|
||||
.setDeploymentVersion(new DeploymentVersion(version))
|
||||
.setDeploymentDef(DummyReplica.class.getName());
|
||||
|
||||
ActorHandle<RayServeWrappedReplica> replica =
|
||||
Ray.actor(
|
||||
RayServeWrappedReplica::new,
|
||||
deploymentName,
|
||||
replicaTag,
|
||||
deploymentInfo,
|
||||
controllerName)
|
||||
replicaTag,
|
||||
controllerName,
|
||||
(RayServeConfig) null)
|
||||
.setName(replicaTag)
|
||||
.remote();
|
||||
replica.task(RayServeWrappedReplica::ready).remote();
|
||||
Assert.assertTrue(replica.task(RayServeWrappedReplica::checkHealth).remote().get());
|
||||
|
||||
// ProxyActor
|
||||
ProxyActor proxyActor = new ProxyActor(controllerName, null);
|
||||
Assert.assertTrue(proxyActor.ready());
|
||||
|
||||
proxyActor.getProxyRouter().updateRoutes(endpointInfos);
|
||||
proxyActor
|
||||
.getProxyRouter()
|
||||
|
|
|
@ -3,12 +3,9 @@ package io.ray.serve;
|
|||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.generated.ActorSet;
|
||||
import io.ray.serve.generated.DeploymentConfig;
|
||||
import io.ray.serve.generated.DeploymentLanguage;
|
||||
import io.ray.serve.generated.DeploymentVersion;
|
||||
import java.util.HashMap;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
|
@ -31,30 +28,29 @@ public class RayServeHandleTest {
|
|||
Ray.actor(DummyServeController::new).setName(controllerName).remote();
|
||||
|
||||
// Replica
|
||||
DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder();
|
||||
deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA);
|
||||
byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray();
|
||||
DeploymentConfig deploymentConfig =
|
||||
new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber());
|
||||
|
||||
Object[] initArgs = new Object[] {deploymentName, replicaTag, controllerName, new Object()};
|
||||
byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft();
|
||||
|
||||
DeploymentInfo deploymentInfo = new DeploymentInfo();
|
||||
deploymentInfo.setDeploymentConfig(deploymentConfigBytes);
|
||||
deploymentInfo.setDeploymentVersion(
|
||||
DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray());
|
||||
deploymentInfo.setReplicaConfig(
|
||||
new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>()));
|
||||
DeploymentInfo deploymentInfo =
|
||||
new DeploymentInfo()
|
||||
.setName(deploymentName)
|
||||
.setDeploymentConfig(deploymentConfig)
|
||||
.setDeploymentVersion(new DeploymentVersion(version))
|
||||
.setDeploymentDef("io.ray.serve.ReplicaContext")
|
||||
.setInitArgs(initArgs);
|
||||
|
||||
ActorHandle<RayServeWrappedReplica> replicaHandle =
|
||||
Ray.actor(
|
||||
RayServeWrappedReplica::new,
|
||||
deploymentName,
|
||||
replicaTag,
|
||||
deploymentInfo,
|
||||
controllerName)
|
||||
replicaTag,
|
||||
controllerName,
|
||||
(RayServeConfig) null)
|
||||
.setName(actorName)
|
||||
.remote();
|
||||
replicaHandle.task(RayServeWrappedReplica::ready).remote();
|
||||
Assert.assertTrue(replicaHandle.task(RayServeWrappedReplica::checkHealth).remote().get());
|
||||
|
||||
// RayServeHandle
|
||||
RayServeHandle rayServeHandle =
|
||||
|
@ -71,6 +67,7 @@ public class RayServeHandleTest {
|
|||
if (!inited) {
|
||||
Ray.shutdown();
|
||||
}
|
||||
Serve.setInternalReplicaContext(null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.generated.DeploymentConfig;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.generated.DeploymentLanguage;
|
||||
import io.ray.serve.generated.DeploymentVersion;
|
||||
import io.ray.serve.generated.RequestMetadata;
|
||||
import io.ray.serve.generated.RequestWrapper;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
@ -32,64 +30,84 @@ public class RayServeReplicaTest {
|
|||
ActorHandle<DummyServeController> controllerHandle =
|
||||
Ray.actor(DummyServeController::new).setName(controllerName).remote();
|
||||
|
||||
DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder();
|
||||
deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA);
|
||||
byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray();
|
||||
Object[] initArgs = new Object[] {deploymentName, replicaTag, controllerName, new Object()};
|
||||
byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft();
|
||||
DeploymentConfig deploymentConfig =
|
||||
new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber());
|
||||
DeploymentInfo deploymentInfo =
|
||||
new DeploymentInfo()
|
||||
.setName(deploymentName)
|
||||
.setDeploymentConfig(deploymentConfig)
|
||||
.setDeploymentVersion(new DeploymentVersion(version))
|
||||
.setDeploymentDef(DummyReplica.class.getName());
|
||||
|
||||
DeploymentInfo deploymentInfo = new DeploymentInfo();
|
||||
deploymentInfo.setDeploymentConfig(deploymentConfigBytes);
|
||||
deploymentInfo.setDeploymentVersion(
|
||||
DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray());
|
||||
deploymentInfo.setReplicaConfig(
|
||||
new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>()));
|
||||
|
||||
ActorHandle<RayServeWrappedReplica> backendHandle =
|
||||
ActorHandle<RayServeWrappedReplica> replicHandle =
|
||||
Ray.actor(
|
||||
RayServeWrappedReplica::new,
|
||||
deploymentName,
|
||||
replicaTag,
|
||||
deploymentInfo,
|
||||
controllerName)
|
||||
replicaTag,
|
||||
controllerName,
|
||||
(RayServeConfig) null)
|
||||
.remote();
|
||||
|
||||
// ready
|
||||
backendHandle.task(RayServeWrappedReplica::ready).remote();
|
||||
Assert.assertTrue(replicHandle.task(RayServeWrappedReplica::checkHealth).remote().get());
|
||||
|
||||
// handle request
|
||||
RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder();
|
||||
requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10));
|
||||
requestMetadata.setCallMethod("getDeploymentName");
|
||||
requestMetadata.setCallMethod(Constants.CALL_METHOD);
|
||||
RequestWrapper.Builder requestWrapper = RequestWrapper.newBuilder();
|
||||
|
||||
ObjectRef<Object> resultRef =
|
||||
backendHandle
|
||||
replicHandle
|
||||
.task(
|
||||
RayServeWrappedReplica::handleRequest,
|
||||
requestMetadata.build().toByteArray(),
|
||||
requestWrapper.build().toByteArray())
|
||||
.remote();
|
||||
Assert.assertEquals((String) resultRef.get(), deploymentName);
|
||||
Assert.assertEquals((String) resultRef.get(), "1");
|
||||
|
||||
// reconfigure
|
||||
ObjectRef<byte[]> versionRef =
|
||||
backendHandle.task(RayServeWrappedReplica::reconfigure, (Object) null).remote();
|
||||
Assert.assertEquals(DeploymentVersion.parseFrom(versionRef.get()).getCodeVersion(), version);
|
||||
ObjectRef<Object> versionRef =
|
||||
replicHandle.task(RayServeWrappedReplica::reconfigure, (Object) null).remote();
|
||||
Assert.assertEquals(((DeploymentVersion) versionRef.get()).getCodeVersion(), version);
|
||||
|
||||
replicHandle.task(RayServeWrappedReplica::reconfigure, new Object()).remote().get();
|
||||
resultRef =
|
||||
replicHandle
|
||||
.task(
|
||||
RayServeWrappedReplica::handleRequest,
|
||||
requestMetadata.build().toByteArray(),
|
||||
requestWrapper.build().toByteArray())
|
||||
.remote();
|
||||
Assert.assertEquals((String) resultRef.get(), "1");
|
||||
|
||||
replicHandle
|
||||
.task(RayServeWrappedReplica::reconfigure, ImmutableMap.of("value", "100"))
|
||||
.remote()
|
||||
.get();
|
||||
resultRef =
|
||||
replicHandle
|
||||
.task(
|
||||
RayServeWrappedReplica::handleRequest,
|
||||
requestMetadata.build().toByteArray(),
|
||||
requestWrapper.build().toByteArray())
|
||||
.remote();
|
||||
Assert.assertEquals((String) resultRef.get(), "101");
|
||||
|
||||
// get version
|
||||
versionRef = backendHandle.task(RayServeWrappedReplica::getVersion).remote();
|
||||
Assert.assertEquals(DeploymentVersion.parseFrom(versionRef.get()).getCodeVersion(), version);
|
||||
versionRef = replicHandle.task(RayServeWrappedReplica::getVersion).remote();
|
||||
Assert.assertEquals(((DeploymentVersion) versionRef.get()).getCodeVersion(), version);
|
||||
|
||||
// prepare for shutdown
|
||||
ObjectRef<Boolean> shutdownRef =
|
||||
backendHandle.task(RayServeWrappedReplica::prepareForShutdown).remote();
|
||||
replicHandle.task(RayServeWrappedReplica::prepareForShutdown).remote();
|
||||
Assert.assertTrue(shutdownRef.get());
|
||||
|
||||
} finally {
|
||||
if (!inited) {
|
||||
Ray.shutdown();
|
||||
}
|
||||
Serve.setInternalReplicaContext(null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,40 +15,42 @@ public class ReplicaConfigTest {
|
|||
public void test() {
|
||||
|
||||
Object dummy = new Object();
|
||||
String backendDef = "io.ray.serve.ReplicaConfigTest";
|
||||
String deploymentDef = "io.ray.serve.ReplicaConfigTest";
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("placement_group", dummy)));
|
||||
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("placement_group", dummy)));
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("lifetime", dummy)));
|
||||
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("lifetime", dummy)));
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("name", dummy)));
|
||||
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("name", dummy)));
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("max_restarts", dummy)));
|
||||
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("max_restarts", dummy)));
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("num_cpus", -1.0)));
|
||||
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("num_cpus", -1.0)));
|
||||
ReplicaConfig replicaConfig =
|
||||
new ReplicaConfig(backendDef, null, getRayActorOptions("num_cpus", 2.0));
|
||||
new ReplicaConfig(deploymentDef, null, getRayActorOptions("num_cpus", 2.0));
|
||||
Assert.assertEquals(replicaConfig.getResource().get("CPU").doubleValue(), 2.0);
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("num_gpus", -1.0)));
|
||||
replicaConfig = new ReplicaConfig(backendDef, null, getRayActorOptions("num_gpus", 2.0));
|
||||
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("num_gpus", -1.0)));
|
||||
replicaConfig = new ReplicaConfig(deploymentDef, null, getRayActorOptions("num_gpus", 2.0));
|
||||
Assert.assertEquals(replicaConfig.getResource().get("GPU").doubleValue(), 2.0);
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("memory", -1.0)));
|
||||
replicaConfig = new ReplicaConfig(backendDef, null, getRayActorOptions("memory", 2.0));
|
||||
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("memory", -1.0)));
|
||||
replicaConfig = new ReplicaConfig(deploymentDef, null, getRayActorOptions("memory", 2.0));
|
||||
Assert.assertEquals(replicaConfig.getResource().get("memory").doubleValue(), 2.0);
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("object_store_memory", -1.0)));
|
||||
() ->
|
||||
new ReplicaConfig(
|
||||
deploymentDef, null, getRayActorOptions("object_store_memory", -1.0)));
|
||||
replicaConfig =
|
||||
new ReplicaConfig(backendDef, null, getRayActorOptions("object_store_memory", 2.0));
|
||||
new ReplicaConfig(deploymentDef, null, getRayActorOptions("object_store_memory", 2.0));
|
||||
Assert.assertEquals(replicaConfig.getResource().get("object_store_memory").doubleValue(), 2.0);
|
||||
}
|
||||
|
||||
|
|
19
java/serve/src/test/java/io/ray/serve/ReplicaNameTest.java
Normal file
19
java/serve/src/test/java/io/ray/serve/ReplicaNameTest.java
Normal file
|
@ -0,0 +1,19 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class ReplicaNameTest {
|
||||
|
||||
@Test
|
||||
public void test() {
|
||||
String deploymentTag = "ReplicaNameTest";
|
||||
String replicaSuffix = RandomStringUtils.randomAlphabetic(6);
|
||||
ReplicaName replicaName = new ReplicaName(deploymentTag, replicaSuffix);
|
||||
|
||||
Assert.assertEquals(replicaName.getDeploymentTag(), deploymentTag);
|
||||
Assert.assertEquals(replicaName.getReplicaSuffix(), replicaSuffix);
|
||||
Assert.assertEquals(replicaName.getReplicaTag(), deploymentTag + "#" + replicaSuffix);
|
||||
}
|
||||
}
|
|
@ -3,13 +3,10 @@ package io.ray.serve;
|
|||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.generated.ActorSet;
|
||||
import io.ray.serve.generated.DeploymentConfig;
|
||||
import io.ray.serve.generated.DeploymentLanguage;
|
||||
import io.ray.serve.generated.DeploymentVersion;
|
||||
import io.ray.serve.generated.RequestMetadata;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
|
@ -20,16 +17,6 @@ public class ReplicaSetTest {
|
|||
|
||||
private String deploymentName = "ReplicaSetTest";
|
||||
|
||||
@Test
|
||||
public void setMaxConcurrentQueriesTest() {
|
||||
ReplicaSet replicaSet = new ReplicaSet(deploymentName);
|
||||
DeploymentConfig.Builder builder = DeploymentConfig.newBuilder();
|
||||
builder.setMaxConcurrentQueries(200);
|
||||
|
||||
replicaSet.setMaxConcurrentQueries(builder.build());
|
||||
Assert.assertEquals(replicaSet.getMaxConcurrentQueries(), 200);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void updateWorkerReplicasTest() {
|
||||
ReplicaSet replicaSet = new ReplicaSet(deploymentName);
|
||||
|
@ -58,30 +45,29 @@ public class ReplicaSetTest {
|
|||
Ray.actor(DummyServeController::new).setName(controllerName).remote();
|
||||
|
||||
// Replica
|
||||
DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder();
|
||||
deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA);
|
||||
byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray();
|
||||
DeploymentConfig deploymentConfig =
|
||||
new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber());
|
||||
|
||||
Object[] initArgs = new Object[] {deploymentName, replicaTag, controllerName, new Object()};
|
||||
byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft();
|
||||
|
||||
DeploymentInfo deploymentInfo = new DeploymentInfo();
|
||||
deploymentInfo.setDeploymentConfig(deploymentConfigBytes);
|
||||
deploymentInfo.setDeploymentVersion(
|
||||
DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray());
|
||||
deploymentInfo.setReplicaConfig(
|
||||
new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>()));
|
||||
DeploymentInfo deploymentInfo =
|
||||
new DeploymentInfo()
|
||||
.setName(deploymentName)
|
||||
.setDeploymentConfig(deploymentConfig)
|
||||
.setDeploymentVersion(new DeploymentVersion(version))
|
||||
.setDeploymentDef("io.ray.serve.ReplicaContext")
|
||||
.setInitArgs(initArgs);
|
||||
|
||||
ActorHandle<RayServeWrappedReplica> replicaHandle =
|
||||
Ray.actor(
|
||||
RayServeWrappedReplica::new,
|
||||
deploymentName,
|
||||
replicaTag,
|
||||
deploymentInfo,
|
||||
controllerName)
|
||||
replicaTag,
|
||||
controllerName,
|
||||
(RayServeConfig) null)
|
||||
.setName(actorName)
|
||||
.remote();
|
||||
replicaHandle.task(RayServeWrappedReplica::ready).remote();
|
||||
Assert.assertTrue(replicaHandle.task(RayServeWrappedReplica::checkHealth).remote().get());
|
||||
|
||||
// ReplicaSet
|
||||
ReplicaSet replicaSet = new ReplicaSet(deploymentName);
|
||||
|
@ -90,7 +76,6 @@ public class ReplicaSetTest {
|
|||
replicaSet.updateWorkerReplicas(builder.build());
|
||||
|
||||
// assign
|
||||
|
||||
RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder();
|
||||
requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10));
|
||||
requestMetadata.setCallMethod("getDeploymentName");
|
||||
|
@ -103,6 +88,7 @@ public class ReplicaSetTest {
|
|||
if (!inited) {
|
||||
Ray.shutdown();
|
||||
}
|
||||
Serve.setInternalReplicaContext(null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,13 +3,10 @@ package io.ray.serve;
|
|||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.serializer.MessagePackSerializer;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.generated.ActorSet;
|
||||
import io.ray.serve.generated.DeploymentConfig;
|
||||
import io.ray.serve.generated.DeploymentLanguage;
|
||||
import io.ray.serve.generated.DeploymentVersion;
|
||||
import io.ray.serve.generated.RequestMetadata;
|
||||
import java.util.HashMap;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
@ -33,30 +30,29 @@ public class RouterTest {
|
|||
Ray.actor(DummyServeController::new).setName(controllerName).remote();
|
||||
|
||||
// Replica
|
||||
DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder();
|
||||
deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA);
|
||||
byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray();
|
||||
DeploymentConfig deploymentConfig =
|
||||
new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber());
|
||||
|
||||
Object[] initArgs = new Object[] {deploymentName, replicaTag, controllerName, new Object()};
|
||||
byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft();
|
||||
|
||||
DeploymentInfo deploymentInfo = new DeploymentInfo();
|
||||
deploymentInfo.setDeploymentConfig(deploymentConfigBytes);
|
||||
deploymentInfo.setDeploymentVersion(
|
||||
DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray());
|
||||
deploymentInfo.setReplicaConfig(
|
||||
new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>()));
|
||||
DeploymentInfo deploymentInfo =
|
||||
new DeploymentInfo()
|
||||
.setName(deploymentName)
|
||||
.setDeploymentConfig(deploymentConfig)
|
||||
.setDeploymentVersion(new DeploymentVersion(version))
|
||||
.setDeploymentDef("io.ray.serve.ReplicaContext")
|
||||
.setInitArgs(initArgs);
|
||||
|
||||
ActorHandle<RayServeWrappedReplica> replicaHandle =
|
||||
Ray.actor(
|
||||
RayServeWrappedReplica::new,
|
||||
deploymentName,
|
||||
replicaTag,
|
||||
deploymentInfo,
|
||||
controllerName)
|
||||
replicaTag,
|
||||
controllerName,
|
||||
(RayServeConfig) null)
|
||||
.setName(actorName)
|
||||
.remote();
|
||||
replicaHandle.task(RayServeWrappedReplica::ready).remote();
|
||||
Assert.assertTrue(replicaHandle.task(RayServeWrappedReplica::checkHealth).remote().get());
|
||||
|
||||
// Router
|
||||
Router router = new Router(controllerHandle, deploymentName);
|
||||
|
@ -75,6 +71,7 @@ public class RouterTest {
|
|||
if (!inited) {
|
||||
Ray.shutdown();
|
||||
}
|
||||
Serve.setInternalReplicaContext(null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,10 +10,10 @@ public class KeyTypeTest {
|
|||
|
||||
@Test
|
||||
public void hashTest() {
|
||||
KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1");
|
||||
KeyType k2 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1");
|
||||
KeyType k3 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, null);
|
||||
KeyType k4 = new KeyType(LongPollNamespace.REPLICA_HANDLES, "k4");
|
||||
KeyType k1 = new KeyType(LongPollNamespace.ROUTE_TABLE, "k1");
|
||||
KeyType k2 = new KeyType(LongPollNamespace.ROUTE_TABLE, "k1");
|
||||
KeyType k3 = new KeyType(LongPollNamespace.ROUTE_TABLE, null);
|
||||
KeyType k4 = new KeyType(LongPollNamespace.ROUTE_TABLE, "k4");
|
||||
|
||||
Assert.assertEquals(k1, k1);
|
||||
Assert.assertEquals(k1.hashCode(), k1.hashCode());
|
||||
|
@ -34,7 +34,7 @@ public class KeyTypeTest {
|
|||
|
||||
@Test
|
||||
public void jsonTest() {
|
||||
KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1");
|
||||
KeyType k1 = new KeyType(LongPollNamespace.ROUTE_TABLE, "k1");
|
||||
String json = GSON.toJson(k1);
|
||||
|
||||
KeyType k2 = GSON.fromJson(json, KeyType.class);
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
package io.ray.serve.poll;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.ray.serve.generated.DeploymentConfig;
|
||||
import io.ray.serve.generated.EndpointInfo;
|
||||
import io.ray.serve.generated.EndpointSet;
|
||||
import io.ray.serve.generated.UpdatedObject;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -10,27 +11,29 @@ import org.testng.annotations.Test;
|
|||
|
||||
public class LongPollClientTest {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Test
|
||||
public void test() throws Throwable {
|
||||
|
||||
String[] a = new String[] {"test"};
|
||||
|
||||
// Construct a listener map.
|
||||
KeyType keyType = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "deploymentName");
|
||||
KeyType keyType = new KeyType(LongPollNamespace.ROUTE_TABLE, null);
|
||||
Map<KeyType, KeyListener> keyListeners = new HashMap<>();
|
||||
keyListeners.put(
|
||||
keyType, (object) -> a[0] = String.valueOf(((DeploymentConfig) object).getNumReplicas()));
|
||||
keyType, (object) -> a[0] = String.valueOf(((Map<String, EndpointInfo>) object).size()));
|
||||
|
||||
// Initialize LongPollClient.
|
||||
LongPollClient longPollClient = new LongPollClient(null, keyListeners);
|
||||
|
||||
// Construct updated object.
|
||||
DeploymentConfig.Builder deploymentConfig = DeploymentConfig.newBuilder();
|
||||
deploymentConfig.setNumReplicas(20);
|
||||
EndpointSet.Builder endpointSet = EndpointSet.newBuilder();
|
||||
endpointSet.putEndpoints("1", EndpointInfo.newBuilder().build());
|
||||
endpointSet.putEndpoints("2", EndpointInfo.newBuilder().build());
|
||||
int snapshotId = 10;
|
||||
UpdatedObject.Builder updatedObject = UpdatedObject.newBuilder();
|
||||
updatedObject.setSnapshotId(snapshotId);
|
||||
updatedObject.setObjectSnapshot(ByteString.copyFrom(deploymentConfig.build().toByteArray()));
|
||||
updatedObject.setObjectSnapshot(ByteString.copyFrom(endpointSet.build().toByteArray()));
|
||||
|
||||
// Process update.
|
||||
Map<KeyType, UpdatedObject> updates = new HashMap<>();
|
||||
|
@ -40,8 +43,7 @@ public class LongPollClientTest {
|
|||
// Validation.
|
||||
Assert.assertEquals(longPollClient.getSnapshotIds().get(keyType).intValue(), snapshotId);
|
||||
Assert.assertEquals(
|
||||
((DeploymentConfig) longPollClient.getObjectSnapshots().get(keyType)).getNumReplicas(),
|
||||
deploymentConfig.getNumReplicas());
|
||||
Assert.assertEquals(a[0], String.valueOf(deploymentConfig.getNumReplicas()));
|
||||
((Map<String, EndpointInfo>) longPollClient.getObjectSnapshots().get(keyType)).size(), 2);
|
||||
Assert.assertEquals(a[0], String.valueOf(endpointSet.getEndpointsMap().size()));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
package io.ray.serve.util;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.ray.serve.DeploymentConfig;
|
||||
import io.ray.serve.DeploymentVersion;
|
||||
import io.ray.serve.generated.RequestMetadata;
|
||||
import io.ray.serve.generated.RequestWrapper;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class ServeProtoUtilTest {
|
||||
|
||||
@Test
|
||||
public void parseDeploymentConfigTest() {
|
||||
int numReplicas = 10;
|
||||
io.ray.serve.generated.DeploymentConfig pbDeploymentConfig =
|
||||
io.ray.serve.generated.DeploymentConfig.newBuilder().setNumReplicas(numReplicas).build();
|
||||
|
||||
DeploymentConfig deploymentConfig =
|
||||
ServeProtoUtil.parseDeploymentConfig(pbDeploymentConfig.toByteArray());
|
||||
Assert.assertNotNull(deploymentConfig);
|
||||
Assert.assertEquals(deploymentConfig.getNumReplicas(), numReplicas);
|
||||
Assert.assertEquals(deploymentConfig.getDeploymentLanguage(), 0);
|
||||
Assert.assertEquals(deploymentConfig.getGracefulShutdownTimeoutS(), 20);
|
||||
Assert.assertEquals(deploymentConfig.getGracefulShutdownWaitLoopS(), 2);
|
||||
Assert.assertEquals(deploymentConfig.getMaxConcurrentQueries(), 100);
|
||||
Assert.assertNull(deploymentConfig.getUserConfig());
|
||||
Assert.assertEquals(deploymentConfig.isCrossLanguage(), false);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parseRequestMetadataTest() {
|
||||
String prefix = "parseRequestMetadataTest";
|
||||
String requestId = RandomStringUtils.randomAlphabetic(10);
|
||||
String callMethod = prefix + "_method";
|
||||
String endpoint = prefix + "_endpoint";
|
||||
String context = prefix + "_context";
|
||||
RequestMetadata requestMetadata =
|
||||
RequestMetadata.newBuilder()
|
||||
.setRequestId(requestId)
|
||||
.setCallMethod(callMethod)
|
||||
.setEndpoint(endpoint)
|
||||
.putContext("context", context)
|
||||
.build();
|
||||
|
||||
RequestMetadata result = ServeProtoUtil.parseRequestMetadata(requestMetadata.toByteArray());
|
||||
Assert.assertNotNull(result);
|
||||
Assert.assertEquals(result.getCallMethod(), callMethod);
|
||||
Assert.assertEquals(result.getEndpoint(), endpoint);
|
||||
Assert.assertEquals(result.getRequestId(), requestId);
|
||||
Assert.assertEquals(result.getContextMap().get("context"), context);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parseRequestWrapperTest() {
|
||||
byte[] body = new byte[] {1, 2};
|
||||
RequestWrapper requestWrapper =
|
||||
RequestWrapper.newBuilder().setBody(ByteString.copyFrom(body)).build();
|
||||
|
||||
RequestWrapper result = ServeProtoUtil.parseRequestWrapper(requestWrapper.toByteArray());
|
||||
Assert.assertNotNull(result);
|
||||
byte[] rstBody = result.getBody().toByteArray();
|
||||
Assert.assertEquals(rstBody[0], 1);
|
||||
Assert.assertEquals(rstBody[1], 2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parseDeploymentVersionTest() {
|
||||
String codeVersion = "parseDeploymentVersionTest";
|
||||
io.ray.serve.generated.DeploymentVersion pbDeploymentVersion =
|
||||
io.ray.serve.generated.DeploymentVersion.newBuilder().setCodeVersion(codeVersion).build();
|
||||
|
||||
DeploymentVersion deploymentVersion =
|
||||
ServeProtoUtil.parseDeploymentVersion(pbDeploymentVersion.toByteArray());
|
||||
Assert.assertNotNull(deploymentVersion);
|
||||
Assert.assertEquals(deploymentVersion.getCodeVersion(), codeVersion);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void toDeploymentVersionProtobufTest() {
|
||||
String codeVersion = "toDeploymentVersionProtobufTest";
|
||||
DeploymentVersion deploymentVersion = new DeploymentVersion(codeVersion);
|
||||
io.ray.serve.generated.DeploymentVersion pbDeploymentVersion =
|
||||
ServeProtoUtil.toProtobuf(deploymentVersion);
|
||||
|
||||
Assert.assertNotNull(pbDeploymentVersion);
|
||||
Assert.assertEquals(pbDeploymentVersion.getCodeVersion(), codeVersion);
|
||||
}
|
||||
}
|
|
@ -95,6 +95,8 @@ message RequestMetadata {
|
|||
string endpoint = 2;
|
||||
|
||||
string call_method = 3;
|
||||
|
||||
map<string, string> context = 4;
|
||||
}
|
||||
|
||||
message RequestWrapper {
|
||||
|
|
Loading…
Add table
Reference in a new issue