[RuntimeEnv] Support setting actor level env vars for Java worker (#22240)

This PR supports setting actor level env vars for Java worker in runtime env.
General API looks like:
```java
RuntimeEnv runtimeEnv = new RuntimeEnv.Builder()
    .addEnvVar("KEY1", "A")
    .addEnvVar("KEY2", "B")
    .addEnvVar("KEY1", "C")  // This overwrites "KEY1" to "C"
    .build();

ActorHandle<A> actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote();
```

If `num-java-workers-per-process` > 1, it will never reuse the worker process except they have the same runtime envs.

Co-authored-by: Qing Wang <jovany.wq@antgroup.com>
This commit is contained in:
Qing Wang 2022-02-28 10:58:37 +08:00 committed by GitHub
parent 94caac8722
commit 9572bb717f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 183 additions and 6 deletions

View file

@ -4,6 +4,7 @@ import io.ray.api.ActorHandle;
import io.ray.api.Ray;
import io.ray.api.concurrencygroup.ConcurrencyGroup;
import io.ray.api.function.RayFuncR;
import io.ray.api.runtimeenv.RuntimeEnv;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@ -54,4 +55,9 @@ public class ActorCreator<A> extends BaseActorCreator<ActorCreator<A>> {
builder.setConcurrencyGroups(list);
return this;
}
public ActorCreator<A> setRuntimeEnv(RuntimeEnv runtimeEnv) {
builder.setRuntimeEnv(runtimeEnv);
return this;
}
}

View file

@ -3,6 +3,7 @@ package io.ray.api.options;
import io.ray.api.Ray;
import io.ray.api.concurrencygroup.ConcurrencyGroup;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimeenv.RuntimeEnv;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@ -18,6 +19,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
public final PlacementGroup group;
public final int bundleIndex;
public final List<ConcurrencyGroup> concurrencyGroups;
public final String serializedRuntimeEnv;
public final int maxPendingCalls;
private ActorCreationOptions(
@ -30,6 +32,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
PlacementGroup group,
int bundleIndex,
List<ConcurrencyGroup> concurrencyGroups,
String serializedRuntimeEnv,
int maxPendingCalls) {
super(resources);
this.name = name;
@ -40,6 +43,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
this.group = group;
this.bundleIndex = bundleIndex;
this.concurrencyGroups = concurrencyGroups;
this.serializedRuntimeEnv = serializedRuntimeEnv;
this.maxPendingCalls = maxPendingCalls;
}
@ -54,6 +58,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
private PlacementGroup group;
private int bundleIndex;
private List<ConcurrencyGroup> concurrencyGroups = new ArrayList<>();
private RuntimeEnv runtimeEnv = null;
private int maxPendingCalls = -1;
/**
@ -188,6 +193,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
group,
bundleIndex,
concurrencyGroups,
runtimeEnv != null ? runtimeEnv.toJsonBytes() : "",
maxPendingCalls);
}
@ -196,5 +202,10 @@ public class ActorCreationOptions extends BaseTaskOptions {
this.concurrencyGroups = concurrencyGroups;
return this;
}
public Builder setRuntimeEnv(RuntimeEnv runtimeEnv) {
this.runtimeEnv = runtimeEnv;
return this;
}
}
}

View file

@ -19,6 +19,7 @@ import io.ray.api.options.PlacementGroupCreationOptions;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimecontext.ResourceValue;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.api.runtimeenv.RuntimeEnv;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -280,4 +281,7 @@ public interface RayRuntime {
ConcurrencyGroup createConcurrencyGroup(String name, int maxConcurrency, List<RayFunc> funcs);
List<ConcurrencyGroup> extractConcurrencyGroups(RayFuncR<?> actorConstructorLambda);
/** Create runtime env instance at runtime. */
RuntimeEnv createRuntimeEnv(Map<String, String> envVars);
}

View file

@ -0,0 +1,25 @@
package io.ray.api.runtimeenv;
import io.ray.api.Ray;
import java.util.HashMap;
import java.util.Map;
/** This is an experimental API to let you set runtime environments for your actors. */
public interface RuntimeEnv {
String toJsonBytes();
public static class Builder {
private Map<String, String> envVars = new HashMap<>();
public Builder addEnvVar(String key, String value) {
envVars.put(key, value);
return this;
}
public RuntimeEnv build() {
return Ray.internal().createRuntimeEnv(envVars);
}
}
}

View file

@ -21,6 +21,7 @@ import io.ray.api.options.CallOptions;
import io.ray.api.options.PlacementGroupCreationOptions;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.config.RunMode;
import io.ray.runtime.context.RuntimeContextImpl;
@ -33,6 +34,7 @@ import io.ray.runtime.generated.Common;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.object.ObjectRefImpl;
import io.ray.runtime.object.ObjectStore;
import io.ray.runtime.runtimeenv.RuntimeEnvImpl;
import io.ray.runtime.task.ArgumentsBuilder;
import io.ray.runtime.task.FunctionArg;
import io.ray.runtime.task.TaskExecutor;
@ -40,6 +42,7 @@ import io.ray.runtime.task.TaskSubmitter;
import io.ray.runtime.util.ConcurrencyGroupUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
@ -285,6 +288,11 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
return ConcurrencyGroupUtils.extractConcurrencyGroupsByAnnotations(actorConstructorLambda);
}
@Override
public RuntimeEnv createRuntimeEnv(Map<String, String> envVars) {
return new RuntimeEnvImpl(envVars);
}
private ObjectRef callNormalFunction(
FunctionDescriptor functionDescriptor,
Object[] args,

View file

@ -146,7 +146,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
numWorkersPerProcess,
rayConfig.logDir,
serializedJobConfig,
rayConfig.getStartupToken());
rayConfig.getStartupToken(),
rayConfig.runtimeEnvHash);
taskExecutor = new NativeTaskExecutor(this);
workerContext = new NativeWorkerContext();
@ -278,7 +279,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
int numWorkersPerProcess,
String logDir,
byte[] serializedJobConfig,
int startupToken);
int startupToken,
int runtimeEnvHash);
private static native void nativeRunTaskExecutor(TaskExecutor taskExecutor);

View file

@ -52,6 +52,8 @@ public class RayConfig {
public int startupToken;
public int runtimeEnvHash;
public final ActorLifetime defaultActorLifetime;
public static class LoggerConf {
@ -201,6 +203,11 @@ public class RayConfig {
startupToken = config.getInt("ray.raylet.startup-token");
/// Driver needn't this config item.
if (workerMode == WorkerType.WORKER && config.hasPath("ray.internal.runtime-env-hash")) {
runtimeEnvHash = config.getInt("ray.internal.runtime-env-hash");
}
{
loggers = new ArrayList<>();
List<Config> loggerConfigs = (List<Config>) config.getConfigList("ray.logging.loggers");

View file

@ -0,0 +1,33 @@
package io.ray.runtime.runtimeenv;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.runtime.generated.RuntimeEnvCommon;
import java.util.HashMap;
import java.util.Map;
public class RuntimeEnvImpl implements RuntimeEnv {
private Map<String, String> envVars = new HashMap<>();
public RuntimeEnvImpl(Map<String, String> envVars) {
this.envVars = envVars;
}
@Override
public String toJsonBytes() {
if (!envVars.isEmpty()) {
RuntimeEnvCommon.RuntimeEnv.Builder protoRuntimeEnvBuilder =
RuntimeEnvCommon.RuntimeEnv.newBuilder();
protoRuntimeEnvBuilder.putAllEnvVars(envVars);
JsonFormat.Printer printer = JsonFormat.printer();
try {
return printer.print(protoRuntimeEnvBuilder);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}
return "{}";
}
}

View file

@ -0,0 +1,65 @@
package io.ray.test;
import io.ray.api.ActorHandle;
import io.ray.api.Ray;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.runtime.util.SystemUtil;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
@Test(groups = "cluster")
public class RuntimeEnvTest extends BaseTest {
@BeforeClass
public void setUp() {
/// This is used to test that actors with runtime envs will not reuse worker process.
System.setProperty("ray.job.num-java-workers-per-process", "2");
}
private static class A {
public String getEnv(String key) {
return System.getenv(key);
}
public int getPid() {
return SystemUtil.pid();
}
}
public void testEnvironmentVariable() {
int pid1 = 0;
int pid2 = 0;
{
RuntimeEnv runtimeEnv =
new RuntimeEnv.Builder()
.addEnvVar("KEY1", "A")
.addEnvVar("KEY2", "B")
.addEnvVar("KEY1", "C")
.build();
ActorHandle<A> actor1 = Ray.actor(A::new).setRuntimeEnv(runtimeEnv).remote();
String val = actor1.task(A::getEnv, "KEY1").remote().get();
Assert.assertEquals(val, "C");
val = actor1.task(A::getEnv, "KEY2").remote().get();
Assert.assertEquals(val, "B");
pid1 = actor1.task(A::getPid).remote().get();
}
{
/// Because we didn't set them for actor2 , all should be null.
ActorHandle<A> actor2 = Ray.actor(A::new).remote();
String val = actor2.task(A::getEnv, "KEY1").remote().get();
Assert.assertNull(val);
val = actor2.task(A::getEnv, "KEY2").remote().get();
Assert.assertNull(val);
pid2 = actor2.task(A::getPid).remote().get();
}
// actor1 and actor2 shouldn't be in one process because they have
// different runtime env.
Assert.assertNotEquals(pid1, pid2);
}
}

View file

@ -98,7 +98,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
JNIEnv *env, jclass, jint workerMode, jstring nodeIpAddress, jint nodeManagerPort,
jstring driverName, jstring storeSocket, jstring rayletSocket, jbyteArray jobId,
jobject gcsClientOptions, jint numWorkersPerProcess, jstring logDir,
jbyteArray jobConfig, jint startupToken) {
jbyteArray jobConfig, jint startupToken, jint runtimeEnvHash) {
auto task_execution_callback =
[](TaskType task_type, const std::string task_name, const RayFunction &ray_function,
const std::unordered_map<std::string, double> &required_resources,
@ -267,6 +267,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
options.serialized_job_config = serialized_job_config;
options.metrics_agent_port = -1;
options.startup_token = startupToken;
options.runtime_env_hash = runtimeEnvHash;
CoreWorkerProcess::Initialize(options);
}

View file

@ -25,11 +25,11 @@ extern "C" {
* Class: io_ray_runtime_RayNativeRuntime
* Method: nativeInitialize
* Signature:
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;[BI)V
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;[BII)V
*/
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject,
jint, jstring, jbyteArray, jint);
jint, jstring, jbyteArray, jint, jint);
/*
* Class: io_ray_runtime_RayNativeRuntime

View file

@ -151,6 +151,7 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env,
uint64_t max_concurrency = 1;
auto placement_options = std::make_pair(PlacementGroupID::Nil(), -1);
std::vector<ConcurrencyGroup> concurrency_groups;
std::string serialized_runtime_env = "";
int32_t max_pending_calls = -1;
if (actorCreationOptions) {
@ -223,6 +224,12 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env,
return ray::ConcurrencyGroup{concurrency_group_name, max_concurrency,
native_func_descriptors};
});
auto java_serialized_runtime_env = (jstring)env->GetObjectField(actorCreationOptions,
java_actor_creation_options_serialized_runtime_env);
if (java_serialized_runtime_env) {
serialized_runtime_env = JavaStringToNativeString(env, java_serialized_runtime_env);
}
max_pending_calls = static_cast<int32_t>(env->GetIntField(
actorCreationOptions, java_actor_creation_options_max_pending_calls));
}
@ -253,7 +260,7 @@ inline ActorCreationOptions ToActorCreationOptions(JNIEnv *env,
ray_namespace,
/*is_asyncio=*/false,
/*scheduling_strategy=*/scheduling_strategy,
/*serialized_runtime_env=*/"{}",
serialized_runtime_env,
concurrency_groups,
/*execute_out_of_order*/ false,
max_pending_calls};

View file

@ -105,6 +105,7 @@ jfieldID java_actor_creation_options_max_concurrency;
jfieldID java_actor_creation_options_group;
jfieldID java_actor_creation_options_bundle_index;
jfieldID java_actor_creation_options_concurrency_groups;
jfieldID java_actor_creation_options_serialized_runtime_env;
jfieldID java_actor_creation_options_max_pending_calls;
jclass java_actor_lifetime_class;
@ -332,6 +333,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
env->GetFieldID(java_actor_creation_options_class, "bundleIndex", "I");
java_actor_creation_options_concurrency_groups = env->GetFieldID(
java_actor_creation_options_class, "concurrencyGroups", "Ljava/util/List;");
java_actor_creation_options_serialized_runtime_env =
env->GetFieldID(java_actor_creation_options_class, "serializedRuntimeEnv", "Ljava/lang/String;");
java_actor_creation_options_max_pending_calls =
env->GetFieldID(java_actor_creation_options_class, "maxPendingCalls", "I");

View file

@ -189,6 +189,8 @@ extern jfieldID java_actor_creation_options_group;
extern jfieldID java_actor_creation_options_bundle_index;
/// concurrencyGroups field of ActorCreationOptions class
extern jfieldID java_actor_creation_options_concurrency_groups;
/// serializedRuntimeEnv field of ActorCreatrionOptions class
extern jfieldID java_actor_creation_options_serialized_runtime_env;
/// maxPendingCalls field of ActorCreationOptions class
extern jfieldID java_actor_creation_options_max_pending_calls;
/// ActorLifetime enum class

View file

@ -314,6 +314,8 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
if (language == Language::JAVA) {
options.push_back("-Dray.raylet.startup-token=" +
std::to_string(worker_startup_token_counter_));
options.push_back("-Dray.internal.runtime-env-hash=" +
std::to_string(runtime_env_hash));
}
// Append user-defined per-process options here

View file

@ -708,6 +708,7 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
// Ray-defined per-process options
expected_command.push_back(GetNumJavaWorkersPerProcessSystemProperty(1));
expected_command.push_back("-Dray.raylet.startup-token=0");
expected_command.push_back("-Dray.internal.runtime-env-hash=0");
// User-defined per-process options
expected_command.insert(expected_command.end(), actor_jvm_options.begin(),
actor_jvm_options.end());