diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 831de1acf..78257c699 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -10,13 +10,9 @@ import java.lang.reflect.Field; import java.nio.file.Files; import java.nio.file.Paths; import java.nio.file.StandardCopyOption; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.RayPyActor; @@ -37,7 +33,6 @@ import org.ray.runtime.functionmanager.FunctionManager; import org.ray.runtime.functionmanager.PyFunctionDescriptor; import org.ray.runtime.gcs.GcsClient; import org.ray.runtime.objectstore.ObjectStoreProxy; -import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult; import org.ray.runtime.raylet.RayletClient; import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskLanguage; @@ -54,23 +49,6 @@ public abstract class AbstractRayRuntime implements RayRuntime { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class); - /** - * Default timeout of a get. - */ - private static final int GET_TIMEOUT_MS = 1000; - /** - * Split objects in this batch size when fetching or reconstructing them. - */ - private static final int FETCH_BATCH_SIZE = 1000; - /** - * Print a warning every this number of attempts. - */ - private static final int WARN_PER_NUM_ATTEMPTS = 50; - /** - * Max number of ids to print in the warning message. - */ - private static final int MAX_IDS_TO_PRINT_IN_WARNING = 20; - protected RayConfig rayConfig; protected WorkerContext workerContext; protected Worker worker; @@ -182,84 +160,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public List get(List objectIds) { - List ret = new ArrayList<>(Collections.nCopies(objectIds.size(), null)); - boolean wasBlocked = false; - - try { - // A map that stores the unready object ids and their original indexes. - Map unready = new HashMap<>(); - for (int i = 0; i < objectIds.size(); i++) { - unready.put(objectIds.get(i), i); - } - int numAttempts = 0; - - // Repeat until we get all objects. - while (!unready.isEmpty()) { - List unreadyIds = new ArrayList<>(unready.keySet()); - - // For the initial fetch, we only fetch the objects, do not reconstruct them. - boolean fetchOnly = numAttempts == 0; - if (!fetchOnly) { - // If fetchOnly is false, this worker will be blocked. - wasBlocked = true; - } - // Call `fetchOrReconstruct` in batches. - for (List batch : splitIntoBatches(unreadyIds)) { - rayletClient.fetchOrReconstruct(batch, fetchOnly, workerContext.getCurrentTaskId()); - } - - // Get the objects from the object store, and parse the result. - List> getResults = objectStoreProxy.get(unreadyIds, GET_TIMEOUT_MS); - for (int i = 0; i < getResults.size(); i++) { - GetResult getResult = getResults.get(i); - if (getResult.exists) { - if (getResult.exception != null) { - // If the result is an exception, throw it. - throw getResult.exception; - } else { - // Set the result to the return list, and remove it from the unready map. - ObjectId id = unreadyIds.get(i); - ret.set(unready.get(id), getResult.object); - unready.remove(id); - } - } - } - - numAttempts += 1; - if (LOGGER.isWarnEnabled() && numAttempts % WARN_PER_NUM_ATTEMPTS == 0) { - // Print a warning if we've attempted too many times, but some objects are still - // unavailable. - List idsToPrint = new ArrayList<>(unready.keySet()); - if (idsToPrint.size() > MAX_IDS_TO_PRINT_IN_WARNING) { - idsToPrint = idsToPrint.subList(0, MAX_IDS_TO_PRINT_IN_WARNING); - } - String ids = idsToPrint.stream().map(ObjectId::toString) - .collect(Collectors.joining(", ")); - if (idsToPrint.size() < unready.size()) { - ids += ", etc"; - } - String msg = String.format("Attempted %d times to reconstruct objects," - + " but some objects are still unavailable. If this message continues to print," - + " it may indicate that object's creating task is hanging, or something wrong" - + " happened in raylet backend. %d object(s) pending: %s.", numAttempts, - unreadyIds.size(), ids); - LOGGER.warn(msg); - } - } - - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()), - workerContext.getCurrentTaskId()); - } - - return ret; - } finally { - // If there were objects that we weren't able to get locally, let the raylet backend - // know that we're now unblocked. - if (wasBlocked) { - rayletClient.notifyUnblocked(workerContext.getCurrentTaskId()); - } - } + return objectStoreProxy.get(objectIds); } @Override @@ -276,22 +177,6 @@ public abstract class AbstractRayRuntime implements RayRuntime { rayletClient.setResource(resourceName, capacity, nodeId); } - private List> splitIntoBatches(List objectIds) { - List> batches = new ArrayList<>(); - int objectsSize = objectIds.size(); - - for (int i = 0; i < objectsSize; i += FETCH_BATCH_SIZE) { - int endIndex = i + FETCH_BATCH_SIZE; - List batchIds = (endIndex < objectsSize) - ? objectIds.subList(i, endIndex) - : objectIds.subList(i, objectsSize); - - batches.add(batchIds); - } - - return batches; - } - @Override public WaitResult wait(List> waitList, int numReturns, int timeoutMs) { return rayletClient.wait(waitList, numReturns, diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java index 8ec855bca..9fb672a61 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectInterface.java @@ -71,7 +71,8 @@ public class MockObjectInterface implements ObjectInterface { boolean firstCheck = true; while (ready < numObjects && (timeoutMs < 0 || remainingTime > 0)) { if (!firstCheck) { - long sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS); + long sleepTime = + timeoutMs < 0 ? GET_CHECK_INTERVAL_MS : Math.min(remainingTime, GET_CHECK_INTERVAL_MS); try { Thread.sleep(sleepTime); } catch (InterruptedException e) { diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index 5470d719b..3c3696a9c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -1,9 +1,11 @@ package org.ray.runtime.objectstore; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Objects; import org.ray.api.exception.RayActorException; import org.ray.api.exception.RayException; import org.ray.api.exception.RayTaskException; @@ -48,12 +50,11 @@ public class ObjectStoreProxy { * Get an object from the object store. * * @param id Id of the object. - * @param timeoutMs Timeout in milliseconds. * @param Type of the object. * @return The GetResult object. */ - public GetResult get(ObjectId id, int timeoutMs) { - List> list = get(ImmutableList.of(id), timeoutMs); + public T get(ObjectId id) { + List list = get(ImmutableList.of(id)); return list.get(0); } @@ -61,59 +62,57 @@ public class ObjectStoreProxy { * Get a list of objects from the object store. * * @param ids List of the object ids. - * @param timeoutMs Timeout in milliseconds. * @param Type of these objects. * @return A list of GetResult objects. */ - public List> get(List ids, int timeoutMs) { - List dataAndMetaList = objectInterface.get(ids, timeoutMs); + @SuppressWarnings("unchecked") + public List get(List ids) { + // Pass -1 as timeout to wait until all objects are available in object store. + List dataAndMetaList = objectInterface.get(ids, -1); - List> results = new ArrayList<>(); + List results = new ArrayList<>(); for (int i = 0; i < dataAndMetaList.size(); i++) { NativeRayObject dataAndMeta = dataAndMetaList.get(i); - GetResult result; + Object object = null; if (dataAndMeta != null) { byte[] meta = dataAndMeta.metadata; byte[] data = dataAndMeta.data; if (meta != null && meta.length > 0) { // If meta is not null, deserialize the object from meta. - result = deserializeFromMeta(meta, data, + object = deserializeFromMeta(meta, data, workerContext.getCurrentClassLoader(), ids.get(i)); } else { // If data is not null, deserialize the Java object. - Object object = Serializer.decode(data, workerContext.getCurrentClassLoader()); - if (object instanceof RayException) { - // If the object is a `RayException`, it means that an error occurred during task - // execution. - result = new GetResult<>(true, null, (RayException) object); - } else { - // Otherwise, the object is valid. - result = new GetResult<>(true, (T) object, null); - } + object = Serializer.decode(data, workerContext.getCurrentClassLoader()); + } + if (object instanceof RayException) { + // If the object is a `RayException`, it means that an error occurred during task + // execution. + throw (RayException) object; } - } else { - // If both meta and data are null, the object doesn't exist in object store. - result = new GetResult<>(false, null, null); } - results.add(result); + results.add((T) object); } + // This check must be placed after the throw exception statement. + // Because if there was any exception, The get operation would return early + // and wouldn't wait until all objects exist. + Preconditions.checkState(dataAndMetaList.stream().allMatch(Objects::nonNull)); return results; } - @SuppressWarnings("unchecked") - private GetResult deserializeFromMeta(byte[] meta, byte[] data, + private Object deserializeFromMeta(byte[] meta, byte[] data, ClassLoader classLoader, ObjectId objectId) { if (Arrays.equals(meta, RAW_TYPE_META)) { - return (GetResult) new GetResult<>(true, data, null); + return data; } else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { - return new GetResult<>(true, null, RayWorkerException.INSTANCE); + return RayWorkerException.INSTANCE; } else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) { - return new GetResult<>(true, null, RayActorException.INSTANCE); + return RayActorException.INSTANCE; } else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) { - return new GetResult<>(true, null, new UnreconstructableException(objectId)); + return new UnreconstructableException(objectId); } else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) { - return new GetResult<>(true, null, Serializer.decode(data, classLoader)); + return Serializer.decode(data, classLoader); } throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta)); } @@ -130,7 +129,8 @@ public class ObjectStoreProxy { // indicate it's raw binary. So that this object can also be read by Python. objectInterface.put(new NativeRayObject((byte[]) object, RAW_TYPE_META), id); } else if (object instanceof RayTaskException) { - objectInterface.put(new NativeRayObject(Serializer.encode(object), TASK_EXECUTION_EXCEPTION_META), id); + objectInterface + .put(new NativeRayObject(Serializer.encode(object), TASK_EXECUTION_EXCEPTION_META), id); } else { objectInterface.put(new NativeRayObject(Serializer.encode(object), null), id); } @@ -146,32 +146,7 @@ public class ObjectStoreProxy { objectInterface.put(new NativeRayObject(serializedObject, null), id); } - /** - * A class that represents the result of a get operation. - */ - public static class GetResult { - - /** - * Whether this object exists in object store. - */ - public final boolean exists; - - /** - * The Java object that was fetched and deserialized from the object store. Note, this field - * only makes sense when @code{exists == true && exception !=null}. - */ - public final T object; - - /** - * If this field is not null, it represents the exception that occurred during object's creating - * task. - */ - public final RayException exception; - - GetResult(boolean exists, T object, RayException exception) { - this.exists = exists; - this.object = object; - this.exception = exception; - } + public ObjectInterface getObjectInterface() { + return objectInterface; } } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 38995bf9b..d5212af91 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -154,17 +154,6 @@ public class MockRayletClient implements RayletClient { throw new RuntimeException("invalid execution flow here"); } - @Override - public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - TaskId currentTaskId) { - - } - - @Override - public void notifyUnblocked(TaskId currentTaskId) { - - } - @Override public TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) { return TaskId.randomId(); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index 8c6abcd5a..3db431db5 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -18,10 +18,6 @@ public interface RayletClient { TaskSpec getTask(); - void fetchOrReconstruct(List objectIds, boolean fetchOnly, TaskId currentTaskId); - - void notifyUnblocked(TaskId currentTaskId); - TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex); WaitResult wait(List> waitFor, int numReturns, int diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index af94d933e..19ae8c8aa 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -93,28 +93,12 @@ public class RayletClientImpl implements RayletClient { return parseTaskSpecFromProtobuf(bytes); } - @Override - public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - TaskId currentTaskId) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Blocked on objects for task {}, object IDs are {}", - objectIds.get(0).getTaskId(), objectIds); - } - nativeFetchOrReconstruct(client, IdUtil.getIdBytes(objectIds), - fetchOnly, currentTaskId.getBytes()); - } - @Override public TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) { byte[] bytes = nativeGenerateTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex); return new TaskId(bytes); } - @Override - public void notifyUnblocked(TaskId currentTaskId) { - nativeNotifyUnblocked(client, currentTaskId.getBytes()); - } - @Override public void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { @@ -320,14 +304,6 @@ public class RayletClientImpl implements RayletClient { private static native void nativeDestroy(long client) throws RayException; - private static native void nativeFetchOrReconstruct(long client, byte[][] objectIds, - boolean fetchOnly, byte[] currentTaskId) throws RayException; - - private static native void nativeNotifyUnblocked(long client, byte[] currentTaskId) - throws RayException; - - private static native void nativePutObject(long client, byte[] taskId, byte[] objectId); - private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds, int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId) throws RayException; diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index c3dadd8f1..3bba9adf9 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -11,7 +11,7 @@ import org.ray.api.exception.UnreconstructableException; import org.ray.api.id.UniqueId; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayActorImpl; -import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult; +import org.ray.runtime.objectstore.NativeRayObject; import org.testng.Assert; import org.testng.annotations.Test; @@ -100,9 +100,10 @@ public class ActorTest extends BaseTest { Ray.internal().free(ImmutableList.of(value.getId()), false, false); // Wait until the object is deleted, because the above free operation is async. while (true) { - GetResult result = ((AbstractRayRuntime) - Ray.internal()).getObjectStoreProxy().get(value.getId(), 0); - if (!result.exists) { + NativeRayObject result = ((AbstractRayRuntime) + Ray.internal()).getObjectStoreProxy().getObjectInterface() + .get(ImmutableList.of(value.getId()), 0).get(0); + if (result == null) { break; } TimeUnit.MILLISECONDS.sleep(100); diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java index 3c36f2201..13e9930fc 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java @@ -23,8 +23,9 @@ public class PlasmaFreeTest extends BaseTest { Assert.assertEquals("hello", helloString); Ray.internal().free(ImmutableList.of(helloId.getId()), true, false); - final boolean result = TestUtils.waitForCondition(() -> !((AbstractRayRuntime) Ray.internal()) - .getObjectStoreProxy().get(helloId.getId(), 0).exists, 50); + final boolean result = TestUtils.waitForCondition(() -> + ((AbstractRayRuntime) Ray.internal()).getObjectStoreProxy().getObjectInterface() + .get(ImmutableList.of(helloId.getId()), 0).get(0) == null, 50); Assert.assertTrue(result); } diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java index 84adba6d7..eb2e9a909 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java @@ -17,9 +17,9 @@ public class PlasmaStoreTest extends BaseTest { AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal(); ObjectStoreProxy objectInterface = runtime.getObjectStoreProxy(); objectInterface.put(objectId, 1); - Assert.assertEquals(objectInterface.get(objectId, -1).object, (Integer) 1); + Assert.assertEquals(objectInterface.get(objectId), (Integer) 1); objectInterface.put(objectId, 2); // Putting 2 objects with duplicate ID should fail but ignored. - Assert.assertEquals(objectInterface.get(objectId, -1).object, (Integer) 1); + Assert.assertEquals(objectInterface.get(objectId), (Integer) 1); } } diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 51cfaca27..3a86c8c8e 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -92,43 +92,6 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestro delete raylet_client; } -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeFetchOrReconstruct - * Signature: (J[[BZ[B)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( - JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean fetchOnly, - jbyteArray currentTaskId) { - std::vector object_ids; - auto len = env->GetArrayLength(objectIds); - for (int i = 0; i < len; i++) { - jbyteArray object_id_bytes = - static_cast(env->GetObjectArrayElement(objectIds, i)); - const auto object_id = JavaByteArrayToId(env, object_id_bytes); - object_ids.push_back(object_id); - env->DeleteLocalRef(object_id_bytes); - } - const auto current_task_id = JavaByteArrayToId(env, currentTaskId); - auto &raylet_client = *reinterpret_cast *>(client); - auto status = raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeNotifyUnblocked - * Signature: (J[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked( - JNIEnv *env, jclass, jlong client, jbyteArray currentTaskId) { - const auto current_task_id = JavaByteArrayToId(env, currentTaskId); - auto &raylet_client = *reinterpret_cast *>(client); - auto status = raylet_client->NotifyUnblocked(current_task_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeWaitObject diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h index d2538654a..ea9c507f4 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h @@ -39,33 +39,6 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask(JNIEnv *, jclass, jlo JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(JNIEnv *, jclass, jlong); -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeFetchOrReconstruct - * Signature: (J[[BZ[B)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(JNIEnv *, jclass, - jlong, jobjectArray, - jboolean, - jbyteArray); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeNotifyUnblocked - * Signature: (J[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked( - JNIEnv *, jclass, jlong, jbyteArray); - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativePutObject - * Signature: (J[B[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativePutObject( - JNIEnv *, jclass, jlong, jbyteArray, jbyteArray); - /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeWaitObject