mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
Cross language serialization for primitive types (#7711)
* Cross language serialization for Java and Python * Use strict types when Python serializing * Handle recursive objects in Python; Pin msgpack >= 0.6.0, < 1.0.0 * Disable gc for optimizing msgpack loads * Fix merge bug * Java call Python use returnType; Fix ClassLoaderTest * Fix RayMethodsTest * Fix checkstyle * Fix lint * prepare_args raises exception if try to transfer a non-deserializable object to another language * Fix CrossLanguageInvocationTest.java, Python msgpack treat float as double * Minor fixes * Fix compile error on linux * Fix lint in java/BUILD.bazel * Fix test_failure * Fix lint * Class<?> to Class<T>; Refine metadata bytes. * Rename FST to Fst; sort java dependencies * Change Class<?>[] to Optional<Class<?>>; sort requirements in setup.py * Improve CrossLanguageInvocationTest * Refactor MessagePackSerializer.java * Refactor MessagePackSerializer.java; Refine CrossLanguageInvocationTest.java * Remove unnecessary dependencies for Java; Add getReturnType() for RayFunction in Java * Fix bug * Remove custom cross language type support * Replace Serializer.Meta with MutableBoolean * Remove @SuppressWarnings support from checkstyle.xml; Add null test in CrossLanguageInvocationTest.java * Refine MessagePackSerializer.pack * Ray.get support RayObject as input * Improve comments and error info * Remove classLoader argument from serializer * Separate msgpack from pickle5 in Python * Pair<byte[], MutableBoolean> to Pair<byte[], Boolean> * Remove public static <T> T get(RayObject<T> object), use RayObject.get() instead * Refine test * small fixes Co-authored-by: 刘宝 <po.lb@antfin.com> Co-authored-by: Hao Chen <chenh1024@gmail.com>
This commit is contained in:
parent
e8c19aba41
commit
fc6259a656
42 changed files with 1057 additions and 313 deletions
|
@ -76,6 +76,7 @@ define_java_module(
|
|||
"@maven//:de_ruedigermoeller_fst",
|
||||
"@maven//:net_java_dev_jna_jna",
|
||||
"@maven//:org_apache_commons_commons_lang3",
|
||||
"@maven//:org_msgpack_msgpack_core",
|
||||
"@maven//:org_ow2_asm_asm",
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
"@maven//:org_slf4j_slf4j_log4j12",
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package org.ray.api;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import org.ray.api.id.ObjectId;
|
||||
|
@ -62,23 +63,41 @@ public final class Ray extends RayCall {
|
|||
}
|
||||
|
||||
/**
|
||||
* Get an object from the object store.
|
||||
* Get an object by id from the object store.
|
||||
*
|
||||
* @param objectId The ID of the object to get.
|
||||
* @param objectType The type of the object to get.
|
||||
* @return The Java object.
|
||||
*/
|
||||
public static <T> T get(ObjectId objectId) {
|
||||
return runtime.get(objectId);
|
||||
public static <T> T get(ObjectId objectId, Class<T> objectType) {
|
||||
return runtime.get(objectId, objectType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a list of objects from the object store.
|
||||
* Get a list of objects by ids from the object store.
|
||||
*
|
||||
* @param objectIds The list of object IDs.
|
||||
* @param objectType The type of object.
|
||||
* @return A list of Java objects.
|
||||
*/
|
||||
public static <T> List<T> get(List<ObjectId> objectIds) {
|
||||
return runtime.get(objectIds);
|
||||
public static <T> List<T> get(List<ObjectId> objectIds, Class<T> objectType) {
|
||||
return runtime.get(objectIds, objectType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a list of objects by RayObjects from the object store.
|
||||
*
|
||||
* @param objectList A list of RayObject to get.
|
||||
* @return A list of Java objects.
|
||||
*/
|
||||
public static <T> List<T> get(List<RayObject<T>> objectList) {
|
||||
List<ObjectId> objectIds = new ArrayList<>();
|
||||
Class<T> objectType = null;
|
||||
for (RayObject<T> o : objectList) {
|
||||
objectIds.add(o.getId());
|
||||
objectType = o.getType();
|
||||
}
|
||||
return runtime.get(objectIds, objectType);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -19,5 +19,10 @@ public interface RayObject<T> {
|
|||
*/
|
||||
ObjectId getId();
|
||||
|
||||
/**
|
||||
* Get the Object type.
|
||||
*/
|
||||
Class<T> getType();
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -39,17 +39,19 @@ public interface RayRuntime {
|
|||
* Get an object from the object store.
|
||||
*
|
||||
* @param objectId The ID of the object to get.
|
||||
* @param objectType The type of the object to get.
|
||||
* @return The Java object.
|
||||
*/
|
||||
<T> T get(ObjectId objectId);
|
||||
<T> T get(ObjectId objectId, Class<T> objectType);
|
||||
|
||||
/**
|
||||
* Get a list of objects from the object store.
|
||||
*
|
||||
* @param objectIds The list of object IDs.
|
||||
* @param objectType The type of object.
|
||||
* @return A list of Java objects.
|
||||
*/
|
||||
<T> List<T> get(List<ObjectId> objectIds);
|
||||
<T> List<T> get(List<ObjectId> objectIds, Class<T> objectType);
|
||||
|
||||
/**
|
||||
* Wait for a list of RayObjects to be locally available, until specified number of objects are
|
||||
|
|
|
@ -15,11 +15,12 @@ def gen_java_deps():
|
|||
"de.ruedigermoeller:fst:2.57",
|
||||
"javax.xml.bind:jaxb-api:2.3.0",
|
||||
"org.apache.commons:commons-lang3:3.4",
|
||||
"org.msgpack:msgpack-core:0.8.20",
|
||||
"org.ow2.asm:asm:6.0",
|
||||
"org.slf4j:slf4j-log4j12:1.7.25",
|
||||
"org.testng:testng:6.9.10",
|
||||
"redis.clients:jedis:2.8.0",
|
||||
"net.java.dev.jna:jna:5.5.0"
|
||||
"net.java.dev.jna:jna:5.5.0",
|
||||
],
|
||||
repositories = [
|
||||
"https://repo1.maven.org/maven2/",
|
||||
|
|
|
@ -62,6 +62,11 @@
|
|||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.msgpack</groupId>
|
||||
<artifactId>msgpack-core</artifactId>
|
||||
<version>0.8.20</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.ow2.asm</groupId>
|
||||
<artifactId>asm</artifactId>
|
||||
|
|
|
@ -4,6 +4,7 @@ import com.google.common.base.Preconditions;
|
|||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Callable;
|
||||
import org.ray.api.BaseActor;
|
||||
import org.ray.api.RayActor;
|
||||
|
@ -15,7 +16,6 @@ import org.ray.api.function.PyActorClass;
|
|||
import org.ray.api.function.PyActorMethod;
|
||||
import org.ray.api.function.PyRemoteFunction;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.function.RayFuncVoid;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
|
@ -26,6 +26,7 @@ import org.ray.runtime.context.WorkerContext;
|
|||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
|
@ -73,18 +74,18 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
@Override
|
||||
public <T> RayObject<T> put(T obj) {
|
||||
ObjectId objectId = objectStore.put(obj);
|
||||
return new RayObjectImpl<>(objectId);
|
||||
return new RayObjectImpl<T>(objectId, (Class<T>)(obj == null ? Object.class : obj.getClass()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T get(ObjectId objectId) throws RayException {
|
||||
List<T> ret = get(ImmutableList.of(objectId));
|
||||
public <T> T get(ObjectId objectId, Class<T> objectType) throws RayException {
|
||||
List<T> ret = get(ImmutableList.of(objectId), objectType);
|
||||
return ret.get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> List<T> get(List<ObjectId> objectIds) {
|
||||
return objectStore.get(objectIds);
|
||||
public <T> List<T> get(List<ObjectId> objectIds, Class<T> objectType) {
|
||||
return objectStore.get(objectIds, objectType);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -99,41 +100,39 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
|
||||
@Override
|
||||
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
|
||||
FunctionDescriptor functionDescriptor =
|
||||
functionManager.getFunction(workerContext.getCurrentJobId(), func)
|
||||
.functionDescriptor;
|
||||
int numReturns = func instanceof RayFuncVoid ? 0 : 1;
|
||||
return callNormalFunction(functionDescriptor, args, numReturns, options);
|
||||
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
|
||||
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
|
||||
Optional<Class<?>> returnType = rayFunction.getReturnType();
|
||||
return callNormalFunction(functionDescriptor, args, returnType, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args,
|
||||
CallOptions options) {
|
||||
checkPyArguments(args);
|
||||
CallOptions options) {
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
|
||||
pyRemoteFunction.moduleName,
|
||||
"",
|
||||
pyRemoteFunction.functionName);
|
||||
// Python functions always have a return value, even if it's `None`.
|
||||
return callNormalFunction(functionDescriptor, args, /*numReturns=*/1, options);
|
||||
return callNormalFunction(functionDescriptor, args,
|
||||
/*returnType=*/Optional.of(pyRemoteFunction.returnType), options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayActor<?> actor, RayFunc func, Object[] args) {
|
||||
FunctionDescriptor functionDescriptor =
|
||||
functionManager.getFunction(workerContext.getCurrentJobId(), func)
|
||||
.functionDescriptor;
|
||||
int numReturns = func instanceof RayFuncVoid ? 0 : 1;
|
||||
return callActorFunction(actor, functionDescriptor, args, numReturns);
|
||||
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
|
||||
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
|
||||
Optional<Class<?>> returnType = rayFunction.getReturnType();
|
||||
return callActorFunction(actor, functionDescriptor, args, returnType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object... args) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
|
||||
pyActor.getClassName(), pyActorMethod.methodName);
|
||||
// Python functions always have a return value, even if it's `None`.
|
||||
return callActorFunction(pyActor, functionDescriptor, args, /*numReturns=*/1);
|
||||
return callActorFunction(pyActor, functionDescriptor, args,
|
||||
/*returnType=*/Optional.of(pyActorMethod.returnType));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -148,8 +147,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
|
||||
@Override
|
||||
public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
checkPyArguments(args);
|
||||
ActorCreationOptions options) {
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
|
||||
pyActorClass.moduleName,
|
||||
pyActorClass.className,
|
||||
|
@ -157,14 +155,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
return (RayPyActor) createActorImpl(functionDescriptor, args, options);
|
||||
}
|
||||
|
||||
private void checkPyArguments(Object[] args) {
|
||||
for (Object arg : args) {
|
||||
Preconditions.checkArgument(
|
||||
(arg instanceof RayPyActor) || (arg instanceof byte[]),
|
||||
"Python argument can only be a RayPyActor or a byte array, not {}.",
|
||||
arg.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
|
@ -218,30 +208,32 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
}
|
||||
|
||||
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, int numReturns, CallOptions options) {
|
||||
Object[] args, Optional<Class<?>> returnType, CallOptions options) {
|
||||
int numReturns = returnType.isPresent() ? 1 : 0;
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage());
|
||||
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
|
||||
functionArgs, numReturns, options);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
Preconditions.checkState(returnIds.size() == numReturns);
|
||||
if (returnIds.isEmpty()) {
|
||||
return null;
|
||||
} else {
|
||||
return new RayObjectImpl(returnIds.get(0));
|
||||
return new RayObjectImpl(returnIds.get(0), returnType.get());
|
||||
}
|
||||
}
|
||||
|
||||
private RayObject callActorFunction(BaseActor rayActor,
|
||||
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
|
||||
FunctionDescriptor functionDescriptor, Object[] args, Optional<Class<?>> returnType) {
|
||||
int numReturns = returnType.isPresent() ? 1 : 0;
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage());
|
||||
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
|
||||
functionDescriptor, functionArgs, numReturns, null);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
Preconditions.checkState(returnIds.size() == numReturns);
|
||||
if (returnIds.isEmpty()) {
|
||||
return null;
|
||||
} else {
|
||||
return new RayObjectImpl(returnIds.get(0));
|
||||
return new RayObjectImpl(returnIds.get(0), returnType.get());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import org.ray.api.id.JobId;
|
|||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.serializer.Serializer;
|
||||
|
||||
/**
|
||||
* The context of worker.
|
||||
|
@ -28,7 +29,7 @@ public interface WorkerContext {
|
|||
|
||||
/**
|
||||
* The class loader that is associated with the current job. It's used for locating classes when
|
||||
* dealing with serialization and deserialization in {@link org.ray.runtime.util.Serializer}.
|
||||
* dealing with serialization and deserialization in {@link Serializer}.
|
||||
*/
|
||||
ClassLoader getCurrentClassLoader();
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package org.ray.runtime.functionmanager;
|
|||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Executable;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Represents a Ray function (either a Method or a Constructor in Java) and its metadata.
|
||||
|
@ -67,6 +68,17 @@ public class RayFunction {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Return type.
|
||||
*/
|
||||
public Optional<Class<?>> getReturnType() {
|
||||
if (hasReturn()) {
|
||||
return Optional.of(((Method) executable).getReturnType());
|
||||
} else {
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return executable.toString();
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
package org.ray.runtime.object;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.api.exception.RayActorException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.exception.RayWorkerException;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.generated.Gcs.ErrorType;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.ray.runtime.serializer.Serializer;
|
||||
|
||||
/**
|
||||
* Serialize to and deserialize from {@link NativeRayObject}. Metadata is generated during
|
||||
|
@ -21,29 +23,33 @@ public class ObjectSerializer {
|
|||
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
|
||||
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
|
||||
public static final byte[] OBJECT_METADATA_TYPE_CROSS_LANGUAGE = "XLANG".getBytes();
|
||||
public static final byte[] OBJECT_METADATA_TYPE_JAVA = "JAVA".getBytes();
|
||||
public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes();
|
||||
public static final byte[] OBJECT_METADATA_TYPE_RAW = "RAW".getBytes();
|
||||
|
||||
/**
|
||||
* Deserialize an object from an {@link NativeRayObject} instance.
|
||||
*
|
||||
* @param nativeRayObject The object to deserialize.
|
||||
* @param objectId The associated object ID of the object.
|
||||
* @param classLoader The classLoader of the object.
|
||||
* @return The deserialized object.
|
||||
*/
|
||||
public static Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId,
|
||||
ClassLoader classLoader) {
|
||||
Class<?> objectType) {
|
||||
byte[] meta = nativeRayObject.metadata;
|
||||
byte[] data = nativeRayObject.data;
|
||||
|
||||
if (meta != null && meta.length > 0) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
if (Arrays.equals(meta, RAW_TYPE_META)) {
|
||||
if (Arrays.equals(meta, OBJECT_METADATA_TYPE_RAW)) {
|
||||
return data;
|
||||
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
|
||||
Arrays.equals(meta, OBJECT_METADATA_TYPE_JAVA)) {
|
||||
return Serializer.decode(data, objectType);
|
||||
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
return new RayWorkerException();
|
||||
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
|
||||
|
@ -51,12 +57,15 @@ public class ObjectSerializer {
|
|||
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
|
||||
return new UnreconstructableException(objectId);
|
||||
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
|
||||
return Serializer.decode(data, classLoader);
|
||||
return Serializer.decode(data, objectType);
|
||||
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_PYTHON)) {
|
||||
throw new IllegalArgumentException("Can't deserialize Python object: " + objectId
|
||||
.toString());
|
||||
}
|
||||
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
|
||||
} else {
|
||||
// If data is not null, deserialize the Java object.
|
||||
return Serializer.decode(data, classLoader);
|
||||
return Serializer.decode(data, objectType);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,12 +81,14 @@ public class ObjectSerializer {
|
|||
} else if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
return new NativeRayObject((byte[]) object, RAW_TYPE_META);
|
||||
return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW);
|
||||
} else if (object instanceof RayTaskException) {
|
||||
return new NativeRayObject(Serializer.encode(object),
|
||||
TASK_EXECUTION_EXCEPTION_META);
|
||||
byte[] serializedBytes = Serializer.encode(object).getLeft();
|
||||
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
|
||||
} else {
|
||||
return new NativeRayObject(Serializer.encode(object), null);
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(object);
|
||||
return new NativeRayObject(serialized.getLeft(), serialized.getRight() ?
|
||||
OBJECT_METADATA_TYPE_CROSS_LANGUAGE : OBJECT_METADATA_TYPE_JAVA);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ public abstract class ObjectStore {
|
|||
* @return A list of GetResult objects.
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> List<T> get(List<ObjectId> ids) {
|
||||
public <T> List<T> get(List<ObjectId> ids, Class<?> elementType) {
|
||||
// Pass -1 as timeout to wait until all objects are available in object store.
|
||||
List<NativeRayObject> dataAndMetaList = getRaw(ids, -1);
|
||||
|
||||
|
@ -96,7 +96,7 @@ public abstract class ObjectStore {
|
|||
Object object = null;
|
||||
if (dataAndMeta != null) {
|
||||
object = ObjectSerializer
|
||||
.deserialize(dataAndMeta, ids.get(i), workerContext.getCurrentClassLoader());
|
||||
.deserialize(dataAndMeta, ids.get(i), elementType);
|
||||
}
|
||||
if (object instanceof RayException) {
|
||||
// If the object is a `RayException`, it means that an error occurred during task
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package org.ray.runtime.object;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.id.ObjectId;
|
||||
|
@ -20,13 +21,16 @@ public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
|
|||
*/
|
||||
private transient T object;
|
||||
|
||||
private Class<T> type;
|
||||
|
||||
/**
|
||||
* Whether the object is already gotten from the object store.
|
||||
*/
|
||||
private transient boolean objectGotten;
|
||||
|
||||
public RayObjectImpl(ObjectId id) {
|
||||
public RayObjectImpl(ObjectId id, Class<T> type) {
|
||||
this.id = id;
|
||||
this.type = type;
|
||||
object = null;
|
||||
objectGotten = false;
|
||||
}
|
||||
|
@ -34,7 +38,7 @@ public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
|
|||
@Override
|
||||
public synchronized T get() {
|
||||
if (!objectGotten) {
|
||||
object = Ray.get(id);
|
||||
object = Ray.get(id, type);
|
||||
objectGotten = true;
|
||||
}
|
||||
return object;
|
||||
|
@ -45,4 +49,9 @@ public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
|
|||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<T> getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
package org.ray.runtime.serializer;
|
||||
|
||||
import org.nustaq.serialization.FSTConfiguration;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.runtime.actor.NativeRayActorSerializer;
|
||||
|
||||
/**
|
||||
* Java object serialization TODO: use others (e.g. Arrow) for higher performance
|
||||
*/
|
||||
public class FstSerializer {
|
||||
|
||||
private static final ThreadLocal<FSTConfiguration> conf = ThreadLocal.withInitial(() -> {
|
||||
FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration();
|
||||
conf.registerSerializer(NativeRayActor.class, new NativeRayActorSerializer(), true);
|
||||
return conf;
|
||||
});
|
||||
|
||||
|
||||
public static byte[] encode(Object obj) {
|
||||
FSTConfiguration current = conf.get();
|
||||
current.setClassLoader(Thread.currentThread().getContextClassLoader());
|
||||
return current.asByteArray(obj);
|
||||
}
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T> T decode(byte[] bs) {
|
||||
FSTConfiguration current = conf.get();
|
||||
current.setClassLoader(Thread.currentThread().getContextClassLoader());
|
||||
return (T) current.asObject(bs);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,270 @@
|
|||
package org.ray.runtime.serializer;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Array;
|
||||
import java.math.BigInteger;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.lang3.mutable.MutableBoolean;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.msgpack.core.MessageBufferPacker;
|
||||
import org.msgpack.core.MessagePack;
|
||||
import org.msgpack.core.MessagePacker;
|
||||
import org.msgpack.core.MessageUnpacker;
|
||||
import org.msgpack.value.ArrayValue;
|
||||
import org.msgpack.value.ExtensionValue;
|
||||
import org.msgpack.value.IntegerValue;
|
||||
import org.msgpack.value.Value;
|
||||
import org.msgpack.value.ValueType;
|
||||
|
||||
// We can't pack List / Map by MessagePack, because we don't know the type class when unpacking.
|
||||
public class MessagePackSerializer {
|
||||
|
||||
private static final byte LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID = 101;
|
||||
// MessagePack length is an int takes up to 9 bytes.
|
||||
// https://github.com/msgpack/msgpack/blob/master/spec.md#int-format-family
|
||||
private static final int MESSAGE_PACK_OFFSET = 9;
|
||||
|
||||
// Pakcers indexed by its corresponding Java class object.
|
||||
private static Map<Class<?>, TypePacker> packers = new HashMap<>();
|
||||
// Unpackers indexed by its corresponding MessagePack ValueType.
|
||||
private static Map<ValueType, TypeUnpacker> unpackers = new HashMap<>();
|
||||
// Null and array don't have a corresponding class, so define them separately.
|
||||
private static final TypePacker NULL_PACKER;
|
||||
private static final TypePacker ARRAY_PACKER;
|
||||
private static final TypePacker EXTENSION_PACKER;
|
||||
|
||||
static {
|
||||
// ===== Initialize packers =====
|
||||
// Null packer.
|
||||
NULL_PACKER = (object, packer, javaSerializer) -> packer.packNil();
|
||||
|
||||
// Array packer.
|
||||
ARRAY_PACKER = ((object, packer, javaSerializer) -> {
|
||||
int length = Array.getLength(object);
|
||||
packer.packArrayHeader(length);
|
||||
for (int i = 0; i < length; ++i) {
|
||||
pack(Array.get(object, i), packer, javaSerializer);
|
||||
}
|
||||
});
|
||||
|
||||
// Extension packer.
|
||||
EXTENSION_PACKER = ((object, packer, javaSerializer) -> {
|
||||
javaSerializer.serialize(object, packer);
|
||||
});
|
||||
|
||||
packers.put(Boolean.class,
|
||||
((object, packer, javaSerializer) -> packer.packBoolean((Boolean) object)));
|
||||
packers.put(Byte.class,
|
||||
((object, packer, javaSerializer) -> packer.packByte((Byte) object)));
|
||||
packers.put(Short.class,
|
||||
((object, packer, javaSerializer) -> packer.packShort((Short) object)));
|
||||
packers.put(Integer.class,
|
||||
((object, packer, javaSerializer) -> packer.packInt((Integer) object)));
|
||||
packers.put(Long.class,
|
||||
((object, packer, javaSerializer) -> packer.packLong((Long) object)));
|
||||
packers.put(BigInteger.class,
|
||||
((object, packer, javaSerializer) -> packer.packBigInteger((BigInteger) object)));
|
||||
packers.put(Float.class,
|
||||
((object, packer, javaSerializer) -> packer.packFloat((Float) object)));
|
||||
packers.put(Double.class,
|
||||
((object, packer, javaSerializer) -> packer.packDouble((Double) object)));
|
||||
packers.put(String.class,
|
||||
((object, packer, javaSerializer) -> packer.packString((String) object)));
|
||||
packers.put(byte[].class,
|
||||
((object, packer, javaSerializer) -> {
|
||||
byte[] bytes = (byte[]) object;
|
||||
packer.packBinaryHeader(bytes.length);
|
||||
packer.writePayload(bytes);
|
||||
}));
|
||||
|
||||
// ===== Initialize unpackers =====
|
||||
List<Class<?>> booleanClasses = ImmutableList.of(Boolean.class, boolean.class);
|
||||
List<Class<?>> byteClasses = ImmutableList.of(Byte.class, byte.class);
|
||||
List<Class<?>> shortClasses = ImmutableList.of(Short.class, short.class);
|
||||
List<Class<?>> intClasses = ImmutableList.of(Integer.class, int.class);
|
||||
List<Class<?>> longClasses = ImmutableList.of(Long.class, long.class);
|
||||
List<Class<?>> bigIntClasses = ImmutableList.of(BigInteger.class);
|
||||
List<Class<?>> floatClasses = ImmutableList.of(Float.class, float.class);
|
||||
List<Class<?>> doubleClasses = ImmutableList.of(Double.class, double.class);
|
||||
List<Class<?>> stringClasses = ImmutableList.of(String.class);
|
||||
List<Class<?>> binaryClasses = ImmutableList.of(byte[].class);
|
||||
|
||||
// Null unpacker.
|
||||
unpackers.put(ValueType.NIL, (value, targetClass, javaDeserializer) -> null);
|
||||
// Boolean unpacker.
|
||||
unpackers.put(ValueType.BOOLEAN, (value, targetClass, javaDeserializer) -> {
|
||||
Preconditions.checkArgument(checkTypeCompatible(booleanClasses, targetClass),
|
||||
"Boolean can't be deserialized as {}.", targetClass);
|
||||
return value.asBooleanValue().getBoolean();
|
||||
});
|
||||
// Integer unpacker.
|
||||
unpackers.put(ValueType.INTEGER, ((value, targetClass, javaDeserializer) -> {
|
||||
IntegerValue iv = value.asIntegerValue();
|
||||
if (iv.isInByteRange() && checkTypeCompatible(byteClasses, targetClass)) {
|
||||
return iv.asByte();
|
||||
} else if (iv.isInShortRange() && checkTypeCompatible(shortClasses, targetClass)) {
|
||||
return iv.asShort();
|
||||
} else if (iv.isInIntRange() && checkTypeCompatible(intClasses, targetClass)) {
|
||||
return iv.asInt();
|
||||
} else if (iv.isInLongRange() && checkTypeCompatible(longClasses, targetClass)) {
|
||||
return iv.asLong();
|
||||
} else if (checkTypeCompatible(bigIntClasses, targetClass)) {
|
||||
return iv.asBigInteger();
|
||||
}
|
||||
throw new IllegalArgumentException("Integer can't be deserialized as " + targetClass + ".");
|
||||
}));
|
||||
// Float unpacker.
|
||||
unpackers.put(ValueType.FLOAT, ((value, targetClass, javaDeserializer) -> {
|
||||
if (checkTypeCompatible(doubleClasses, targetClass)) {
|
||||
return value.asFloatValue().toDouble();
|
||||
} else if (checkTypeCompatible(floatClasses, targetClass)) {
|
||||
return value.asFloatValue().toFloat();
|
||||
}
|
||||
throw new IllegalArgumentException("Float can't be deserialized as " + targetClass + ".");
|
||||
}));
|
||||
// String unpacker.
|
||||
unpackers.put(ValueType.STRING, ((value, targetClass, javaDeserializer) -> {
|
||||
Preconditions.checkArgument(checkTypeCompatible(stringClasses, targetClass),
|
||||
"String can't be deserialized as {}.", targetClass);
|
||||
return value.asStringValue().asString();
|
||||
}));
|
||||
// Binary unpacker.
|
||||
unpackers.put(ValueType.BINARY, ((value, targetClass, javaDeserializer) -> {
|
||||
Preconditions.checkArgument(checkTypeCompatible(binaryClasses, targetClass),
|
||||
"Binary can't be deserialized as {}.", targetClass);
|
||||
return value.asBinaryValue().asByteArray();
|
||||
}));
|
||||
// Array unpacker.
|
||||
unpackers.put(ValueType.ARRAY, ((value, targetClass, javaDeserializer) -> {
|
||||
ArrayValue av = value.asArrayValue();
|
||||
Class<?> componentType =
|
||||
targetClass.isArray() ? targetClass.getComponentType() : Object.class;
|
||||
Object array = Array.newInstance(componentType, av.size());
|
||||
for (int i = 0; i < av.size(); ++i) {
|
||||
Array.set(array, i, unpack(av.get(i), componentType, javaDeserializer));
|
||||
}
|
||||
return array;
|
||||
}));
|
||||
// Extension unpacker.
|
||||
unpackers.put(ValueType.EXTENSION, ((value, targetClass, javaDeserializer) -> {
|
||||
ExtensionValue ev = value.asExtensionValue();
|
||||
byte extType = ev.getType();
|
||||
if (extType == LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID) {
|
||||
return javaDeserializer.deserialize(ev);
|
||||
}
|
||||
throw new IllegalArgumentException("Unknown extension type id " + ev.getType() + ".");
|
||||
}));
|
||||
}
|
||||
|
||||
interface JavaSerializer {
|
||||
|
||||
void serialize(Object object, MessagePacker packer) throws IOException;
|
||||
}
|
||||
|
||||
interface JavaDeserializer {
|
||||
|
||||
Object deserialize(ExtensionValue v);
|
||||
}
|
||||
|
||||
interface TypePacker {
|
||||
|
||||
void pack(Object object, MessagePacker packer,
|
||||
JavaSerializer javaSerializer) throws IOException;
|
||||
}
|
||||
|
||||
interface TypeUnpacker {
|
||||
|
||||
Object unpack(Value value, Class<?> targetClass,
|
||||
JavaDeserializer javaDeserializer);
|
||||
}
|
||||
|
||||
private static boolean checkTypeCompatible(List<Class<?>> expected, Class<?> actual) {
|
||||
for (Class<?> expectedClass : expected) {
|
||||
if (actual.isAssignableFrom(expectedClass)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static void pack(Object object, MessagePacker packer, JavaSerializer javaSerializer)
|
||||
throws IOException {
|
||||
TypePacker typePacker;
|
||||
if (object == null) {
|
||||
typePacker = NULL_PACKER;
|
||||
} else {
|
||||
Class<?> type = object.getClass();
|
||||
typePacker = packers.get(type);
|
||||
if (typePacker == null) {
|
||||
if (type.isArray()) {
|
||||
typePacker = ARRAY_PACKER;
|
||||
} else {
|
||||
typePacker = EXTENSION_PACKER;
|
||||
}
|
||||
}
|
||||
}
|
||||
typePacker.pack(object, packer, javaSerializer);
|
||||
}
|
||||
|
||||
private static Object unpack(Value v, Class<?> type, JavaDeserializer javaDeserializer) {
|
||||
return unpackers.get(v.getValueType()).unpack(v, type, javaDeserializer);
|
||||
}
|
||||
|
||||
public static Pair<byte[], Boolean> encode(Object obj) {
|
||||
MessageBufferPacker packer = MessagePack.newDefaultBufferPacker();
|
||||
try {
|
||||
// Reserve MESSAGE_PACK_OFFSET bytes for MessagePack bytes length.
|
||||
packer.writePayload(new byte[MESSAGE_PACK_OFFSET]);
|
||||
// Serialize input object by MessagePack.
|
||||
MutableBoolean isCrossLanguage = new MutableBoolean(true);
|
||||
pack(obj, packer, ((object, packer1) -> {
|
||||
byte[] payload = FstSerializer.encode(object);
|
||||
packer1.packExtensionTypeHeader(LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID, payload.length);
|
||||
packer1.addPayload(payload);
|
||||
isCrossLanguage.setFalse();
|
||||
}));
|
||||
byte[] msgpackBytes = packer.toByteArray();
|
||||
// Serialize MessagePack bytes length.
|
||||
MessageBufferPacker headerPacker = MessagePack.newDefaultBufferPacker();
|
||||
Preconditions.checkState(msgpackBytes.length >= MESSAGE_PACK_OFFSET);
|
||||
headerPacker.packLong(msgpackBytes.length - MESSAGE_PACK_OFFSET);
|
||||
byte[] msgpackBytesLength = headerPacker.toByteArray();
|
||||
// Check serialized MessagePack bytes length is valid.
|
||||
Preconditions.checkState(msgpackBytesLength.length <= MESSAGE_PACK_OFFSET);
|
||||
// Write MessagePack bytes length to reserved buffer.
|
||||
System.arraycopy(msgpackBytesLength, 0, msgpackBytes, 0, msgpackBytesLength.length);
|
||||
return ImmutablePair.of(msgpackBytes, isCrossLanguage.getValue());
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T> T decode(byte[] bs, Class<?> type) {
|
||||
try {
|
||||
// Read MessagePack bytes length.
|
||||
MessageUnpacker headerUnpacker = MessagePack.newDefaultUnpacker(bs, 0, MESSAGE_PACK_OFFSET);
|
||||
long msgpackBytesLength = headerUnpacker.unpackLong();
|
||||
headerUnpacker.close();
|
||||
// Check MessagePack bytes length is valid.
|
||||
Preconditions.checkState(MESSAGE_PACK_OFFSET + msgpackBytesLength <= bs.length);
|
||||
// Deserialize MessagePack bytes from MESSAGE_PACK_OFFSET.
|
||||
MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(bs, MESSAGE_PACK_OFFSET,
|
||||
(int) msgpackBytesLength);
|
||||
Value v = unpacker.unpackValue();
|
||||
if (type == null) {
|
||||
type = Object.class;
|
||||
}
|
||||
return (T) unpack(v, type,
|
||||
((ExtensionValue ev) -> FstSerializer.decode(ev.getData())));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package org.ray.runtime.serializer;
|
||||
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
public class Serializer {
|
||||
|
||||
public static Pair<byte[], Boolean> encode(Object obj) {
|
||||
return MessagePackSerializer.encode(obj);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T> T decode(byte[] bs, Class<?> type) {
|
||||
return MessagePackSerializer.decode(bs, type);
|
||||
}
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
package org.ray.runtime.task;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.id.ObjectId;
|
||||
|
@ -40,9 +40,18 @@ public class ArgumentsBuilder {
|
|||
id = ((RayObject) arg).getId();
|
||||
} else {
|
||||
value = ObjectSerializer.serialize(arg);
|
||||
if (language != Language.JAVA) {
|
||||
boolean isCrossData =
|
||||
Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
|
||||
Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_RAW);
|
||||
if (!isCrossData) {
|
||||
throw new IllegalArgumentException(String.format("Can't transfer %s data to %s",
|
||||
Arrays.toString(value.metadata), language.getValueDescriptor().getName()));
|
||||
}
|
||||
}
|
||||
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore()
|
||||
.putRaw(value);
|
||||
.putRaw(value);
|
||||
value = null;
|
||||
}
|
||||
}
|
||||
|
@ -61,10 +70,10 @@ public class ArgumentsBuilder {
|
|||
/**
|
||||
* Convert list of NativeRayObject to real function arguments.
|
||||
*/
|
||||
public static Object[] unwrap(List<NativeRayObject> args, ClassLoader classLoader) {
|
||||
public static Object[] unwrap(List<NativeRayObject> args, Class<?>[] types) {
|
||||
Object[] realArgs = new Object[args.size()];
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, classLoader);
|
||||
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, types[i]);
|
||||
}
|
||||
return realArgs;
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import java.lang.reflect.InvocationTargetException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
|
@ -97,7 +98,8 @@ public abstract class TaskExecutor<T extends ActorContext> {
|
|||
}
|
||||
actor = actorContext.currentActor;
|
||||
}
|
||||
Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader);
|
||||
Object[] args = ArgumentsBuilder
|
||||
.unwrap(argsBytes, rayFunction.executable.getParameterTypes());
|
||||
// Execute the task.
|
||||
Object result;
|
||||
try {
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
package org.ray.runtime.util;
|
||||
|
||||
import org.nustaq.serialization.FSTConfiguration;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.runtime.actor.NativeRayActorSerializer;
|
||||
|
||||
/**
|
||||
* Java object serialization TODO: use others (e.g. Arrow) for higher performance
|
||||
*/
|
||||
public class Serializer {
|
||||
|
||||
private static final ThreadLocal<FSTConfiguration> conf = ThreadLocal.withInitial(() -> {
|
||||
FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration();
|
||||
conf.registerSerializer(NativeRayActor.class, new NativeRayActorSerializer(), true);
|
||||
return conf;
|
||||
});
|
||||
|
||||
public static byte[] encode(Object obj) {
|
||||
return conf.get().asByteArray(obj);
|
||||
}
|
||||
|
||||
public static byte[] encode(Object obj, ClassLoader classLoader) {
|
||||
byte[] result;
|
||||
FSTConfiguration current = conf.get();
|
||||
if (classLoader != null && classLoader != current.getClassLoader()) {
|
||||
ClassLoader old = current.getClassLoader();
|
||||
current.setClassLoader(classLoader);
|
||||
result = current.asByteArray(obj);
|
||||
current.setClassLoader(old);
|
||||
} else {
|
||||
result = current.asByteArray(obj);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T> T decode(byte[] bs) {
|
||||
return (T) conf.get().asObject(bs);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T> T decode(byte[] bs, ClassLoader classLoader) {
|
||||
Object object;
|
||||
FSTConfiguration current = conf.get();
|
||||
if (classLoader != null && classLoader != current.getClassLoader()) {
|
||||
ClassLoader old = current.getClassLoader();
|
||||
current.setClassLoader(classLoader);
|
||||
object = current.asObject(bs);
|
||||
current.setClassLoader(old);
|
||||
} else {
|
||||
object = current.asObject(bs);
|
||||
}
|
||||
return (T) object;
|
||||
}
|
||||
|
||||
public static void setClassloader(ClassLoader classLoader) {
|
||||
conf.get().setClassLoader(classLoader);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package org.ray.runtime.util;
|
||||
|
||||
import org.apache.commons.lang3.mutable.MutableBoolean;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.runtime.serializer.Serializer;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
public class SerializerTest {
|
||||
|
||||
@Test
|
||||
public void testBasicSerialization() {
|
||||
// Test serialize / deserialize primitive types with type conversion.
|
||||
{
|
||||
Object[] foo = new Object[]{"hello", (byte) 1, 2.0, (short) 3, 4, 5L,
|
||||
new String[]{"hello", "world"}};
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(foo);
|
||||
Object[] bar = Serializer.decode(serialized.getLeft(), Object[].class);
|
||||
Assert.assertTrue(serialized.getRight());
|
||||
Assert.assertEquals(foo[0], bar[0]);
|
||||
Assert.assertEquals(((Number) foo[1]).byteValue(), ((Number) bar[1]).byteValue());
|
||||
Assert.assertEquals(foo[2], bar[2]);
|
||||
Assert.assertEquals(((Number) foo[3]).intValue(), ((Number) bar[3]).intValue());
|
||||
Assert.assertEquals(((Number) foo[4]).intValue(), ((Number) bar[4]).intValue());
|
||||
Assert.assertEquals(((Number) foo[5]).intValue(), ((Number) bar[5]).intValue());
|
||||
}
|
||||
// Test multidimensional array.
|
||||
{
|
||||
Object[][] foo = new Object[][]{{1, 2}, {"3", 4}};
|
||||
Assert.expectThrows(RuntimeException.class, () -> {
|
||||
Object[][] bar = Serializer.decode(Serializer.encode(foo).getLeft(), Integer[][].class);
|
||||
});
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(foo);
|
||||
Object[][] bar = Serializer.decode(serialized.getLeft(), Object[][].class);
|
||||
Assert.assertTrue(serialized.getRight());
|
||||
Assert.assertEquals(((Number) foo[0][1]).intValue(), ((Number) bar[0][1]).intValue());
|
||||
Assert.assertEquals(foo[1][0], bar[1][0]);
|
||||
}
|
||||
// Test List.
|
||||
{
|
||||
ArrayList<String> foo = new ArrayList<>();
|
||||
foo.add("1");
|
||||
foo.add("2");
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(foo);
|
||||
ArrayList<String> bar = Serializer.decode(serialized.getLeft(), String[].class);
|
||||
Assert.assertFalse(serialized.getRight());
|
||||
Assert.assertEquals(foo.get(0), bar.get(0));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -143,7 +143,7 @@ public class ActorTest extends BaseTest {
|
|||
try {
|
||||
// Try getting the object again, this should throw an UnreconstructableException.
|
||||
// Use `Ray.get()` to bypass the cache in `RayObjectImpl`.
|
||||
Ray.get(value.getId());
|
||||
Ray.get(value.getId(), value.getType());
|
||||
Assert.fail("This line should not be reachable.");
|
||||
} catch (UnreconstructableException e) {
|
||||
Assert.assertEquals(value.getId(), e.objectId);
|
||||
|
|
|
@ -4,6 +4,7 @@ import java.io.File;
|
|||
import java.lang.reflect.Method;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Optional;
|
||||
import javax.tools.JavaCompiler;
|
||||
import javax.tools.ToolProvider;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -101,12 +102,14 @@ public class ClassLoaderTest extends BaseTest {
|
|||
"()V");
|
||||
RayActor<?> actor1 = createActor(constructor);
|
||||
FunctionDescriptor getPid = new JavaFunctionDescriptor("ClassLoaderTester", "getPid", "()I");
|
||||
int pid = this.<Integer>callActorFunction(actor1, getPid, new Object[0], 1).get();
|
||||
int pid = this.<Integer>callActorFunction(actor1, getPid, new Object[0],
|
||||
Optional.of(Integer.class)).get();
|
||||
RayActor<?> actor2;
|
||||
while (true) {
|
||||
// Create another actor which share the same process of actor 1.
|
||||
actor2 = createActor(constructor);
|
||||
int actor2Pid = this.<Integer>callActorFunction(actor2, getPid, new Object[0], 1).get();
|
||||
int actor2Pid = this.<Integer>callActorFunction(actor2, getPid, new Object[0],
|
||||
Optional.of(Integer.class)).get();
|
||||
if (actor2Pid == pid) {
|
||||
break;
|
||||
}
|
||||
|
@ -116,15 +119,17 @@ public class ClassLoaderTest extends BaseTest {
|
|||
"getClassLoaderHashCode",
|
||||
"()I");
|
||||
RayObject<Integer> hashCode1 = callActorFunction(actor1, getClassLoaderHashCode, new Object[0],
|
||||
1);
|
||||
Optional.of(Integer.class));
|
||||
RayObject<Integer> hashCode2 = callActorFunction(actor2, getClassLoaderHashCode, new Object[0],
|
||||
1);
|
||||
Optional.of(Integer.class));
|
||||
Assert.assertEquals(hashCode1.get(), hashCode2.get());
|
||||
|
||||
FunctionDescriptor increase = new JavaFunctionDescriptor("ClassLoaderTester", "increase",
|
||||
"()I");
|
||||
RayObject<Integer> value1 = callActorFunction(actor1, increase, new Object[0], 1);
|
||||
RayObject<Integer> value2 = callActorFunction(actor2, increase, new Object[0], 1);
|
||||
RayObject<Integer> value1 = callActorFunction(actor1, increase, new Object[0],
|
||||
Optional.of(Integer.class));
|
||||
RayObject<Integer> value2 = callActorFunction(actor2, increase, new Object[0],
|
||||
Optional.of(Integer.class));
|
||||
Assert.assertNotEquals(value1.get(), value2.get());
|
||||
}
|
||||
|
||||
|
@ -138,11 +143,12 @@ public class ClassLoaderTest extends BaseTest {
|
|||
}
|
||||
|
||||
private <T> RayObject<T> callActorFunction(RayActor<?> rayActor,
|
||||
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) throws Exception {
|
||||
FunctionDescriptor functionDescriptor, Object[] args, Optional<Class<?>> returnType)
|
||||
throws Exception {
|
||||
Method callActorFunctionMethod = AbstractRayRuntime.class.getDeclaredMethod("callActorFunction",
|
||||
BaseActor.class, FunctionDescriptor.class, Object[].class, int.class);
|
||||
BaseActor.class, FunctionDescriptor.class, Object[].class, Optional.class);
|
||||
callActorFunctionMethod.setAccessible(true);
|
||||
return (RayObject<T>) callActorFunctionMethod
|
||||
.invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, numReturns);
|
||||
.invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, returnType);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ public class ClientExceptionTest extends BaseTest {
|
|||
public void testWaitAndCrash() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
ObjectId randomId = ObjectId.fromRandom();
|
||||
RayObject<String> notExisting = new RayObjectImpl(randomId);
|
||||
RayObject<String> notExisting = new RayObjectImpl(randomId, String.class);
|
||||
|
||||
Thread thread = new Thread(() -> {
|
||||
try {
|
||||
|
|
|
@ -5,6 +5,9 @@ import com.google.common.collect.ImmutableMap;
|
|||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.math.BigInteger;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.ray.api.Ray;
|
||||
|
@ -51,18 +54,85 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
|
|||
|
||||
@Test
|
||||
public void testCallingPythonFunction() {
|
||||
RayObject<byte[]> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_func", byte[].class),
|
||||
"hello".getBytes());
|
||||
Assert.assertEquals(res.get(), "Response from Python: hello".getBytes());
|
||||
Object[] inputs = new Object[]{
|
||||
true, // Boolean
|
||||
Byte.MAX_VALUE, // Byte
|
||||
Short.MAX_VALUE, // Short
|
||||
Integer.MAX_VALUE, // Integer
|
||||
Long.MAX_VALUE, // Long
|
||||
// BigInteger can support max value of 2^64-1, please refer to:
|
||||
// https://github.com/msgpack/msgpack/blob/master/spec.md#int-format-family
|
||||
// If BigInteger larger than 2^64-1, the value can only be transferred among Java workers.
|
||||
BigInteger.valueOf(Long.MAX_VALUE), // BigInteger
|
||||
"Hello World!", // String
|
||||
1.234f, // Float
|
||||
1.234, // Double
|
||||
"example binary".getBytes()}; // byte[]
|
||||
for (Object o : inputs) {
|
||||
RayObject res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", o.getClass()),
|
||||
o);
|
||||
Assert.assertEquals(res.get(), o);
|
||||
}
|
||||
// null
|
||||
{
|
||||
Object input = null;
|
||||
RayObject<Object> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object.class), input);
|
||||
Object r = res.get();
|
||||
Assert.assertEquals(r, input);
|
||||
}
|
||||
// array
|
||||
{
|
||||
int[] input = new int[]{1, 2};
|
||||
RayObject<int[]> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", int[].class), input);
|
||||
int[] r = res.get();
|
||||
Assert.assertEquals(r, input);
|
||||
}
|
||||
// array of Object
|
||||
{
|
||||
Object[] input = new Object[]{1, 2.3f, 4.56, "789", "10".getBytes(), null, true,
|
||||
new int[]{1, 2}};
|
||||
RayObject<Object[]> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object[].class), input);
|
||||
Object[] r = res.get();
|
||||
// If we tell the value type is Object, then all numbers will be Number type.
|
||||
Assert.assertEquals(((Number) r[0]).intValue(), input[0]);
|
||||
Assert.assertEquals(((Number) r[1]).floatValue(), input[1]);
|
||||
Assert.assertEquals(((Number) r[2]).doubleValue(), input[2]);
|
||||
// String cast
|
||||
Assert.assertEquals((String) r[3], input[3]);
|
||||
// binary cast
|
||||
Assert.assertEquals((byte[]) r[4], input[4]);
|
||||
// null
|
||||
Assert.assertEquals(r[5], input[5]);
|
||||
// Boolean cast
|
||||
Assert.assertEquals((Boolean) r[6], input[6]);
|
||||
// array cast
|
||||
Object[] r7array = (Object[]) r[7];
|
||||
int[] input7array = (int[]) input[7];
|
||||
Assert.assertEquals(((Number) r7array[0]).intValue(), input7array[0]);
|
||||
Assert.assertEquals(((Number) r7array[1]).intValue(), input7array[1]);
|
||||
}
|
||||
// Unsupported types, all Java specific types, e.g. List / Map...
|
||||
{
|
||||
Assert.expectThrows(Exception.class, () -> {
|
||||
List<Integer> input = Arrays.asList(1, 2);
|
||||
RayObject<List<Integer>> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input",
|
||||
(Class<List<Integer>>) input.getClass()), input);
|
||||
List<Integer> r = res.get();
|
||||
Assert.assertEquals(r, input);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPythonCallJavaFunction() {
|
||||
RayObject<byte[]> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_function", byte[].class),
|
||||
"hello".getBytes());
|
||||
Assert.assertEquals(res.get(), "[Python]py_func -> [Java]bytesEcho -> hello".getBytes());
|
||||
RayObject<String> res = Ray.call(
|
||||
new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_function", String.class));
|
||||
Assert.assertEquals(res.get(), "success");
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -117,11 +187,33 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
|
|||
Assert.assertEquals(res.get(), "3".getBytes());
|
||||
}
|
||||
|
||||
public static byte[] bytesEcho(byte[] value) {
|
||||
public static Object[] pack(int i, String s, double f, Object[] o) {
|
||||
// This function will be called from test_cross_language_invocation.py
|
||||
String valueStr = new String(value);
|
||||
LOGGER.debug(String.format("bytesEcho called with: %s", valueStr));
|
||||
return ("[Java]bytesEcho -> " + valueStr).getBytes();
|
||||
return new Object[]{i, s, f, o};
|
||||
}
|
||||
|
||||
public static Object returnInput(Object o) {
|
||||
return o;
|
||||
}
|
||||
|
||||
public static boolean returnInputBoolean(boolean b) {
|
||||
return b;
|
||||
}
|
||||
|
||||
public static int returnInputInt(int i) {
|
||||
return i;
|
||||
}
|
||||
|
||||
public static double returnInputDouble(double d) {
|
||||
return d;
|
||||
}
|
||||
|
||||
public static String returnInputString(String s) {
|
||||
return s;
|
||||
}
|
||||
|
||||
public static int[] returnInputIntArray(int[] l) {
|
||||
return l;
|
||||
}
|
||||
|
||||
public static byte[] callPythonActorHandle(byte[] value) {
|
||||
|
@ -135,6 +227,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
|
|||
}
|
||||
|
||||
public static class TestActor {
|
||||
|
||||
public TestActor(byte[] v) {
|
||||
value = v;
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ public class DynamicResourceTest extends BaseTest {
|
|||
// Assert ray call result.
|
||||
result = Ray.wait(ImmutableList.of(obj), 1, 1000);
|
||||
Assert.assertEquals(result.getReady().size(), 1);
|
||||
Assert.assertEquals(Ray.get(obj.getId()), "hi");
|
||||
Assert.assertEquals(obj.get(), "hi");
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -148,7 +148,7 @@ public class FailureTest extends BaseTest {
|
|||
RayObject<Integer> obj2 = Ray.call(FailureTest::slowFunc);
|
||||
Instant start = Instant.now();
|
||||
try {
|
||||
Ray.get(Arrays.asList(obj1.getId(), obj2.getId()));
|
||||
Ray.get(Arrays.asList(obj1, obj2));
|
||||
Assert.fail("Should throw RayException.");
|
||||
} catch (RayException e) {
|
||||
Instant end = Instant.now();
|
||||
|
|
|
@ -104,7 +104,7 @@ public class MultiThreadingTest extends BaseTest {
|
|||
runTestCaseInMultipleThreads(() -> {
|
||||
int arg = random.nextInt();
|
||||
RayObject<Integer> obj = Ray.put(arg);
|
||||
Assert.assertEquals(arg, (int) Ray.get(obj.getId()));
|
||||
Assert.assertEquals(arg, (int) obj.get());
|
||||
}, LOOP_COUNTER);
|
||||
|
||||
TestUtils.warmUpCluster();
|
||||
|
@ -141,7 +141,7 @@ public class MultiThreadingTest extends BaseTest {
|
|||
final RayActor<Echo> fooActor = Ray.createActor(Echo::new);
|
||||
final Runnable[] runnables = new Runnable[]{
|
||||
() -> Ray.put(1),
|
||||
() -> Ray.get(fooObject.getId()),
|
||||
() -> Ray.get(fooObject.getId(), fooObject.getType()),
|
||||
fooObject::get,
|
||||
() -> Ray.wait(ImmutableList.of(fooObject)),
|
||||
Ray::getRuntimeContext,
|
||||
|
|
|
@ -16,8 +16,22 @@ public class ObjectStoreTest extends BaseTest {
|
|||
|
||||
@Test
|
||||
public void testPutAndGet() {
|
||||
RayObject<Integer> obj = Ray.put(1);
|
||||
Assert.assertEquals(1, (int) obj.get());
|
||||
{
|
||||
RayObject<Integer> obj = Ray.put(1);
|
||||
Assert.assertEquals(1, (int) obj.get());
|
||||
}
|
||||
|
||||
{
|
||||
String s = null;
|
||||
RayObject<String> obj = Ray.put(s);
|
||||
Assert.assertNull(obj.get());
|
||||
}
|
||||
|
||||
{
|
||||
List<List<String>> l = ImmutableList.of(ImmutableList.of("abc"));
|
||||
RayObject<List<List<String>>> obj = Ray.put(l);
|
||||
Assert.assertEquals(obj.get(), l);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -25,6 +39,6 @@ public class ObjectStoreTest extends BaseTest {
|
|||
List<Integer> ints = ImmutableList.of(1, 2, 3, 4, 5);
|
||||
List<ObjectId> ids = ints.stream().map(obj -> Ray.put(obj).getId())
|
||||
.collect(Collectors.toList());
|
||||
Assert.assertEquals(ints, Ray.get(ids));
|
||||
Assert.assertEquals(ints, Ray.get(ids, Integer.class));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,9 +15,9 @@ public class PlasmaStoreTest extends BaseTest {
|
|||
ObjectId objectId = ObjectId.fromRandom();
|
||||
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
|
||||
objectStore.put("1", objectId);
|
||||
Assert.assertEquals(Ray.get(objectId), "1");
|
||||
Assert.assertEquals(Ray.get(objectId, String.class), "1");
|
||||
objectStore.put("2", objectId);
|
||||
// Putting the second object with duplicate ID should fail but ignored.
|
||||
Assert.assertEquals(Ray.get(objectId), "1");
|
||||
Assert.assertEquals(Ray.get(objectId, String.class), "1");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ public class RayCallTest extends BaseTest {
|
|||
|
||||
ObjectId randomObjectId = ObjectId.fromRandom();
|
||||
Ray.call(RayCallTest::testNoReturn, randomObjectId);
|
||||
Assert.assertEquals(((int) Ray.get(randomObjectId)), 1);
|
||||
Assert.assertEquals(((int) Ray.get(randomObjectId, Integer.class)), 1);
|
||||
}
|
||||
|
||||
private static int testNoParam() {
|
||||
|
|
|
@ -2,9 +2,7 @@ package org.ray.api.test;
|
|||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.function.PyActorClass;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectSerializer;
|
||||
import org.testng.Assert;
|
||||
|
@ -15,10 +13,9 @@ public class RaySerializerTest extends BaseMultiLanguageTest {
|
|||
@Test
|
||||
public void testSerializePyActor() {
|
||||
RayPyActor pyActor = Ray.createActor(new PyActorClass("test", "RaySerializerTest"));
|
||||
WorkerContext workerContext = TestUtils.getRuntime().getWorkerContext();
|
||||
NativeRayObject nativeRayObject = ObjectSerializer.serialize(pyActor);
|
||||
RayPyActor result = (RayPyActor) ObjectSerializer
|
||||
.deserialize(nativeRayObject, null, workerContext.getCurrentClassLoader());
|
||||
.deserialize(nativeRayObject, null, Object.class);
|
||||
Assert.assertEquals(result.getId(), pyActor.getId());
|
||||
Assert.assertEquals(result.getModuleName(), "test");
|
||||
Assert.assertEquals(result.getClassName(), "RaySerializerTest");
|
||||
|
|
|
@ -28,7 +28,7 @@ public class StressTest extends BaseTest {
|
|||
resultIds.add(Ray.call(StressTest::echo, 1).getId());
|
||||
}
|
||||
|
||||
for (Integer result : Ray.<Integer>get(resultIds)) {
|
||||
for (Integer result : Ray.<Integer>get(resultIds, Integer.class)) {
|
||||
Assert.assertEquals(result, Integer.valueOf(1));
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ public class StressTest extends BaseTest {
|
|||
objectIds.add(actor.call(Actor::ping).getId());
|
||||
}
|
||||
int sum = 0;
|
||||
for (Integer result : Ray.<Integer>get(objectIds)) {
|
||||
for (Integer result : Ray.<Integer>get(objectIds, Integer.class)) {
|
||||
sum += result;
|
||||
}
|
||||
return sum;
|
||||
|
@ -84,7 +84,7 @@ public class StressTest extends BaseTest {
|
|||
objectIds.add(worker.call(Worker::ping, 100).getId());
|
||||
}
|
||||
|
||||
for (Integer result : Ray.<Integer>get(objectIds)) {
|
||||
for (Integer result : Ray.<Integer>get(objectIds, Integer.class)) {
|
||||
Assert.assertEquals(result, Integer.valueOf(100));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,18 +5,47 @@ import ray
|
|||
|
||||
|
||||
@ray.remote
|
||||
def py_func(value):
|
||||
assert isinstance(value, bytes)
|
||||
return b"Response from Python: " + value
|
||||
def py_return_input(v):
|
||||
return v
|
||||
|
||||
|
||||
@ray.remote
|
||||
def py_func_call_java_function(value):
|
||||
assert isinstance(value, bytes)
|
||||
f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"bytesEcho")
|
||||
r = f.remote(value)
|
||||
return b"[Python]py_func -> " + ray.get(r)
|
||||
def py_func_call_java_function():
|
||||
try:
|
||||
# None
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInput").remote(None)
|
||||
assert ray.get(r) is None
|
||||
# bool
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInputBoolean").remote(True)
|
||||
assert ray.get(r) is True
|
||||
# int
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInputInt").remote(100)
|
||||
assert ray.get(r) == 100
|
||||
# double
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInputDouble").remote(1.23)
|
||||
assert ray.get(r) == 1.23
|
||||
# string
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInputString").remote("Hello World!")
|
||||
assert ray.get(r) == "Hello World!"
|
||||
# list (tuple will be packed by pickle,
|
||||
# so only list can be transferred across language)
|
||||
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"returnInputIntArray").remote([1, 2, 3])
|
||||
assert ray.get(r) == [1, 2, 3]
|
||||
# pack
|
||||
f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
|
||||
"pack")
|
||||
input = [100, "hello", 1.23, [1, "2", 3.0]]
|
||||
r = f.remote(*input)
|
||||
assert ray.get(r) == input
|
||||
return "success"
|
||||
except Exception as ex:
|
||||
return str(ex)
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
|
|
@ -92,6 +92,8 @@ from ray.exceptions import (
|
|||
RayTimeoutError,
|
||||
)
|
||||
from ray.utils import decode
|
||||
import gc
|
||||
import msgpack
|
||||
|
||||
cimport cpython
|
||||
|
||||
|
@ -106,8 +108,6 @@ include "includes/libcoreworker.pxi"
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMCOPY_THREADS = 6
|
||||
|
||||
|
||||
def set_internal_config(dict options):
|
||||
cdef:
|
||||
|
@ -257,8 +257,9 @@ cdef int prepare_resources(
|
|||
return 0
|
||||
|
||||
|
||||
cdef void prepare_args(
|
||||
CoreWorker core_worker, args, c_vector[CTaskArg] *args_vector):
|
||||
cdef prepare_args(
|
||||
CoreWorker core_worker,
|
||||
Language language, args, c_vector[CTaskArg] *args_vector):
|
||||
cdef:
|
||||
size_t size
|
||||
int64_t put_threshold
|
||||
|
@ -274,6 +275,13 @@ cdef void prepare_args(
|
|||
|
||||
else:
|
||||
serialized_arg = worker.get_serialization_context().serialize(arg)
|
||||
metadata = serialized_arg.metadata
|
||||
if language != Language.PYTHON:
|
||||
if metadata not in [
|
||||
ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE,
|
||||
ray_constants.OBJECT_METADATA_TYPE_RAW]:
|
||||
raise Exception("Can't transfer {} data to {}".format(
|
||||
metadata, language))
|
||||
size = serialized_arg.total_bytes
|
||||
|
||||
# TODO(edoakes): any objects containing ObjectIDs are spilled to
|
||||
|
@ -283,12 +291,14 @@ cdef void prepare_args(
|
|||
if <int64_t>size <= put_threshold:
|
||||
arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](size))
|
||||
write_serialized_object(serialized_arg, arg_data)
|
||||
if size > 0:
|
||||
(<SerializedObject>serialized_arg).write_to(
|
||||
Buffer.make(arg_data))
|
||||
for object_id in serialized_arg.contained_object_ids:
|
||||
inlined_ids.push_back((<ObjectID>object_id).native())
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByValue(make_shared[CRayObject](
|
||||
arg_data, string_to_buffer(serialized_arg.metadata),
|
||||
arg_data, string_to_buffer(metadata),
|
||||
inlined_ids)))
|
||||
inlined_ids.clear()
|
||||
else:
|
||||
|
@ -616,29 +626,6 @@ cdef shared_ptr[CBuffer] string_to_buffer(c_string& c_str):
|
|||
<uint8_t*>(c_str.data()), c_str.size(), True))
|
||||
|
||||
|
||||
cdef write_serialized_object(
|
||||
serialized_object, const shared_ptr[CBuffer]& buf):
|
||||
from ray.serialization import Pickle5SerializedObject, RawSerializedObject
|
||||
|
||||
if isinstance(serialized_object, RawSerializedObject):
|
||||
if buf.get() != NULL and buf.get().Size() > 0:
|
||||
size = serialized_object.total_bytes
|
||||
if MEMCOPY_THREADS > 1 and size > kMemcopyDefaultThreshold:
|
||||
parallel_memcopy(buf.get().Data(),
|
||||
<const uint8_t*> serialized_object.value,
|
||||
size, kMemcopyDefaultBlocksize,
|
||||
MEMCOPY_THREADS)
|
||||
else:
|
||||
memcpy(buf.get().Data(),
|
||||
<const uint8_t*>serialized_object.value, size)
|
||||
|
||||
elif isinstance(serialized_object, Pickle5SerializedObject):
|
||||
(<Pickle5Writer>serialized_object.writer).write_to(
|
||||
serialized_object.inband, buf, MEMCOPY_THREADS)
|
||||
else:
|
||||
raise TypeError("Unsupported serialization type.")
|
||||
|
||||
|
||||
cdef class CoreWorker:
|
||||
|
||||
def __cinit__(self, is_driver, store_socket, raylet_socket,
|
||||
|
@ -780,7 +767,9 @@ cdef class CoreWorker:
|
|||
&c_object_id, &data)
|
||||
|
||||
if not object_already_exists:
|
||||
write_serialized_object(serialized_object, data)
|
||||
if total_bytes > 0:
|
||||
(<SerializedObject>serialized_object).write_to(
|
||||
Buffer.make(data))
|
||||
if self.is_local_mode:
|
||||
c_object_id_vector.push_back(c_object_id)
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Put(
|
||||
|
@ -875,7 +864,7 @@ cdef class CoreWorker:
|
|||
num_return_vals, c_resources)
|
||||
ray_function = CRayFunction(
|
||||
language.lang, function_descriptor.descriptor)
|
||||
prepare_args(self, args, &args_vector)
|
||||
prepare_args(self, language, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().SubmitTask(
|
||||
|
@ -908,7 +897,7 @@ cdef class CoreWorker:
|
|||
prepare_resources(placement_resources, &c_placement_resources)
|
||||
ray_function = CRayFunction(
|
||||
language.lang, function_descriptor.descriptor)
|
||||
prepare_args(self, args, &args_vector)
|
||||
prepare_args(self, language, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().CreateActor(
|
||||
|
@ -944,7 +933,7 @@ cdef class CoreWorker:
|
|||
task_options = CTaskOptions(num_return_vals, c_resources)
|
||||
ray_function = CRayFunction(
|
||||
language.lang, function_descriptor.descriptor)
|
||||
prepare_args(self, args, &args_vector)
|
||||
prepare_args(self, language, args, &args_vector)
|
||||
|
||||
with nogil:
|
||||
check_status(
|
||||
|
@ -1133,8 +1122,9 @@ cdef class CoreWorker:
|
|||
for i, serialized_object in enumerate(serialized_objects):
|
||||
# A nullptr is returned if the object already exists.
|
||||
if returns[0][i].get() != NULL:
|
||||
write_serialized_object(
|
||||
serialized_object, returns[0][i].get().GetData())
|
||||
if returns[0][i].get().HasData():
|
||||
(<SerializedObject>serialized_object).write_to(
|
||||
Buffer.make(returns[0][i].get().GetData()))
|
||||
if self.is_local_mode:
|
||||
return_ids_vector.push_back(return_ids[i])
|
||||
check_status(
|
||||
|
|
|
@ -44,7 +44,7 @@ cdef class Buffer:
|
|||
def __getbuffer__(self, Py_buffer* buffer, int flags):
|
||||
buffer.readonly = 0
|
||||
buffer.buf = <char *>self.buffer.get().Data()
|
||||
buffer.format = 'b'
|
||||
buffer.format = 'B'
|
||||
buffer.internal = NULL
|
||||
buffer.itemsize = 1
|
||||
buffer.len = self.size
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
from libc.string cimport memcpy
|
||||
from libc.stdint cimport uintptr_t, uint64_t, INT32_MAX
|
||||
from libcpp cimport nullptr
|
||||
import cython
|
||||
|
||||
DEF MEMCOPY_THREADS = 6
|
||||
|
||||
# This is the default alignment value for len(buffer) < 2048.
|
||||
DEF kMinorBufferAlign = 8
|
||||
|
@ -9,6 +13,8 @@ DEF kMajorBufferAlign = 64
|
|||
DEF kMajorBufferSize = 2048
|
||||
DEF kMemcopyDefaultBlocksize = 64
|
||||
DEF kMemcopyDefaultThreshold = 1024 * 1024
|
||||
DEF kLanguageSpecificTypeExtensionId = 101
|
||||
DEF kMessagePackOffset = 9
|
||||
|
||||
cdef extern from "ray/util/memory.h" namespace "ray" nogil:
|
||||
void parallel_memcopy(uint8_t* dst, const uint8_t* src, int64_t nbytes,
|
||||
|
@ -82,7 +88,7 @@ cdef class SubBuffer:
|
|||
void *internal
|
||||
object buffer
|
||||
|
||||
def __cinit__(self, Buffer buffer):
|
||||
def __cinit__(self, object buffer):
|
||||
# Increase ref count.
|
||||
self.buffer = buffer
|
||||
self.suboffsets = NULL
|
||||
|
@ -142,15 +148,68 @@ cdef class SubBuffer:
|
|||
return self.size
|
||||
|
||||
|
||||
# See 'serialization.proto' for the memory layout in the Plasma buffer.
|
||||
def unpack_pickle5_buffers(Buffer buf):
|
||||
cdef class MessagePackSerializer(object):
|
||||
@staticmethod
|
||||
def dumps(o, python_serializer=None):
|
||||
def _default(obj):
|
||||
if python_serializer is not None:
|
||||
return msgpack.ExtType(kLanguageSpecificTypeExtensionId,
|
||||
msgpack.dumps(python_serializer(obj)))
|
||||
return obj
|
||||
try:
|
||||
# If we let strict_types is False, then whether list or tuple will
|
||||
# be packed to a message pack array. So, they can't be
|
||||
# distinguished when unpacking.
|
||||
return msgpack.dumps(o, default=_default,
|
||||
use_bin_type=True, strict_types=True)
|
||||
except ValueError as ex:
|
||||
# msgpack can't handle recursive objects, so we serialize them by
|
||||
# python serializer, e.g. pickle.
|
||||
return msgpack.dumps(_default(o), default=_default,
|
||||
use_bin_type=True, strict_types=True)
|
||||
|
||||
@classmethod
|
||||
def loads(cls, s, python_deserializer=None):
|
||||
def _ext_hook(code, data):
|
||||
if code == kLanguageSpecificTypeExtensionId:
|
||||
if python_deserializer is not None:
|
||||
return python_deserializer(msgpack.loads(data))
|
||||
raise Exception('Unrecognized ext type id: {}'.format(code))
|
||||
try:
|
||||
gc.disable() # Performance optimization for msgpack.
|
||||
return msgpack.loads(s, ext_hook=_ext_hook, raw=False)
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
def split_buffer(Buffer buf):
|
||||
cdef:
|
||||
shared_ptr[CBuffer] _buffer = buf.buffer
|
||||
const uint8_t *data = buf.buffer.get().Data()
|
||||
size_t size = _buffer.get().Size()
|
||||
size_t size = buf.buffer.get().Size()
|
||||
uint8_t[:] bufferview = buf
|
||||
int64_t msgpack_bytes_length
|
||||
|
||||
assert kMessagePackOffset <= size
|
||||
header_unpacker = msgpack.Unpacker()
|
||||
header_unpacker.feed(bufferview[:kMessagePackOffset])
|
||||
msgpack_bytes_length = header_unpacker.unpack()
|
||||
assert kMessagePackOffset + msgpack_bytes_length <= <int64_t>size
|
||||
return (bufferview[kMessagePackOffset:
|
||||
kMessagePackOffset + msgpack_bytes_length],
|
||||
bufferview[kMessagePackOffset + msgpack_bytes_length:])
|
||||
|
||||
|
||||
# See 'serialization.proto' for the memory layout in the Plasma buffer.
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
def unpack_pickle5_buffers(uint8_t[:] bufferview):
|
||||
cdef:
|
||||
const uint8_t *data = &bufferview[0]
|
||||
size_t size = len(bufferview)
|
||||
CPythonObject python_object
|
||||
CPythonBuffer *buffer_meta
|
||||
c_string inband_data
|
||||
int64_t protobuf_offset
|
||||
int64_t protobuf_size
|
||||
int32_t i
|
||||
|
@ -167,14 +226,16 @@ def unpack_pickle5_buffers(Buffer buf):
|
|||
if not python_object.ParseFromArray(
|
||||
data + protobuf_offset, <int32_t>protobuf_size):
|
||||
raise ValueError("Protobuf object is corrupted.")
|
||||
inband_data.append(<char*>(data + python_object.inband_data_offset()),
|
||||
<size_t>python_object.inband_data_size())
|
||||
inband_data_offset = python_object.inband_data_offset()
|
||||
inband_data = bufferview[
|
||||
inband_data_offset:
|
||||
inband_data_offset + python_object.inband_data_size()]
|
||||
buffers_segment = data + python_object.raw_buffers_offset()
|
||||
pickled_buffers = []
|
||||
# Now read buffer meta
|
||||
for i in range(python_object.buffer_size()):
|
||||
buffer_meta = <CPythonBuffer *>&python_object.buffer(i)
|
||||
buffer = SubBuffer(buf)
|
||||
buffer = SubBuffer(bufferview)
|
||||
buffer.buf = <void*>(buffers_segment + buffer_meta.address())
|
||||
buffer.len = buffer_meta.length()
|
||||
buffer.itemsize = buffer_meta.itemsize()
|
||||
|
@ -207,6 +268,11 @@ cdef class Pickle5Writer:
|
|||
self._curr_buffer_addr = 0
|
||||
self._total_bytes = -1
|
||||
|
||||
def __dealloc__(self):
|
||||
# We must release the buffer, or we could experience memory leaks.
|
||||
for i in range(self.buffers.size()):
|
||||
cpython.PyBuffer_Release(&self.buffers[i])
|
||||
|
||||
def buffer_callback(self, pickle_buffer):
|
||||
cdef:
|
||||
Py_buffer view
|
||||
|
@ -240,14 +306,14 @@ cdef class Pickle5Writer:
|
|||
self._curr_buffer_addr += view.len
|
||||
self.buffers.push_back(view)
|
||||
|
||||
def get_total_bytes(self, const c_string &inband):
|
||||
def get_total_bytes(self, const uint8_t[:] inband):
|
||||
cdef:
|
||||
size_t protobuf_bytes = 0
|
||||
uint64_t inband_data_offset = sizeof(int64_t) * 2
|
||||
uint64_t raw_buffers_offset = padded_length_u64(
|
||||
inband_data_offset + inband.length(), kMajorBufferAlign)
|
||||
inband_data_offset + len(inband), kMajorBufferAlign)
|
||||
self.python_object.set_inband_data_offset(inband_data_offset)
|
||||
self.python_object.set_inband_data_size(inband.length())
|
||||
self.python_object.set_inband_data_size(len(inband))
|
||||
self.python_object.set_raw_buffers_offset(raw_buffers_offset)
|
||||
self.python_object.set_raw_buffers_size(self._curr_buffer_addr)
|
||||
# Since calculating the output size is expensive, we will
|
||||
|
@ -265,9 +331,11 @@ cdef class Pickle5Writer:
|
|||
self._total_bytes = self._protobuf_offset + protobuf_bytes
|
||||
return self._total_bytes
|
||||
|
||||
cdef void write_to(self, const c_string &inband, shared_ptr[CBuffer] data,
|
||||
int memcopy_threads):
|
||||
cdef uint8_t *ptr = data.get().Data()
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void write_to(self, const uint8_t[:] inband, uint8_t[:] data,
|
||||
int memcopy_threads) nogil:
|
||||
cdef uint8_t *ptr = &data[0]
|
||||
cdef int32_t protobuf_size
|
||||
cdef uint64_t buffer_addr
|
||||
cdef uint64_t buffer_len
|
||||
|
@ -284,7 +352,7 @@ cdef class Pickle5Writer:
|
|||
ptr + self._protobuf_offset)
|
||||
# Write inband data.
|
||||
memcpy(ptr + self.python_object.inband_data_offset(),
|
||||
inband.data(), inband.length())
|
||||
&inband[0], len(inband))
|
||||
# Write buffer data.
|
||||
ptr += self.python_object.raw_buffers_offset()
|
||||
for i in range(self.python_object.buffer_size()):
|
||||
|
@ -298,5 +366,141 @@ cdef class Pickle5Writer:
|
|||
kMemcopyDefaultBlocksize, memcopy_threads)
|
||||
else:
|
||||
memcpy(ptr + buffer_addr, self.buffers[i].buf, buffer_len)
|
||||
# We must release the buffer, or we could experience memory leaks.
|
||||
cpython.PyBuffer_Release(&self.buffers[i])
|
||||
|
||||
|
||||
cdef class SerializedObject(object):
|
||||
cdef:
|
||||
object _metadata
|
||||
object _contained_object_ids
|
||||
|
||||
def __init__(self, metadata, contained_object_ids=None):
|
||||
self._metadata = metadata
|
||||
self._contained_object_ids = contained_object_ids or []
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
raise NotImplementedError("{}.total_bytes not implemented.".format(
|
||||
type(self).__name__))
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return self._metadata
|
||||
|
||||
@property
|
||||
def contained_object_ids(self):
|
||||
return self._contained_object_ids
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void write_to(self, uint8_t[:] buffer) nogil:
|
||||
raise NotImplementedError("{}.write_to not implemented.".format(
|
||||
type(self).__name__))
|
||||
|
||||
|
||||
cdef class Pickle5SerializedObject(SerializedObject):
|
||||
cdef:
|
||||
const uint8_t[:] inband
|
||||
Pickle5Writer writer
|
||||
object _total_bytes
|
||||
|
||||
def __init__(self, metadata, inband, Pickle5Writer writer,
|
||||
contained_object_ids):
|
||||
super(Pickle5SerializedObject, self).__init__(metadata,
|
||||
contained_object_ids)
|
||||
self.inband = inband
|
||||
self.writer = writer
|
||||
# cached total bytes
|
||||
self._total_bytes = None
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
if self._total_bytes is None:
|
||||
self._total_bytes = self.writer.get_total_bytes(self.inband)
|
||||
return self._total_bytes
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void write_to(self, uint8_t[:] buffer) nogil:
|
||||
self.writer.write_to(self.inband, buffer, MEMCOPY_THREADS)
|
||||
|
||||
|
||||
cdef class MessagePackSerializedObject(SerializedObject):
|
||||
cdef:
|
||||
SerializedObject nest_serialized_object
|
||||
object msgpack_header
|
||||
object msgpack_data
|
||||
int64_t _msgpack_header_bytes
|
||||
int64_t _msgpack_data_bytes
|
||||
int64_t _total_bytes
|
||||
const uint8_t *msgpack_header_ptr
|
||||
const uint8_t *msgpack_data_ptr
|
||||
|
||||
def __init__(self, metadata, msgpack_data,
|
||||
SerializedObject nest_serialized_object=None):
|
||||
if nest_serialized_object:
|
||||
contained_object_ids = nest_serialized_object.contained_object_ids
|
||||
total_bytes = nest_serialized_object.total_bytes
|
||||
else:
|
||||
contained_object_ids = []
|
||||
total_bytes = 0
|
||||
super(MessagePackSerializedObject, self).__init__(metadata,
|
||||
contained_object_ids)
|
||||
self.nest_serialized_object = nest_serialized_object
|
||||
self.msgpack_header = msgpack_header = msgpack.dumps(len(msgpack_data))
|
||||
self.msgpack_data = msgpack_data
|
||||
self._msgpack_header_bytes = len(msgpack_header)
|
||||
self._msgpack_data_bytes = len(msgpack_data)
|
||||
self._total_bytes = (kMessagePackOffset +
|
||||
self._msgpack_data_bytes +
|
||||
total_bytes)
|
||||
self.msgpack_header_ptr = <const uint8_t*>msgpack_header
|
||||
self.msgpack_data_ptr = <const uint8_t*>msgpack_data
|
||||
assert self._msgpack_header_bytes <= kMessagePackOffset
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
return self._total_bytes
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void write_to(self, uint8_t[:] buffer) nogil:
|
||||
cdef uint8_t *ptr = &buffer[0]
|
||||
|
||||
# Write msgpack data first.
|
||||
memcpy(ptr, self.msgpack_header_ptr, self._msgpack_header_bytes)
|
||||
memcpy(ptr + kMessagePackOffset,
|
||||
self.msgpack_data_ptr, self._msgpack_data_bytes)
|
||||
|
||||
if self.nest_serialized_object is not None:
|
||||
self.nest_serialized_object.write_to(
|
||||
buffer[kMessagePackOffset + self._msgpack_data_bytes:])
|
||||
|
||||
|
||||
cdef class RawSerializedObject(SerializedObject):
|
||||
cdef:
|
||||
object value
|
||||
const uint8_t *value_ptr
|
||||
int64_t _total_bytes
|
||||
|
||||
def __init__(self, value):
|
||||
super(RawSerializedObject,
|
||||
self).__init__(ray_constants.OBJECT_METADATA_TYPE_RAW)
|
||||
self.value = value
|
||||
self.value_ptr = <const uint8_t*> value
|
||||
self._total_bytes = len(value)
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
return self._total_bytes
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void write_to(self, uint8_t[:] buffer) nogil:
|
||||
if (MEMCOPY_THREADS > 1 and
|
||||
self._total_bytes > kMemcopyDefaultThreshold):
|
||||
parallel_memcopy(&buffer[0],
|
||||
self.value_ptr,
|
||||
self._total_bytes, kMemcopyDefaultBlocksize,
|
||||
MEMCOPY_THREADS)
|
||||
else:
|
||||
memcpy(&buffer[0], self.value_ptr, self._total_bytes)
|
||||
|
|
|
@ -180,13 +180,12 @@ PROCESS_TYPE_GCS_SERVER = "gcs_server"
|
|||
|
||||
LOG_MONITOR_MAX_OPEN_FILES = 200
|
||||
|
||||
# A constant used as object metadata to indicate the object is raw binary.
|
||||
RAW_BUFFER_METADATA = b"RAW"
|
||||
# A constant used as object metadata to indicate the object is pickled. This
|
||||
# format is only ever used for Python inline task argument values.
|
||||
PICKLE_BUFFER_METADATA = b"PICKLE"
|
||||
# A constant used as object metadata to indicate the object is pickle5 format.
|
||||
PICKLE5_BUFFER_METADATA = b"PICKLE5"
|
||||
# A constant used as object metadata to indicate the object is cross language.
|
||||
OBJECT_METADATA_TYPE_CROSS_LANGUAGE = b"XLANG"
|
||||
# A constant used as object metadata to indicate the object is python specific.
|
||||
OBJECT_METADATA_TYPE_PYTHON = b"PYTHON"
|
||||
# A constant used as object metadata to indicate the object is raw bytes.
|
||||
OBJECT_METADATA_TYPE_RAW = b"RAW"
|
||||
|
||||
AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request"
|
||||
|
||||
|
|
|
@ -15,7 +15,15 @@ from ray.exceptions import (
|
|||
RayWorkerError,
|
||||
UnreconstructableError,
|
||||
)
|
||||
from ray._raylet import Pickle5Writer, unpack_pickle5_buffers
|
||||
from ray._raylet import (
|
||||
split_buffer,
|
||||
unpack_pickle5_buffers,
|
||||
Pickle5Writer,
|
||||
Pickle5SerializedObject,
|
||||
MessagePackSerializer,
|
||||
MessagePackSerializedObject,
|
||||
RawSerializedObject,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -34,51 +42,6 @@ class DeserializationError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class SerializedObject:
|
||||
def __init__(self, metadata, contained_object_ids=None):
|
||||
self._metadata = metadata
|
||||
self._contained_object_ids = contained_object_ids or []
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return self._metadata
|
||||
|
||||
@property
|
||||
def contained_object_ids(self):
|
||||
return self._contained_object_ids
|
||||
|
||||
|
||||
class Pickle5SerializedObject(SerializedObject):
|
||||
def __init__(self, metadata, inband, writer, contained_object_ids):
|
||||
super(Pickle5SerializedObject, self).__init__(metadata,
|
||||
contained_object_ids)
|
||||
self.inband = inband
|
||||
self.writer = writer
|
||||
# cached total bytes
|
||||
self._total_bytes = None
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
if self._total_bytes is None:
|
||||
self._total_bytes = self.writer.get_total_bytes(self.inband)
|
||||
return self._total_bytes
|
||||
|
||||
|
||||
class RawSerializedObject(SerializedObject):
|
||||
def __init__(self, value):
|
||||
super(RawSerializedObject,
|
||||
self).__init__(ray_constants.RAW_BUFFER_METADATA)
|
||||
self.value = value
|
||||
|
||||
@property
|
||||
def total_bytes(self):
|
||||
return len(self.value)
|
||||
|
||||
|
||||
def _try_to_compute_deterministic_class_id(cls, depth=5):
|
||||
"""Attempt to produce a deterministic class ID for a given class.
|
||||
|
||||
|
@ -265,23 +228,51 @@ class SerializationContext:
|
|||
raise DeserializationError()
|
||||
return obj
|
||||
|
||||
def _deserialize_msgpack_data(self, data, metadata):
|
||||
msgpack_data, pickle5_data = split_buffer(data)
|
||||
|
||||
if metadata == ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE:
|
||||
python_objects = []
|
||||
else:
|
||||
python_objects = self._deserialize_pickle5_data(pickle5_data)
|
||||
|
||||
try:
|
||||
|
||||
def _python_deserializer(index):
|
||||
return python_objects[index]
|
||||
|
||||
obj = MessagePackSerializer.loads(msgpack_data,
|
||||
_python_deserializer)
|
||||
except Exception:
|
||||
raise DeserializationError()
|
||||
return obj
|
||||
|
||||
def _deserialize_object(self, data, metadata, object_id):
|
||||
if metadata:
|
||||
if metadata == ray_constants.PICKLE5_BUFFER_METADATA:
|
||||
return self._deserialize_pickle5_data(data)
|
||||
if metadata in [
|
||||
ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE,
|
||||
ray_constants.OBJECT_METADATA_TYPE_PYTHON
|
||||
]:
|
||||
return self._deserialize_msgpack_data(data, metadata)
|
||||
# Check if the object should be returned as raw bytes.
|
||||
if metadata == ray_constants.RAW_BUFFER_METADATA:
|
||||
if metadata == ray_constants.OBJECT_METADATA_TYPE_RAW:
|
||||
if data is None:
|
||||
return b""
|
||||
return data.to_pybytes()
|
||||
# Otherwise, return an exception object based on
|
||||
# the error type.
|
||||
error_type = int(metadata)
|
||||
try:
|
||||
error_type = int(metadata)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Can't deserialize object: {}, metadata: {}".format(
|
||||
object_id, metadata))
|
||||
|
||||
# RayTaskError is serialized with pickle5 in the data field.
|
||||
# TODO (kfstorm): exception serialization should be language
|
||||
# independent.
|
||||
if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"):
|
||||
obj = self._deserialize_pickle5_data(data)
|
||||
obj = self._deserialize_msgpack_data(data, metadata)
|
||||
assert isinstance(obj, RayTaskError)
|
||||
return obj
|
||||
elif error_type == ErrorType.Value("WORKER_DIED"):
|
||||
|
@ -347,6 +338,43 @@ class SerializationContext:
|
|||
|
||||
return results
|
||||
|
||||
def _serialize_to_pickle5(self, metadata, value):
|
||||
writer = Pickle5Writer()
|
||||
# TODO(swang): Check that contained_object_ids is empty.
|
||||
try:
|
||||
self.set_in_band_serialization()
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
except Exception as e:
|
||||
self.get_and_clear_contained_object_ids()
|
||||
raise e
|
||||
finally:
|
||||
self.set_out_of_band_serialization()
|
||||
|
||||
return Pickle5SerializedObject(
|
||||
metadata, inband, writer,
|
||||
self.get_and_clear_contained_object_ids())
|
||||
|
||||
def _serialize_to_msgpack(self, metadata, value):
|
||||
python_objects = []
|
||||
|
||||
def _python_serializer(o):
|
||||
index = len(python_objects)
|
||||
python_objects.append(o)
|
||||
return index
|
||||
|
||||
msgpack_data = MessagePackSerializer.dumps(value, _python_serializer)
|
||||
|
||||
if python_objects:
|
||||
pickle5_serialized_object = \
|
||||
self._serialize_to_pickle5(metadata, python_objects)
|
||||
else:
|
||||
metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE
|
||||
pickle5_serialized_object = None
|
||||
|
||||
return MessagePackSerializedObject(metadata, msgpack_data,
|
||||
pickle5_serialized_object)
|
||||
|
||||
def serialize(self, value):
|
||||
"""Serialize an object.
|
||||
|
||||
|
@ -365,23 +393,9 @@ class SerializationContext:
|
|||
metadata = str(ErrorType.Value(
|
||||
"TASK_EXECUTION_EXCEPTION")).encode("ascii")
|
||||
else:
|
||||
metadata = ray_constants.PICKLE5_BUFFER_METADATA
|
||||
metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON
|
||||
|
||||
writer = Pickle5Writer()
|
||||
# TODO(swang): Check that contained_object_ids is empty.
|
||||
try:
|
||||
self.set_in_band_serialization()
|
||||
inband = pickle.dumps(
|
||||
value, protocol=5, buffer_callback=writer.buffer_callback)
|
||||
except Exception as e:
|
||||
self.get_and_clear_contained_object_ids()
|
||||
raise e
|
||||
finally:
|
||||
self.set_out_of_band_serialization()
|
||||
|
||||
return Pickle5SerializedObject(
|
||||
metadata, inband, writer,
|
||||
self.get_and_clear_contained_object_ids())
|
||||
return self._serialize_to_msgpack(metadata, value)
|
||||
|
||||
def register_custom_serializer(self,
|
||||
cls,
|
||||
|
|
|
@ -13,3 +13,13 @@ def test_cross_language_raise_kwargs(shutdown_only):
|
|||
|
||||
with pytest.raises(Exception, match="kwargs"):
|
||||
ray.java_actor_class("a").remote(x="arg1")
|
||||
|
||||
|
||||
def test_cross_language_raise_exception(shutdown_only):
|
||||
ray.init(load_code_from_local=True, include_java=True)
|
||||
|
||||
class PythonObject(object):
|
||||
pass
|
||||
|
||||
with pytest.raises(Exception, match="transfer"):
|
||||
ray.java_function("a", "b").remote(PythonObject())
|
||||
|
|
|
@ -172,9 +172,19 @@ def find_version(*filepath):
|
|||
|
||||
|
||||
requires = [
|
||||
"numpy >= 1.16", "filelock", "jsonschema", "click", "colorama", "pyyaml",
|
||||
"redis >= 3.3.2", "protobuf >= 3.8.0", "py-spy >= 0.2.0", "aiohttp",
|
||||
"google", "grpcio"
|
||||
"aiohttp",
|
||||
"click",
|
||||
"colorama",
|
||||
"filelock",
|
||||
"google",
|
||||
"grpcio",
|
||||
"jsonschema",
|
||||
"msgpack >= 0.6.0, < 1.0.0",
|
||||
"numpy >= 1.16",
|
||||
"protobuf >= 3.8.0",
|
||||
"py-spy >= 0.2.0",
|
||||
"pyyaml",
|
||||
"redis >= 3.3.2",
|
||||
]
|
||||
|
||||
setup(
|
||||
|
|
|
@ -13,8 +13,8 @@ def gen_streaming_java_deps():
|
|||
"org.slf4j:slf4j-api:1.7.12",
|
||||
"org.slf4j:slf4j-log4j12:1.7.25",
|
||||
"org.apache.logging.log4j:log4j-core:2.8.2",
|
||||
"org.testng:testng:6.9.10",
|
||||
"org.msgpack:msgpack-core:0.8.20",
|
||||
"org.testng:testng:6.9.10",
|
||||
],
|
||||
repositories = [
|
||||
"https://repo1.maven.org/maven2/",
|
||||
|
|
|
@ -2,7 +2,7 @@ package org.ray.streaming.runtime.core.collector;
|
|||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Collection;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.ray.runtime.serializer.Serializer;
|
||||
import org.ray.streaming.api.collector.Collector;
|
||||
import org.ray.streaming.api.partition.Partition;
|
||||
import org.ray.streaming.message.Record;
|
||||
|
@ -31,7 +31,7 @@ public class OutputCollector implements Collector<Record> {
|
|||
@Override
|
||||
public void collect(Record record) {
|
||||
int[] partitions = this.partition.partition(record, outputQueues.length);
|
||||
ByteBuffer msgBuffer = ByteBuffer.wrap(Serializer.encode(record));
|
||||
ByteBuffer msgBuffer = ByteBuffer.wrap(Serializer.encode(record).getLeft());
|
||||
for (int partition : partitions) {
|
||||
writer.write(outputQueues[partition], msgBuffer);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package org.ray.streaming.runtime.worker.tasks;
|
||||
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.ray.runtime.serializer.Serializer;
|
||||
import org.ray.streaming.runtime.core.processor.Processor;
|
||||
import org.ray.streaming.runtime.transfer.Message;
|
||||
import org.ray.streaming.runtime.worker.JobWorker;
|
||||
|
@ -28,7 +28,7 @@ public abstract class InputStreamTask extends StreamTask {
|
|||
if (item != null) {
|
||||
byte[] bytes = new byte[item.body().remaining()];
|
||||
item.body().get(bytes);
|
||||
Object obj = Serializer.decode(bytes);
|
||||
Object obj = Serializer.decode(bytes, Object.class);
|
||||
processor.process(obj);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue