[Xlang][Java] Fix Java overrided default method cannot be invoked. (#21491)

In Xlang(Python call Java), a Java method which overrides a `default` method of the super class is not able to be invoked successfully, due to we treat it as overloaded method instead of overrided method. This PR correctly handle it at the case it overrides a `default` method.

Before this PR, the following usage is not able to be invoked from Python -> Java.
```Java
public interface ExampleInterface {
  default String echo(String inp) {
    return inp;
  }
}
public class ExampleImpl implements ExampleInterface {
  @Override
  public String echo(String inp) {
    return inp + " echo";
  }
}
```
```python
/// Invoke it in Python.
cls = ray.java_actor_class("io.ray.serve.util.ExampleImpl")
handle = cls.remote()
print(ray.get(handle.echo.remote("hi")))
```
This commit is contained in:
Qing Wang 2022-01-11 23:11:24 +08:00 committed by GitHub
parent 9ac34ecc94
commit bb647626cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 54 additions and 11 deletions

View file

@ -157,7 +157,7 @@ public class FunctionManager {
/** The job's corresponding class loader. */
final ClassLoader classLoader;
/** Functions per class, per function name + type descriptor. */
ConcurrentMap<String, Map<Pair<String, String>, RayFunction>> functions;
ConcurrentMap<String, Map<Pair<String, String>, Pair<RayFunction, Boolean>>> functions;
JobFunctionTable(ClassLoader classLoader) {
this.classLoader = classLoader;
@ -165,7 +165,8 @@ public class FunctionManager {
}
RayFunction getFunction(JavaFunctionDescriptor descriptor) {
Map<Pair<String, String>, RayFunction> classFunctions = functions.get(descriptor.className);
Map<Pair<String, String>, Pair<RayFunction, Boolean>> classFunctions =
functions.get(descriptor.className);
if (classFunctions == null) {
synchronized (this) {
classFunctions = functions.get(descriptor.className);
@ -176,7 +177,7 @@ public class FunctionManager {
}
}
final Pair<String, String> key = ImmutablePair.of(descriptor.name, descriptor.signature);
RayFunction func = classFunctions.get(key);
RayFunction func = classFunctions.get(key).getLeft();
if (func == null) {
if (classFunctions.containsKey(key)) {
throw new RuntimeException(
@ -192,9 +193,11 @@ public class FunctionManager {
}
/** Load all functions from a class. */
Map<Pair<String, String>, RayFunction> loadFunctionsForClass(String className) {
Map<Pair<String, String>, Pair<RayFunction, Boolean>> loadFunctionsForClass(String className) {
// If RayFunction is null, the function is overloaded.
Map<Pair<String, String>, RayFunction> map = new HashMap<>();
// The value of this map is a pair of <rayFunction, isDefault>.
// The `isDefault` is used to mark if the method is a marked as default keyword.
Map<Pair<String, String>, Pair<RayFunction, Boolean>> map = new HashMap<>();
try {
Class clazz = Class.forName(className, true, classLoader);
List<Executable> executables = new ArrayList<>();
@ -227,13 +230,18 @@ public class FunctionManager {
RayFunction rayFunction =
new RayFunction(
e, classLoader, new JavaFunctionDescriptor(className, methodName, signature));
map.put(ImmutablePair.of(methodName, signature), rayFunction);
final boolean isDefault = e instanceof Method && ((Method) e).isDefault();
map.put(
ImmutablePair.of(methodName, signature), ImmutablePair.of(rayFunction, isDefault));
// For cross language call java function without signature
final Pair<String, String> emptyDescriptor = ImmutablePair.of(methodName, "");
if (map.containsKey(emptyDescriptor)) {
map.put(emptyDescriptor, null); // Mark this function as overloaded.
/// default method is not overloaded, so we should filter it.
if (map.containsKey(emptyDescriptor) && !map.get(emptyDescriptor).getRight()) {
map.put(
emptyDescriptor,
ImmutablePair.of(null, false)); // Mark this function as overloaded.
} else {
map.put(emptyDescriptor, rayFunction);
map.put(emptyDescriptor, ImmutablePair.of(rayFunction, isDefault));
}
}
} catch (Exception e) {

View file

@ -183,7 +183,7 @@ public class FunctionManagerTest {
@Test
public void testLoadFunctionTableForClass() {
JobFunctionTable functionTable = new JobFunctionTable(getClass().getClassLoader());
Map<Pair<String, String>, RayFunction> res =
Map<Pair<String, String>, Pair<RayFunction, Boolean>> res =
functionTable.loadFunctionsForClass(ChildClass.class.getName());
// The result should be 4 entries:
// 1, the constructor with signature
@ -211,7 +211,7 @@ public class FunctionManagerTest {
overloadFunctionDescriptorDouble.signature)));
Assert.assertTrue(res.containsKey(ImmutablePair.of(overloadFunctionDescriptorInt.name, "")));
Pair<String, String> overloadKey = ImmutablePair.of(overloadFunctionDescriptorInt.name, "");
RayFunction func = res.get(overloadKey);
RayFunction func = res.get(overloadKey).getLeft();
// The function is overloaded.
Assert.assertTrue(res.containsKey(overloadKey));
Assert.assertNull(func);

View file

@ -396,4 +396,15 @@ public class CrossLanguageInvocationTest extends BaseTest {
private byte[] value;
}
public void testPyCallJavaOeveridedMethodWithDefault() {
ObjectRef<Object> res =
Ray.task(
PyFunction.of(
PYTHON_MODULE,
"py_func_call_java_overrided_method_with_default_keyword",
Object.class))
.remote();
Assert.assertEquals("hi", res.get());
}
}

View file

@ -0,0 +1,9 @@
package io.ray.test;
public class ExampleImpl implements ExampleInterface {
@Override
public String echo(String str) {
return str;
}
}

View file

@ -0,0 +1,8 @@
package io.ray.test;
public interface ExampleInterface {
default String echo(String str) {
return "default" + str;
}
}

View file

@ -131,3 +131,10 @@ def py_func_get_and_invoke_named_actor():
java_named_actor = ray.get_actor("java_named_actor")
assert ray.get(java_named_actor.concat.remote(b"world")) == b"helloworld"
return b"true"
@ray.remote
def py_func_call_java_overrided_method_with_default_keyword():
cls = ray.java_actor_class("io.ray.test.ExampleImpl")
handle = cls.remote()
return ray.get(handle.echo.remote("hi"))