[Serve] Make Java Replica Extendable (#19463)

This commit is contained in:
liuyang-my 2021-11-11 07:05:37 +08:00 committed by GitHub
parent 81f036d078
commit efca009258
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 1200 additions and 666 deletions

View file

@ -6,8 +6,8 @@ import java.util.List;
/** Ray Serve common constants. */
public class Constants {
/** Name of backend reconfiguration method implemented by user. */
public static final String BACKEND_RECONFIGURE_METHOD = "reconfigure";
/** Name of deployment reconfiguration method implemented by user. */
public static final String RECONFIGURE_METHOD = "reconfigure";
/** Default histogram buckets for latency tracker. */
public static final List<Double> DEFAULT_LATENCY_BUCKET_MS =
@ -19,7 +19,9 @@ public class Constants {
public static final String SERVE_CONTROLLER_NAME = "SERVE_CONTROLLER_ACTOR";
public static final String DEFAULT_CALL_METHOD = "call";
public static final String CALL_METHOD = "call";
public static final String UTF8 = "UTF-8";
public static final String CHECK_HEALTH_METHOD = "checkHealth";
}

View file

@ -0,0 +1,87 @@
package io.ray.serve;
import com.google.common.base.Preconditions;
import java.io.Serializable;
public class DeploymentConfig implements Serializable {
private static final long serialVersionUID = 4037621960087621036L;
private int numReplicas = 1;
private int maxConcurrentQueries = 100;
private Object userConfig;
private double gracefulShutdownWaitLoopS = 2;
private double gracefulShutdownTimeoutS = 20;
private boolean isCrossLanguage;
private int deploymentLanguage = 1;
public int getNumReplicas() {
return numReplicas;
}
public DeploymentConfig setNumReplicas(int numReplicas) {
this.numReplicas = numReplicas;
return this;
}
public int getMaxConcurrentQueries() {
return maxConcurrentQueries;
}
public DeploymentConfig setMaxConcurrentQueries(int maxConcurrentQueries) {
Preconditions.checkArgument(maxConcurrentQueries >= 0, "max_concurrent_queries must be >= 0");
this.maxConcurrentQueries = maxConcurrentQueries;
return this;
}
public Object getUserConfig() {
return userConfig;
}
public DeploymentConfig setUserConfig(Object userConfig) {
this.userConfig = userConfig;
return this;
}
public double getGracefulShutdownWaitLoopS() {
return gracefulShutdownWaitLoopS;
}
public DeploymentConfig setGracefulShutdownWaitLoopS(double gracefulShutdownWaitLoopS) {
this.gracefulShutdownWaitLoopS = gracefulShutdownWaitLoopS;
return this;
}
public double getGracefulShutdownTimeoutS() {
return gracefulShutdownTimeoutS;
}
public DeploymentConfig setGracefulShutdownTimeoutS(double gracefulShutdownTimeoutS) {
this.gracefulShutdownTimeoutS = gracefulShutdownTimeoutS;
return this;
}
public boolean isCrossLanguage() {
return isCrossLanguage;
}
public DeploymentConfig setCrossLanguage(boolean isCrossLanguage) {
this.isCrossLanguage = isCrossLanguage;
return this;
}
public int getDeploymentLanguage() {
return deploymentLanguage;
}
public DeploymentConfig setDeploymentLanguage(int deploymentLanguage) {
this.deploymentLanguage = deploymentLanguage;
return this;
}
}

View file

@ -1,38 +1,75 @@
package io.ray.serve;
import java.io.Serializable;
import java.util.Map;
public class DeploymentInfo implements Serializable {
private static final long serialVersionUID = -4198364411759931955L;
private static final long serialVersionUID = -7132135316463505391L;
private byte[] deploymentConfig;
private String name;
private ReplicaConfig replicaConfig;
private String deploymentDef;
private byte[] deploymentVersion;
private Object[] initArgs;
public byte[] getDeploymentConfig() {
private DeploymentConfig deploymentConfig;
private DeploymentVersion deploymentVersion;
private Map<String, String> config;
public String getName() {
return name;
}
public DeploymentInfo setName(String name) {
this.name = name;
return this;
}
public String getDeploymentDef() {
return deploymentDef;
}
public DeploymentInfo setDeploymentDef(String deploymentDef) {
this.deploymentDef = deploymentDef;
return this;
}
public Object[] getInitArgs() {
return initArgs;
}
public DeploymentInfo setInitArgs(Object[] initArgs) {
this.initArgs = initArgs;
return this;
}
public DeploymentConfig getDeploymentConfig() {
return deploymentConfig;
}
public void setDeploymentConfig(byte[] deploymentConfig) {
public DeploymentInfo setDeploymentConfig(DeploymentConfig deploymentConfig) {
this.deploymentConfig = deploymentConfig;
return this;
}
public ReplicaConfig getReplicaConfig() {
return replicaConfig;
}
public void setReplicaConfig(ReplicaConfig replicaConfig) {
this.replicaConfig = replicaConfig;
}
public byte[] getDeploymentVersion() {
public DeploymentVersion getDeploymentVersion() {
return deploymentVersion;
}
public void setDeploymentVersion(byte[] deploymentVersion) {
public DeploymentInfo setDeploymentVersion(DeploymentVersion deploymentVersion) {
this.deploymentVersion = deploymentVersion;
return this;
}
public Map<String, String> getConfig() {
return config;
}
public DeploymentInfo setConfig(Map<String, String> config) {
this.config = config;
return this;
}
}

View file

@ -0,0 +1,46 @@
package io.ray.serve;
import java.io.Serializable;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;
public class DeploymentVersion implements Serializable {
private static final long serialVersionUID = 3400261981775851058L;
private String codeVersion;
private Object userConfig;
private boolean unversioned;
public DeploymentVersion() {
this(null, null);
}
public DeploymentVersion(String codeVersion) {
this(codeVersion, null);
}
public DeploymentVersion(String codeVersion, Object userConfig) {
if (StringUtils.isBlank(codeVersion)) {
this.unversioned = true;
this.codeVersion = RandomStringUtils.randomAlphabetic(6);
} else {
this.codeVersion = codeVersion;
}
this.userConfig = userConfig;
}
public String getCodeVersion() {
return codeVersion;
}
public Object getUserConfig() {
return userConfig;
}
public boolean isUnversioned() {
return unversioned;
}
}

View file

@ -1,12 +0,0 @@
package io.ray.serve;
import java.util.concurrent.atomic.AtomicInteger;
public class DummyBackendReplica {
private AtomicInteger counter = new AtomicInteger();
public String call() {
return String.valueOf(counter.incrementAndGet());
}
}

View file

@ -0,0 +1,21 @@
package io.ray.serve;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
public class DummyReplica {
private AtomicInteger counter = new AtomicInteger();
public String call() {
return String.valueOf(counter.incrementAndGet());
}
public void reconfigure(Object userConfig) {
counter.set(0);
}
public void reconfigure(Map<String, String> userConfig) {
counter.set(Integer.valueOf(userConfig.get("value")));
}
}

View file

@ -137,8 +137,8 @@ public class ProxyActor {
this.proxyRouter.updateRoutes(endpointInfos);
}
public void ready() {
return;
public boolean ready() {
return true;
}
public void blockUntilEndpointExists(String endpoint, double timeoutS) {

View file

@ -1,6 +1,35 @@
package io.ray.serve;
public class RayServeConfig {
import java.io.Serializable;
import java.util.Map;
public class RayServeConfig implements Serializable {
private static final long serialVersionUID = 5367425336296141588L;
public static final String PROXY_CLASS = "ray.serve.proxy.class";
public static final String METRICS_ENABLED = "ray.serve.metrics.enabled";
private String name;
private Map<String, String> config;
public String getName() {
return name;
}
public RayServeConfig setName(String name) {
this.name = name;
return this;
}
public Map<String, String> getConfig() {
return config;
}
public RayServeConfig setConfig(Map<String, String> config) {
this.config = config;
return this;
}
}

View file

@ -49,7 +49,7 @@ public class RayServeHandle {
* Returns a Ray ObjectRef whose results can be waited for or retrieved using ray.wait or ray.get
* (or ``await object_ref``), respectively.
*
* @param parameters The input parameters of the specified method to invoke on the backend.
* @param parameters The input parameters of the specified method to be invoked in the deployment.
* @return ray.ObjectRef
*/
public ObjectRef<Object> remote(Object[] parameters) {
@ -58,7 +58,7 @@ public class RayServeHandle {
requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10));
requestMetadata.setEndpoint(endpointName);
requestMetadata.setCallMethod(
handleOptions != null ? handleOptions.getMethodName() : Constants.DEFAULT_CALL_METHOD);
handleOptions != null ? handleOptions.getMethodName() : Constants.CALL_METHOD);
return router.assignRequest(requestMetadata.build(), parameters);
}

View file

@ -14,20 +14,20 @@ public enum RayServeMetrics {
"serve_deployment_queued_queries",
"The current number of queries to this deployment waiting to be assigned to a replica."),
SERVE_BACKEND_REQUEST_COUNTER(
"serve_backend_request_counter",
SERVE_DEPLOYMENT_REQUEST_COUNTER(
"serve_deployment_request_counter",
"The number of queries that have been processed in this replica."),
SERVE_BACKEND_ERROR_COUNTER(
"serve_backend_error_counter",
SERVE_DEPLOYMENT_ERROR_COUNTER(
"serve_deployment_error_counter",
"The number of exceptions that have occurred in this replica."),
SERVE_BACKEND_REPLICA_STARTS(
"serve_backend_replica_starts",
SERVE_DEPLOYMENT_REPLICA_STARTS(
"serve_deployment_replica_starts",
"The number of times this replica has been restarted due to failure."),
SERVE_BACKEND_PROCESSING_LATENCY_MS(
"serve_backend_processing_latency_ms", "The latency for queries to be processed."),
SERVE_DEPLOYMENT_PROCESSING_LATENCY_MS(
"serve_deployment_processing_latency_ms", "The latency for queries to be processed."),
SERVE_REPLICA_PROCESSING_QUERIES(
"serve_replica_processing_queries", "The current number of queries being processed."),
@ -41,13 +41,13 @@ public enum RayServeMetrics {
public static final String TAG_ROUTE = "route";
public static final String TAG_BACKEND = "backend";
public static final String TAG_REPLICA = "replica";
private static final boolean isMetricsEnabled =
private static final boolean canBeUsed =
Ray.isInitialized() && !Ray.getRuntimeContext().isSingleProcess();
private static volatile boolean enabled = true;
private String name;
private String description;
@ -58,7 +58,7 @@ public enum RayServeMetrics {
}
public static void execute(Runnable runnable) {
if (!isMetricsEnabled) {
if (!enabled || !canBeUsed) {
return;
}
runnable.run();
@ -71,4 +71,12 @@ public enum RayServeMetrics {
public String getDescription() {
return description;
}
public static void enable() {
enabled = true;
}
public static void disable() {
enabled = false;
}
}

View file

@ -1,330 +1,18 @@
package io.ray.serve;
import com.google.common.collect.ImmutableMap;
import com.google.protobuf.ByteString;
import io.ray.api.BaseActorHandle;
import io.ray.runtime.metric.Count;
import io.ray.runtime.metric.Gauge;
import io.ray.runtime.metric.Histogram;
import io.ray.runtime.metric.Metrics;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.api.Serve;
import io.ray.serve.generated.DeploymentConfig;
import io.ray.serve.generated.DeploymentVersion;
import io.ray.serve.generated.RequestWrapper;
import io.ray.serve.poll.KeyListener;
import io.ray.serve.poll.KeyType;
import io.ray.serve.poll.LongPollClient;
import io.ray.serve.poll.LongPollNamespace;
import io.ray.serve.util.LogUtil;
import io.ray.serve.util.ReflectUtil;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public interface RayServeReplica {
/** Handles requests with the provided callable. */
public class RayServeReplica {
Object handleRequest(Object requestMetadata, Object requestArgs);
private static final Logger LOGGER = LoggerFactory.getLogger(RayServeReplica.class);
private String deploymentName;
private String replicaTag;
private DeploymentConfig config;
private AtomicInteger numOngoingRequests = new AtomicInteger();
private Object callable;
private Count requestCounter;
private Count errorCounter;
private Count restartCounter;
private Histogram processingLatencyTracker;
private Gauge numProcessingItems;
private LongPollClient longPollClient;
private DeploymentVersion version;
private boolean isDeleted = false;
public RayServeReplica(
Object callable,
DeploymentConfig deploymentConfig,
DeploymentVersion version,
BaseActorHandle actorHandle) {
this.deploymentName = Serve.getReplicaContext().getDeploymentName();
this.replicaTag = Serve.getReplicaContext().getReplicaTag();
this.callable = callable;
this.config = deploymentConfig;
this.version = version;
Map<KeyType, KeyListener> keyListeners = new HashMap<>();
keyListeners.put(
new KeyType(LongPollNamespace.BACKEND_CONFIGS, deploymentName),
newConfig -> updateDeploymentConfigs(newConfig));
this.longPollClient = new LongPollClient(actorHandle, keyListeners);
this.longPollClient.start();
registerMetrics();
default Object reconfigure(Object userConfig) {
return new DeploymentVersion(null, userConfig);
}
private void registerMetrics() {
RayServeMetrics.execute(
() ->
requestCounter =
Metrics.count()
.name(RayServeMetrics.SERVE_BACKEND_REQUEST_COUNTER.getName())
.description(RayServeMetrics.SERVE_BACKEND_REQUEST_COUNTER.getDescription())
.unit("")
.tags(
ImmutableMap.of(
RayServeMetrics.TAG_BACKEND,
deploymentName,
RayServeMetrics.TAG_REPLICA,
replicaTag))
.register());
RayServeMetrics.execute(
() ->
errorCounter =
Metrics.count()
.name(RayServeMetrics.SERVE_BACKEND_ERROR_COUNTER.getName())
.description(RayServeMetrics.SERVE_BACKEND_ERROR_COUNTER.getDescription())
.unit("")
.tags(
ImmutableMap.of(
RayServeMetrics.TAG_BACKEND,
deploymentName,
RayServeMetrics.TAG_REPLICA,
replicaTag))
.register());
RayServeMetrics.execute(
() ->
restartCounter =
Metrics.count()
.name(RayServeMetrics.SERVE_BACKEND_REPLICA_STARTS.getName())
.description(RayServeMetrics.SERVE_BACKEND_REPLICA_STARTS.getDescription())
.unit("")
.tags(
ImmutableMap.of(
RayServeMetrics.TAG_BACKEND,
deploymentName,
RayServeMetrics.TAG_REPLICA,
replicaTag))
.register());
RayServeMetrics.execute(
() ->
processingLatencyTracker =
Metrics.histogram()
.name(RayServeMetrics.SERVE_BACKEND_PROCESSING_LATENCY_MS.getName())
.description(
RayServeMetrics.SERVE_BACKEND_PROCESSING_LATENCY_MS.getDescription())
.unit("")
.boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS)
.tags(
ImmutableMap.of(
RayServeMetrics.TAG_BACKEND,
deploymentName,
RayServeMetrics.TAG_REPLICA,
replicaTag))
.register());
RayServeMetrics.execute(
() ->
numProcessingItems =
Metrics.gauge()
.name(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getName())
.description(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getDescription())
.unit("")
.tags(
ImmutableMap.of(
RayServeMetrics.TAG_BACKEND,
deploymentName,
RayServeMetrics.TAG_REPLICA,
replicaTag))
.register());
RayServeMetrics.execute(() -> restartCounter.inc(1.0));
}
public Object handleRequest(Query request) {
long startTime = System.currentTimeMillis();
LOGGER.debug(
"Replica {} received request {}", replicaTag, request.getMetadata().getRequestId());
numOngoingRequests.incrementAndGet();
RayServeMetrics.execute(() -> numProcessingItems.update(numOngoingRequests.get()));
Object result = invokeSingle(request);
numOngoingRequests.decrementAndGet();
long requestTimeMs = System.currentTimeMillis() - startTime;
LOGGER.debug(
"Replica {} finished request {} in {}ms",
replicaTag,
request.getMetadata().getRequestId(),
requestTimeMs);
return result;
}
private Object invokeSingle(Query requestItem) {
long start = System.currentTimeMillis();
Method methodToCall = null;
try {
LOGGER.debug(
"Replica {} started executing request {}",
replicaTag,
requestItem.getMetadata().getRequestId());
Object[] args = parseRequestItem(requestItem);
methodToCall = getRunnerMethod(requestItem.getMetadata().getCallMethod(), args);
Object result = methodToCall.invoke(callable, args);
RayServeMetrics.execute(() -> requestCounter.inc(1.0));
return result;
} catch (Throwable e) {
RayServeMetrics.execute(() -> errorCounter.inc(1.0));
throw new RayServeException(
LogUtil.format(
"Replica {} failed to invoke method {}",
replicaTag,
methodToCall == null ? "unknown" : methodToCall.getName()),
e);
} finally {
RayServeMetrics.execute(
() -> processingLatencyTracker.update(System.currentTimeMillis() - start));
}
}
private Object[] parseRequestItem(Query requestItem) {
if (requestItem.getArgs() == null) {
return new Object[0];
}
// From Java Proxy or Handle.
if (requestItem.getArgs() instanceof Object[]) {
return (Object[]) requestItem.getArgs();
}
// From other language Proxy or Handle.
RequestWrapper requestWrapper = (RequestWrapper) requestItem.getArgs();
if (requestWrapper.getBody() == null || requestWrapper.getBody().isEmpty()) {
return new Object[0];
}
return MessagePackSerializer.decode(requestWrapper.getBody().toByteArray(), Object[].class);
}
private Method getRunnerMethod(String methodName, Object[] args) {
try {
return ReflectUtil.getMethod(callable.getClass(), methodName, args);
} catch (NoSuchMethodException e) {
throw new RayServeException(
LogUtil.format(
"Backend doesn't have method {} which is specified in the request. "
+ "The available methods are {}",
methodName,
ReflectUtil.getMethodStrings(callable.getClass())));
}
}
/**
* Perform graceful shutdown. Trigger a graceful shutdown protocol that will wait for all the
* queued tasks to be completed and return to the controller.
*/
public synchronized boolean prepareForShutdown() {
while (true) {
// Sleep first because we want to make sure all the routers receive the notification to remove
// this replica first.
try {
Thread.sleep((long) (config.getGracefulShutdownWaitLoopS() * 1000));
} catch (InterruptedException e) {
LOGGER.error(
"Replica {} was interrupted in sheep when draining pending queries", replicaTag);
}
if (numOngoingRequests.get() == 0) {
break;
} else {
LOGGER.info(
"Waiting for an additional {}s to shut down because there are {} ongoing requests.",
config.getGracefulShutdownWaitLoopS(),
numOngoingRequests.get());
}
}
// Explicitly call the del method to trigger clean up. We set isDeleted = true after
// succssifully calling it so the destructor is called only once.
try {
if (!isDeleted) {
ReflectUtil.getMethod(callable.getClass(), "del").invoke(callable);
}
} catch (NoSuchMethodException e) {
LOGGER.warn("Deployment {} has no del method.", deploymentName);
} catch (Throwable e) {
LOGGER.error("Exception during graceful shutdown of replica.");
} finally {
isDeleted = true;
}
default boolean checkHealth() {
return true;
}
/**
* Reconfigure user's configuration in the callable object through its reconfigure method.
*
* @param userConfig new user's configuration
*/
public DeploymentVersion reconfigure(Object userConfig) {
DeploymentVersion.Builder builder = DeploymentVersion.newBuilder();
builder.setCodeVersion(version.getCodeVersion());
if (userConfig != null) {
builder.setUserConfig(ByteString.copyFrom((byte[]) userConfig));
}
version = builder.build();
try {
Method reconfigureMethod =
ReflectUtil.getMethod(
callable.getClass(),
Constants.BACKEND_RECONFIGURE_METHOD,
userConfig != null
? MessagePackSerializer.decode((byte[]) userConfig, Object[].class)
: new Object[0]); // TODO cache reconfigure method
reconfigureMethod.invoke(callable, userConfig);
} catch (NoSuchMethodException e) {
LOGGER.warn(
"user_config specified but backend {} missing {} method",
deploymentName,
Constants.BACKEND_RECONFIGURE_METHOD);
} catch (Throwable e) {
throw new RayServeException(
LogUtil.format(
"Backend {} failed to reconfigure user_config {}", deploymentName, userConfig),
e);
}
return version;
}
/**
* Update backend configs.
*
* @param newConfig the new configuration of backend
*/
private void updateDeploymentConfigs(Object newConfig) {
config = (DeploymentConfig) newConfig;
}
public DeploymentVersion getVersion() {
return version;
default boolean prepareForShutdown() {
return true;
}
}

View file

@ -0,0 +1,380 @@
package io.ray.serve;
import com.google.common.collect.ImmutableMap;
import io.ray.api.BaseActorHandle;
import io.ray.runtime.metric.Count;
import io.ray.runtime.metric.Gauge;
import io.ray.runtime.metric.Histogram;
import io.ray.runtime.metric.Metrics;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.api.Serve;
import io.ray.serve.generated.RequestMetadata;
import io.ray.serve.generated.RequestWrapper;
import io.ray.serve.util.LogUtil;
import io.ray.serve.util.ReflectUtil;
import java.lang.reflect.Method;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Handles requests with the provided callable. */
public class RayServeReplicaImpl implements RayServeReplica {
private static final Logger LOGGER = LoggerFactory.getLogger(RayServeReplicaImpl.class);
private String deploymentName;
private String replicaTag;
private DeploymentConfig config;
private AtomicInteger numOngoingRequests = new AtomicInteger();
private Object callable;
private Count requestCounter;
private Count errorCounter;
private Count restartCounter;
private Histogram processingLatencyTracker;
private Gauge numProcessingItems;
private DeploymentVersion version;
private boolean isDeleted = false;
private final Method checkHealthMethod;
private final Method callMethod;
public RayServeReplicaImpl(
Object callable,
DeploymentConfig deploymentConfig,
DeploymentVersion version,
BaseActorHandle actorHandle) {
this.deploymentName = Serve.getReplicaContext().getDeploymentName();
this.replicaTag = Serve.getReplicaContext().getReplicaTag();
this.callable = callable;
this.config = deploymentConfig;
this.version = version;
this.checkHealthMethod = getRunnerMethod(Constants.CHECK_HEALTH_METHOD, null, true);
this.callMethod = getRunnerMethod(Constants.CALL_METHOD, new Object[] {new Object()}, true);
registerMetrics();
}
private void registerMetrics() {
RayServeMetrics.execute(
() ->
requestCounter =
Metrics.count()
.name(RayServeMetrics.SERVE_DEPLOYMENT_REQUEST_COUNTER.getName())
.description(RayServeMetrics.SERVE_DEPLOYMENT_REQUEST_COUNTER.getDescription())
.unit("")
.tags(
ImmutableMap.of(
RayServeMetrics.TAG_DEPLOYMENT,
deploymentName,
RayServeMetrics.TAG_REPLICA,
replicaTag))
.register());
RayServeMetrics.execute(
() ->
errorCounter =
Metrics.count()
.name(RayServeMetrics.SERVE_DEPLOYMENT_ERROR_COUNTER.getName())
.description(RayServeMetrics.SERVE_DEPLOYMENT_ERROR_COUNTER.getDescription())
.unit("")
.tags(
ImmutableMap.of(
RayServeMetrics.TAG_DEPLOYMENT,
deploymentName,
RayServeMetrics.TAG_REPLICA,
replicaTag))
.register());
RayServeMetrics.execute(
() ->
restartCounter =
Metrics.count()
.name(RayServeMetrics.SERVE_DEPLOYMENT_REPLICA_STARTS.getName())
.description(RayServeMetrics.SERVE_DEPLOYMENT_REPLICA_STARTS.getDescription())
.unit("")
.tags(
ImmutableMap.of(
RayServeMetrics.TAG_DEPLOYMENT,
deploymentName,
RayServeMetrics.TAG_REPLICA,
replicaTag))
.register());
RayServeMetrics.execute(
() ->
processingLatencyTracker =
Metrics.histogram()
.name(RayServeMetrics.SERVE_DEPLOYMENT_PROCESSING_LATENCY_MS.getName())
.description(
RayServeMetrics.SERVE_DEPLOYMENT_PROCESSING_LATENCY_MS.getDescription())
.unit("")
.boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS)
.tags(
ImmutableMap.of(
RayServeMetrics.TAG_DEPLOYMENT,
deploymentName,
RayServeMetrics.TAG_REPLICA,
replicaTag))
.register());
RayServeMetrics.execute(
() ->
numProcessingItems =
Metrics.gauge()
.name(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getName())
.description(RayServeMetrics.SERVE_REPLICA_PROCESSING_QUERIES.getDescription())
.unit("")
.tags(
ImmutableMap.of(
RayServeMetrics.TAG_DEPLOYMENT,
deploymentName,
RayServeMetrics.TAG_REPLICA,
replicaTag))
.register());
RayServeMetrics.execute(() -> restartCounter.inc(1.0));
}
@Override
public Object handleRequest(Object requestMetadata, Object requestArgs) {
long startTime = System.currentTimeMillis();
Query request = new Query((RequestMetadata) requestMetadata, requestArgs);
LOGGER.debug(
"Replica {} received request {}", replicaTag, request.getMetadata().getRequestId());
numOngoingRequests.incrementAndGet();
RayServeMetrics.execute(() -> numProcessingItems.update(numOngoingRequests.get()));
Object result = invokeSingle(request);
numOngoingRequests.decrementAndGet();
long requestTimeMs = System.currentTimeMillis() - startTime;
LOGGER.debug(
"Replica {} finished request {} in {}ms",
replicaTag,
request.getMetadata().getRequestId(),
requestTimeMs);
return result;
}
private Object invokeSingle(Query requestItem) {
long start = System.currentTimeMillis();
Method methodToCall = null;
try {
LOGGER.debug(
"Replica {} started executing request {}",
replicaTag,
requestItem.getMetadata().getRequestId());
Object[] args = parseRequestItem(requestItem);
methodToCall =
args.length == 1 && callMethod != null
? callMethod
: getRunnerMethod(requestItem.getMetadata().getCallMethod(), args, false);
Object result = methodToCall.invoke(callable, args);
RayServeMetrics.execute(() -> requestCounter.inc(1.0));
return result;
} catch (Throwable e) {
RayServeMetrics.execute(() -> errorCounter.inc(1.0));
throw new RayServeException(
LogUtil.format(
"Replica {} failed to invoke method {}",
replicaTag,
methodToCall == null ? "unknown" : methodToCall.getName()),
e);
} finally {
RayServeMetrics.execute(
() -> processingLatencyTracker.update(System.currentTimeMillis() - start));
}
}
private Object[] parseRequestItem(Query requestItem) {
if (requestItem.getArgs() == null) {
return new Object[0];
}
// From Java Proxy or Handle.
if (requestItem.getArgs() instanceof Object[]) {
return (Object[]) requestItem.getArgs();
}
// From other language Proxy or Handle.
RequestWrapper requestWrapper = (RequestWrapper) requestItem.getArgs();
if (requestWrapper.getBody() == null || requestWrapper.getBody().isEmpty()) {
return new Object[0];
}
return new Object[] {
MessagePackSerializer.decode(requestWrapper.getBody().toByteArray(), Object.class)
};
}
private Method getRunnerMethod(String methodName, Object[] args, boolean isNullable) {
try {
return ReflectUtil.getMethod(callable.getClass(), methodName, args);
} catch (NoSuchMethodException e) {
String errMsg =
LogUtil.format(
"Tried to call a method {} that does not exist. Available methods: {}",
methodName,
ReflectUtil.getMethodStrings(callable.getClass()));
if (isNullable) {
LOGGER.warn(errMsg);
return null;
} else {
LOGGER.error(errMsg, e);
throw new RayServeException(errMsg, e);
}
}
}
/**
* Perform graceful shutdown. Trigger a graceful shutdown protocol that will wait for all the
* queued tasks to be completed and return to the controller.
*
* @return true if it is ready for shutdown.
*/
@Override
public synchronized boolean prepareForShutdown() {
while (true) {
// Sleep first because we want to make sure all the routers receive the notification to remove
// this replica first.
try {
Thread.sleep((long) (config.getGracefulShutdownWaitLoopS() * 1000));
} catch (InterruptedException e) {
LOGGER.error(
"Replica {} was interrupted in sheep when draining pending queries", replicaTag);
}
if (numOngoingRequests.get() == 0) {
break;
} else {
LOGGER.info(
"Waiting for an additional {}s to shut down because there are {} ongoing requests.",
config.getGracefulShutdownWaitLoopS(),
numOngoingRequests.get());
}
}
// Explicitly call the del method to trigger clean up. We set isDeleted = true after
// succssifully calling it so the destructor is called only once.
try {
if (!isDeleted) {
ReflectUtil.getMethod(callable.getClass(), "del").invoke(callable);
}
} catch (NoSuchMethodException e) {
LOGGER.warn("Deployment {} has no del method.", deploymentName);
} catch (Throwable e) {
LOGGER.error("Exception during graceful shutdown of replica.");
} finally {
isDeleted = true;
}
return true;
}
@Override
public DeploymentVersion reconfigure(Object userConfig) {
DeploymentVersion deploymentVersion =
new DeploymentVersion(version.getCodeVersion(), userConfig);
version = deploymentVersion;
if (userConfig == null) {
return deploymentVersion;
}
LOGGER.info(
"Replica {} of deployment {} reconfigure userConfig: {}",
replicaTag,
deploymentName,
userConfig);
try {
ReflectUtil.getMethod(callable.getClass(), Constants.RECONFIGURE_METHOD, userConfig)
.invoke(callable, userConfig);
return version;
} catch (NoSuchMethodException e) {
String errMsg =
LogUtil.format(
"userConfig specified but deployment {} missing {} method",
deploymentName,
Constants.RECONFIGURE_METHOD);
LOGGER.error(errMsg);
throw new RayServeException(errMsg, e);
} catch (Throwable e) {
String errMsg =
LogUtil.format(
"Replica {} of deployment {} failed to reconfigure userConfig {}",
replicaTag,
deploymentName,
userConfig);
LOGGER.error(errMsg);
throw new RayServeException(errMsg, e);
} finally {
LOGGER.info(
"Replica {} of deployment {} finished reconfiguring userConfig: {}",
replicaTag,
deploymentName,
userConfig);
}
}
public DeploymentVersion getVersion() {
return version;
}
@Override
public boolean checkHealth() {
if (checkHealthMethod == null) {
return true;
}
boolean result = true;
try {
LOGGER.info(
"Replica {} of deployment {} check health of {}",
replicaTag,
deploymentName,
callable.getClass().getName());
Object isHealthy = checkHealthMethod.invoke(callable);
if (!(isHealthy instanceof Boolean)) {
LOGGER.error(
"The health check result {} of {} in replica {} of deployment {} is illegal.",
isHealthy == null ? "null" : isHealthy.getClass().getName() + ":" + isHealthy,
callable.getClass().getName(),
replicaTag,
deploymentName);
result = false;
} else {
result = (boolean) isHealthy;
}
} catch (Throwable e) {
LOGGER.error(
"Replica {} of deployment {} failed to check health of {}",
replicaTag,
deploymentName,
callable.getClass().getName(),
e);
result = false;
} finally {
LOGGER.info(
"The health check result of {} in replica {} of deployment {} is {}.",
callable.getClass().getName(),
replicaTag,
deploymentName,
result);
}
return result;
}
public Object getCallable() {
return callable;
}
}

View file

@ -1,82 +1,135 @@
package io.ray.serve;
import com.google.common.base.Preconditions;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.api.BaseActorHandle;
import io.ray.api.Ray;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.api.Serve;
import io.ray.serve.generated.DeploymentConfig;
import io.ray.serve.generated.DeploymentVersion;
import io.ray.serve.generated.RequestMetadata;
import io.ray.serve.util.LogUtil;
import io.ray.serve.util.ReflectUtil;
import io.ray.serve.util.ServeProtoUtil;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Replica class wrapping the provided class. Note that Java function is not supported now. */
public class RayServeWrappedReplica {
public class RayServeWrappedReplica implements RayServeReplica {
private RayServeReplica backend;
private static final Logger LOGGER = LoggerFactory.getLogger(RayServeReplicaImpl.class);
private DeploymentInfo deploymentInfo;
private RayServeReplicaImpl replica;
@SuppressWarnings("rawtypes")
public RayServeWrappedReplica(
String deploymentName,
String replicaTag,
String backendDef,
String deploymentDef,
byte[] initArgsbytes,
byte[] deploymentConfigBytes,
byte[] deploymentVersionBytes,
String controllerName)
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException {
String controllerName) {
// Parse DeploymentConfig.
DeploymentConfig deploymentConfig = ServeProtoUtil.parseDeploymentConfig(deploymentConfigBytes);
// Parse init args.
Object[] initArgs = parseInitArgs(initArgsbytes, deploymentConfig);
Object[] initArgs = null;
try {
initArgs = parseInitArgs(initArgsbytes, deploymentConfig);
} catch (IOException e) {
String errMsg =
LogUtil.format(
"Failed to initialize replica {} of deployment {}",
replicaTag,
deploymentInfo.getName());
LOGGER.error(errMsg, e);
throw new RayServeException(errMsg, e);
}
// Instantiate the object defined by backendDef.
Class backendClass = Class.forName(backendDef);
Object callable = ReflectUtil.getConstructor(backendClass, initArgs).newInstance(initArgs);
// Get the controller by controllerName.
Preconditions.checkArgument(
StringUtils.isNotBlank(controllerName), "Must provide a valid controllerName");
Optional<BaseActorHandle> optional = Ray.getActor(controllerName);
Preconditions.checkState(optional.isPresent(), "Controller does not exist");
// Set the controller name so that Serve.connect() in the user's backend code will connect to
// the instance that this backend is running in.
Serve.setInternalReplicaContext(deploymentName, replicaTag, controllerName, callable);
// Construct worker replica.
backend =
new RayServeReplica(
callable,
deploymentConfig,
ServeProtoUtil.parseDeploymentVersion(deploymentVersionBytes),
optional.get());
// Init replica.
init(
new DeploymentInfo()
.setName(deploymentName)
.setDeploymentConfig(deploymentConfig)
.setDeploymentVersion(ServeProtoUtil.parseDeploymentVersion(deploymentVersionBytes))
.setDeploymentDef(deploymentDef)
.setInitArgs(initArgs),
replicaTag,
controllerName,
null);
}
public RayServeWrappedReplica(
String deploymentName,
String replicaTag,
DeploymentInfo deploymentInfo,
String controllerName)
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException {
this(
deploymentName,
replicaTag,
deploymentInfo.getReplicaConfig().getBackendDef(),
deploymentInfo.getReplicaConfig().getInitArgs(),
deploymentInfo.getDeploymentConfig(),
deploymentInfo.getDeploymentVersion(),
controllerName);
String replicaTag,
String controllerName,
RayServeConfig rayServeConfig) {
init(deploymentInfo, replicaTag, controllerName, rayServeConfig);
}
@SuppressWarnings("rawtypes")
private void init(
DeploymentInfo deploymentInfo,
String replicaTag,
String controllerName,
RayServeConfig rayServeConfig) {
try {
// Set the controller name so that Serve.connect() in the user's code will connect to the
// instance that this deployment is running in.
Serve.setInternalReplicaContext(deploymentInfo.getName(), replicaTag, controllerName, null);
Serve.getReplicaContext().setRayServeConfig(rayServeConfig);
// Instantiate the object defined by deploymentDef.
Class deploymentClass = Class.forName(deploymentInfo.getDeploymentDef());
Object callable =
ReflectUtil.getConstructor(deploymentClass, deploymentInfo.getInitArgs())
.newInstance(deploymentInfo.getInitArgs());
Serve.getReplicaContext().setServableObject(callable);
// Get the controller by controllerName.
Preconditions.checkArgument(
StringUtils.isNotBlank(controllerName), "Must provide a valid controllerName");
Optional<BaseActorHandle> optional = Ray.getActor(controllerName);
Preconditions.checkState(optional.isPresent(), "Controller does not exist");
// Enable metrics.
enableMetrics(deploymentInfo.getConfig());
// Construct worker replica.
this.replica =
new RayServeReplicaImpl(
callable,
deploymentInfo.getDeploymentConfig(),
deploymentInfo.getDeploymentVersion(),
optional.get());
this.deploymentInfo = deploymentInfo;
} catch (Throwable e) {
String errMsg =
LogUtil.format(
"Failed to initialize replica {} of deployment {}",
replicaTag,
deploymentInfo.getName());
LOGGER.error(errMsg, e);
throw new RayServeException(errMsg, e);
}
}
private void enableMetrics(Map<String, String> config) {
Optional.ofNullable(config)
.map(conf -> conf.get(RayServeConfig.METRICS_ENABLED))
.ifPresent(
enabled -> {
if (Boolean.valueOf(enabled)) {
RayServeMetrics.enable();
} else {
RayServeMetrics.disable();
}
});
}
private Object[] parseInitArgs(byte[] initArgsbytes, DeploymentConfig deploymentConfig)
@ -86,13 +139,13 @@ public class RayServeWrappedReplica {
return new Object[0];
}
if (!deploymentConfig.getIsCrossLanguage()) {
if (deploymentConfig.isCrossLanguage()) {
// For other language like Python API, not support Array type.
return new Object[] {MessagePackSerializer.decode(initArgsbytes, Object.class)};
} else {
// If the construction request is from Java API, deserialize initArgsbytes to Object[]
// directly.
return MessagePackSerializer.decode(initArgsbytes, Object[].class);
} else {
// For other language like Python API, not support Array type.
return new Object[] {MessagePackSerializer.decode(initArgsbytes, Object.class)};
}
}
@ -102,27 +155,28 @@ public class RayServeWrappedReplica {
* @param requestMetadata the real type is byte[] if this invocation is cross-language. Otherwise,
* the real type is {@link io.ray.serve.generated.RequestMetadata}.
* @param requestArgs The input parameters of the specified method of the object defined by
* backendDef. The real type is serialized {@link io.ray.serve.generated.RequestWrapper} if
* deploymentDef. The real type is serialized {@link io.ray.serve.generated.RequestWrapper} if
* this invocation is cross-language. Otherwise, the real type is Object[].
* @return the result of request being processed
* @throws InvalidProtocolBufferException if the protobuf deserialization fails.
*/
public Object handleRequest(Object requestMetadata, Object requestArgs)
throws InvalidProtocolBufferException {
@Override
public Object handleRequest(Object requestMetadata, Object requestArgs) {
boolean isCrossLanguage = requestMetadata instanceof byte[];
return backend.handleRequest(
new Query(
isCrossLanguage
? ServeProtoUtil.parseRequestMetadata((byte[]) requestMetadata)
: (RequestMetadata) requestMetadata,
isCrossLanguage
? ServeProtoUtil.parseRequestWrapper((byte[]) requestArgs)
: requestArgs));
return replica.handleRequest(
isCrossLanguage
? ServeProtoUtil.parseRequestMetadata((byte[]) requestMetadata)
: (RequestMetadata) requestMetadata,
isCrossLanguage ? ServeProtoUtil.parseRequestWrapper((byte[]) requestArgs) : requestArgs);
}
/** Check whether this replica is ready or not. */
public void ready() {
return;
/**
* Check if the actor is healthy.
*
* @return true if the actor is health, or return false.
*/
@Override
public boolean checkHealth() {
return replica.checkHealth();
}
/**
@ -130,16 +184,44 @@ public class RayServeWrappedReplica {
*
* @return true if it is ready for shutdown.
*/
@Override
public boolean prepareForShutdown() {
return backend.prepareForShutdown();
return replica.prepareForShutdown();
}
public byte[] reconfigure(Object userConfig) {
DeploymentVersion deploymentVersion = backend.reconfigure(userConfig);
return deploymentVersion.toByteArray();
/**
* Reconfigure user's configuration in the callable object through its reconfigure method.
*
* @param userConfig new user's configuration
* @return DeploymentVersion. If the current invocation is crossing language, the
* DeploymentVersion is serialized to protobuf byte[].
*/
@Override
public Object reconfigure(Object userConfig) {
DeploymentVersion deploymentVersion =
replica.reconfigure(
deploymentInfo.getDeploymentConfig().isCrossLanguage() && userConfig != null
? MessagePackSerializer.decode((byte[]) userConfig, Object.class)
: userConfig);
return deploymentInfo.getDeploymentConfig().isCrossLanguage()
? ServeProtoUtil.toProtobuf(deploymentVersion).toByteArray()
: deploymentVersion;
}
public byte[] getVersion() {
return backend.getVersion().toByteArray();
/**
* Get the deployment version of the current replica.
*
* @return DeploymentVersion. If the current invocation is crossing language, the
* DeploymentVersion is serialized to protobuf byte[].
*/
public Object getVersion() {
DeploymentVersion deploymentVersion = replica.getVersion();
return deploymentInfo.getDeploymentConfig().isCrossLanguage()
? ServeProtoUtil.toProtobuf(deploymentVersion).toByteArray()
: deploymentVersion;
}
public Object getCallable() {
return replica.getCallable();
}
}

View file

@ -10,7 +10,7 @@ public class ReplicaConfig implements Serializable {
private static final long serialVersionUID = -1442657824045704226L;
private String backendDef;
private String deploymentDef;
private byte[] initArgs;
@ -18,8 +18,8 @@ public class ReplicaConfig implements Serializable {
private Map<String, Double> resource;
public ReplicaConfig(String backendDef, byte[] initArgs, Map<String, Object> rayActorOptions) {
this.backendDef = backendDef;
public ReplicaConfig(String deploymentDef, byte[] initArgs, Map<String, Object> rayActorOptions) {
this.deploymentDef = deploymentDef;
this.initArgs = initArgs;
this.rayActorOptions = rayActorOptions;
this.resource = new HashMap<>();
@ -30,7 +30,7 @@ public class ReplicaConfig implements Serializable {
private void validate() {
Preconditions.checkArgument(
!rayActorOptions.containsKey("placement_group"),
"Providing placement_group for backend actors is not currently supported.");
"Providing placement_group for deployment actors is not currently supported.");
Preconditions.checkArgument(
!rayActorOptions.containsKey("lifetime"),
@ -81,12 +81,12 @@ public class ReplicaConfig implements Serializable {
resource.putAll((Map) customResources);
}
public String getBackendDef() {
return backendDef;
public String getDeploymentDef() {
return deploymentDef;
}
public void setBackendDef(String backendDef) {
this.backendDef = backendDef;
public void setDeploymentDef(String deploymentDef) {
this.deploymentDef = deploymentDef;
}
public byte[] getInitArgs() {

View file

@ -1,9 +1,9 @@
package io.ray.serve;
/** Stores data for Serve API calls from within the user's backend code. */
/** Stores data for Serve API calls from within deployments. */
public class ReplicaContext {
private String deploymentName; // TODO deployment
private String deploymentName;
private String replicaTag;
@ -11,6 +11,8 @@ public class ReplicaContext {
private Object servableObject;
private RayServeConfig rayServeConfig;
public ReplicaContext(
String deploymentName, String replicaTag, String controllerName, Object servableObject) {
this.deploymentName = deploymentName;
@ -50,4 +52,12 @@ public class ReplicaContext {
public void setServableObject(Object servableObject) {
this.servableObject = servableObject;
}
public RayServeConfig getRayServeConfig() {
return rayServeConfig;
}
public void setRayServeConfig(RayServeConfig rayServeConfig) {
this.rayServeConfig = rayServeConfig;
}
}

View file

@ -0,0 +1,30 @@
package io.ray.serve;
public class ReplicaName {
private String deploymentTag;
private String replicaSuffix;
private String replicaTag = "";
private String delimiter = "#";
public ReplicaName(String deploymentTag, String replicaSuffix) {
this.deploymentTag = deploymentTag;
this.replicaSuffix = replicaSuffix;
this.replicaTag = deploymentTag + this.delimiter + replicaSuffix;
}
public String getDeploymentTag() {
return deploymentTag;
}
public String getReplicaSuffix() {
return replicaSuffix;
}
public String getReplicaTag() {
return replicaTag;
}
}

View file

@ -9,7 +9,6 @@ import io.ray.runtime.metric.Gauge;
import io.ray.runtime.metric.Metrics;
import io.ray.runtime.metric.TagKey;
import io.ray.serve.generated.ActorSet;
import io.ray.serve.generated.DeploymentConfig;
import io.ray.serve.util.CollectionUtil;
import java.util.ArrayList;
import java.util.HashSet;
@ -27,8 +26,6 @@ public class ReplicaSet {
private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaSet.class);
private volatile int maxConcurrentQueries = 8;
private final Map<ActorHandle<RayServeWrappedReplica>, Set<ObjectRef<Object>>> inFlightQueries;
private AtomicInteger numQueuedQueries = new AtomicInteger();
@ -48,18 +45,6 @@ public class ReplicaSet {
.register());
}
public void setMaxConcurrentQueries(Object deploymentConfig) {
int newValue = ((DeploymentConfig) deploymentConfig).getMaxConcurrentQueries();
if (newValue != this.maxConcurrentQueries) {
this.maxConcurrentQueries = newValue;
LOGGER.info("ReplicaSet: changing max_concurrent_queries to {}", newValue);
}
}
public int getMaxConcurrentQueries() {
return maxConcurrentQueries;
}
@SuppressWarnings("unchecked")
public synchronized void updateWorkerReplicas(Object actorSet) {
List<String> actorNames = ((ActorSet) actorSet).getNamesList();
@ -86,7 +71,7 @@ public class ReplicaSet {
/**
* Given a query, submit it to a replica and return the object ref. This method will keep track of
* the in flight queries for each replicas and only send a query to available replicas (determined
* by the backend max_concurrent_quries value.)
* by the max_concurrent_quries value.)
*
* @param query the incoming query.
* @return ray.ObjectRef
@ -98,7 +83,7 @@ public class ReplicaSet {
() ->
numQueuedQueriesGauge.update(
numQueuedQueries.get(),
TagKey.tagsFromMap(ImmutableMap.of(RayServeMetrics.TAG_ENDPOINT, endpoint))));
ImmutableMap.of(new TagKey(RayServeMetrics.TAG_ENDPOINT), endpoint)));
ObjectRef<Object> assignedRef =
tryAssignReplica(query); // TODO controll concurrency using maxConcurrentQueries
numQueuedQueries.decrementAndGet();
@ -106,7 +91,7 @@ public class ReplicaSet {
() ->
numQueuedQueriesGauge.update(
numQueuedQueries.get(),
TagKey.tagsFromMap(ImmutableMap.of(RayServeMetrics.TAG_ENDPOINT, endpoint))));
ImmutableMap.of(new TagKey(RayServeMetrics.TAG_ENDPOINT), endpoint)));
return assignedRef;
}

View file

@ -13,7 +13,7 @@ import io.ray.serve.poll.LongPollNamespace;
import java.util.HashMap;
import java.util.Map;
/** Router process incoming queries: choose backend, and assign replica. */
/** Router process incoming queries: assign a replica. */
public class Router {
private ReplicaSet replicaSet;
@ -36,9 +36,6 @@ public class Router {
.register());
Map<KeyType, KeyListener> keyListeners = new HashMap<>();
keyListeners.put(
new KeyType(LongPollNamespace.BACKEND_CONFIGS, deploymentName),
deploymentConfig -> replicaSet.setMaxConcurrentQueries(deploymentConfig)); // cross language
keyListeners.put(
new KeyType(LongPollNamespace.REPLICA_HANDLES, deploymentName),
workerReplicas -> replicaSet.updateWorkerReplicas(workerReplicas)); // cross language

View file

@ -19,7 +19,7 @@ public class Serve {
/**
* Set replica information to global context.
*
* @param deploymentName backend tag
* @param deploymentName deployment name
* @param replicaTag replica tag
* @param controllerName the controller actor's name
* @param servableObject the servable object of the specified replica.
@ -42,7 +42,7 @@ public class Serve {
public static ReplicaContext getReplicaContext() {
if (INTERNAL_REPLICA_CONTEXT == null) {
throw new RayServeException(
"`Serve.getReplicaContext()` may only be called from within a Ray Serve backend.");
"`Serve.getReplicaContext()` may only be called from within a Ray Serve deployment.");
}
return INTERNAL_REPLICA_CONTEXT;
}

View file

@ -46,10 +46,7 @@ public class LongPollClient {
new HashMap<>();
static {
DESERIALIZERS.put(
LongPollNamespace.BACKEND_CONFIGS, body -> ServeProtoUtil.parseDeploymentConfig(body));
DESERIALIZERS.put(
LongPollNamespace.REPLICA_HANDLES, body -> ServeProtoUtil.parseEndpointSet(body));
DESERIALIZERS.put(LongPollNamespace.ROUTE_TABLE, body -> ServeProtoUtil.parseEndpointSet(body));
DESERIALIZERS.put(
LongPollNamespace.REPLICA_HANDLES,
body -> {
@ -89,7 +86,7 @@ public class LongPollClient {
}
}
},
"backend-poll-thread");
"ray-serve-long-poll-thread");
}
public void start() {

View file

@ -4,7 +4,5 @@ package io.ray.serve.poll;
public enum LongPollNamespace {
REPLICA_HANDLES,
BACKEND_CONFIGS,
ROUTE_TABLE;
}

View file

@ -1,15 +1,15 @@
package io.ray.serve.util;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.gson.Gson;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.Constants;
import io.ray.serve.DeploymentConfig;
import io.ray.serve.DeploymentVersion;
import io.ray.serve.RayServeException;
import io.ray.serve.generated.DeploymentConfig;
import io.ray.serve.generated.DeploymentLanguage;
import io.ray.serve.generated.DeploymentVersion;
import io.ray.serve.generated.EndpointInfo;
import io.ray.serve.generated.EndpointSet;
import io.ray.serve.generated.LongPollResult;
@ -27,71 +27,67 @@ public class ServeProtoUtil {
public static DeploymentConfig parseDeploymentConfig(byte[] deploymentConfigBytes) {
// Get a builder from DeploymentConfig(bytes) or create a new one.
DeploymentConfig.Builder builder = null;
DeploymentConfig deploymentConfig = new DeploymentConfig();
if (deploymentConfigBytes == null) {
builder = DeploymentConfig.newBuilder();
} else {
DeploymentConfig deploymentConfig = null;
try {
deploymentConfig = DeploymentConfig.parseFrom(deploymentConfigBytes);
} catch (InvalidProtocolBufferException e) {
throw new RayServeException("Failed to parse DeploymentConfig from protobuf bytes.", e);
}
if (deploymentConfig == null) {
builder = DeploymentConfig.newBuilder();
} else {
builder = DeploymentConfig.newBuilder(deploymentConfig);
}
return deploymentConfig;
}
// Set default values.
if (builder.getNumReplicas() == 0) {
builder.setNumReplicas(1);
io.ray.serve.generated.DeploymentConfig pbDeploymentConfig = null;
try {
pbDeploymentConfig = io.ray.serve.generated.DeploymentConfig.parseFrom(deploymentConfigBytes);
} catch (InvalidProtocolBufferException e) {
throw new RayServeException("Failed to parse DeploymentConfig from protobuf bytes.", e);
}
Preconditions.checkArgument(
builder.getMaxConcurrentQueries() >= 0, "max_concurrent_queries must be >= 0");
if (builder.getMaxConcurrentQueries() == 0) {
builder.setMaxConcurrentQueries(100);
if (pbDeploymentConfig == null) {
return deploymentConfig;
}
if (builder.getGracefulShutdownWaitLoopS() == 0) {
builder.setGracefulShutdownWaitLoopS(2);
if (pbDeploymentConfig.getNumReplicas() != 0) {
deploymentConfig.setNumReplicas(pbDeploymentConfig.getNumReplicas());
}
if (builder.getGracefulShutdownTimeoutS() == 0) {
builder.setGracefulShutdownTimeoutS(20);
if (pbDeploymentConfig.getMaxConcurrentQueries() != 0) {
deploymentConfig.setMaxConcurrentQueries(pbDeploymentConfig.getMaxConcurrentQueries());
}
if (builder.getDeploymentLanguage() == DeploymentLanguage.UNRECOGNIZED) {
if (pbDeploymentConfig.getGracefulShutdownWaitLoopS() != 0) {
deploymentConfig.setGracefulShutdownWaitLoopS(
pbDeploymentConfig.getGracefulShutdownWaitLoopS());
}
if (pbDeploymentConfig.getGracefulShutdownTimeoutS() != 0) {
deploymentConfig.setGracefulShutdownTimeoutS(
pbDeploymentConfig.getGracefulShutdownTimeoutS());
}
deploymentConfig.setCrossLanguage(pbDeploymentConfig.getIsCrossLanguage());
if (pbDeploymentConfig.getDeploymentLanguage() == DeploymentLanguage.UNRECOGNIZED) {
throw new RayServeException(
LogUtil.format(
"Unrecognized backend language {}. Backend language must be in {}.",
builder.getDeploymentLanguageValue(),
"Unrecognized deployment language {}. Deployment language must be in {}.",
pbDeploymentConfig.getDeploymentLanguage(),
Lists.newArrayList(DeploymentLanguage.values())));
}
return builder.build();
}
public static Object parseUserConfig(DeploymentConfig deploymentConfig) {
if (deploymentConfig.getUserConfig() == null || deploymentConfig.getUserConfig().size() == 0) {
return null;
deploymentConfig.setDeploymentLanguage(pbDeploymentConfig.getDeploymentLanguageValue());
if (pbDeploymentConfig.getUserConfig() != null
&& pbDeploymentConfig.getUserConfig().size() != 0) {
deploymentConfig.setUserConfig(
MessagePackSerializer.decode(
pbDeploymentConfig.getUserConfig().toByteArray(), Object.class));
}
return MessagePackSerializer.decode(
deploymentConfig.getUserConfig().toByteArray(), Object.class);
return deploymentConfig;
}
public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes)
throws InvalidProtocolBufferException {
public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes) {
// Get a builder from RequestMetadata(bytes) or create a new one.
RequestMetadata.Builder builder = null;
if (requestMetadataBytes == null) {
builder = RequestMetadata.newBuilder();
} else {
RequestMetadata requestMetadata = RequestMetadata.parseFrom(requestMetadataBytes);
RequestMetadata requestMetadata = null;
try {
requestMetadata = RequestMetadata.parseFrom(requestMetadataBytes);
} catch (InvalidProtocolBufferException e) {
throw new RayServeException("Failed to parse RequestMetadata from protobuf bytes.", e);
}
if (requestMetadata == null) {
builder = RequestMetadata.newBuilder();
} else {
@ -101,21 +97,25 @@ public class ServeProtoUtil {
// Set default values.
if (StringUtils.isBlank(builder.getCallMethod())) {
builder.setCallMethod(Constants.DEFAULT_CALL_METHOD);
builder.setCallMethod(Constants.CALL_METHOD);
}
return builder.build();
}
public static RequestWrapper parseRequestWrapper(byte[] httpRequestWrapperBytes)
throws InvalidProtocolBufferException {
public static RequestWrapper parseRequestWrapper(byte[] httpRequestWrapperBytes) {
// Get a builder from HTTPRequestWrapper(bytes) or create a new one.
RequestWrapper.Builder builder = null;
if (httpRequestWrapperBytes == null) {
builder = RequestWrapper.newBuilder();
} else {
RequestWrapper requestWrapper = RequestWrapper.parseFrom(httpRequestWrapperBytes);
RequestWrapper requestWrapper = null;
try {
requestWrapper = RequestWrapper.parseFrom(httpRequestWrapperBytes);
} catch (InvalidProtocolBufferException e) {
throw new RayServeException("Failed to parse RequestWrapper from protobuf bytes.", e);
}
if (requestWrapper == null) {
builder = RequestWrapper.newBuilder();
} else {
@ -160,12 +160,46 @@ public class ServeProtoUtil {
public static DeploymentVersion parseDeploymentVersion(byte[] deploymentVersionBytes) {
if (deploymentVersionBytes == null) {
return null;
return new DeploymentVersion();
}
io.ray.serve.generated.DeploymentVersion pbDeploymentVersion = null;
try {
return DeploymentVersion.parseFrom(deploymentVersionBytes);
pbDeploymentVersion =
io.ray.serve.generated.DeploymentVersion.parseFrom(deploymentVersionBytes);
} catch (InvalidProtocolBufferException e) {
throw new RayServeException("Failed to parse DeploymentVersion from protobuf bytes.", e);
}
if (pbDeploymentVersion == null) {
return new DeploymentVersion();
}
return new DeploymentVersion(
pbDeploymentVersion.getCodeVersion(),
pbDeploymentVersion.getUserConfig() != null
&& pbDeploymentVersion.getUserConfig().size() != 0
? new Object[] {
MessagePackSerializer.decode(
pbDeploymentVersion.getUserConfig().toByteArray(), Object.class)
}
: null);
}
public static io.ray.serve.generated.DeploymentVersion toProtobuf(
DeploymentVersion deploymentVersion) {
io.ray.serve.generated.DeploymentVersion.Builder pbDeploymentVersion =
io.ray.serve.generated.DeploymentVersion.newBuilder();
if (deploymentVersion == null) {
return pbDeploymentVersion.build();
}
if (StringUtils.isNotBlank(deploymentVersion.getCodeVersion())) {
pbDeploymentVersion.setCodeVersion(deploymentVersion.getCodeVersion());
}
if (deploymentVersion.getUserConfig() != null) {
pbDeploymentVersion.setUserConfig(
ByteString.copyFrom(
MessagePackSerializer.encode(deploymentVersion.getUserConfig()).getLeft()));
}
return pbDeploymentVersion.build();
}
}

View file

@ -58,7 +58,7 @@ public class HttpProxyTest {
try (CloseableHttpResponse httpResponse =
(CloseableHttpResponse) httpClient.execute(httpPost)) {
// No Backend replica, so error is expected.
// No replica, so error is expected.
int status = httpResponse.getCode();
Assert.assertEquals(status, HttpURLConnection.HTTP_INTERNAL_ERROR);
}

View file

@ -5,8 +5,7 @@ import io.ray.api.Ray;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.api.Serve;
import io.ray.serve.generated.ActorSet;
import io.ray.serve.generated.DeploymentConfig;
import io.ray.serve.generated.DeploymentVersion;
import io.ray.serve.generated.DeploymentLanguage;
import io.ray.serve.generated.EndpointInfo;
import io.ray.serve.util.CommonUtil;
import java.io.IOException;
@ -50,26 +49,29 @@ public class ProxyActorTest {
controller.task(DummyServeController::setEndpoints, endpointInfos).remote();
// Replica
DeploymentInfo deploymentInfo = new DeploymentInfo();
deploymentInfo.setDeploymentConfig(DeploymentConfig.newBuilder().build().toByteArray());
deploymentInfo.setDeploymentVersion(
DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray());
deploymentInfo.setReplicaConfig(
new ReplicaConfig(DummyBackendReplica.class.getName(), null, new HashMap<>()));
DeploymentInfo deploymentInfo =
new DeploymentInfo()
.setName(deploymentName)
.setDeploymentConfig(
new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber()))
.setDeploymentVersion(new DeploymentVersion(version))
.setDeploymentDef(DummyReplica.class.getName());
ActorHandle<RayServeWrappedReplica> replica =
Ray.actor(
RayServeWrappedReplica::new,
deploymentName,
replicaTag,
deploymentInfo,
controllerName)
replicaTag,
controllerName,
(RayServeConfig) null)
.setName(replicaTag)
.remote();
replica.task(RayServeWrappedReplica::ready).remote();
Assert.assertTrue(replica.task(RayServeWrappedReplica::checkHealth).remote().get());
// ProxyActor
ProxyActor proxyActor = new ProxyActor(controllerName, null);
Assert.assertTrue(proxyActor.ready());
proxyActor.getProxyRouter().updateRoutes(endpointInfos);
proxyActor
.getProxyRouter()

View file

@ -3,12 +3,9 @@ package io.ray.serve;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.api.Serve;
import io.ray.serve.generated.ActorSet;
import io.ray.serve.generated.DeploymentConfig;
import io.ray.serve.generated.DeploymentLanguage;
import io.ray.serve.generated.DeploymentVersion;
import java.util.HashMap;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -31,30 +28,29 @@ public class RayServeHandleTest {
Ray.actor(DummyServeController::new).setName(controllerName).remote();
// Replica
DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder();
deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA);
byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray();
DeploymentConfig deploymentConfig =
new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber());
Object[] initArgs = new Object[] {deploymentName, replicaTag, controllerName, new Object()};
byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft();
DeploymentInfo deploymentInfo = new DeploymentInfo();
deploymentInfo.setDeploymentConfig(deploymentConfigBytes);
deploymentInfo.setDeploymentVersion(
DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray());
deploymentInfo.setReplicaConfig(
new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>()));
DeploymentInfo deploymentInfo =
new DeploymentInfo()
.setName(deploymentName)
.setDeploymentConfig(deploymentConfig)
.setDeploymentVersion(new DeploymentVersion(version))
.setDeploymentDef("io.ray.serve.ReplicaContext")
.setInitArgs(initArgs);
ActorHandle<RayServeWrappedReplica> replicaHandle =
Ray.actor(
RayServeWrappedReplica::new,
deploymentName,
replicaTag,
deploymentInfo,
controllerName)
replicaTag,
controllerName,
(RayServeConfig) null)
.setName(actorName)
.remote();
replicaHandle.task(RayServeWrappedReplica::ready).remote();
Assert.assertTrue(replicaHandle.task(RayServeWrappedReplica::checkHealth).remote().get());
// RayServeHandle
RayServeHandle rayServeHandle =
@ -71,6 +67,7 @@ public class RayServeHandleTest {
if (!inited) {
Ray.shutdown();
}
Serve.setInternalReplicaContext(null);
}
}
}

View file

@ -1,16 +1,14 @@
package io.ray.serve;
import com.google.common.collect.ImmutableMap;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.generated.DeploymentConfig;
import io.ray.serve.api.Serve;
import io.ray.serve.generated.DeploymentLanguage;
import io.ray.serve.generated.DeploymentVersion;
import io.ray.serve.generated.RequestMetadata;
import io.ray.serve.generated.RequestWrapper;
import java.io.IOException;
import java.util.HashMap;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -32,64 +30,84 @@ public class RayServeReplicaTest {
ActorHandle<DummyServeController> controllerHandle =
Ray.actor(DummyServeController::new).setName(controllerName).remote();
DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder();
deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA);
byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray();
Object[] initArgs = new Object[] {deploymentName, replicaTag, controllerName, new Object()};
byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft();
DeploymentConfig deploymentConfig =
new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber());
DeploymentInfo deploymentInfo =
new DeploymentInfo()
.setName(deploymentName)
.setDeploymentConfig(deploymentConfig)
.setDeploymentVersion(new DeploymentVersion(version))
.setDeploymentDef(DummyReplica.class.getName());
DeploymentInfo deploymentInfo = new DeploymentInfo();
deploymentInfo.setDeploymentConfig(deploymentConfigBytes);
deploymentInfo.setDeploymentVersion(
DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray());
deploymentInfo.setReplicaConfig(
new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>()));
ActorHandle<RayServeWrappedReplica> backendHandle =
ActorHandle<RayServeWrappedReplica> replicHandle =
Ray.actor(
RayServeWrappedReplica::new,
deploymentName,
replicaTag,
deploymentInfo,
controllerName)
replicaTag,
controllerName,
(RayServeConfig) null)
.remote();
// ready
backendHandle.task(RayServeWrappedReplica::ready).remote();
Assert.assertTrue(replicHandle.task(RayServeWrappedReplica::checkHealth).remote().get());
// handle request
RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder();
requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10));
requestMetadata.setCallMethod("getDeploymentName");
requestMetadata.setCallMethod(Constants.CALL_METHOD);
RequestWrapper.Builder requestWrapper = RequestWrapper.newBuilder();
ObjectRef<Object> resultRef =
backendHandle
replicHandle
.task(
RayServeWrappedReplica::handleRequest,
requestMetadata.build().toByteArray(),
requestWrapper.build().toByteArray())
.remote();
Assert.assertEquals((String) resultRef.get(), deploymentName);
Assert.assertEquals((String) resultRef.get(), "1");
// reconfigure
ObjectRef<byte[]> versionRef =
backendHandle.task(RayServeWrappedReplica::reconfigure, (Object) null).remote();
Assert.assertEquals(DeploymentVersion.parseFrom(versionRef.get()).getCodeVersion(), version);
ObjectRef<Object> versionRef =
replicHandle.task(RayServeWrappedReplica::reconfigure, (Object) null).remote();
Assert.assertEquals(((DeploymentVersion) versionRef.get()).getCodeVersion(), version);
replicHandle.task(RayServeWrappedReplica::reconfigure, new Object()).remote().get();
resultRef =
replicHandle
.task(
RayServeWrappedReplica::handleRequest,
requestMetadata.build().toByteArray(),
requestWrapper.build().toByteArray())
.remote();
Assert.assertEquals((String) resultRef.get(), "1");
replicHandle
.task(RayServeWrappedReplica::reconfigure, ImmutableMap.of("value", "100"))
.remote()
.get();
resultRef =
replicHandle
.task(
RayServeWrappedReplica::handleRequest,
requestMetadata.build().toByteArray(),
requestWrapper.build().toByteArray())
.remote();
Assert.assertEquals((String) resultRef.get(), "101");
// get version
versionRef = backendHandle.task(RayServeWrappedReplica::getVersion).remote();
Assert.assertEquals(DeploymentVersion.parseFrom(versionRef.get()).getCodeVersion(), version);
versionRef = replicHandle.task(RayServeWrappedReplica::getVersion).remote();
Assert.assertEquals(((DeploymentVersion) versionRef.get()).getCodeVersion(), version);
// prepare for shutdown
ObjectRef<Boolean> shutdownRef =
backendHandle.task(RayServeWrappedReplica::prepareForShutdown).remote();
replicHandle.task(RayServeWrappedReplica::prepareForShutdown).remote();
Assert.assertTrue(shutdownRef.get());
} finally {
if (!inited) {
Ray.shutdown();
}
Serve.setInternalReplicaContext(null);
}
}
}

View file

@ -15,40 +15,42 @@ public class ReplicaConfigTest {
public void test() {
Object dummy = new Object();
String backendDef = "io.ray.serve.ReplicaConfigTest";
String deploymentDef = "io.ray.serve.ReplicaConfigTest";
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("placement_group", dummy)));
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("placement_group", dummy)));
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("lifetime", dummy)));
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("lifetime", dummy)));
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("name", dummy)));
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("name", dummy)));
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("max_restarts", dummy)));
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("max_restarts", dummy)));
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("num_cpus", -1.0)));
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("num_cpus", -1.0)));
ReplicaConfig replicaConfig =
new ReplicaConfig(backendDef, null, getRayActorOptions("num_cpus", 2.0));
new ReplicaConfig(deploymentDef, null, getRayActorOptions("num_cpus", 2.0));
Assert.assertEquals(replicaConfig.getResource().get("CPU").doubleValue(), 2.0);
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("num_gpus", -1.0)));
replicaConfig = new ReplicaConfig(backendDef, null, getRayActorOptions("num_gpus", 2.0));
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("num_gpus", -1.0)));
replicaConfig = new ReplicaConfig(deploymentDef, null, getRayActorOptions("num_gpus", 2.0));
Assert.assertEquals(replicaConfig.getResource().get("GPU").doubleValue(), 2.0);
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("memory", -1.0)));
replicaConfig = new ReplicaConfig(backendDef, null, getRayActorOptions("memory", 2.0));
() -> new ReplicaConfig(deploymentDef, null, getRayActorOptions("memory", -1.0)));
replicaConfig = new ReplicaConfig(deploymentDef, null, getRayActorOptions("memory", 2.0));
Assert.assertEquals(replicaConfig.getResource().get("memory").doubleValue(), 2.0);
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("object_store_memory", -1.0)));
() ->
new ReplicaConfig(
deploymentDef, null, getRayActorOptions("object_store_memory", -1.0)));
replicaConfig =
new ReplicaConfig(backendDef, null, getRayActorOptions("object_store_memory", 2.0));
new ReplicaConfig(deploymentDef, null, getRayActorOptions("object_store_memory", 2.0));
Assert.assertEquals(replicaConfig.getResource().get("object_store_memory").doubleValue(), 2.0);
}

View file

@ -0,0 +1,19 @@
package io.ray.serve;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
public class ReplicaNameTest {
@Test
public void test() {
String deploymentTag = "ReplicaNameTest";
String replicaSuffix = RandomStringUtils.randomAlphabetic(6);
ReplicaName replicaName = new ReplicaName(deploymentTag, replicaSuffix);
Assert.assertEquals(replicaName.getDeploymentTag(), deploymentTag);
Assert.assertEquals(replicaName.getReplicaSuffix(), replicaSuffix);
Assert.assertEquals(replicaName.getReplicaTag(), deploymentTag + "#" + replicaSuffix);
}
}

View file

@ -3,13 +3,10 @@ package io.ray.serve;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.api.Serve;
import io.ray.serve.generated.ActorSet;
import io.ray.serve.generated.DeploymentConfig;
import io.ray.serve.generated.DeploymentLanguage;
import io.ray.serve.generated.DeploymentVersion;
import io.ray.serve.generated.RequestMetadata;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.RandomStringUtils;
@ -20,16 +17,6 @@ public class ReplicaSetTest {
private String deploymentName = "ReplicaSetTest";
@Test
public void setMaxConcurrentQueriesTest() {
ReplicaSet replicaSet = new ReplicaSet(deploymentName);
DeploymentConfig.Builder builder = DeploymentConfig.newBuilder();
builder.setMaxConcurrentQueries(200);
replicaSet.setMaxConcurrentQueries(builder.build());
Assert.assertEquals(replicaSet.getMaxConcurrentQueries(), 200);
}
@Test
public void updateWorkerReplicasTest() {
ReplicaSet replicaSet = new ReplicaSet(deploymentName);
@ -58,30 +45,29 @@ public class ReplicaSetTest {
Ray.actor(DummyServeController::new).setName(controllerName).remote();
// Replica
DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder();
deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA);
byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray();
DeploymentConfig deploymentConfig =
new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber());
Object[] initArgs = new Object[] {deploymentName, replicaTag, controllerName, new Object()};
byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft();
DeploymentInfo deploymentInfo = new DeploymentInfo();
deploymentInfo.setDeploymentConfig(deploymentConfigBytes);
deploymentInfo.setDeploymentVersion(
DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray());
deploymentInfo.setReplicaConfig(
new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>()));
DeploymentInfo deploymentInfo =
new DeploymentInfo()
.setName(deploymentName)
.setDeploymentConfig(deploymentConfig)
.setDeploymentVersion(new DeploymentVersion(version))
.setDeploymentDef("io.ray.serve.ReplicaContext")
.setInitArgs(initArgs);
ActorHandle<RayServeWrappedReplica> replicaHandle =
Ray.actor(
RayServeWrappedReplica::new,
deploymentName,
replicaTag,
deploymentInfo,
controllerName)
replicaTag,
controllerName,
(RayServeConfig) null)
.setName(actorName)
.remote();
replicaHandle.task(RayServeWrappedReplica::ready).remote();
Assert.assertTrue(replicaHandle.task(RayServeWrappedReplica::checkHealth).remote().get());
// ReplicaSet
ReplicaSet replicaSet = new ReplicaSet(deploymentName);
@ -90,7 +76,6 @@ public class ReplicaSetTest {
replicaSet.updateWorkerReplicas(builder.build());
// assign
RequestMetadata.Builder requestMetadata = RequestMetadata.newBuilder();
requestMetadata.setRequestId(RandomStringUtils.randomAlphabetic(10));
requestMetadata.setCallMethod("getDeploymentName");
@ -103,6 +88,7 @@ public class ReplicaSetTest {
if (!inited) {
Ray.shutdown();
}
Serve.setInternalReplicaContext(null);
}
}
}

View file

@ -3,13 +3,10 @@ package io.ray.serve;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.runtime.serializer.MessagePackSerializer;
import io.ray.serve.api.Serve;
import io.ray.serve.generated.ActorSet;
import io.ray.serve.generated.DeploymentConfig;
import io.ray.serve.generated.DeploymentLanguage;
import io.ray.serve.generated.DeploymentVersion;
import io.ray.serve.generated.RequestMetadata;
import java.util.HashMap;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -33,30 +30,29 @@ public class RouterTest {
Ray.actor(DummyServeController::new).setName(controllerName).remote();
// Replica
DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder();
deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA);
byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray();
DeploymentConfig deploymentConfig =
new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber());
Object[] initArgs = new Object[] {deploymentName, replicaTag, controllerName, new Object()};
byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft();
DeploymentInfo deploymentInfo = new DeploymentInfo();
deploymentInfo.setDeploymentConfig(deploymentConfigBytes);
deploymentInfo.setDeploymentVersion(
DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray());
deploymentInfo.setReplicaConfig(
new ReplicaConfig("io.ray.serve.ReplicaContext", initArgsBytes, new HashMap<>()));
DeploymentInfo deploymentInfo =
new DeploymentInfo()
.setName(deploymentName)
.setDeploymentConfig(deploymentConfig)
.setDeploymentVersion(new DeploymentVersion(version))
.setDeploymentDef("io.ray.serve.ReplicaContext")
.setInitArgs(initArgs);
ActorHandle<RayServeWrappedReplica> replicaHandle =
Ray.actor(
RayServeWrappedReplica::new,
deploymentName,
replicaTag,
deploymentInfo,
controllerName)
replicaTag,
controllerName,
(RayServeConfig) null)
.setName(actorName)
.remote();
replicaHandle.task(RayServeWrappedReplica::ready).remote();
Assert.assertTrue(replicaHandle.task(RayServeWrappedReplica::checkHealth).remote().get());
// Router
Router router = new Router(controllerHandle, deploymentName);
@ -75,6 +71,7 @@ public class RouterTest {
if (!inited) {
Ray.shutdown();
}
Serve.setInternalReplicaContext(null);
}
}
}

View file

@ -10,10 +10,10 @@ public class KeyTypeTest {
@Test
public void hashTest() {
KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1");
KeyType k2 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1");
KeyType k3 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, null);
KeyType k4 = new KeyType(LongPollNamespace.REPLICA_HANDLES, "k4");
KeyType k1 = new KeyType(LongPollNamespace.ROUTE_TABLE, "k1");
KeyType k2 = new KeyType(LongPollNamespace.ROUTE_TABLE, "k1");
KeyType k3 = new KeyType(LongPollNamespace.ROUTE_TABLE, null);
KeyType k4 = new KeyType(LongPollNamespace.ROUTE_TABLE, "k4");
Assert.assertEquals(k1, k1);
Assert.assertEquals(k1.hashCode(), k1.hashCode());
@ -34,7 +34,7 @@ public class KeyTypeTest {
@Test
public void jsonTest() {
KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1");
KeyType k1 = new KeyType(LongPollNamespace.ROUTE_TABLE, "k1");
String json = GSON.toJson(k1);
KeyType k2 = GSON.fromJson(json, KeyType.class);

View file

@ -1,7 +1,8 @@
package io.ray.serve.poll;
import com.google.protobuf.ByteString;
import io.ray.serve.generated.DeploymentConfig;
import io.ray.serve.generated.EndpointInfo;
import io.ray.serve.generated.EndpointSet;
import io.ray.serve.generated.UpdatedObject;
import java.util.HashMap;
import java.util.Map;
@ -10,27 +11,29 @@ import org.testng.annotations.Test;
public class LongPollClientTest {
@SuppressWarnings("unchecked")
@Test
public void test() throws Throwable {
String[] a = new String[] {"test"};
// Construct a listener map.
KeyType keyType = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "deploymentName");
KeyType keyType = new KeyType(LongPollNamespace.ROUTE_TABLE, null);
Map<KeyType, KeyListener> keyListeners = new HashMap<>();
keyListeners.put(
keyType, (object) -> a[0] = String.valueOf(((DeploymentConfig) object).getNumReplicas()));
keyType, (object) -> a[0] = String.valueOf(((Map<String, EndpointInfo>) object).size()));
// Initialize LongPollClient.
LongPollClient longPollClient = new LongPollClient(null, keyListeners);
// Construct updated object.
DeploymentConfig.Builder deploymentConfig = DeploymentConfig.newBuilder();
deploymentConfig.setNumReplicas(20);
EndpointSet.Builder endpointSet = EndpointSet.newBuilder();
endpointSet.putEndpoints("1", EndpointInfo.newBuilder().build());
endpointSet.putEndpoints("2", EndpointInfo.newBuilder().build());
int snapshotId = 10;
UpdatedObject.Builder updatedObject = UpdatedObject.newBuilder();
updatedObject.setSnapshotId(snapshotId);
updatedObject.setObjectSnapshot(ByteString.copyFrom(deploymentConfig.build().toByteArray()));
updatedObject.setObjectSnapshot(ByteString.copyFrom(endpointSet.build().toByteArray()));
// Process update.
Map<KeyType, UpdatedObject> updates = new HashMap<>();
@ -40,8 +43,7 @@ public class LongPollClientTest {
// Validation.
Assert.assertEquals(longPollClient.getSnapshotIds().get(keyType).intValue(), snapshotId);
Assert.assertEquals(
((DeploymentConfig) longPollClient.getObjectSnapshots().get(keyType)).getNumReplicas(),
deploymentConfig.getNumReplicas());
Assert.assertEquals(a[0], String.valueOf(deploymentConfig.getNumReplicas()));
((Map<String, EndpointInfo>) longPollClient.getObjectSnapshots().get(keyType)).size(), 2);
Assert.assertEquals(a[0], String.valueOf(endpointSet.getEndpointsMap().size()));
}
}

View file

@ -0,0 +1,90 @@
package io.ray.serve.util;
import com.google.protobuf.ByteString;
import io.ray.serve.DeploymentConfig;
import io.ray.serve.DeploymentVersion;
import io.ray.serve.generated.RequestMetadata;
import io.ray.serve.generated.RequestWrapper;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
public class ServeProtoUtilTest {
@Test
public void parseDeploymentConfigTest() {
int numReplicas = 10;
io.ray.serve.generated.DeploymentConfig pbDeploymentConfig =
io.ray.serve.generated.DeploymentConfig.newBuilder().setNumReplicas(numReplicas).build();
DeploymentConfig deploymentConfig =
ServeProtoUtil.parseDeploymentConfig(pbDeploymentConfig.toByteArray());
Assert.assertNotNull(deploymentConfig);
Assert.assertEquals(deploymentConfig.getNumReplicas(), numReplicas);
Assert.assertEquals(deploymentConfig.getDeploymentLanguage(), 0);
Assert.assertEquals(deploymentConfig.getGracefulShutdownTimeoutS(), 20);
Assert.assertEquals(deploymentConfig.getGracefulShutdownWaitLoopS(), 2);
Assert.assertEquals(deploymentConfig.getMaxConcurrentQueries(), 100);
Assert.assertNull(deploymentConfig.getUserConfig());
Assert.assertEquals(deploymentConfig.isCrossLanguage(), false);
}
@Test
public void parseRequestMetadataTest() {
String prefix = "parseRequestMetadataTest";
String requestId = RandomStringUtils.randomAlphabetic(10);
String callMethod = prefix + "_method";
String endpoint = prefix + "_endpoint";
String context = prefix + "_context";
RequestMetadata requestMetadata =
RequestMetadata.newBuilder()
.setRequestId(requestId)
.setCallMethod(callMethod)
.setEndpoint(endpoint)
.putContext("context", context)
.build();
RequestMetadata result = ServeProtoUtil.parseRequestMetadata(requestMetadata.toByteArray());
Assert.assertNotNull(result);
Assert.assertEquals(result.getCallMethod(), callMethod);
Assert.assertEquals(result.getEndpoint(), endpoint);
Assert.assertEquals(result.getRequestId(), requestId);
Assert.assertEquals(result.getContextMap().get("context"), context);
}
@Test
public void parseRequestWrapperTest() {
byte[] body = new byte[] {1, 2};
RequestWrapper requestWrapper =
RequestWrapper.newBuilder().setBody(ByteString.copyFrom(body)).build();
RequestWrapper result = ServeProtoUtil.parseRequestWrapper(requestWrapper.toByteArray());
Assert.assertNotNull(result);
byte[] rstBody = result.getBody().toByteArray();
Assert.assertEquals(rstBody[0], 1);
Assert.assertEquals(rstBody[1], 2);
}
@Test
public void parseDeploymentVersionTest() {
String codeVersion = "parseDeploymentVersionTest";
io.ray.serve.generated.DeploymentVersion pbDeploymentVersion =
io.ray.serve.generated.DeploymentVersion.newBuilder().setCodeVersion(codeVersion).build();
DeploymentVersion deploymentVersion =
ServeProtoUtil.parseDeploymentVersion(pbDeploymentVersion.toByteArray());
Assert.assertNotNull(deploymentVersion);
Assert.assertEquals(deploymentVersion.getCodeVersion(), codeVersion);
}
@Test
public void toDeploymentVersionProtobufTest() {
String codeVersion = "toDeploymentVersionProtobufTest";
DeploymentVersion deploymentVersion = new DeploymentVersion(codeVersion);
io.ray.serve.generated.DeploymentVersion pbDeploymentVersion =
ServeProtoUtil.toProtobuf(deploymentVersion);
Assert.assertNotNull(pbDeploymentVersion);
Assert.assertEquals(pbDeploymentVersion.getCodeVersion(), codeVersion);
}
}

View file

@ -95,6 +95,8 @@ message RequestMetadata {
string endpoint = 2;
string call_method = 3;
map<string, string> context = 4;
}
message RequestWrapper {