fix java UT about multi-threading (#8014)

This commit is contained in:
Kai Yang 2020-04-27 15:11:22 +08:00 committed by GitHub
parent 7ec2223c84
commit 1d5bceddf0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 56 deletions

View file

@ -36,6 +36,7 @@ import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -227,21 +228,27 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
if (unreadyObjects.isEmpty()) {
// If all dependencies are ready, execute this task.
ExecutorService executorService;
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
ExecutorService actorExecutorService = Executors.newSingleThreadExecutor();
executorService = Executors.newSingleThreadExecutor();
synchronized (actorTaskExecutorServices) {
actorTaskExecutorServices.put(getActorId(taskSpec), actorExecutorService);
actorTaskExecutorServices.put(getActorId(taskSpec), executorService);
}
actorExecutorService.submit(runnable);
} else if (taskSpec.getType() == TaskType.ACTOR_TASK) {
synchronized (actorTaskExecutorServices) {
ExecutorService actorExecutorService =
actorTaskExecutorServices.get(getActorId(taskSpec));
actorExecutorService.submit(runnable);
executorService = actorTaskExecutorServices.get(getActorId(taskSpec));
}
} else {
// Normal task.
normalTaskExecutorService.submit(runnable);
executorService = normalTaskExecutorService;
}
try {
executorService.submit(runnable);
} catch (RejectedExecutionException e) {
if (executorService.isShutdown()) {
LOGGER.warn("Ignore task submission due to the ExecutorService is shutdown. Task: {}",
taskSpec);
}
}
} else {
// If some dependencies aren't ready yet, put this task in waiting list.

View file

@ -136,10 +136,13 @@ public class MultiThreadingTest extends BaseTest {
Assert.assertEquals(actorId, actorIdTester.getId());
}
static boolean testMissingWrapRunnable() throws InterruptedException {
/**
* Call this method each time to avoid hitting the cache in {@link RayObject#get()}.
*/
static Runnable[] generateRunnables() {
final RayObject<Integer> fooObject = Ray.put(1);
final RayActor<Echo> fooActor = Ray.createActor(Echo::new);
final Runnable[] runnables = new Runnable[]{
return new Runnable[]{
() -> Ray.put(1),
() -> Ray.get(fooObject.getId(), fooObject.getType()),
fooObject::get,
@ -149,58 +152,72 @@ public class MultiThreadingTest extends BaseTest {
() -> Ray.createActor(Echo::new),
() -> fooActor.call(Echo::echo, 1),
};
}
// It's OK to run them in main thread.
for (Runnable runnable : runnables) {
runnable.run();
}
Exception[] exception = new Exception[1];
Thread thread = new Thread(Ray.wrapRunnable(() -> {
try {
// It would be OK to run them in another thread if wrapped the runnable.
for (Runnable runnable : runnables) {
runnable.run();
}
} catch (Exception ex) {
exception[0] = ex;
static boolean testMissingWrapRunnable() throws InterruptedException {
{
Runnable[] runnables = generateRunnables();
// It's OK to run them in main thread.
for (Runnable runnable : runnables) {
runnable.run();
}
}));
thread.start();
thread.join();
if (exception[0] != null) {
throw new RuntimeException("Exception occurred in thread.", exception[0]);
}
thread = new Thread(() -> {
try {
// It wouldn't be OK to run them in another thread if not wrapped the runnable.
for (Runnable runnable : runnables) {
Assert.expectThrows(RayException.class, runnable::run);
Throwable[] throwable = new Throwable[1];
{
Runnable[] runnables = generateRunnables();
Thread thread = new Thread(Ray.wrapRunnable(() -> {
try {
// It would be OK to run them in another thread if wrapped the runnable.
for (Runnable runnable : runnables) {
runnable.run();
}
} catch (Throwable ex) {
throwable[0] = ex;
}
} catch (Exception ex) {
exception[0] = ex;
}));
thread.start();
thread.join();
if (throwable[0] != null) {
throw new RuntimeException("Exception occurred in thread.", throwable[0]);
}
});
thread.start();
thread.join();
if (exception[0] != null) {
throw new RuntimeException("Exception occurred in thread.", exception[0]);
}
Runnable[] wrappedRunnables = new Runnable[runnables.length];
for (int i = 0; i < runnables.length; i++) {
wrappedRunnables[i] = Ray.wrapRunnable(runnables[i]);
}
// It would be OK to run the wrapped runnables in the current thread.
for (Runnable runnable : wrappedRunnables) {
runnable.run();
{
Runnable[] runnables = generateRunnables();
Thread thread = new Thread(() -> {
try {
// It wouldn't be OK to run them in another thread if not wrapped the runnable.
for (Runnable runnable : runnables) {
Assert.expectThrows(RayException.class, runnable::run);
}
} catch (Throwable ex) {
throwable[0] = ex;
}
});
thread.start();
thread.join();
if (throwable[0] != null) {
throw new RuntimeException("Exception occurred in thread.", throwable[0]);
}
}
// It would be OK to invoke Ray APIs after executing a wrapped runnable in the current thread.
wrappedRunnables[0].run();
runnables[0].run();
{
Runnable[] runnables = generateRunnables();
Runnable[] wrappedRunnables = new Runnable[runnables.length];
for (int i = 0; i < runnables.length; i++) {
wrappedRunnables[i] = Ray.wrapRunnable(runnables[i]);
}
// It would be OK to run the wrapped runnables in the current thread.
for (Runnable runnable : wrappedRunnables) {
runnable.run();
}
// It would be OK to invoke Ray APIs after executing a wrapped runnable in the current thread.
wrappedRunnables[0].run();
runnables[0].run();
}
// Return true here to make the Ray.call returns an RayObject.
return true;
@ -219,19 +236,19 @@ public class MultiThreadingTest extends BaseTest {
@Test
public void testGetAndSetAsyncContext() throws InterruptedException {
Object asyncContext = Ray.getAsyncContext();
Exception[] exception = new Exception[1];
Throwable[] throwable = new Throwable[1];
Thread thread = new Thread(() -> {
try {
Ray.setAsyncContext(asyncContext);
Ray.put(1);
} catch (Exception ex) {
exception[0] = ex;
} catch (Throwable ex) {
throwable[0] = ex;
}
});
thread.start();
thread.join();
if (exception[0] != null) {
throw new RuntimeException("Exception occurred in thread.", exception[0]);
if (throwable[0] != null) {
throw new RuntimeException("Exception occurred in thread.", throwable[0]);
}
}