Fix “argument type mismatch” when an exception occurs in chained tasks (#17636)

This commit is contained in:
Kai Yang 2021-08-07 17:47:43 +08:00 committed by GitHub
parent c415c26644
commit 9b3c0ad35b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 0 deletions

View file

@ -6,6 +6,7 @@ import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.exception.RayActorException;
import io.ray.runtime.exception.RayException;
import io.ray.runtime.exception.RayIntentionalSystemExitException;
import io.ray.runtime.exception.RayTaskException;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
@ -85,6 +86,12 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
return results;
}
private void throwIfDependencyFailed(Object arg) {
if (arg instanceof RayException) {
throw (RayException) arg;
}
}
protected List<NativeRayObject> execute(List<String> rayFunctionInfo, List<Object> argsBytes) {
runtime.setIsContextSet(true);
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
@ -122,6 +129,10 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
}
Object[] args =
ArgumentsBuilder.unwrap(argsBytes, rayFunction.executable.getParameterTypes());
for (Object arg : args) {
throwIfDependencyFailed(arg);
}
// Execute the task.
Object result;
try {

View file

@ -48,6 +48,10 @@ public class FailureTest extends BaseTest {
return 0;
}
public static int echo(int obj) {
return obj;
}
public static class BadActor {
public BadActor(boolean failOnCreation) {
@ -184,4 +188,15 @@ public class FailureTest extends BaseTest {
Assert.assertEquals(ex3.getCause().getClass(), UnreconstructableException.class);
Assert.assertEquals(((UnreconstructableException) ex3.getCause()).objectId, objectId);
}
public void testTaskChainWithException() {
ObjectRef<Integer> obj1 = Ray.task(FailureTest::badFunc).remote();
ObjectRef<Integer> obj2 = Ray.task(FailureTest::echo, obj1).remote();
RayTaskException ex = Assert.expectThrows(RayTaskException.class, () -> Ray.get(obj2));
Assert.assertTrue(ex.getCause() instanceof RayTaskException);
RayTaskException ex2 = (RayTaskException) ex.getCause();
Assert.assertTrue(ex2.getCause() instanceof RuntimeException);
RuntimeException ex3 = (RuntimeException) ex2.getCause();
Assert.assertEquals(EXCEPTION_MESSAGE, ex3.getMessage());
}
}