mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[Java] some small improvements (#14565)
This commit is contained in:
parent
9cf328d616
commit
f60bd3afee
13 changed files with 121 additions and 149 deletions
|
@ -13,9 +13,9 @@ package io.ray.api;
|
|||
* }
|
||||
* }
|
||||
* // Create an actor, and get a handle.
|
||||
* ActorHandle<MyActor> myActor = Ray.createActor(MyActor::new);
|
||||
* ActorHandle<MyActor> myActor = Ray.actor(MyActor::new).remote();
|
||||
* // Call the `echo` method remotely.
|
||||
* ObjectRef<Integer> result = myActor.call(MyActor::echo, 1);
|
||||
* ObjectRef<Integer> result = myActor.task(MyActor::echo, 1).remote();
|
||||
* // Get the result of the remote `echo` method.
|
||||
* Assert.assertEqual(result.get(), 1);
|
||||
* }</pre>
|
||||
|
|
|
@ -2,6 +2,7 @@ package io.ray.api.runtimecontext;
|
|||
|
||||
import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.JobId;
|
||||
import io.ray.api.id.TaskId;
|
||||
import java.util.List;
|
||||
|
||||
/** A class used for getting information of Ray runtime. */
|
||||
|
@ -10,6 +11,9 @@ public interface RuntimeContext {
|
|||
/** Get the current Job ID. */
|
||||
JobId getCurrentJobId();
|
||||
|
||||
/** Get current task ID. */
|
||||
TaskId getCurrentTaskId();
|
||||
|
||||
/**
|
||||
* Get the current actor ID.
|
||||
*
|
||||
|
|
|
@ -21,6 +21,7 @@ import io.ray.api.options.PlacementGroupCreationOptions;
|
|||
import io.ray.api.placementgroup.PlacementGroup;
|
||||
import io.ray.api.runtimecontext.RuntimeContext;
|
||||
import io.ray.runtime.config.RayConfig;
|
||||
import io.ray.runtime.config.RunMode;
|
||||
import io.ray.runtime.context.RuntimeContextImpl;
|
||||
import io.ray.runtime.context.WorkerContext;
|
||||
import io.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
|
@ -71,6 +72,9 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
|
||||
@Override
|
||||
public <T> ObjectRef<T> put(T obj) {
|
||||
if (LOGGER.isDebugEnabled()) {
|
||||
LOGGER.debug("Putting Object in Task {}.", workerContext.getCurrentTaskId());
|
||||
}
|
||||
ObjectId objectId = objectStore.put(obj);
|
||||
return new ObjectRefImpl<T>(objectId, (Class<T>) (obj == null ? Object.class : obj.getClass()));
|
||||
}
|
||||
|
@ -90,21 +94,30 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
objectIds.add(objectRefImpl.getId());
|
||||
objectType = objectRefImpl.getType();
|
||||
}
|
||||
LOGGER.debug("Getting Objects {}.", objectIds);
|
||||
return objectStore.get(objectIds, objectType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void free(List<ObjectRef<?>> objectRefs, boolean localOnly) {
|
||||
objectStore.delete(
|
||||
List<ObjectId> objectIds =
|
||||
objectRefs.stream()
|
||||
.map(ref -> ((ObjectRefImpl<?>) ref).getId())
|
||||
.collect(Collectors.toList()),
|
||||
localOnly);
|
||||
.collect(Collectors.toList());
|
||||
LOGGER.debug("Freeing Objects {}, localOnly = {}.", objectIds, localOnly);
|
||||
objectStore.delete(objectIds, localOnly);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(
|
||||
List<ObjectRef<T>> waitList, int numReturns, int timeoutMs, boolean fetchLocal) {
|
||||
if (LOGGER.isDebugEnabled()) {
|
||||
LOGGER.debug(
|
||||
"Waiting Objects {} with minimum number {} within {} ms.",
|
||||
waitList,
|
||||
numReturns,
|
||||
timeoutMs);
|
||||
}
|
||||
return objectStore.wait(waitList, numReturns, timeoutMs, fetchLocal);
|
||||
}
|
||||
|
||||
|
@ -259,6 +272,9 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
CallOptions options) {
|
||||
int numReturns = returnType.isPresent() ? 1 : 0;
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
|
||||
if (options == null) {
|
||||
options = new CallOptions.Builder().build();
|
||||
}
|
||||
List<ObjectId> returnIds =
|
||||
taskSubmitter.submitTask(functionDescriptor, functionArgs, numReturns, options);
|
||||
Preconditions.checkState(returnIds.size() == numReturns);
|
||||
|
@ -275,6 +291,9 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
Object[] args,
|
||||
Optional<Class<?>> returnType) {
|
||||
int numReturns = returnType.isPresent() ? 1 : 0;
|
||||
if (LOGGER.isDebugEnabled()) {
|
||||
LOGGER.debug("Submitting Actor Task {}.", functionDescriptor);
|
||||
}
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
|
||||
List<ObjectId> returnIds =
|
||||
taskSubmitter.submitActorTask(rayActor, functionDescriptor, functionArgs, numReturns, null);
|
||||
|
@ -288,6 +307,19 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
|||
|
||||
private BaseActorHandle createActorImpl(
|
||||
FunctionDescriptor functionDescriptor, Object[] args, ActorCreationOptions options) {
|
||||
if (LOGGER.isDebugEnabled()) {
|
||||
if (options == null) {
|
||||
LOGGER.debug("Creating Actor {} with default options.", functionDescriptor);
|
||||
} else {
|
||||
LOGGER.debug("Creating Actor {}, jvmOptions = {}.", functionDescriptor, options.jvmOptions);
|
||||
}
|
||||
}
|
||||
if (rayConfig.runMode == RunMode.SINGLE_PROCESS
|
||||
&& functionDescriptor.getLanguage() != Language.JAVA) {
|
||||
throw new IllegalArgumentException(
|
||||
"Ray doesn't support cross-language invocation in local mode.");
|
||||
}
|
||||
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
|
||||
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
|
||||
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package io.ray.runtime.context;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.JobId;
|
||||
import io.ray.api.id.TaskId;
|
||||
|
@ -9,6 +10,7 @@ import io.ray.runtime.generated.Common.Address;
|
|||
import io.ray.runtime.generated.Common.TaskSpec;
|
||||
import io.ray.runtime.generated.Common.TaskType;
|
||||
import io.ray.runtime.task.LocalModeTaskSubmitter;
|
||||
import java.util.Random;
|
||||
|
||||
/** Worker context for local mode. */
|
||||
public class LocalModeWorkerContext implements WorkerContext {
|
||||
|
@ -19,6 +21,14 @@ public class LocalModeWorkerContext implements WorkerContext {
|
|||
|
||||
public LocalModeWorkerContext(JobId jobId) {
|
||||
this.jobId = jobId;
|
||||
|
||||
// Create a dummy driver task with a random task id, so that we can call
|
||||
// `getCurrentTaskId` from a driver.
|
||||
byte[] driverTaskId = new byte[TaskId.LENGTH];
|
||||
new Random().nextBytes(driverTaskId);
|
||||
TaskSpec dummyDriverTask =
|
||||
TaskSpec.newBuilder().setTaskId(ByteString.copyFrom(driverTaskId)).build();
|
||||
currentTask.set(dummyDriverTask);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -3,6 +3,7 @@ package io.ray.runtime.context;
|
|||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.JobId;
|
||||
import io.ray.api.id.TaskId;
|
||||
import io.ray.api.runtimecontext.NodeInfo;
|
||||
import io.ray.api.runtimecontext.RuntimeContext;
|
||||
import io.ray.runtime.RayRuntimeInternal;
|
||||
|
@ -30,6 +31,11 @@ public class RuntimeContextImpl implements RuntimeContext {
|
|||
return actorId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskId getCurrentTaskId() {
|
||||
return runtime.getWorkerContext().getCurrentTaskId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean wasCurrentActorRestarted() {
|
||||
if (isSingleProcess()) {
|
||||
|
|
|
@ -10,13 +10,11 @@ import io.ray.api.placementgroup.PlacementGroup;
|
|||
import io.ray.api.runtimecontext.NodeInfo;
|
||||
import io.ray.runtime.generated.Gcs;
|
||||
import io.ray.runtime.generated.Gcs.GcsNodeInfo;
|
||||
import io.ray.runtime.generated.Gcs.TablePrefix;
|
||||
import io.ray.runtime.placementgroup.PlacementGroupUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -132,8 +130,6 @@ public class GcsClient {
|
|||
}
|
||||
|
||||
public boolean wasCurrentActorRestarted(ActorId actorId) {
|
||||
byte[] key = ArrayUtils.addAll(TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes());
|
||||
|
||||
// TODO(ZhuSenlin): Get the actor table data from CoreWorker later.
|
||||
byte[] value = globalStateAccessor.getActorInfo(actorId);
|
||||
if (value == null) {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package io.ray.runtime.gcs;
|
||||
|
||||
import com.google.common.base.Strings;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import redis.clients.jedis.Jedis;
|
||||
import redis.clients.jedis.JedisPool;
|
||||
|
@ -55,12 +54,6 @@ public class RedisClient {
|
|||
}
|
||||
}
|
||||
|
||||
public Map<byte[], byte[]> hgetAll(byte[] key) {
|
||||
try (Jedis jedis = jedisPool.getResource()) {
|
||||
return jedis.hgetAll(key);
|
||||
}
|
||||
}
|
||||
|
||||
public String get(final String key, final String field) {
|
||||
try (Jedis jedis = jedisPool.getResource()) {
|
||||
if (field == null) {
|
||||
|
@ -85,17 +78,6 @@ public class RedisClient {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the specified elements of the list stored at the specified key.
|
||||
*
|
||||
* @return Multi bulk reply, specifically a list of elements in the specified range.
|
||||
*/
|
||||
public List<byte[]> lrange(byte[] key, long start, long end) {
|
||||
try (Jedis jedis = jedisPool.getResource()) {
|
||||
return jedis.lrange(key, start, end);
|
||||
}
|
||||
}
|
||||
|
||||
/** Whether the key exists in Redis. */
|
||||
public boolean exists(byte[] key) {
|
||||
try (Jedis jedis = jedisPool.getResource()) {
|
||||
|
|
|
@ -90,7 +90,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
runtime.setIsContextSet(true);
|
||||
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
|
||||
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
|
||||
LOGGER.debug("Executing task {}", taskId);
|
||||
LOGGER.debug("Executing task {} {}", taskId, rayFunctionInfo);
|
||||
|
||||
T actorContext = null;
|
||||
if (taskType == TaskType.ACTOR_CREATION_TASK) {
|
||||
|
@ -103,6 +103,8 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
|
||||
List<NativeRayObject> returnObjects = new ArrayList<>();
|
||||
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
|
||||
// Find the executable object.
|
||||
|
||||
RayFunction rayFunction = localRayFunction.get();
|
||||
try {
|
||||
// Find the executable object.
|
||||
|
@ -133,6 +135,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
result = rayFunction.getConstructor().newInstance(args);
|
||||
}
|
||||
} catch (InvocationTargetException e) {
|
||||
LOGGER.error("Execute rayFunction {} failed. actor {}, args {}", rayFunction, actor, args);
|
||||
if (e.getCause() != null) {
|
||||
throw e.getCause();
|
||||
} else {
|
||||
|
@ -156,30 +159,41 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
|||
throw (RayIntentionalSystemExitException) e;
|
||||
}
|
||||
LOGGER.error("Error executing task " + taskId, e);
|
||||
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
|
||||
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
|
||||
if (hasReturn || isCrossLanguage) {
|
||||
NativeRayObject serializedException;
|
||||
try {
|
||||
serializedException =
|
||||
ObjectSerializer.serialize(
|
||||
new RayTaskException("Error executing task " + taskId, e));
|
||||
} catch (Exception unserializable) {
|
||||
// We should try-catch `ObjectSerializer.serialize` here. Because otherwise if the
|
||||
// application-level exception is not serializable. `ObjectSerializer.serialize`
|
||||
// will throw an exception and crash the worker.
|
||||
// Refer to the case `TaskExceptionTest.java` for more details.
|
||||
LOGGER.warn("Failed to serialize the exception to a RayObject.", unserializable);
|
||||
serializedException =
|
||||
ObjectSerializer.serialize(
|
||||
new RayTaskException(
|
||||
String.format(
|
||||
"Error executing task %s with the exception: %s",
|
||||
taskId, ExceptionUtils.getStackTrace(e))));
|
||||
if (rayFunction != null) {
|
||||
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
|
||||
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
|
||||
if (hasReturn || isCrossLanguage) {
|
||||
NativeRayObject serializedException;
|
||||
try {
|
||||
serializedException =
|
||||
ObjectSerializer.serialize(
|
||||
new RayTaskException("Error executing task " + taskId, e));
|
||||
} catch (Exception unserializable) {
|
||||
// We should try-catch `ObjectSerializer.serialize` here. Because otherwise if the
|
||||
// application-level exception is not serializable. `ObjectSerializer.serialize`
|
||||
// will throw an exception and crash the worker.
|
||||
// Refer to the case `TaskExceptionTest.java` for more details.
|
||||
LOGGER.warn("Failed to serialize the exception to a RayObject.", unserializable);
|
||||
serializedException =
|
||||
ObjectSerializer.serialize(
|
||||
new RayTaskException(
|
||||
String.format(
|
||||
"Error executing task %s with the exception: %s",
|
||||
taskId, ExceptionUtils.getStackTrace(e))));
|
||||
}
|
||||
Preconditions.checkNotNull(serializedException);
|
||||
returnObjects.add(serializedException);
|
||||
}
|
||||
Preconditions.checkNotNull(serializedException);
|
||||
returnObjects.add(serializedException);
|
||||
} else {
|
||||
returnObjects.add(
|
||||
ObjectSerializer.serialize(
|
||||
new RayTaskException(
|
||||
String.format(
|
||||
"Function %s of task %s doesn't exist",
|
||||
String.join(".", rayFunctionInfo), taskId),
|
||||
e)));
|
||||
}
|
||||
} else {
|
||||
actorContext.actorCreationException = e;
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
package io.ray.runtime.util;
|
||||
|
||||
import com.google.common.base.Strings;
|
||||
import java.io.IOException;
|
||||
import java.net.DatagramSocket;
|
||||
import java.net.Inet6Address;
|
||||
import java.net.InetAddress;
|
||||
import java.net.NetworkInterface;
|
||||
import java.net.ServerSocket;
|
||||
import java.util.Enumeration;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -16,9 +12,6 @@ public class NetworkUtil {
|
|||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(NetworkUtil.class);
|
||||
|
||||
private static final int MIN_PORT = 10000;
|
||||
private static final int MAX_PORT = 65535;
|
||||
|
||||
public static String getIpAddress(String interfaceName) {
|
||||
try {
|
||||
Enumeration<NetworkInterface> interfaces = NetworkInterface.getNetworkInterfaces();
|
||||
|
@ -50,29 +43,4 @@ public class NetworkUtil {
|
|||
|
||||
return "127.0.0.1";
|
||||
}
|
||||
|
||||
public static int getUnusedPort() {
|
||||
while (true) {
|
||||
int port = ThreadLocalRandom.current().nextInt(MAX_PORT - MIN_PORT) + MIN_PORT;
|
||||
if (isPortAvailable(port)) {
|
||||
return port;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean isPortAvailable(int port) {
|
||||
if (port < 1 || port > 65535) {
|
||||
throw new IllegalArgumentException("Invalid start port: " + port);
|
||||
}
|
||||
|
||||
try (ServerSocket ss = new ServerSocket(port);
|
||||
DatagramSocket ds = new DatagramSocket(port)) {
|
||||
ss.setReuseAddress(true);
|
||||
ds.setReuseAddress(true);
|
||||
return true;
|
||||
} catch (IOException ignored) {
|
||||
/* should not be thrown */
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
package io.ray.runtime.util;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class ResourceUtil {
|
||||
public static final String CPU_LITERAL = "CPU";
|
||||
public static final String GPU_LITERAL = "GPU";
|
||||
|
||||
/**
|
||||
* Convert resources map to a string that is used for the command line argument of starting
|
||||
* raylet.
|
||||
*
|
||||
* @param resources The resources map to be converted.
|
||||
* @return The starting-raylet command line argument, like "CPU,4,GPU,0".
|
||||
*/
|
||||
public static String getResourcesStringFromMap(Map<String, Double> resources) {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
if (resources != null) {
|
||||
int count = 1;
|
||||
for (Map.Entry<String, Double> entry : resources.entrySet()) {
|
||||
builder.append(entry.getKey()).append(",").append(entry.getValue());
|
||||
if (count != resources.size()) {
|
||||
builder.append(",");
|
||||
}
|
||||
count++;
|
||||
}
|
||||
}
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the static resources configure field and convert to the resources map.
|
||||
*
|
||||
* @param resources The static resources string to be parsed.
|
||||
* @return The map whose key represents the resource name and the value represents the resource
|
||||
* quantity.
|
||||
* @throws IllegalArgumentException If the resources string's format does match, it will throw an
|
||||
* IllegalArgumentException.
|
||||
*/
|
||||
public static Map<String, Double> getResourcesMapFromString(String resources)
|
||||
throws IllegalArgumentException {
|
||||
Map<String, Double> ret = new HashMap<>();
|
||||
if (resources != null) {
|
||||
String[] items = resources.split(",");
|
||||
for (String item : items) {
|
||||
String trimItem = item.trim();
|
||||
if (trimItem.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
String[] resourcePair = trimItem.split(":");
|
||||
|
||||
if (resourcePair.length != 2) {
|
||||
throw new IllegalArgumentException("Format of static resources configure is invalid.");
|
||||
}
|
||||
|
||||
final String resourceName = resourcePair[0].trim();
|
||||
final Double resourceValue = Double.valueOf(resourcePair[1].trim());
|
||||
ret.put(resourceName, resourceValue);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
|
@ -1,10 +1,8 @@
|
|||
package io.ray.test;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import io.ray.api.ActorHandle;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
@ -33,7 +31,6 @@ public class ActorConcurrentCallTest extends BaseTest {
|
|||
ObjectRef<String> obj2 = actor.task(ConcurrentActor::countDown).remote();
|
||||
ObjectRef<String> obj3 = actor.task(ConcurrentActor::countDown).remote();
|
||||
|
||||
List<Integer> expectedResult = ImmutableList.of(1, 2, 3);
|
||||
Assert.assertEquals(obj1.get(), "ok");
|
||||
Assert.assertEquals(obj2.get(), "ok");
|
||||
Assert.assertEquals(obj3.get(), "ok");
|
||||
|
|
|
@ -4,8 +4,10 @@ import com.google.common.collect.ImmutableList;
|
|||
import com.google.common.collect.ImmutableMap;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.runtime.task.ArgumentsBuilder;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.testng.Assert;
|
||||
|
@ -14,6 +16,9 @@ import org.testng.annotations.Test;
|
|||
/** Test Ray.call API */
|
||||
public class RayCallTest extends BaseTest {
|
||||
|
||||
private static final byte[] LARGE_RAW_DATA =
|
||||
new byte[ArgumentsBuilder.LARGEST_SIZE_PASS_BY_VALUE + 100];
|
||||
|
||||
private static int testInt(int val) {
|
||||
return val;
|
||||
}
|
||||
|
@ -22,6 +27,10 @@ public class RayCallTest extends BaseTest {
|
|||
return val;
|
||||
}
|
||||
|
||||
private static byte[] testBytes(byte[] val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
private static short testShort(short val) {
|
||||
return val;
|
||||
}
|
||||
|
@ -98,6 +107,12 @@ public class RayCallTest extends BaseTest {
|
|||
// Assert.assertEquals(((int) Ray.get(randomObjectId, Integer.class)), 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBytesType() {
|
||||
Assert.assertEquals(
|
||||
"123".getBytes(), Ray.task(RayCallTest::testBytes, "123".getBytes()).remote().get());
|
||||
}
|
||||
|
||||
private static int testNoParam() {
|
||||
return 0;
|
||||
}
|
||||
|
@ -138,4 +153,13 @@ public class RayCallTest extends BaseTest {
|
|||
Assert.assertEquals(
|
||||
6, (int) Ray.task(RayCallTest::testSixParams, 1, 1, 1, 1, 1, 1).remote().get());
|
||||
}
|
||||
|
||||
private static Boolean testLargeRawData(byte[] data) {
|
||||
return Arrays.equals(data, LARGE_RAW_DATA);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLargeRawDataArgument() {
|
||||
Assert.assertTrue(Ray.task(RayCallTest::testLargeRawData, LARGE_RAW_DATA).remote().get());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import io.ray.api.ActorHandle;
|
|||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.JobId;
|
||||
import io.ray.api.id.TaskId;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import org.testng.Assert;
|
||||
|
@ -29,12 +30,14 @@ public class RuntimeContextTest extends BaseTest {
|
|||
@Test
|
||||
public void testRuntimeContextInDriver() {
|
||||
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
|
||||
Assert.assertNotEquals(Ray.getRuntimeContext().getCurrentTaskId(), TaskId.NIL);
|
||||
}
|
||||
|
||||
public static class RuntimeContextTester {
|
||||
|
||||
public String testRuntimeContext(ActorId actorId) {
|
||||
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
|
||||
Assert.assertNotEquals(Ray.getRuntimeContext().getCurrentTaskId(), TaskId.NIL);
|
||||
Assert.assertEquals(actorId, Ray.getRuntimeContext().getCurrentActorId());
|
||||
return "ok";
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue