diff --git a/java/api/src/main/java/io/ray/api/Ray.java b/java/api/src/main/java/io/ray/api/Ray.java index 34bd0a2b3..7763fae43 100644 --- a/java/api/src/main/java/io/ray/api/Ray.java +++ b/java/api/src/main/java/io/ray/api/Ray.java @@ -1,5 +1,6 @@ package io.ray.api; +import io.ray.api.id.PlacementGroupId; import io.ray.api.id.UniqueId; import io.ray.api.placementgroup.PlacementGroup; import io.ray.api.placementgroup.PlacementStrategy; @@ -245,7 +246,7 @@ public final class Ray extends RayCall { * to be updated and rescheduled. * This function only works when gcs actor manager is turned on. * - * @param name Name of the Placement Group. + * @param name Name of the placement group. * @param bundles Pre-allocated resource list. * @param strategy Actor placement strategy. * @return A handle to the created placement group. @@ -271,4 +272,30 @@ public final class Ray extends RayCall { public static void exitActor() { runtime.exitActor(); } + + /** + * Get a placement group by placement group Id. + * @param id placement group id. + * @return The placement group. + */ + public static PlacementGroup getPlacementGroup(PlacementGroupId id) { + return internal().getPlacementGroup(id); + } + + /** + * Get all placement groups in this cluster. + * @return All placement groups. + */ + public static List getAllPlacementGroups() { + return internal().getAllPlacementGroups(); + } + + /** + * Remove a placement group by id. + * Throw RayException if remove failed. + * @param id Id of the placement group. + */ + public static void removePlacementGroup(PlacementGroupId id) { + internal().removePlacementGroup(id); + } } diff --git a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupId.java b/java/api/src/main/java/io/ray/api/id/PlacementGroupId.java similarity index 94% rename from java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupId.java rename to java/api/src/main/java/io/ray/api/id/PlacementGroupId.java index 46005c96e..dffdcae06 100644 --- a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupId.java +++ b/java/api/src/main/java/io/ray/api/id/PlacementGroupId.java @@ -1,6 +1,5 @@ -package io.ray.runtime.placementgroup; +package io.ray.api.id; -import io.ray.api.id.BaseId; import java.io.Serializable; import java.nio.ByteBuffer; import java.util.Arrays; diff --git a/java/api/src/main/java/io/ray/api/placementgroup/PlacementGroupState.java b/java/api/src/main/java/io/ray/api/placementgroup/PlacementGroupState.java new file mode 100644 index 000000000..c2aeae589 --- /dev/null +++ b/java/api/src/main/java/io/ray/api/placementgroup/PlacementGroupState.java @@ -0,0 +1,42 @@ +package io.ray.api.placementgroup; + +/** + * State of placement group. + */ +public enum PlacementGroupState { + + /** + * Wait for resource to schedule. + */ + PENDING(0), + + /** + * The placement group has created on some node. + */ + CREATED(1), + + /** + * The placement group has removed. + */ + REMOVED(2), + + /** + * The placement group is rescheduling. + */ + RESCHEDULING(3), + + /** + * Unrecognized state. + */ + UNRECOGNIZED(-1); + + private int value = 0; + + PlacementGroupState(int value) { + this.value = value; + } + + public int value() { + return this.value; + } +} diff --git a/java/api/src/main/java/io/ray/api/placementgroup/PlacementStrategy.java b/java/api/src/main/java/io/ray/api/placementgroup/PlacementStrategy.java index 1cdfb2733..1fc4036f5 100644 --- a/java/api/src/main/java/io/ray/api/placementgroup/PlacementStrategy.java +++ b/java/api/src/main/java/io/ray/api/placementgroup/PlacementStrategy.java @@ -23,7 +23,12 @@ public enum PlacementStrategy { * Places Bundles across distinct nodes. * The group is not allowed to deploy more than one bundle on a node. */ - STRICT_SPREAD(3); + STRICT_SPREAD(3), + + /** + * Unrecognized strategy. + */ + UNRECOGNIZED(-1); private int value = 0; diff --git a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java index fbd1f95cf..e689cea00 100644 --- a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java @@ -10,6 +10,7 @@ import io.ray.api.function.PyActorMethod; import io.ray.api.function.PyFunction; import io.ray.api.function.RayFunc; import io.ray.api.id.ActorId; +import io.ray.api.id.PlacementGroupId; import io.ray.api.id.UniqueId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; @@ -202,4 +203,22 @@ public interface RayRuntime { */ void exitActor(); + /** + * Get a placement group by id. + * @param id placement group id. + * @return The placement group. + */ + PlacementGroup getPlacementGroup(PlacementGroupId id); + + /** + * Get all placement groups in this cluster. + * @return All placement groups. + */ + List getAllPlacementGroups(); + + /** + * Remove a placement group by id. + * @param id Id of the placement group. + */ + void removePlacementGroup(PlacementGroupId id); } diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index f5b636244..2eae3b647 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -14,6 +14,7 @@ import io.ray.api.function.PyFunction; import io.ray.api.function.RayFunc; import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; +import io.ray.api.id.PlacementGroupId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; import io.ray.api.placementgroup.PlacementGroup; @@ -184,6 +185,21 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return createPlacementGroup(DEFAULT_PLACEMENT_GROUP_NAME, bundles, strategy); } + @Override + public void removePlacementGroup(PlacementGroupId id) { + taskSubmitter.removePlacementGroup(id); + } + + @Override + public PlacementGroup getPlacementGroup(PlacementGroupId id) { + return gcsClient.getPlacementGroupInfo(id); + } + + @Override + public List getAllPlacementGroups() { + return gcsClient.getAllPlacementGroupInfo(); + } + @SuppressWarnings("unchecked") @Override public T getActorHandle(ActorId actorId) { diff --git a/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java index 26fa0430b..281677787 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayDevRuntime.java @@ -3,12 +3,15 @@ package io.ray.runtime; import com.google.common.base.Preconditions; import io.ray.api.BaseActorHandle; import io.ray.api.id.JobId; +import io.ray.api.id.PlacementGroupId; import io.ray.api.id.UniqueId; +import io.ray.api.placementgroup.PlacementGroup; import io.ray.runtime.config.RayConfig; import io.ray.runtime.context.LocalModeWorkerContext; import io.ray.runtime.object.LocalModeObjectStore; import io.ray.runtime.task.LocalModeTaskExecutor; import io.ray.runtime.task.LocalModeTaskSubmitter; +import java.util.List; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import org.slf4j.Logger; @@ -84,6 +87,21 @@ public class RayDevRuntime extends AbstractRayRuntime { super.setAsyncContext(asyncContext); } + @Override + public PlacementGroup getPlacementGroup( + PlacementGroupId id) { + //@TODO(clay4444): We need a LocalGcsClient before implements this. + throw new UnsupportedOperationException( + "Ray doesn't support placement group operations in local mode."); + } + + @Override + public List getAllPlacementGroups() { + //@TODO(clay4444): We need a LocalGcsClient before implements this. + throw new UnsupportedOperationException( + "Ray doesn't support placement group operations in local mode."); + } + @Override public void exitActor() { diff --git a/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java index 22e06ad96..97d98b137 100644 --- a/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java @@ -6,13 +6,16 @@ import io.ray.api.Checkpointable.Checkpoint; import io.ray.api.id.ActorId; import io.ray.api.id.BaseId; import io.ray.api.id.JobId; +import io.ray.api.id.PlacementGroupId; import io.ray.api.id.TaskId; import io.ray.api.id.UniqueId; +import io.ray.api.placementgroup.PlacementGroup; import io.ray.api.runtimecontext.NodeInfo; import io.ray.runtime.generated.Gcs; import io.ray.runtime.generated.Gcs.ActorCheckpointIdData; import io.ray.runtime.generated.Gcs.GcsNodeInfo; import io.ray.runtime.generated.Gcs.TablePrefix; +import io.ray.runtime.placementgroup.PlacementGroupUtils; import io.ray.runtime.util.IdUtil; import java.util.ArrayList; import java.util.HashMap; @@ -52,6 +55,30 @@ public class GcsClient { globalStateAccessor = GlobalStateAccessor.getInstance(redisAddress, redisPassword); } + /** + * Get placement group by {@link PlacementGroupId} + * @param placementGroupId Id of placement group. + * @return The placement group. + */ + public PlacementGroup getPlacementGroupInfo(PlacementGroupId placementGroupId) { + byte[] result = globalStateAccessor.getPlacementGroupInfo(placementGroupId); + return PlacementGroupUtils.generatePlacementGroupFromByteArray(result); + } + + /** + * Get all placement groups in this cluster. + * @return All placement groups. + */ + public List getAllPlacementGroupInfo() { + List results = globalStateAccessor.getAllPlacementGroupInfo(); + + List placementGroups = new ArrayList<>(); + for (byte[] result : results) { + placementGroups.add(PlacementGroupUtils.generatePlacementGroupFromByteArray(result)); + } + return placementGroups; + } + public List getAllNodeInfo() { List results = globalStateAccessor.getAllNodeInfo(); diff --git a/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java b/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java index c0116ec4c..4d57256b9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java +++ b/java/runtime/src/main/java/io/ray/runtime/gcs/GlobalStateAccessor.java @@ -2,6 +2,7 @@ package io.ray.runtime.gcs; import com.google.common.base.Preconditions; import io.ray.api.id.ActorId; +import io.ray.api.id.PlacementGroupId; import io.ray.api.id.UniqueId; import java.util.List; @@ -33,8 +34,7 @@ public class GlobalStateAccessor { private GlobalStateAccessor(String redisAddress, String redisPassword) { globalStateAccessorNativePointer = nativeCreateGlobalStateAccessor(redisAddress, redisPassword); - Preconditions.checkState(globalStateAccessorNativePointer != 0, - "Global state accessor native pointer must not be 0."); + validateGlobalStateAccessorPointer(); connect(); } @@ -42,14 +42,18 @@ public class GlobalStateAccessor { return this.nativeConnect(globalStateAccessorNativePointer); } + private void validateGlobalStateAccessorPointer() { + Preconditions.checkState(globalStateAccessorNativePointer != 0, + "Global state accessor native pointer must not be 0."); + } + /** * @return A list of job info with JobInfo protobuf schema. */ public List getAllJobInfo() { // Fetch a job list with protobuf bytes format from GCS. synchronized (GlobalStateAccessor.class) { - Preconditions.checkState(globalStateAccessorNativePointer != 0, - "Get all job info when global state accessor have been destroyed."); + validateGlobalStateAccessorPointer(); return this.nativeGetAllJobInfo(globalStateAccessorNativePointer); } } @@ -60,8 +64,7 @@ public class GlobalStateAccessor { public List getAllNodeInfo() { // Fetch a node list with protobuf bytes format from GCS. synchronized (GlobalStateAccessor.class) { - Preconditions.checkState(globalStateAccessorNativePointer != 0, - "Get all node info when global state accessor have been destroyed."); + validateGlobalStateAccessorPointer(); return this.nativeGetAllNodeInfo(globalStateAccessorNativePointer); } } @@ -72,16 +75,30 @@ public class GlobalStateAccessor { */ public byte[] getNodeResourceInfo(UniqueId nodeId) { synchronized (GlobalStateAccessor.class) { - Preconditions.checkState(globalStateAccessorNativePointer != 0, - "Get resource info by node id when global state accessor have been destroyed."); + validateGlobalStateAccessorPointer(); return nativeGetNodeResourceInfo(globalStateAccessorNativePointer, nodeId.getBytes()); } } + public byte[] getPlacementGroupInfo(PlacementGroupId placementGroupId) { + synchronized (GlobalStateAccessor.class) { + Preconditions.checkNotNull(placementGroupId, + "PlacementGroupId can't be null when get placement group info."); + return nativeGetPlacementGroupInfo(globalStateAccessorNativePointer, + placementGroupId.getBytes()); + } + } + + public List getAllPlacementGroupInfo() { + synchronized (GlobalStateAccessor.class) { + validateGlobalStateAccessorPointer(); + return this.nativeGetAllPlacementGroupInfo(globalStateAccessorNativePointer); + } + } + public byte[] getInternalConfig() { synchronized (GlobalStateAccessor.class) { - Preconditions.checkState(globalStateAccessorNativePointer != 0, - "Get internal config when global state accessor have been destroyed."); + validateGlobalStateAccessorPointer(); return nativeGetInternalConfig(globalStateAccessorNativePointer); } } @@ -92,7 +109,7 @@ public class GlobalStateAccessor { public List getAllActorInfo() { // Fetch a actor list with protobuf bytes format from GCS. synchronized (GlobalStateAccessor.class) { - Preconditions.checkState(globalStateAccessorNativePointer != 0); + validateGlobalStateAccessorPointer(); return this.nativeGetAllActorInfo(globalStateAccessorNativePointer); } } @@ -103,7 +120,7 @@ public class GlobalStateAccessor { public byte[] getActorInfo(ActorId actorId) { // Fetch an actor with protobuf bytes format from GCS. synchronized (GlobalStateAccessor.class) { - Preconditions.checkState(globalStateAccessorNativePointer != 0); + validateGlobalStateAccessorPointer(); return this.nativeGetActorInfo(globalStateAccessorNativePointer, actorId.getBytes()); } } @@ -114,7 +131,7 @@ public class GlobalStateAccessor { public byte[] getActorCheckpointId(ActorId actorId) { // Fetch an actor checkpoint id with protobuf bytes format from GCS. synchronized (GlobalStateAccessor.class) { - Preconditions.checkState(globalStateAccessorNativePointer != 0); + validateGlobalStateAccessorPointer(); return this.nativeGetActorCheckpointId(globalStateAccessorNativePointer, actorId.getBytes()); } } @@ -148,4 +165,9 @@ public class GlobalStateAccessor { private native byte[] nativeGetActorInfo(long nativePtr, byte[] actorId); private native byte[] nativeGetActorCheckpointId(long nativePtr, byte[] actorId); + + private native byte[] nativeGetPlacementGroupInfo(long nativePtr, + byte[] placementGroupId); + + private native List nativeGetAllPlacementGroupInfo(long nativePtr); } diff --git a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java index be64bfefa..633bad98c 100644 --- a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java @@ -1,6 +1,8 @@ package io.ray.runtime.placementgroup; +import io.ray.api.id.PlacementGroupId; import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.placementgroup.PlacementGroupState; import io.ray.api.placementgroup.PlacementStrategy; import java.util.List; import java.util.Map; @@ -48,7 +50,7 @@ public class PlacementGroupImpl implements PlacementGroup { } /** - * A help class for create the Placement Group. + * A help class for create the placement group. */ public static class Builder { private PlacementGroupId id; @@ -58,8 +60,8 @@ public class PlacementGroupImpl implements PlacementGroup { private PlacementGroupState state; /** - * Set the Id of the Placement Group. - * @param id Id of the Placement Group. + * Set the Id of the placement group. + * @param id Id of the placement group. * @return self. */ public Builder setId(PlacementGroupId id) { @@ -68,8 +70,8 @@ public class PlacementGroupImpl implements PlacementGroup { } /** - * Set the name of the Placement Group. - * @param name Name of the Placement Group. + * Set the name of the placement group. + * @param name Name of the placement group. * @return self. */ public Builder setName(String name) { @@ -78,8 +80,8 @@ public class PlacementGroupImpl implements PlacementGroup { } /** - * Set the bundles of the Placement Group. - * @param bundles the bundles of the Placement Group. + * Set the bundles of the placement group. + * @param bundles the bundles of the placement group. * @return self. */ public Builder setBundles(List> bundles) { @@ -88,8 +90,8 @@ public class PlacementGroupImpl implements PlacementGroup { } /** - * Set the placement strategy of the Placement Group. - * @param strategy the placement strategy of the Placement Group. + * Set the placement strategy of the placement group. + * @param strategy the placement strategy of the placement group. * @return self. */ public Builder setStrategy(PlacementStrategy strategy) { @@ -98,8 +100,8 @@ public class PlacementGroupImpl implements PlacementGroup { } /** - * Set the placement state of the Placement Group. - * @param state the state of the Placement Group. + * Set the placement state of the placement group. + * @param state the state of the placement group. * @return self. */ public Builder setState(PlacementGroupState state) { diff --git a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupState.java b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupState.java deleted file mode 100644 index ad6e017a5..000000000 --- a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupState.java +++ /dev/null @@ -1,32 +0,0 @@ -package io.ray.runtime.placementgroup; - -/** - * State of Placement Group. - */ -public enum PlacementGroupState { - - /** - * Wait for resource to schedule. - */ - PENDING(0), - - /** - * The Placement Group has created on some node. - */ - CREATED(1), - - /** - * The Placement Group has removed. - */ - REMOVED(2); - - private int value = 0; - - PlacementGroupState(int value) { - this.value = value; - } - - public int value() { - return this.value; - } -} diff --git a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupUtils.java b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupUtils.java new file mode 100644 index 000000000..0211d6a35 --- /dev/null +++ b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupUtils.java @@ -0,0 +1,108 @@ +package io.ray.runtime.placementgroup; + +import com.google.common.base.Preconditions; +import com.google.protobuf.InvalidProtocolBufferException; +import io.ray.api.id.PlacementGroupId; +import io.ray.api.placementgroup.PlacementGroupState; +import io.ray.api.placementgroup.PlacementStrategy; +import io.ray.runtime.generated.Common; +import io.ray.runtime.generated.Common.Bundle; +import io.ray.runtime.generated.Gcs.PlacementGroupTableData; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Utils for placement group. + */ +public class PlacementGroupUtils { + + private static List> covertToUserSpecifiedBundles(List bundles) { + List> result = new ArrayList<>(); + + // NOTE(clay4444): We need to guarantee the order here. + for (int i = 0; i < bundles.size(); i++) { + Bundle bundle = bundles.get(i); + result.add(bundle.getUnitResourcesMap()); + } + return result; + } + + private static PlacementStrategy covertToUserSpecifiedStrategy( + Common.PlacementStrategy strategy) { + switch (strategy) { + case PACK: + return PlacementStrategy.PACK; + case STRICT_PACK: + return PlacementStrategy.STRICT_PACK; + case SPREAD: + return PlacementStrategy.SPREAD; + case STRICT_SPREAD: + return PlacementStrategy.STRICT_SPREAD; + default: + return PlacementStrategy.UNRECOGNIZED; + } + } + + private static PlacementGroupState covertToUserSpecifiedState( + PlacementGroupTableData.PlacementGroupState state) { + switch (state) { + case PENDING: + return PlacementGroupState.PENDING; + case CREATED: + return PlacementGroupState.CREATED; + case REMOVED: + return PlacementGroupState.REMOVED; + case RESCHEDULING: + return PlacementGroupState.RESCHEDULING; + default: + return PlacementGroupState.UNRECOGNIZED; + } + } + + /** + * Generate a PlacementGroupImpl from placementGroupTableData protobuf data. + * @param placementGroupTableData protobuf data. + * @return placement group info {@link PlacementGroupImpl} + */ + private static PlacementGroupImpl generatePlacementGroupFromPbData( + PlacementGroupTableData placementGroupTableData) { + + PlacementGroupState state = covertToUserSpecifiedState( + placementGroupTableData.getState()); + PlacementStrategy strategy = covertToUserSpecifiedStrategy( + placementGroupTableData.getStrategy()); + + List> bundles = covertToUserSpecifiedBundles( + placementGroupTableData.getBundlesList()); + + PlacementGroupId placementGroupId = PlacementGroupId.fromByteBuffer( + placementGroupTableData.getPlacementGroupId().asReadOnlyByteBuffer()); + + return new PlacementGroupImpl.Builder() + .setId(placementGroupId).setName(placementGroupTableData.getName()) + .setState(state).setStrategy(strategy).setBundles(bundles) + .build(); + } + + /** + * Generate a PlacementGroupImpl from byte array. + * @param placementGroupByteArray bytes array from native method. + * @return placement group info {@link PlacementGroupImpl} + */ + public static PlacementGroupImpl generatePlacementGroupFromByteArray( + byte[] placementGroupByteArray) { + Preconditions.checkNotNull(placementGroupByteArray, + "Can't generate a placement group from empty byte array."); + + PlacementGroupTableData placementGroupTableData; + try { + placementGroupTableData = PlacementGroupTableData.parseFrom(placementGroupByteArray); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException( + "Received invalid placement group table protobuf data from GCS.", e); + } + + return generatePlacementGroupFromPbData(placementGroupTableData); + } +} diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index 06a8d4af3..53d7d2ae2 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -8,6 +8,7 @@ import io.ray.api.BaseActorHandle; import io.ray.api.Ray; import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; +import io.ray.api.id.PlacementGroupId; import io.ray.api.id.TaskId; import io.ray.api.id.UniqueId; import io.ray.api.options.ActorCreationOptions; @@ -75,6 +76,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { private final Map actorContexts = new ConcurrentHashMap<>(); + private final Map placementGroups = new ConcurrentHashMap<>(); + public LocalModeTaskSubmitter(RayRuntimeInternal runtime, TaskExecutor taskExecutor, LocalModeObjectStore objectStore) { this.runtime = runtime; @@ -225,8 +228,16 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { @Override public PlacementGroup createPlacementGroup(String name, List> bundles, PlacementStrategy strategy) { - return new PlacementGroupImpl.Builder() - .setName(name).setBundles(bundles).setStrategy(strategy).build(); + PlacementGroupImpl placementGroup = new PlacementGroupImpl.Builder() + .setId(PlacementGroupId.fromRandom()).setName(name) + .setBundles(bundles).setStrategy(strategy).build(); + placementGroups.put(placementGroup.getId(), placementGroup); + return placementGroup; + } + + @Override + public void removePlacementGroup(PlacementGroupId id) { + placementGroups.remove(id); } @Override diff --git a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java index e6aec1b16..dd2def600 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java @@ -6,13 +6,13 @@ import io.ray.api.BaseActorHandle; import io.ray.api.Ray; import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; +import io.ray.api.id.PlacementGroupId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; import io.ray.api.placementgroup.PlacementGroup; import io.ray.api.placementgroup.PlacementStrategy; import io.ray.runtime.actor.NativeActorHandle; import io.ray.runtime.functionmanager.FunctionDescriptor; -import io.ray.runtime.placementgroup.PlacementGroupId; import io.ray.runtime.placementgroup.PlacementGroupImpl; import java.util.List; import java.util.Map; @@ -86,6 +86,11 @@ public class NativeTaskSubmitter implements TaskSubmitter { .setName(name).setBundles(bundles).setStrategy(strategy).build(); } + @Override + public void removePlacementGroup(PlacementGroupId id) { + nativeRemovePlacementGroup(id.getBytes()); + } + private static native List nativeSubmitTask(FunctionDescriptor functionDescriptor, int functionDescriptorHash, List args, int numReturns, CallOptions callOptions); @@ -99,4 +104,7 @@ public class NativeTaskSubmitter implements TaskSubmitter { private static native byte[] nativeCreatePlacementGroup(String name, List> bundles, int strategy); + + private static native void nativeRemovePlacementGroup(byte[] placementGroupId); + } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java index fd3b2212d..5c172caf9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java @@ -3,6 +3,7 @@ package io.ray.runtime.task; import io.ray.api.BaseActorHandle; import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; +import io.ray.api.id.PlacementGroupId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; import io.ray.api.placementgroup.PlacementGroup; @@ -53,7 +54,7 @@ public interface TaskSubmitter { /** * Create a placement group. * - * @param name Name of the Placement Group. + * @param name Name of the placement group. * @param bundles Pre-allocated resource list. * @param strategy Actor placement strategy. * @return A handle to the created placement group. @@ -61,6 +62,12 @@ public interface TaskSubmitter { PlacementGroup createPlacementGroup(String name, List> bundles, PlacementStrategy strategy); + /** + * Remove a placement group by id. + * @param id Id of the placement group. + */ + void removePlacementGroup(PlacementGroupId id); + BaseActorHandle getActor(ActorId actorId); } diff --git a/java/test/src/main/java/io/ray/test/PlacementGroupTest.java b/java/test/src/main/java/io/ray/test/PlacementGroupTest.java index c933b16ef..39fea16e9 100644 --- a/java/test/src/main/java/io/ray/test/PlacementGroupTest.java +++ b/java/test/src/main/java/io/ray/test/PlacementGroupTest.java @@ -4,7 +4,10 @@ import io.ray.api.ActorHandle; import io.ray.api.Ray; import io.ray.api.id.ActorId; import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.placementgroup.PlacementGroupState; +import io.ray.api.placementgroup.PlacementStrategy; import io.ray.runtime.placementgroup.PlacementGroupImpl; +import java.util.List; import org.testng.Assert; import org.testng.annotations.Test; @@ -40,6 +43,62 @@ public class PlacementGroupTest extends BaseTest { Assert.assertEquals(Integer.valueOf(1), actor.task(Counter::getValue).remote().get()); } + @Test(groups = {"cluster"}) + public void testGetPlacementGroup() { + PlacementGroupImpl firstPlacementGroup = (PlacementGroupImpl)PlacementGroupTestUtils + .createNameSpecifiedSimpleGroup("CPU", 1, PlacementStrategy.PACK, + 1.0, "first_placement_group"); + + PlacementGroupImpl secondPlacementGroup = (PlacementGroupImpl)PlacementGroupTestUtils + .createNameSpecifiedSimpleGroup("CPU", 1, PlacementStrategy.PACK, + 1.0, "second_placement_group"); + + PlacementGroupImpl firstPlacementGroupRes = + (PlacementGroupImpl)Ray.getPlacementGroup((firstPlacementGroup).getId()); + PlacementGroupImpl secondPlacementGroupRes = + (PlacementGroupImpl)Ray.getPlacementGroup((secondPlacementGroup).getId()); + + Assert.assertNotNull(firstPlacementGroupRes); + Assert.assertNotNull(secondPlacementGroupRes); + + Assert.assertEquals(firstPlacementGroup.getId(), firstPlacementGroupRes.getId()); + Assert.assertEquals(firstPlacementGroup.getName(), firstPlacementGroupRes.getName()); + Assert.assertEquals(firstPlacementGroupRes.getBundles().size(), 1); + Assert.assertEquals(firstPlacementGroupRes.getStrategy(), PlacementStrategy.PACK); + + List allPlacementGroup = Ray.getAllPlacementGroups(); + Assert.assertEquals(allPlacementGroup.size(), 2); + + PlacementGroupImpl placementGroupRes = (PlacementGroupImpl)allPlacementGroup.get(0); + Assert.assertNotNull(placementGroupRes.getId()); + PlacementGroupImpl expectPlacementGroup = placementGroupRes.getId() + .equals(firstPlacementGroup.getId()) ? firstPlacementGroup : secondPlacementGroup; + + Assert.assertEquals(placementGroupRes.getName(), expectPlacementGroup.getName()); + Assert.assertEquals(placementGroupRes.getBundles().size(), + expectPlacementGroup.getBundles().size()); + Assert.assertEquals(placementGroupRes.getStrategy(), expectPlacementGroup.getStrategy()); + } + + @Test(groups = {"cluster"}) + public void testRemovePlacementGroup() { + PlacementGroupTestUtils.createNameSpecifiedSimpleGroup("CPU", + 1, PlacementStrategy.PACK, 1.0, "first_placement_group"); + + PlacementGroupImpl secondPlacementGroup = (PlacementGroupImpl)PlacementGroupTestUtils + .createNameSpecifiedSimpleGroup("CPU", 1, PlacementStrategy.PACK, + 1.0, "second_placement_group"); + + List allPlacementGroup = Ray.getAllPlacementGroups(); + Assert.assertEquals(allPlacementGroup.size(), 2); + + Ray.removePlacementGroup(secondPlacementGroup.getId()); + + PlacementGroupImpl removedPlacementGroup = + (PlacementGroupImpl)Ray.getPlacementGroup((secondPlacementGroup).getId()); + Assert.assertEquals(removedPlacementGroup.getState(), PlacementGroupState.REMOVED); + } + public void testCheckBundleIndex() { PlacementGroup placementGroup = PlacementGroupTestUtils.createSimpleGroup(); diff --git a/java/test/src/main/java/io/ray/test/PlacementGroupTestUtils.java b/java/test/src/main/java/io/ray/test/PlacementGroupTestUtils.java index 48e419df1..72ebd6629 100644 --- a/java/test/src/main/java/io/ray/test/PlacementGroupTestUtils.java +++ b/java/test/src/main/java/io/ray/test/PlacementGroupTestUtils.java @@ -9,12 +9,12 @@ import java.util.List; import java.util.Map; /** - * A utils class for Placement Group test. + * A utils class for placement group test. */ public class PlacementGroupTestUtils { - public static PlacementGroup createSpecifiedSimpleGroup(String resourceName, int bundleSize, - PlacementStrategy strategy, Double resourceSize) { + public static PlacementGroup createNameSpecifiedSimpleGroup(String resourceName, int bundleSize, + PlacementStrategy strategy, Double resourceSize, String groupName) { List> bundles = new ArrayList<>(); for (int i = 0; i < bundleSize; i++) { @@ -23,7 +23,13 @@ public class PlacementGroupTestUtils { bundles.add(bundle); } - return Ray.createPlacementGroup(bundles, strategy); + return Ray.createPlacementGroup(groupName, bundles, strategy); + } + + public static PlacementGroup createSpecifiedSimpleGroup(String resourceName, int bundleSize, + PlacementStrategy strategy, Double resourceSize) { + return createNameSpecifiedSimpleGroup(resourceName, bundleSize, strategy, + resourceSize, "unnamed_group"); } public static PlacementGroup createSimpleGroup() { diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc index 8fe43deef..ec94fa334 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.cc @@ -129,6 +129,32 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetActorCheckpointId( return nullptr; } +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetPlacementGroupInfo( + JNIEnv *env, jobject o, jlong gcs_accessor_ptr, jbyteArray placement_group_id_bytes) { + const auto placement_group_id = + JavaByteArrayToId(env, placement_group_id_bytes); + auto *gcs_accessor = + reinterpret_cast(gcs_accessor_ptr); + auto placement_group = gcs_accessor->GetPlacementGroupInfo(placement_group_id); + if (placement_group) { + return NativeStringToJavaByteArray(env, *placement_group); + } + return nullptr; +} + +JNIEXPORT jobject JNICALL +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllPlacementGroupInfo( + JNIEnv *env, jobject o, jlong gcs_accessor_ptr) { + auto *gcs_accessor = + reinterpret_cast(gcs_accessor_ptr); + auto placement_group_info_list = gcs_accessor->GetAllPlacementGroupInfo(); + return NativeVectorToJavaList( + env, placement_group_info_list, [](JNIEnv *env, const std::string &str) { + return NativeStringToJavaByteArray(env, str); + }); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h index 51ce96b83..0bc2dd19b 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_gcs_GlobalStateAccessor.h @@ -112,6 +112,26 @@ JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetActorCheckpointId(JNIEnv *, jobject, jlong, jbyteArray); +/* + * Class: io_ray_runtime_gcs_GlobalStateAccessor + * Method: nativeGetPlacementGroupInfo + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetPlacementGroupInfo(JNIEnv *, jobject, + jlong, + jbyteArray); + +/* + * Class: io_ray_runtime_gcs_GlobalStateAccessor + * Method: nativeGetAllPlacementGroupInfo + * Signature: (J)Ljava/util/List; + */ +JNIEXPORT jobject JNICALL +Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeGetAllPlacementGroupInfo(JNIEnv *, + jobject, + jlong); + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index 5d7f3d89c..9115945d2 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -283,6 +283,16 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup( return IdToJavaByteArray(env, placement_group_id); } +JNIEXPORT void JNICALL +Java_io_ray_runtime_task_NativeTaskSubmitter_nativeRemovePlacementGroup( + JNIEnv *env, jclass p, jbyteArray placement_group_id_bytes) { + const auto placement_group_id = + JavaByteArrayToId(env, placement_group_id_bytes); + auto status = + ray::CoreWorkerProcess::GetCoreWorker().RemovePlacementGroup(placement_group_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h index 8f0879529..33a46806e 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h @@ -62,6 +62,15 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup(JNIEnv * jstring, jobject, jint); +/* + * Class: io_ray_runtime_task_NativeTaskSubmitter + * Method: nativeRemovePlacementGroup + * Signature: ([B)V + */ +JNIEXPORT void JNICALL +Java_io_ray_runtime_task_NativeTaskSubmitter_nativeRemovePlacementGroup(JNIEnv *, jclass, + jbyteArray); + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 522c12d3d..d1d932584 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -225,9 +225,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_placement_group_class = LoadClass(env, "io/ray/runtime/placementgroup/PlacementGroupImpl"); - java_placement_group_id = - env->GetFieldID(java_placement_group_class, "id", - "Lio/ray/runtime/placementgroup/PlacementGroupId;"); + java_placement_group_id = env->GetFieldID(java_placement_group_class, "id", + "Lio/ray/api/id/PlacementGroupId;"); java_actor_creation_options_class = LoadClass(env, "io/ray/api/options/ActorCreationOptions");