mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RuntimeEnv] Support setting actor level env vars for Java worker (#22240)
This PR supports setting actor level env vars for Java worker in runtime env. General API looks like: ```java RuntimeEnv runtimeEnv = new RuntimeEnv.Builder() .addEnvVar("KEY1", "A") .addEnvVar("KEY2", "B") .addEnvVar("KEY1", "C") // This overwrites "KEY1" to "C" .build(); ActorHandle<A> actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote(); ``` If `num-java-workers-per-process` > 1, it will never reuse the worker process except they have the same runtime envs. Co-authored-by: Qing Wang <jovany.wq@antgroup.com>
This commit is contained in:
parent
94caac8722
commit
9572bb717f
16 changed files with 183 additions and 6 deletions
|
@ -4,6 +4,7 @@ import io.ray.api.ActorHandle;
|
|||
import io.ray.api.Ray;
|
||||
import io.ray.api.concurrencygroup.ConcurrencyGroup;
|
||||
import io.ray.api.function.RayFuncR;
|
||||
import io.ray.api.runtimeenv.RuntimeEnv;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
@ -54,4 +55,9 @@ public class ActorCreator<A> extends BaseActorCreator<ActorCreator<A>> {
|
|||
builder.setConcurrencyGroups(list);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ActorCreator<A> setRuntimeEnv(RuntimeEnv runtimeEnv) {
|
||||
builder.setRuntimeEnv(runtimeEnv);
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package io.ray.api.options;
|
|||
import io.ray.api.Ray;
|
||||
import io.ray.api.concurrencygroup.ConcurrencyGroup;
|
||||
import io.ray.api.placementgroup.PlacementGroup;
|
||||
import io.ray.api.runtimeenv.RuntimeEnv;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
@ -18,6 +19,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
|||
public final PlacementGroup group;
|
||||
public final int bundleIndex;
|
||||
public final List<ConcurrencyGroup> concurrencyGroups;
|
||||
public final String serializedRuntimeEnv;
|
||||
public final int maxPendingCalls;
|
||||
|
||||
private ActorCreationOptions(
|
||||
|
@ -30,6 +32,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
|||
PlacementGroup group,
|
||||
int bundleIndex,
|
||||
List<ConcurrencyGroup> concurrencyGroups,
|
||||
String serializedRuntimeEnv,
|
||||
int maxPendingCalls) {
|
||||
super(resources);
|
||||
this.name = name;
|
||||
|
@ -40,6 +43,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
|||
this.group = group;
|
||||
this.bundleIndex = bundleIndex;
|
||||
this.concurrencyGroups = concurrencyGroups;
|
||||
this.serializedRuntimeEnv = serializedRuntimeEnv;
|
||||
this.maxPendingCalls = maxPendingCalls;
|
||||
}
|
||||
|
||||
|
@ -54,6 +58,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
|||
private PlacementGroup group;
|
||||
private int bundleIndex;
|
||||
private List<ConcurrencyGroup> concurrencyGroups = new ArrayList<>();
|
||||
private RuntimeEnv runtimeEnv = null;
|
||||
private int maxPendingCalls = -1;
|
||||
|
||||
/**
|
||||
|
@ -188,6 +193,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
|||
group,
|
||||
bundleIndex,
|
||||
concurrencyGroups,
|
||||
runtimeEnv != null ? runtimeEnv.toJsonBytes() : "",
|
||||
maxPendingCalls);
|
||||
}
|
||||
|
||||
|
@ -196,5 +202,10 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
|||
this.concurrencyGroups = concurrencyGroups;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setRuntimeEnv(RuntimeEnv runtimeEnv) {
|
||||
this.runtimeEnv = runtimeEnv;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import io.ray.api.options.PlacementGroupCreationOptions;
|
|||
import io.ray.api.placementgroup.PlacementGroup;
|
||||
import io.ray.api.runtimecontext.ResourceValue;
|
||||
import io.ray.api.runtimecontext.RuntimeContext;
|
||||
import io.ray.api.runtimeenv.RuntimeEnv;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
@ -280,4 +281,7 @@ public interface RayRuntime {
|
|||
ConcurrencyGroup createConcurrencyGroup(String name, int maxConcurrency, List<RayFunc> funcs);
|
||||
|
||||
List<ConcurrencyGroup> extractConcurrencyGroups(RayFuncR<?> actorConstructorLambda);
|
||||
|
||||
/** Create runtime env instance at runtime. */
|
||||
RuntimeEnv createRuntimeEnv(Map<String, String> envVars);
|
||||
}
|
||||
|
|
25
java/api/src/main/java/io/ray/api/runtimeenv/RuntimeEnv.java
Normal file
25
java/api/src/main/java/io/ray/api/runtimeenv/RuntimeEnv.java
Normal file
|
@ -0,0 +1,25 @@
|
|||
package io.ray.api.runtimeenv;
|
||||
|
||||
import io.ray.api.Ray;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/** This is an experimental API to let you set runtime environments for your actors. */
|
||||
public interface RuntimeEnv {
|
||||
|
||||
String toJsonBytes();
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private Map<String, String> envVars = new HashMap<>();
|
||||
|
||||
public Builder addEnvVar(String key, String value) {
|
||||
envVars.put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public RuntimeEnv build() {
|
||||
return Ray.internal().createRuntimeEnv(envVars);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -21,6 +21,7 @@ import io.ray.api.options.CallOptions;
|
|||
import io.ray.api.options.PlacementGroupCreationOptions;
|
||||
import io.ray.api.placementgroup.PlacementGroup;
|
||||
import io.ray.api.runtimecontext.RuntimeContext;
|
||||
import io.ray.api.runtimeenv.RuntimeEnv;
|
||||
import io.ray.runtime.config.RayConfig;
|
||||
import io.ray.runtime.config.RunMode;
|
||||
import io.ray.runtime.context.RuntimeContextImpl;
|
||||
|
@ -33,6 +34,7 @@ import io.ray.runtime.generated.Common;
|
|||
import io.ray.runtime.generated.Common.Language;
|
||||
import io.ray.runtime.object.ObjectRefImpl;
|
||||
import io.ray.runtime.object.ObjectStore;
|
||||
import io.ray.runtime.runtimeenv.RuntimeEnvImpl;
|
||||
import io.ray.runtime.task.ArgumentsBuilder;
|
||||
import io.ray.runtime.task.FunctionArg;
|
||||
import io.ray.runtime.task.TaskExecutor;
|
||||
|
@ -40,6 +42,7 @@ import io.ray.runtime.task.TaskSubmitter;
|
|||
import io.ray.runtime.util.ConcurrencyGroupUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -285,6 +288,11 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
return ConcurrencyGroupUtils.extractConcurrencyGroupsByAnnotations(actorConstructorLambda);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RuntimeEnv createRuntimeEnv(Map<String, String> envVars) {
|
||||
return new RuntimeEnvImpl(envVars);
|
||||
}
|
||||
|
||||
private ObjectRef callNormalFunction(
|
||||
FunctionDescriptor functionDescriptor,
|
||||
Object[] args,
|
||||
|
|
|
@ -146,7 +146,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
numWorkersPerProcess,
|
||||
rayConfig.logDir,
|
||||
serializedJobConfig,
|
||||
rayConfig.getStartupToken());
|
||||
rayConfig.getStartupToken(),
|
||||
rayConfig.runtimeEnvHash);
|
||||
|
||||
taskExecutor = new NativeTaskExecutor(this);
|
||||
workerContext = new NativeWorkerContext();
|
||||
|
@ -278,7 +279,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
|||
int numWorkersPerProcess,
|
||||
String logDir,
|
||||
byte[] serializedJobConfig,
|
||||
int startupToken);
|
||||
int startupToken,
|
||||
int runtimeEnvHash);
|
||||
|
||||
private static native void nativeRunTaskExecutor(TaskExecutor taskExecutor);
|
||||
|
||||
|
|
|
@ -52,6 +52,8 @@ public class RayConfig {
|
|||
|
||||
public int startupToken;
|
||||
|
||||
public int runtimeEnvHash;
|
||||
|
||||
public final ActorLifetime defaultActorLifetime;
|
||||
|
||||
public static class LoggerConf {
|
||||
|
@ -201,6 +203,11 @@ public class RayConfig {
|
|||
|
||||
startupToken = config.getInt("ray.raylet.startup-token");
|
||||
|
||||
/// Driver needn't this config item.
|
||||
if (workerMode == WorkerType.WORKER && config.hasPath("ray.internal.runtime-env-hash")) {
|
||||
runtimeEnvHash = config.getInt("ray.internal.runtime-env-hash");
|
||||
}
|
||||
|
||||
{
|
||||
loggers = new ArrayList<>();
|
||||
List<Config> loggerConfigs = (List<Config>) config.getConfigList("ray.logging.loggers");
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
package io.ray.runtime.runtimeenv;
|
||||
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import com.google.protobuf.util.JsonFormat;
|
||||
import io.ray.api.runtimeenv.RuntimeEnv;
|
||||
import io.ray.runtime.generated.RuntimeEnvCommon;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class RuntimeEnvImpl implements RuntimeEnv {
|
||||
|
||||
private Map<String, String> envVars = new HashMap<>();
|
||||
|
||||
public RuntimeEnvImpl(Map<String, String> envVars) {
|
||||
this.envVars = envVars;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toJsonBytes() {
|
||||
if (!envVars.isEmpty()) {
|
||||
RuntimeEnvCommon.RuntimeEnv.Builder protoRuntimeEnvBuilder =
|
||||
RuntimeEnvCommon.RuntimeEnv.newBuilder();
|
||||
protoRuntimeEnvBuilder.putAllEnvVars(envVars);
|
||||
JsonFormat.Printer printer = JsonFormat.printer();
|
||||
try {
|
||||
return printer.print(protoRuntimeEnvBuilder);
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
return "{}";
|
||||
}
|
||||
}
|
65
java/test/src/main/java/io/ray/test/RuntimeEnvTest.java
Normal file
65
java/test/src/main/java/io/ray/test/RuntimeEnvTest.java
Normal file
|
@ -0,0 +1,65 @@
|
|||
package io.ray.test;
|
||||
|
||||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.runtimeenv.RuntimeEnv;
|
||||
import io.ray.runtime.util.SystemUtil;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.BeforeClass;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@Test(groups = "cluster")
|
||||
public class RuntimeEnvTest extends BaseTest {
|
||||
|
||||
@BeforeClass
|
||||
public void setUp() {
|
||||
/// This is used to test that actors with runtime envs will not reuse worker process.
|
||||
System.setProperty("ray.job.num-java-workers-per-process", "2");
|
||||
}
|
||||
|
||||
private static class A {
|
||||
|
||||
public String getEnv(String key) {
|
||||
return System.getenv(key);
|
||||
}
|
||||
|
||||
public int getPid() {
|
||||
return SystemUtil.pid();
|
||||
}
|
||||
}
|
||||
|
||||
public void testEnvironmentVariable() {
|
||||
int pid1 = 0;
|
||||
int pid2 = 0;
|
||||
{
|
||||
RuntimeEnv runtimeEnv =
|
||||
new RuntimeEnv.Builder()
|
||||
.addEnvVar("KEY1", "A")
|
||||
.addEnvVar("KEY2", "B")
|
||||
.addEnvVar("KEY1", "C")
|
||||
.build();
|
||||
|
||||
ActorHandle<A> actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote();
|
||||
String val = actor1.task(A::getEnv, "KEY1").remote().get();
|
||||
Assert.assertEquals(val, "C");
|
||||
val = actor1.task(A::getEnv, "KEY2").remote().get();
|
||||
Assert.assertEquals(val, "B");
|
||||
|
||||
pid1 = actor1.task(A::getPid).remote().get();
|
||||
}
|
||||
|
||||
{
|
||||
/// Because we didn't set them for actor2 , all should be null.
|
||||
ActorHandle<A> actor2 = Ray.actor(A::new).remote();
|
||||
String val = actor2.task(A::getEnv, "KEY1").remote().get();
|
||||
Assert.assertNull(val);
|
||||
val = actor2.task(A::getEnv, "KEY2").remote().get();
|
||||
Assert.assertNull(val);
|
||||
pid2 = actor2.task(A::getPid).remote().get();
|
||||
}
|
||||
|
||||
// actor1 and actor2 shouldn't be in one process because they have
|
||||
// different runtime env.
|
||||
Assert.assertNotEquals(pid1, pid2);
|
||||
}
|
||||
}
|
|
@ -98,7 +98,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
|||
JNIEnv *env, jclass, jint workerMode, jstring nodeIpAddress, jint nodeManagerPort,
|
||||
jstring driverName, jstring storeSocket, jstring rayletSocket, jbyteArray jobId,
|
||||
jobject gcsClientOptions, jint numWorkersPerProcess, jstring logDir,
|
||||
jbyteArray jobConfig, jint startupToken) {
|
||||
jbyteArray jobConfig, jint startupToken, jint runtimeEnvHash) {
|
||||
auto task_execution_callback =
|
||||
[](TaskType task_type, const std::string task_name, const RayFunction &ray_function,
|
||||
const std::unordered_map<std::string, double> &required_resources,
|
||||
|
@ -267,6 +267,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
|||
options.serialized_job_config = serialized_job_config;
|
||||
options.metrics_agent_port = -1;
|
||||
options.startup_token = startupToken;
|
||||
options.runtime_env_hash = runtimeEnvHash;
|
||||
|
||||
CoreWorkerProcess::Initialize(options);
|
||||
}
|
||||
|
|
|
@ -25,11 +25,11 @@ extern "C" {
|
|||
* Class: io_ray_runtime_RayNativeRuntime
|
||||
* Method: nativeInitialize
|
||||
* Signature:
|
||||
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;[BI)V
|
||||
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;[BII)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
||||
JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject,
|
||||
jint, jstring, jbyteArray, jint);
|
||||
jint, jstring, jbyteArray, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: io_ray_runtime_RayNativeRuntime
|
||||
|
|
|
@ -151,6 +151,7 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env,
|
|||
uint64_t max_concurrency = 1;
|
||||
auto placement_options = std::make_pair(PlacementGroupID::Nil(), -1);
|
||||
std::vector<ConcurrencyGroup> concurrency_groups;
|
||||
std::string serialized_runtime_env = "";
|
||||
int32_t max_pending_calls = -1;
|
||||
|
||||
if (actorCreationOptions) {
|
||||
|
@ -223,6 +224,12 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env,
|
|||
return ray::ConcurrencyGroup{concurrency_group_name, max_concurrency,
|
||||
native_func_descriptors};
|
||||
});
|
||||
auto java_serialized_runtime_env = (jstring)env->GetObjectField(actorCreationOptions,
|
||||
java_actor_creation_options_serialized_runtime_env);
|
||||
if (java_serialized_runtime_env) {
|
||||
serialized_runtime_env = JavaStringToNativeString(env, java_serialized_runtime_env);
|
||||
}
|
||||
|
||||
max_pending_calls = static_cast<int32_t>(env->GetIntField(
|
||||
actorCreationOptions, java_actor_creation_options_max_pending_calls));
|
||||
}
|
||||
|
@ -253,7 +260,7 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env,
|
|||
ray_namespace,
|
||||
/*is_asyncio=*/false,
|
||||
/*scheduling_strategy=*/scheduling_strategy,
|
||||
/*serialized_runtime_env=*/"{}",
|
||||
serialized_runtime_env,
|
||||
concurrency_groups,
|
||||
/*execute_out_of_order*/ false,
|
||||
max_pending_calls};
|
||||
|
|
|
@ -105,6 +105,7 @@ jfieldID java_actor_creation_options_max_concurrency;
|
|||
jfieldID java_actor_creation_options_group;
|
||||
jfieldID java_actor_creation_options_bundle_index;
|
||||
jfieldID java_actor_creation_options_concurrency_groups;
|
||||
jfieldID java_actor_creation_options_serialized_runtime_env;
|
||||
jfieldID java_actor_creation_options_max_pending_calls;
|
||||
|
||||
jclass java_actor_lifetime_class;
|
||||
|
@ -332,6 +333,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
|||
env->GetFieldID(java_actor_creation_options_class, "bundleIndex", "I");
|
||||
java_actor_creation_options_concurrency_groups = env->GetFieldID(
|
||||
java_actor_creation_options_class, "concurrencyGroups", "Ljava/util/List;");
|
||||
java_actor_creation_options_serialized_runtime_env =
|
||||
env->GetFieldID(java_actor_creation_options_class, "serializedRuntimeEnv", "Ljava/lang/String;");
|
||||
java_actor_creation_options_max_pending_calls =
|
||||
env->GetFieldID(java_actor_creation_options_class, "maxPendingCalls", "I");
|
||||
|
||||
|
|
|
@ -189,6 +189,8 @@ extern jfieldID java_actor_creation_options_group;
|
|||
extern jfieldID java_actor_creation_options_bundle_index;
|
||||
/// concurrencyGroups field of ActorCreationOptions class
|
||||
extern jfieldID java_actor_creation_options_concurrency_groups;
|
||||
/// serializedRuntimeEnv field of ActorCreatrionOptions class
|
||||
extern jfieldID java_actor_creation_options_serialized_runtime_env;
|
||||
/// maxPendingCalls field of ActorCreationOptions class
|
||||
extern jfieldID java_actor_creation_options_max_pending_calls;
|
||||
/// ActorLifetime enum class
|
||||
|
|
|
@ -314,6 +314,8 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
|
|||
if (language == Language::JAVA) {
|
||||
options.push_back("-Dray.raylet.startup-token=" +
|
||||
std::to_string(worker_startup_token_counter_));
|
||||
options.push_back("-Dray.internal.runtime-env-hash=" +
|
||||
std::to_string(runtime_env_hash));
|
||||
}
|
||||
|
||||
// Append user-defined per-process options here
|
||||
|
|
|
@ -708,6 +708,7 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
|
|||
// Ray-defined per-process options
|
||||
expected_command.push_back(GetNumJavaWorkersPerProcessSystemProperty(1));
|
||||
expected_command.push_back("-Dray.raylet.startup-token=0");
|
||||
expected_command.push_back("-Dray.internal.runtime-env-hash=0");
|
||||
// User-defined per-process options
|
||||
expected_command.insert(expected_command.end(), actor_jvm_options.begin(),
|
||||
actor_jvm_options.end());
|
||||
|
|
Loading…
Add table
Reference in a new issue