From 419e78180a8fa7a53918ff62eab6042be0d70ac3 Mon Sep 17 00:00:00 2001 From: Guyang Song Date: Tue, 26 Jul 2022 09:00:57 +0800 Subject: [PATCH] [runtime env] plugin refactor[6/n]: java api refactor (#26783) --- .bazelrc | 2 + java/BUILD.bazel | 2 + .../api/exception/RuntimeEnvException.java | 8 ++ .../ray/api/options/ActorCreationOptions.java | 2 +- .../java/io/ray/api/options/CallOptions.java | 3 +- .../java/io/ray/api/runtime/RayRuntime.java | 6 +- .../io/ray/api/runtimeenv/RuntimeEnv.java | 105 ++++++++++++---- .../api/runtimeenv/types/RuntimeEnvName.java | 10 ++ java/dependencies.bzl | 2 + .../io/ray/runtime/AbstractRayRuntime.java | 22 +++- .../java/io/ray/runtime/RayNativeRuntime.java | 21 +--- .../java/io/ray/runtime/config/RayConfig.java | 5 +- .../runtime/runtimeenv/RuntimeEnvImpl.java | 115 ++++++++++++------ .../main/java/io/ray/test/RuntimeEnvTest.java | 103 ++++++++++++++-- 14 files changed, 309 insertions(+), 97 deletions(-) create mode 100644 java/api/src/main/java/io/ray/api/exception/RuntimeEnvException.java create mode 100644 java/api/src/main/java/io/ray/api/runtimeenv/types/RuntimeEnvName.java diff --git a/.bazelrc b/.bazelrc index 0786d8177..1c2c7ec0d 100644 --- a/.bazelrc +++ b/.bazelrc @@ -198,3 +198,5 @@ try-import %workspace%/.llvm-local.bazelrc # It picks up the system headers when someone has protobuf installed via Homebrew. # Work around for https://github.com/bazelbuild/bazel/issues/8053 build:macos --sandbox_block_path=/usr/local/ +#This option controls whether javac checks for missing direct dependencies. +build --strict_java_deps=off diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 33c2b7b44..7fe1bb8f6 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -80,6 +80,8 @@ define_java_module( visibility = ["//visibility:public"], deps = [ ":io_ray_ray_api", + "@maven//:com_fasterxml_jackson_core_jackson_databind", + "@maven//:com_github_java_json_tools_json_schema_validator", "@maven//:com_google_code_gson_gson", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", diff --git a/java/api/src/main/java/io/ray/api/exception/RuntimeEnvException.java b/java/api/src/main/java/io/ray/api/exception/RuntimeEnvException.java new file mode 100644 index 000000000..e5b1024a6 --- /dev/null +++ b/java/api/src/main/java/io/ray/api/exception/RuntimeEnvException.java @@ -0,0 +1,8 @@ +package io.ray.api.exception; + +public class RuntimeEnvException extends RayException { + + public RuntimeEnvException(String message) { + super(message); + } +} diff --git a/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java index 33aaa2315..779efb494 100644 --- a/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java @@ -200,7 +200,7 @@ public class ActorCreationOptions extends BaseTaskOptions { group, bundleIndex, concurrencyGroups, - runtimeEnv != null ? runtimeEnv.toJsonBytes() : "", + runtimeEnv != null ? runtimeEnv.serializeToRuntimeEnvInfo() : "", namespace, maxPendingCalls); } diff --git a/java/api/src/main/java/io/ray/api/options/CallOptions.java b/java/api/src/main/java/io/ray/api/options/CallOptions.java index e646887ec..e0af44e0c 100644 --- a/java/api/src/main/java/io/ray/api/options/CallOptions.java +++ b/java/api/src/main/java/io/ray/api/options/CallOptions.java @@ -26,7 +26,8 @@ public class CallOptions extends BaseTaskOptions { this.group = group; this.bundleIndex = bundleIndex; this.concurrencyGroupName = concurrencyGroupName; - this.serializedRuntimeEnvInfo = runtimeEnv == null ? "" : runtimeEnv.toJsonBytes(); + this.serializedRuntimeEnvInfo = + runtimeEnv == null ? "" : runtimeEnv.serializeToRuntimeEnvInfo(); } /** This inner class for building CallOptions. */ diff --git a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java index 2616a9813..b8a649448 100644 --- a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java @@ -7,6 +7,7 @@ import io.ray.api.ObjectRef; import io.ray.api.PyActorHandle; import io.ray.api.WaitResult; import io.ray.api.concurrencygroup.ConcurrencyGroup; +import io.ray.api.exception.RuntimeEnvException; import io.ray.api.function.CppActorClass; import io.ray.api.function.CppActorMethod; import io.ray.api.function.CppFunction; @@ -295,7 +296,10 @@ public interface RayRuntime { List extractConcurrencyGroups(RayFuncR actorConstructorLambda); /** Create runtime env instance at runtime. */ - RuntimeEnv createRuntimeEnv(Map envVars, List jars); + RuntimeEnv createRuntimeEnv(); + + /** Deserialize runtime env instance at runtime. */ + RuntimeEnv deserializeRuntimeEnv(String serializedRuntimeEnv) throws RuntimeEnvException; /// Get the parallel actor context at runtime. ParallelActorContext getParallelActorContext(); diff --git a/java/api/src/main/java/io/ray/api/runtimeenv/RuntimeEnv.java b/java/api/src/main/java/io/ray/api/runtimeenv/RuntimeEnv.java index 4e3ca3089..730f47ac7 100644 --- a/java/api/src/main/java/io/ray/api/runtimeenv/RuntimeEnv.java +++ b/java/api/src/main/java/io/ray/api/runtimeenv/RuntimeEnv.java @@ -1,38 +1,91 @@ package io.ray.api.runtimeenv; import io.ray.api.Ray; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import io.ray.api.exception.RuntimeEnvException; +import io.ray.api.runtimeenv.types.RuntimeEnvName; -/** This is an experimental API to let you set runtime environments for your actors. */ +/** This class provides interfaces of setting runtime environments for job/actor/task. */ public interface RuntimeEnv { - String toJsonBytes(); + /** + * Set a runtime env field by name and Object. + * + * @param name The build-in names or a runtime env plugin name. + * @see RuntimeEnvName + * @param value An object with primitive data type or plain old java object(POJO). + * @throws RuntimeEnvException + */ + void set(String name, Object value) throws RuntimeEnvException; + /** + * Set a runtime env field by name and json string. + * + * @param name The build-in names or a runtime env plugin name. + * @see RuntimeEnvName + * @param jsonStr A json string represents the runtime env field. + * @throws RuntimeEnvException + */ + public void setJsonStr(String name, String jsonStr) throws RuntimeEnvException; + + /** + * Get the object of a runtime env field. + * + * @param name The build-in names or a runtime env plugin name. + * @param classOfT The class of a primitive data type or plain old java object(POJO) type. + * @return + * @param A primitive data type or plain old java object(POJO) type. + * @throws RuntimeEnvException + */ + public T get(String name, Class classOfT) throws RuntimeEnvException; + + /** + * Get the json string of a runtime env field. + * + * @param name The build-in names or a runtime env plugin name. + * @return A json string represents the runtime env field. + * @throws RuntimeEnvException + */ + public String getJsonStr(String name) throws RuntimeEnvException; + + /** + * Remove a runtime env field by name. + * + * @param name The build-in names or a runtime env plugin name. + * @throws RuntimeEnvException + */ + public void remove(String name) throws RuntimeEnvException; + + /** + * Serialize the runtime env to string. + * + * @return The serialized runtime env string. + * @throws RuntimeEnvException + */ + public String serialize() throws RuntimeEnvException; + + /** + * Serialize the runtime env to string of RuntimeEnvInfo. + * + * @return The serialized runtime env info string. + * @throws RuntimeEnvException + */ + public String serializeToRuntimeEnvInfo() throws RuntimeEnvException; + + /** + * Deserialize the runtime env from string. + * + * @param serializedRuntimeEnv The serialized runtime env string. + * @return The deserialized RuntimeEnv instance. + * @throws RuntimeEnvException + */ + public static RuntimeEnv deserialize(String serializedRuntimeEnv) throws RuntimeEnvException { + return Ray.internal().deserializeRuntimeEnv(serializedRuntimeEnv); + } + + /** The builder which is used to generate a RuntimeEnv instance. */ public static class Builder { - - private Map envVars = new HashMap<>(); - private List jars = new ArrayList<>(); - - /** Add environment variable as runtime environment for the actor or job. */ - public Builder addEnvVar(String key, String value) { - envVars.put(key, value); - return this; - } - - /** - * Add the jars as runtime environment for the actor or job. We now support both `.jar` files - * and `.zip` files. - */ - public Builder addJars(List jars) { - this.jars.addAll(jars); - return this; - } - public RuntimeEnv build() { - return Ray.internal().createRuntimeEnv(envVars, jars); + return Ray.internal().createRuntimeEnv(); } } } diff --git a/java/api/src/main/java/io/ray/api/runtimeenv/types/RuntimeEnvName.java b/java/api/src/main/java/io/ray/api/runtimeenv/types/RuntimeEnvName.java new file mode 100644 index 000000000..afc215285 --- /dev/null +++ b/java/api/src/main/java/io/ray/api/runtimeenv/types/RuntimeEnvName.java @@ -0,0 +1,10 @@ +package io.ray.api.runtimeenv.types; + +public class RuntimeEnvName { + + /** The environment variables which type is `Map`. */ + public static final String ENV_VARS = "env_vars"; + + /** The dependent java jars which type is `List`. */ + public static final String JARS = "java_jars"; +} diff --git a/java/dependencies.bzl b/java/dependencies.bzl index eb79c1de4..4a2c47c52 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -4,6 +4,8 @@ load("@rules_jvm_external//:specs.bzl", "maven") def gen_java_deps(): maven_install( artifacts = [ + "com.fasterxml.jackson.core:jackson-databind:2.13.3", + "com.github.java-json-tools:json-schema-validator:2.2.14", "com.google.code.gson:gson:2.8.5", "com.google.guava:guava:30.0-jre", "com.google.protobuf:protobuf-java:3.19.4", diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index 3ff7fbff8..36065df4b 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -1,5 +1,8 @@ package io.ray.runtime; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import io.ray.api.ActorHandle; @@ -9,6 +12,7 @@ import io.ray.api.ObjectRef; import io.ray.api.PyActorHandle; import io.ray.api.WaitResult; import io.ray.api.concurrencygroup.ConcurrencyGroup; +import io.ray.api.exception.RuntimeEnvException; import io.ray.api.function.CppActorClass; import io.ray.api.function.CppActorMethod; import io.ray.api.function.CppFunction; @@ -50,7 +54,6 @@ import io.ray.runtime.util.ConcurrencyGroupUtils; import io.ray.runtime.utils.parallelactor.ParallelActorContextImpl; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import org.slf4j.Logger; @@ -72,6 +75,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { private static ParallelActorContextImpl parallelActorContextImpl = new ParallelActorContextImpl(); + private static final ObjectMapper MAPPER = new ObjectMapper(); + public AbstractRayRuntime(RayConfig rayConfig) { this.rayConfig = rayConfig; runtimeContext = new RuntimeContextImpl(this); @@ -306,8 +311,19 @@ public abstract class AbstractRayRuntime implements RayRuntime { } @Override - public RuntimeEnv createRuntimeEnv(Map envVars, List jars) { - return new RuntimeEnvImpl(envVars, jars); + public RuntimeEnv createRuntimeEnv() { + return new RuntimeEnvImpl(); + } + + @Override + public RuntimeEnv deserializeRuntimeEnv(String serializedRuntimeEnv) throws RuntimeEnvException { + RuntimeEnvImpl runtimeEnv = new RuntimeEnvImpl(); + try { + runtimeEnv.runtimeEnvs = (ObjectNode) MAPPER.readTree(serializedRuntimeEnv); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + return runtimeEnv; } private ObjectRef callNormalFunction( diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index 4f471b890..801d311d5 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -1,7 +1,6 @@ package io.ray.runtime; import com.google.common.base.Preconditions; -import com.google.gson.Gson; import io.ray.api.BaseActorHandle; import io.ray.api.exception.RayIntentionalSystemExitException; import io.ray.api.id.ActorId; @@ -18,7 +17,6 @@ import io.ray.runtime.gcs.GcsClientOptions; import io.ray.runtime.generated.Common.WorkerType; import io.ray.runtime.generated.Gcs.GcsNodeInfo; import io.ray.runtime.generated.Gcs.JobConfig; -import io.ray.runtime.generated.RuntimeEnvCommon.RuntimeEnvInfo; import io.ray.runtime.object.NativeObjectStore; import io.ray.runtime.runner.RunManager; import io.ray.runtime.task.NativeTaskExecutor; @@ -26,7 +24,6 @@ import io.ray.runtime.task.NativeTaskSubmitter; import io.ray.runtime.task.TaskExecutor; import io.ray.runtime.util.BinaryFileUtil; import io.ray.runtime.util.JniUtils; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -111,23 +108,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { .addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker) .addAllCodeSearchPath(rayConfig.codeSearchPath) .setRayNamespace(rayConfig.namespace); - RuntimeEnvInfo.Builder runtimeEnvInfoBuilder = RuntimeEnvInfo.newBuilder(); - if (rayConfig.runtimeEnvImpl != null) { - Map runtimeEnvMap = new HashMap<>(); - if (!rayConfig.runtimeEnvImpl.getEnvVars().isEmpty()) { - runtimeEnvMap.put("env_vars", rayConfig.runtimeEnvImpl.getEnvVars()); - } - - final List jarUrls = rayConfig.runtimeEnvImpl.getJars(); - if (jarUrls != null && !jarUrls.isEmpty()) { - runtimeEnvMap.put("java_jars", jarUrls); - } - runtimeEnvInfoBuilder.setSerializedRuntimeEnv(new Gson().toJson(runtimeEnvMap)); - - } else { - runtimeEnvInfoBuilder.setSerializedRuntimeEnv("{}"); - } - jobConfigBuilder.setRuntimeEnvInfo(runtimeEnvInfoBuilder.build()); + jobConfigBuilder.setRuntimeEnvInfo(rayConfig.runtimeEnvImpl.GenerateRuntimeEnvInfo()); jobConfigBuilder.setDefaultActorLifetime( rayConfig.defaultActorLifetime == ActorLifetime.DETACHED ? JobConfig.ActorLifetime.DETACHED diff --git a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java index 20c7a8706..ce1dec5dc 100644 --- a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java @@ -8,6 +8,7 @@ import com.typesafe.config.ConfigFactory; import com.typesafe.config.ConfigRenderOptions; import io.ray.api.id.JobId; import io.ray.api.options.ActorLifetime; +import io.ray.api.runtimeenv.types.RuntimeEnvName; import io.ray.runtime.generated.Common.WorkerType; import io.ray.runtime.runtimeenv.RuntimeEnvImpl; import io.ray.runtime.util.NetworkUtil; @@ -215,7 +216,9 @@ public class RayConfig { if (config.hasPath(jarsPath)) { jarUrls = config.getStringList(jarsPath); } - runtimeEnvImpl = new RuntimeEnvImpl(envVars, jarUrls); + runtimeEnvImpl = new RuntimeEnvImpl(); + runtimeEnvImpl.set(RuntimeEnvName.ENV_VARS, envVars); + runtimeEnvImpl.set(RuntimeEnvName.JARS, jarUrls); } { diff --git a/java/runtime/src/main/java/io/ray/runtime/runtimeenv/RuntimeEnvImpl.java b/java/runtime/src/main/java/io/ray/runtime/runtimeenv/RuntimeEnvImpl.java index 60c651485..d404c3e3a 100644 --- a/java/runtime/src/main/java/io/ray/runtime/runtimeenv/RuntimeEnvImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/runtimeenv/RuntimeEnvImpl.java @@ -1,58 +1,90 @@ package io.ray.runtime.runtimeenv; -import com.google.gson.Gson; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.fge.jackson.JsonLoader; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.util.JsonFormat; +import io.ray.api.exception.RuntimeEnvException; import io.ray.api.runtimeenv.RuntimeEnv; import io.ray.runtime.generated.RuntimeEnvCommon; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.io.IOException; public class RuntimeEnvImpl implements RuntimeEnv { - private Map envVars = new HashMap<>(); + private static final ObjectMapper MAPPER = new ObjectMapper(); - private List jars = new ArrayList<>(); + public ObjectNode runtimeEnvs = MAPPER.createObjectNode(); - public RuntimeEnvImpl(Map envVars, List jars) { - this.envVars = envVars; - if (jars != null) { - this.jars = jars; + public RuntimeEnvImpl() {} + + @Override + public void set(String name, Object value) throws RuntimeEnvException { + JsonNode node = null; + try { + node = MAPPER.valueToTree(value); + } catch (IllegalArgumentException e) { + throw new RuntimeException(e); } - } - - public Map getEnvVars() { - return envVars; - } - - public List getJars() { - return jars; + runtimeEnvs.set(name, node); } @Override - public String toJsonBytes() { - // Get serializedRuntimeEnv - String serializedRuntimeEnv = "{}"; - - Map runtimeEnvMap = new HashMap<>(); - if (!envVars.isEmpty()) { - runtimeEnvMap.put("env_vars", envVars); - } - if (!jars.isEmpty()) { - runtimeEnvMap.put("java_jars", jars); + public void setJsonStr(String name, String jsonStr) throws RuntimeEnvException { + JsonNode node = null; + try { + node = JsonLoader.fromString(jsonStr); + } catch (IOException e) { + throw new RuntimeException(e); } + runtimeEnvs.set(name, node); + } - serializedRuntimeEnv = new Gson().toJson(runtimeEnvMap); - - // Get serializedRuntimeEnvInfo - if (serializedRuntimeEnv.equals("{}") || serializedRuntimeEnv.isEmpty()) { - return "{}"; + @Override + public T get(String name, Class classOfT) throws RuntimeEnvException { + JsonNode jsonNode = runtimeEnvs.get(name); + if (jsonNode == null) { + return null; } + try { + return MAPPER.treeToValue(jsonNode, classOfT); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public String getJsonStr(String name) throws RuntimeEnvException { + try { + return MAPPER.writeValueAsString(runtimeEnvs.get(name)); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public void remove(String name) { + runtimeEnvs.remove(name); + } + + @Override + public String serialize() throws RuntimeEnvException { + try { + return MAPPER.writeValueAsString(runtimeEnvs); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public String serializeToRuntimeEnvInfo() throws RuntimeEnvException { + // TODO(SongGuyang): Expose runtime env config API to users. + String serializeRuntimeEnv = serialize(); RuntimeEnvCommon.RuntimeEnvInfo.Builder protoRuntimeEnvInfoBuilder = RuntimeEnvCommon.RuntimeEnvInfo.newBuilder(); - protoRuntimeEnvInfoBuilder.setSerializedRuntimeEnv(serializedRuntimeEnv); + protoRuntimeEnvInfoBuilder.setSerializedRuntimeEnv(serializeRuntimeEnv); JsonFormat.Printer printer = JsonFormat.printer(); try { return printer.print(protoRuntimeEnvInfoBuilder); @@ -60,4 +92,17 @@ public class RuntimeEnvImpl implements RuntimeEnv { throw new RuntimeException(e); } } + + public RuntimeEnvCommon.RuntimeEnvInfo GenerateRuntimeEnvInfo() throws RuntimeEnvException { + RuntimeEnvCommon.RuntimeEnvInfo.Builder protoRuntimeEnvInfoBuilder = + RuntimeEnvCommon.RuntimeEnvInfo.newBuilder(); + + try { + protoRuntimeEnvInfoBuilder.setSerializedRuntimeEnv(MAPPER.writeValueAsString(runtimeEnvs)); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + + return protoRuntimeEnvInfoBuilder.build(); + } } diff --git a/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java b/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java index 8419f754a..485e4d17b 100644 --- a/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java +++ b/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java @@ -4,7 +4,10 @@ import com.google.common.collect.ImmutableList; import io.ray.api.ActorHandle; import io.ray.api.Ray; import io.ray.api.runtimeenv.RuntimeEnv; +import io.ray.api.runtimeenv.types.RuntimeEnvName; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.testng.Assert; import org.testng.annotations.Test; @@ -62,12 +65,16 @@ public class RuntimeEnvTest { public void testEnvVarsForNormalTask() { try { Ray.init(); - RuntimeEnv runtimeEnv = - new RuntimeEnv.Builder() - .addEnvVar("KEY1", "A") - .addEnvVar("KEY2", "B") - .addEnvVar("KEY1", "C") - .build(); + RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build(); + Map envMap = + new HashMap() { + { + put("KEY1", "A"); + put("KEY2", "B"); + put("KEY1", "C"); + } + }; + runtimeEnv.set(RuntimeEnvName.ENV_VARS, envMap); String val = Ray.task(RuntimeEnvTest::getEnvVar, "KEY1").setRuntimeEnv(runtimeEnv).remote().get(); @@ -85,7 +92,14 @@ public class RuntimeEnvTest { System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B"); try { Ray.init(); - RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addEnvVar("KEY1", "C").build(); + Map envMap = + new HashMap() { + { + put("KEY1", "C"); + } + }; + RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build(); + runtimeEnv.set(RuntimeEnvName.ENV_VARS, envMap); /// value of KEY1 is overwritten to `C` and KEY2s is extended from job config. String val = @@ -101,7 +115,8 @@ public class RuntimeEnvTest { private static void testDownloadAndLoadPackage(String url) { try { Ray.init(); - final RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addJars(ImmutableList.of(url)).build(); + RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build(); + runtimeEnv.set(RuntimeEnvName.JARS, ImmutableList.of(url)); ActorHandle actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote(); boolean ret = actor1.task(A::findClass, FOO_CLASS_NAME).remote().get(); Assert.assertTrue(ret); @@ -133,7 +148,8 @@ public class RuntimeEnvTest { List urls, List classNames) { try { Ray.init(); - final RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addJars(urls).build(); + RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build(); + runtimeEnv.set(RuntimeEnvName.JARS, urls); boolean ret = Ray.task(RuntimeEnvTest::findClasses, classNames) .setRuntimeEnv(runtimeEnv) @@ -178,4 +194,73 @@ public class RuntimeEnvTest { Ray.shutdown(); } } + + private static class Pip { + private String[] packages; + private Boolean pip_check; + + public String[] getPackages() { + return packages; + } + + public void setPackages(String[] packages) { + this.packages = packages; + } + + public Boolean getPip_check() { + return pip_check; + } + + public void setPip_check(Boolean pip_check) { + this.pip_check = pip_check; + } + } + + public void testRuntimeEnvAPI() { + try { + Ray.init(); + RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build(); + String workingDir = "https://path/to/working_dir.zip"; + runtimeEnv.set("working_dir", workingDir); + String[] py_modules = + new String[] {"https://path/to/py_modules1.zip", "https://path/to/py_modules2.zip"}; + runtimeEnv.set("py_modules", py_modules); + Pip pip = new Pip(); + pip.setPackages(new String[] {"requests", "tensorflow"}); + pip.setPip_check(true); + runtimeEnv.set("pip", pip); + String serializedRuntimeEnv = runtimeEnv.serialize(); + + RuntimeEnv runtimeEnv2 = RuntimeEnv.deserialize(serializedRuntimeEnv); + Assert.assertEquals(runtimeEnv2.get("working_dir", String.class), workingDir); + Assert.assertEquals(runtimeEnv2.get("py_modules", String[].class), py_modules); + Pip pip2 = runtimeEnv2.get("pip", Pip.class); + Assert.assertEquals(pip2.getPackages(), pip.getPackages()); + Assert.assertEquals(pip2.getPip_check(), pip.getPip_check()); + + runtimeEnv2.remove("working_dir"); + runtimeEnv2.remove("py_modules"); + runtimeEnv2.remove("pip"); + Assert.assertEquals(runtimeEnv2.get("working_dir", String.class), null); + Assert.assertEquals(runtimeEnv2.get("py_modules", String[].class), null); + Assert.assertEquals(runtimeEnv2.get("pip", Pip.class), null); + } finally { + Ray.shutdown(); + } + } + + public void testRuntimeEnvJsonStringAPI() { + try { + Ray.init(); + RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().build(); + String pipString = "{\"packages\":[\"requests\",\"tensorflow\"],\"pip_check\":false}"; + runtimeEnv.setJsonStr("pip", pipString); + String serializedRuntimeEnv = runtimeEnv.serialize(); + + RuntimeEnv runtimeEnv2 = RuntimeEnv.deserialize(serializedRuntimeEnv); + Assert.assertEquals(runtimeEnv2.getJsonStr("pip"), pipString); + } finally { + Ray.shutdown(); + } + } }