Enable including Java worker for ray start command (#3838)

This commit is contained in:
Wang Qing 2019-02-04 16:23:43 +08:00 committed by Hao Chen
parent 7ef830bef1
commit e1c68a0881
8 changed files with 277 additions and 17 deletions

View file

@ -77,4 +77,38 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<executions>
<execution>
<id>copy-dependencies-to-build</id>
<phase>package</phase>
<goals>
<goal>copy-dependencies</goal>
</goals>
<configuration>
<outputDirectory>${basedir}/../../build/java</outputDirectory>
<overWriteReleases>false</overWriteReleases>
<overWriteSnapshots>false</overWriteSnapshots>
<overWriteIfNewer>true</overWriteIfNewer>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>2.3.1</version>
<configuration>
<outputDirectory>${basedir}/../../build/java</outputDirectory>
</configuration>
</plugin>
</plugins>
</build>
</project> </project>

View file

@ -71,6 +71,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
} }
redisClient = new RedisClient(rayConfig.getRedisAddress()); redisClient = new RedisClient(rayConfig.getRedisAddress());
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName); objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
rayletClient = new RayletClientImpl( rayletClient = new RayletClientImpl(

View file

@ -4,7 +4,6 @@ import org.ray.api.Ray;
import org.ray.api.RayActor; import org.ray.api.RayActor;
import org.ray.api.RayObject; import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote; import org.ray.api.annotation.RayRemote;
import org.ray.api.function.RayFunc2;
import org.ray.api.id.UniqueId; import org.ray.api.id.UniqueId;
import org.ray.runtime.RayActorImpl; import org.ray.runtime.RayActorImpl;
import org.testng.Assert; import org.testng.Assert;

View file

@ -0,0 +1,115 @@
package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import java.io.File;
import java.io.IOException;
import java.lang.ProcessBuilder.Redirect;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.SkipException;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
/**
* Test starting a ray cluster with multi-language support.
*/
public class MultiLanguageClusterTest {
private static final Logger LOGGER = LoggerFactory.getLogger(MultiLanguageClusterTest.class);
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket";
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
@RayRemote
public static String echo(String word) {
return word;
}
/**
* Execute an external command.
* @return Whether the command succeeded.
*/
private boolean executeCommand(List<String> command, int waitTimeoutSeconds) {
try {
LOGGER.info("Executing command: {}", String.join(" ", command));
Process process = new ProcessBuilder(command).redirectOutput(Redirect.INHERIT)
.redirectError(Redirect.INHERIT).start();
process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
return process.exitValue() == 0;
} catch (Exception e) {
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
}
}
@BeforeMethod
public void setUp() {
// Check whether 'ray' command is installed.
boolean rayCommandExists = executeCommand(ImmutableList.of("which", "ray"), 5);
if (!rayCommandExists) {
throw new SkipException("Skipping test, because ray command doesn't exist.");
}
// Delete existing socket files.
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
File file = new File(socket);
if (file.exists()) {
file.delete();
}
}
// Start ray cluster.
final List<String> startCommand = ImmutableList.of(
"ray",
"start",
"--head",
"--redis-port=6379",
"--include-java",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
"--java-worker-options=-classpath ../../build/java/*:../../java/test/target/*"
);
if (!executeCommand(startCommand, 10)) {
throw new RuntimeException("Couldn't start ray cluster.");
}
// Connect to the cluster.
System.setProperty("ray.home", "../..");
System.setProperty("ray.redis.address", "127.0.0.1:6379");
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
Ray.init();
}
@AfterMethod
public void tearDown() {
// Disconnect to the cluster.
Ray.shutdown();
System.clearProperty("ray.home");
System.clearProperty("ray.redis.address");
System.clearProperty("ray.object-store.socket-name");
System.clearProperty("ray.raylet.socket-name");
// Stop ray cluster.
final List<String> stopCommand = ImmutableList.of(
"ray",
"stop"
);
if (!executeCommand(stopCommand, 10)) {
throw new RuntimeException("Couldn't stop ray cluster");
}
}
@Test
public void testMultiLanguageCluster() {
RayObject<String> obj = Ray.call(MultiLanguageClusterTest::echo, "hello");
Assert.assertEquals("hello", obj.get());
}
}

View file

@ -62,6 +62,11 @@ class Node(object):
if head: if head:
ray_params.update_if_absent(num_redis_shards=1, include_webui=True) ray_params.update_if_absent(num_redis_shards=1, include_webui=True)
else:
redis_client = ray.services.create_redis_client(
ray_params.redis_address, ray_params.redis_password)
ray_params.include_java = (
ray.services.include_java_from_redis(redis_client))
self._ray_params = ray_params self._ray_params = ray_params
self._config = (json.loads(ray_params._internal_config) self._config = (json.loads(ray_params._internal_config)
@ -224,7 +229,10 @@ class Node(object):
use_profiler=use_profiler, use_profiler=use_profiler,
stdout_file=stdout_file, stdout_file=stdout_file,
stderr_file=stderr_file, stderr_file=stderr_file,
config=self._config) config=self._config,
include_java=self._ray_params.include_java,
java_worker_options=self._ray_params.java_worker_options,
)
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]

View file

@ -70,6 +70,9 @@ class RayParams(object):
monitor the log files for all processes on this node and push their monitor the log files for all processes on this node and push their
contents to Redis. contents to Redis.
autoscaling_config: path to autoscaling config file. autoscaling_config: path to autoscaling config file.
include_java (bool): If True, the raylet backend can also support
Java worker.
java_worker_options (str): The command options for Java worker.
_internal_config (str): JSON configuration for overriding _internal_config (str): JSON configuration for overriding
RayConfig defaults. For testing purposes ONLY. RayConfig defaults. For testing purposes ONLY.
""" """
@ -106,6 +109,8 @@ class RayParams(object):
temp_dir=None, temp_dir=None,
include_log_monitor=None, include_log_monitor=None,
autoscaling_config=None, autoscaling_config=None,
include_java=False,
java_worker_options=None,
_internal_config=None): _internal_config=None):
self.object_id_seed = object_id_seed self.object_id_seed = object_id_seed
self.redis_address = redis_address self.redis_address = redis_address
@ -136,6 +141,8 @@ class RayParams(object):
self.temp_dir = temp_dir self.temp_dir = temp_dir
self.include_log_monitor = include_log_monitor self.include_log_monitor = include_log_monitor
self.autoscaling_config = autoscaling_config self.autoscaling_config = autoscaling_config
self.include_java = include_java
self.java_worker_options = java_worker_options
self._internal_config = _internal_config self._internal_config = _internal_config
self._check_usage() self._check_usage()
@ -146,7 +153,7 @@ class RayParams(object):
kwargs: The keyword arguments to set corresponding fields. kwargs: The keyword arguments to set corresponding fields.
""" """
for arg in kwargs: for arg in kwargs:
if (hasattr(self, arg)): if hasattr(self, arg):
setattr(self, arg, kwargs[arg]) setattr(self, arg, kwargs[arg])
else: else:
raise ValueError("Invalid RayParams parameter in" raise ValueError("Invalid RayParams parameter in"
@ -161,7 +168,7 @@ class RayParams(object):
kwargs: The keyword arguments to set corresponding fields. kwargs: The keyword arguments to set corresponding fields.
""" """
for arg in kwargs: for arg in kwargs:
if (hasattr(self, arg)): if hasattr(self, arg):
if getattr(self, arg) is None: if getattr(self, arg) is None:
setattr(self, arg, kwargs[arg]) setattr(self, arg, kwargs[arg])
else: else:
@ -180,6 +187,10 @@ class RayParams(object):
"num_gpus instead.") "num_gpus instead.")
if self.num_workers is not None: if self.num_workers is not None:
raise Exception( raise ValueError(
"The 'num_workers' argument is deprecated. Please use " "The 'num_workers' argument is deprecated. Please use "
"'num_cpus' instead.") "'num_cpus' instead.")
if self.include_java is None and self.java_worker_options is not None:
raise ValueError("Should not specify `java-worker-options` "
"without providing `include-java`.")

View file

@ -201,6 +201,17 @@ def cli(logging_level, logging_format):
"--temp-dir", "--temp-dir",
default=None, default=None,
help="manually specify the root temporary dir of the Ray process") help="manually specify the root temporary dir of the Ray process")
@click.option(
"--include-java",
is_flag=True,
default=None,
help="Enable Java worker support.")
@click.option(
"--java-worker-options",
required=False,
default=None,
type=str,
help="Overwrite the options to start Java workers.")
@click.option( @click.option(
"--internal-config", "--internal-config",
default=None, default=None,
@ -212,8 +223,8 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
redis_max_memory, num_workers, num_cpus, num_gpus, resources, head, redis_max_memory, num_workers, num_cpus, num_gpus, resources, head,
no_ui, block, plasma_directory, huge_pages, autoscaling_config, no_ui, block, plasma_directory, huge_pages, autoscaling_config,
no_redirect_worker_output, no_redirect_output, no_redirect_worker_output, no_redirect_output,
plasma_store_socket_name, raylet_socket_name, temp_dir, plasma_store_socket_name, raylet_socket_name, temp_dir, include_java,
internal_config): java_worker_options, internal_config):
# Convert hostnames to numerical IP address. # Convert hostnames to numerical IP address.
if node_ip_address is not None: if node_ip_address is not None:
node_ip_address = services.address_to_ip(node_ip_address) node_ip_address = services.address_to_ip(node_ip_address)
@ -245,6 +256,8 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
plasma_store_socket_name=plasma_store_socket_name, plasma_store_socket_name=plasma_store_socket_name,
raylet_socket_name=raylet_socket_name, raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir, temp_dir=temp_dir,
include_java=include_java,
java_worker_options=java_worker_options,
_internal_config=internal_config) _internal_config=internal_config)
if head: if head:
@ -280,7 +293,9 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
num_redis_shards=num_redis_shards, num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients, redis_max_clients=redis_max_clients,
include_webui=(not no_ui), include_webui=(not no_ui),
autoscaling_config=autoscaling_config) autoscaling_config=autoscaling_config,
include_java=False,
)
node = ray.node.Node(ray_params, head=True, shutdown_at_exit=False) node = ray.node.Node(ray_params, head=True, shutdown_at_exit=False)
redis_address = node.redis_address redis_address = node.redis_address
@ -322,6 +337,10 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
if no_ui: if no_ui:
raise Exception("If --head is not passed in, the --no-ui flag is " raise Exception("If --head is not passed in, the --no-ui flag is "
"not relevant.") "not relevant.")
if include_java is not None:
raise ValueError("--include-java should only be set for the head "
"node.")
redis_ip_address, redis_port = redis_address.split(":") redis_ip_address, redis_port = redis_address.split(":")
# Wait for the Redis server to be started. And throw an exception if we # Wait for the Redis server to be started. And throw an exception if we
@ -348,7 +367,6 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
check_no_existing_redis_clients(ray_params.node_ip_address, check_no_existing_redis_clients(ray_params.node_ip_address,
redis_client) redis_client)
ray_params.update(redis_address=redis_address) ray_params.update(redis_address=redis_address)
node = ray.node.Node(ray_params, head=False, shutdown_at_exit=False) node = ray.node.Node(ray_params, head=False, shutdown_at_exit=False)
logger.info("\nStarted Ray on this node. If you wish to terminate the " logger.info("\nStarted Ray on this node. If you wish to terminate the "
"processes that have been started, run\n\n" "processes that have been started, run\n\n"

View file

@ -21,14 +21,19 @@ import pyarrow
import ray import ray
import ray.ray_constants as ray_constants import ray.ray_constants as ray_constants
from ray.tempfile_services import (get_ipython_notebook_path, get_temp_root, from ray.tempfile_services import (
new_redis_log_file) get_ipython_notebook_path,
get_logs_dir_path,
get_temp_root,
new_redis_log_file,
)
# True if processes are run in the valgrind profiler. # True if processes are run in the valgrind profiler.
RUN_RAYLET_PROFILER = False RUN_RAYLET_PROFILER = False
RUN_PLASMA_STORE_PROFILER = False RUN_PLASMA_STORE_PROFILER = False
# Location of the redis server and module. # Location of the redis server and module.
RAY_HOME = os.path.join(os.path.dirname(__file__), "../..")
REDIS_EXECUTABLE = os.path.join( REDIS_EXECUTABLE = os.path.join(
os.path.abspath(os.path.dirname(__file__)), os.path.abspath(os.path.dirname(__file__)),
"core/src/ray/thirdparty/redis/src/redis-server") "core/src/ray/thirdparty/redis/src/redis-server")
@ -60,6 +65,10 @@ RAYLET_MONITOR_EXECUTABLE = os.path.join(
RAYLET_EXECUTABLE = os.path.join( RAYLET_EXECUTABLE = os.path.join(
os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet") os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet")
DEFAULT_JAVA_WORKER_OPTIONS = "-classpath {}".format(
os.path.join(
os.path.abspath(os.path.dirname(__file__)), "../../../build/java/*"))
# Logger for this module. It should be configured at the entry point # Logger for this module. It should be configured at the entry point
# into the program using Ray. Ray provides a default configuration at # into the program using Ray. Ray provides a default configuration at
# entry/init points. # entry/init points.
@ -93,6 +102,18 @@ def new_port():
return random.randint(10000, 65535) return random.randint(10000, 65535)
def include_java_from_redis(redis_client):
"""This is used for query include_java bool from redis.
Args:
redis_client (StrictRedis): The redis client to GCS.
Returns:
True if this cluster backend enables Java worker.
"""
return redis_client.get("INCLUDE_JAVA") == b"1"
def remaining_processes_alive(): def remaining_processes_alive():
"""See if the remaining processes are alive or not. """See if the remaining processes are alive or not.
@ -249,8 +270,8 @@ def start_ray_process(command,
no redirection should happen, then this should be None. no redirection should happen, then this should be None.
Returns: Returns:
Inormation about the process that was started including a handle to the Information about the process that was started including a handle to
process that was started. the process that was started.
""" """
# Detect which flags are set through environment variables. # Detect which flags are set through environment variables.
valgrind_env_var = "RAY_{}_VALGRIND".format(process_type.upper()) valgrind_env_var = "RAY_{}_VALGRIND".format(process_type.upper())
@ -451,7 +472,8 @@ def start_redis(node_ip_address,
redirect_worker_output=False, redirect_worker_output=False,
password=None, password=None,
use_credis=None, use_credis=None,
redis_max_memory=None): redis_max_memory=None,
include_java=False):
"""Start the Redis global state store. """Start the Redis global state store.
Args: Args:
@ -481,6 +503,8 @@ def start_redis(node_ip_address,
LRU eviction of entries. This only applies to the sharded redis LRU eviction of entries. This only applies to the sharded redis
tables (task, object, and profile tables). By default, this is tables (task, object, and profile tables). By default, this is
capped at 10GB but can be set higher. capped at 10GB but can be set higher.
include_java (bool): If True, the raylet backend can also support
Java worker.
Returns: Returns:
A tuple of the address for the primary Redis shard, a list of A tuple of the address for the primary Redis shard, a list of
@ -555,6 +579,10 @@ def start_redis(node_ip_address,
primary_redis_client.set("RedirectOutput", 1 primary_redis_client.set("RedirectOutput", 1
if redirect_worker_output else 0) if redirect_worker_output else 0)
# put the include_java bool to primary redis-server, so that other nodes
# can access it and know whether or not to enable cross-languages.
primary_redis_client.set("INCLUDE_JAVA", 1 if include_java else 0)
# Store version information in the primary Redis shard. # Store version information in the primary Redis shard.
_put_version_info_in_redis(primary_redis_client) _put_version_info_in_redis(primary_redis_client)
@ -960,7 +988,9 @@ def start_raylet(redis_address,
use_profiler=False, use_profiler=False,
stdout_file=None, stdout_file=None,
stderr_file=None, stderr_file=None,
config=None): config=None,
include_java=False,
java_worker_options=None):
"""Start a raylet, which is a combined local scheduler and object manager. """Start a raylet, which is a combined local scheduler and object manager.
Args: Args:
@ -989,7 +1019,9 @@ def start_raylet(redis_address,
no redirection should happen, then this should be None. no redirection should happen, then this should be None.
config (dict|None): Optional Raylet configuration that will config (dict|None): Optional Raylet configuration that will
override defaults in RayConfig. override defaults in RayConfig.
include_java (bool): If True, the raylet backend can also support
Java worker.
java_worker_options (str): The command options for Java worker.
Returns: Returns:
ProcessInfo for the process that was started. ProcessInfo for the process that was started.
""" """
@ -1016,6 +1048,14 @@ def start_raylet(redis_address,
gcs_ip_address, gcs_port = redis_address.split(":") gcs_ip_address, gcs_port = redis_address.split(":")
if include_java is True:
java_worker_options = (java_worker_options
or DEFAULT_JAVA_WORKER_OPTIONS)
java_worker_command = build_java_worker_command(
java_worker_options, redis_address, plasma_store_name, raylet_name)
else:
java_worker_command = ""
# Create the command that the Raylet will use to start workers. # Create the command that the Raylet will use to start workers.
start_worker_command = ("{} {} " start_worker_command = ("{} {} "
"--node-ip-address={} " "--node-ip-address={} "
@ -1052,7 +1092,7 @@ def start_raylet(redis_address,
resource_argument, resource_argument,
config_str, config_str,
start_worker_command, start_worker_command,
"", # Worker command for Java, not needed for Python. java_worker_command,
redis_password or "", redis_password or "",
get_temp_root(), get_temp_root(),
] ]
@ -1073,6 +1113,40 @@ def start_raylet(redis_address,
return process_info return process_info
def build_java_worker_command(java_worker_options, redis_address,
plasma_store_name, raylet_name):
"""This method assembles the command used to start a Java worker.
Args:
java_worker_options (str): The command options for Java worker.
redis_address (str): Redis address of GCS.
plasma_store_name (str): The name of the plasma store socket to connect
to.
raylet_name (str): The name of the raylet socket to create.
Returns:
The command string for starting Java worker.
"""
assert java_worker_options is not None
command = "java {} ".format(java_worker_options)
if redis_address is not None:
command += "-Dray.redis.address={} ".format(redis_address)
if plasma_store_name is not None:
command += (
"-Dray.object-store.socket-name={} ".format(plasma_store_name))
if raylet_name is not None:
command += "-Dray.raylet.socket-name={} ".format(raylet_name)
command += "-Dray.home={} ".format(RAY_HOME)
command += "-Dray.log-dir={} ".format(get_logs_dir_path())
command += "org.ray.runtime.runner.worker.DefaultWorker"
return command
def determine_plasma_store_config(object_store_memory=None, def determine_plasma_store_config(object_store_memory=None,
plasma_directory=None, plasma_directory=None,
huge_pages=False): huge_pages=False):