[Java] Add java api overload doc and test (#14204)

This commit is contained in:
chaokunyang 2021-02-19 19:46:35 +08:00 committed by GitHub
parent ec344b87c7
commit f8a36eb350
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 0 deletions

View file

@ -206,6 +206,66 @@ Ray supports resource specific accelerator types. The `accelerator_type` field c
See `ray.util.accelerators` to see available accelerator types. Current automatically detected accelerator types include Nvidia GPUs.
Overloaded Functions
--------------------
Ray Java API supports calling overloaded java functions remotely. However, due to the limitation of Java compiler type inference, one must explicitly cast the method reference to the correct function type. For example, consider the following.
Overloaded normal task call:
.. code:: java
public static class MyRayApp {
public static int overloadFunction() {
return 1;
}
public static int overloadFunction(int x) {
return x;
}
}
// Invoke overloaded functions.
Assert.assertEquals((int) Ray.task((RayFunc0<Integer>) MyRayApp::overloadFunction).remote().get(), 1);
Assert.assertEquals((int) Ray.task((RayFunc1<Integer, Integer>) MyRayApp::overloadFunction, 2).remote().get(), 2);
Overloaded actor task call:
.. code:: java
public static class Counter {
protected int value = 0;
public int increment() {
this.value += 1;
return this.value;
}
}
public static class CounterOverloaded extends Counter {
public int increment(int diff) {
super.value += diff;
return super.value;
}
public int increment(int diff1, int diff2) {
super.value += diff1 + diff2;
return super.value;
}
}
.. code:: java
ActorHandle<CounterOverloaded> a = Ray.actor(CounterOverloaded::new).remote();
// Call an overloaded actor method by super class method reference.
Assert.assertEquals((int) a.task(Counter::increment).remote().get(), 1);
// Call an overloaded actor method, cast method reference first.
a.task((RayFunc1<CounterOverloaded, Integer>) CounterOverloaded::increment).remote();
a.task((RayFunc2<CounterOverloaded, Integer, Integer>) CounterOverloaded::increment, 10).remote();
a.task((RayFunc3<CounterOverloaded, Integer, Integer, Integer>) CounterOverloaded::increment, 10, 10).remote();
Assert.assertEquals((int) a.task(Counter::increment).remote().get(), 33);
Nested Remote Functions
-----------------------

View file

@ -2,6 +2,9 @@ package io.ray.docdemo;
import io.ray.api.ActorHandle;
import io.ray.api.Ray;
import io.ray.api.function.RayFunc1;
import io.ray.api.function.RayFunc2;
import io.ray.api.function.RayFunc3;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.testng.Assert;
@ -33,6 +36,18 @@ public class UsingActorsDemo {
}
}
public static class CounterOverloaded extends Counter {
public int increment(int diff) {
super.value += diff;
return super.value;
}
public int increment(int diff1, int diff2) {
super.value += diff1 + diff2;
return super.value;
}
}
public static class CounterFactory {
public static Counter createCounter() {
@ -71,6 +86,19 @@ public class UsingActorsDemo {
Assert.assertEquals((int) a.task(Counter::increment).remote().get(), 11);
}
{
ActorHandle<CounterOverloaded> a = Ray.actor(CounterOverloaded::new).remote();
// Call an overloaded actor method by super class method reference.
Assert.assertEquals((int) a.task(Counter::increment).remote().get(), 1);
// Call an overloaded actor method, cast method reference first.
a.task((RayFunc1<CounterOverloaded, Integer>) CounterOverloaded::increment).remote();
a.task((RayFunc2<CounterOverloaded, Integer, Integer>) CounterOverloaded::increment, 10)
.remote();
RayFunc3<CounterOverloaded, Integer, Integer, Integer> f = CounterOverloaded::increment;
a.task(f, 10, 10).remote();
Assert.assertEquals((int) a.task(Counter::increment).remote().get(), 33);
}
{
Ray.actor(GpuActor::new).setResource("CPU", 2.0).setResource("GPU", 0.5).remote();
}

View file

@ -5,6 +5,8 @@ import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.WaitResult;
import io.ray.api.function.RayFunc0;
import io.ray.api.function.RayFunc1;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
@ -33,6 +35,14 @@ public class WalkthroughDemo {
public static int functionWithAnArgument(int value) {
return value + 1;
}
public static int overloadFunction() {
return 1;
}
public static int overloadFunction(int x) {
return x;
}
}
public static void demoTasks() {
@ -51,6 +61,13 @@ public class WalkthroughDemo {
Ray.task(MyRayApp::slowFunction).remote();
}
// Invoke overloaded functions.
Assert.assertEquals(
(int) Ray.task((RayFunc0<Integer>) MyRayApp::overloadFunction).remote().get(), 1);
Assert.assertEquals(
(int) Ray.task((RayFunc1<Integer, Integer>) MyRayApp::overloadFunction, 2).remote().get(),
2);
ObjectRef<Integer> objRef1 = Ray.task(MyRayApp::myFunction).remote();
Assert.assertTrue(objRef1.get() == 1);