[Java] Add set runtime env api for normal task. (#23412)

This PR adds the API `setRuntimeEnv` for submitting a normal task, for the usage:
```java
RuntimeEnv runtimeEnv =
    new RuntimeEnv.Builder()
        .addEnvVar("KEY1", "A")
        .build();

/// Return `A`
Ray.task(RuntimeEnvTest::getEnvVar, "KEY1").setRuntimeEnv(runtimeEnv).remote().get();
```
This commit is contained in:
Qing Wang 2022-03-24 15:57:24 +08:00 committed by GitHub
parent 26f1a7ef7d
commit ef5b9b87d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 4 deletions

View file

@ -2,6 +2,7 @@ package io.ray.api.call;
import io.ray.api.options.CallOptions;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimeenv.RuntimeEnv;
import java.util.Map;
/**
@ -75,6 +76,17 @@ public class BaseTaskCaller<T extends BaseTaskCaller<T>> {
return setPlacementGroup(group, -1);
}
/**
* Set the runtime env for this task to run the task in a specific environment.
*
* @param runtimeEnv The runtime env of this task.
* @return self
*/
public T setRuntimeEnv(RuntimeEnv runtimeEnv) {
builder.setRuntimeEnv(runtimeEnv);
return self();
}
@SuppressWarnings("unchecked")
private T self() {
return (T) this;

View file

@ -1,6 +1,7 @@
package io.ray.api.options;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimeenv.RuntimeEnv;
import java.util.HashMap;
import java.util.Map;
@ -11,18 +12,21 @@ public class CallOptions extends BaseTaskOptions {
public final PlacementGroup group;
public final int bundleIndex;
public final String concurrencyGroupName;
private final String serializedRuntimeEnvInfo;
private CallOptions(
String name,
Map<String, Double> resources,
PlacementGroup group,
int bundleIndex,
String concurrencyGroupName) {
String concurrencyGroupName,
RuntimeEnv runtimeEnv) {
super(resources);
this.name = name;
this.group = group;
this.bundleIndex = bundleIndex;
this.concurrencyGroupName = concurrencyGroupName;
this.serializedRuntimeEnvInfo = runtimeEnv == null ? "" : runtimeEnv.toJsonBytes();
}
/** This inner class for building CallOptions. */
@ -33,6 +37,7 @@ public class CallOptions extends BaseTaskOptions {
private PlacementGroup group;
private int bundleIndex;
private String concurrencyGroupName = "";
private RuntimeEnv runtimeEnv = null;
/**
* Set a name for this task.
@ -88,8 +93,13 @@ public class CallOptions extends BaseTaskOptions {
return this;
}
public Builder setRuntimeEnv(RuntimeEnv runtimeEnv) {
this.runtimeEnv = runtimeEnv;
return this;
}
public CallOptions build() {
return new CallOptions(name, resources, group, bundleIndex, concurrencyGroupName);
return new CallOptions(name, resources, group, bundleIndex, concurrencyGroupName, runtimeEnv);
}
}
}

View file

@ -118,4 +118,47 @@ public class RuntimeEnvTest {
Ray.shutdown();
}
}
private static String getEnvVar(String key) {
return System.getenv(key);
}
public void testEnvVarsForNormalTask() {
try {
Ray.init();
RuntimeEnv runtimeEnv =
new RuntimeEnv.Builder()
.addEnvVar("KEY1", "A")
.addEnvVar("KEY2", "B")
.addEnvVar("KEY1", "C")
.build();
String val =
Ray.task(RuntimeEnvTest::getEnvVar, "KEY1").setRuntimeEnv(runtimeEnv).remote().get();
Assert.assertEquals(val, "C");
val = Ray.task(RuntimeEnvTest::getEnvVar, "KEY2").setRuntimeEnv(runtimeEnv).remote().get();
Assert.assertEquals(val, "B");
} finally {
Ray.shutdown();
}
}
/// overwrite the runtime env from job config.
public void testPerTaskEnvVarsOverwritePerJobEnvVars() {
System.setProperty("ray.job.runtime-env.env-vars.KEY1", "A");
System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B");
try {
Ray.init();
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addEnvVar("KEY1", "C").build();
/// value of KEY1 is overwritten to `C` and KEY2s is extended from job config.
String val =
Ray.task(RuntimeEnvTest::getEnvVar, "KEY1").setRuntimeEnv(runtimeEnv).remote().get();
Assert.assertEquals(val, "C");
val = Ray.task(RuntimeEnvTest::getEnvVar, "KEY2").setRuntimeEnv(runtimeEnv).remote().get();
Assert.assertEquals(val, "B");
} finally {
Ray.shutdown();
}
}
}

View file

@ -122,6 +122,8 @@ inline TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptio
std::unordered_map<std::string, double> resources;
std::string name = "";
std::string concurrency_group_name = "";
std::string serialzied_runtime_env_info = "";
if (callOptions) {
jobject java_resources =
env->GetObjectField(callOptions, java_base_task_options_resources);
@ -137,9 +139,19 @@ inline TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptio
if (java_concurrency_group_name) {
concurrency_group_name = JavaStringToNativeString(env, java_concurrency_group_name);
}
auto java_serialized_runtime_env_info = reinterpret_cast<jstring>(
env->GetObjectField(callOptions, java_call_options_serialized_runtime_env_info));
RAY_CHECK_JAVA_EXCEPTION(env);
RAY_CHECK(java_serialized_runtime_env_info != nullptr);
if (java_serialized_runtime_env_info) {
serialzied_runtime_env_info =
JavaStringToNativeString(env, java_serialized_runtime_env_info);
}
}
TaskOptions task_options{name, numReturns, resources, concurrency_group_name};
TaskOptions task_options{
name, numReturns, resources, concurrency_group_name, serialzied_runtime_env_info};
return task_options;
}

View file

@ -99,6 +99,7 @@ jfieldID java_call_options_name;
jfieldID java_task_creation_options_group;
jfieldID java_task_creation_options_bundle_index;
jfieldID java_call_options_concurrency_group_name;
jfieldID java_call_options_serialized_runtime_env_info;
jclass java_actor_creation_options_class;
jfieldID java_actor_creation_options_name;
@ -309,6 +310,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
env->GetFieldID(java_call_options_class, "bundleIndex", "I");
java_call_options_concurrency_group_name = env->GetFieldID(
java_call_options_class, "concurrencyGroupName", "Ljava/lang/String;");
java_call_options_serialized_runtime_env_info = env->GetFieldID(
java_call_options_class, "serializedRuntimeEnvInfo", "Ljava/lang/String;");
java_placement_group_class =
LoadClass(env, "io/ray/runtime/placementgroup/PlacementGroupImpl");

View file

@ -175,6 +175,8 @@ extern jfieldID java_task_creation_options_group;
extern jfieldID java_task_creation_options_bundle_index;
/// concurrencyGroupName field of CallOptions class
extern jfieldID java_call_options_concurrency_group_name;
/// serializedRuntimeEnvInfo field of CallOptions class
extern jfieldID java_call_options_serialized_runtime_env_info;
/// ActorCreationOptions class
extern jclass java_actor_creation_options_class;
@ -549,7 +551,7 @@ inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env,
if (!buffer) {
return nullptr;
}
auto buffer_size = buffer->Size();
jbyteArray java_byte_array = env->NewByteArray(buffer_size);
if (buffer_size > 0) {