[Java] some small improvements (#14565)

This commit is contained in:
Kai Yang 2021-03-12 12:26:55 +08:00 committed by GitHub
parent 9cf328d616
commit f60bd3afee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 121 additions and 149 deletions

View file

@ -13,9 +13,9 @@ package io.ray.api;
* }
* }
* // Create an actor, and get a handle.
* ActorHandle<MyActor> myActor = Ray.createActor(MyActor::new);
* ActorHandle<MyActor> myActor = Ray.actor(MyActor::new).remote();
* // Call the `echo` method remotely.
* ObjectRef<Integer> result = myActor.call(MyActor::echo, 1);
* ObjectRef<Integer> result = myActor.task(MyActor::echo, 1).remote();
* // Get the result of the remote `echo` method.
* Assert.assertEqual(result.get(), 1);
* }</pre>

View file

@ -2,6 +2,7 @@ package io.ray.api.runtimecontext;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import java.util.List;
/** A class used for getting information of Ray runtime. */
@ -10,6 +11,9 @@ public interface RuntimeContext {
/** Get the current Job ID. */
JobId getCurrentJobId();
/** Get current task ID. */
TaskId getCurrentTaskId();
/**
* Get the current actor ID.
*

View file

@ -21,6 +21,7 @@ import io.ray.api.options.PlacementGroupCreationOptions;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.config.RunMode;
import io.ray.runtime.context.RuntimeContextImpl;
import io.ray.runtime.context.WorkerContext;
import io.ray.runtime.functionmanager.FunctionDescriptor;
@ -71,6 +72,9 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
@Override
public <T> ObjectRef<T> put(T obj) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Putting Object in Task {}.", workerContext.getCurrentTaskId());
}
ObjectId objectId = objectStore.put(obj);
return new ObjectRefImpl<T>(objectId, (Class<T>) (obj == null ? Object.class : obj.getClass()));
}
@ -90,21 +94,30 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
objectIds.add(objectRefImpl.getId());
objectType = objectRefImpl.getType();
}
LOGGER.debug("Getting Objects {}.", objectIds);
return objectStore.get(objectIds, objectType);
}
@Override
public void free(List<ObjectRef<?>> objectRefs, boolean localOnly) {
objectStore.delete(
List<ObjectId> objectIds =
objectRefs.stream()
.map(ref -> ((ObjectRefImpl<?>) ref).getId())
.collect(Collectors.toList()),
localOnly);
.collect(Collectors.toList());
LOGGER.debug("Freeing Objects {}, localOnly = {}.", objectIds, localOnly);
objectStore.delete(objectIds, localOnly);
}
@Override
public <T> WaitResult<T> wait(
List<ObjectRef<T>> waitList, int numReturns, int timeoutMs, boolean fetchLocal) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"Waiting Objects {} with minimum number {} within {} ms.",
waitList,
numReturns,
timeoutMs);
}
return objectStore.wait(waitList, numReturns, timeoutMs, fetchLocal);
}
@ -259,6 +272,9 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
CallOptions options) {
int numReturns = returnType.isPresent() ? 1 : 0;
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
if (options == null) {
options = new CallOptions.Builder().build();
}
List<ObjectId> returnIds =
taskSubmitter.submitTask(functionDescriptor, functionArgs, numReturns, options);
Preconditions.checkState(returnIds.size() == numReturns);
@ -275,6 +291,9 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
Object[] args,
Optional<Class<?>> returnType) {
int numReturns = returnType.isPresent() ? 1 : 0;
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Submitting Actor Task {}.", functionDescriptor);
}
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
List<ObjectId> returnIds =
taskSubmitter.submitActorTask(rayActor, functionDescriptor, functionArgs, numReturns, null);
@ -288,6 +307,19 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
private BaseActorHandle createActorImpl(
FunctionDescriptor functionDescriptor, Object[] args, ActorCreationOptions options) {
if (LOGGER.isDebugEnabled()) {
if (options == null) {
LOGGER.debug("Creating Actor {} with default options.", functionDescriptor);
} else {
LOGGER.debug("Creating Actor {}, jvmOptions = {}.", functionDescriptor, options.jvmOptions);
}
}
if (rayConfig.runMode == RunMode.SINGLE_PROCESS
&& functionDescriptor.getLanguage() != Language.JAVA) {
throw new IllegalArgumentException(
"Ray doesn't support cross-language invocation in local mode.");
}
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));

View file

@ -1,6 +1,7 @@
package io.ray.runtime.context;
import com.google.common.base.Preconditions;
import com.google.protobuf.ByteString;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
@ -9,6 +10,7 @@ import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.generated.Common.TaskSpec;
import io.ray.runtime.generated.Common.TaskType;
import io.ray.runtime.task.LocalModeTaskSubmitter;
import java.util.Random;
/** Worker context for local mode. */
public class LocalModeWorkerContext implements WorkerContext {
@ -19,6 +21,14 @@ public class LocalModeWorkerContext implements WorkerContext {
public LocalModeWorkerContext(JobId jobId) {
this.jobId = jobId;
// Create a dummy driver task with a random task id, so that we can call
// `getCurrentTaskId` from a driver.
byte[] driverTaskId = new byte[TaskId.LENGTH];
new Random().nextBytes(driverTaskId);
TaskSpec dummyDriverTask =
TaskSpec.newBuilder().setTaskId(ByteString.copyFrom(driverTaskId)).build();
currentTask.set(dummyDriverTask);
}
@Override

View file

@ -3,6 +3,7 @@ package io.ray.runtime.context;
import com.google.common.base.Preconditions;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import io.ray.api.runtimecontext.NodeInfo;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.runtime.RayRuntimeInternal;
@ -30,6 +31,11 @@ public class RuntimeContextImpl implements RuntimeContext {
return actorId;
}
@Override
public TaskId getCurrentTaskId() {
return runtime.getWorkerContext().getCurrentTaskId();
}
@Override
public boolean wasCurrentActorRestarted() {
if (isSingleProcess()) {

View file

@ -10,13 +10,11 @@ import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimecontext.NodeInfo;
import io.ray.runtime.generated.Gcs;
import io.ray.runtime.generated.Gcs.GcsNodeInfo;
import io.ray.runtime.generated.Gcs.TablePrefix;
import io.ray.runtime.placementgroup.PlacementGroupUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -132,8 +130,6 @@ public class GcsClient {
}
public boolean wasCurrentActorRestarted(ActorId actorId) {
byte[] key = ArrayUtils.addAll(TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes());
// TODO(ZhuSenlin): Get the actor table data from CoreWorker later.
byte[] value = globalStateAccessor.getActorInfo(actorId);
if (value == null) {

View file

@ -1,7 +1,6 @@
package io.ray.runtime.gcs;
import com.google.common.base.Strings;
import java.util.List;
import java.util.Map;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
@ -55,12 +54,6 @@ public class RedisClient {
}
}
public Map<byte[], byte[]> hgetAll(byte[] key) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.hgetAll(key);
}
}
public String get(final String key, final String field) {
try (Jedis jedis = jedisPool.getResource()) {
if (field == null) {
@ -85,17 +78,6 @@ public class RedisClient {
}
}
/**
* Return the specified elements of the list stored at the specified key.
*
* @return Multi bulk reply, specifically a list of elements in the specified range.
*/
public List<byte[]> lrange(byte[] key, long start, long end) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.lrange(key, start, end);
}
}
/** Whether the key exists in Redis. */
public boolean exists(byte[] key) {
try (Jedis jedis = jedisPool.getResource()) {

View file

@ -90,7 +90,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
runtime.setIsContextSet(true);
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
LOGGER.debug("Executing task {}", taskId);
LOGGER.debug("Executing task {} {}", taskId, rayFunctionInfo);
T actorContext = null;
if (taskType == TaskType.ACTOR_CREATION_TASK) {
@ -103,6 +103,8 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
List<NativeRayObject> returnObjects = new ArrayList<>();
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
// Find the executable object.
RayFunction rayFunction = localRayFunction.get();
try {
// Find the executable object.
@ -133,6 +135,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
result = rayFunction.getConstructor().newInstance(args);
}
} catch (InvocationTargetException e) {
LOGGER.error("Execute rayFunction {} failed. actor {}, args {}", rayFunction, actor, args);
if (e.getCause() != null) {
throw e.getCause();
} else {
@ -156,30 +159,41 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
throw (RayIntentionalSystemExitException) e;
}
LOGGER.error("Error executing task " + taskId, e);
if (taskType != TaskType.ACTOR_CREATION_TASK) {
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
if (hasReturn || isCrossLanguage) {
NativeRayObject serializedException;
try {
serializedException =
ObjectSerializer.serialize(
new RayTaskException("Error executing task " + taskId, e));
} catch (Exception unserializable) {
// We should try-catch `ObjectSerializer.serialize` here. Because otherwise if the
// application-level exception is not serializable. `ObjectSerializer.serialize`
// will throw an exception and crash the worker.
// Refer to the case `TaskExceptionTest.java` for more details.
LOGGER.warn("Failed to serialize the exception to a RayObject.", unserializable);
serializedException =
ObjectSerializer.serialize(
new RayTaskException(
String.format(
"Error executing task %s with the exception: %s",
taskId, ExceptionUtils.getStackTrace(e))));
if (rayFunction != null) {
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
if (hasReturn || isCrossLanguage) {
NativeRayObject serializedException;
try {
serializedException =
ObjectSerializer.serialize(
new RayTaskException("Error executing task " + taskId, e));
} catch (Exception unserializable) {
// We should try-catch `ObjectSerializer.serialize` here. Because otherwise if the
// application-level exception is not serializable. `ObjectSerializer.serialize`
// will throw an exception and crash the worker.
// Refer to the case `TaskExceptionTest.java` for more details.
LOGGER.warn("Failed to serialize the exception to a RayObject.", unserializable);
serializedException =
ObjectSerializer.serialize(
new RayTaskException(
String.format(
"Error executing task %s with the exception: %s",
taskId, ExceptionUtils.getStackTrace(e))));
}
Preconditions.checkNotNull(serializedException);
returnObjects.add(serializedException);
}
Preconditions.checkNotNull(serializedException);
returnObjects.add(serializedException);
} else {
returnObjects.add(
ObjectSerializer.serialize(
new RayTaskException(
String.format(
"Function %s of task %s doesn't exist",
String.join(".", rayFunctionInfo), taskId),
e)));
}
} else {
actorContext.actorCreationException = e;

View file

@ -1,14 +1,10 @@
package io.ray.runtime.util;
import com.google.common.base.Strings;
import java.io.IOException;
import java.net.DatagramSocket;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.net.ServerSocket;
import java.util.Enumeration;
import java.util.concurrent.ThreadLocalRandom;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -16,9 +12,6 @@ public class NetworkUtil {
private static final Logger LOGGER = LoggerFactory.getLogger(NetworkUtil.class);
private static final int MIN_PORT = 10000;
private static final int MAX_PORT = 65535;
public static String getIpAddress(String interfaceName) {
try {
Enumeration<NetworkInterface> interfaces = NetworkInterface.getNetworkInterfaces();
@ -50,29 +43,4 @@ public class NetworkUtil {
return "127.0.0.1";
}
public static int getUnusedPort() {
while (true) {
int port = ThreadLocalRandom.current().nextInt(MAX_PORT - MIN_PORT) + MIN_PORT;
if (isPortAvailable(port)) {
return port;
}
}
}
public static boolean isPortAvailable(int port) {
if (port < 1 || port > 65535) {
throw new IllegalArgumentException("Invalid start port: " + port);
}
try (ServerSocket ss = new ServerSocket(port);
DatagramSocket ds = new DatagramSocket(port)) {
ss.setReuseAddress(true);
ds.setReuseAddress(true);
return true;
} catch (IOException ignored) {
/* should not be thrown */
return false;
}
}
}

View file

@ -1,64 +0,0 @@
package io.ray.runtime.util;
import java.util.HashMap;
import java.util.Map;
public class ResourceUtil {
public static final String CPU_LITERAL = "CPU";
public static final String GPU_LITERAL = "GPU";
/**
* Convert resources map to a string that is used for the command line argument of starting
* raylet.
*
* @param resources The resources map to be converted.
* @return The starting-raylet command line argument, like "CPU,4,GPU,0".
*/
public static String getResourcesStringFromMap(Map<String, Double> resources) {
StringBuilder builder = new StringBuilder();
if (resources != null) {
int count = 1;
for (Map.Entry<String, Double> entry : resources.entrySet()) {
builder.append(entry.getKey()).append(",").append(entry.getValue());
if (count != resources.size()) {
builder.append(",");
}
count++;
}
}
return builder.toString();
}
/**
* Parse the static resources configure field and convert to the resources map.
*
* @param resources The static resources string to be parsed.
* @return The map whose key represents the resource name and the value represents the resource
* quantity.
* @throws IllegalArgumentException If the resources string's format does match, it will throw an
* IllegalArgumentException.
*/
public static Map<String, Double> getResourcesMapFromString(String resources)
throws IllegalArgumentException {
Map<String, Double> ret = new HashMap<>();
if (resources != null) {
String[] items = resources.split(",");
for (String item : items) {
String trimItem = item.trim();
if (trimItem.isEmpty()) {
continue;
}
String[] resourcePair = trimItem.split(":");
if (resourcePair.length != 2) {
throw new IllegalArgumentException("Format of static resources configure is invalid.");
}
final String resourceName = resourcePair[0].trim();
final Double resourceValue = Double.valueOf(resourcePair[1].trim());
ret.put(resourceName, resourceValue);
}
}
return ret;
}
}

View file

@ -1,10 +1,8 @@
package io.ray.test;
import com.google.common.collect.ImmutableList;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import org.testng.Assert;
import org.testng.annotations.Test;
@ -33,7 +31,6 @@ public class ActorConcurrentCallTest extends BaseTest {
ObjectRef<String> obj2 = actor.task(ConcurrentActor::countDown).remote();
ObjectRef<String> obj3 = actor.task(ConcurrentActor::countDown).remote();
List<Integer> expectedResult = ImmutableList.of(1, 2, 3);
Assert.assertEquals(obj1.get(), "ok");
Assert.assertEquals(obj2.get(), "ok");
Assert.assertEquals(obj3.get(), "ok");

View file

@ -4,8 +4,10 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
import io.ray.runtime.task.ArgumentsBuilder;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.testng.Assert;
@ -14,6 +16,9 @@ import org.testng.annotations.Test;
/** Test Ray.call API */
public class RayCallTest extends BaseTest {
private static final byte[] LARGE_RAW_DATA =
new byte[ArgumentsBuilder.LARGEST_SIZE_PASS_BY_VALUE + 100];
private static int testInt(int val) {
return val;
}
@ -22,6 +27,10 @@ public class RayCallTest extends BaseTest {
return val;
}
private static byte[] testBytes(byte[] val) {
return val;
}
private static short testShort(short val) {
return val;
}
@ -98,6 +107,12 @@ public class RayCallTest extends BaseTest {
// Assert.assertEquals(((int) Ray.get(randomObjectId, Integer.class)), 1);
}
@Test
public void testBytesType() {
Assert.assertEquals(
"123".getBytes(), Ray.task(RayCallTest::testBytes, "123".getBytes()).remote().get());
}
private static int testNoParam() {
return 0;
}
@ -138,4 +153,13 @@ public class RayCallTest extends BaseTest {
Assert.assertEquals(
6, (int) Ray.task(RayCallTest::testSixParams, 1, 1, 1, 1, 1, 1).remote().get());
}
private static Boolean testLargeRawData(byte[] data) {
return Arrays.equals(data, LARGE_RAW_DATA);
}
@Test
public void testLargeRawDataArgument() {
Assert.assertTrue(Ray.task(RayCallTest::testLargeRawData, LARGE_RAW_DATA).remote().get());
}
}

View file

@ -4,6 +4,7 @@ import io.ray.api.ActorHandle;
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import java.nio.ByteBuffer;
import java.util.Arrays;
import org.testng.Assert;
@ -29,12 +30,14 @@ public class RuntimeContextTest extends BaseTest {
@Test
public void testRuntimeContextInDriver() {
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
Assert.assertNotEquals(Ray.getRuntimeContext().getCurrentTaskId(), TaskId.NIL);
}
public static class RuntimeContextTester {
public String testRuntimeContext(ActorId actorId) {
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
Assert.assertNotEquals(Ray.getRuntimeContext().getCurrentTaskId(), TaskId.NIL);
Assert.assertEquals(actorId, Ray.getRuntimeContext().getCurrentActorId());
return "ok";
}