diff --git a/java/api/src/main/java/io/ray/api/call/BaseTaskCaller.java b/java/api/src/main/java/io/ray/api/call/BaseTaskCaller.java index d1bcf0d61..482c430f7 100644 --- a/java/api/src/main/java/io/ray/api/call/BaseTaskCaller.java +++ b/java/api/src/main/java/io/ray/api/call/BaseTaskCaller.java @@ -2,6 +2,7 @@ package io.ray.api.call; import io.ray.api.options.CallOptions; import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.runtimeenv.RuntimeEnv; import java.util.Map; /** @@ -75,6 +76,17 @@ public class BaseTaskCaller> { return setPlacementGroup(group, -1); } + /** + * Set the runtime env for this task to run the task in a specific environment. + * + * @param runtimeEnv The runtime env of this task. + * @return self + */ + public T setRuntimeEnv(RuntimeEnv runtimeEnv) { + builder.setRuntimeEnv(runtimeEnv); + return self(); + } + @SuppressWarnings("unchecked") private T self() { return (T) this; diff --git a/java/api/src/main/java/io/ray/api/options/CallOptions.java b/java/api/src/main/java/io/ray/api/options/CallOptions.java index f1a48dd99..e646887ec 100644 --- a/java/api/src/main/java/io/ray/api/options/CallOptions.java +++ b/java/api/src/main/java/io/ray/api/options/CallOptions.java @@ -1,6 +1,7 @@ package io.ray.api.options; import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.runtimeenv.RuntimeEnv; import java.util.HashMap; import java.util.Map; @@ -11,18 +12,21 @@ public class CallOptions extends BaseTaskOptions { public final PlacementGroup group; public final int bundleIndex; public final String concurrencyGroupName; + private final String serializedRuntimeEnvInfo; private CallOptions( String name, Map resources, PlacementGroup group, int bundleIndex, - String concurrencyGroupName) { + String concurrencyGroupName, + RuntimeEnv runtimeEnv) { super(resources); this.name = name; this.group = group; this.bundleIndex = bundleIndex; this.concurrencyGroupName = concurrencyGroupName; + this.serializedRuntimeEnvInfo = runtimeEnv == null ? "" : runtimeEnv.toJsonBytes(); } /** This inner class for building CallOptions. */ @@ -33,6 +37,7 @@ public class CallOptions extends BaseTaskOptions { private PlacementGroup group; private int bundleIndex; private String concurrencyGroupName = ""; + private RuntimeEnv runtimeEnv = null; /** * Set a name for this task. @@ -88,8 +93,13 @@ public class CallOptions extends BaseTaskOptions { return this; } + public Builder setRuntimeEnv(RuntimeEnv runtimeEnv) { + this.runtimeEnv = runtimeEnv; + return this; + } + public CallOptions build() { - return new CallOptions(name, resources, group, bundleIndex, concurrencyGroupName); + return new CallOptions(name, resources, group, bundleIndex, concurrencyGroupName, runtimeEnv); } } } diff --git a/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java b/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java index 12726f921..b5b06234f 100644 --- a/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java +++ b/java/test/src/main/java/io/ray/test/RuntimeEnvTest.java @@ -118,4 +118,47 @@ public class RuntimeEnvTest { Ray.shutdown(); } } + + private static String getEnvVar(String key) { + return System.getenv(key); + } + + public void testEnvVarsForNormalTask() { + try { + Ray.init(); + RuntimeEnv runtimeEnv = + new RuntimeEnv.Builder() + .addEnvVar("KEY1", "A") + .addEnvVar("KEY2", "B") + .addEnvVar("KEY1", "C") + .build(); + + String val = + Ray.task(RuntimeEnvTest::getEnvVar, "KEY1").setRuntimeEnv(runtimeEnv).remote().get(); + Assert.assertEquals(val, "C"); + val = Ray.task(RuntimeEnvTest::getEnvVar, "KEY2").setRuntimeEnv(runtimeEnv).remote().get(); + Assert.assertEquals(val, "B"); + } finally { + Ray.shutdown(); + } + } + + /// overwrite the runtime env from job config. + public void testPerTaskEnvVarsOverwritePerJobEnvVars() { + System.setProperty("ray.job.runtime-env.env-vars.KEY1", "A"); + System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B"); + try { + Ray.init(); + RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addEnvVar("KEY1", "C").build(); + + /// value of KEY1 is overwritten to `C` and KEY2s is extended from job config. + String val = + Ray.task(RuntimeEnvTest::getEnvVar, "KEY1").setRuntimeEnv(runtimeEnv).remote().get(); + Assert.assertEquals(val, "C"); + val = Ray.task(RuntimeEnvTest::getEnvVar, "KEY2").setRuntimeEnv(runtimeEnv).remote().get(); + Assert.assertEquals(val, "B"); + } finally { + Ray.shutdown(); + } + } } diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index 829950ba9..858723da1 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -122,6 +122,8 @@ inline TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptio std::unordered_map resources; std::string name = ""; std::string concurrency_group_name = ""; + std::string serialzied_runtime_env_info = ""; + if (callOptions) { jobject java_resources = env->GetObjectField(callOptions, java_base_task_options_resources); @@ -137,9 +139,19 @@ inline TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptio if (java_concurrency_group_name) { concurrency_group_name = JavaStringToNativeString(env, java_concurrency_group_name); } + + auto java_serialized_runtime_env_info = reinterpret_cast( + env->GetObjectField(callOptions, java_call_options_serialized_runtime_env_info)); + RAY_CHECK_JAVA_EXCEPTION(env); + RAY_CHECK(java_serialized_runtime_env_info != nullptr); + if (java_serialized_runtime_env_info) { + serialzied_runtime_env_info = + JavaStringToNativeString(env, java_serialized_runtime_env_info); + } } - TaskOptions task_options{name, numReturns, resources, concurrency_group_name}; + TaskOptions task_options{ + name, numReturns, resources, concurrency_group_name, serialzied_runtime_env_info}; return task_options; } diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 0f0891ae6..b9cd4fd2e 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -99,6 +99,7 @@ jfieldID java_call_options_name; jfieldID java_task_creation_options_group; jfieldID java_task_creation_options_bundle_index; jfieldID java_call_options_concurrency_group_name; +jfieldID java_call_options_serialized_runtime_env_info; jclass java_actor_creation_options_class; jfieldID java_actor_creation_options_name; @@ -309,6 +310,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { env->GetFieldID(java_call_options_class, "bundleIndex", "I"); java_call_options_concurrency_group_name = env->GetFieldID( java_call_options_class, "concurrencyGroupName", "Ljava/lang/String;"); + java_call_options_serialized_runtime_env_info = env->GetFieldID( + java_call_options_class, "serializedRuntimeEnvInfo", "Ljava/lang/String;"); java_placement_group_class = LoadClass(env, "io/ray/runtime/placementgroup/PlacementGroupImpl"); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index c15bddf9e..ba592cf11 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -175,6 +175,8 @@ extern jfieldID java_task_creation_options_group; extern jfieldID java_task_creation_options_bundle_index; /// concurrencyGroupName field of CallOptions class extern jfieldID java_call_options_concurrency_group_name; +/// serializedRuntimeEnvInfo field of CallOptions class +extern jfieldID java_call_options_serialized_runtime_env_info; /// ActorCreationOptions class extern jclass java_actor_creation_options_class; @@ -549,7 +551,7 @@ inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env, if (!buffer) { return nullptr; } - + auto buffer_size = buffer->Size(); jbyteArray java_byte_array = env->NewByteArray(buffer_size); if (buffer_size > 0) {