Propagate backend error to worker (#4039)

This commit is contained in:
Hao Chen 2019-02-16 11:39:15 +08:00 committed by GitHub
parent 4be3d0c5d3
commit de17443dc2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 635 additions and 258 deletions

View file

@ -19,22 +19,38 @@ import shlex
# These lines added to enable Sphinx to work without installing Ray.
import mock
MOCK_MODULES = [
"gym", "gym.spaces", "scipy", "scipy.signal", "tensorflow",
"tensorflow.contrib", "tensorflow.contrib.all_reduce",
"tensorflow.contrib.all_reduce.python", "tensorflow.contrib.layers",
"tensorflow.contrib.slim", "tensorflow.contrib.rnn", "tensorflow.core",
"tensorflow.core.util", "tensorflow.python", "tensorflow.python.client",
"tensorflow.python.util", "ray.core.generated",
"gym",
"gym.spaces",
"ray._raylet",
"ray.core.generated",
"ray.core.generated.ActorCheckpointIdData",
"ray.core.generated.ClientTableData", "ray.core.generated.GcsTableEntry",
"ray.core.generated.HeartbeatTableData",
"ray.core.generated.ClientTableData",
"ray.core.generated.DriverTableData",
"ray.core.generated.ErrorTableData",
"ray.core.generated.ErrorType",
"ray.core.generated.GcsTableEntry",
"ray.core.generated.HeartbeatBatchTableData",
"ray.core.generated.DriverTableData", "ray.core.generated.ErrorTableData",
"ray.core.generated.ProfileTableData",
"ray.core.generated.HeartbeatTableData",
"ray.core.generated.Language",
"ray.core.generated.ObjectTableData",
"ray.core.generated.ray.protocol.Task", "ray.core.generated.TablePrefix",
"ray.core.generated.TablePubsub", "ray.core.generated.Language",
"ray._raylet"
"ray.core.generated.ProfileTableData",
"ray.core.generated.TablePrefix",
"ray.core.generated.TablePubsub",
"ray.core.generated.ray.protocol.Task",
"scipy",
"scipy.signal",
"tensorflow",
"tensorflow.contrib",
"tensorflow.contrib.all_reduce",
"tensorflow.contrib.all_reduce.python",
"tensorflow.contrib.layers",
"tensorflow.contrib.rnn",
"tensorflow.contrib.slim",
"tensorflow.core",
"tensorflow.core.util",
"tensorflow.python",
"tensorflow.python.client",
"tensorflow.python.util",
]
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()

View file

@ -0,0 +1,16 @@
package org.ray.api.exception;
/**
* Indicates that the actor died unexpectedly before finishing a task.
*
* This exception could happen either because the actor process dies while executing a task, or
* because a task is submitted to a dead actor.
*/
public class RayActorException extends RayException {
public static final RayActorException INSTANCE = new RayActorException();
private RayActorException() {
super("The actor died unexpectedly before finishing this task.");
}
}

View file

@ -0,0 +1,15 @@
package org.ray.api.exception;
/**
* Indicates that a task threw an exception during execution.
*
* If a task throws an exception during execution, a RayTaskException is stored in the object store
* as the task's output. Then when the object is retrieved from the object store, this exception
* will be thrown and propagate the error message.
*/
public class RayTaskException extends RayException {
public RayTaskException(String message, Throwable cause) {
super(message, cause);
}
}

View file

@ -0,0 +1,13 @@
package org.ray.api.exception;
/**
* Indicates that the worker died unexpectedly while executing a task.
*/
public class RayWorkerException extends RayException {
public static final RayWorkerException INSTANCE = new RayWorkerException();
private RayWorkerException() {
super("The worker died unexpectedly while executing this task.");
}
}

View file

@ -0,0 +1,23 @@
package org.ray.api.exception;
import org.ray.api.id.UniqueId;
/**
* Indicates that an object is lost (either evicted or explicitly deleted) and cannot be
* reconstructed.
*
* Note, this exception only happens for actor objects. If actor's current state is after object's
* creating task, the actor cannot re-run the task to reconstruct the object.
*/
public class UnreconstructableException extends RayException {
public final UniqueId objectId;
public UnreconstructableException(UniqueId objectId) {
super(String.format(
"Object %s is lost (either evicted or explicitly deleted) and cannot be reconstructed.",
objectId));
this.objectId = objectId;
}
}

View file

@ -1,12 +1,15 @@
package org.ray.runtime;
import static java.util.stream.Collectors.toList;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import java.util.stream.Collectors;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
@ -21,7 +24,7 @@ import org.ray.runtime.config.RayConfig;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.objectstore.ObjectStoreProxy.GetStatus;
import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult;
import org.ray.runtime.raylet.RayletClient;
import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskSpec;
@ -37,9 +40,22 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
/**
* Default timeout of a get.
*/
private static final int GET_TIMEOUT_MS = 1000;
/**
* Split objects in this batch size when fetching or reconstructing them.
*/
private static final int FETCH_BATCH_SIZE = 1000;
private static final int LIMITED_RETRY_COUNTER = 10;
/**
* Print a warning every this number of attempts.
*/
private static final int WARN_PER_NUM_ATTEMPTS = 50;
/**
* Max number of ids to print in the warning message.
*/
private static final int MAX_IDS_TO_PRINT_IN_WARNING = 20;
protected RayConfig rayConfig;
protected WorkerContext workerContext;
@ -75,7 +91,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
public <T> void put(UniqueId objectId, T obj) {
UniqueId taskId = workerContext.getCurrentTaskId();
LOGGER.debug("Putting object {}, for task {} ", objectId, taskId);
objectStoreProxy.put(objectId, obj, null);
objectStoreProxy.put(objectId, obj);
}
@ -87,10 +103,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
*/
public RayObject<Object> putSerialized(byte[] obj) {
UniqueId objectId = UniqueIdUtil.computePutId(
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
workerContext.getCurrentTaskId(), workerContext.nextPutIndex());
UniqueId taskId = workerContext.getCurrentTaskId();
LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId);
objectStoreProxy.putSerialized(objectId, obj, null);
objectStoreProxy.putSerialized(objectId, obj);
return new RayObjectImpl<>(objectId);
}
@ -102,63 +118,68 @@ public abstract class AbstractRayRuntime implements RayRuntime {
@Override
public <T> List<T> get(List<UniqueId> objectIds) {
List<T> ret = new ArrayList<>(Collections.nCopies(objectIds.size(), null));
boolean wasBlocked = false;
try {
int numObjectIds = objectIds.size();
// Do an initial fetch for remote objects.
List<List<UniqueId>> fetchBatches = splitIntoBatches(objectIds);
for (List<UniqueId> batch : fetchBatches) {
rayletClient.fetchOrReconstruct(batch, true, workerContext.getCurrentTaskId());
// A map that stores the unready object ids and their original indexes.
Map<UniqueId, Integer> unready = new HashMap<>();
for (int i = 0; i < objectIds.size(); i++) {
unready.put(objectIds.get(i), i);
}
int numAttempts = 0;
// Get the objects. We initially try to get the objects immediately.
List<Pair<T, GetStatus>> ret = objectStoreProxy
.get(objectIds, GET_TIMEOUT_MS, false);
assert ret.size() == numObjectIds;
// Repeat until we get all objects.
while (!unready.isEmpty()) {
List<UniqueId> unreadyIds = new ArrayList<>(unready.keySet());
// Mapping the object IDs that we haven't gotten yet to their original index in objectIds.
Map<UniqueId, Integer> unreadys = new HashMap<>();
for (int i = 0; i < numObjectIds; i++) {
if (ret.get(i).getRight() != GetStatus.SUCCESS) {
unreadys.put(objectIds.get(i), i);
// For the initial fetch, we only fetch the objects, do not reconstruct them.
boolean fetchOnly = numAttempts == 0;
if (!fetchOnly) {
// If fetchOnly is false, this worker will be blocked.
wasBlocked = true;
}
}
wasBlocked = (unreadys.size() > 0);
// Try reconstructing any objects we haven't gotten yet. Try to get them
// until at least PlasmaLink.GET_TIMEOUT_MS milliseconds passes, then repeat.
int retryCounter = 0;
while (unreadys.size() > 0) {
retryCounter++;
List<UniqueId> unreadyList = new ArrayList<>(unreadys.keySet());
List<List<UniqueId>> reconstructBatches = splitIntoBatches(unreadyList);
for (List<UniqueId> batch : reconstructBatches) {
rayletClient.fetchOrReconstruct(batch, false, workerContext.getCurrentTaskId());
// Call `fetchOrReconstruct` in batches.
for (List<UniqueId> batch : splitIntoBatches(unreadyIds)) {
rayletClient.fetchOrReconstruct(batch, fetchOnly, workerContext.getCurrentTaskId());
}
List<Pair<T, GetStatus>> results = objectStoreProxy
.get(unreadyList, GET_TIMEOUT_MS, false);
// Remove any entries for objects we received during this iteration so we
// don't retrieve the same object twice.
for (int i = 0; i < results.size(); i++) {
Pair<T, GetStatus> value = results.get(i);
if (value.getRight() == GetStatus.SUCCESS) {
UniqueId id = unreadyList.get(i);
ret.set(unreadys.get(id), value);
unreadys.remove(id);
// Get the objects from the object store, and parse the result.
List<GetResult<T>> getResults = objectStoreProxy.get(unreadyIds, GET_TIMEOUT_MS);
for (int i = 0; i < getResults.size(); i++) {
GetResult<T> getResult = getResults.get(i);
if (getResult.exists) {
if (getResult.exception != null) {
// If the result is an exception, throw it.
throw getResult.exception;
} else {
// Set the result to the return list, and remove it from the unready map.
UniqueId id = unreadyIds.get(i);
ret.set(unready.get(id), getResult.object);
unready.remove(id);
}
}
}
if (retryCounter % LIMITED_RETRY_COUNTER == 0) {
LOGGER.warn("Attempted {} times to reconstruct objects {}, "
+ "but haven't received response. If this message continues to print,"
+ " it may indicate that the task is hanging, or someting wrong "
+ "happened in raylet backend.",
retryCounter, unreadys.keySet());
numAttempts += 1;
if (LOGGER.isWarnEnabled() && numAttempts % WARN_PER_NUM_ATTEMPTS == 0) {
// Print a warning if we've attempted too many times, but some objects are still
// unavailable.
List<UniqueId> idsToPrint = new ArrayList<>(unready.keySet());
if (idsToPrint.size() > MAX_IDS_TO_PRINT_IN_WARNING) {
idsToPrint = idsToPrint.subList(0, MAX_IDS_TO_PRINT_IN_WARNING);
}
String ids = idsToPrint.stream().map(UniqueId::toString)
.collect(Collectors.joining(", "));
if (idsToPrint.size() < unready.size()) {
ids += ", etc";
}
String msg = String.format("Attempted %d times to reconstruct objects,"
+ " but some objects are still unavailable. If this message continues to print,"
+ " it may indicate that object's creating task is hanging, or something wrong"
+ " happened in raylet backend. %d object(s) pending: %s.", numAttempts,
unreadyIds.size(), ids);
LOGGER.warn(msg);
}
}
@ -167,19 +188,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
workerContext.getCurrentTaskId());
}
List<T> finalRet = new ArrayList<>();
for (Pair<T, GetStatus> value : ret) {
finalRet.add(value.getLeft());
}
return finalRet;
} catch (RayException e) {
LOGGER.error("Failed to get objects for task {}.", workerContext.getCurrentTaskId(), e);
throw e;
return ret;
} finally {
// If there were objects that we weren't able to get locally, let the local
// scheduler know that we're now unblocked.
// If there were objects that we weren't able to get locally, let the raylet backend
// know that we're now unblocked.
if (wasBlocked) {
rayletClient.notifyUnblocked(workerContext.getCurrentTaskId());
}
@ -252,6 +264,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
/**
* Create the task specification.
*
* @param func The target remote function.
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
* @param args The arguments for the remote function.
@ -278,7 +291,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
if (!resources.containsKey(ResourceUtil.CPU_LITERAL)
&& !resources.containsKey(ResourceUtil.CPU_LITERAL.toLowerCase())) {
&& !resources.containsKey(ResourceUtil.CPU_LITERAL.toLowerCase())) {
resources.put(ResourceUtil.CPU_LITERAL, 0.0);
}
@ -323,6 +336,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return rayletClient;
}
public ObjectStoreProxy getObjectStoreProxy() {
return objectStoreProxy;
}
public FunctionManager getFunctionManager() {
return functionManager;
}

View file

@ -6,7 +6,7 @@ import java.util.List;
import org.ray.api.Checkpointable;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.Checkpointable.CheckpointContext;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.RayFunction;
@ -118,7 +118,7 @@ public class Worker {
} catch (Exception e) {
LOGGER.error("Error executing task " + spec, e);
if (!spec.isActorCreationTask()) {
runtime.put(returnId, new RayException("Error executing task " + spec, e));
runtime.put(returnId, new RayTaskException("Error executing task " + spec, e));
} else {
actorCreationException = e;
currentActorId = returnId;

View file

@ -1,30 +1,42 @@
package org.ray.runtime.objectstore;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.arrow.plasma.ObjectStoreLink.ObjectStoreData;
import org.apache.arrow.plasma.PlasmaClient;
import org.apache.arrow.plasma.exceptions.DuplicateObjectException;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.generated.ErrorType;
import org.ray.runtime.util.Serializer;
import org.ray.runtime.util.UniqueIdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Object store proxy, which handles serialization and deserialization, and utilize a {@code
* org.ray.spi.ObjectStoreLink} to actually store data.
* A class that is used to put/get objects to/from the object store.
*/
public class ObjectStoreProxy {
private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class);
private static final int GET_TIMEOUT_MS = 1000;
private static final byte[] WORKER_EXCEPTION_META = String.valueOf(ErrorType.WORKER_DIED)
.getBytes();
private static final byte[] ACTOR_EXCEPTION_META = String.valueOf(ErrorType.ACTOR_DIED)
.getBytes();
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes();
private final AbstractRayRuntime runtime;
@ -41,68 +53,134 @@ public class ObjectStoreProxy {
});
}
public <T> Pair<T, GetStatus> get(UniqueId objectId, boolean isMetadata)
throws RayException {
return get(objectId, GET_TIMEOUT_MS, isMetadata);
/**
* Get an object from the object store.
*
* @param id Id of the object.
* @param timeoutMs Timeout in milliseconds.
* @param <T> Type of the object.
* @return The GetResult object.
*/
public <T> GetResult<T> get(UniqueId id, int timeoutMs) {
List<GetResult<T>> list = get(ImmutableList.of(id), timeoutMs);
return list.get(0);
}
public <T> Pair<T, GetStatus> get(UniqueId id, int timeoutMs, boolean isMetadata)
throws RayException {
byte[] obj = objectStore.get().get(id.getBytes(), timeoutMs, isMetadata);
if (obj != null) {
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
objectStore.get().release(id.getBytes());
if (t instanceof RayException) {
throw (RayException) t;
}
return Pair.of(t, GetStatus.SUCCESS);
} else {
return Pair.of(null, GetStatus.FAILED);
}
}
/**
* Get a list of objects from the object store.
*
* @param ids List of the object ids.
* @param timeoutMs Timeout in milliseconds.
* @param <T> Type of these objects.
* @return A list of GetResult objects.
*/
public <T> List<GetResult<T>> get(List<UniqueId> ids, int timeoutMs) {
byte[][] binaryIds = UniqueIdUtil.getIdBytes(ids);
List<ObjectStoreData> dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs);
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> objectIds, boolean isMetadata)
throws RayException {
return get(objectIds, GET_TIMEOUT_MS, isMetadata);
}
List<GetResult<T>> results = new ArrayList<>();
for (int i = 0; i < dataAndMetaList.size(); i++) {
// TODO(hchen): Plasma API returns data and metadata in wrong order, this should be fixed
// from the arrow side first.
byte[] meta = dataAndMetaList.get(i).data;
byte[] data = dataAndMetaList.get(i).metadata;
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> ids, int timeoutMs, boolean isMetadata)
throws RayException {
List<byte[]> objs = objectStore.get().get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata);
List<Pair<T, GetStatus>> ret = new ArrayList<>();
for (int i = 0; i < objs.size(); i++) {
byte[] obj = objs.get(i);
if (obj != null) {
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
objectStore.get().release(ids.get(i).getBytes());
if (t instanceof RayException) {
throw (RayException) t;
GetResult<T> result;
if (meta != null) {
// If meta is not null, deserialize the exception.
RayException exception = deserializeRayExceptionFromMeta(meta, ids.get(i));
result = new GetResult<>(true, null, exception);
} else if (data != null) {
// If data is not null, deserialize the Java object.
Object object = Serializer.decode(data, runtime.getWorkerContext().getCurrentClassLoader());
if (object instanceof RayException) {
// If the object is a `RayException`, it means that an error occurred during task
// execution.
result = new GetResult<>(true, null, (RayException) object);
} else {
// Otherwise, the object is valid.
result = new GetResult<>(true, (T) object, null);
}
ret.add(Pair.of(t, GetStatus.SUCCESS));
} else {
ret.add(Pair.of(null, GetStatus.FAILED));
// If both meta and data are null, the object doesn't exist in object store.
result = new GetResult<>(false, null, null);
}
if (meta != null || data != null) {
// Release the object from object store..
objectStore.get().release(binaryIds[i]);
}
results.add(result);
}
return ret;
return results;
}
public void put(UniqueId id, Object obj, Object metadata) {
private RayException deserializeRayExceptionFromMeta(byte[] meta, UniqueId objectId) {
if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
return RayWorkerException.INSTANCE;
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
return RayActorException.INSTANCE;
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new UnreconstructableException(objectId);
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
}
/**
* Serialize and put an object to the object store.
*
* @param id Id of the object.
* @param object The object to put.
*/
public void put(UniqueId id, Object object) {
try {
objectStore.get().put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
objectStore.get().put(id.getBytes(), Serializer.encode(object), null);
} catch (DuplicateObjectException e) {
LOGGER.warn(e.getMessage());
}
}
public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) {
/**
* Put an already serialized object to the object store.
*
* @param id Id of the object.
* @param serializedObject The serialized object to put.
*/
public void putSerialized(UniqueId id, byte[] serializedObject) {
try {
objectStore.get().put(id.getBytes(), obj, metadata);
objectStore.get().put(id.getBytes(), serializedObject, null);
} catch (DuplicateObjectException e) {
LOGGER.warn(e.getMessage());
}
}
public enum GetStatus {
SUCCESS, FAILED
/**
* A class that represents the result of a get operation.
*/
public static class GetResult<T> {
/**
* Whether this object exists in object store.
*/
public final boolean exists;
/**
* The Java object that was fetched and deserialized from the object store. Note, this field
* only makes sense when @code{exists == true && exception !=null}.
*/
public final T object;
/**
* If this field is not null, it represents the exception that occurred during object's creating
* task.
*/
public final RayException exception;
GetResult(boolean exists, T object, RayException exception) {
this.exists = exists;
this.object = object;
this.exception = exception;
}
}
}

View file

@ -10,6 +10,7 @@ import org.ray.api.Checkpointable;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.annotation.RayRemote;
import org.ray.api.exception.RayActorException;
import org.ray.api.id.UniqueId;
import org.ray.api.options.ActorCreationOptions;
import org.testng.Assert;
@ -60,11 +61,8 @@ public class ActorReconstructionTest extends BaseTest {
try {
Ray.call(Counter::increase, actor).get();
Assert.fail("The above task didn't fail.");
} catch (StringIndexOutOfBoundsException e) {
// Raylet backend will put invalid data in task's result to indicate the task has failed.
// Thus, Java deserialization will fail and throw `StringIndexOutOfBoundsException`.
// TODO(hchen): we should use object's metadata to indicate task failure,
// instead of throwing this exception.
} catch (RayActorException e) {
// We should receive a RayActorException because the actor is dead.
}
}

View file

@ -1,11 +1,16 @@
package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import java.util.concurrent.TimeUnit;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayActorImpl;
import org.ray.runtime.objectstore.ObjectStoreProxy.GetResult;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -83,4 +88,30 @@ public class ActorTest extends BaseTest {
Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increase, counter2, 2).get());
}
@Test
public void testUnreconstructableActorObject() throws InterruptedException {
RayActor<Counter> counter = Ray.createActor(Counter::new, 100);
// Call an actor method.
RayObject value = Ray.call(Counter::getValue, counter);
Assert.assertEquals(100, value.get());
// Delete the object from the object store.
Ray.internal().free(ImmutableList.of(value.getId()), false);
// Wait until the object is deleted, because the above free operation is async.
while (true) {
GetResult<Integer> result = ((AbstractRayRuntime)
Ray.internal()).getObjectStoreProxy().get(value.getId(), 0);
if (!result.exists) {
break;
}
TimeUnit.MILLISECONDS.sleep(100);
}
try {
// Try getting the object again, this should throw an UnreconstructableException.
value.get();
Assert.fail("This line should not be reachable.");
} catch (UnreconstructableException e) {
Assert.assertEquals(value.getId(), e.objectId);
}
}
}

View file

@ -3,7 +3,9 @@ package org.ray.api.test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -15,6 +17,11 @@ public class FailureTest extends BaseTest {
throw new RuntimeException(EXCEPTION_MESSAGE);
}
public static int badFunc2() {
System.exit(-1);
return 0;
}
public static class BadActor {
public BadActor(boolean failOnCreation) {
@ -23,17 +30,21 @@ public class FailureTest extends BaseTest {
}
}
public int func() {
public int badMethod() {
throw new RuntimeException(EXCEPTION_MESSAGE);
}
public int badMethod2() {
System.exit(-1);
return 0;
}
}
private static void assertTaskFail(RayObject<?> rayObject) {
private static void assertTaskFailedWithRayTaskException(RayObject<?> rayObject) {
try {
rayObject.get();
Assert.fail("Task didn't fail.");
} catch (RayException e) {
e.printStackTrace();
} catch (RayTaskException e) {
Throwable rootCause = e.getCause();
while (rootCause.getCause() != null) {
rootCause = rootCause.getCause();
@ -45,19 +56,49 @@ public class FailureTest extends BaseTest {
@Test
public void testNormalTaskFailure() {
assertTaskFail(Ray.call(FailureTest::badFunc));
assertTaskFailedWithRayTaskException(Ray.call(FailureTest::badFunc));
}
@Test
public void testActorCreationFailure() {
RayActor<BadActor> actor = Ray.createActor(BadActor::new, true);
assertTaskFail(Ray.call(BadActor::func, actor));
assertTaskFailedWithRayTaskException(Ray.call(BadActor::badMethod, actor));
}
@Test
public void testActorTaskFailure() {
RayActor<BadActor> actor = Ray.createActor(BadActor::new, false);
assertTaskFail(Ray.call(BadActor::func, actor));
assertTaskFailedWithRayTaskException(Ray.call(BadActor::badMethod, actor));
}
@Test
public void testWorkerProcessDying() {
try {
Ray.call(FailureTest::badFunc2).get();
Assert.fail("This line shouldn't be reached.");
} catch (RayWorkerException e) {
// When the worker process dies while executing a task, we should receive an
// RayWorkerException.
}
}
@Test
public void testActorProcessDying() {
RayActor<BadActor> actor = Ray.createActor(BadActor::new, false);
try {
Ray.call(BadActor::badMethod2, actor).get();
Assert.fail("This line shouldn't be reached.");
} catch (RayActorException e) {
// When the actor process dies while executing a task, we should receive an
// RayActorException.
}
try {
Ray.call(BadActor::badMethod, actor).get();
Assert.fail("This line shouldn't be reached.");
} catch (RayActorException e) {
// When a actor task is submitted to a dead actor, we should also receive an
// RayActorException.
}
}
}

105
python/ray/exceptions.py Normal file
View file

@ -0,0 +1,105 @@
import os
import colorama
try:
import setproctitle
except ImportError:
setproctitle = None
class RayError(Exception):
"""Super class of all ray exception types."""
pass
class RayTaskError(RayError):
"""Indicates that a task threw an exception during execution.
If a task throws an exception during execution, a RayTaskError is stored in
the object store for each of the task's outputs. When an object is
retrieved from the object store, the Python method that retrieved it checks
to see if the object is a RayTaskError and if it is then an exception is
thrown propagating the error message.
Attributes:
function_name (str): The name of the function that failed and produced
the RayTaskError.
traceback_str (str): The traceback from the exception.
"""
def __init__(self, function_name, traceback_str):
"""Initialize a RayTaskError."""
if setproctitle:
self.proctitle = setproctitle.getproctitle()
else:
self.proctitle = "ray_worker"
self.pid = os.getpid()
self.host = os.uname()[1]
self.function_name = function_name
self.traceback_str = traceback_str
assert traceback_str is not None
def __str__(self):
"""Format a RayTaskError as a string."""
lines = self.traceback_str.split("\n")
out = []
in_worker = False
for line in lines:
if line.startswith("Traceback "):
out.append("{}{}{} (pid={}, host={})".format(
colorama.Fore.CYAN, self.proctitle, colorama.Fore.RESET,
self.pid, self.host))
elif in_worker:
in_worker = False
elif "ray/worker.py" in line or "ray/function_manager.py" in line:
in_worker = True
else:
out.append(line)
return "\n".join(out)
class RayWorkerError(RayError):
"""Indicates that the worker died unexpectedly while executing a task."""
def __str__(self):
return "The worker died unexpectedly while executing this task."
class RayActorError(RayError):
"""Indicates that the actor died unexpectedly before finishing a task.
This exception could happen either because the actor process dies while
executing a task, or because a task is submitted to a dead actor.
"""
def __str__(self):
return "The actor died unexpectedly before finishing this task."
class UnreconstructableError(RayError):
"""Indicates that an object is lost and cannot be reconstructed.
Note, this exception only happens for actor objects. If actor's current
state is after object's creating task, the actor cannot re-run the task to
reconstruct the object.
Attributes:
object_id: ID of the object.
"""
def __init__(self, object_id):
self.object_id = object_id
def __str__(self):
return ("Object {} is lost (either evicted or explicitly deleted) and "
+ "cannot be reconstructed.").format(self.object_id.hex())
RAY_EXCEPTION_TYPES = [
RayError,
RayTaskError,
RayWorkerError,
RayActorError,
UnreconstructableError,
]

View file

@ -4,7 +4,6 @@ from __future__ import print_function
from contextlib import contextmanager
import atexit
import colorama
import faulthandler
import hashlib
import inspect
@ -28,18 +27,43 @@ import ray.experimental.state as state
import ray.gcs_utils
import ray.memory_monitor as memory_monitor
import ray.node
import ray.parameter
import ray.ray_constants as ray_constants
import ray.remote_function
import ray.serialization as serialization
import ray.services as services
import ray.signature
import ray.ray_constants as ray_constants
from ray import (
ActorHandleID,
ActorID,
ClientID,
DriverID,
ObjectID,
TaskID,
)
from ray import import_thread
from ray import ObjectID, DriverID, ActorID, ActorHandleID, ClientID, TaskID
from ray import profiling
from ray.function_manager import (FunctionActorManager, FunctionDescriptor)
import ray.parameter
from ray.utils import (check_oversized_pickle, is_cython, _random_string,
thread_safe_client, setup_logger)
from ray.core.generated.ErrorType import ErrorType
from ray.exceptions import (
RayActorError,
RayError,
RayTaskError,
RayWorkerError,
UnreconstructableError,
RAY_EXCEPTION_TYPES,
)
from ray.function_manager import (
FunctionActorManager,
FunctionDescriptor,
)
from ray.utils import (
_random_string,
check_oversized_pickle,
is_cython,
setup_logger,
thread_safe_client,
)
SCRIPT_MODE = 0
WORKER_MODE = 1
@ -68,55 +92,6 @@ except ImportError:
setproctitle = None
class RayTaskError(Exception):
"""An object used internally to represent a task that threw an exception.
If a task throws an exception during execution, a RayTaskError is stored in
the object store for each of the task's outputs. When an object is
retrieved from the object store, the Python method that retrieved it checks
to see if the object is a RayTaskError and if it is then an exception is
thrown propagating the error message.
Currently, we either use the exception attribute or the traceback attribute
but not both.
Attributes:
function_name (str): The name of the function that failed and produced
the RayTaskError.
traceback_str (str): The traceback from the exception.
"""
def __init__(self, function_name, traceback_str):
"""Initialize a RayTaskError."""
if setproctitle:
self.proctitle = setproctitle.getproctitle()
else:
self.proctitle = "ray_worker"
self.pid = os.getpid()
self.host = os.uname()[1]
self.function_name = function_name
self.traceback_str = traceback_str
assert traceback_str is not None
def __str__(self):
"""Format a RayTaskError as a string."""
lines = self.traceback_str.split("\n")
out = []
in_worker = False
for line in lines:
if line.startswith("Traceback "):
out.append("{}{}{} (pid={}, host={})".format(
colorama.Fore.CYAN, self.proctitle, colorama.Fore.RESET,
self.pid, self.host))
elif in_worker:
in_worker = False
elif "ray/worker.py" in line or "ray/function_manager.py" in line:
in_worker = True
else:
out.append(line)
return "\n".join(out)
class ActorCheckpointInfo(object):
"""Information used to maintain actor checkpoints."""
@ -400,6 +375,8 @@ class Worker(object):
start_time = time.time()
# Only send the warning once.
warning_sent = False
serialization_context = self.get_serialization_context(
self.task_driver_id)
while True:
try:
# We divide very large get requests into smaller get requests
@ -407,23 +384,23 @@ class Worker(object):
# long time, if the store is blocked, it can block the manager
# as well as a consequence.
results = []
for i in range(0, len(object_ids),
ray._config.worker_get_request_size()):
results += self.plasma_client.get(
object_ids[i:(
i + ray._config.worker_get_request_size())],
batch_size = ray._config.worker_fetch_request_size()
for i in range(0, len(object_ids), batch_size):
metadata_data_pairs = self.plasma_client.get_buffers(
object_ids[i:i + batch_size],
timeout,
self.get_serialization_context(self.task_driver_id))
with_meta=True,
)
for j in range(len(metadata_data_pairs)):
metadata, data = metadata_data_pairs[j]
results.append(
self._deserialize_object_from_arrow(
data,
metadata,
object_ids[i + j],
serialization_context,
))
return results
except pyarrow.lib.ArrowInvalid:
# TODO(ekl): the local scheduler could include relevant
# metadata in the task kill case for a better error message
invalid_error = RayTaskError(
"<unknown>",
"Invalid return value: likely worker died or was killed "
"while executing the task; check previous logs or dmesg "
"for errors.")
return [invalid_error] * len(object_ids)
except pyarrow.DeserializationCallbackError:
# Wait a little bit for the import thread to import the class.
# If we currently have the worker lock, we need to release it
@ -448,6 +425,30 @@ class Worker(object):
driver_id=self.task_driver_id)
warning_sent = True
def _deserialize_object_from_arrow(self, data, metadata, object_id,
serialization_context):
if metadata:
# If metadata is not empty, return an exception object based on
# the error type.
error_type = int(metadata)
if error_type == ErrorType.WORKER_DIED:
return RayWorkerError()
elif error_type == ErrorType.ACTOR_DIED:
return RayActorError()
elif error_type == ErrorType.OBJECT_UNRECONSTRUCTABLE:
return UnreconstructableError(ray.ObjectID(object_id.binary()))
else:
assert False, "Unrecognized error type " + str(error_type)
elif data:
# If data is not empty, deserialize the object.
# Note, the lock is needed because `serialization_context` isn't
# thread-safe.
with self.plasma_client.lock:
return pyarrow.deserialize(data, serialization_context)
else:
# Object isn't available in plasma.
return plasma.ObjectNotAvailable
def get_object(self, object_ids):
"""Get the value or values in the object store associated with the IDs.
@ -741,7 +742,7 @@ class Worker(object):
passed by value.
Raises:
RayTaskError: This exception is raised if a task that
RayError: This exception is raised if a task that
created one of the arguments failed.
"""
arguments = []
@ -749,7 +750,7 @@ class Worker(object):
if isinstance(arg, ObjectID):
# get the object from the local object store
argument = self.get_object([arg])[0]
if isinstance(argument, RayTaskError):
if isinstance(argument, RayError):
raise argument
else:
# pass the argument by value
@ -831,11 +832,6 @@ class Worker(object):
with profiling.profile("task:deserialize_arguments"):
arguments = self._get_arguments_for_execution(
function_name, args)
except RayTaskError as e:
self._handle_process_task_failure(
function_descriptor, return_object_ids, e,
ray.utils.format_error_message(traceback.format_exc()))
return
except Exception as e:
self._handle_process_task_failure(
function_descriptor, return_object_ids, e,
@ -1155,12 +1151,15 @@ def _initialize_serialization(driver_id, worker=global_worker):
worker.serialization_context_map[driver_id] = serialization_context
register_custom_serializer(
RayTaskError,
use_dict=True,
local=True,
driver_id=driver_id,
class_id="ray.RayTaskError")
# Register exception types.
for error_cls in RAY_EXCEPTION_TYPES:
register_custom_serializer(
error_cls,
use_dict=True,
local=True,
driver_id=driver_id,
class_id=error_cls.__module__ + ". " + error_cls.__name__,
)
# Tell Ray to serialize lambdas with pickle.
register_custom_serializer(
type(lambda: 0),
@ -2229,14 +2228,14 @@ def get(object_ids):
if isinstance(object_ids, list):
values = worker.get_object(object_ids)
for i, value in enumerate(values):
if isinstance(value, RayTaskError):
if isinstance(value, RayError):
last_task_error_raise_time = time.time()
raise value
return values
else:
value = worker.get_object([object_ids])[0]
if isinstance(value, RayTaskError):
# If the result is a RayTaskError, then the task that created
if isinstance(value, RayError):
# If the result is a RayError, then the task that created
# this object failed, and we should propagate the error message
# here.
last_task_error_raise_time = time.time()

View file

@ -347,3 +347,23 @@ table ActorCheckpointIdData {
// A list of the timestamps for each of the above `checkpoint_ids`.
timestamps: [long];
}
// This enum type is used as object's metadata to indicate the object's creating
// task has failed because of a certain error.
// TODO(hchen): We may want to make these errors more specific. E.g., we may want
// to distinguish between intentional and expected actor failures, and between
// worker process failure and node failure.
enum ErrorType:int {
// Indicates that a task failed because the worker died unexpectedly while executing it.
WORKER_DIED = 1,
// Indicates that a task failed because the actor died unexpectedly before finishing it.
ACTOR_DIED = 2,
// Indicates that an object is lost and cannot be reconstructed.
// Note, this currently only happens to actor objects. When the actor's state is already
// after the object's creating task, the actor cannot re-run the task.
// TODO(hchen): we may want to reuse this error type for more cases. E.g.,
// 1) A object that was put by the driver.
// 2) The object's creating task is already cleaned up from GCS (this currently
// crashes raylet).
OBJECT_UNRECONSTRUCTABLE = 3,
}

View file

@ -574,7 +574,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id,
auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id);
auto removed_tasks = local_queues_.RemoveTasks(tasks_to_remove);
for (auto const &task : removed_tasks) {
TreatTaskAsFailed(task);
TreatTaskAsFailed(task, ErrorType::ACTOR_DIED);
}
} else {
RAY_CHECK(actor_registration.GetState() == ActorState::RECONSTRUCTING);
@ -858,7 +858,7 @@ void NodeManager::ProcessDisconnectClientMessage(
// `HandleDisconnectedActor`.
if (actor_id.is_nil()) {
const Task &task = local_queues_.RemoveTask(task_id);
TreatTaskAsFailed(task);
TreatTaskAsFailed(task, ErrorType::WORKER_DIED);
}
if (!intentional_disconnect) {
@ -1214,9 +1214,10 @@ bool NodeManager::CheckDependencyManagerInvariant() const {
return true;
}
void NodeManager::TreatTaskAsFailed(const Task &task) {
void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_type) {
const TaskSpecification &spec = task.GetTaskSpecification();
RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed.";
RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed because of error "
<< EnumNameErrorType(error_type) << ".";
// If this was an actor creation task that tried to resume from a checkpoint,
// then erase it here since the task did not finish.
if (spec.IsActorCreationTask()) {
@ -1231,20 +1232,22 @@ void NodeManager::TreatTaskAsFailed(const Task &task) {
// information about the TaskSpecification implementation.
num_returns -= 1;
}
const std::string meta = std::to_string(static_cast<int>(error_type));
for (int64_t i = 0; i < num_returns; i++) {
const ObjectID object_id = spec.ReturnId(i);
std::shared_ptr<Buffer> data;
// TODO(ekl): this writes an invalid arrow object, which is sufficient to
// signal that the worker failed, but it would be nice to return more
// detailed failure metadata in the future.
arrow::Status status =
store_client_.Create(object_id.to_plasma_id(), 1, NULL, 0, &data);
if (!status.IsPlasmaObjectExists()) {
// TODO(rkn): We probably don't want this checks. E.g., if the object
// store is full, we don't want to kill the raylet.
RAY_ARROW_CHECK_OK(status);
RAY_ARROW_CHECK_OK(store_client_.Seal(object_id.to_plasma_id()));
const auto object_id = spec.ReturnId(i).to_plasma_id();
arrow::Status status = store_client_.CreateAndSeal(object_id, "", meta);
if (!status.ok() && !status.IsPlasmaObjectExists()) {
// If we failed to save the error code, log a warning and push an error message
// to the driver.
std::ostringstream stream;
stream << "An plasma error (" << status.ToString() << ") occurred while saving"
<< " error code to object " << object_id << ". Anyone who's getting this"
<< " object may hang forever.";
std::string error_message = stream.str();
RAY_LOG(WARNING) << error_message;
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
task.GetTaskSpecification().DriverId(), "task", error_message,
current_time_ms()));
}
}
// A task failing is equivalent to assigning and finishing the task, so clean
@ -1297,7 +1300,7 @@ void NodeManager::TreatTaskAsFailedIfLost(const Task &task) {
// The object does not exist on any nodes but has been created
// before, so the object has been lost. Mark the task as failed to
// prevent any tasks that depend on this object from hanging.
TreatTaskAsFailed(task);
TreatTaskAsFailed(task, ErrorType::OBJECT_UNRECONSTRUCTABLE);
*task_marked_as_failed = true;
}
}
@ -1343,7 +1346,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag
if (actor_entry->second.GetState() == ActorState::DEAD) {
// If this actor is dead, either because the actor process is dead
// or because its residing node is dead, treat this task as failed.
TreatTaskAsFailed(task);
TreatTaskAsFailed(task, ErrorType::ACTOR_DIED);
} else {
// If this actor is alive, check whether this actor is local.
auto node_manager_id = actor_entry->second.GetNodeManagerId();

View file

@ -157,8 +157,9 @@ class NodeManager {
/// the local queue.
///
/// \param task The task to fail.
/// \param error_type The type of the error that caused this task to fail.
/// \return Void.
void TreatTaskAsFailed(const Task &task);
void TreatTaskAsFailed(const Task &task, const ErrorType &error_type);
/// This is similar to TreatTaskAsFailed, but it will only mark the task as
/// failed if at least one of the task's return values is lost. A return
/// value is lost if it has been created before, but no longer exists on any

View file

@ -1405,7 +1405,7 @@ def test_exception_raised_when_actor_node_dies(head_node_cluster):
# Submit some new actor tasks.
x_ids = [actor.inc.remote() for _ in range(5)]
for x_id in x_ids:
with pytest.raises(ray.worker.RayTaskError):
with pytest.raises(ray.exceptions.RayActorError):
# There is some small chance that ray.get will actually
# succeed (if the object is transferred before the raylet
# dies).
@ -2128,7 +2128,7 @@ def test_actor_eviction(shutdown_only):
try:
ray.get(obj)
num_success += 1
except ray.worker.RayTaskError:
except ray.exceptions.UnreconstructableError:
num_evicted += 1
# Some objects should have been evicted, and some should still be in the
# object store.
@ -2173,7 +2173,7 @@ def test_actor_reconstruction(ray_start_regular):
pid = ray.get(actor.get_pid.remote())
os.kill(pid, signal.SIGKILL)
# The actor has exceeded max reconstructions, and this task should fail.
with pytest.raises(ray.worker.RayTaskError):
with pytest.raises(ray.exceptions.RayActorError):
ray.get(actor.increase.remote())
# Create another actor.
@ -2181,7 +2181,7 @@ def test_actor_reconstruction(ray_start_regular):
# Intentionlly exit the actor
actor.__ray_terminate__.remote()
# Check that the actor won't be reconstructed.
with pytest.raises(ray.worker.RayTaskError):
with pytest.raises(ray.exceptions.RayActorError):
ray.get(actor.increase.remote())
@ -2241,7 +2241,7 @@ def test_actor_reconstruction_on_node_failure(head_node_cluster):
object_store_socket = ray.get(actor.get_object_store_socket.remote())
kill_node(object_store_socket)
# The actor has exceeded max reconstructions, and this task should fail.
with pytest.raises(ray.worker.RayTaskError):
with pytest.raises(ray.exceptions.RayActorError):
ray.get(actor.increase.remote())

View file

@ -279,7 +279,7 @@ def test_worker_failed(ray_start_workers_separate_multinode):
for object_id in object_ids:
try:
ray.get(object_id)
except ray.worker.RayTaskError:
except (ray.exceptions.RayTaskError, ray.exceptions.RayWorkerError):
pass
@ -424,7 +424,7 @@ def test_actor_creation_node_failure(ray_start_cluster):
for i, out in enumerate(children_out):
try:
ray.get(out)
except ray.worker.RayTaskError:
except ray.exceptions.RayActorError:
children[i] = Child.remote(death_probability)
# Remove a node. Any actor creation tasks that were forwarded to this
# node must be reconstructed.

View file

@ -319,7 +319,8 @@ def test_worker_dying(ray_start_regular):
def f():
eval("exit()")
f.remote()
with pytest.raises(ray.exceptions.RayWorkerError):
ray.get(f.remote())
wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1)
@ -340,9 +341,9 @@ def test_actor_worker_dying(ray_start_regular):
a = Actor.remote()
[obj], _ = ray.wait([a.kill.remote()], timeout=5.0)
with pytest.raises(Exception):
with pytest.raises(ray.exceptions.RayActorError):
ray.get(obj)
with pytest.raises(Exception):
with pytest.raises(ray.exceptions.RayTaskError):
ray.get(consume.remote(obj))
wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1)

View file

@ -2621,7 +2621,7 @@ def test_inline_objects(shutdown_only):
value = ray.get(inline_object)
assert value == "inline"
inlined += 1
except ray.worker.RayTaskError:
except ray.exceptions.UnreconstructableError:
pass
# Make sure some objects were inlined. Some of them may not get inlined
# because we evict the object soon after creating it.
@ -2638,7 +2638,7 @@ def test_inline_objects(shutdown_only):
ray.worker.global_worker.plasma_client.delete([plasma_id])
# Objects created by an actor that were evicted and larger than the
# maximum inline object size cannot be retrieved or reconstructed.
with pytest.raises(ray.worker.RayTaskError):
with pytest.raises(ray.exceptions.UnreconstructableError):
ray.get(non_inline_object) == 10000 * [1]

View file

@ -106,7 +106,7 @@ def test_task_crash(ray_start):
try:
ray.get(object_id)
except Exception as e:
assert type(e) == ray.worker.RayTaskError
assert type(e) == ray.exceptions.RayTaskError
finally:
result_list = signal.receive([object_id], timeout=5)
assert len(result_list) == 1
@ -142,7 +142,7 @@ def test_actor_crash(ray_start):
try:
ray.get(a.crash.remote())
except Exception as e:
assert type(e) == ray.worker.RayTaskError
assert type(e) == ray.exceptions.RayTaskError
finally:
result_list = signal.receive([a], timeout=5)
assert len(result_list) == 1
@ -184,7 +184,7 @@ def test_actor_crash_init2(ray_start):
try:
ray.get(a.method.remote())
except Exception as e:
assert type(e) == ray.worker.RayTaskError
assert type(e) == ray.exceptions.RayTaskError
finally:
result_list = receive_all_signals([a], timeout=5)
assert len(result_list) == 2