mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[java] Pass large args by reference (#3504)
This commit is contained in:
parent
de3fdeb5b5
commit
7fd24e384b
5 changed files with 59 additions and 15 deletions
|
@ -27,13 +27,16 @@ import org.ray.runtime.task.ArgumentsBuilder;
|
|||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.util.ResourceUtil;
|
||||
import org.ray.runtime.util.UniqueIdUtil;
|
||||
import org.ray.runtime.util.logger.RayLog;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Core functionality to implement Ray APIs.
|
||||
*/
|
||||
public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
|
||||
|
||||
private static final int GET_TIMEOUT_MS = 1000;
|
||||
private static final int FETCH_BATCH_SIZE = 1000;
|
||||
|
||||
|
@ -75,10 +78,26 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
|||
|
||||
public <T> void put(UniqueId objectId, T obj) {
|
||||
UniqueId taskId = workerContext.getCurrentTask().taskId;
|
||||
RayLog.core.debug("Putting object {}, for task {} ", objectId, taskId);
|
||||
LOGGER.debug("Putting object {}, for task {} ", objectId, taskId);
|
||||
objectStoreProxy.put(objectId, obj, null);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Store a serialized object in the object store.
|
||||
*
|
||||
* @param obj The serialized Java object to be stored.
|
||||
* @return A RayObject instance that represents the in-store object.
|
||||
*/
|
||||
public RayObject<Object> putSerialized(byte[] obj) {
|
||||
UniqueId objectId = UniqueIdUtil.computePutId(
|
||||
workerContext.getCurrentTask().taskId, workerContext.nextPutIndex());
|
||||
UniqueId taskId = workerContext.getCurrentTask().taskId;
|
||||
LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId);
|
||||
objectStoreProxy.putSerialized(objectId, obj, null);
|
||||
return new RayObjectImpl<>(objectId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T get(UniqueId objectId) throws RayException {
|
||||
List<T> ret = get(ImmutableList.of(objectId));
|
||||
|
@ -142,8 +161,9 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
|||
}
|
||||
}
|
||||
|
||||
RayLog.core
|
||||
.debug("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get");
|
||||
if (LOGGER.isDebugEnabled()) {
|
||||
LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()), taskId);
|
||||
}
|
||||
List<T> finalRet = new ArrayList<>();
|
||||
|
||||
for (Pair<T, GetStatus> value : ret) {
|
||||
|
@ -152,8 +172,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
|||
|
||||
return finalRet;
|
||||
} catch (RayException e) {
|
||||
RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray())
|
||||
+ " get with Exception", e);
|
||||
LOGGER.error("Failed to get objects for task {}.", taskId, e);
|
||||
throw e;
|
||||
} finally {
|
||||
// If there were objects that we weren't able to get locally, let the local
|
||||
|
|
|
@ -34,8 +34,9 @@ public class MockObjectStore implements ObjectStoreLink {
|
|||
}
|
||||
UniqueId uniqueId = new UniqueId(objectId);
|
||||
data.put(uniqueId, value);
|
||||
metadata.put(uniqueId, metadataValue);
|
||||
|
||||
if (metadataValue != null) {
|
||||
metadata.put(uniqueId, metadataValue);
|
||||
}
|
||||
if (scheduler != null) {
|
||||
scheduler.onObjectPut(uniqueId);
|
||||
}
|
||||
|
|
|
@ -75,6 +75,10 @@ public class ObjectStoreProxy {
|
|||
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
|
||||
}
|
||||
|
||||
public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) {
|
||||
store.put(id.getBytes(), obj, metadata);
|
||||
}
|
||||
|
||||
public enum GetStatus {
|
||||
SUCCESS, FAILED
|
||||
}
|
||||
|
|
|
@ -6,14 +6,17 @@ import org.ray.api.Ray;
|
|||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
|
||||
public class ArgumentsBuilder {
|
||||
|
||||
private static boolean checkSimpleValue(Object o) {
|
||||
// TODO(raulchen): implement this.
|
||||
return true;
|
||||
}
|
||||
/**
|
||||
* If the the size of an argument's serialized data is smaller than this number,
|
||||
* the argument will be passed by value. Otherwise it'll be passed by reference.
|
||||
*/
|
||||
private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
|
||||
|
||||
|
||||
/**
|
||||
* Convert real function arguments to task spec arguments.
|
||||
|
@ -30,10 +33,13 @@ public class ArgumentsBuilder {
|
|||
data = Serializer.encode(arg);
|
||||
} else if (arg instanceof RayObject) {
|
||||
id = ((RayObject) arg).getId();
|
||||
} else if (checkSimpleValue(arg)) {
|
||||
data = Serializer.encode(arg);
|
||||
} else {
|
||||
id = Ray.put(arg).getId();
|
||||
byte[] serialized = Serializer.encode(arg);
|
||||
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
id = ((AbstractRayRuntime)Ray.internal()).putSerialized(serialized).getId();
|
||||
} else {
|
||||
data = serialized;
|
||||
}
|
||||
}
|
||||
if (id != null) {
|
||||
ret[i] = FunctionArg.passByReference(id);
|
||||
|
|
|
@ -2,6 +2,8 @@ package org.ray.api.test;
|
|||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.junit.Assert;
|
||||
|
@ -66,6 +68,15 @@ public class RayCallTest {
|
|||
return val;
|
||||
}
|
||||
|
||||
public static class LargeObject implements Serializable {
|
||||
private byte[] data = new byte[1024 * 1024];
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
private static LargeObject testLargeObject(LargeObject largeObject) {
|
||||
return largeObject;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test calling and returning different types.
|
||||
*/
|
||||
|
@ -83,6 +94,8 @@ public class RayCallTest {
|
|||
Assert.assertEquals(list, Ray.call(RayCallTest::testList, list).get());
|
||||
Map<String, Integer> map = ImmutableMap.of("1", 1, "2", 2);
|
||||
Assert.assertEquals(map, Ray.call(RayCallTest::testMap, map).get());
|
||||
LargeObject largeObject = new LargeObject();
|
||||
Assert.assertNotNull(Ray.call(RayCallTest::testLargeObject, largeObject).get());
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
|
@ -130,4 +143,5 @@ public class RayCallTest {
|
|||
Assert.assertEquals(5, (int) Ray.call(RayCallTest::testFiveParams, 1, 1, 1, 1, 1).get());
|
||||
Assert.assertEquals(6, (int) Ray.call(RayCallTest::testSixParams, 1, 1, 1, 1, 1, 1).get());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue