From b4cba9a49f96e649811f16e8144e3a29116a37b0 Mon Sep 17 00:00:00 2001 From: Wang Qing Date: Tue, 28 Aug 2018 04:11:33 +0800 Subject: [PATCH] [java] Fix the logic of generating TaskID (#2747) ## What do these changes do? Because the logic of generating `TaskID` in java is different from python's, there are many tests fail when we change the `Ray Core` code. In this change, I rewrote the logic of generating `TaskID` in java which is the same as the python's. In java, we call the native method `_generateTaskId()` to generate a `TaskID` which is also used in python. We change `computePutId()`'s logic too. ## Related issue number [#2608](https://github.com/ray-project/ray/issues/2608) --- .../src/main/java/org/ray/api/UniqueID.java | 12 +- .../main/java/org/ray/core/RayRuntime.java | 1 - .../java/org/ray/core/UniqueIdHelper.java | 304 +++--------------- .../src/main/java/org/ray/core/Worker.java | 6 +- .../java/org/ray/spi/LocalSchedulerLink.java | 2 + .../java/org/ray/spi/LocalSchedulerProxy.java | 6 +- .../java/org/ray/core/impl/RayDevRuntime.java | 2 +- .../org/ray/spi/impl/MockLocalScheduler.java | 5 + .../org/ray/core/impl/RayNativeRuntime.java | 9 +- .../spi/impl/DefaultLocalSchedulerClient.java | 16 +- .../java/org/ray/api/test/UniqueIdTest.java | 84 +++++ ...ay_spi_impl_DefaultLocalSchedulerClient.cc | 26 ++ ...ray_spi_impl_DefaultLocalSchedulerClient.h | 12 + src/ray/id.cc | 22 ++ src/ray/id.h | 13 + src/ray/raylet/node_manager.cc | 20 +- src/ray/raylet/task_spec.cc | 14 +- 17 files changed, 255 insertions(+), 299 deletions(-) create mode 100644 java/test/src/main/java/org/ray/api/test/UniqueIdTest.java diff --git a/java/api/src/main/java/org/ray/api/UniqueID.java b/java/api/src/main/java/org/ray/api/UniqueID.java index 09323f264..d6af0a664 100644 --- a/java/api/src/main/java/org/ray/api/UniqueID.java +++ b/java/api/src/main/java/org/ray/api/UniqueID.java @@ -7,7 +7,7 @@ import java.util.Random; import javax.xml.bind.DatatypeConverter; /** - * Unique ID for task, worker, function... + * Unique ID for task, worker, function. */ public class UniqueID implements Serializable { @@ -42,7 +42,8 @@ public class UniqueID implements Serializable { public UniqueID(byte[] id) { if (id.length != LENGTH) { - throw new IllegalArgumentException("Illegal argument: " + id.toString()); + throw new IllegalArgumentException("Illegal argument for UniqueID, expect " + LENGTH + + " bytes, but got " + id.length + " bytes."); } this.id = id; @@ -86,11 +87,6 @@ public class UniqueID implements Serializable { } public boolean isNil() { - for (byte b : id) { - if (b != (byte) 0xFF) { - return false; - } - } - return true; + return this.equals(NIL); } } diff --git a/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java b/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java index b15664888..f0ccc6f60 100644 --- a/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java +++ b/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java @@ -113,7 +113,6 @@ public abstract class RayRuntime implements RayApi { RemoteFunctionManager remoteLoader, PathConfig pathManager ) { - UniqueIdHelper.setThreadRandomSeed(UniqueIdHelper.getUniqueness(params.driver_id)); remoteFunctionManager = remoteLoader; pathConfig = pathManager; diff --git a/java/runtime-common/src/main/java/org/ray/core/UniqueIdHelper.java b/java/runtime-common/src/main/java/org/ray/core/UniqueIdHelper.java index c5b217cb4..7cb4e4dd5 100644 --- a/java/runtime-common/src/main/java/org/ray/core/UniqueIdHelper.java +++ b/java/runtime-common/src/main/java/org/ray/core/UniqueIdHelper.java @@ -3,274 +3,68 @@ package org.ray.core; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; -import java.util.Random; -import org.apache.commons.lang3.BitField; import org.ray.api.UniqueID; -import org.ray.util.MD5Digestor; -import org.ray.util.logger.RayLog; // -// see src/common/common.h for UniqueID layout +// Helper methods for UniqueID. These are the same as the helper functions in src/ray/id.h. // public class UniqueIdHelper { + public static final int OBJECT_INDEX_POS = 0; + public static final int OBJECT_INDEX_LENGTH = 4; - private static final ThreadLocal longBuffer = ThreadLocal - .withInitial(() -> ByteBuffer.allocate(Long.SIZE / Byte.SIZE)); - private static final ThreadLocal rand = ThreadLocal.withInitial(Random::new); - private static final ThreadLocal randSeed = new ThreadLocal<>(); - private static final int batchPos = 0; - private static final int uniquenessPos = Long.SIZE / Byte.SIZE; - private static final int typePos = 2 * Long.SIZE / Byte.SIZE; - private static final BitField typeField = new BitField(0x7); - private static final int testPos = 2 * Long.SIZE / Byte.SIZE; - private static final BitField testField = new BitField(0x1 << 3); - private static final int unionPos = 2 * Long.SIZE / Byte.SIZE; - private static final BitField multipleReturnField = new BitField(0x1 << 8); - private static final BitField isReturnIdField = new BitField(0x1 << 9); - private static final BitField withinTaskIndexField = new BitField(0xFFFFFC00); - - public static void setThreadRandomSeed(long seed) { - if (randSeed.get() != null) { - RayLog.core.error("Thread random seed is already set to " + randSeed.get() - + " and now to be overwritten to " + seed); - throw new RuntimeException("Thread random seed is already set to " + randSeed.get() - + " and now to be overwritten to " + seed); - } - - RayLog.core.debug("Thread random seed is set to " + seed); - randSeed.set(seed); - rand.get().setSeed(seed); + /** + * Compute the object ID of an object returned by the task. + * + * @param taskId The task ID of the task that created the object. + * @param returnIndex What number return value this object is in the task. + * @return The computed object ID. + */ + public static UniqueID computeReturnId(UniqueID taskId, int returnIndex) { + return computeObjectId(taskId, returnIndex); } - public static Long getNextCreateThreadRandomSeed() { - UniqueID currentTaskId = WorkerContext.currentTask().taskId; - byte[] bytes; - - ByteBuffer lbuffer = longBuffer.get(); - // similar to task id generation (see nextTaskId below) - if (!currentTaskId.isNil()) { - ByteBuffer rbb = ByteBuffer.wrap(currentTaskId.getBytes()); - rbb.order(ByteOrder.LITTLE_ENDIAN); - long cid = rbb.getLong(uniquenessPos); - byte[] cbuffer = lbuffer.putLong(cid).array(); - bytes = MD5Digestor.digest(cbuffer, WorkerContext.nextCallIndex()); - } else { - long cid = rand.get().nextLong(); - byte[] cbuffer = lbuffer.putLong(cid).array(); - bytes = MD5Digestor.digest(cbuffer, rand.get().nextLong()); - } - lbuffer.clear(); - - lbuffer.put(bytes, 0, Long.SIZE / Byte.SIZE); - long r = lbuffer.getLong(); - lbuffer.clear(); - return r; - } - - private static Type getType(ByteBuffer bb) { - byte v = bb.get(typePos); - return Type.values()[typeField.getValue(v)]; - } - - private static boolean getIsTest(ByteBuffer bb) { - byte v = bb.get(testPos); - return testField.getValue(v) == 1; - } - - private static int getIsReturn(ByteBuffer bb) { - int v = bb.getInt(unionPos); - return isReturnIdField.getValue(v); - } - - private static int getWithinTaskIndex(ByteBuffer bb) { - int v = bb.getInt(unionPos); - return withinTaskIndexField.getValue(v); - } - - public static void setTest(UniqueID id, boolean isTest) { - ByteBuffer bb = ByteBuffer.wrap(id.getBytes()); - setIsTest(bb, isTest); - } - - private static void setIsTest(ByteBuffer bb, boolean isTest) { - byte v = bb.get(testPos); - v = (byte) testField.setValue(v, isTest ? 1 : 0); - bb.put(testPos, v); - } - - public static long getUniqueness(UniqueID id) { - ByteBuffer bb = ByteBuffer.wrap(id.getBytes()); - bb.order(ByteOrder.LITTLE_ENDIAN); - return getUniqueness(bb); - } - - private static long getUniqueness(ByteBuffer bb) { - return bb.getLong(uniquenessPos); - } - - public static UniqueID taskComputeReturnId( - UniqueID uid, - int returnIndex, - boolean hasMultipleReturn - ) { - return objectIdFromTaskId(uid, true, hasMultipleReturn, returnIndex); - } - - private static UniqueID objectIdFromTaskId(UniqueID taskId, - boolean isReturn, - boolean hasMultipleReturn, - int index - ) { - UniqueID oid = newZero(); - ByteBuffer rbb = ByteBuffer.wrap(taskId.getBytes()); - rbb.order(ByteOrder.LITTLE_ENDIAN); - ByteBuffer wbb = ByteBuffer.wrap(oid.getBytes()); + /** + * Compute the object ID from the task ID and the index. + * @param taskId The task ID of the task that created the object. + * @param index The index which can distinguish different objects in one task. + * @return The computed object ID. + */ + private static UniqueID computeObjectId(UniqueID taskId, int index) { + byte[] objId = new byte[UniqueID.LENGTH]; + System.arraycopy(taskId.getBytes(),0, objId, 0, UniqueID.LENGTH); + ByteBuffer wbb = ByteBuffer.wrap(objId); wbb.order(ByteOrder.LITTLE_ENDIAN); - setBatch(wbb, getBatch(rbb)); - setUniqueness(wbb, getUniqueness(rbb)); - setType(wbb, Type.OBJECT); - setHasMultipleReturn(wbb, hasMultipleReturn ? 1 : 0); - setIsReturn(wbb, isReturn ? 1 : 0); - setWithinTaskIndex(wbb, index); - return oid; + wbb.putInt(UniqueIdHelper.OBJECT_INDEX_POS, index); + + return new UniqueID(objId); } - private static UniqueID newZero() { - byte[] b = new byte[UniqueID.LENGTH]; - Arrays.fill(b, (byte) 0); - return new UniqueID(b); + /** + * Compute the object ID of an object put by the task. + * + * @param taskId The task ID of the task that created the object. + * @param putIndex What number put this object was created by in the task. + * @return The computed object ID. + */ + public static UniqueID computePutId(UniqueID taskId, int putIndex) { + // We multiply putIndex by -1 to distinguish from returnIndex. + return computeObjectId(taskId, -1 * putIndex); } - private static void setBatch(ByteBuffer bb, long batchId) { - bb.putLong(batchPos, batchId); + /** + * Compute the task ID of the task that created the object. + * + * @param objectId The object ID. + * @return The task ID of the task that created this object. + */ + public static UniqueID computeTaskId(UniqueID objectId) { + byte[] taskId = new byte[UniqueID.LENGTH]; + System.arraycopy(objectId.getBytes(), 0, taskId, 0, UniqueID.LENGTH); + Arrays.fill(taskId, UniqueIdHelper.OBJECT_INDEX_POS, + UniqueIdHelper.OBJECT_INDEX_POS + UniqueIdHelper.OBJECT_INDEX_LENGTH, (byte) 0); + + return new UniqueID(taskId); } - private static long getBatch(ByteBuffer bb) { - return bb.getLong(batchPos); - } - - private static void setUniqueness(ByteBuffer bb, long uniqueness) { - bb.putLong(uniquenessPos, uniqueness); - } - - private static void setUniqueness(ByteBuffer bb, byte[] uniqueness) { - for (int i = 0; i < Long.SIZE / Byte.SIZE; ++i) { - bb.put(uniquenessPos + i, uniqueness[i]); - } - } - - private static void setType(ByteBuffer bb, Type type) { - byte v = bb.get(typePos); - v = (byte) typeField.setValue(v, type.ordinal()); - bb.put(typePos, v); - } - - private static void setHasMultipleReturn(ByteBuffer bb, int hasMultipleReturnOrNot) { - int v = bb.getInt(unionPos); - v = multipleReturnField.setValue(v, hasMultipleReturnOrNot); - bb.putInt(unionPos, v); - } - - private static void setIsReturn(ByteBuffer bb, int isReturn) { - int v = bb.getInt(unionPos); - v = isReturnIdField.setValue(v, isReturn); - bb.putInt(unionPos, v); - } - - private static void setWithinTaskIndex(ByteBuffer bb, int index) { - int v = bb.getInt(unionPos); - v = withinTaskIndexField.setValue(v, index); - bb.putInt(unionPos, v); - } - - public static UniqueID taskComputePutId(UniqueID uid, int putIndex) { - return objectIdFromTaskId(uid, false, false, putIndex); - } - - public static boolean hasMultipleReturnOrNotFromReturnObjectId(UniqueID returnId) { - ByteBuffer bb = ByteBuffer.wrap(returnId.getBytes()); - bb.order(ByteOrder.LITTLE_ENDIAN); - return getHasMultipleReturn(bb) != 0; - } - - private static int getHasMultipleReturn(ByteBuffer bb) { - int v = bb.getInt(unionPos); - return multipleReturnField.getValue(v); - } - - public static UniqueID taskIdFromObjectId(UniqueID objectId) { - UniqueID taskId = newZero(); - ByteBuffer rbb = ByteBuffer.wrap(objectId.getBytes()); - rbb.order(ByteOrder.LITTLE_ENDIAN); - ByteBuffer wbb = ByteBuffer.wrap(taskId.getBytes()); - wbb.order(ByteOrder.LITTLE_ENDIAN); - setBatch(wbb, getBatch(rbb)); - setUniqueness(wbb, getUniqueness(rbb)); - setType(wbb, Type.TASK); - return taskId; - } - - public static UniqueID nextTaskId(long batchId) { - UniqueID taskId = newZero(); - ByteBuffer wbb = ByteBuffer.wrap(taskId.getBytes()); - wbb.order(ByteOrder.LITTLE_ENDIAN); - setType(wbb, Type.TASK); - - UniqueID currentTaskId = WorkerContext.currentTask().taskId; - ByteBuffer rbb = ByteBuffer.wrap(currentTaskId.getBytes()); - rbb.order(ByteOrder.LITTLE_ENDIAN); - - // setup batch id - if (batchId == -1) { - setBatch(wbb, getBatch(rbb)); - } else { - setBatch(wbb, batchId); - } - - // setup unique id (task id) - byte[] idBytes; - - ByteBuffer lbuffer = longBuffer.get(); - // if inside a task - if (!currentTaskId.isNil()) { - long cid = rbb.getLong(uniquenessPos); - byte[] cbuffer = lbuffer.putLong(cid).array(); - idBytes = MD5Digestor.digest(cbuffer, WorkerContext.nextCallIndex()); - - // if not - } else { - long cid = rand.get().nextLong(); - byte[] cbuffer = lbuffer.putLong(cid).array(); - idBytes = MD5Digestor.digest(cbuffer, rand.get().nextLong()); - } - setUniqueness(wbb, idBytes); - lbuffer.clear(); - return taskId; - } - - public static void markCreateActorStage1Function(UniqueID functionId) { - ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes()); - wbb.order(ByteOrder.LITTLE_ENDIAN); - setUniqueness(wbb, 1); - } - - // WARNING: see hack in MethodId.java which must be aligned with here - public static boolean isNonLambdaCreateActorStage1Function(UniqueID functionId) { - ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes()); - wbb.order(ByteOrder.LITTLE_ENDIAN); - return getUniqueness(wbb) == 1; - } - - public static boolean isNonLambdaCommonFunction(UniqueID functionId) { - ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes()); - wbb.order(ByteOrder.LITTLE_ENDIAN); - return getUniqueness(wbb) == 0; - } - - public enum Type { - OBJECT, - TASK, - ACTOR, - } -} \ No newline at end of file +} diff --git a/java/runtime-common/src/main/java/org/ray/core/Worker.java b/java/runtime-common/src/main/java/org/ray/core/Worker.java index 2254b5b5d..f6d05c50f 100644 --- a/java/runtime-common/src/main/java/org/ray/core/Worker.java +++ b/java/runtime-common/src/main/java/org/ray/core/Worker.java @@ -82,7 +82,9 @@ public class Worker { public RayObject submit(RayFunc func, Object[] args) { MethodId methodId = methodIdOf(func); - UniqueID taskId = UniqueIdHelper.nextTaskId(-1); + UniqueID taskId = scheduler.generateTaskId(WorkerContext.currentTask().driverId, + WorkerContext.currentTask().taskId, + WorkerContext.nextCallIndex()); if (args.length > 0 && args[0].getClass().equals(RayActor.class)) { return actorTaskSubmit(taskId, methodId, args, (RayActor) args[0]); } else { @@ -123,7 +125,7 @@ public class Worker { } public UniqueID getCurrentTaskNextPutId() { - return UniqueIdHelper.taskComputePutId( + return UniqueIdHelper.computePutId( WorkerContext.currentTask().taskId, WorkerContext.nextPutIndex()); } diff --git a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java index b04b2508d..d135f7891 100644 --- a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java +++ b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerLink.java @@ -21,5 +21,7 @@ public interface LocalSchedulerLink { void notifyUnblocked(); + UniqueID generateTaskId(UniqueID driverId, UniqueID parentTaskId, int taskIndex); + List wait(byte[][] objectIds, int timeoutMs, int numReturns); } diff --git a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java index d2baa0181..ea5b041df 100644 --- a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java +++ b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java @@ -50,7 +50,7 @@ public class LocalSchedulerProxy { private UniqueID[] genReturnIds(UniqueID taskId, int numReturns) { UniqueID[] ret = new UniqueID[numReturns]; for (int i = 0; i < numReturns; i++) { - ret[i] = UniqueIdHelper.taskComputeReturnId(taskId, i, false); + ret[i] = UniqueIdHelper.computeReturnId(taskId, i + 1); } return ret; } @@ -129,4 +129,8 @@ public class LocalSchedulerProxy { return new WaitResult<>(readyObjs, remainObjs); } + + public UniqueID generateTaskId(UniqueID driverId, UniqueID parentTaskId, int taskIndex) { + return scheduler.generateTaskId(driverId, parentTaskId, taskIndex); + } } diff --git a/java/runtime-dev/src/main/java/org/ray/core/impl/RayDevRuntime.java b/java/runtime-dev/src/main/java/org/ray/core/impl/RayDevRuntime.java index 114f533b8..451841250 100644 --- a/java/runtime-dev/src/main/java/org/ray/core/impl/RayDevRuntime.java +++ b/java/runtime-dev/src/main/java/org/ray/core/impl/RayDevRuntime.java @@ -33,7 +33,7 @@ public class RayDevRuntime extends RayRuntime { private byte[] createLocalActor(String className) { UniqueID taskId = WorkerContext.currentTask().taskId; - UniqueID actorId = UniqueIdHelper.taskComputeReturnId(taskId, 0, false); + UniqueID actorId = UniqueIdHelper.computeReturnId(taskId, 0); try { Class cls = Class.forName(className); diff --git a/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java b/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java index 626f08473..ed08c83c6 100644 --- a/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java +++ b/java/runtime-dev/src/main/java/org/ray/spi/impl/MockLocalScheduler.java @@ -89,6 +89,11 @@ public class MockLocalScheduler implements LocalSchedulerLink { } + @Override + public UniqueID generateTaskId(UniqueID driverId, UniqueID parentTaskId, int taskIndex) { + throw new RuntimeException("Not implemented here."); + } + @Override public List wait(byte[][] objectIds, int timeoutMs, int numReturns) { return store.wait(objectIds, timeoutMs, numReturns); diff --git a/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java b/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java index 45d26ac21..337a9e057 100644 --- a/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java +++ b/java/runtime-native/src/main/java/org/ray/core/impl/RayNativeRuntime.java @@ -243,8 +243,13 @@ public class RayNativeRuntime extends RayRuntime { @SuppressWarnings("unchecked") @Override public RayActor create(Class cls) { - UniqueID createTaskId = UniqueIdHelper.nextTaskId(-1); - UniqueID actorId = UniqueIdHelper.taskComputeReturnId(createTaskId, 0, false); + UniqueID createTaskId = localSchedulerProxy.generateTaskId( + WorkerContext.currentTask().driverId, + WorkerContext.currentTask().taskId, + WorkerContext.nextCallIndex() + ); + + UniqueID actorId = UniqueIdHelper.computeReturnId(createTaskId, 0); RayActor actor = new RayActor<>(actorId); UniqueID cursorId; diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java b/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java index 690783812..689b844ad 100644 --- a/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java +++ b/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java @@ -6,8 +6,11 @@ import java.nio.ByteOrder; import java.util.ArrayList; import java.util.List; import java.util.Map; + +import org.ray.api.Ray; import org.ray.api.UniqueID; import org.ray.core.RayRuntime; +import org.ray.core.UniqueIdHelper; import org.ray.spi.LocalSchedulerLink; import org.ray.spi.model.FunctionArg; import org.ray.spi.model.TaskSpec; @@ -42,6 +45,8 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink { private static native byte[] _computePutId(long client, byte[] taskId, int putIndex); + private static native byte[] _generateTaskId(byte[] driverId, byte[] parentTaskId, int taskIndex); + private static native void _task_done(long client); private static native boolean[] _waitObject(long conn, byte[][] objectIds, @@ -111,10 +116,19 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink { @Override public void reconstructObjects(List objectIds, boolean fetchOnly) { - RayLog.core.info("reconstruct objects {}", objectIds); + if (RayLog.core.isInfoEnabled()) { + RayLog.core.info("Reconstructing objects for task {}, object IDs are {}", + UniqueIdHelper.computeTaskId(objectIds.get(0)), objectIds); + } _reconstruct_objects(client, getIdBytes(objectIds), fetchOnly); } + @Override + public UniqueID generateTaskId(UniqueID driverId, UniqueID parentTaskId, int taskIndex) { + byte[] bytes = _generateTaskId(driverId.getBytes(), parentTaskId.getBytes(), taskIndex); + return new UniqueID(bytes); + } + @Override public void notifyUnblocked() { _notify_unblocked(client); diff --git a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java new file mode 100644 index 000000000..0ffa8c582 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java @@ -0,0 +1,84 @@ +package org.ray.api.test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import javax.xml.bind.DatatypeConverter; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.ray.api.UniqueID; +import org.ray.core.UniqueIdHelper; + +@RunWith(MyRunner.class) +public class UniqueIdTest { + + @Test + public void testConstructUniqueId() { + // Test `fromHexString()` + UniqueID id1 = UniqueID.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); + Assert.assertEquals("00000000123456789ABCDEF123456789ABCDEF00", id1.toString()); + Assert.assertFalse(id1.isNil()); + + try { + UniqueID id2 = UniqueID.fromHexString("000000123456789ABCDEF123456789ABCDEF00"); + // This shouldn't be happened. + Assert.assertTrue(false); + } catch (IllegalArgumentException e) { + Assert.assertTrue(true); + } + + try { + UniqueID id3 = UniqueID.fromHexString("GGGGGGGGGGGGG"); + // This shouldn't be happened. + Assert.assertTrue(false); + } catch (IllegalArgumentException e) { + Assert.assertTrue(true); + } + + // Test `fromByteBuffer()` + byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF01234567"); + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 20); + UniqueID id4 = UniqueID.fromByteBuffer(byteBuffer); + Assert.assertTrue(Arrays.equals(bytes, id4.getBytes())); + Assert.assertEquals("0123456789ABCDEF0123456789ABCDEF01234567", id4.toString()); + + + // Test `genNil()` + UniqueID id6 = UniqueID.genNil(); + Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", id6.toString()); + Assert.assertTrue(id6.isNil()); + } + + @Test + public void testComputeReturnId() { + // Mock a taskId, and the lowest 4 bytes should be 0. + UniqueID taskId = UniqueID.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); + + UniqueID returnId = UniqueIdHelper.computeReturnId(taskId, 1); + Assert.assertEquals("01000000123456789ABCDEF123456789ABCDEF00", returnId.toString()); + + returnId = UniqueIdHelper.computeReturnId(taskId, 0x01020304); + Assert.assertEquals("04030201123456789ABCDEF123456789ABCDEF00", returnId.toString()); + } + + @Test + public void testComputeTaskId() { + UniqueID objId = UniqueID.fromHexString("34421980123456789ABCDEF123456789ABCDEF00"); + UniqueID taskId = UniqueIdHelper.computeTaskId(objId); + + Assert.assertEquals("00000000123456789ABCDEF123456789ABCDEF00", taskId.toString()); + } + + @Test + public void testComputePutId() { + // Mock a taskId, the lowest 4 bytes should be 0. + UniqueID taskId = UniqueID.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); + + UniqueID putId = UniqueIdHelper.computePutId(taskId, 1); + Assert.assertEquals("FFFFFFFF123456789ABCDEF123456789ABCDEF00", putId.toString()); + + putId = UniqueIdHelper.computePutId(taskId, 0x01020304); + Assert.assertEquals("FCFCFDFE123456789ABCDEF123456789ABCDEF00", putId.toString()); + } + +} diff --git a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc index 70de5786c..ea4819b4c 100644 --- a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc +++ b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc @@ -3,6 +3,7 @@ #include "local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h" #include "local_scheduler_client.h" #include "logging.h" +#include "ray/id.h" #ifdef __cplusplus extern "C" { @@ -299,6 +300,31 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1waitObject( return resultArray; } +JNIEXPORT jbyteArray JNICALL +Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1generateTaskId( + JNIEnv *env, + jclass, + jbyteArray did, + jbyteArray ptid, + jint parent_task_counter) { + UniqueIdFromJByteArray o1(env, did); + ray::DriverID driver_id = *o1.PID; + + UniqueIdFromJByteArray o2(env, ptid); + ray::TaskID parent_task_id = *o2.PID; + + ray::TaskID task_id = + ray::GenerateTaskId(driver_id, parent_task_id, parent_task_counter); + jbyteArray result = env->NewByteArray(sizeof(ray::TaskID)); + if (nullptr == result) { + return nullptr; + } + env->SetByteArrayRegion(result, 0, sizeof(TaskID), + reinterpret_cast(&task_id)); + + return result; +} + #ifdef __cplusplus } #endif diff --git a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h index edd6574b6..d12a65e90 100644 --- a/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h +++ b/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.h @@ -130,6 +130,18 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1waitObject(JNIEnv *, jint, jboolean); +/* + * Class: org_ray_spi_impl_DefaultLocalSchedulerClient + * Method: _generateTaskId + * Signature: ([B[BI)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1generateTaskId(JNIEnv *, + jclass, + jbyteArray, + jbyteArray, + jint); + #ifdef __cplusplus } #endif diff --git a/src/ray/id.cc b/src/ray/id.cc index 01dbf4444..0f3eb33f8 100644 --- a/src/ray/id.cc +++ b/src/ray/id.cc @@ -6,6 +6,7 @@ #include #include +#include "common/common.h" #include "ray/constants.h" #include "ray/status.h" @@ -177,6 +178,7 @@ const ObjectID ComputeReturnId(const TaskID &task_id, int64_t return_index) { const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index) { RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts); + // We multiply put_index by -1 to distinguish from return_index. return ComputeObjectId(task_id, -1 * put_index); } @@ -190,6 +192,26 @@ const TaskID ComputeTaskId(const ObjectID &object_id) { return task_id; } +const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id, + int parent_task_counter) { + // Compute hashes. + SHA256_CTX ctx; + sha256_init(&ctx); + sha256_update(&ctx, (BYTE *)&driver_id, sizeof(driver_id)); + sha256_update(&ctx, (BYTE *)&parent_task_id, sizeof(parent_task_id)); + sha256_update(&ctx, (BYTE *)&parent_task_counter, sizeof(parent_task_counter)); + + // Compute the final task ID from the hash. + BYTE buff[DIGEST_SIZE]; + sha256_final(&ctx, buff); + TaskID task_id; + RAY_DCHECK(sizeof(task_id) <= DIGEST_SIZE); + memcpy(&task_id, buff, sizeof(task_id)); + task_id = FinishTaskId(task_id); + + return task_id; +} + int64_t ComputeObjectIndex(const ObjectID &object_id) { const int64_t *first_bytes = reinterpret_cast(&object_id); uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; diff --git a/src/ray/id.h b/src/ray/id.h index e2f9cf05a..2f35b151a 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -10,6 +10,10 @@ #include "ray/constants.h" #include "ray/util/visibility.h" +extern "C" { +#include "sha256.h" +} + namespace ray { class RAY_EXPORT UniqueID { @@ -81,6 +85,15 @@ const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index); /// \return The task ID of the task that created this object. const TaskID ComputeTaskId(const ObjectID &object_id); +/// Generate a task ID from the given info. +/// +/// \param driver_id The driver that creates the task. +/// \param parent_task_id The parent task of this task. +/// \param parent_task_counter The task index of the worker. +/// \return The task ID generated from the given info. +const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id, + int parent_task_counter); + /// Compute the index of this object in the task that created it. /// /// \param object_id The object ID. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 99230827a..e19b6c111 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1195,21 +1195,11 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id) { // The task was not in the GCS task table. It must therefore be in the // lineage cache. - if (!lineage_cache_.ContainsTask(task_id)) { - // The task was not in the lineage cache. - // TODO(swang): This should not ever happen, but Java TaskIDs are - // currently computed differently from Python TaskIDs, so - // reconstruction is currently broken for Java. Once the TaskID - // generation code matches for both frontends, we should be able to - // remove this warning and make it a fatal check. - RAY_LOG(WARNING) << "Task " << task_id << " to reconstruct was not found in " - "the GCS or the lineage cache. This " - "job may hang."; - } else { - // Use a copy of the cached task spec to re-execute the task. - const Task task = lineage_cache_.GetTask(task_id); - ResubmitTask(task); - } + RAY_CHECK(lineage_cache_.ContainsTask(task_id)); + // Use a copy of the cached task spec to re-execute the task. + const Task task = lineage_cache_.GetTask(task_id); + ResubmitTask(task); + })); } diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 9e90ddcde..b9fd35f02 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -63,20 +63,8 @@ TaskSpecification::TaskSpecification( : spec_() { flatbuffers::FlatBufferBuilder fbb; - // Compute hashes. - SHA256_CTX ctx; - sha256_init(&ctx); - sha256_update(&ctx, (BYTE *)&driver_id, sizeof(driver_id)); - sha256_update(&ctx, (BYTE *)&parent_task_id, sizeof(parent_task_id)); - sha256_update(&ctx, (BYTE *)&parent_counter, sizeof(parent_counter)); + TaskID task_id = GenerateTaskId(driver_id, parent_task_id, parent_counter); - // Compute the final task ID from the hash. - BYTE buff[DIGEST_SIZE]; - sha256_final(&ctx, buff); - TaskID task_id; - RAY_DCHECK(sizeof(task_id) <= DIGEST_SIZE); - memcpy(&task_id, buff, sizeof(task_id)); - task_id = FinishTaskId(task_id); // Add argument object IDs. std::vector> arguments; for (auto &argument : task_arguments) {