[Serve] Define Java Backend (#16169)

This commit is contained in:
liuyang-my 2021-07-02 11:41:17 +08:00 committed by GitHub
parent 7f78e8c014
commit 2c3ce469ba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 1587 additions and 0 deletions

View file

@ -12,6 +12,7 @@ exports_files([
all_modules = [
"api",
"runtime",
"serve",
"test",
"performance_test",
]
@ -134,6 +135,27 @@ define_java_module(
],
)
define_java_module(
name = "serve",
define_test_lib = True,
test_deps = [
":io_ray_ray_api",
":io_ray_ray_serve",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:com_google_guava_guava",
"@maven//:org_slf4j_slf4j_api",
"@maven//:org_testng_testng",
],
visibility = ["//visibility:public"],
deps = [
":io_ray_ray_api",
":io_ray_ray_runtime",
"@maven//:com_google_guava_guava",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_slf4j_slf4j_api",
],
)
java_binary(
name = "all_tests",
args = ["java/testng.xml"],
@ -142,6 +164,7 @@ java_binary(
runtime_deps = [
":io_ray_ray_performance_test",
":io_ray_ray_runtime_test",
":io_ray_ray_serve_test",
":io_ray_ray_test",
],
)
@ -237,6 +260,7 @@ genrule(
cp -f $(location //java:io_ray_ray_runtime_pom) "$$WORK_DIR/java/runtime/pom.xml"
cp -f $(location //java:io_ray_ray_test_pom) "$$WORK_DIR/java/test/pom.xml"
cp -f $(location //java:io_ray_ray_performance_test_pom) "$$WORK_DIR/java/performance_test/pom.xml"
cp -f $(location //java:io_ray_ray_serve_pom) "$$WORK_DIR/java/serve/pom.xml"
date > $@
""",
local = 1,
@ -251,6 +275,7 @@ java_binary(
runtime_deps = [
"//java:io_ray_ray_api",
"//java:io_ray_ray_runtime",
"//java:io_ray_ray_serve",
"//streaming/java:io_ray_ray_streaming-api",
"//streaming/java:io_ray_ray_streaming-runtime",
],

View file

@ -56,6 +56,7 @@
<modules>
<module>api</module>
<module>runtime</module>
<module>serve</module>
<module>test</module>
</modules>

51
java/serve/pom.xml Normal file
View file

@ -0,0 +1,51 @@
<?xml version="1.0" encoding="UTF-8"?>
<!-- This file is auto-generated by Bazel from pom_template.xml, do not modify it. -->
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<parent>
<groupId>io.ray</groupId>
<artifactId>ray-superpom</artifactId>
<version>1.1.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>ray-serve</artifactId>
<name>ray serve</name>
<description>java for ray serve</description>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-runtime</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>27.0.1-jre</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.25</version>
</dependency>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>7.3.0</version>
</dependency>
</dependencies>
</project>

View file

@ -0,0 +1,32 @@
<?xml version="1.0" encoding="UTF-8"?>
{auto_gen_header}
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<parent>
<groupId>io.ray</groupId>
<artifactId>ray-superpom</artifactId>
<version>1.1.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>ray-serve</artifactId>
<name>ray serve</name>
<description>java for ray serve</description>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-runtime</artifactId>
<version>${project.version}</version>
</dependency>
{generated_bzl_deps}
</dependencies>
</project>

View file

@ -0,0 +1,86 @@
package io.ray.serve;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import org.apache.commons.lang3.builder.ReflectionToStringBuilder;
/** Configuration options for a backend, to be set by the user. */
public class BackendConfig implements Serializable {
private static final long serialVersionUID = 244486384449779141L;
/**
* The number of processes to start up that will handle requests to this backend. Defaults to 1.
*/
private int numReplicas = 1;
/**
* The maximum number of queries that will be sent to a replica of this backend without receiving
* a response. Defaults to 100.
*/
private int maxConcurrentQueries = 100;
/**
* Arguments to pass to the reconfigure method of the backend. The reconfigure method is called if
* user_config is not None.
*/
private Object userConfig;
/**
* Duration that backend workers will wait until there is no more work to be done before shutting
* down. Defaults to 2s.
*/
private long experimentalGracefulShutdownWaitLoopS = 2;
/**
* Controller waits for this duration to forcefully kill the replica for shutdown. Defaults to
* 20s.
*/
private long experimentalGracefulShutdownTimeoutS = 20;
@Override
public String toString() {
return ReflectionToStringBuilder.toString(this);
}
public int getNumReplicas() {
return numReplicas;
}
public void setNumReplicas(int numReplicas) {
this.numReplicas = numReplicas;
}
public int getMaxConcurrentQueries() {
return maxConcurrentQueries;
}
public void setMaxConcurrentQueries(int maxConcurrentQueries) {
Preconditions.checkArgument(maxConcurrentQueries > 0, "max_concurrent_queries must be > 0");
this.maxConcurrentQueries = maxConcurrentQueries;
}
public Object getUserConfig() {
return userConfig;
}
public void setUserConfig(Object userConfig) {
this.userConfig = userConfig;
}
public long getExperimentalGracefulShutdownWaitLoopS() {
return experimentalGracefulShutdownWaitLoopS;
}
public void setExperimentalGracefulShutdownWaitLoopS(long experimentalGracefulShutdownWaitLoopS) {
this.experimentalGracefulShutdownWaitLoopS = experimentalGracefulShutdownWaitLoopS;
}
public long getExperimentalGracefulShutdownTimeoutS() {
return experimentalGracefulShutdownTimeoutS;
}
public void setExperimentalGracefulShutdownTimeoutS(long experimentalGracefulShutdownTimeoutS) {
this.experimentalGracefulShutdownTimeoutS = experimentalGracefulShutdownTimeoutS;
}
}

View file

@ -0,0 +1,19 @@
package io.ray.serve;
import com.google.common.collect.Lists;
import java.util.List;
/** Ray Serve common constants. */
public class Constants {
/** Name of backend reconfiguration method implemented by user. */
public static final String BACKEND_RECONFIGURE_METHOD = "reconfigure";
/** Default histogram buckets for latency tracker. */
public static final List<Double> DEFAULT_LATENCY_BUCKET_MS =
Lists.newArrayList(
1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0, 200.0, 500.0, 1000.0, 2000.0, 5000.0);
/** Name of controller listen_for_change method. */
public static final String CONTROLLER_LISTEN_FOR_CHANGE_METHOD = "listen_for_change";
}

View file

@ -0,0 +1,30 @@
package io.ray.serve;
/** Wrap request arguments and meta data. */
public class Query {
private Object[] args;
private RequestMetadata metadata;
public Query(Object[] args, RequestMetadata requestMetadata) {
this.args = args;
this.metadata = requestMetadata;
}
public Object[] getArgs() {
return args;
}
public void setArgs(Object[] args) {
this.args = args;
}
public RequestMetadata getMetadata() {
return metadata;
}
public void setMetadata(RequestMetadata metadata) {
this.metadata = metadata;
}
}

View file

@ -0,0 +1,16 @@
package io.ray.serve;
import io.ray.runtime.exception.RayException;
public class RayServeException extends RayException {
private static final long serialVersionUID = 4673951342965950469L;
public RayServeException(String message) {
super(message);
}
public RayServeException(String message, Throwable cause) {
super(message, cause);
}
}

View file

@ -0,0 +1,254 @@
package io.ray.serve;
import com.google.common.collect.ImmutableMap;
import io.ray.api.BaseActorHandle;
import io.ray.api.Ray;
import io.ray.runtime.metric.Count;
import io.ray.runtime.metric.Gauge;
import io.ray.runtime.metric.Histogram;
import io.ray.runtime.metric.MetricConfig;
import io.ray.runtime.metric.Metrics;
import io.ray.serve.api.Serve;
import io.ray.serve.poll.KeyListener;
import io.ray.serve.poll.KeyType;
import io.ray.serve.poll.LongPollClient;
import io.ray.serve.poll.LongPollNamespace;
import io.ray.serve.util.LogUtil;
import io.ray.serve.util.ReflectUtil;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Handles requests with the provided callable. */
public class RayServeReplica {
private static final Logger LOGGER = LoggerFactory.getLogger(RayServeReplica.class);
private String backendTag;
private String replicaTag;
private BackendConfig config;
private AtomicInteger numOngoingRequests = new AtomicInteger();
private Object callable;
private boolean metricsRegistered = false;
private Count requestCounter;
private Count errorCounter;
private Count restartCounter;
private Histogram processingLatencyTracker;
private Gauge numProcessingItems;
private LongPollClient longPollClient;
public RayServeReplica(
Object callable, BackendConfig backendConfig, BaseActorHandle actorHandle) {
this.backendTag = Serve.getReplicaContext().getBackendTag();
this.replicaTag = Serve.getReplicaContext().getReplicaTag();
this.callable = callable;
this.config = backendConfig;
this.reconfigure(backendConfig.getUserConfig());
Map<KeyType, KeyListener> keyListeners = new HashMap<>();
keyListeners.put(
new KeyType(LongPollNamespace.BACKEND_CONFIGS, backendTag),
newConfig -> updateBackendConfigs(newConfig));
this.longPollClient = new LongPollClient(actorHandle, keyListeners);
this.longPollClient.start();
registerMetrics();
}
private void registerMetrics() {
if (!Ray.isInitialized() || Ray.getRuntimeContext().isSingleProcess()) {
return;
}
Metrics.init(MetricConfig.DEFAULT_CONFIG);
requestCounter =
Metrics.count()
.name("serve_backend_request_counter")
.description("The number of queries that have been processed in this replica.")
.unit("")
.tags(ImmutableMap.of("backend", backendTag))
.register();
errorCounter =
Metrics.count()
.name("serve_backend_error_counter")
.description("The number of exceptions that have occurred in the backend.")
.unit("")
.tags(ImmutableMap.of("backend", backendTag))
.register();
restartCounter =
Metrics.count()
.name("serve_backend_replica_starts")
.description("The number of times this replica has been restarted due to failure.")
.unit("")
.tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag))
.register();
processingLatencyTracker =
Metrics.histogram()
.name("serve_backend_processing_latency_ms")
.description("The latency for queries to be processed.")
.unit("")
.boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS)
.tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag))
.register();
numProcessingItems =
Metrics.gauge()
.name("serve_replica_processing_queries")
.description("The current number of queries being processed.")
.unit("")
.tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag))
.register();
metricsRegistered = true;
restartCounter.inc(1.0);
}
public Object handleRequest(Query request) {
long startTime = System.currentTimeMillis();
LOGGER.debug(
"Replica {} received request {}", replicaTag, request.getMetadata().getRequestId());
numOngoingRequests.incrementAndGet();
reportMetrics(() -> numProcessingItems.update(numOngoingRequests.get()));
Object result = invokeSingle(request);
numOngoingRequests.decrementAndGet();
long requestTimeMs = System.currentTimeMillis() - startTime;
LOGGER.debug(
"Replica {} finished request {} in {}ms",
replicaTag,
request.getMetadata().getRequestId(),
requestTimeMs);
return result;
}
private Object invokeSingle(Query requestItem) {
long start = System.currentTimeMillis();
Method methodToCall = null;
try {
LOGGER.debug(
"Replica {} started executing request {}",
replicaTag,
requestItem.getMetadata().getRequestId());
methodToCall = getRunnerMethod(requestItem);
Object result = methodToCall.invoke(callable, requestItem.getArgs());
reportMetrics(() -> requestCounter.inc(1.0));
return result;
} catch (Throwable e) {
reportMetrics(() -> errorCounter.inc(1.0));
throw new RayServeException(
LogUtil.format(
"Replica {} failed to invoke method {}",
replicaTag,
methodToCall == null ? "unknown" : methodToCall.getName()),
e);
} finally {
reportMetrics(() -> processingLatencyTracker.update(System.currentTimeMillis() - start));
}
}
private Method getRunnerMethod(Query query) {
String methodName = query.getMetadata().getCallMethod();
try {
return ReflectUtil.getMethod(
callable.getClass(), methodName, query.getArgs() == null ? null : query.getArgs());
} catch (NoSuchMethodException e) {
throw new RayServeException(
LogUtil.format(
"Backend doesn't have method {} which is specified in the request. "
+ "The available methods are {}",
methodName,
ReflectUtil.getMethodStrings(callable.getClass())));
}
}
/**
* Perform graceful shutdown. Trigger a graceful shutdown protocol that will wait for all the
* queued tasks to be completed and return to the controller.
*/
public void drainPendingQueries() {
while (true) {
try {
Thread.sleep(config.getExperimentalGracefulShutdownWaitLoopS() * 1000);
} catch (InterruptedException e) {
LOGGER.error(
"Replica {} was interrupted in sheep when draining pending queries", replicaTag);
}
if (numOngoingRequests.get() == 0) {
break;
} else {
LOGGER.debug(
"Waiting for an additional {}s to shut down because there are {} ongoing requests.",
config.getExperimentalGracefulShutdownWaitLoopS(),
numOngoingRequests.get());
}
}
Ray.exitActor();
}
/**
* Reconfigure user's configuration in the callable object through its reconfigure method.
*
* @param userConfig new user's configuration
*/
private void reconfigure(Object userConfig) {
if (userConfig == null) {
return;
}
try {
Method reconfigureMethod =
ReflectUtil.getMethod(
callable.getClass(),
Constants.BACKEND_RECONFIGURE_METHOD,
userConfig); // TODO cache reconfigureMethod
reconfigureMethod.invoke(callable, userConfig);
} catch (NoSuchMethodException e) {
throw new RayServeException(
LogUtil.format(
"user_config specified but backend {} missing {} method",
backendTag,
Constants.BACKEND_RECONFIGURE_METHOD));
} catch (Throwable e) {
throw new RayServeException(
LogUtil.format("Backend {} failed to reconfigure user_config {}", backendTag, userConfig),
e);
}
}
/**
* Update backend configs.
*
* @param newConfig the new configuration of backend
*/
private void updateBackendConfigs(Object newConfig) {
config = (BackendConfig) newConfig;
reconfigure(((BackendConfig) newConfig).getUserConfig());
}
private void reportMetrics(Runnable runnable) {
if (metricsRegistered) {
runnable.run();
}
}
}

View file

@ -0,0 +1,67 @@
package io.ray.serve;
import com.google.common.base.Preconditions;
import io.ray.api.BaseActorHandle;
import io.ray.api.Ray;
import io.ray.serve.api.Serve;
import io.ray.serve.util.ReflectUtil;
import java.lang.reflect.InvocationTargetException;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
/** Replica class wrapping the provided class. Note that Java function is not supported now. */
public class RayServeWrappedReplica {
private RayServeReplica backend;
@SuppressWarnings("rawtypes")
public RayServeWrappedReplica(
String backendTag,
String replicaTag,
String backendDef,
Object[] initArgs,
BackendConfig backendConfig,
String controllerName)
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
IllegalAccessException, IllegalArgumentException, InvocationTargetException {
// Instantiate the object defined by backendDef.
Class backendClass = Class.forName(backendDef);
Object callable = ReflectUtil.getConstructor(backendClass, initArgs).newInstance(initArgs);
// Get the controller by controllerName.
Preconditions.checkArgument(
StringUtils.isNotBlank(controllerName), "Must provide a valid controllerName");
Optional<BaseActorHandle> optional = Ray.getActor(controllerName);
Preconditions.checkState(optional.isPresent(), "Controller does not exist");
// Set the controller name so that Serve.connect() in the user's backend code will connect to
// the instance that this backend is running in.
Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, callable);
// Construct worker replica.
backend = new RayServeReplica(callable, backendConfig, optional.get());
}
/**
* The entry method to process the request.
*
* @param requestMetadata request metadata
* @param requestArgs the input parameters of the specified method of the object defined by
* backendDef.
* @return the result of request being processed
*/
public Object handle_request(RequestMetadata requestMetadata, Object[] requestArgs) {
return backend.handleRequest(new Query(requestArgs, requestMetadata));
}
/** Check whether this replica is ready or not. */
public void ready() {
return;
}
/** Wait until there is no request in processing. It is used for stopping replica gracefully. */
public void drain_pending_queries() {
backend.drainPendingQueries();
}
}

View file

@ -0,0 +1,115 @@
package io.ray.serve;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
/** Configuration options for a replica. */
public class ReplicaConfig implements Serializable {
private static final long serialVersionUID = -1442657824045704226L;
private String backendDef;
private Object[] initArgs;
private Map<String, Object> rayActorOptions;
private Map<String, Double> resource;
public ReplicaConfig(String backendDef, Object[] initArgs, Map<String, Object> rayActorOptions) {
this.backendDef = backendDef;
this.initArgs = initArgs;
this.rayActorOptions = rayActorOptions;
this.resource = new HashMap<>();
this.validate();
}
@SuppressWarnings({"unchecked", "rawtypes"})
private void validate() {
Preconditions.checkArgument(
!rayActorOptions.containsKey("placement_group"),
"Providing placement_group for backend actors is not currently supported.");
Preconditions.checkArgument(
!rayActorOptions.containsKey("lifetime"),
"Specifying lifetime in init_args is not allowed.");
Preconditions.checkArgument(
!rayActorOptions.containsKey("name"), "Specifying name in init_args is not allowed.");
Preconditions.checkArgument(
!rayActorOptions.containsKey("max_restarts"),
"Specifying max_restarts in init_args is not allowed.");
// TODO Confirm num_cpus, num_gpus, memory is double in protobuf.
// Ray defaults to zero CPUs for placement, we default to one here.
Object numCpus = rayActorOptions.getOrDefault("num_cpus", 1.0);
Preconditions.checkArgument(
numCpus instanceof Double, "num_cpus in ray_actor_options must be a double.");
Preconditions.checkArgument(
((Double) numCpus) >= 0, "num_cpus in ray_actor_options must be >= 0.");
resource.put("CPU", (Double) numCpus);
Object numGpus = rayActorOptions.getOrDefault("num_gpus", 0.0);
Preconditions.checkArgument(
numGpus instanceof Double, "num_gpus in ray_actor_options must be a double.");
Preconditions.checkArgument(
((Double) numGpus) >= 0, "num_gpus in ray_actor_options must be >= 0.");
resource.put("GPU", (Double) numGpus);
Object memory = rayActorOptions.getOrDefault("memory", 0.0);
Preconditions.checkArgument(
memory instanceof Double, "memory in ray_actor_options must be a double.");
Preconditions.checkArgument(
((Double) memory) >= 0, "memory in ray_actor_options must be >= 0.");
resource.put("memory", (Double) memory);
Object objectStoreMemory = rayActorOptions.getOrDefault("object_store_memory", 0.0);
Preconditions.checkArgument(
objectStoreMemory instanceof Double,
"object_store_memory in ray_actor_options must be a double.");
Preconditions.checkArgument(
((Double) objectStoreMemory) >= 0,
"object_store_memory in ray_actor_options must be >= 0.");
resource.put("object_store_memory", (Double) objectStoreMemory);
Object customResources = rayActorOptions.getOrDefault("resources", new HashMap<>());
Preconditions.checkArgument(
customResources instanceof Map, "resources in ray_actor_options must be a map.");
resource.putAll((Map) customResources);
}
public String getBackendDef() {
return backendDef;
}
public void setBackendDef(String backendDef) {
this.backendDef = backendDef;
}
public Object[] getInitArgs() {
return initArgs;
}
public void setInitArgs(Object[] initArgs) {
this.initArgs = initArgs;
}
public Map<String, Object> getRayActorOptions() {
return rayActorOptions;
}
public void setRayActorOptions(Map<String, Object> rayActorOptions) {
this.rayActorOptions = rayActorOptions;
}
public Map<String, Double> getResource() {
return resource;
}
public void setResource(Map<String, Double> resource) {
this.resource = resource;
}
}

View file

@ -0,0 +1,53 @@
package io.ray.serve;
/** Stores data for Serve API calls from within the user's backend code. */
public class ReplicaContext {
private String backendTag;
private String replicaTag;
private String internalControllerName;
private Object servableObject;
public ReplicaContext(
String backendTag, String replicaTag, String controllerName, Object servableObject) {
this.backendTag = backendTag;
this.replicaTag = replicaTag;
this.internalControllerName = controllerName;
this.servableObject = servableObject;
}
public String getBackendTag() {
return backendTag;
}
public void setBackendTag(String backendTag) {
this.backendTag = backendTag;
}
public String getReplicaTag() {
return replicaTag;
}
public void setReplicaTag(String replicaTag) {
this.replicaTag = replicaTag;
}
public String getInternalControllerName() {
return internalControllerName;
}
public void setInternalControllerName(String internalControllerName) {
this.internalControllerName = internalControllerName;
}
public Object getServableObject() {
return servableObject;
}
public void setServableObject(Object servableObject) {
this.servableObject = servableObject;
}
}

View file

@ -0,0 +1,60 @@
package io.ray.serve;
import java.io.Serializable;
import java.util.Map;
/** The meta data of request. */
public class RequestMetadata implements Serializable {
private static final long serialVersionUID = -8925036926565326811L;
private String requestId;
private String endpoint;
private String callMethod = "__call__";
private String httpMethod;
private Map<String, String> httpHeaders;
public String getRequestId() {
return requestId;
}
public void setRequestId(String requestId) {
this.requestId = requestId;
}
public String getEndpoint() {
return endpoint;
}
public void setEndpoint(String endpoint) {
this.endpoint = endpoint;
}
public String getCallMethod() {
return callMethod;
}
public void setCallMethod(String callMethod) {
this.callMethod = callMethod;
}
public String getHttpMethod() {
return httpMethod;
}
public void setHttpMethod(String httpMethod) {
this.httpMethod = httpMethod;
}
public Map<String, String> getHttpHeaders() {
return httpHeaders;
}
public void setHttpHeaders(Map<String, String> httpHeaders) {
this.httpHeaders = httpHeaders;
}
}

View file

@ -0,0 +1,38 @@
package io.ray.serve.api;
import io.ray.serve.RayServeException;
import io.ray.serve.ReplicaContext;
/** Ray Serve global API. TODO: will be riched in the Java SDK/API PR. */
public class Serve {
public static ReplicaContext INTERNAL_REPLICA_CONTEXT;
/**
* Set replica information to global context.
*
* @param backendTag backend tag
* @param replicaTag replica tag
* @param controllerName the controller actor's name
* @param servableObject the servable object of the specified replica.
*/
public static void setInternalReplicaContext(
String backendTag, String replicaTag, String controllerName, Object servableObject) {
// TODO singleton.
INTERNAL_REPLICA_CONTEXT =
new ReplicaContext(backendTag, replicaTag, controllerName, servableObject);
}
/**
* Get the global replica context.
*
* @return the replica context if it exists, or throw RayServeException.
*/
public static ReplicaContext getReplicaContext() {
if (INTERNAL_REPLICA_CONTEXT == null) {
throw new RayServeException(
"`Serve.getReplicaContext()` may only be called from within a Ray Serve backend.");
}
return INTERNAL_REPLICA_CONTEXT;
}
}

View file

@ -0,0 +1,8 @@
package io.ray.serve.poll;
/** Listener of long poll. It notifies changed object to the specified key. */
@FunctionalInterface
public interface KeyListener {
void notifyChanged(Object object);
}

View file

@ -0,0 +1,56 @@
package io.ray.serve.poll;
import java.io.Serializable;
import java.util.Objects;
import org.apache.commons.lang3.builder.ReflectionToStringBuilder;
/** Key type of long poll. */
public class KeyType implements Serializable {
private static final long serialVersionUID = -8838552786234630401L;
private final LongPollNamespace longPollNamespace;
private final String key;
private int hash;
public KeyType(LongPollNamespace longPollNamespace, String key) {
this.longPollNamespace = longPollNamespace;
this.key = key;
}
public LongPollNamespace getLongPollNamespace() {
return longPollNamespace;
}
public String getKey() {
return key;
}
@Override
public int hashCode() {
if (hash == 0) {
hash = Objects.hash(longPollNamespace, key);
}
return hash;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
KeyType keyType = (KeyType) obj;
return Objects.equals(longPollNamespace, keyType.getLongPollNamespace())
&& Objects.equals(key, keyType.getKey());
}
@Override
public String toString() {
return ReflectionToStringBuilder.toString(this);
}
}

View file

@ -0,0 +1,102 @@
package io.ray.serve.poll;
import com.google.common.base.Preconditions;
import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.function.PyActorMethod;
import io.ray.runtime.exception.RayActorException;
import io.ray.runtime.exception.RayTaskException;
import io.ray.serve.Constants;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** The asynchronous long polling client. */
public class LongPollClient {
private static final Logger LOGGER = LoggerFactory.getLogger(LongPollClient.class);
/** Handle to actor embedding LongPollHost. */
private BaseActorHandle hostActor;
/** A set mapping keys to callbacks to be called on state update for the corresponding keys. */
private Map<KeyType, KeyListener> keyListeners;
private Map<KeyType, Integer> snapshotIds;
private Map<KeyType, Object> objectSnapshots;
private ObjectRef<Object> currentRef;
/** An async thread to post the callback into. */
private Thread pollThread;
public LongPollClient(BaseActorHandle hostActor, Map<KeyType, KeyListener> keyListeners) {
Preconditions.checkArgument(keyListeners != null && keyListeners.size() != 0);
this.hostActor = hostActor;
this.keyListeners = keyListeners;
this.snapshotIds = new ConcurrentHashMap<>();
for (KeyType keyType : keyListeners.keySet()) {
this.snapshotIds.put(keyType, -1);
}
this.objectSnapshots = new ConcurrentHashMap<>();
this.pollThread =
new Thread(
() -> {
while (true) {
try {
pollNext();
} catch (RayActorException e) {
LOGGER.debug("LongPollClient failed to connect to host. Shutting down.");
break;
} catch (RayTaskException e) {
LOGGER.error("LongPollHost errored", e);
} catch (Throwable e) {
LOGGER.error("LongPollClient failed to update object of key {}", snapshotIds, e);
}
}
},
"backend-poll-thread");
}
public void start() {
if (!(hostActor instanceof PyActorHandle)) {
LOGGER.warn("LongPollClient only support Python controller now.");
return;
}
pollThread.start();
}
/** Poll the update. */
@SuppressWarnings("unchecked")
public void pollNext() {
currentRef =
((PyActorHandle) hostActor)
.task(PyActorMethod.of(Constants.CONTROLLER_LISTEN_FOR_CHANGE_METHOD), snapshotIds)
.remote();
processUpdate((Map<KeyType, UpdatedObject>) currentRef.get());
}
public void processUpdate(Map<KeyType, UpdatedObject> updates) {
LOGGER.debug("LongPollClient received updates for keys: {}", updates.keySet());
for (Map.Entry<KeyType, UpdatedObject> entry : updates.entrySet()) {
objectSnapshots.put(entry.getKey(), entry.getValue().getObjectSnapshot());
snapshotIds.put(entry.getKey(), entry.getValue().getSnapshotId());
keyListeners.get(entry.getKey()).notifyChanged(entry.getValue().getObjectSnapshot());
}
}
public Map<KeyType, Integer> getSnapshotIds() {
return snapshotIds;
}
public Map<KeyType, Object> getObjectSnapshots() {
return objectSnapshots;
}
}

View file

@ -0,0 +1,12 @@
package io.ray.serve.poll;
/** The long poll namespace enum. */
public enum LongPollNamespace {
REPLICA_HANDLES,
TRAFFIC_POLICIES,
BACKEND_CONFIGS,
ROUTE_TABLE
}

View file

@ -0,0 +1,33 @@
package io.ray.serve.poll;
import java.io.Serializable;
/** The updated object that long poll client received. */
public class UpdatedObject implements Serializable {
private static final long serialVersionUID = 6245682414826079438L;
private Object objectSnapshot;
/**
* The identifier for the object's version. There is not sequential relation among different
* object's snapshot_ids.
*/
private int snapshotId;
public Object getObjectSnapshot() {
return objectSnapshot;
}
public void setObjectSnapshot(Object objectSnapshot) {
this.objectSnapshot = objectSnapshot;
}
public int getSnapshotId() {
return snapshotId;
}
public void setSnapshotId(int snapshotId) {
this.snapshotId = snapshotId;
}
}

View file

@ -0,0 +1,11 @@
package io.ray.serve.util;
import org.slf4j.helpers.MessageFormatter;
/** Ray Serve common log tool. */
public class LogUtil {
public static String format(String messagePattern, Object... args) {
return MessageFormatter.arrayFormat(messagePattern, args).getMessage();
}
}

View file

@ -0,0 +1,181 @@
package io.ray.serve.util;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
/** Tool class for reflection. */
public class ReflectUtil {
/**
* Get types of the parameters in input array, and make the types into a new array.
*
* @param parameters The input parameter array
* @return Type array corresponding to the input parameter array
*/
@SuppressWarnings("rawtypes")
private static Class[] getParameterTypes(Object[] parameters) {
Class[] parameterTypes = null;
if (ArrayUtils.isEmpty(parameters)) {
return null;
}
parameterTypes = new Class[parameters.length];
for (int i = 0; i < parameters.length; i++) {
parameterTypes[i] = parameters[i].getClass();
}
return parameterTypes;
}
/**
* Get a constructor whose each parameter's type is the most closed to the corresponding input
* parameter in terms of Java inheritance system, while {@link Class#getConstructor(Class...)}
* returns the one based on class's equal.
*
* @param targetClass the constructor's class
* @param parameters the input parameters
* @return a matching constructor of the target class
* @throws NoSuchMethodException if a matching method is not found
*/
@SuppressWarnings("rawtypes")
public static Constructor getConstructor(Class targetClass, Object... parameters)
throws NoSuchMethodException {
return reflect(
targetClass.getConstructors(), (candidate) -> true, ".<init>", targetClass, parameters);
}
/**
* Get a method whose each parameter's type is the most closed to the corresponding input
* parameter in terms of Java inheritance system, while {@link Class#getMethod(String, Class...)}
* returns the one based on class's equal.
*
* @param targetClass the constructor's class
* @param name the specified method's name
* @param parameters the input parameters
* @return a matching method of the target class
* @throws NoSuchMethodException if a matching method is not found
*/
@SuppressWarnings("rawtypes")
public static Method getMethod(Class targetClass, String name, Object... parameters)
throws NoSuchMethodException {
return reflect(
targetClass.getMethods(),
(candidate) -> StringUtils.equals(name, candidate.getName()),
"." + name,
targetClass,
parameters);
}
/**
* Get an object of {@link Executable} of the specified class according to initialization
* parameters. The result's every parameter's type is the same as, or is a subclass or
* subinterface of, the type of the corresponding input parameter. This method returns the result
* whose each parameter's type is the most closed to the corresponding input parameter in terms of
* Java inheritance system.
*
* @param <T> the type of result which extends {@link Executable}
* @param candidates a set of candidates
* @param filter the filter deciding whether to select the input candidate
* @param message a message representing the target executable object
* @param targetClass the constructor's class
* @param parameters the input parameters
* @return a matching executable of the target class
* @throws NoSuchMethodException if a matching method is not found
*/
@SuppressWarnings("rawtypes")
private static <T extends Executable> T reflect(
T[] candidates,
Function<T, Boolean> filter,
String message,
Class targetClass,
Object... parameters)
throws NoSuchMethodException {
Class[] parameterTypes = getParameterTypes(parameters);
T result = null;
for (int i = 0; i < candidates.length; i++) {
T candidate = candidates[i];
if (filter.apply(candidate)
&& assignable(parameterTypes, candidate.getParameterTypes())
&& (result == null
|| assignable(candidate.getParameterTypes(), result.getParameterTypes()))) {
result = candidate;
}
}
if (result == null) {
throw new NoSuchMethodException(
targetClass.getName() + message + argumentTypesToString(parameterTypes));
}
return result;
}
@SuppressWarnings({"unchecked", "rawtypes"})
private static boolean assignable(Class[] from, Class[] to) {
if (from == null) {
return to == null || to.length == 0;
}
if (to == null) {
return from.length == 0;
}
if (from.length != to.length) {
return false;
}
for (int i = 0; i < from.length; i++) {
if (!to[i].isAssignableFrom(from[i])) {
return false;
}
}
return true;
}
/**
* It is copied from {@link Class#argumentTypesToString(Class[])}.
*
* @param argTypes array of Class object
* @return Formatted string of the input Class array.
*/
private static String argumentTypesToString(Class<?>[] argTypes) {
StringBuilder buf = new StringBuilder();
buf.append("(");
if (argTypes != null) {
for (int i = 0; i < argTypes.length; i++) {
if (i > 0) {
buf.append(", ");
}
Class<?> c = argTypes[i];
buf.append((c == null) ? "null" : c.getName());
}
}
buf.append(")");
return buf.toString();
}
/**
* Get a string representing the specified class's all methods.
*
* @param targetClass the input class
* @return the formatted string of the specified class's all methods.
*/
@SuppressWarnings("rawtypes")
public static List<String> getMethodStrings(Class targetClass) {
if (targetClass == null) {
return null;
}
Method[] methods = targetClass.getMethods();
if (methods == null || methods.length == 0) {
return null;
}
List<String> methodStrings = new ArrayList<>();
for (int i = 0; i < methods.length; i++) {
methodStrings.add(methods[i].toString());
}
return methodStrings;
}
}

View file

@ -0,0 +1,58 @@
package io.ray.serve;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import org.testng.Assert;
import org.testng.annotations.Test;
public class RayServeReplicaTest {
@SuppressWarnings("unused")
@Test
public void test() {
boolean inited = Ray.isInitialized();
Ray.init();
try {
String controllerName = "RayServeReplicaTest";
String backendTag = "b_tag";
String replicaTag = "r_tag";
ActorHandle<ReplicaContext> controllerHandle =
Ray.actor(ReplicaContext::new, backendTag, replicaTag, controllerName, new Object())
.setName(controllerName)
.remote();
BackendConfig backendConfig = new BackendConfig();
ActorHandle<RayServeWrappedReplica> backendHandle =
Ray.actor(
RayServeWrappedReplica::new,
backendTag,
replicaTag,
"io.ray.serve.ReplicaContext",
new Object[] {backendTag, replicaTag, controllerName, new Object()},
backendConfig,
controllerName)
.remote();
backendHandle.task(RayServeWrappedReplica::ready).remote();
RequestMetadata requestMetadata = new RequestMetadata();
requestMetadata.setRequestId("RayServeReplicaTest");
requestMetadata.setCallMethod("getBackendTag");
ObjectRef<Object> resultRef =
backendHandle
.task(RayServeWrappedReplica::handle_request, requestMetadata, (Object[]) null)
.remote();
Assert.assertEquals((String) resultRef.get(), backendTag);
} finally {
if (!inited) {
Ray.shutdown();
}
}
}
}

View file

@ -0,0 +1,69 @@
package io.ray.serve;
import java.util.HashMap;
import java.util.Map;
import org.testng.Assert;
import org.testng.annotations.Test;
public class ReplicaConfigTest {
static interface Validator {
void validate();
}
@Test
public void test() {
Object dummy = new Object();
String backendDef = "io.ray.serve.ReplicaConfigTest";
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("placement_group", dummy)));
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("lifetime", dummy)));
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("name", dummy)));
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("max_restarts", dummy)));
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("num_cpus", -1.0)));
ReplicaConfig replicaConfig =
new ReplicaConfig(backendDef, null, getRayActorOptions("num_cpus", 2.0));
Assert.assertEquals(replicaConfig.getResource().get("CPU").doubleValue(), 2.0);
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("num_gpus", -1.0)));
replicaConfig = new ReplicaConfig(backendDef, null, getRayActorOptions("num_gpus", 2.0));
Assert.assertEquals(replicaConfig.getResource().get("GPU").doubleValue(), 2.0);
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("memory", -1.0)));
replicaConfig = new ReplicaConfig(backendDef, null, getRayActorOptions("memory", 2.0));
Assert.assertEquals(replicaConfig.getResource().get("memory").doubleValue(), 2.0);
expectIllegalArgumentException(
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("object_store_memory", -1.0)));
replicaConfig =
new ReplicaConfig(backendDef, null, getRayActorOptions("object_store_memory", 2.0));
Assert.assertEquals(replicaConfig.getResource().get("object_store_memory").doubleValue(), 2.0);
}
private void expectIllegalArgumentException(Validator validator) {
try {
validator.validate();
Assert.assertTrue(false, "expect IllegalArgumentException");
} catch (IllegalArgumentException e) {
}
}
private Map<String, Object> getRayActorOptions(String key, Object value) {
Map<String, Object> rayActorOptions = new HashMap<>();
rayActorOptions.put(key, value);
return rayActorOptions;
}
}

View file

@ -0,0 +1,40 @@
package io.ray.serve.api;
import io.ray.serve.RayServeException;
import io.ray.serve.ReplicaContext;
import org.testng.Assert;
import org.testng.annotations.Test;
public class ServeTest {
@Test
public void replicaContextTest() {
ReplicaContext preContext = Serve.INTERNAL_REPLICA_CONTEXT;
ReplicaContext replicaContext;
// Test null replica context.
Serve.INTERNAL_REPLICA_CONTEXT = null;
try {
replicaContext = Serve.getReplicaContext();
Assert.assertTrue(false, "expect RayServeException");
} catch (RayServeException e) {
}
// Test context setting and getting.
String backendTag = "backendTag";
String replicaTag = "replicaTag";
String controllerName = "controllerName";
Object servableObject = new Object();
Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, servableObject);
replicaContext = Serve.getReplicaContext();
Assert.assertNotNull(replicaContext, "no replica context");
Assert.assertEquals(replicaContext.getBackendTag(), backendTag);
Assert.assertEquals(replicaContext.getReplicaTag(), replicaTag);
Assert.assertEquals(replicaContext.getInternalControllerName(), controllerName);
Serve.INTERNAL_REPLICA_CONTEXT = preContext;
}
}

View file

@ -0,0 +1,19 @@
package io.ray.serve.poll;
import org.testng.Assert;
import org.testng.annotations.Test;
public class KeyListenerTest {
@Test
public void test() throws Throwable {
int[] a = new int[] {0};
KeyListener keyListener = (x) -> ((int[]) x)[0] = 1;
keyListener.notifyChanged(a);
Assert.assertEquals(a[0], 1);
}
}

View file

@ -0,0 +1,31 @@
package io.ray.serve.poll;
import org.testng.Assert;
import org.testng.annotations.Test;
public class KeyTypeTest {
@Test
public void test() {
KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1");
KeyType k2 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1");
KeyType k3 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, null);
KeyType k4 = new KeyType(LongPollNamespace.REPLICA_HANDLES, "k4");
Assert.assertEquals(k1, k1);
Assert.assertEquals(k1.hashCode(), k1.hashCode());
Assert.assertTrue(k1.equals(k1));
Assert.assertEquals(k1, k2);
Assert.assertEquals(k1.hashCode(), k2.hashCode());
Assert.assertTrue(k1.equals(k2));
Assert.assertNotEquals(k1, k3);
Assert.assertNotEquals(k1.hashCode(), k3.hashCode());
Assert.assertFalse(k1.equals(k3));
Assert.assertNotEquals(k1, k4);
Assert.assertNotEquals(k1.hashCode(), k4.hashCode());
Assert.assertFalse(k1.equals(k4));
}
}

View file

@ -0,0 +1,34 @@
package io.ray.serve.poll;
import java.util.HashMap;
import java.util.Map;
import org.testng.Assert;
import org.testng.annotations.Test;
public class LongPollClientTest {
@Test
public void test() throws Throwable {
KeyType keyType = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "backendTag");
int[] a = new int[] {0};
Map<KeyType, KeyListener> keyListeners = new HashMap<>();
keyListeners.put(keyType, (object) -> a[0] = (Integer) object);
LongPollClient longPollClient = new LongPollClient(null, keyListeners);
int snapshotId = 10;
int objectSnapshot = 20;
UpdatedObject updatedObject = new UpdatedObject();
updatedObject.setSnapshotId(snapshotId);
updatedObject.setObjectSnapshot(objectSnapshot);
Map<KeyType, UpdatedObject> updates = new HashMap<>();
updates.put(keyType, updatedObject);
longPollClient.processUpdate(updates);
Assert.assertEquals(longPollClient.getSnapshotIds().get(keyType).intValue(), snapshotId);
Assert.assertEquals(
((Integer) longPollClient.getObjectSnapshots().get(keyType)).intValue(), objectSnapshot);
Assert.assertEquals(a[0], objectSnapshot);
}
}

View file

@ -0,0 +1,13 @@
package io.ray.serve.util;
import org.testng.Assert;
import org.testng.annotations.Test;
public class LogUtilTest {
@Test
public void formatTest() {
String result = LogUtil.format("{},{},{}", "1", "2", "3");
Assert.assertEquals(result, "1,2,3");
}
}

View file

@ -0,0 +1,72 @@
package io.ray.serve.util;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
public class ReflectUtilTest {
static class ReflectExample {
public ReflectExample() {}
public ReflectExample(Integer a) {}
public ReflectExample(String a) {}
public void test(String a) {}
}
@SuppressWarnings("unchecked")
@Test
public void getConstructorTest() throws NoSuchMethodException {
Constructor<ReflectExample> constructor = ReflectUtil.getConstructor(ReflectExample.class);
Assert.assertNotNull(constructor);
constructor = ReflectUtil.getConstructor(ReflectExample.class, null);
Assert.assertNotNull(constructor);
constructor = ReflectUtil.getConstructor(ReflectExample.class, 2);
Assert.assertNotNull(constructor);
constructor = ReflectUtil.getConstructor(ReflectExample.class, "");
Assert.assertNotNull(constructor);
try {
constructor = ReflectUtil.getConstructor(ReflectExample.class, new HashMap<>());
Assert.assertTrue(false, "expect NoSuchMethodException");
} catch (NoSuchMethodException e) {
}
}
@Test
public void getMethodTest() throws NoSuchMethodException {
Method method = ReflectUtil.getMethod(ReflectExample.class, "test", "");
Assert.assertNotNull(method);
try {
method = ReflectUtil.getMethod(ReflectExample.class, "test", new HashMap<>());
Assert.assertTrue(false, "expect NoSuchMethodException");
} catch (NoSuchMethodException e) {
}
}
@Test
public void getMethodStringsTest() {
List<String> methodList = ReflectUtil.getMethodStrings(ReflectExample.class);
String result = null;
for (String method : methodList) {
if (StringUtils.contains(method, "test")) {
result = method;
}
}
Assert.assertNotNull(result, "there should be test method");
}
}

View file

@ -5,6 +5,7 @@
<packages>
<package name = "io.ray.runtime.*" />
<package name = "io.ray.test.*" />
<package name = "io.ray.serve.*" />
</packages>
</test>
<listeners>