[runtime env] plugin refactor[6/n]: java api refactor (#26783)

This commit is contained in:
Guyang Song 2022-07-26 09:00:57 +08:00 committed by GitHub
parent 778a799909
commit 419e78180a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 309 additions and 97 deletions

View file

@ -198,3 +198,5 @@ try-import %workspace%/.llvm-local.bazelrc
# It picks up the system headers when someone has protobuf installed via Homebrew.
# Work around for https://github.com/bazelbuild/bazel/issues/8053
build:macos --sandbox_block_path=/usr/local/
#This option controls whether javac checks for missing direct dependencies.
build --strict_java_deps=off

View file

@ -80,6 +80,8 @@ define_java_module(
visibility = ["//visibility:public"],
deps = [
":io_ray_ray_api",
"@maven//:com_fasterxml_jackson_core_jackson_databind",
"@maven//:com_github_java_json_tools_json_schema_validator",
"@maven//:com_google_code_gson_gson",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",

View file

@ -0,0 +1,8 @@
package io.ray.api.exception;
public class RuntimeEnvException extends RayException {
public RuntimeEnvException(String message) {
super(message);
}
}

View file

@ -200,7 +200,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
group,
bundleIndex,
concurrencyGroups,
runtimeEnv != null ? runtimeEnv.toJsonBytes() : "",
runtimeEnv != null ? runtimeEnv.serializeToRuntimeEnvInfo() : "",
namespace,
maxPendingCalls);
}

View file

@ -26,7 +26,8 @@ public class CallOptions extends BaseTaskOptions {
this.group = group;
this.bundleIndex = bundleIndex;
this.concurrencyGroupName = concurrencyGroupName;
this.serializedRuntimeEnvInfo = runtimeEnv == null ? "" : runtimeEnv.toJsonBytes();
this.serializedRuntimeEnvInfo =
runtimeEnv == null ? "" : runtimeEnv.serializeToRuntimeEnvInfo();
}
/** This inner class for building CallOptions. */

View file

@ -7,6 +7,7 @@ import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.WaitResult;
import io.ray.api.concurrencygroup.ConcurrencyGroup;
import io.ray.api.exception.RuntimeEnvException;
import io.ray.api.function.CppActorClass;
import io.ray.api.function.CppActorMethod;
import io.ray.api.function.CppFunction;
@ -295,7 +296,10 @@ public interface RayRuntime {
List<ConcurrencyGroup> extractConcurrencyGroups(RayFuncR<?> actorConstructorLambda);
/** Create runtime env instance at runtime. */
RuntimeEnv createRuntimeEnv(Map<String, String> envVars, List<String> jars);
RuntimeEnv createRuntimeEnv();
/** Deserialize runtime env instance at runtime. */
RuntimeEnv deserializeRuntimeEnv(String serializedRuntimeEnv) throws RuntimeEnvException;
/// Get the parallel actor context at runtime.
ParallelActorContext getParallelActorContext();

View file

@ -1,38 +1,91 @@
package io.ray.api.runtimeenv;
import io.ray.api.Ray;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import io.ray.api.exception.RuntimeEnvException;
import io.ray.api.runtimeenv.types.RuntimeEnvName;
/** This is an experimental API to let you set runtime environments for your actors. */
/** This class provides interfaces of setting runtime environments for job/actor/task. */
public interface RuntimeEnv {
String toJsonBytes();
/**
* Set a runtime env field by name and Object.
*
* @param name The build-in names or a runtime env plugin name.
* @see RuntimeEnvName
* @param value An object with primitive data type or plain old java object(POJO).
* @throws RuntimeEnvException
*/
void set(String name, Object value) throws RuntimeEnvException;
/**
* Set a runtime env field by name and json string.
*
* @param name The build-in names or a runtime env plugin name.
* @see RuntimeEnvName
* @param jsonStr A json string represents the runtime env field.
* @throws RuntimeEnvException
*/
public void setJsonStr(String name, String jsonStr) throws RuntimeEnvException;
/**
* Get the object of a runtime env field.
*
* @param name The build-in names or a runtime env plugin name.
* @param classOfT The class of a primitive data type or plain old java object(POJO) type.
* @return
* @param <T> A primitive data type or plain old java object(POJO) type.
* @throws RuntimeEnvException
*/
public <T> T get(String name, Class<T> classOfT) throws RuntimeEnvException;
/**
* Get the json string of a runtime env field.
*
* @param name The build-in names or a runtime env plugin name.
* @return A json string represents the runtime env field.
* @throws RuntimeEnvException
*/
public String getJsonStr(String name) throws RuntimeEnvException;
/**
* Remove a runtime env field by name.
*
* @param name The build-in names or a runtime env plugin name.
* @throws RuntimeEnvException
*/
public void remove(String name) throws RuntimeEnvException;
/**
* Serialize the runtime env to string.
*
* @return The serialized runtime env string.
* @throws RuntimeEnvException
*/
public String serialize() throws RuntimeEnvException;
/**
* Serialize the runtime env to string of RuntimeEnvInfo.
*
* @return The serialized runtime env info string.
* @throws RuntimeEnvException
*/
public String serializeToRuntimeEnvInfo() throws RuntimeEnvException;
/**
* Deserialize the runtime env from string.
*
* @param serializedRuntimeEnv The serialized runtime env string.
* @return The deserialized RuntimeEnv instance.
* @throws RuntimeEnvException
*/
public static RuntimeEnv deserialize(String serializedRuntimeEnv) throws RuntimeEnvException {
return Ray.internal().deserializeRuntimeEnv(serializedRuntimeEnv);
}
/** The builder which is used to generate a RuntimeEnv instance. */
public static class Builder {
private Map<String, String> envVars = new HashMap<>();
private List<String> jars = new ArrayList<>();
/** Add environment variable as runtime environment for the actor or job. */
public Builder addEnvVar(String key, String value) {
envVars.put(key, value);
return this;
}
/**
* Add the jars as runtime environment for the actor or job. We now support both `.jar` files
* and `.zip` files.
*/
public Builder addJars(List<String> jars) {
this.jars.addAll(jars);
return this;
}
public RuntimeEnv build() {
return Ray.internal().createRuntimeEnv(envVars, jars);
return Ray.internal().createRuntimeEnv();
}
}
}

View file

@ -0,0 +1,10 @@
package io.ray.api.runtimeenv.types;
public class RuntimeEnvName {
/** The environment variables which type is `Map<String, String>`. */
public static final String ENV_VARS = "env_vars";
/** The dependent java jars which type is `List<String>`. */
public static final String JARS = "java_jars";
}

View file

@ -4,6 +4,8 @@ load("@rules_jvm_external//:specs.bzl", "maven")
def gen_java_deps():
maven_install(
artifacts = [
"com.fasterxml.jackson.core:jackson-databind:2.13.3",
"com.github.java-json-tools:json-schema-validator:2.2.14",
"com.google.code.gson:gson:2.8.5",
"com.google.guava:guava:30.0-jre",
"com.google.protobuf:protobuf-java:3.19.4",

View file

@ -1,5 +1,8 @@
package io.ray.runtime;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.ray.api.ActorHandle;
@ -9,6 +12,7 @@ import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.WaitResult;
import io.ray.api.concurrencygroup.ConcurrencyGroup;
import io.ray.api.exception.RuntimeEnvException;
import io.ray.api.function.CppActorClass;
import io.ray.api.function.CppActorMethod;
import io.ray.api.function.CppFunction;
@ -50,7 +54,6 @@ import io.ray.runtime.util.ConcurrencyGroupUtils;
import io.ray.runtime.utils.parallelactor.ParallelActorContextImpl;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
@ -72,6 +75,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private static ParallelActorContextImpl parallelActorContextImpl = new ParallelActorContextImpl();
private static final ObjectMapper MAPPER = new ObjectMapper();
public AbstractRayRuntime(RayConfig rayConfig) {
this.rayConfig = rayConfig;
runtimeContext = new RuntimeContextImpl(this);
@ -306,8 +311,19 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
@Override
public RuntimeEnv createRuntimeEnv(Map<String, String> envVars, List<String> jars) {
return new RuntimeEnvImpl(envVars, jars);
public RuntimeEnv createRuntimeEnv() {
return new RuntimeEnvImpl();
}
@Override
public RuntimeEnv deserializeRuntimeEnv(String serializedRuntimeEnv) throws RuntimeEnvException {
RuntimeEnvImpl runtimeEnv = new RuntimeEnvImpl();
try {
runtimeEnv.runtimeEnvs = (ObjectNode) MAPPER.readTree(serializedRuntimeEnv);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
return runtimeEnv;
}
private ObjectRef callNormalFunction(

View file

@ -1,7 +1,6 @@
package io.ray.runtime;
import com.google.common.base.Preconditions;
import com.google.gson.Gson;
import io.ray.api.BaseActorHandle;
import io.ray.api.exception.RayIntentionalSystemExitException;
import io.ray.api.id.ActorId;
@ -18,7 +17,6 @@ import io.ray.runtime.gcs.GcsClientOptions;
import io.ray.runtime.generated.Common.WorkerType;
import io.ray.runtime.generated.Gcs.GcsNodeInfo;
import io.ray.runtime.generated.Gcs.JobConfig;
import io.ray.runtime.generated.RuntimeEnvCommon.RuntimeEnvInfo;
import io.ray.runtime.object.NativeObjectStore;
import io.ray.runtime.runner.RunManager;
import io.ray.runtime.task.NativeTaskExecutor;
@ -26,7 +24,6 @@ import io.ray.runtime.task.NativeTaskSubmitter;
import io.ray.runtime.task.TaskExecutor;
import io.ray.runtime.util.BinaryFileUtil;
import io.ray.runtime.util.JniUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -111,23 +108,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
.addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker)
.addAllCodeSearchPath(rayConfig.codeSearchPath)
.setRayNamespace(rayConfig.namespace);
RuntimeEnvInfo.Builder runtimeEnvInfoBuilder = RuntimeEnvInfo.newBuilder();
if (rayConfig.runtimeEnvImpl != null) {
Map<String, Object> runtimeEnvMap = new HashMap<>();
if (!rayConfig.runtimeEnvImpl.getEnvVars().isEmpty()) {
runtimeEnvMap.put("env_vars", rayConfig.runtimeEnvImpl.getEnvVars());
}
final List<String> jarUrls = rayConfig.runtimeEnvImpl.getJars();
if (jarUrls != null && !jarUrls.isEmpty()) {
runtimeEnvMap.put("java_jars", jarUrls);
}
runtimeEnvInfoBuilder.setSerializedRuntimeEnv(new Gson().toJson(runtimeEnvMap));
} else {
runtimeEnvInfoBuilder.setSerializedRuntimeEnv("{}");
}
jobConfigBuilder.setRuntimeEnvInfo(runtimeEnvInfoBuilder.build());
jobConfigBuilder.setRuntimeEnvInfo(rayConfig.runtimeEnvImpl.GenerateRuntimeEnvInfo());
jobConfigBuilder.setDefaultActorLifetime(
rayConfig.defaultActorLifetime == ActorLifetime.DETACHED
? JobConfig.ActorLifetime.DETACHED

View file

@ -8,6 +8,7 @@ import com.typesafe.config.ConfigFactory;
import com.typesafe.config.ConfigRenderOptions;
import io.ray.api.id.JobId;
import io.ray.api.options.ActorLifetime;
import io.ray.api.runtimeenv.types.RuntimeEnvName;
import io.ray.runtime.generated.Common.WorkerType;
import io.ray.runtime.runtimeenv.RuntimeEnvImpl;
import io.ray.runtime.util.NetworkUtil;
@ -215,7 +216,9 @@ public class RayConfig {
if (config.hasPath(jarsPath)) {
jarUrls = config.getStringList(jarsPath);
}
runtimeEnvImpl = new RuntimeEnvImpl(envVars, jarUrls);
runtimeEnvImpl = new RuntimeEnvImpl();
runtimeEnvImpl.set(RuntimeEnvName.ENV_VARS, envVars);
runtimeEnvImpl.set(RuntimeEnvName.JARS, jarUrls);
}
{

View file

@ -1,58 +1,90 @@
package io.ray.runtime.runtimeenv;
import com.google.gson.Gson;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.github.fge.jackson.JsonLoader;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import io.ray.api.exception.RuntimeEnvException;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.runtime.generated.RuntimeEnvCommon;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.io.IOException;
public class RuntimeEnvImpl implements RuntimeEnv {
private Map<String, String> envVars = new HashMap<>();
private static final ObjectMapper MAPPER = new ObjectMapper();
private List<String> jars = new ArrayList<>();
public ObjectNode runtimeEnvs = MAPPER.createObjectNode();
public RuntimeEnvImpl(Map<String, String> envVars, List<String> jars) {
this.envVars = envVars;
if (jars != null) {
this.jars = jars;
public RuntimeEnvImpl() {}
@Override
public void set(String name, Object value) throws RuntimeEnvException {
JsonNode node = null;
try {
node = MAPPER.valueToTree(value);
} catch (IllegalArgumentException e) {
throw new RuntimeException(e);
}
}
public Map<String, String> getEnvVars() {
return envVars;
}
public List<String> getJars() {
return jars;
runtimeEnvs.set(name, node);
}
@Override
public String toJsonBytes() {
// Get serializedRuntimeEnv
String serializedRuntimeEnv = "{}";
Map<String, Object> runtimeEnvMap = new HashMap<>();
if (!envVars.isEmpty()) {
runtimeEnvMap.put("env_vars", envVars);
}
if (!jars.isEmpty()) {
runtimeEnvMap.put("java_jars", jars);
public void setJsonStr(String name, String jsonStr) throws RuntimeEnvException {
JsonNode node = null;
try {
node = JsonLoader.fromString(jsonStr);
} catch (IOException e) {
throw new RuntimeException(e);
}
runtimeEnvs.set(name, node);
}
serializedRuntimeEnv = new Gson().toJson(runtimeEnvMap);
// Get serializedRuntimeEnvInfo
if (serializedRuntimeEnv.equals("{}") || serializedRuntimeEnv.isEmpty()) {
return "{}";
@Override
public <T> T get(String name, Class<T> classOfT) throws RuntimeEnvException {
JsonNode jsonNode = runtimeEnvs.get(name);
if (jsonNode == null) {
return null;
}
try {
return MAPPER.treeToValue(jsonNode, classOfT);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
@Override
public String getJsonStr(String name) throws RuntimeEnvException {
try {
return MAPPER.writeValueAsString(runtimeEnvs.get(name));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
@Override
public void remove(String name) {
runtimeEnvs.remove(name);
}
@Override
public String serialize() throws RuntimeEnvException {
try {
return MAPPER.writeValueAsString(runtimeEnvs);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
@Override
public String serializeToRuntimeEnvInfo() throws RuntimeEnvException {
// TODO(SongGuyang): Expose runtime env config API to users.
String serializeRuntimeEnv = serialize();
RuntimeEnvCommon.RuntimeEnvInfo.Builder protoRuntimeEnvInfoBuilder =
RuntimeEnvCommon.RuntimeEnvInfo.newBuilder();
protoRuntimeEnvInfoBuilder.setSerializedRuntimeEnv(serializedRuntimeEnv);
protoRuntimeEnvInfoBuilder.setSerializedRuntimeEnv(serializeRuntimeEnv);
JsonFormat.Printer printer = JsonFormat.printer();
try {
return printer.print(protoRuntimeEnvInfoBuilder);
@ -60,4 +92,17 @@ public class RuntimeEnvImpl implements RuntimeEnv {
throw new RuntimeException(e);
}
}
public RuntimeEnvCommon.RuntimeEnvInfo GenerateRuntimeEnvInfo() throws RuntimeEnvException {
RuntimeEnvCommon.RuntimeEnvInfo.Builder protoRuntimeEnvInfoBuilder =
RuntimeEnvCommon.RuntimeEnvInfo.newBuilder();
try {
protoRuntimeEnvInfoBuilder.setSerializedRuntimeEnv(MAPPER.writeValueAsString(runtimeEnvs));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
return protoRuntimeEnvInfoBuilder.build();
}
}

View file

@ -4,7 +4,10 @@ import com.google.common.collect.ImmutableList;
import io.ray.api.ActorHandle;
import io.ray.api.Ray;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.api.runtimeenv.types.RuntimeEnvName;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -62,12 +65,16 @@ public class RuntimeEnvTest {
public void testEnvVarsForNormalTask() {
try {
Ray.init();
RuntimeEnv runtimeEnv =
new RuntimeEnv.Builder()
.addEnvVar("KEY1", "A")
.addEnvVar("KEY2", "B")
.addEnvVar("KEY1", "C")
.build();
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build();
Map<String, String> envMap =
new HashMap<String, String>() {
{
put("KEY1", "A");
put("KEY2", "B");
put("KEY1", "C");
}
};
runtimeEnv.set(RuntimeEnvName.ENV_VARS, envMap);
String val =
Ray.task(RuntimeEnvTest::getEnvVar, "KEY1").setRuntimeEnv(runtimeEnv).remote().get();
@ -85,7 +92,14 @@ public class RuntimeEnvTest {
System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B");
try {
Ray.init();
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addEnvVar("KEY1", "C").build();
Map<String, String> envMap =
new HashMap<String, String>() {
{
put("KEY1", "C");
}
};
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build();
runtimeEnv.set(RuntimeEnvName.ENV_VARS, envMap);
/// value of KEY1 is overwritten to `C` and KEY2s is extended from job config.
String val =
@ -101,7 +115,8 @@ public class RuntimeEnvTest {
private static void testDownloadAndLoadPackage(String url) {
try {
Ray.init();
final RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addJars(ImmutableList.of(url)).build();
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build();
runtimeEnv.set(RuntimeEnvName.JARS, ImmutableList.of(url));
ActorHandle<A> actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote();
boolean ret = actor1.task(A::findClass, FOO_CLASS_NAME).remote().get();
Assert.assertTrue(ret);
@ -133,7 +148,8 @@ public class RuntimeEnvTest {
List<String> urls, List<String> classNames) {
try {
Ray.init();
final RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addJars(urls).build();
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build();
runtimeEnv.set(RuntimeEnvName.JARS, urls);
boolean ret =
Ray.task(RuntimeEnvTest::findClasses, classNames)
.setRuntimeEnv(runtimeEnv)
@ -178,4 +194,73 @@ public class RuntimeEnvTest {
Ray.shutdown();
}
}
private static class Pip {
private String[] packages;
private Boolean pip_check;
public String[] getPackages() {
return packages;
}
public void setPackages(String[] packages) {
this.packages = packages;
}
public Boolean getPip_check() {
return pip_check;
}
public void setPip_check(Boolean pip_check) {
this.pip_check = pip_check;
}
}
public void testRuntimeEnvAPI() {
try {
Ray.init();
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build();
String workingDir = "https://path/to/working_dir.zip";
runtimeEnv.set("working_dir", workingDir);
String[] py_modules =
new String[] {"https://path/to/py_modules1.zip", "https://path/to/py_modules2.zip"};
runtimeEnv.set("py_modules", py_modules);
Pip pip = new Pip();
pip.setPackages(new String[] {"requests", "tensorflow"});
pip.setPip_check(true);
runtimeEnv.set("pip", pip);
String serializedRuntimeEnv = runtimeEnv.serialize();
RuntimeEnv runtimeEnv2 = RuntimeEnv.deserialize(serializedRuntimeEnv);
Assert.assertEquals(runtimeEnv2.get("working_dir", String.class), workingDir);
Assert.assertEquals(runtimeEnv2.get("py_modules", String[].class), py_modules);
Pip pip2 = runtimeEnv2.get("pip", Pip.class);
Assert.assertEquals(pip2.getPackages(), pip.getPackages());
Assert.assertEquals(pip2.getPip_check(), pip.getPip_check());
runtimeEnv2.remove("working_dir");
runtimeEnv2.remove("py_modules");
runtimeEnv2.remove("pip");
Assert.assertEquals(runtimeEnv2.get("working_dir", String.class), null);
Assert.assertEquals(runtimeEnv2.get("py_modules", String[].class), null);
Assert.assertEquals(runtimeEnv2.get("pip", Pip.class), null);
} finally {
Ray.shutdown();
}
}
public void testRuntimeEnvJsonStringAPI() {
try {
Ray.init();
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build();
String pipString = "{\"packages\":[\"requests\",\"tensorflow\"],\"pip_check\":false}";
runtimeEnv.setJsonStr("pip", pipString);
String serializedRuntimeEnv = runtimeEnv.serialize();
RuntimeEnv runtimeEnv2 = RuntimeEnv.deserialize(serializedRuntimeEnv);
Assert.assertEquals(runtimeEnv2.getJsonStr("pip"), pipString);
} finally {
Ray.shutdown();
}
}
}