mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[java] Fix the logic of generating TaskID (#2747)
## What do these changes do? Because the logic of generating `TaskID` in java is different from python's, there are many tests fail when we change the `Ray Core` code. In this change, I rewrote the logic of generating `TaskID` in java which is the same as the python's. In java, we call the native method `_generateTaskId()` to generate a `TaskID` which is also used in python. We change `computePutId()`'s logic too. ## Related issue number [#2608](https://github.com/ray-project/ray/issues/2608)
This commit is contained in:
parent
f37c260bdb
commit
b4cba9a49f
17 changed files with 255 additions and 299 deletions
|
@ -7,7 +7,7 @@ import java.util.Random;
|
|||
import javax.xml.bind.DatatypeConverter;
|
||||
|
||||
/**
|
||||
* Unique ID for task, worker, function...
|
||||
* Unique ID for task, worker, function.
|
||||
*/
|
||||
public class UniqueID implements Serializable {
|
||||
|
||||
|
@ -42,7 +42,8 @@ public class UniqueID implements Serializable {
|
|||
|
||||
public UniqueID(byte[] id) {
|
||||
if (id.length != LENGTH) {
|
||||
throw new IllegalArgumentException("Illegal argument: " + id.toString());
|
||||
throw new IllegalArgumentException("Illegal argument for UniqueID, expect " + LENGTH
|
||||
+ " bytes, but got " + id.length + " bytes.");
|
||||
}
|
||||
|
||||
this.id = id;
|
||||
|
@ -86,11 +87,6 @@ public class UniqueID implements Serializable {
|
|||
}
|
||||
|
||||
public boolean isNil() {
|
||||
for (byte b : id) {
|
||||
if (b != (byte) 0xFF) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return this.equals(NIL);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -113,7 +113,6 @@ public abstract class RayRuntime implements RayApi {
|
|||
RemoteFunctionManager remoteLoader,
|
||||
PathConfig pathManager
|
||||
) {
|
||||
UniqueIdHelper.setThreadRandomSeed(UniqueIdHelper.getUniqueness(params.driver_id));
|
||||
remoteFunctionManager = remoteLoader;
|
||||
pathConfig = pathManager;
|
||||
|
||||
|
|
|
@ -3,274 +3,68 @@ package org.ray.core;
|
|||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import org.apache.commons.lang3.BitField;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.util.MD5Digestor;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
|
||||
//
|
||||
// see src/common/common.h for UniqueID layout
|
||||
// Helper methods for UniqueID. These are the same as the helper functions in src/ray/id.h.
|
||||
//
|
||||
public class UniqueIdHelper {
|
||||
public static final int OBJECT_INDEX_POS = 0;
|
||||
public static final int OBJECT_INDEX_LENGTH = 4;
|
||||
|
||||
private static final ThreadLocal<ByteBuffer> longBuffer = ThreadLocal
|
||||
.withInitial(() -> ByteBuffer.allocate(Long.SIZE / Byte.SIZE));
|
||||
private static final ThreadLocal<Random> rand = ThreadLocal.withInitial(Random::new);
|
||||
private static final ThreadLocal<Long> randSeed = new ThreadLocal<>();
|
||||
private static final int batchPos = 0;
|
||||
private static final int uniquenessPos = Long.SIZE / Byte.SIZE;
|
||||
private static final int typePos = 2 * Long.SIZE / Byte.SIZE;
|
||||
private static final BitField typeField = new BitField(0x7);
|
||||
private static final int testPos = 2 * Long.SIZE / Byte.SIZE;
|
||||
private static final BitField testField = new BitField(0x1 << 3);
|
||||
private static final int unionPos = 2 * Long.SIZE / Byte.SIZE;
|
||||
private static final BitField multipleReturnField = new BitField(0x1 << 8);
|
||||
private static final BitField isReturnIdField = new BitField(0x1 << 9);
|
||||
private static final BitField withinTaskIndexField = new BitField(0xFFFFFC00);
|
||||
|
||||
public static void setThreadRandomSeed(long seed) {
|
||||
if (randSeed.get() != null) {
|
||||
RayLog.core.error("Thread random seed is already set to " + randSeed.get()
|
||||
+ " and now to be overwritten to " + seed);
|
||||
throw new RuntimeException("Thread random seed is already set to " + randSeed.get()
|
||||
+ " and now to be overwritten to " + seed);
|
||||
}
|
||||
|
||||
RayLog.core.debug("Thread random seed is set to " + seed);
|
||||
randSeed.set(seed);
|
||||
rand.get().setSeed(seed);
|
||||
/**
|
||||
* Compute the object ID of an object returned by the task.
|
||||
*
|
||||
* @param taskId The task ID of the task that created the object.
|
||||
* @param returnIndex What number return value this object is in the task.
|
||||
* @return The computed object ID.
|
||||
*/
|
||||
public static UniqueID computeReturnId(UniqueID taskId, int returnIndex) {
|
||||
return computeObjectId(taskId, returnIndex);
|
||||
}
|
||||
|
||||
public static Long getNextCreateThreadRandomSeed() {
|
||||
UniqueID currentTaskId = WorkerContext.currentTask().taskId;
|
||||
byte[] bytes;
|
||||
|
||||
ByteBuffer lbuffer = longBuffer.get();
|
||||
// similar to task id generation (see nextTaskId below)
|
||||
if (!currentTaskId.isNil()) {
|
||||
ByteBuffer rbb = ByteBuffer.wrap(currentTaskId.getBytes());
|
||||
rbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
long cid = rbb.getLong(uniquenessPos);
|
||||
byte[] cbuffer = lbuffer.putLong(cid).array();
|
||||
bytes = MD5Digestor.digest(cbuffer, WorkerContext.nextCallIndex());
|
||||
} else {
|
||||
long cid = rand.get().nextLong();
|
||||
byte[] cbuffer = lbuffer.putLong(cid).array();
|
||||
bytes = MD5Digestor.digest(cbuffer, rand.get().nextLong());
|
||||
}
|
||||
lbuffer.clear();
|
||||
|
||||
lbuffer.put(bytes, 0, Long.SIZE / Byte.SIZE);
|
||||
long r = lbuffer.getLong();
|
||||
lbuffer.clear();
|
||||
return r;
|
||||
}
|
||||
|
||||
private static Type getType(ByteBuffer bb) {
|
||||
byte v = bb.get(typePos);
|
||||
return Type.values()[typeField.getValue(v)];
|
||||
}
|
||||
|
||||
private static boolean getIsTest(ByteBuffer bb) {
|
||||
byte v = bb.get(testPos);
|
||||
return testField.getValue(v) == 1;
|
||||
}
|
||||
|
||||
private static int getIsReturn(ByteBuffer bb) {
|
||||
int v = bb.getInt(unionPos);
|
||||
return isReturnIdField.getValue(v);
|
||||
}
|
||||
|
||||
private static int getWithinTaskIndex(ByteBuffer bb) {
|
||||
int v = bb.getInt(unionPos);
|
||||
return withinTaskIndexField.getValue(v);
|
||||
}
|
||||
|
||||
public static void setTest(UniqueID id, boolean isTest) {
|
||||
ByteBuffer bb = ByteBuffer.wrap(id.getBytes());
|
||||
setIsTest(bb, isTest);
|
||||
}
|
||||
|
||||
private static void setIsTest(ByteBuffer bb, boolean isTest) {
|
||||
byte v = bb.get(testPos);
|
||||
v = (byte) testField.setValue(v, isTest ? 1 : 0);
|
||||
bb.put(testPos, v);
|
||||
}
|
||||
|
||||
public static long getUniqueness(UniqueID id) {
|
||||
ByteBuffer bb = ByteBuffer.wrap(id.getBytes());
|
||||
bb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
return getUniqueness(bb);
|
||||
}
|
||||
|
||||
private static long getUniqueness(ByteBuffer bb) {
|
||||
return bb.getLong(uniquenessPos);
|
||||
}
|
||||
|
||||
public static UniqueID taskComputeReturnId(
|
||||
UniqueID uid,
|
||||
int returnIndex,
|
||||
boolean hasMultipleReturn
|
||||
) {
|
||||
return objectIdFromTaskId(uid, true, hasMultipleReturn, returnIndex);
|
||||
}
|
||||
|
||||
private static UniqueID objectIdFromTaskId(UniqueID taskId,
|
||||
boolean isReturn,
|
||||
boolean hasMultipleReturn,
|
||||
int index
|
||||
) {
|
||||
UniqueID oid = newZero();
|
||||
ByteBuffer rbb = ByteBuffer.wrap(taskId.getBytes());
|
||||
rbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
ByteBuffer wbb = ByteBuffer.wrap(oid.getBytes());
|
||||
/**
|
||||
* Compute the object ID from the task ID and the index.
|
||||
* @param taskId The task ID of the task that created the object.
|
||||
* @param index The index which can distinguish different objects in one task.
|
||||
* @return The computed object ID.
|
||||
*/
|
||||
private static UniqueID computeObjectId(UniqueID taskId, int index) {
|
||||
byte[] objId = new byte[UniqueID.LENGTH];
|
||||
System.arraycopy(taskId.getBytes(),0, objId, 0, UniqueID.LENGTH);
|
||||
ByteBuffer wbb = ByteBuffer.wrap(objId);
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
setBatch(wbb, getBatch(rbb));
|
||||
setUniqueness(wbb, getUniqueness(rbb));
|
||||
setType(wbb, Type.OBJECT);
|
||||
setHasMultipleReturn(wbb, hasMultipleReturn ? 1 : 0);
|
||||
setIsReturn(wbb, isReturn ? 1 : 0);
|
||||
setWithinTaskIndex(wbb, index);
|
||||
return oid;
|
||||
wbb.putInt(UniqueIdHelper.OBJECT_INDEX_POS, index);
|
||||
|
||||
return new UniqueID(objId);
|
||||
}
|
||||
|
||||
private static UniqueID newZero() {
|
||||
byte[] b = new byte[UniqueID.LENGTH];
|
||||
Arrays.fill(b, (byte) 0);
|
||||
return new UniqueID(b);
|
||||
/**
|
||||
* Compute the object ID of an object put by the task.
|
||||
*
|
||||
* @param taskId The task ID of the task that created the object.
|
||||
* @param putIndex What number put this object was created by in the task.
|
||||
* @return The computed object ID.
|
||||
*/
|
||||
public static UniqueID computePutId(UniqueID taskId, int putIndex) {
|
||||
// We multiply putIndex by -1 to distinguish from returnIndex.
|
||||
return computeObjectId(taskId, -1 * putIndex);
|
||||
}
|
||||
|
||||
private static void setBatch(ByteBuffer bb, long batchId) {
|
||||
bb.putLong(batchPos, batchId);
|
||||
/**
|
||||
* Compute the task ID of the task that created the object.
|
||||
*
|
||||
* @param objectId The object ID.
|
||||
* @return The task ID of the task that created this object.
|
||||
*/
|
||||
public static UniqueID computeTaskId(UniqueID objectId) {
|
||||
byte[] taskId = new byte[UniqueID.LENGTH];
|
||||
System.arraycopy(objectId.getBytes(), 0, taskId, 0, UniqueID.LENGTH);
|
||||
Arrays.fill(taskId, UniqueIdHelper.OBJECT_INDEX_POS,
|
||||
UniqueIdHelper.OBJECT_INDEX_POS + UniqueIdHelper.OBJECT_INDEX_LENGTH, (byte) 0);
|
||||
|
||||
return new UniqueID(taskId);
|
||||
}
|
||||
|
||||
private static long getBatch(ByteBuffer bb) {
|
||||
return bb.getLong(batchPos);
|
||||
}
|
||||
|
||||
private static void setUniqueness(ByteBuffer bb, long uniqueness) {
|
||||
bb.putLong(uniquenessPos, uniqueness);
|
||||
}
|
||||
|
||||
private static void setUniqueness(ByteBuffer bb, byte[] uniqueness) {
|
||||
for (int i = 0; i < Long.SIZE / Byte.SIZE; ++i) {
|
||||
bb.put(uniquenessPos + i, uniqueness[i]);
|
||||
}
|
||||
}
|
||||
|
||||
private static void setType(ByteBuffer bb, Type type) {
|
||||
byte v = bb.get(typePos);
|
||||
v = (byte) typeField.setValue(v, type.ordinal());
|
||||
bb.put(typePos, v);
|
||||
}
|
||||
|
||||
private static void setHasMultipleReturn(ByteBuffer bb, int hasMultipleReturnOrNot) {
|
||||
int v = bb.getInt(unionPos);
|
||||
v = multipleReturnField.setValue(v, hasMultipleReturnOrNot);
|
||||
bb.putInt(unionPos, v);
|
||||
}
|
||||
|
||||
private static void setIsReturn(ByteBuffer bb, int isReturn) {
|
||||
int v = bb.getInt(unionPos);
|
||||
v = isReturnIdField.setValue(v, isReturn);
|
||||
bb.putInt(unionPos, v);
|
||||
}
|
||||
|
||||
private static void setWithinTaskIndex(ByteBuffer bb, int index) {
|
||||
int v = bb.getInt(unionPos);
|
||||
v = withinTaskIndexField.setValue(v, index);
|
||||
bb.putInt(unionPos, v);
|
||||
}
|
||||
|
||||
public static UniqueID taskComputePutId(UniqueID uid, int putIndex) {
|
||||
return objectIdFromTaskId(uid, false, false, putIndex);
|
||||
}
|
||||
|
||||
public static boolean hasMultipleReturnOrNotFromReturnObjectId(UniqueID returnId) {
|
||||
ByteBuffer bb = ByteBuffer.wrap(returnId.getBytes());
|
||||
bb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
return getHasMultipleReturn(bb) != 0;
|
||||
}
|
||||
|
||||
private static int getHasMultipleReturn(ByteBuffer bb) {
|
||||
int v = bb.getInt(unionPos);
|
||||
return multipleReturnField.getValue(v);
|
||||
}
|
||||
|
||||
public static UniqueID taskIdFromObjectId(UniqueID objectId) {
|
||||
UniqueID taskId = newZero();
|
||||
ByteBuffer rbb = ByteBuffer.wrap(objectId.getBytes());
|
||||
rbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
ByteBuffer wbb = ByteBuffer.wrap(taskId.getBytes());
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
setBatch(wbb, getBatch(rbb));
|
||||
setUniqueness(wbb, getUniqueness(rbb));
|
||||
setType(wbb, Type.TASK);
|
||||
return taskId;
|
||||
}
|
||||
|
||||
public static UniqueID nextTaskId(long batchId) {
|
||||
UniqueID taskId = newZero();
|
||||
ByteBuffer wbb = ByteBuffer.wrap(taskId.getBytes());
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
setType(wbb, Type.TASK);
|
||||
|
||||
UniqueID currentTaskId = WorkerContext.currentTask().taskId;
|
||||
ByteBuffer rbb = ByteBuffer.wrap(currentTaskId.getBytes());
|
||||
rbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
// setup batch id
|
||||
if (batchId == -1) {
|
||||
setBatch(wbb, getBatch(rbb));
|
||||
} else {
|
||||
setBatch(wbb, batchId);
|
||||
}
|
||||
|
||||
// setup unique id (task id)
|
||||
byte[] idBytes;
|
||||
|
||||
ByteBuffer lbuffer = longBuffer.get();
|
||||
// if inside a task
|
||||
if (!currentTaskId.isNil()) {
|
||||
long cid = rbb.getLong(uniquenessPos);
|
||||
byte[] cbuffer = lbuffer.putLong(cid).array();
|
||||
idBytes = MD5Digestor.digest(cbuffer, WorkerContext.nextCallIndex());
|
||||
|
||||
// if not
|
||||
} else {
|
||||
long cid = rand.get().nextLong();
|
||||
byte[] cbuffer = lbuffer.putLong(cid).array();
|
||||
idBytes = MD5Digestor.digest(cbuffer, rand.get().nextLong());
|
||||
}
|
||||
setUniqueness(wbb, idBytes);
|
||||
lbuffer.clear();
|
||||
return taskId;
|
||||
}
|
||||
|
||||
public static void markCreateActorStage1Function(UniqueID functionId) {
|
||||
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
setUniqueness(wbb, 1);
|
||||
}
|
||||
|
||||
// WARNING: see hack in MethodId.java which must be aligned with here
|
||||
public static boolean isNonLambdaCreateActorStage1Function(UniqueID functionId) {
|
||||
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
return getUniqueness(wbb) == 1;
|
||||
}
|
||||
|
||||
public static boolean isNonLambdaCommonFunction(UniqueID functionId) {
|
||||
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
return getUniqueness(wbb) == 0;
|
||||
}
|
||||
|
||||
public enum Type {
|
||||
OBJECT,
|
||||
TASK,
|
||||
ACTOR,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -82,7 +82,9 @@ public class Worker {
|
|||
|
||||
public RayObject submit(RayFunc func, Object[] args) {
|
||||
MethodId methodId = methodIdOf(func);
|
||||
UniqueID taskId = UniqueIdHelper.nextTaskId(-1);
|
||||
UniqueID taskId = scheduler.generateTaskId(WorkerContext.currentTask().driverId,
|
||||
WorkerContext.currentTask().taskId,
|
||||
WorkerContext.nextCallIndex());
|
||||
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
|
||||
return actorTaskSubmit(taskId, methodId, args, (RayActor<?>) args[0]);
|
||||
} else {
|
||||
|
@ -123,7 +125,7 @@ public class Worker {
|
|||
}
|
||||
|
||||
public UniqueID getCurrentTaskNextPutId() {
|
||||
return UniqueIdHelper.taskComputePutId(
|
||||
return UniqueIdHelper.computePutId(
|
||||
WorkerContext.currentTask().taskId, WorkerContext.nextPutIndex());
|
||||
}
|
||||
|
||||
|
|
|
@ -21,5 +21,7 @@ public interface LocalSchedulerLink {
|
|||
|
||||
void notifyUnblocked();
|
||||
|
||||
UniqueID generateTaskId(UniqueID driverId, UniqueID parentTaskId, int taskIndex);
|
||||
|
||||
List<byte[]> wait(byte[][] objectIds, int timeoutMs, int numReturns);
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ public class LocalSchedulerProxy {
|
|||
private UniqueID[] genReturnIds(UniqueID taskId, int numReturns) {
|
||||
UniqueID[] ret = new UniqueID[numReturns];
|
||||
for (int i = 0; i < numReturns; i++) {
|
||||
ret[i] = UniqueIdHelper.taskComputeReturnId(taskId, i, false);
|
||||
ret[i] = UniqueIdHelper.computeReturnId(taskId, i + 1);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -129,4 +129,8 @@ public class LocalSchedulerProxy {
|
|||
|
||||
return new WaitResult<>(readyObjs, remainObjs);
|
||||
}
|
||||
|
||||
public UniqueID generateTaskId(UniqueID driverId, UniqueID parentTaskId, int taskIndex) {
|
||||
return scheduler.generateTaskId(driverId, parentTaskId, taskIndex);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ public class RayDevRuntime extends RayRuntime {
|
|||
|
||||
private byte[] createLocalActor(String className) {
|
||||
UniqueID taskId = WorkerContext.currentTask().taskId;
|
||||
UniqueID actorId = UniqueIdHelper.taskComputeReturnId(taskId, 0, false);
|
||||
UniqueID actorId = UniqueIdHelper.computeReturnId(taskId, 0);
|
||||
try {
|
||||
Class<?> cls = Class.forName(className);
|
||||
|
||||
|
|
|
@ -89,6 +89,11 @@ public class MockLocalScheduler implements LocalSchedulerLink {
|
|||
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueID generateTaskId(UniqueID driverId, UniqueID parentTaskId, int taskIndex) {
|
||||
throw new RuntimeException("Not implemented here.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<byte[]> wait(byte[][] objectIds, int timeoutMs, int numReturns) {
|
||||
return store.wait(objectIds, timeoutMs, numReturns);
|
||||
|
|
|
@ -243,8 +243,13 @@ public class RayNativeRuntime extends RayRuntime {
|
|||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public <T> RayActor<T> create(Class<T> cls) {
|
||||
UniqueID createTaskId = UniqueIdHelper.nextTaskId(-1);
|
||||
UniqueID actorId = UniqueIdHelper.taskComputeReturnId(createTaskId, 0, false);
|
||||
UniqueID createTaskId = localSchedulerProxy.generateTaskId(
|
||||
WorkerContext.currentTask().driverId,
|
||||
WorkerContext.currentTask().taskId,
|
||||
WorkerContext.nextCallIndex()
|
||||
);
|
||||
|
||||
UniqueID actorId = UniqueIdHelper.computeReturnId(createTaskId, 0);
|
||||
RayActor<T> actor = new RayActor<>(actorId);
|
||||
UniqueID cursorId;
|
||||
|
||||
|
|
|
@ -6,8 +6,11 @@ import java.nio.ByteOrder;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.core.RayRuntime;
|
||||
import org.ray.core.UniqueIdHelper;
|
||||
import org.ray.spi.LocalSchedulerLink;
|
||||
import org.ray.spi.model.FunctionArg;
|
||||
import org.ray.spi.model.TaskSpec;
|
||||
|
@ -42,6 +45,8 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
|
|||
|
||||
private static native byte[] _computePutId(long client, byte[] taskId, int putIndex);
|
||||
|
||||
private static native byte[] _generateTaskId(byte[] driverId, byte[] parentTaskId, int taskIndex);
|
||||
|
||||
private static native void _task_done(long client);
|
||||
|
||||
private static native boolean[] _waitObject(long conn, byte[][] objectIds,
|
||||
|
@ -111,10 +116,19 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
|
|||
|
||||
@Override
|
||||
public void reconstructObjects(List<UniqueID> objectIds, boolean fetchOnly) {
|
||||
RayLog.core.info("reconstruct objects {}", objectIds);
|
||||
if (RayLog.core.isInfoEnabled()) {
|
||||
RayLog.core.info("Reconstructing objects for task {}, object IDs are {}",
|
||||
UniqueIdHelper.computeTaskId(objectIds.get(0)), objectIds);
|
||||
}
|
||||
_reconstruct_objects(client, getIdBytes(objectIds), fetchOnly);
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueID generateTaskId(UniqueID driverId, UniqueID parentTaskId, int taskIndex) {
|
||||
byte[] bytes = _generateTaskId(driverId.getBytes(), parentTaskId.getBytes(), taskIndex);
|
||||
return new UniqueID(bytes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyUnblocked() {
|
||||
_notify_unblocked(client);
|
||||
|
|
84
java/test/src/main/java/org/ray/api/test/UniqueIdTest.java
Normal file
84
java/test/src/main/java/org/ray/api/test/UniqueIdTest.java
Normal file
|
@ -0,0 +1,84 @@
|
|||
package org.ray.api.test;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import javax.xml.bind.DatatypeConverter;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.core.UniqueIdHelper;
|
||||
|
||||
@RunWith(MyRunner.class)
|
||||
public class UniqueIdTest {
|
||||
|
||||
@Test
|
||||
public void testConstructUniqueId() {
|
||||
// Test `fromHexString()`
|
||||
UniqueID id1 = UniqueID.fromHexString("00000000123456789ABCDEF123456789ABCDEF00");
|
||||
Assert.assertEquals("00000000123456789ABCDEF123456789ABCDEF00", id1.toString());
|
||||
Assert.assertFalse(id1.isNil());
|
||||
|
||||
try {
|
||||
UniqueID id2 = UniqueID.fromHexString("000000123456789ABCDEF123456789ABCDEF00");
|
||||
// This shouldn't be happened.
|
||||
Assert.assertTrue(false);
|
||||
} catch (IllegalArgumentException e) {
|
||||
Assert.assertTrue(true);
|
||||
}
|
||||
|
||||
try {
|
||||
UniqueID id3 = UniqueID.fromHexString("GGGGGGGGGGGGG");
|
||||
// This shouldn't be happened.
|
||||
Assert.assertTrue(false);
|
||||
} catch (IllegalArgumentException e) {
|
||||
Assert.assertTrue(true);
|
||||
}
|
||||
|
||||
// Test `fromByteBuffer()`
|
||||
byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF01234567");
|
||||
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 20);
|
||||
UniqueID id4 = UniqueID.fromByteBuffer(byteBuffer);
|
||||
Assert.assertTrue(Arrays.equals(bytes, id4.getBytes()));
|
||||
Assert.assertEquals("0123456789ABCDEF0123456789ABCDEF01234567", id4.toString());
|
||||
|
||||
|
||||
// Test `genNil()`
|
||||
UniqueID id6 = UniqueID.genNil();
|
||||
Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", id6.toString());
|
||||
Assert.assertTrue(id6.isNil());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComputeReturnId() {
|
||||
// Mock a taskId, and the lowest 4 bytes should be 0.
|
||||
UniqueID taskId = UniqueID.fromHexString("00000000123456789ABCDEF123456789ABCDEF00");
|
||||
|
||||
UniqueID returnId = UniqueIdHelper.computeReturnId(taskId, 1);
|
||||
Assert.assertEquals("01000000123456789ABCDEF123456789ABCDEF00", returnId.toString());
|
||||
|
||||
returnId = UniqueIdHelper.computeReturnId(taskId, 0x01020304);
|
||||
Assert.assertEquals("04030201123456789ABCDEF123456789ABCDEF00", returnId.toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComputeTaskId() {
|
||||
UniqueID objId = UniqueID.fromHexString("34421980123456789ABCDEF123456789ABCDEF00");
|
||||
UniqueID taskId = UniqueIdHelper.computeTaskId(objId);
|
||||
|
||||
Assert.assertEquals("00000000123456789ABCDEF123456789ABCDEF00", taskId.toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComputePutId() {
|
||||
// Mock a taskId, the lowest 4 bytes should be 0.
|
||||
UniqueID taskId = UniqueID.fromHexString("00000000123456789ABCDEF123456789ABCDEF00");
|
||||
|
||||
UniqueID putId = UniqueIdHelper.computePutId(taskId, 1);
|
||||
Assert.assertEquals("FFFFFFFF123456789ABCDEF123456789ABCDEF00", putId.toString());
|
||||
|
||||
putId = UniqueIdHelper.computePutId(taskId, 0x01020304);
|
||||
Assert.assertEquals("FCFCFDFE123456789ABCDEF123456789ABCDEF00", putId.toString());
|
||||
}
|
||||
|
||||
}
|
|
@ -3,6 +3,7 @@
|
|||
#include "local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h"
|
||||
#include "local_scheduler_client.h"
|
||||
#include "logging.h"
|
||||
#include "ray/id.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -299,6 +300,31 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1waitObject(
|
|||
return resultArray;
|
||||
}
|
||||
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1generateTaskId(
|
||||
JNIEnv *env,
|
||||
jclass,
|
||||
jbyteArray did,
|
||||
jbyteArray ptid,
|
||||
jint parent_task_counter) {
|
||||
UniqueIdFromJByteArray o1(env, did);
|
||||
ray::DriverID driver_id = *o1.PID;
|
||||
|
||||
UniqueIdFromJByteArray o2(env, ptid);
|
||||
ray::TaskID parent_task_id = *o2.PID;
|
||||
|
||||
ray::TaskID task_id =
|
||||
ray::GenerateTaskId(driver_id, parent_task_id, parent_task_counter);
|
||||
jbyteArray result = env->NewByteArray(sizeof(ray::TaskID));
|
||||
if (nullptr == result) {
|
||||
return nullptr;
|
||||
}
|
||||
env->SetByteArrayRegion(result, 0, sizeof(TaskID),
|
||||
reinterpret_cast<jbyte *>(&task_id));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -130,6 +130,18 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1waitObject(JNIEnv *,
|
|||
jint,
|
||||
jboolean);
|
||||
|
||||
/*
|
||||
* Class: org_ray_spi_impl_DefaultLocalSchedulerClient
|
||||
* Method: _generateTaskId
|
||||
* Signature: ([B[BI)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1generateTaskId(JNIEnv *,
|
||||
jclass,
|
||||
jbyteArray,
|
||||
jbyteArray,
|
||||
jint);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include <mutex>
|
||||
#include <random>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "ray/constants.h"
|
||||
#include "ray/status.h"
|
||||
|
||||
|
@ -177,6 +178,7 @@ const ObjectID ComputeReturnId(const TaskID &task_id, int64_t return_index) {
|
|||
|
||||
const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index) {
|
||||
RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts);
|
||||
// We multiply put_index by -1 to distinguish from return_index.
|
||||
return ComputeObjectId(task_id, -1 * put_index);
|
||||
}
|
||||
|
||||
|
@ -190,6 +192,26 @@ const TaskID ComputeTaskId(const ObjectID &object_id) {
|
|||
return task_id;
|
||||
}
|
||||
|
||||
const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id,
|
||||
int parent_task_counter) {
|
||||
// Compute hashes.
|
||||
SHA256_CTX ctx;
|
||||
sha256_init(&ctx);
|
||||
sha256_update(&ctx, (BYTE *)&driver_id, sizeof(driver_id));
|
||||
sha256_update(&ctx, (BYTE *)&parent_task_id, sizeof(parent_task_id));
|
||||
sha256_update(&ctx, (BYTE *)&parent_task_counter, sizeof(parent_task_counter));
|
||||
|
||||
// Compute the final task ID from the hash.
|
||||
BYTE buff[DIGEST_SIZE];
|
||||
sha256_final(&ctx, buff);
|
||||
TaskID task_id;
|
||||
RAY_DCHECK(sizeof(task_id) <= DIGEST_SIZE);
|
||||
memcpy(&task_id, buff, sizeof(task_id));
|
||||
task_id = FinishTaskId(task_id);
|
||||
|
||||
return task_id;
|
||||
}
|
||||
|
||||
int64_t ComputeObjectIndex(const ObjectID &object_id) {
|
||||
const int64_t *first_bytes = reinterpret_cast<const int64_t *>(&object_id);
|
||||
uint64_t bitmask = static_cast<uint64_t>(-1) << kObjectIdIndexSize;
|
||||
|
|
13
src/ray/id.h
13
src/ray/id.h
|
@ -10,6 +10,10 @@
|
|||
#include "ray/constants.h"
|
||||
#include "ray/util/visibility.h"
|
||||
|
||||
extern "C" {
|
||||
#include "sha256.h"
|
||||
}
|
||||
|
||||
namespace ray {
|
||||
|
||||
class RAY_EXPORT UniqueID {
|
||||
|
@ -81,6 +85,15 @@ const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index);
|
|||
/// \return The task ID of the task that created this object.
|
||||
const TaskID ComputeTaskId(const ObjectID &object_id);
|
||||
|
||||
/// Generate a task ID from the given info.
|
||||
///
|
||||
/// \param driver_id The driver that creates the task.
|
||||
/// \param parent_task_id The parent task of this task.
|
||||
/// \param parent_task_counter The task index of the worker.
|
||||
/// \return The task ID generated from the given info.
|
||||
const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id,
|
||||
int parent_task_counter);
|
||||
|
||||
/// Compute the index of this object in the task that created it.
|
||||
///
|
||||
/// \param object_id The object ID.
|
||||
|
|
|
@ -1195,21 +1195,11 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) {
|
|||
[this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id) {
|
||||
// The task was not in the GCS task table. It must therefore be in the
|
||||
// lineage cache.
|
||||
if (!lineage_cache_.ContainsTask(task_id)) {
|
||||
// The task was not in the lineage cache.
|
||||
// TODO(swang): This should not ever happen, but Java TaskIDs are
|
||||
// currently computed differently from Python TaskIDs, so
|
||||
// reconstruction is currently broken for Java. Once the TaskID
|
||||
// generation code matches for both frontends, we should be able to
|
||||
// remove this warning and make it a fatal check.
|
||||
RAY_LOG(WARNING) << "Task " << task_id << " to reconstruct was not found in "
|
||||
"the GCS or the lineage cache. This "
|
||||
"job may hang.";
|
||||
} else {
|
||||
// Use a copy of the cached task spec to re-execute the task.
|
||||
const Task task = lineage_cache_.GetTask(task_id);
|
||||
ResubmitTask(task);
|
||||
}
|
||||
RAY_CHECK(lineage_cache_.ContainsTask(task_id));
|
||||
// Use a copy of the cached task spec to re-execute the task.
|
||||
const Task task = lineage_cache_.GetTask(task_id);
|
||||
ResubmitTask(task);
|
||||
|
||||
}));
|
||||
}
|
||||
|
||||
|
|
|
@ -63,20 +63,8 @@ TaskSpecification::TaskSpecification(
|
|||
: spec_() {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
|
||||
// Compute hashes.
|
||||
SHA256_CTX ctx;
|
||||
sha256_init(&ctx);
|
||||
sha256_update(&ctx, (BYTE *)&driver_id, sizeof(driver_id));
|
||||
sha256_update(&ctx, (BYTE *)&parent_task_id, sizeof(parent_task_id));
|
||||
sha256_update(&ctx, (BYTE *)&parent_counter, sizeof(parent_counter));
|
||||
TaskID task_id = GenerateTaskId(driver_id, parent_task_id, parent_counter);
|
||||
|
||||
// Compute the final task ID from the hash.
|
||||
BYTE buff[DIGEST_SIZE];
|
||||
sha256_final(&ctx, buff);
|
||||
TaskID task_id;
|
||||
RAY_DCHECK(sizeof(task_id) <= DIGEST_SIZE);
|
||||
memcpy(&task_id, buff, sizeof(task_id));
|
||||
task_id = FinishTaskId(task_id);
|
||||
// Add argument object IDs.
|
||||
std::vector<flatbuffers::Offset<Arg>> arguments;
|
||||
for (auto &argument : task_arguments) {
|
||||
|
|
Loading…
Add table
Reference in a new issue