[java] Pass large args by reference (#3504)

This commit is contained in:
bibabolynn 2018-12-14 23:32:35 +08:00 committed by Hao Chen
parent de3fdeb5b5
commit 7fd24e384b
5 changed files with 59 additions and 15 deletions

View file

@ -27,13 +27,16 @@ import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.util.UniqueIdUtil;
import org.ray.runtime.util.logger.RayLog;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Core functionality to implement Ray APIs.
*/
public abstract class AbstractRayRuntime implements RayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
private static final int GET_TIMEOUT_MS = 1000;
private static final int FETCH_BATCH_SIZE = 1000;
@ -75,10 +78,26 @@ public abstract class AbstractRayRuntime implements RayRuntime {
public <T> void put(UniqueId objectId, T obj) {
UniqueId taskId = workerContext.getCurrentTask().taskId;
RayLog.core.debug("Putting object {}, for task {} ", objectId, taskId);
LOGGER.debug("Putting object {}, for task {} ", objectId, taskId);
objectStoreProxy.put(objectId, obj, null);
}
/**
* Store a serialized object in the object store.
*
* @param obj The serialized Java object to be stored.
* @return A RayObject instance that represents the in-store object.
*/
public RayObject<Object> putSerialized(byte[] obj) {
UniqueId objectId = UniqueIdUtil.computePutId(
workerContext.getCurrentTask().taskId, workerContext.nextPutIndex());
UniqueId taskId = workerContext.getCurrentTask().taskId;
LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId);
objectStoreProxy.putSerialized(objectId, obj, null);
return new RayObjectImpl<>(objectId);
}
@Override
public <T> T get(UniqueId objectId) throws RayException {
List<T> ret = get(ImmutableList.of(objectId));
@ -142,8 +161,9 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
}
RayLog.core
.debug("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get");
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()), taskId);
}
List<T> finalRet = new ArrayList<>();
for (Pair<T, GetStatus> value : ret) {
@ -152,8 +172,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return finalRet;
} catch (RayException e) {
RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray())
+ " get with Exception", e);
LOGGER.error("Failed to get objects for task {}.", taskId, e);
throw e;
} finally {
// If there were objects that we weren't able to get locally, let the local

View file

@ -34,8 +34,9 @@ public class MockObjectStore implements ObjectStoreLink {
}
UniqueId uniqueId = new UniqueId(objectId);
data.put(uniqueId, value);
metadata.put(uniqueId, metadataValue);
if (metadataValue != null) {
metadata.put(uniqueId, metadataValue);
}
if (scheduler != null) {
scheduler.onObjectPut(uniqueId);
}

View file

@ -75,6 +75,10 @@ public class ObjectStoreProxy {
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
}
public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) {
store.put(id.getBytes(), obj, metadata);
}
public enum GetStatus {
SUCCESS, FAILED
}

View file

@ -6,14 +6,17 @@ import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.util.Serializer;
public class ArgumentsBuilder {
private static boolean checkSimpleValue(Object o) {
// TODO(raulchen): implement this.
return true;
}
/**
* If the the size of an argument's serialized data is smaller than this number,
* the argument will be passed by value. Otherwise it'll be passed by reference.
*/
private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
/**
* Convert real function arguments to task spec arguments.
@ -30,10 +33,13 @@ public class ArgumentsBuilder {
data = Serializer.encode(arg);
} else if (arg instanceof RayObject) {
id = ((RayObject) arg).getId();
} else if (checkSimpleValue(arg)) {
data = Serializer.encode(arg);
} else {
id = Ray.put(arg).getId();
byte[] serialized = Serializer.encode(arg);
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
id = ((AbstractRayRuntime)Ray.internal()).putSerialized(serialized).getId();
} else {
data = serialized;
}
}
if (id != null) {
ret[i] = FunctionArg.passByReference(id);

View file

@ -2,6 +2,8 @@ package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import org.junit.Assert;
@ -66,6 +68,15 @@ public class RayCallTest {
return val;
}
public static class LargeObject implements Serializable {
private byte[] data = new byte[1024 * 1024];
}
@RayRemote
private static LargeObject testLargeObject(LargeObject largeObject) {
return largeObject;
}
/**
* Test calling and returning different types.
*/
@ -83,6 +94,8 @@ public class RayCallTest {
Assert.assertEquals(list, Ray.call(RayCallTest::testList, list).get());
Map<String, Integer> map = ImmutableMap.of("1", 1, "2", 2);
Assert.assertEquals(map, Ray.call(RayCallTest::testMap, map).get());
LargeObject largeObject = new LargeObject();
Assert.assertNotNull(Ray.call(RayCallTest::testLargeObject, largeObject).get());
}
@RayRemote
@ -130,4 +143,5 @@ public class RayCallTest {
Assert.assertEquals(5, (int) Ray.call(RayCallTest::testFiveParams, 1, 1, 1, 1, 1).get());
Assert.assertEquals(6, (int) Ray.call(RayCallTest::testSixParams, 1, 1, 1, 1, 1, 1).get());
}
}