diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java index 5fd6cd20f..43c53ca43 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java @@ -6,6 +6,7 @@ import io.ray.api.id.TaskId; import io.ray.api.id.UniqueId; import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.exception.RayActorException; +import io.ray.runtime.exception.RayException; import io.ray.runtime.exception.RayIntentionalSystemExitException; import io.ray.runtime.exception.RayTaskException; import io.ray.runtime.functionmanager.JavaFunctionDescriptor; @@ -85,6 +86,12 @@ public abstract class TaskExecutor { return results; } + private void throwIfDependencyFailed(Object arg) { + if (arg instanceof RayException) { + throw (RayException) arg; + } + } + protected List execute(List rayFunctionInfo, List argsBytes) { runtime.setIsContextSet(true); TaskType taskType = runtime.getWorkerContext().getCurrentTaskType(); @@ -122,6 +129,10 @@ public abstract class TaskExecutor { } Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.executable.getParameterTypes()); + for (Object arg : args) { + throwIfDependencyFailed(arg); + } + // Execute the task. Object result; try { diff --git a/java/test/src/main/java/io/ray/test/FailureTest.java b/java/test/src/main/java/io/ray/test/FailureTest.java index 524c1ebff..9dd84ea7a 100644 --- a/java/test/src/main/java/io/ray/test/FailureTest.java +++ b/java/test/src/main/java/io/ray/test/FailureTest.java @@ -48,6 +48,10 @@ public class FailureTest extends BaseTest { return 0; } + public static int echo(int obj) { + return obj; + } + public static class BadActor { public BadActor(boolean failOnCreation) { @@ -184,4 +188,15 @@ public class FailureTest extends BaseTest { Assert.assertEquals(ex3.getCause().getClass(), UnreconstructableException.class); Assert.assertEquals(((UnreconstructableException) ex3.getCause()).objectId, objectId); } + + public void testTaskChainWithException() { + ObjectRef obj1 = Ray.task(FailureTest::badFunc).remote(); + ObjectRef obj2 = Ray.task(FailureTest::echo, obj1).remote(); + RayTaskException ex = Assert.expectThrows(RayTaskException.class, () -> Ray.get(obj2)); + Assert.assertTrue(ex.getCause() instanceof RayTaskException); + RayTaskException ex2 = (RayTaskException) ex.getCause(); + Assert.assertTrue(ex2.getCause() instanceof RuntimeException); + RuntimeException ex3 = (RuntimeException) ex2.getCause(); + Assert.assertEquals(EXCEPTION_MESSAGE, ex3.getMessage()); + } }