mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[Java] Add set runtime env api for normal task. (#23412)
This PR adds the API `setRuntimeEnv` for submitting a normal task, for the usage: ```java RuntimeEnv runtimeEnv = new RuntimeEnv.Builder() .addEnvVar("KEY1", "A") .build(); /// Return `A` Ray.task(RuntimeEnvTest::getEnvVar, "KEY1").setRuntimeEnv(runtimeEnv).remote().get(); ```
This commit is contained in:
parent
26f1a7ef7d
commit
ef5b9b87d3
6 changed files with 86 additions and 4 deletions
|
@ -2,6 +2,7 @@ package io.ray.api.call;
|
|||
|
||||
import io.ray.api.options.CallOptions;
|
||||
import io.ray.api.placementgroup.PlacementGroup;
|
||||
import io.ray.api.runtimeenv.RuntimeEnv;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
|
@ -75,6 +76,17 @@ public class BaseTaskCaller<T extends BaseTaskCaller<T>> {
|
|||
return setPlacementGroup(group, -1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the runtime env for this task to run the task in a specific environment.
|
||||
*
|
||||
* @param runtimeEnv The runtime env of this task.
|
||||
* @return self
|
||||
*/
|
||||
public T setRuntimeEnv(RuntimeEnv runtimeEnv) {
|
||||
builder.setRuntimeEnv(runtimeEnv);
|
||||
return self();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private T self() {
|
||||
return (T) this;
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package io.ray.api.options;
|
||||
|
||||
import io.ray.api.placementgroup.PlacementGroup;
|
||||
import io.ray.api.runtimeenv.RuntimeEnv;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -11,18 +12,21 @@ public class CallOptions extends BaseTaskOptions {
|
|||
public final PlacementGroup group;
|
||||
public final int bundleIndex;
|
||||
public final String concurrencyGroupName;
|
||||
private final String serializedRuntimeEnvInfo;
|
||||
|
||||
private CallOptions(
|
||||
String name,
|
||||
Map<String, Double> resources,
|
||||
PlacementGroup group,
|
||||
int bundleIndex,
|
||||
String concurrencyGroupName) {
|
||||
String concurrencyGroupName,
|
||||
RuntimeEnv runtimeEnv) {
|
||||
super(resources);
|
||||
this.name = name;
|
||||
this.group = group;
|
||||
this.bundleIndex = bundleIndex;
|
||||
this.concurrencyGroupName = concurrencyGroupName;
|
||||
this.serializedRuntimeEnvInfo = runtimeEnv == null ? "" : runtimeEnv.toJsonBytes();
|
||||
}
|
||||
|
||||
/** This inner class for building CallOptions. */
|
||||
|
@ -33,6 +37,7 @@ public class CallOptions extends BaseTaskOptions {
|
|||
private PlacementGroup group;
|
||||
private int bundleIndex;
|
||||
private String concurrencyGroupName = "";
|
||||
private RuntimeEnv runtimeEnv = null;
|
||||
|
||||
/**
|
||||
* Set a name for this task.
|
||||
|
@ -88,8 +93,13 @@ public class CallOptions extends BaseTaskOptions {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setRuntimeEnv(RuntimeEnv runtimeEnv) {
|
||||
this.runtimeEnv = runtimeEnv;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CallOptions build() {
|
||||
return new CallOptions(name, resources, group, bundleIndex, concurrencyGroupName);
|
||||
return new CallOptions(name, resources, group, bundleIndex, concurrencyGroupName, runtimeEnv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -118,4 +118,47 @@ public class RuntimeEnvTest {
|
|||
Ray.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
private static String getEnvVar(String key) {
|
||||
return System.getenv(key);
|
||||
}
|
||||
|
||||
public void testEnvVarsForNormalTask() {
|
||||
try {
|
||||
Ray.init();
|
||||
RuntimeEnv runtimeEnv =
|
||||
new RuntimeEnv.Builder()
|
||||
.addEnvVar("KEY1", "A")
|
||||
.addEnvVar("KEY2", "B")
|
||||
.addEnvVar("KEY1", "C")
|
||||
.build();
|
||||
|
||||
String val =
|
||||
Ray.task(RuntimeEnvTest::getEnvVar, "KEY1").setRuntimeEnv(runtimeEnv).remote().get();
|
||||
Assert.assertEquals(val, "C");
|
||||
val = Ray.task(RuntimeEnvTest::getEnvVar, "KEY2").setRuntimeEnv(runtimeEnv).remote().get();
|
||||
Assert.assertEquals(val, "B");
|
||||
} finally {
|
||||
Ray.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
/// overwrite the runtime env from job config.
|
||||
public void testPerTaskEnvVarsOverwritePerJobEnvVars() {
|
||||
System.setProperty("ray.job.runtime-env.env-vars.KEY1", "A");
|
||||
System.setProperty("ray.job.runtime-env.env-vars.KEY2", "B");
|
||||
try {
|
||||
Ray.init();
|
||||
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder().addEnvVar("KEY1", "C").build();
|
||||
|
||||
/// value of KEY1 is overwritten to `C` and KEY2s is extended from job config.
|
||||
String val =
|
||||
Ray.task(RuntimeEnvTest::getEnvVar, "KEY1").setRuntimeEnv(runtimeEnv).remote().get();
|
||||
Assert.assertEquals(val, "C");
|
||||
val = Ray.task(RuntimeEnvTest::getEnvVar, "KEY2").setRuntimeEnv(runtimeEnv).remote().get();
|
||||
Assert.assertEquals(val, "B");
|
||||
} finally {
|
||||
Ray.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -122,6 +122,8 @@ inline TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptio
|
|||
std::unordered_map<std::string, double> resources;
|
||||
std::string name = "";
|
||||
std::string concurrency_group_name = "";
|
||||
std::string serialzied_runtime_env_info = "";
|
||||
|
||||
if (callOptions) {
|
||||
jobject java_resources =
|
||||
env->GetObjectField(callOptions, java_base_task_options_resources);
|
||||
|
@ -137,9 +139,19 @@ inline TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptio
|
|||
if (java_concurrency_group_name) {
|
||||
concurrency_group_name = JavaStringToNativeString(env, java_concurrency_group_name);
|
||||
}
|
||||
|
||||
auto java_serialized_runtime_env_info = reinterpret_cast<jstring>(
|
||||
env->GetObjectField(callOptions, java_call_options_serialized_runtime_env_info));
|
||||
RAY_CHECK_JAVA_EXCEPTION(env);
|
||||
RAY_CHECK(java_serialized_runtime_env_info != nullptr);
|
||||
if (java_serialized_runtime_env_info) {
|
||||
serialzied_runtime_env_info =
|
||||
JavaStringToNativeString(env, java_serialized_runtime_env_info);
|
||||
}
|
||||
}
|
||||
|
||||
TaskOptions task_options{name, numReturns, resources, concurrency_group_name};
|
||||
TaskOptions task_options{
|
||||
name, numReturns, resources, concurrency_group_name, serialzied_runtime_env_info};
|
||||
return task_options;
|
||||
}
|
||||
|
||||
|
|
|
@ -99,6 +99,7 @@ jfieldID java_call_options_name;
|
|||
jfieldID java_task_creation_options_group;
|
||||
jfieldID java_task_creation_options_bundle_index;
|
||||
jfieldID java_call_options_concurrency_group_name;
|
||||
jfieldID java_call_options_serialized_runtime_env_info;
|
||||
|
||||
jclass java_actor_creation_options_class;
|
||||
jfieldID java_actor_creation_options_name;
|
||||
|
@ -309,6 +310,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
|||
env->GetFieldID(java_call_options_class, "bundleIndex", "I");
|
||||
java_call_options_concurrency_group_name = env->GetFieldID(
|
||||
java_call_options_class, "concurrencyGroupName", "Ljava/lang/String;");
|
||||
java_call_options_serialized_runtime_env_info = env->GetFieldID(
|
||||
java_call_options_class, "serializedRuntimeEnvInfo", "Ljava/lang/String;");
|
||||
|
||||
java_placement_group_class =
|
||||
LoadClass(env, "io/ray/runtime/placementgroup/PlacementGroupImpl");
|
||||
|
|
|
@ -175,6 +175,8 @@ extern jfieldID java_task_creation_options_group;
|
|||
extern jfieldID java_task_creation_options_bundle_index;
|
||||
/// concurrencyGroupName field of CallOptions class
|
||||
extern jfieldID java_call_options_concurrency_group_name;
|
||||
/// serializedRuntimeEnvInfo field of CallOptions class
|
||||
extern jfieldID java_call_options_serialized_runtime_env_info;
|
||||
|
||||
/// ActorCreationOptions class
|
||||
extern jclass java_actor_creation_options_class;
|
||||
|
@ -549,7 +551,7 @@ inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env,
|
|||
if (!buffer) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
auto buffer_size = buffer->Size();
|
||||
jbyteArray java_byte_array = env->NewByteArray(buffer_size);
|
||||
if (buffer_size > 0) {
|
||||
|
|
Loading…
Add table
Reference in a new issue