From 50110b934c24a7ab1ad08870905265777046d0d6 Mon Sep 17 00:00:00 2001 From: "DK.Pino" Date: Thu, 5 Nov 2020 09:59:36 +0800 Subject: [PATCH] [Placement Group]Enhance create placement group java api (#11702) * enhance create pg java api * add state for PlacementGroup * fix comment * move default pg * make default pg name private * add bundle size and bundle resource size check when placement group create --- java/api/src/main/java/io/ray/api/Ray.java | 8 +- .../java/io/ray/api/runtime/RayRuntime.java | 4 + .../io/ray/runtime/AbstractRayRuntime.java | 20 +++- .../placementgroup/PlacementGroupImpl.java | 105 ++++++++++++++++-- .../placementgroup/PlacementGroupState.java | 32 ++++++ .../runtime/task/LocalModeTaskSubmitter.java | 8 +- .../ray/runtime/task/NativeTaskSubmitter.java | 14 ++- .../io/ray/runtime/task/TaskSubmitter.java | 6 +- .../java/io/ray/test/PlacementGroupTest.java | 32 +++--- .../io/ray/test/PlacementGroupTestUtils.java | 41 +++++++ ...io_ray_runtime_task_NativeTaskSubmitter.cc | 13 +-- .../io_ray_runtime_task_NativeTaskSubmitter.h | 5 +- 12 files changed, 236 insertions(+), 52 deletions(-) create mode 100644 java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupState.java create mode 100644 java/test/src/main/java/io/ray/test/PlacementGroupTestUtils.java diff --git a/java/api/src/main/java/io/ray/api/Ray.java b/java/api/src/main/java/io/ray/api/Ray.java index 60c4b2978..34bd0a2b3 100644 --- a/java/api/src/main/java/io/ray/api/Ray.java +++ b/java/api/src/main/java/io/ray/api/Ray.java @@ -245,10 +245,16 @@ public final class Ray extends RayCall { * to be updated and rescheduled. * This function only works when gcs actor manager is turned on. * - * @param bundles Preallocated resource list. + * @param name Name of the Placement Group. + * @param bundles Pre-allocated resource list. * @param strategy Actor placement strategy. * @return A handle to the created placement group. */ + public static PlacementGroup createPlacementGroup(String name, + List> bundles, PlacementStrategy strategy) { + return internal().createPlacementGroup(name, bundles, strategy); + } + public static PlacementGroup createPlacementGroup(List> bundles, PlacementStrategy strategy) { return internal().createPlacementGroup(bundles, strategy); diff --git a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java index 19ec5390a..fbd1f95cf 100644 --- a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java @@ -169,6 +169,9 @@ public interface RayRuntime { PyActorHandle createActor(PyActorClass pyActorClass, Object[] args, ActorCreationOptions options); + PlacementGroup createPlacementGroup(String name, List> bundles, + PlacementStrategy strategy); + PlacementGroup createPlacementGroup(List> bundles, PlacementStrategy strategy); @@ -198,4 +201,5 @@ public interface RayRuntime { * Intentionally exit the current actor. */ void exitActor(); + } diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index c7b0df21b..f5b636244 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -51,6 +51,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class); public static final String PYTHON_INIT_METHOD_NAME = "__init__"; + private static final String DEFAULT_PLACEMENT_GROUP_NAME = "unnamed_group"; protected RayConfig rayConfig; protected TaskExecutor taskExecutor; protected FunctionManager functionManager; @@ -165,9 +166,22 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { } @Override - public PlacementGroup createPlacementGroup(List> bundles, - PlacementStrategy strategy) { - return taskSubmitter.createPlacementGroup(bundles, strategy); + public PlacementGroup createPlacementGroup(String name, + List> bundles, PlacementStrategy strategy) { + boolean bundleResourceValid = bundles.stream().allMatch( + bundle -> bundle.values().stream().allMatch(resource -> resource > 0)); + + if (bundles.isEmpty() || !bundleResourceValid) { + throw new IllegalArgumentException( + "Bundles cannot be empty or bundle's resource must be positive."); + } + return taskSubmitter.createPlacementGroup(name, bundles, strategy); + } + + @Override + public PlacementGroup createPlacementGroup( + List> bundles, PlacementStrategy strategy) { + return createPlacementGroup(DEFAULT_PLACEMENT_GROUP_NAME, bundles, strategy); } @SuppressWarnings("unchecked") diff --git a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java index 6a0fac180..be64bfefa 100644 --- a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java @@ -1,28 +1,115 @@ package io.ray.runtime.placementgroup; import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.placementgroup.PlacementStrategy; +import java.util.List; +import java.util.Map; /** * The default implementation of `PlacementGroup` interface. */ public class PlacementGroupImpl implements PlacementGroup { - private PlacementGroupId id; - private int bundleCount = 0; + private final PlacementGroupId id; + private final String name; + private final List> bundles; + private final PlacementStrategy strategy; + private final PlacementGroupState state; - public PlacementGroupImpl() { - } - - public PlacementGroupImpl(PlacementGroupId id, int bundleCount) { + private PlacementGroupImpl(PlacementGroupId id, String name, + List> bundles, + PlacementStrategy strategy, + PlacementGroupState state) { this.id = id; - this.bundleCount = bundleCount; + this.name = name; + this.bundles = bundles; + this.strategy = strategy; + this.state = state; } public PlacementGroupId getId() { return id; } - public int getBundleCount() { - return bundleCount; + public String getName() { + return name; } + + public List> getBundles() { + return bundles; + } + + public PlacementStrategy getStrategy() { + return strategy; + } + + public PlacementGroupState getState() { + return state; + } + + /** + * A help class for create the Placement Group. + */ + public static class Builder { + private PlacementGroupId id; + private String name; + private List> bundles; + private PlacementStrategy strategy; + private PlacementGroupState state; + + /** + * Set the Id of the Placement Group. + * @param id Id of the Placement Group. + * @return self. + */ + public Builder setId(PlacementGroupId id) { + this.id = id; + return this; + } + + /** + * Set the name of the Placement Group. + * @param name Name of the Placement Group. + * @return self. + */ + public Builder setName(String name) { + this.name = name; + return this; + } + + /** + * Set the bundles of the Placement Group. + * @param bundles the bundles of the Placement Group. + * @return self. + */ + public Builder setBundles(List> bundles) { + this.bundles = bundles; + return this; + } + + /** + * Set the placement strategy of the Placement Group. + * @param strategy the placement strategy of the Placement Group. + * @return self. + */ + public Builder setStrategy(PlacementStrategy strategy) { + this.strategy = strategy; + return this; + } + + /** + * Set the placement state of the Placement Group. + * @param state the state of the Placement Group. + * @return self. + */ + public Builder setState(PlacementGroupState state) { + this.state = state; + return this; + } + + public PlacementGroupImpl build() { + return new PlacementGroupImpl(id, name, bundles, strategy, state); + } + } + } diff --git a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupState.java b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupState.java new file mode 100644 index 000000000..ad6e017a5 --- /dev/null +++ b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupState.java @@ -0,0 +1,32 @@ +package io.ray.runtime.placementgroup; + +/** + * State of Placement Group. + */ +public enum PlacementGroupState { + + /** + * Wait for resource to schedule. + */ + PENDING(0), + + /** + * The Placement Group has created on some node. + */ + CREATED(1), + + /** + * The Placement Group has removed. + */ + REMOVED(2); + + private int value = 0; + + PlacementGroupState(int value) { + this.value = value; + } + + public int value() { + return this.value; + } +} diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index 276cb48b8..06a8d4af3 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -30,7 +30,6 @@ import io.ray.runtime.generated.Common.TaskSpec; import io.ray.runtime.generated.Common.TaskType; import io.ray.runtime.object.LocalModeObjectStore; import io.ray.runtime.object.NativeRayObject; -import io.ray.runtime.placementgroup.PlacementGroupId; import io.ray.runtime.placementgroup.PlacementGroupImpl; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -171,7 +170,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { if (options.group != null) { PlacementGroupImpl group = (PlacementGroupImpl)options.group; Preconditions.checkArgument(options.bundleIndex >= 0 - && options.bundleIndex < group.getBundleCount(), + && options.bundleIndex < group.getBundles().size(), String.format("Bundle index %s is invalid", options.bundleIndex)); } } @@ -224,9 +223,10 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { } @Override - public PlacementGroup createPlacementGroup(List> bundles, + public PlacementGroup createPlacementGroup(String name, List> bundles, PlacementStrategy strategy) { - return new PlacementGroupImpl(PlacementGroupId.fromRandom(), bundles.size()); + return new PlacementGroupImpl.Builder() + .setName(name).setBundles(bundles).setStrategy(strategy).build(); } @Override diff --git a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java index c153ad1f1..e6aec1b16 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java @@ -43,7 +43,7 @@ public class NativeTaskSubmitter implements TaskSubmitter { if (options.group != null) { PlacementGroupImpl group = (PlacementGroupImpl)options.group; Preconditions.checkArgument(options.bundleIndex >= 0 - && options.bundleIndex < group.getBundleCount(), + && options.bundleIndex < group.getBundles().size(), String.format("Bundle index %s is invalid", options.bundleIndex)); } @@ -78,10 +78,12 @@ public class NativeTaskSubmitter implements TaskSubmitter { } @Override - public PlacementGroup createPlacementGroup(List> bundles, + public PlacementGroup createPlacementGroup(String name, List> bundles, PlacementStrategy strategy) { - byte[] bytes = nativeCreatePlacementGroup(bundles, strategy.value()); - return new PlacementGroupImpl(PlacementGroupId.fromBytes(bytes), bundles.size()); + byte[] bytes = nativeCreatePlacementGroup(name, bundles, strategy.value()); + return new PlacementGroupImpl.Builder() + .setId(PlacementGroupId.fromBytes(bytes)) + .setName(name).setBundles(bundles).setStrategy(strategy).build(); } private static native List nativeSubmitTask(FunctionDescriptor functionDescriptor, @@ -95,6 +97,6 @@ public class NativeTaskSubmitter implements TaskSubmitter { FunctionDescriptor functionDescriptor, int functionDescriptorHash, List args, int numReturns, CallOptions callOptions); - private static native byte[] nativeCreatePlacementGroup(List> bundles, - int strategy); + private static native byte[] nativeCreatePlacementGroup(String name, + List> bundles, int strategy); } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java index f67b7f4d5..fd3b2212d 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java @@ -52,11 +52,13 @@ public interface TaskSubmitter { /** * Create a placement group. - * @param bundles Preallocated resource list. + * + * @param name Name of the Placement Group. + * @param bundles Pre-allocated resource list. * @param strategy Actor placement strategy. * @return A handle to the created placement group. */ - PlacementGroup createPlacementGroup(List> bundles, + PlacementGroup createPlacementGroup(String name, List> bundles, PlacementStrategy strategy); BaseActorHandle getActor(ActorId actorId); diff --git a/java/test/src/main/java/io/ray/test/PlacementGroupTest.java b/java/test/src/main/java/io/ray/test/PlacementGroupTest.java index 370a518c8..c933b16ef 100644 --- a/java/test/src/main/java/io/ray/test/PlacementGroupTest.java +++ b/java/test/src/main/java/io/ray/test/PlacementGroupTest.java @@ -4,11 +4,7 @@ import io.ray.api.ActorHandle; import io.ray.api.Ray; import io.ray.api.id.ActorId; import io.ray.api.placementgroup.PlacementGroup; -import io.ray.api.placementgroup.PlacementStrategy; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import io.ray.runtime.placementgroup.PlacementGroupImpl; import org.testng.Assert; import org.testng.annotations.Test; @@ -32,12 +28,8 @@ public class PlacementGroupTest extends BaseTest { // This test just creates a placement group with one bundle. // It's not comprehensive to test all placement group test cases. public void testCreateAndCallActor() { - List> bundles = new ArrayList<>(); - Map bundle = new HashMap<>(); - bundle.put("CPU", 1.0); - bundles.add(bundle); - PlacementStrategy strategy = PlacementStrategy.PACK; - PlacementGroup placementGroup = Ray.createPlacementGroup(bundles, strategy); + PlacementGroup placementGroup = PlacementGroupTestUtils.createSimpleGroup(); + Assert.assertEquals(((PlacementGroupImpl)placementGroup).getName(),"unnamed_group"); // Test creating an actor from a constructor. ActorHandle actor = Ray.actor(Counter::new, 1) @@ -49,12 +41,7 @@ public class PlacementGroupTest extends BaseTest { } public void testCheckBundleIndex() { - List> bundles = new ArrayList<>(); - Map bundle = new HashMap<>(); - bundle.put("CPU", 1.0); - bundles.add(bundle); - PlacementStrategy strategy = PlacementStrategy.PACK; - PlacementGroup placementGroup = Ray.createPlacementGroup(bundles, strategy); + PlacementGroup placementGroup = PlacementGroupTestUtils.createSimpleGroup(); int exceptionCount = 0; try { @@ -64,7 +51,6 @@ public class PlacementGroupTest extends BaseTest { } Assert.assertEquals(1, exceptionCount); - try { Ray.actor(Counter::new, 1).setPlacementGroup(placementGroup, -1).remote(); } catch (IllegalArgumentException e) { @@ -72,4 +58,14 @@ public class PlacementGroupTest extends BaseTest { } Assert.assertEquals(2, exceptionCount); } + + @Test (expectedExceptions = { IllegalArgumentException.class }) + public void testBundleSizeValidCheckWhenCreate() { + PlacementGroupTestUtils.createBundleSizeInvalidGroup(); + } + + @Test (expectedExceptions = { IllegalArgumentException.class }) + public void testBundleResourceValidCheckWhenCreate() { + PlacementGroupTestUtils.createBundleResourceInvalidGroup(); + } } diff --git a/java/test/src/main/java/io/ray/test/PlacementGroupTestUtils.java b/java/test/src/main/java/io/ray/test/PlacementGroupTestUtils.java new file mode 100644 index 000000000..48e419df1 --- /dev/null +++ b/java/test/src/main/java/io/ray/test/PlacementGroupTestUtils.java @@ -0,0 +1,41 @@ +package io.ray.test; + +import io.ray.api.Ray; +import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.placementgroup.PlacementStrategy; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A utils class for Placement Group test. + */ +public class PlacementGroupTestUtils { + + public static PlacementGroup createSpecifiedSimpleGroup(String resourceName, int bundleSize, + PlacementStrategy strategy, Double resourceSize) { + List> bundles = new ArrayList<>(); + + for (int i = 0; i < bundleSize; i++) { + Map bundle = new HashMap<>(); + bundle.put(resourceName, resourceSize); + bundles.add(bundle); + } + + return Ray.createPlacementGroup(bundles, strategy); + } + + public static PlacementGroup createSimpleGroup() { + return createSpecifiedSimpleGroup("CPU", 1, PlacementStrategy.PACK, 1.0); + } + + public static void createBundleSizeInvalidGroup() { + createSpecifiedSimpleGroup("CPU", 0, PlacementStrategy.PACK, 1.0); + } + + public static void createBundleResourceInvalidGroup() { + createSpecifiedSimpleGroup("CPU", 1, PlacementStrategy.PACK, 0.0); + } + +} diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index 3d67ca9b1..5d7f3d89c 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -185,7 +185,7 @@ inline ray::PlacementStrategy ConvertStrategy(jint java_strategy) { } inline ray::PlacementGroupCreationOptions ToPlacementGroupCreationOptions( - JNIEnv *env, jobject java_bundles, jint java_strategy) { + JNIEnv *env, jstring name, jobject java_bundles, jint java_strategy) { std::vector> bundles; JavaListToNativeVector>( env, java_bundles, &bundles, [](JNIEnv *env, jobject java_bundle) { @@ -200,7 +200,8 @@ inline ray::PlacementGroupCreationOptions ToPlacementGroupCreationOptions( return value; }); }); - return ray::PlacementGroupCreationOptions("", ConvertStrategy(java_strategy), bundles); + return ray::PlacementGroupCreationOptions(JavaStringToNativeString(env, name), + ConvertStrategy(java_strategy), bundles); } #ifdef __cplusplus @@ -272,11 +273,9 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask( } JNIEXPORT jbyteArray JNICALL -Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup(JNIEnv *env, - jclass, - jobject bundles, - jint strategy) { - auto options = ToPlacementGroupCreationOptions(env, bundles, strategy); +Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup( + JNIEnv *env, jclass, jstring name, jobject bundles, jint strategy) { + auto options = ToPlacementGroupCreationOptions(env, name, bundles, strategy); ray::PlacementGroupID placement_group_id; auto status = ray::CoreWorkerProcess::GetCoreWorker().CreatePlacementGroup( options, &placement_group_id); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h index 80f1aa004..8f0879529 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h @@ -55,11 +55,12 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(JNIEnv *, jcl /* * Class: io_ray_runtime_task_NativeTaskSubmitter * Method: nativeCreatePlacementGroup - * Signature: (Ljava/util/List;I)[B + * Signature: (Ljava/lang/String;Ljava/util/List;I)[B */ JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup(JNIEnv *, jclass, - jobject, jint); + jstring, jobject, + jint); #ifdef __cplusplus }