[Placement Group]Enhance create placement group java api (#11702)

* enhance create pg java api

* add state for PlacementGroup

* fix comment

* move default pg

* make default pg name private

* add bundle size and bundle resource size check when placement group create
This commit is contained in:
DK.Pino 2020-11-05 09:59:36 +08:00 committed by GitHub
parent 69145d6215
commit 50110b934c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 236 additions and 52 deletions

View file

@ -245,10 +245,16 @@ public final class Ray extends RayCall {
* to be updated and rescheduled.
* This function only works when gcs actor manager is turned on.
*
* @param bundles Preallocated resource list.
* @param name Name of the Placement Group.
* @param bundles Pre-allocated resource list.
* @param strategy Actor placement strategy.
* @return A handle to the created placement group.
*/
public static PlacementGroup createPlacementGroup(String name,
List<Map<String, Double>> bundles, PlacementStrategy strategy) {
return internal().createPlacementGroup(name, bundles, strategy);
}
public static PlacementGroup createPlacementGroup(List<Map<String, Double>> bundles,
PlacementStrategy strategy) {
return internal().createPlacementGroup(bundles, strategy);

View file

@ -169,6 +169,9 @@ public interface RayRuntime {
PyActorHandle createActor(PyActorClass pyActorClass, Object[] args,
ActorCreationOptions options);
PlacementGroup createPlacementGroup(String name, List<Map<String, Double>> bundles,
PlacementStrategy strategy);
PlacementGroup createPlacementGroup(List<Map<String, Double>> bundles,
PlacementStrategy strategy);
@ -198,4 +201,5 @@ public interface RayRuntime {
* Intentionally exit the current actor.
*/
void exitActor();
}

View file

@ -51,6 +51,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
public static final String PYTHON_INIT_METHOD_NAME = "__init__";
private static final String DEFAULT_PLACEMENT_GROUP_NAME = "unnamed_group";
protected RayConfig rayConfig;
protected TaskExecutor taskExecutor;
protected FunctionManager functionManager;
@ -165,9 +166,22 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
}
@Override
public PlacementGroup createPlacementGroup(List<Map<String, Double>> bundles,
PlacementStrategy strategy) {
return taskSubmitter.createPlacementGroup(bundles, strategy);
public PlacementGroup createPlacementGroup(String name,
List<Map<String, Double>> bundles, PlacementStrategy strategy) {
boolean bundleResourceValid = bundles.stream().allMatch(
bundle -> bundle.values().stream().allMatch(resource -> resource > 0));
if (bundles.isEmpty() || !bundleResourceValid) {
throw new IllegalArgumentException(
"Bundles cannot be empty or bundle's resource must be positive.");
}
return taskSubmitter.createPlacementGroup(name, bundles, strategy);
}
@Override
public PlacementGroup createPlacementGroup(
List<Map<String, Double>> bundles, PlacementStrategy strategy) {
return createPlacementGroup(DEFAULT_PLACEMENT_GROUP_NAME, bundles, strategy);
}
@SuppressWarnings("unchecked")

View file

@ -1,28 +1,115 @@
package io.ray.runtime.placementgroup;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.placementgroup.PlacementStrategy;
import java.util.List;
import java.util.Map;
/**
* The default implementation of `PlacementGroup` interface.
*/
public class PlacementGroupImpl implements PlacementGroup {
private PlacementGroupId id;
private int bundleCount = 0;
private final PlacementGroupId id;
private final String name;
private final List<Map<String, Double>> bundles;
private final PlacementStrategy strategy;
private final PlacementGroupState state;
public PlacementGroupImpl() {
}
public PlacementGroupImpl(PlacementGroupId id, int bundleCount) {
private PlacementGroupImpl(PlacementGroupId id, String name,
List<Map<String, Double>> bundles,
PlacementStrategy strategy,
PlacementGroupState state) {
this.id = id;
this.bundleCount = bundleCount;
this.name = name;
this.bundles = bundles;
this.strategy = strategy;
this.state = state;
}
public PlacementGroupId getId() {
return id;
}
public int getBundleCount() {
return bundleCount;
public String getName() {
return name;
}
public List<Map<String, Double>> getBundles() {
return bundles;
}
public PlacementStrategy getStrategy() {
return strategy;
}
public PlacementGroupState getState() {
return state;
}
/**
* A help class for create the Placement Group.
*/
public static class Builder {
private PlacementGroupId id;
private String name;
private List<Map<String, Double>> bundles;
private PlacementStrategy strategy;
private PlacementGroupState state;
/**
* Set the Id of the Placement Group.
* @param id Id of the Placement Group.
* @return self.
*/
public Builder setId(PlacementGroupId id) {
this.id = id;
return this;
}
/**
* Set the name of the Placement Group.
* @param name Name of the Placement Group.
* @return self.
*/
public Builder setName(String name) {
this.name = name;
return this;
}
/**
* Set the bundles of the Placement Group.
* @param bundles the bundles of the Placement Group.
* @return self.
*/
public Builder setBundles(List<Map<String, Double>> bundles) {
this.bundles = bundles;
return this;
}
/**
* Set the placement strategy of the Placement Group.
* @param strategy the placement strategy of the Placement Group.
* @return self.
*/
public Builder setStrategy(PlacementStrategy strategy) {
this.strategy = strategy;
return this;
}
/**
* Set the placement state of the Placement Group.
* @param state the state of the Placement Group.
* @return self.
*/
public Builder setState(PlacementGroupState state) {
this.state = state;
return this;
}
public PlacementGroupImpl build() {
return new PlacementGroupImpl(id, name, bundles, strategy, state);
}
}
}

View file

@ -0,0 +1,32 @@
package io.ray.runtime.placementgroup;
/**
* State of Placement Group.
*/
public enum PlacementGroupState {
/**
* Wait for resource to schedule.
*/
PENDING(0),
/**
* The Placement Group has created on some node.
*/
CREATED(1),
/**
* The Placement Group has removed.
*/
REMOVED(2);
private int value = 0;
PlacementGroupState(int value) {
this.value = value;
}
public int value() {
return this.value;
}
}

View file

@ -30,7 +30,6 @@ import io.ray.runtime.generated.Common.TaskSpec;
import io.ray.runtime.generated.Common.TaskType;
import io.ray.runtime.object.LocalModeObjectStore;
import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.placementgroup.PlacementGroupId;
import io.ray.runtime.placementgroup.PlacementGroupImpl;
import java.nio.ByteBuffer;
import java.util.ArrayList;
@ -171,7 +170,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
if (options.group != null) {
PlacementGroupImpl group = (PlacementGroupImpl)options.group;
Preconditions.checkArgument(options.bundleIndex >= 0
&& options.bundleIndex < group.getBundleCount(),
&& options.bundleIndex < group.getBundles().size(),
String.format("Bundle index %s is invalid", options.bundleIndex));
}
}
@ -224,9 +223,10 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
@Override
public PlacementGroup createPlacementGroup(List<Map<String, Double>> bundles,
public PlacementGroup createPlacementGroup(String name, List<Map<String, Double>> bundles,
PlacementStrategy strategy) {
return new PlacementGroupImpl(PlacementGroupId.fromRandom(), bundles.size());
return new PlacementGroupImpl.Builder()
.setName(name).setBundles(bundles).setStrategy(strategy).build();
}
@Override

View file

@ -43,7 +43,7 @@ public class NativeTaskSubmitter implements TaskSubmitter {
if (options.group != null) {
PlacementGroupImpl group = (PlacementGroupImpl)options.group;
Preconditions.checkArgument(options.bundleIndex >= 0
&& options.bundleIndex < group.getBundleCount(),
&& options.bundleIndex < group.getBundles().size(),
String.format("Bundle index %s is invalid", options.bundleIndex));
}
@ -78,10 +78,12 @@ public class NativeTaskSubmitter implements TaskSubmitter {
}
@Override
public PlacementGroup createPlacementGroup(List<Map<String, Double>> bundles,
public PlacementGroup createPlacementGroup(String name, List<Map<String, Double>> bundles,
PlacementStrategy strategy) {
byte[] bytes = nativeCreatePlacementGroup(bundles, strategy.value());
return new PlacementGroupImpl(PlacementGroupId.fromBytes(bytes), bundles.size());
byte[] bytes = nativeCreatePlacementGroup(name, bundles, strategy.value());
return new PlacementGroupImpl.Builder()
.setId(PlacementGroupId.fromBytes(bytes))
.setName(name).setBundles(bundles).setStrategy(strategy).build();
}
private static native List<byte[]> nativeSubmitTask(FunctionDescriptor functionDescriptor,
@ -95,6 +97,6 @@ public class NativeTaskSubmitter implements TaskSubmitter {
FunctionDescriptor functionDescriptor, int functionDescriptorHash, List<FunctionArg> args,
int numReturns, CallOptions callOptions);
private static native byte[] nativeCreatePlacementGroup(List<Map<String, Double>> bundles,
int strategy);
private static native byte[] nativeCreatePlacementGroup(String name,
List<Map<String, Double>> bundles, int strategy);
}

View file

@ -52,11 +52,13 @@ public interface TaskSubmitter {
/**
* Create a placement group.
* @param bundles Preallocated resource list.
*
* @param name Name of the Placement Group.
* @param bundles Pre-allocated resource list.
* @param strategy Actor placement strategy.
* @return A handle to the created placement group.
*/
PlacementGroup createPlacementGroup(List<Map<String, Double>> bundles,
PlacementGroup createPlacementGroup(String name, List<Map<String, Double>> bundles,
PlacementStrategy strategy);
BaseActorHandle getActor(ActorId actorId);

View file

@ -4,11 +4,7 @@ import io.ray.api.ActorHandle;
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.placementgroup.PlacementStrategy;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import io.ray.runtime.placementgroup.PlacementGroupImpl;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -32,12 +28,8 @@ public class PlacementGroupTest extends BaseTest {
// This test just creates a placement group with one bundle.
// It's not comprehensive to test all placement group test cases.
public void testCreateAndCallActor() {
List<Map<String, Double>> bundles = new ArrayList<>();
Map<String, Double> bundle = new HashMap<>();
bundle.put("CPU", 1.0);
bundles.add(bundle);
PlacementStrategy strategy = PlacementStrategy.PACK;
PlacementGroup placementGroup = Ray.createPlacementGroup(bundles, strategy);
PlacementGroup placementGroup = PlacementGroupTestUtils.createSimpleGroup();
Assert.assertEquals(((PlacementGroupImpl)placementGroup).getName(),"unnamed_group");
// Test creating an actor from a constructor.
ActorHandle<Counter> actor = Ray.actor(Counter::new, 1)
@ -49,12 +41,7 @@ public class PlacementGroupTest extends BaseTest {
}
public void testCheckBundleIndex() {
List<Map<String, Double>> bundles = new ArrayList<>();
Map<String, Double> bundle = new HashMap<>();
bundle.put("CPU", 1.0);
bundles.add(bundle);
PlacementStrategy strategy = PlacementStrategy.PACK;
PlacementGroup placementGroup = Ray.createPlacementGroup(bundles, strategy);
PlacementGroup placementGroup = PlacementGroupTestUtils.createSimpleGroup();
int exceptionCount = 0;
try {
@ -64,7 +51,6 @@ public class PlacementGroupTest extends BaseTest {
}
Assert.assertEquals(1, exceptionCount);
try {
Ray.actor(Counter::new, 1).setPlacementGroup(placementGroup, -1).remote();
} catch (IllegalArgumentException e) {
@ -72,4 +58,14 @@ public class PlacementGroupTest extends BaseTest {
}
Assert.assertEquals(2, exceptionCount);
}
@Test (expectedExceptions = { IllegalArgumentException.class })
public void testBundleSizeValidCheckWhenCreate() {
PlacementGroupTestUtils.createBundleSizeInvalidGroup();
}
@Test (expectedExceptions = { IllegalArgumentException.class })
public void testBundleResourceValidCheckWhenCreate() {
PlacementGroupTestUtils.createBundleResourceInvalidGroup();
}
}

View file

@ -0,0 +1,41 @@
package io.ray.test;
import io.ray.api.Ray;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.placementgroup.PlacementStrategy;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* A utils class for Placement Group test.
*/
public class PlacementGroupTestUtils {
public static PlacementGroup createSpecifiedSimpleGroup(String resourceName, int bundleSize,
PlacementStrategy strategy, Double resourceSize) {
List<Map<String, Double>> bundles = new ArrayList<>();
for (int i = 0; i < bundleSize; i++) {
Map<String, Double> bundle = new HashMap<>();
bundle.put(resourceName, resourceSize);
bundles.add(bundle);
}
return Ray.createPlacementGroup(bundles, strategy);
}
public static PlacementGroup createSimpleGroup() {
return createSpecifiedSimpleGroup("CPU", 1, PlacementStrategy.PACK, 1.0);
}
public static void createBundleSizeInvalidGroup() {
createSpecifiedSimpleGroup("CPU", 0, PlacementStrategy.PACK, 1.0);
}
public static void createBundleResourceInvalidGroup() {
createSpecifiedSimpleGroup("CPU", 1, PlacementStrategy.PACK, 0.0);
}
}

View file

@ -185,7 +185,7 @@ inline ray::PlacementStrategy ConvertStrategy(jint java_strategy) {
}
inline ray::PlacementGroupCreationOptions ToPlacementGroupCreationOptions(
JNIEnv *env, jobject java_bundles, jint java_strategy) {
JNIEnv *env, jstring name, jobject java_bundles, jint java_strategy) {
std::vector<std::unordered_map<std::string, double>> bundles;
JavaListToNativeVector<std::unordered_map<std::string, double>>(
env, java_bundles, &bundles, [](JNIEnv *env, jobject java_bundle) {
@ -200,7 +200,8 @@ inline ray::PlacementGroupCreationOptions ToPlacementGroupCreationOptions(
return value;
});
});
return ray::PlacementGroupCreationOptions("", ConvertStrategy(java_strategy), bundles);
return ray::PlacementGroupCreationOptions(JavaStringToNativeString(env, name),
ConvertStrategy(java_strategy), bundles);
}
#ifdef __cplusplus
@ -272,11 +273,9 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(
}
JNIEXPORT jbyteArray JNICALL
Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup(JNIEnv *env,
jclass,
jobject bundles,
jint strategy) {
auto options = ToPlacementGroupCreationOptions(env, bundles, strategy);
Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup(
JNIEnv *env, jclass, jstring name, jobject bundles, jint strategy) {
auto options = ToPlacementGroupCreationOptions(env, name, bundles, strategy);
ray::PlacementGroupID placement_group_id;
auto status = ray::CoreWorkerProcess::GetCoreWorker().CreatePlacementGroup(
options, &placement_group_id);

View file

@ -55,11 +55,12 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(JNIEnv *, jcl
/*
* Class: io_ray_runtime_task_NativeTaskSubmitter
* Method: nativeCreatePlacementGroup
* Signature: (Ljava/util/List;I)[B
* Signature: (Ljava/lang/String;Ljava/util/List;I)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup(JNIEnv *, jclass,
jobject, jint);
jstring, jobject,
jint);
#ifdef __cplusplus
}