mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[Serve] Define Java Backend (#16169)
This commit is contained in:
parent
7f78e8c014
commit
2c3ce469ba
30 changed files with 1587 additions and 0 deletions
|
@ -12,6 +12,7 @@ exports_files([
|
|||
all_modules = [
|
||||
"api",
|
||||
"runtime",
|
||||
"serve",
|
||||
"test",
|
||||
"performance_test",
|
||||
]
|
||||
|
@ -134,6 +135,27 @@ define_java_module(
|
|||
],
|
||||
)
|
||||
|
||||
define_java_module(
|
||||
name = "serve",
|
||||
define_test_lib = True,
|
||||
test_deps = [
|
||||
":io_ray_ray_api",
|
||||
":io_ray_ray_serve",
|
||||
"@maven//:org_apache_commons_commons_lang3",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
"@maven//:org_testng_testng",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":io_ray_ray_api",
|
||||
":io_ray_ray_runtime",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@maven//:org_apache_commons_commons_lang3",
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
],
|
||||
)
|
||||
|
||||
java_binary(
|
||||
name = "all_tests",
|
||||
args = ["java/testng.xml"],
|
||||
|
@ -142,6 +164,7 @@ java_binary(
|
|||
runtime_deps = [
|
||||
":io_ray_ray_performance_test",
|
||||
":io_ray_ray_runtime_test",
|
||||
":io_ray_ray_serve_test",
|
||||
":io_ray_ray_test",
|
||||
],
|
||||
)
|
||||
|
@ -237,6 +260,7 @@ genrule(
|
|||
cp -f $(location //java:io_ray_ray_runtime_pom) "$$WORK_DIR/java/runtime/pom.xml"
|
||||
cp -f $(location //java:io_ray_ray_test_pom) "$$WORK_DIR/java/test/pom.xml"
|
||||
cp -f $(location //java:io_ray_ray_performance_test_pom) "$$WORK_DIR/java/performance_test/pom.xml"
|
||||
cp -f $(location //java:io_ray_ray_serve_pom) "$$WORK_DIR/java/serve/pom.xml"
|
||||
date > $@
|
||||
""",
|
||||
local = 1,
|
||||
|
@ -251,6 +275,7 @@ java_binary(
|
|||
runtime_deps = [
|
||||
"//java:io_ray_ray_api",
|
||||
"//java:io_ray_ray_runtime",
|
||||
"//java:io_ray_ray_serve",
|
||||
"//streaming/java:io_ray_ray_streaming-api",
|
||||
"//streaming/java:io_ray_ray_streaming-runtime",
|
||||
],
|
||||
|
|
|
@ -56,6 +56,7 @@
|
|||
<modules>
|
||||
<module>api</module>
|
||||
<module>runtime</module>
|
||||
<module>serve</module>
|
||||
<module>test</module>
|
||||
</modules>
|
||||
|
||||
|
|
51
java/serve/pom.xml
Normal file
51
java/serve/pom.xml
Normal file
|
@ -0,0 +1,51 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!-- This file is auto-generated by Bazel from pom_template.xml, do not modify it. -->
|
||||
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<parent>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-superpom</artifactId>
|
||||
<version>1.1.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>ray-serve</artifactId>
|
||||
<name>ray serve</name>
|
||||
<description>java for ray serve</description>
|
||||
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-runtime</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>27.0.1-jre</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>1.7.25</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>7.3.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
32
java/serve/pom_template.xml
Normal file
32
java/serve/pom_template.xml
Normal file
|
@ -0,0 +1,32 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
{auto_gen_header}
|
||||
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<parent>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-superpom</artifactId>
|
||||
<version>1.1.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>ray-serve</artifactId>
|
||||
<name>ray serve</name>
|
||||
<description>java for ray serve</description>
|
||||
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-runtime</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
{generated_bzl_deps}
|
||||
</dependencies>
|
||||
</project>
|
86
java/serve/src/main/java/io/ray/serve/BackendConfig.java
Normal file
86
java/serve/src/main/java/io/ray/serve/BackendConfig.java
Normal file
|
@ -0,0 +1,86 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.io.Serializable;
|
||||
import org.apache.commons.lang3.builder.ReflectionToStringBuilder;
|
||||
|
||||
/** Configuration options for a backend, to be set by the user. */
|
||||
public class BackendConfig implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = 244486384449779141L;
|
||||
|
||||
/**
|
||||
* The number of processes to start up that will handle requests to this backend. Defaults to 1.
|
||||
*/
|
||||
private int numReplicas = 1;
|
||||
|
||||
/**
|
||||
* The maximum number of queries that will be sent to a replica of this backend without receiving
|
||||
* a response. Defaults to 100.
|
||||
*/
|
||||
private int maxConcurrentQueries = 100;
|
||||
|
||||
/**
|
||||
* Arguments to pass to the reconfigure method of the backend. The reconfigure method is called if
|
||||
* user_config is not None.
|
||||
*/
|
||||
private Object userConfig;
|
||||
|
||||
/**
|
||||
* Duration that backend workers will wait until there is no more work to be done before shutting
|
||||
* down. Defaults to 2s.
|
||||
*/
|
||||
private long experimentalGracefulShutdownWaitLoopS = 2;
|
||||
|
||||
/**
|
||||
* Controller waits for this duration to forcefully kill the replica for shutdown. Defaults to
|
||||
* 20s.
|
||||
*/
|
||||
private long experimentalGracefulShutdownTimeoutS = 20;
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return ReflectionToStringBuilder.toString(this);
|
||||
}
|
||||
|
||||
public int getNumReplicas() {
|
||||
return numReplicas;
|
||||
}
|
||||
|
||||
public void setNumReplicas(int numReplicas) {
|
||||
this.numReplicas = numReplicas;
|
||||
}
|
||||
|
||||
public int getMaxConcurrentQueries() {
|
||||
return maxConcurrentQueries;
|
||||
}
|
||||
|
||||
public void setMaxConcurrentQueries(int maxConcurrentQueries) {
|
||||
Preconditions.checkArgument(maxConcurrentQueries > 0, "max_concurrent_queries must be > 0");
|
||||
this.maxConcurrentQueries = maxConcurrentQueries;
|
||||
}
|
||||
|
||||
public Object getUserConfig() {
|
||||
return userConfig;
|
||||
}
|
||||
|
||||
public void setUserConfig(Object userConfig) {
|
||||
this.userConfig = userConfig;
|
||||
}
|
||||
|
||||
public long getExperimentalGracefulShutdownWaitLoopS() {
|
||||
return experimentalGracefulShutdownWaitLoopS;
|
||||
}
|
||||
|
||||
public void setExperimentalGracefulShutdownWaitLoopS(long experimentalGracefulShutdownWaitLoopS) {
|
||||
this.experimentalGracefulShutdownWaitLoopS = experimentalGracefulShutdownWaitLoopS;
|
||||
}
|
||||
|
||||
public long getExperimentalGracefulShutdownTimeoutS() {
|
||||
return experimentalGracefulShutdownTimeoutS;
|
||||
}
|
||||
|
||||
public void setExperimentalGracefulShutdownTimeoutS(long experimentalGracefulShutdownTimeoutS) {
|
||||
this.experimentalGracefulShutdownTimeoutS = experimentalGracefulShutdownTimeoutS;
|
||||
}
|
||||
}
|
19
java/serve/src/main/java/io/ray/serve/Constants.java
Normal file
19
java/serve/src/main/java/io/ray/serve/Constants.java
Normal file
|
@ -0,0 +1,19 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import java.util.List;
|
||||
|
||||
/** Ray Serve common constants. */
|
||||
public class Constants {
|
||||
|
||||
/** Name of backend reconfiguration method implemented by user. */
|
||||
public static final String BACKEND_RECONFIGURE_METHOD = "reconfigure";
|
||||
|
||||
/** Default histogram buckets for latency tracker. */
|
||||
public static final List<Double> DEFAULT_LATENCY_BUCKET_MS =
|
||||
Lists.newArrayList(
|
||||
1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0, 200.0, 500.0, 1000.0, 2000.0, 5000.0);
|
||||
|
||||
/** Name of controller listen_for_change method. */
|
||||
public static final String CONTROLLER_LISTEN_FOR_CHANGE_METHOD = "listen_for_change";
|
||||
}
|
30
java/serve/src/main/java/io/ray/serve/Query.java
Normal file
30
java/serve/src/main/java/io/ray/serve/Query.java
Normal file
|
@ -0,0 +1,30 @@
|
|||
package io.ray.serve;
|
||||
|
||||
/** Wrap request arguments and meta data. */
|
||||
public class Query {
|
||||
|
||||
private Object[] args;
|
||||
|
||||
private RequestMetadata metadata;
|
||||
|
||||
public Query(Object[] args, RequestMetadata requestMetadata) {
|
||||
this.args = args;
|
||||
this.metadata = requestMetadata;
|
||||
}
|
||||
|
||||
public Object[] getArgs() {
|
||||
return args;
|
||||
}
|
||||
|
||||
public void setArgs(Object[] args) {
|
||||
this.args = args;
|
||||
}
|
||||
|
||||
public RequestMetadata getMetadata() {
|
||||
return metadata;
|
||||
}
|
||||
|
||||
public void setMetadata(RequestMetadata metadata) {
|
||||
this.metadata = metadata;
|
||||
}
|
||||
}
|
16
java/serve/src/main/java/io/ray/serve/RayServeException.java
Normal file
16
java/serve/src/main/java/io/ray/serve/RayServeException.java
Normal file
|
@ -0,0 +1,16 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import io.ray.runtime.exception.RayException;
|
||||
|
||||
public class RayServeException extends RayException {
|
||||
|
||||
private static final long serialVersionUID = 4673951342965950469L;
|
||||
|
||||
public RayServeException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public RayServeException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
254
java/serve/src/main/java/io/ray/serve/RayServeReplica.java
Normal file
254
java/serve/src/main/java/io/ray/serve/RayServeReplica.java
Normal file
|
@ -0,0 +1,254 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import io.ray.api.BaseActorHandle;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.metric.Count;
|
||||
import io.ray.runtime.metric.Gauge;
|
||||
import io.ray.runtime.metric.Histogram;
|
||||
import io.ray.runtime.metric.MetricConfig;
|
||||
import io.ray.runtime.metric.Metrics;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.poll.KeyListener;
|
||||
import io.ray.serve.poll.KeyType;
|
||||
import io.ray.serve.poll.LongPollClient;
|
||||
import io.ray.serve.poll.LongPollNamespace;
|
||||
import io.ray.serve.util.LogUtil;
|
||||
import io.ray.serve.util.ReflectUtil;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/** Handles requests with the provided callable. */
|
||||
public class RayServeReplica {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayServeReplica.class);
|
||||
|
||||
private String backendTag;
|
||||
|
||||
private String replicaTag;
|
||||
|
||||
private BackendConfig config;
|
||||
|
||||
private AtomicInteger numOngoingRequests = new AtomicInteger();
|
||||
|
||||
private Object callable;
|
||||
|
||||
private boolean metricsRegistered = false;
|
||||
|
||||
private Count requestCounter;
|
||||
|
||||
private Count errorCounter;
|
||||
|
||||
private Count restartCounter;
|
||||
|
||||
private Histogram processingLatencyTracker;
|
||||
|
||||
private Gauge numProcessingItems;
|
||||
|
||||
private LongPollClient longPollClient;
|
||||
|
||||
public RayServeReplica(
|
||||
Object callable, BackendConfig backendConfig, BaseActorHandle actorHandle) {
|
||||
this.backendTag = Serve.getReplicaContext().getBackendTag();
|
||||
this.replicaTag = Serve.getReplicaContext().getReplicaTag();
|
||||
this.callable = callable;
|
||||
this.config = backendConfig;
|
||||
this.reconfigure(backendConfig.getUserConfig());
|
||||
|
||||
Map<KeyType, KeyListener> keyListeners = new HashMap<>();
|
||||
keyListeners.put(
|
||||
new KeyType(LongPollNamespace.BACKEND_CONFIGS, backendTag),
|
||||
newConfig -> updateBackendConfigs(newConfig));
|
||||
this.longPollClient = new LongPollClient(actorHandle, keyListeners);
|
||||
this.longPollClient.start();
|
||||
registerMetrics();
|
||||
}
|
||||
|
||||
private void registerMetrics() {
|
||||
if (!Ray.isInitialized() || Ray.getRuntimeContext().isSingleProcess()) {
|
||||
return;
|
||||
}
|
||||
|
||||
Metrics.init(MetricConfig.DEFAULT_CONFIG);
|
||||
requestCounter =
|
||||
Metrics.count()
|
||||
.name("serve_backend_request_counter")
|
||||
.description("The number of queries that have been processed in this replica.")
|
||||
.unit("")
|
||||
.tags(ImmutableMap.of("backend", backendTag))
|
||||
.register();
|
||||
|
||||
errorCounter =
|
||||
Metrics.count()
|
||||
.name("serve_backend_error_counter")
|
||||
.description("The number of exceptions that have occurred in the backend.")
|
||||
.unit("")
|
||||
.tags(ImmutableMap.of("backend", backendTag))
|
||||
.register();
|
||||
|
||||
restartCounter =
|
||||
Metrics.count()
|
||||
.name("serve_backend_replica_starts")
|
||||
.description("The number of times this replica has been restarted due to failure.")
|
||||
.unit("")
|
||||
.tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag))
|
||||
.register();
|
||||
|
||||
processingLatencyTracker =
|
||||
Metrics.histogram()
|
||||
.name("serve_backend_processing_latency_ms")
|
||||
.description("The latency for queries to be processed.")
|
||||
.unit("")
|
||||
.boundaries(Constants.DEFAULT_LATENCY_BUCKET_MS)
|
||||
.tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag))
|
||||
.register();
|
||||
|
||||
numProcessingItems =
|
||||
Metrics.gauge()
|
||||
.name("serve_replica_processing_queries")
|
||||
.description("The current number of queries being processed.")
|
||||
.unit("")
|
||||
.tags(ImmutableMap.of("backend", backendTag, "replica", replicaTag))
|
||||
.register();
|
||||
|
||||
metricsRegistered = true;
|
||||
|
||||
restartCounter.inc(1.0);
|
||||
}
|
||||
|
||||
public Object handleRequest(Query request) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
LOGGER.debug(
|
||||
"Replica {} received request {}", replicaTag, request.getMetadata().getRequestId());
|
||||
|
||||
numOngoingRequests.incrementAndGet();
|
||||
reportMetrics(() -> numProcessingItems.update(numOngoingRequests.get()));
|
||||
Object result = invokeSingle(request);
|
||||
numOngoingRequests.decrementAndGet();
|
||||
|
||||
long requestTimeMs = System.currentTimeMillis() - startTime;
|
||||
LOGGER.debug(
|
||||
"Replica {} finished request {} in {}ms",
|
||||
replicaTag,
|
||||
request.getMetadata().getRequestId(),
|
||||
requestTimeMs);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private Object invokeSingle(Query requestItem) {
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
Method methodToCall = null;
|
||||
try {
|
||||
LOGGER.debug(
|
||||
"Replica {} started executing request {}",
|
||||
replicaTag,
|
||||
requestItem.getMetadata().getRequestId());
|
||||
|
||||
methodToCall = getRunnerMethod(requestItem);
|
||||
Object result = methodToCall.invoke(callable, requestItem.getArgs());
|
||||
reportMetrics(() -> requestCounter.inc(1.0));
|
||||
return result;
|
||||
} catch (Throwable e) {
|
||||
reportMetrics(() -> errorCounter.inc(1.0));
|
||||
throw new RayServeException(
|
||||
LogUtil.format(
|
||||
"Replica {} failed to invoke method {}",
|
||||
replicaTag,
|
||||
methodToCall == null ? "unknown" : methodToCall.getName()),
|
||||
e);
|
||||
} finally {
|
||||
reportMetrics(() -> processingLatencyTracker.update(System.currentTimeMillis() - start));
|
||||
}
|
||||
}
|
||||
|
||||
private Method getRunnerMethod(Query query) {
|
||||
String methodName = query.getMetadata().getCallMethod();
|
||||
|
||||
try {
|
||||
return ReflectUtil.getMethod(
|
||||
callable.getClass(), methodName, query.getArgs() == null ? null : query.getArgs());
|
||||
} catch (NoSuchMethodException e) {
|
||||
throw new RayServeException(
|
||||
LogUtil.format(
|
||||
"Backend doesn't have method {} which is specified in the request. "
|
||||
+ "The available methods are {}",
|
||||
methodName,
|
||||
ReflectUtil.getMethodStrings(callable.getClass())));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform graceful shutdown. Trigger a graceful shutdown protocol that will wait for all the
|
||||
* queued tasks to be completed and return to the controller.
|
||||
*/
|
||||
public void drainPendingQueries() {
|
||||
while (true) {
|
||||
try {
|
||||
Thread.sleep(config.getExperimentalGracefulShutdownWaitLoopS() * 1000);
|
||||
} catch (InterruptedException e) {
|
||||
LOGGER.error(
|
||||
"Replica {} was interrupted in sheep when draining pending queries", replicaTag);
|
||||
}
|
||||
if (numOngoingRequests.get() == 0) {
|
||||
break;
|
||||
} else {
|
||||
LOGGER.debug(
|
||||
"Waiting for an additional {}s to shut down because there are {} ongoing requests.",
|
||||
config.getExperimentalGracefulShutdownWaitLoopS(),
|
||||
numOngoingRequests.get());
|
||||
}
|
||||
}
|
||||
Ray.exitActor();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconfigure user's configuration in the callable object through its reconfigure method.
|
||||
*
|
||||
* @param userConfig new user's configuration
|
||||
*/
|
||||
private void reconfigure(Object userConfig) {
|
||||
if (userConfig == null) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
Method reconfigureMethod =
|
||||
ReflectUtil.getMethod(
|
||||
callable.getClass(),
|
||||
Constants.BACKEND_RECONFIGURE_METHOD,
|
||||
userConfig); // TODO cache reconfigureMethod
|
||||
reconfigureMethod.invoke(callable, userConfig);
|
||||
} catch (NoSuchMethodException e) {
|
||||
throw new RayServeException(
|
||||
LogUtil.format(
|
||||
"user_config specified but backend {} missing {} method",
|
||||
backendTag,
|
||||
Constants.BACKEND_RECONFIGURE_METHOD));
|
||||
} catch (Throwable e) {
|
||||
throw new RayServeException(
|
||||
LogUtil.format("Backend {} failed to reconfigure user_config {}", backendTag, userConfig),
|
||||
e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update backend configs.
|
||||
*
|
||||
* @param newConfig the new configuration of backend
|
||||
*/
|
||||
private void updateBackendConfigs(Object newConfig) {
|
||||
config = (BackendConfig) newConfig;
|
||||
reconfigure(((BackendConfig) newConfig).getUserConfig());
|
||||
}
|
||||
|
||||
private void reportMetrics(Runnable runnable) {
|
||||
if (metricsRegistered) {
|
||||
runnable.run();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.BaseActorHandle;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.serve.api.Serve;
|
||||
import io.ray.serve.util.ReflectUtil;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.Optional;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/** Replica class wrapping the provided class. Note that Java function is not supported now. */
|
||||
public class RayServeWrappedReplica {
|
||||
|
||||
private RayServeReplica backend;
|
||||
|
||||
@SuppressWarnings("rawtypes")
|
||||
public RayServeWrappedReplica(
|
||||
String backendTag,
|
||||
String replicaTag,
|
||||
String backendDef,
|
||||
Object[] initArgs,
|
||||
BackendConfig backendConfig,
|
||||
String controllerName)
|
||||
throws ClassNotFoundException, NoSuchMethodException, InstantiationException,
|
||||
IllegalAccessException, IllegalArgumentException, InvocationTargetException {
|
||||
|
||||
// Instantiate the object defined by backendDef.
|
||||
Class backendClass = Class.forName(backendDef);
|
||||
Object callable = ReflectUtil.getConstructor(backendClass, initArgs).newInstance(initArgs);
|
||||
|
||||
// Get the controller by controllerName.
|
||||
Preconditions.checkArgument(
|
||||
StringUtils.isNotBlank(controllerName), "Must provide a valid controllerName");
|
||||
Optional<BaseActorHandle> optional = Ray.getActor(controllerName);
|
||||
Preconditions.checkState(optional.isPresent(), "Controller does not exist");
|
||||
|
||||
// Set the controller name so that Serve.connect() in the user's backend code will connect to
|
||||
// the instance that this backend is running in.
|
||||
Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, callable);
|
||||
|
||||
// Construct worker replica.
|
||||
backend = new RayServeReplica(callable, backendConfig, optional.get());
|
||||
}
|
||||
|
||||
/**
|
||||
* The entry method to process the request.
|
||||
*
|
||||
* @param requestMetadata request metadata
|
||||
* @param requestArgs the input parameters of the specified method of the object defined by
|
||||
* backendDef.
|
||||
* @return the result of request being processed
|
||||
*/
|
||||
public Object handle_request(RequestMetadata requestMetadata, Object[] requestArgs) {
|
||||
return backend.handleRequest(new Query(requestArgs, requestMetadata));
|
||||
}
|
||||
|
||||
/** Check whether this replica is ready or not. */
|
||||
public void ready() {
|
||||
return;
|
||||
}
|
||||
|
||||
/** Wait until there is no request in processing. It is used for stopping replica gracefully. */
|
||||
public void drain_pending_queries() {
|
||||
backend.drainPendingQueries();
|
||||
}
|
||||
}
|
115
java/serve/src/main/java/io/ray/serve/ReplicaConfig.java
Normal file
115
java/serve/src/main/java/io/ray/serve/ReplicaConfig.java
Normal file
|
@ -0,0 +1,115 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.io.Serializable;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/** Configuration options for a replica. */
|
||||
public class ReplicaConfig implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = -1442657824045704226L;
|
||||
|
||||
private String backendDef;
|
||||
|
||||
private Object[] initArgs;
|
||||
|
||||
private Map<String, Object> rayActorOptions;
|
||||
|
||||
private Map<String, Double> resource;
|
||||
|
||||
public ReplicaConfig(String backendDef, Object[] initArgs, Map<String, Object> rayActorOptions) {
|
||||
this.backendDef = backendDef;
|
||||
this.initArgs = initArgs;
|
||||
this.rayActorOptions = rayActorOptions;
|
||||
this.resource = new HashMap<>();
|
||||
this.validate();
|
||||
}
|
||||
|
||||
@SuppressWarnings({"unchecked", "rawtypes"})
|
||||
private void validate() {
|
||||
Preconditions.checkArgument(
|
||||
!rayActorOptions.containsKey("placement_group"),
|
||||
"Providing placement_group for backend actors is not currently supported.");
|
||||
|
||||
Preconditions.checkArgument(
|
||||
!rayActorOptions.containsKey("lifetime"),
|
||||
"Specifying lifetime in init_args is not allowed.");
|
||||
|
||||
Preconditions.checkArgument(
|
||||
!rayActorOptions.containsKey("name"), "Specifying name in init_args is not allowed.");
|
||||
|
||||
Preconditions.checkArgument(
|
||||
!rayActorOptions.containsKey("max_restarts"),
|
||||
"Specifying max_restarts in init_args is not allowed.");
|
||||
|
||||
// TODO Confirm num_cpus, num_gpus, memory is double in protobuf.
|
||||
// Ray defaults to zero CPUs for placement, we default to one here.
|
||||
Object numCpus = rayActorOptions.getOrDefault("num_cpus", 1.0);
|
||||
Preconditions.checkArgument(
|
||||
numCpus instanceof Double, "num_cpus in ray_actor_options must be a double.");
|
||||
Preconditions.checkArgument(
|
||||
((Double) numCpus) >= 0, "num_cpus in ray_actor_options must be >= 0.");
|
||||
resource.put("CPU", (Double) numCpus);
|
||||
|
||||
Object numGpus = rayActorOptions.getOrDefault("num_gpus", 0.0);
|
||||
Preconditions.checkArgument(
|
||||
numGpus instanceof Double, "num_gpus in ray_actor_options must be a double.");
|
||||
Preconditions.checkArgument(
|
||||
((Double) numGpus) >= 0, "num_gpus in ray_actor_options must be >= 0.");
|
||||
resource.put("GPU", (Double) numGpus);
|
||||
|
||||
Object memory = rayActorOptions.getOrDefault("memory", 0.0);
|
||||
Preconditions.checkArgument(
|
||||
memory instanceof Double, "memory in ray_actor_options must be a double.");
|
||||
Preconditions.checkArgument(
|
||||
((Double) memory) >= 0, "memory in ray_actor_options must be >= 0.");
|
||||
resource.put("memory", (Double) memory);
|
||||
|
||||
Object objectStoreMemory = rayActorOptions.getOrDefault("object_store_memory", 0.0);
|
||||
Preconditions.checkArgument(
|
||||
objectStoreMemory instanceof Double,
|
||||
"object_store_memory in ray_actor_options must be a double.");
|
||||
Preconditions.checkArgument(
|
||||
((Double) objectStoreMemory) >= 0,
|
||||
"object_store_memory in ray_actor_options must be >= 0.");
|
||||
resource.put("object_store_memory", (Double) objectStoreMemory);
|
||||
|
||||
Object customResources = rayActorOptions.getOrDefault("resources", new HashMap<>());
|
||||
Preconditions.checkArgument(
|
||||
customResources instanceof Map, "resources in ray_actor_options must be a map.");
|
||||
resource.putAll((Map) customResources);
|
||||
}
|
||||
|
||||
public String getBackendDef() {
|
||||
return backendDef;
|
||||
}
|
||||
|
||||
public void setBackendDef(String backendDef) {
|
||||
this.backendDef = backendDef;
|
||||
}
|
||||
|
||||
public Object[] getInitArgs() {
|
||||
return initArgs;
|
||||
}
|
||||
|
||||
public void setInitArgs(Object[] initArgs) {
|
||||
this.initArgs = initArgs;
|
||||
}
|
||||
|
||||
public Map<String, Object> getRayActorOptions() {
|
||||
return rayActorOptions;
|
||||
}
|
||||
|
||||
public void setRayActorOptions(Map<String, Object> rayActorOptions) {
|
||||
this.rayActorOptions = rayActorOptions;
|
||||
}
|
||||
|
||||
public Map<String, Double> getResource() {
|
||||
return resource;
|
||||
}
|
||||
|
||||
public void setResource(Map<String, Double> resource) {
|
||||
this.resource = resource;
|
||||
}
|
||||
}
|
53
java/serve/src/main/java/io/ray/serve/ReplicaContext.java
Normal file
53
java/serve/src/main/java/io/ray/serve/ReplicaContext.java
Normal file
|
@ -0,0 +1,53 @@
|
|||
package io.ray.serve;
|
||||
|
||||
/** Stores data for Serve API calls from within the user's backend code. */
|
||||
public class ReplicaContext {
|
||||
|
||||
private String backendTag;
|
||||
|
||||
private String replicaTag;
|
||||
|
||||
private String internalControllerName;
|
||||
|
||||
private Object servableObject;
|
||||
|
||||
public ReplicaContext(
|
||||
String backendTag, String replicaTag, String controllerName, Object servableObject) {
|
||||
this.backendTag = backendTag;
|
||||
this.replicaTag = replicaTag;
|
||||
this.internalControllerName = controllerName;
|
||||
this.servableObject = servableObject;
|
||||
}
|
||||
|
||||
public String getBackendTag() {
|
||||
return backendTag;
|
||||
}
|
||||
|
||||
public void setBackendTag(String backendTag) {
|
||||
this.backendTag = backendTag;
|
||||
}
|
||||
|
||||
public String getReplicaTag() {
|
||||
return replicaTag;
|
||||
}
|
||||
|
||||
public void setReplicaTag(String replicaTag) {
|
||||
this.replicaTag = replicaTag;
|
||||
}
|
||||
|
||||
public String getInternalControllerName() {
|
||||
return internalControllerName;
|
||||
}
|
||||
|
||||
public void setInternalControllerName(String internalControllerName) {
|
||||
this.internalControllerName = internalControllerName;
|
||||
}
|
||||
|
||||
public Object getServableObject() {
|
||||
return servableObject;
|
||||
}
|
||||
|
||||
public void setServableObject(Object servableObject) {
|
||||
this.servableObject = servableObject;
|
||||
}
|
||||
}
|
60
java/serve/src/main/java/io/ray/serve/RequestMetadata.java
Normal file
60
java/serve/src/main/java/io/ray/serve/RequestMetadata.java
Normal file
|
@ -0,0 +1,60 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
/** The meta data of request. */
|
||||
public class RequestMetadata implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = -8925036926565326811L;
|
||||
|
||||
private String requestId;
|
||||
|
||||
private String endpoint;
|
||||
|
||||
private String callMethod = "__call__";
|
||||
|
||||
private String httpMethod;
|
||||
|
||||
private Map<String, String> httpHeaders;
|
||||
|
||||
public String getRequestId() {
|
||||
return requestId;
|
||||
}
|
||||
|
||||
public void setRequestId(String requestId) {
|
||||
this.requestId = requestId;
|
||||
}
|
||||
|
||||
public String getEndpoint() {
|
||||
return endpoint;
|
||||
}
|
||||
|
||||
public void setEndpoint(String endpoint) {
|
||||
this.endpoint = endpoint;
|
||||
}
|
||||
|
||||
public String getCallMethod() {
|
||||
return callMethod;
|
||||
}
|
||||
|
||||
public void setCallMethod(String callMethod) {
|
||||
this.callMethod = callMethod;
|
||||
}
|
||||
|
||||
public String getHttpMethod() {
|
||||
return httpMethod;
|
||||
}
|
||||
|
||||
public void setHttpMethod(String httpMethod) {
|
||||
this.httpMethod = httpMethod;
|
||||
}
|
||||
|
||||
public Map<String, String> getHttpHeaders() {
|
||||
return httpHeaders;
|
||||
}
|
||||
|
||||
public void setHttpHeaders(Map<String, String> httpHeaders) {
|
||||
this.httpHeaders = httpHeaders;
|
||||
}
|
||||
}
|
38
java/serve/src/main/java/io/ray/serve/api/Serve.java
Normal file
38
java/serve/src/main/java/io/ray/serve/api/Serve.java
Normal file
|
@ -0,0 +1,38 @@
|
|||
package io.ray.serve.api;
|
||||
|
||||
import io.ray.serve.RayServeException;
|
||||
import io.ray.serve.ReplicaContext;
|
||||
|
||||
/** Ray Serve global API. TODO: will be riched in the Java SDK/API PR. */
|
||||
public class Serve {
|
||||
|
||||
public static ReplicaContext INTERNAL_REPLICA_CONTEXT;
|
||||
|
||||
/**
|
||||
* Set replica information to global context.
|
||||
*
|
||||
* @param backendTag backend tag
|
||||
* @param replicaTag replica tag
|
||||
* @param controllerName the controller actor's name
|
||||
* @param servableObject the servable object of the specified replica.
|
||||
*/
|
||||
public static void setInternalReplicaContext(
|
||||
String backendTag, String replicaTag, String controllerName, Object servableObject) {
|
||||
// TODO singleton.
|
||||
INTERNAL_REPLICA_CONTEXT =
|
||||
new ReplicaContext(backendTag, replicaTag, controllerName, servableObject);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the global replica context.
|
||||
*
|
||||
* @return the replica context if it exists, or throw RayServeException.
|
||||
*/
|
||||
public static ReplicaContext getReplicaContext() {
|
||||
if (INTERNAL_REPLICA_CONTEXT == null) {
|
||||
throw new RayServeException(
|
||||
"`Serve.getReplicaContext()` may only be called from within a Ray Serve backend.");
|
||||
}
|
||||
return INTERNAL_REPLICA_CONTEXT;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
package io.ray.serve.poll;
|
||||
|
||||
/** Listener of long poll. It notifies changed object to the specified key. */
|
||||
@FunctionalInterface
|
||||
public interface KeyListener {
|
||||
|
||||
void notifyChanged(Object object);
|
||||
}
|
56
java/serve/src/main/java/io/ray/serve/poll/KeyType.java
Normal file
56
java/serve/src/main/java/io/ray/serve/poll/KeyType.java
Normal file
|
@ -0,0 +1,56 @@
|
|||
package io.ray.serve.poll;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Objects;
|
||||
import org.apache.commons.lang3.builder.ReflectionToStringBuilder;
|
||||
|
||||
/** Key type of long poll. */
|
||||
public class KeyType implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = -8838552786234630401L;
|
||||
|
||||
private final LongPollNamespace longPollNamespace;
|
||||
|
||||
private final String key;
|
||||
|
||||
private int hash;
|
||||
|
||||
public KeyType(LongPollNamespace longPollNamespace, String key) {
|
||||
this.longPollNamespace = longPollNamespace;
|
||||
this.key = key;
|
||||
}
|
||||
|
||||
public LongPollNamespace getLongPollNamespace() {
|
||||
return longPollNamespace;
|
||||
}
|
||||
|
||||
public String getKey() {
|
||||
return key;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
if (hash == 0) {
|
||||
hash = Objects.hash(longPollNamespace, key);
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null || getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
KeyType keyType = (KeyType) obj;
|
||||
return Objects.equals(longPollNamespace, keyType.getLongPollNamespace())
|
||||
&& Objects.equals(key, keyType.getKey());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return ReflectionToStringBuilder.toString(this);
|
||||
}
|
||||
}
|
102
java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java
Normal file
102
java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java
Normal file
|
@ -0,0 +1,102 @@
|
|||
package io.ray.serve.poll;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.BaseActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.PyActorHandle;
|
||||
import io.ray.api.function.PyActorMethod;
|
||||
import io.ray.runtime.exception.RayActorException;
|
||||
import io.ray.runtime.exception.RayTaskException;
|
||||
import io.ray.serve.Constants;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/** The asynchronous long polling client. */
|
||||
public class LongPollClient {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(LongPollClient.class);
|
||||
|
||||
/** Handle to actor embedding LongPollHost. */
|
||||
private BaseActorHandle hostActor;
|
||||
|
||||
/** A set mapping keys to callbacks to be called on state update for the corresponding keys. */
|
||||
private Map<KeyType, KeyListener> keyListeners;
|
||||
|
||||
private Map<KeyType, Integer> snapshotIds;
|
||||
|
||||
private Map<KeyType, Object> objectSnapshots;
|
||||
|
||||
private ObjectRef<Object> currentRef;
|
||||
|
||||
/** An async thread to post the callback into. */
|
||||
private Thread pollThread;
|
||||
|
||||
public LongPollClient(BaseActorHandle hostActor, Map<KeyType, KeyListener> keyListeners) {
|
||||
|
||||
Preconditions.checkArgument(keyListeners != null && keyListeners.size() != 0);
|
||||
|
||||
this.hostActor = hostActor;
|
||||
this.keyListeners = keyListeners;
|
||||
this.snapshotIds = new ConcurrentHashMap<>();
|
||||
for (KeyType keyType : keyListeners.keySet()) {
|
||||
this.snapshotIds.put(keyType, -1);
|
||||
}
|
||||
this.objectSnapshots = new ConcurrentHashMap<>();
|
||||
this.pollThread =
|
||||
new Thread(
|
||||
() -> {
|
||||
while (true) {
|
||||
try {
|
||||
pollNext();
|
||||
} catch (RayActorException e) {
|
||||
LOGGER.debug("LongPollClient failed to connect to host. Shutting down.");
|
||||
break;
|
||||
} catch (RayTaskException e) {
|
||||
LOGGER.error("LongPollHost errored", e);
|
||||
} catch (Throwable e) {
|
||||
LOGGER.error("LongPollClient failed to update object of key {}", snapshotIds, e);
|
||||
}
|
||||
}
|
||||
},
|
||||
"backend-poll-thread");
|
||||
}
|
||||
|
||||
public void start() {
|
||||
if (!(hostActor instanceof PyActorHandle)) {
|
||||
LOGGER.warn("LongPollClient only support Python controller now.");
|
||||
return;
|
||||
}
|
||||
pollThread.start();
|
||||
}
|
||||
|
||||
/** Poll the update. */
|
||||
@SuppressWarnings("unchecked")
|
||||
public void pollNext() {
|
||||
currentRef =
|
||||
((PyActorHandle) hostActor)
|
||||
.task(PyActorMethod.of(Constants.CONTROLLER_LISTEN_FOR_CHANGE_METHOD), snapshotIds)
|
||||
.remote();
|
||||
processUpdate((Map<KeyType, UpdatedObject>) currentRef.get());
|
||||
}
|
||||
|
||||
public void processUpdate(Map<KeyType, UpdatedObject> updates) {
|
||||
|
||||
LOGGER.debug("LongPollClient received updates for keys: {}", updates.keySet());
|
||||
|
||||
for (Map.Entry<KeyType, UpdatedObject> entry : updates.entrySet()) {
|
||||
objectSnapshots.put(entry.getKey(), entry.getValue().getObjectSnapshot());
|
||||
snapshotIds.put(entry.getKey(), entry.getValue().getSnapshotId());
|
||||
keyListeners.get(entry.getKey()).notifyChanged(entry.getValue().getObjectSnapshot());
|
||||
}
|
||||
}
|
||||
|
||||
public Map<KeyType, Integer> getSnapshotIds() {
|
||||
return snapshotIds;
|
||||
}
|
||||
|
||||
public Map<KeyType, Object> getObjectSnapshots() {
|
||||
return objectSnapshots;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
package io.ray.serve.poll;
|
||||
|
||||
/** The long poll namespace enum. */
|
||||
public enum LongPollNamespace {
|
||||
REPLICA_HANDLES,
|
||||
|
||||
TRAFFIC_POLICIES,
|
||||
|
||||
BACKEND_CONFIGS,
|
||||
|
||||
ROUTE_TABLE
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package io.ray.serve.poll;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/** The updated object that long poll client received. */
|
||||
public class UpdatedObject implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = 6245682414826079438L;
|
||||
|
||||
private Object objectSnapshot;
|
||||
|
||||
/**
|
||||
* The identifier for the object's version. There is not sequential relation among different
|
||||
* object's snapshot_ids.
|
||||
*/
|
||||
private int snapshotId;
|
||||
|
||||
public Object getObjectSnapshot() {
|
||||
return objectSnapshot;
|
||||
}
|
||||
|
||||
public void setObjectSnapshot(Object objectSnapshot) {
|
||||
this.objectSnapshot = objectSnapshot;
|
||||
}
|
||||
|
||||
public int getSnapshotId() {
|
||||
return snapshotId;
|
||||
}
|
||||
|
||||
public void setSnapshotId(int snapshotId) {
|
||||
this.snapshotId = snapshotId;
|
||||
}
|
||||
}
|
11
java/serve/src/main/java/io/ray/serve/util/LogUtil.java
Normal file
11
java/serve/src/main/java/io/ray/serve/util/LogUtil.java
Normal file
|
@ -0,0 +1,11 @@
|
|||
package io.ray.serve.util;
|
||||
|
||||
import org.slf4j.helpers.MessageFormatter;
|
||||
|
||||
/** Ray Serve common log tool. */
|
||||
public class LogUtil {
|
||||
|
||||
public static String format(String messagePattern, Object... args) {
|
||||
return MessageFormatter.arrayFormat(messagePattern, args).getMessage();
|
||||
}
|
||||
}
|
181
java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java
Normal file
181
java/serve/src/main/java/io/ray/serve/util/ReflectUtil.java
Normal file
|
@ -0,0 +1,181 @@
|
|||
package io.ray.serve.util;
|
||||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Executable;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/** Tool class for reflection. */
|
||||
public class ReflectUtil {
|
||||
|
||||
/**
|
||||
* Get types of the parameters in input array, and make the types into a new array.
|
||||
*
|
||||
* @param parameters The input parameter array
|
||||
* @return Type array corresponding to the input parameter array
|
||||
*/
|
||||
@SuppressWarnings("rawtypes")
|
||||
private static Class[] getParameterTypes(Object[] parameters) {
|
||||
Class[] parameterTypes = null;
|
||||
if (ArrayUtils.isEmpty(parameters)) {
|
||||
return null;
|
||||
}
|
||||
parameterTypes = new Class[parameters.length];
|
||||
for (int i = 0; i < parameters.length; i++) {
|
||||
parameterTypes[i] = parameters[i].getClass();
|
||||
}
|
||||
return parameterTypes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a constructor whose each parameter's type is the most closed to the corresponding input
|
||||
* parameter in terms of Java inheritance system, while {@link Class#getConstructor(Class...)}
|
||||
* returns the one based on class's equal.
|
||||
*
|
||||
* @param targetClass the constructor's class
|
||||
* @param parameters the input parameters
|
||||
* @return a matching constructor of the target class
|
||||
* @throws NoSuchMethodException if a matching method is not found
|
||||
*/
|
||||
@SuppressWarnings("rawtypes")
|
||||
public static Constructor getConstructor(Class targetClass, Object... parameters)
|
||||
throws NoSuchMethodException {
|
||||
return reflect(
|
||||
targetClass.getConstructors(), (candidate) -> true, ".<init>", targetClass, parameters);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a method whose each parameter's type is the most closed to the corresponding input
|
||||
* parameter in terms of Java inheritance system, while {@link Class#getMethod(String, Class...)}
|
||||
* returns the one based on class's equal.
|
||||
*
|
||||
* @param targetClass the constructor's class
|
||||
* @param name the specified method's name
|
||||
* @param parameters the input parameters
|
||||
* @return a matching method of the target class
|
||||
* @throws NoSuchMethodException if a matching method is not found
|
||||
*/
|
||||
@SuppressWarnings("rawtypes")
|
||||
public static Method getMethod(Class targetClass, String name, Object... parameters)
|
||||
throws NoSuchMethodException {
|
||||
return reflect(
|
||||
targetClass.getMethods(),
|
||||
(candidate) -> StringUtils.equals(name, candidate.getName()),
|
||||
"." + name,
|
||||
targetClass,
|
||||
parameters);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an object of {@link Executable} of the specified class according to initialization
|
||||
* parameters. The result's every parameter's type is the same as, or is a subclass or
|
||||
* subinterface of, the type of the corresponding input parameter. This method returns the result
|
||||
* whose each parameter's type is the most closed to the corresponding input parameter in terms of
|
||||
* Java inheritance system.
|
||||
*
|
||||
* @param <T> the type of result which extends {@link Executable}
|
||||
* @param candidates a set of candidates
|
||||
* @param filter the filter deciding whether to select the input candidate
|
||||
* @param message a message representing the target executable object
|
||||
* @param targetClass the constructor's class
|
||||
* @param parameters the input parameters
|
||||
* @return a matching executable of the target class
|
||||
* @throws NoSuchMethodException if a matching method is not found
|
||||
*/
|
||||
@SuppressWarnings("rawtypes")
|
||||
private static <T extends Executable> T reflect(
|
||||
T[] candidates,
|
||||
Function<T, Boolean> filter,
|
||||
String message,
|
||||
Class targetClass,
|
||||
Object... parameters)
|
||||
throws NoSuchMethodException {
|
||||
Class[] parameterTypes = getParameterTypes(parameters);
|
||||
T result = null;
|
||||
for (int i = 0; i < candidates.length; i++) {
|
||||
T candidate = candidates[i];
|
||||
if (filter.apply(candidate)
|
||||
&& assignable(parameterTypes, candidate.getParameterTypes())
|
||||
&& (result == null
|
||||
|| assignable(candidate.getParameterTypes(), result.getParameterTypes()))) {
|
||||
result = candidate;
|
||||
}
|
||||
}
|
||||
if (result == null) {
|
||||
throw new NoSuchMethodException(
|
||||
targetClass.getName() + message + argumentTypesToString(parameterTypes));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@SuppressWarnings({"unchecked", "rawtypes"})
|
||||
private static boolean assignable(Class[] from, Class[] to) {
|
||||
if (from == null) {
|
||||
return to == null || to.length == 0;
|
||||
}
|
||||
|
||||
if (to == null) {
|
||||
return from.length == 0;
|
||||
}
|
||||
|
||||
if (from.length != to.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < from.length; i++) {
|
||||
if (!to[i].isAssignableFrom(from[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* It is copied from {@link Class#argumentTypesToString(Class[])}.
|
||||
*
|
||||
* @param argTypes array of Class object
|
||||
* @return Formatted string of the input Class array.
|
||||
*/
|
||||
private static String argumentTypesToString(Class<?>[] argTypes) {
|
||||
StringBuilder buf = new StringBuilder();
|
||||
buf.append("(");
|
||||
if (argTypes != null) {
|
||||
for (int i = 0; i < argTypes.length; i++) {
|
||||
if (i > 0) {
|
||||
buf.append(", ");
|
||||
}
|
||||
Class<?> c = argTypes[i];
|
||||
buf.append((c == null) ? "null" : c.getName());
|
||||
}
|
||||
}
|
||||
buf.append(")");
|
||||
return buf.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a string representing the specified class's all methods.
|
||||
*
|
||||
* @param targetClass the input class
|
||||
* @return the formatted string of the specified class's all methods.
|
||||
*/
|
||||
@SuppressWarnings("rawtypes")
|
||||
public static List<String> getMethodStrings(Class targetClass) {
|
||||
if (targetClass == null) {
|
||||
return null;
|
||||
}
|
||||
Method[] methods = targetClass.getMethods();
|
||||
if (methods == null || methods.length == 0) {
|
||||
return null;
|
||||
}
|
||||
List<String> methodStrings = new ArrayList<>();
|
||||
for (int i = 0; i < methods.length; i++) {
|
||||
methodStrings.add(methods[i].toString());
|
||||
}
|
||||
return methodStrings;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class RayServeReplicaTest {
|
||||
|
||||
@SuppressWarnings("unused")
|
||||
@Test
|
||||
public void test() {
|
||||
|
||||
boolean inited = Ray.isInitialized();
|
||||
|
||||
Ray.init();
|
||||
|
||||
try {
|
||||
String controllerName = "RayServeReplicaTest";
|
||||
String backendTag = "b_tag";
|
||||
String replicaTag = "r_tag";
|
||||
|
||||
ActorHandle<ReplicaContext> controllerHandle =
|
||||
Ray.actor(ReplicaContext::new, backendTag, replicaTag, controllerName, new Object())
|
||||
.setName(controllerName)
|
||||
.remote();
|
||||
|
||||
BackendConfig backendConfig = new BackendConfig();
|
||||
ActorHandle<RayServeWrappedReplica> backendHandle =
|
||||
Ray.actor(
|
||||
RayServeWrappedReplica::new,
|
||||
backendTag,
|
||||
replicaTag,
|
||||
"io.ray.serve.ReplicaContext",
|
||||
new Object[] {backendTag, replicaTag, controllerName, new Object()},
|
||||
backendConfig,
|
||||
controllerName)
|
||||
.remote();
|
||||
|
||||
backendHandle.task(RayServeWrappedReplica::ready).remote();
|
||||
|
||||
RequestMetadata requestMetadata = new RequestMetadata();
|
||||
requestMetadata.setRequestId("RayServeReplicaTest");
|
||||
requestMetadata.setCallMethod("getBackendTag");
|
||||
ObjectRef<Object> resultRef =
|
||||
backendHandle
|
||||
.task(RayServeWrappedReplica::handle_request, requestMetadata, (Object[]) null)
|
||||
.remote();
|
||||
|
||||
Assert.assertEquals((String) resultRef.get(), backendTag);
|
||||
} finally {
|
||||
if (!inited) {
|
||||
Ray.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
69
java/serve/src/test/java/io/ray/serve/ReplicaConfigTest.java
Normal file
69
java/serve/src/test/java/io/ray/serve/ReplicaConfigTest.java
Normal file
|
@ -0,0 +1,69 @@
|
|||
package io.ray.serve;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class ReplicaConfigTest {
|
||||
|
||||
static interface Validator {
|
||||
void validate();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test() {
|
||||
|
||||
Object dummy = new Object();
|
||||
String backendDef = "io.ray.serve.ReplicaConfigTest";
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("placement_group", dummy)));
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("lifetime", dummy)));
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("name", dummy)));
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("max_restarts", dummy)));
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("num_cpus", -1.0)));
|
||||
ReplicaConfig replicaConfig =
|
||||
new ReplicaConfig(backendDef, null, getRayActorOptions("num_cpus", 2.0));
|
||||
Assert.assertEquals(replicaConfig.getResource().get("CPU").doubleValue(), 2.0);
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("num_gpus", -1.0)));
|
||||
replicaConfig = new ReplicaConfig(backendDef, null, getRayActorOptions("num_gpus", 2.0));
|
||||
Assert.assertEquals(replicaConfig.getResource().get("GPU").doubleValue(), 2.0);
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("memory", -1.0)));
|
||||
replicaConfig = new ReplicaConfig(backendDef, null, getRayActorOptions("memory", 2.0));
|
||||
Assert.assertEquals(replicaConfig.getResource().get("memory").doubleValue(), 2.0);
|
||||
|
||||
expectIllegalArgumentException(
|
||||
() -> new ReplicaConfig(backendDef, null, getRayActorOptions("object_store_memory", -1.0)));
|
||||
replicaConfig =
|
||||
new ReplicaConfig(backendDef, null, getRayActorOptions("object_store_memory", 2.0));
|
||||
Assert.assertEquals(replicaConfig.getResource().get("object_store_memory").doubleValue(), 2.0);
|
||||
}
|
||||
|
||||
private void expectIllegalArgumentException(Validator validator) {
|
||||
try {
|
||||
validator.validate();
|
||||
Assert.assertTrue(false, "expect IllegalArgumentException");
|
||||
} catch (IllegalArgumentException e) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, Object> getRayActorOptions(String key, Object value) {
|
||||
Map<String, Object> rayActorOptions = new HashMap<>();
|
||||
rayActorOptions.put(key, value);
|
||||
return rayActorOptions;
|
||||
}
|
||||
}
|
40
java/serve/src/test/java/io/ray/serve/api/ServeTest.java
Normal file
40
java/serve/src/test/java/io/ray/serve/api/ServeTest.java
Normal file
|
@ -0,0 +1,40 @@
|
|||
package io.ray.serve.api;
|
||||
|
||||
import io.ray.serve.RayServeException;
|
||||
import io.ray.serve.ReplicaContext;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class ServeTest {
|
||||
|
||||
@Test
|
||||
public void replicaContextTest() {
|
||||
|
||||
ReplicaContext preContext = Serve.INTERNAL_REPLICA_CONTEXT;
|
||||
ReplicaContext replicaContext;
|
||||
|
||||
// Test null replica context.
|
||||
Serve.INTERNAL_REPLICA_CONTEXT = null;
|
||||
try {
|
||||
replicaContext = Serve.getReplicaContext();
|
||||
Assert.assertTrue(false, "expect RayServeException");
|
||||
} catch (RayServeException e) {
|
||||
|
||||
}
|
||||
|
||||
// Test context setting and getting.
|
||||
String backendTag = "backendTag";
|
||||
String replicaTag = "replicaTag";
|
||||
String controllerName = "controllerName";
|
||||
Object servableObject = new Object();
|
||||
Serve.setInternalReplicaContext(backendTag, replicaTag, controllerName, servableObject);
|
||||
|
||||
replicaContext = Serve.getReplicaContext();
|
||||
Assert.assertNotNull(replicaContext, "no replica context");
|
||||
Assert.assertEquals(replicaContext.getBackendTag(), backendTag);
|
||||
Assert.assertEquals(replicaContext.getReplicaTag(), replicaTag);
|
||||
Assert.assertEquals(replicaContext.getInternalControllerName(), controllerName);
|
||||
|
||||
Serve.INTERNAL_REPLICA_CONTEXT = preContext;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package io.ray.serve.poll;
|
||||
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class KeyListenerTest {
|
||||
|
||||
@Test
|
||||
public void test() throws Throwable {
|
||||
|
||||
int[] a = new int[] {0};
|
||||
|
||||
KeyListener keyListener = (x) -> ((int[]) x)[0] = 1;
|
||||
|
||||
keyListener.notifyChanged(a);
|
||||
|
||||
Assert.assertEquals(a[0], 1);
|
||||
}
|
||||
}
|
31
java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java
Normal file
31
java/serve/src/test/java/io/ray/serve/poll/KeyTypeTest.java
Normal file
|
@ -0,0 +1,31 @@
|
|||
package io.ray.serve.poll;
|
||||
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class KeyTypeTest {
|
||||
|
||||
@Test
|
||||
public void test() {
|
||||
KeyType k1 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1");
|
||||
KeyType k2 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "k1");
|
||||
KeyType k3 = new KeyType(LongPollNamespace.BACKEND_CONFIGS, null);
|
||||
KeyType k4 = new KeyType(LongPollNamespace.REPLICA_HANDLES, "k4");
|
||||
|
||||
Assert.assertEquals(k1, k1);
|
||||
Assert.assertEquals(k1.hashCode(), k1.hashCode());
|
||||
Assert.assertTrue(k1.equals(k1));
|
||||
|
||||
Assert.assertEquals(k1, k2);
|
||||
Assert.assertEquals(k1.hashCode(), k2.hashCode());
|
||||
Assert.assertTrue(k1.equals(k2));
|
||||
|
||||
Assert.assertNotEquals(k1, k3);
|
||||
Assert.assertNotEquals(k1.hashCode(), k3.hashCode());
|
||||
Assert.assertFalse(k1.equals(k3));
|
||||
|
||||
Assert.assertNotEquals(k1, k4);
|
||||
Assert.assertNotEquals(k1.hashCode(), k4.hashCode());
|
||||
Assert.assertFalse(k1.equals(k4));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
package io.ray.serve.poll;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class LongPollClientTest {
|
||||
|
||||
@Test
|
||||
public void test() throws Throwable {
|
||||
|
||||
KeyType keyType = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "backendTag");
|
||||
int[] a = new int[] {0};
|
||||
Map<KeyType, KeyListener> keyListeners = new HashMap<>();
|
||||
keyListeners.put(keyType, (object) -> a[0] = (Integer) object);
|
||||
LongPollClient longPollClient = new LongPollClient(null, keyListeners);
|
||||
|
||||
int snapshotId = 10;
|
||||
int objectSnapshot = 20;
|
||||
UpdatedObject updatedObject = new UpdatedObject();
|
||||
updatedObject.setSnapshotId(snapshotId);
|
||||
updatedObject.setObjectSnapshot(objectSnapshot);
|
||||
|
||||
Map<KeyType, UpdatedObject> updates = new HashMap<>();
|
||||
updates.put(keyType, updatedObject);
|
||||
longPollClient.processUpdate(updates);
|
||||
|
||||
Assert.assertEquals(longPollClient.getSnapshotIds().get(keyType).intValue(), snapshotId);
|
||||
Assert.assertEquals(
|
||||
((Integer) longPollClient.getObjectSnapshots().get(keyType)).intValue(), objectSnapshot);
|
||||
Assert.assertEquals(a[0], objectSnapshot);
|
||||
}
|
||||
}
|
13
java/serve/src/test/java/io/ray/serve/util/LogUtilTest.java
Normal file
13
java/serve/src/test/java/io/ray/serve/util/LogUtilTest.java
Normal file
|
@ -0,0 +1,13 @@
|
|||
package io.ray.serve.util;
|
||||
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class LogUtilTest {
|
||||
|
||||
@Test
|
||||
public void formatTest() {
|
||||
String result = LogUtil.format("{},{},{}", "1", "2", "3");
|
||||
Assert.assertEquals(result, "1,2,3");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
package io.ray.serve.util;
|
||||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class ReflectUtilTest {
|
||||
|
||||
static class ReflectExample {
|
||||
public ReflectExample() {}
|
||||
|
||||
public ReflectExample(Integer a) {}
|
||||
|
||||
public ReflectExample(String a) {}
|
||||
|
||||
public void test(String a) {}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Test
|
||||
public void getConstructorTest() throws NoSuchMethodException {
|
||||
|
||||
Constructor<ReflectExample> constructor = ReflectUtil.getConstructor(ReflectExample.class);
|
||||
Assert.assertNotNull(constructor);
|
||||
|
||||
constructor = ReflectUtil.getConstructor(ReflectExample.class, null);
|
||||
Assert.assertNotNull(constructor);
|
||||
|
||||
constructor = ReflectUtil.getConstructor(ReflectExample.class, 2);
|
||||
Assert.assertNotNull(constructor);
|
||||
|
||||
constructor = ReflectUtil.getConstructor(ReflectExample.class, "");
|
||||
Assert.assertNotNull(constructor);
|
||||
|
||||
try {
|
||||
constructor = ReflectUtil.getConstructor(ReflectExample.class, new HashMap<>());
|
||||
Assert.assertTrue(false, "expect NoSuchMethodException");
|
||||
} catch (NoSuchMethodException e) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getMethodTest() throws NoSuchMethodException {
|
||||
|
||||
Method method = ReflectUtil.getMethod(ReflectExample.class, "test", "");
|
||||
Assert.assertNotNull(method);
|
||||
|
||||
try {
|
||||
method = ReflectUtil.getMethod(ReflectExample.class, "test", new HashMap<>());
|
||||
Assert.assertTrue(false, "expect NoSuchMethodException");
|
||||
} catch (NoSuchMethodException e) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getMethodStringsTest() {
|
||||
List<String> methodList = ReflectUtil.getMethodStrings(ReflectExample.class);
|
||||
String result = null;
|
||||
for (String method : methodList) {
|
||||
if (StringUtils.contains(method, "test")) {
|
||||
result = method;
|
||||
}
|
||||
}
|
||||
Assert.assertNotNull(result, "there should be test method");
|
||||
}
|
||||
}
|
|
@ -5,6 +5,7 @@
|
|||
<packages>
|
||||
<package name = "io.ray.runtime.*" />
|
||||
<package name = "io.ray.test.*" />
|
||||
<package name = "io.ray.serve.*" />
|
||||
</packages>
|
||||
</test>
|
||||
<listeners>
|
||||
|
|
Loading…
Add table
Reference in a new issue