Allow setting redis shard ports through ray start (also object store memory). (#1581)

* Allow passing in --object-store-memory to ray start.

* Allow setting ports for the redis shards.

* Reorder arguments and infer number of shards from ports.

* Move code block into only the head node case.

* Add test.
This commit is contained in:
Robert Nishihara 2018-02-22 11:05:37 -08:00 committed by Philipp Moritz
parent a3b44309dd
commit 330159d8bd
3 changed files with 62 additions and 11 deletions

View file

@ -53,8 +53,14 @@ def cli():
@click.option("--redis-max-clients", required=False, type=int, @click.option("--redis-max-clients", required=False, type=int,
help=("If provided, attempt to configure Redis with this " help=("If provided, attempt to configure Redis with this "
"maximum number of clients.")) "maximum number of clients."))
@click.option("--redis-shard-ports", required=False, type=str,
help="the port to use for the Redis shards other than the "
"primary Redis shard")
@click.option("--object-manager-port", required=False, type=int, @click.option("--object-manager-port", required=False, type=int,
help="the port to use for starting the object manager") help="the port to use for starting the object manager")
@click.option("--object-store-memory", required=False, type=int,
help="the maximum amount of memory (in bytes) to allow the "
"object store to use")
@click.option("--num-workers", required=False, type=int, @click.option("--num-workers", required=False, type=int,
help=("The initial number of workers to start on this node, " help=("The initial number of workers to start on this node, "
"note that the local scheduler may start additional " "note that the local scheduler may start additional "
@ -81,15 +87,10 @@ def cli():
@click.option("--autoscaling-config", required=False, type=str, @click.option("--autoscaling-config", required=False, type=str,
help="the file that contains the autoscaling config") help="the file that contains the autoscaling config")
def start(node_ip_address, redis_address, redis_port, num_redis_shards, def start(node_ip_address, redis_address, redis_port, num_redis_shards,
redis_max_clients, object_manager_port, num_workers, num_cpus, redis_max_clients, redis_shard_ports, object_manager_port,
num_gpus, resources, head, no_ui, block, plasma_directory, object_store_memory, num_workers, num_cpus, num_gpus, resources,
huge_pages, autoscaling_config): head, no_ui, block, plasma_directory, huge_pages,
# Note that we redirect stdout and stderr to /dev/null because otherwise autoscaling_config):
# attempts to print may cause exceptions if a process is started inside of
# an SSH connection and the SSH connection dies. TODO(rkn): This is a
# temporary fix. We should actually redirect stdout and stderr to Redis in
# some way.
# Convert hostnames to numerical IP address. # Convert hostnames to numerical IP address.
if node_ip_address is not None: if node_ip_address is not None:
node_ip_address = services.address_to_ip(node_ip_address) node_ip_address = services.address_to_ip(node_ip_address)
@ -113,6 +114,20 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
if head: if head:
# Start Ray on the head node. # Start Ray on the head node.
if redis_shard_ports is not None:
redis_shard_ports = redis_shard_ports.split(",")
# Infer the number of Redis shards from the ports if the number is
# not provided.
if num_redis_shards is None:
num_redis_shards = len(redis_shard_ports)
# Check that the arguments match.
if len(redis_shard_ports) != num_redis_shards:
raise Exception("If --redis-shard-ports is provided, it must "
"have the form '6380,6381,6382', and the "
"number of ports provided must equal "
"--num-redis-shards (which is 1 if not "
"provided)")
if redis_address is not None: if redis_address is not None:
raise Exception("If --head is passed in, a Redis server will be " raise Exception("If --head is passed in, a Redis server will be "
"started, so a Redis address should not be " "started, so a Redis address should not be "
@ -134,6 +149,8 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
address_info=address_info, address_info=address_info,
node_ip_address=node_ip_address, node_ip_address=node_ip_address,
redis_port=redis_port, redis_port=redis_port,
redis_shard_ports=redis_shard_ports,
object_store_memory=object_store_memory,
num_workers=num_workers, num_workers=num_workers,
cleanup=False, cleanup=False,
redirect_output=True, redirect_output=True,
@ -162,6 +179,9 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
if redis_port is not None: if redis_port is not None:
raise Exception("If --head is not passed in, --redis-port is not " raise Exception("If --head is not passed in, --redis-port is not "
"allowed") "allowed")
if redis_shard_ports is not None:
raise Exception("If --head is not passed in, --redis-shard-ports "
"is not allowed")
if redis_address is None: if redis_address is None:
raise Exception("If --head is not passed in, --redis-address must " raise Exception("If --head is not passed in, --redis-address must "
"be provided.") "be provided.")
@ -200,6 +220,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
redis_address=redis_address, redis_address=redis_address,
object_manager_ports=[object_manager_port], object_manager_ports=[object_manager_port],
num_workers=num_workers, num_workers=num_workers,
object_store_memory=object_store_memory,
cleanup=False, cleanup=False,
redirect_output=True, redirect_output=True,
resources=resources, resources=resources,

View file

@ -369,6 +369,7 @@ def check_version_info(redis_client):
def start_redis(node_ip_address, def start_redis(node_ip_address,
port=None, port=None,
redis_shard_ports=None,
num_redis_shards=1, num_redis_shards=1,
redis_max_clients=None, redis_max_clients=None,
redirect_output=False, redirect_output=False,
@ -381,6 +382,8 @@ def start_redis(node_ip_address,
for recording the log filenames in Redis. for recording the log filenames in Redis.
port (int): If provided, the primary Redis shard will be started on port (int): If provided, the primary Redis shard will be started on
this port. this port.
redis_shard_ports: A list of the ports to use for the non-primary Redis
shards.
num_redis_shards (int): If provided, the number of Redis shards to num_redis_shards (int): If provided, the number of Redis shards to
start, in addition to the primary one. The default value is one start, in addition to the primary one. The default value is one
shard. shard.
@ -403,6 +406,12 @@ def start_redis(node_ip_address,
redis_stdout_file, redis_stderr_file = new_log_files( redis_stdout_file, redis_stderr_file = new_log_files(
"redis", redirect_output) "redis", redirect_output)
if redis_shard_ports is None:
redis_shard_ports = num_redis_shards * [None]
elif len(redis_shard_ports) != num_redis_shards:
raise Exception("The number of Redis shard ports does not match the "
"number of Redis shards.")
assigned_port, _ = start_redis_instance( assigned_port, _ = start_redis_instance(
node_ip_address=node_ip_address, port=port, node_ip_address=node_ip_address, port=port,
redis_max_clients=redis_max_clients, redis_max_clients=redis_max_clients,
@ -425,17 +434,20 @@ def start_redis(node_ip_address,
# Store version information in the primary Redis shard. # Store version information in the primary Redis shard.
_put_version_info_in_redis(redis_client) _put_version_info_in_redis(redis_client)
# Start other Redis shards listening on random ports. Each Redis shard logs # Start other Redis shards. Each Redis shard logs to a separate file,
# to a separate file, prefixed by "redis-<shard number>". # prefixed by "redis-<shard number>".
redis_shards = [] redis_shards = []
for i in range(num_redis_shards): for i in range(num_redis_shards):
redis_stdout_file, redis_stderr_file = new_log_files( redis_stdout_file, redis_stderr_file = new_log_files(
"redis-{}".format(i), redirect_output) "redis-{}".format(i), redirect_output)
redis_shard_port, _ = start_redis_instance( redis_shard_port, _ = start_redis_instance(
node_ip_address=node_ip_address, node_ip_address=node_ip_address,
port=redis_shard_ports[i],
redis_max_clients=redis_max_clients, redis_max_clients=redis_max_clients,
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
cleanup=cleanup) cleanup=cleanup)
if redis_shard_ports[i] is not None:
assert redis_shard_port == redis_shard_ports[i]
shard_address = address(node_ip_address, redis_shard_port) shard_address = address(node_ip_address, redis_shard_port)
redis_shards.append(shard_address) redis_shards.append(shard_address)
# Store redis shard information in the primary redis shard. # Store redis shard information in the primary redis shard.
@ -942,6 +954,7 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
def start_ray_processes(address_info=None, def start_ray_processes(address_info=None,
node_ip_address="127.0.0.1", node_ip_address="127.0.0.1",
redis_port=None, redis_port=None,
redis_shard_ports=None,
num_workers=None, num_workers=None,
num_local_schedulers=1, num_local_schedulers=1,
object_store_memory=None, object_store_memory=None,
@ -969,6 +982,8 @@ def start_ray_processes(address_info=None,
to. If None, then a random port will be chosen. If the key to. If None, then a random port will be chosen. If the key
"redis_address" is in address_info, then this argument will be "redis_address" is in address_info, then this argument will be
ignored. ignored.
redis_shard_ports: A list of the ports to use for the non-primary Redis
shards.
num_workers (int): The number of workers to start. num_workers (int): The number of workers to start.
num_local_schedulers (int): The total number of local schedulers num_local_schedulers (int): The total number of local schedulers
required. This is also the total number of object stores required. required. This is also the total number of object stores required.
@ -1042,6 +1057,7 @@ def start_ray_processes(address_info=None,
if redis_address is None: if redis_address is None:
redis_address, redis_shards = start_redis( redis_address, redis_shards = start_redis(
node_ip_address, port=redis_port, node_ip_address, port=redis_port,
redis_shard_ports=redis_shard_ports,
num_redis_shards=num_redis_shards, num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients, redis_max_clients=redis_max_clients,
redirect_output=True, redirect_output=True,
@ -1203,6 +1219,7 @@ def start_ray_node(node_ip_address,
object_manager_ports=None, object_manager_ports=None,
num_workers=0, num_workers=0,
num_local_schedulers=1, num_local_schedulers=1,
object_store_memory=None,
worker_path=None, worker_path=None,
cleanup=True, cleanup=True,
redirect_output=False, redirect_output=False,
@ -1224,6 +1241,8 @@ def start_ray_node(node_ip_address,
num_local_schedulers (int): The number of local schedulers to start. num_local_schedulers (int): The number of local schedulers to start.
This is also the number of plasma stores and plasma managers to This is also the number of plasma stores and plasma managers to
start. start.
object_store_memory (int): The maximum amount of memory (in bytes) to
let the plasma store use.
worker_path (str): The path of the source code that will be run by the worker_path (str): The path of the source code that will be run by the
worker. worker.
cleanup (bool): If cleanup is true, then the processes started here cleanup (bool): If cleanup is true, then the processes started here
@ -1248,6 +1267,7 @@ def start_ray_node(node_ip_address,
node_ip_address=node_ip_address, node_ip_address=node_ip_address,
num_workers=num_workers, num_workers=num_workers,
num_local_schedulers=num_local_schedulers, num_local_schedulers=num_local_schedulers,
object_store_memory=object_store_memory,
worker_path=worker_path, worker_path=worker_path,
include_log_monitor=True, include_log_monitor=True,
cleanup=cleanup, cleanup=cleanup,
@ -1260,6 +1280,7 @@ def start_ray_node(node_ip_address,
def start_ray_head(address_info=None, def start_ray_head(address_info=None,
node_ip_address="127.0.0.1", node_ip_address="127.0.0.1",
redis_port=None, redis_port=None,
redis_shard_ports=None,
num_workers=0, num_workers=0,
num_local_schedulers=1, num_local_schedulers=1,
object_store_memory=None, object_store_memory=None,
@ -1285,6 +1306,8 @@ def start_ray_head(address_info=None,
to. If None, then a random port will be chosen. If the key to. If None, then a random port will be chosen. If the key
"redis_address" is in address_info, then this argument will be "redis_address" is in address_info, then this argument will be
ignored. ignored.
redis_shard_ports: A list of the ports to use for the non-primary Redis
shards.
num_workers (int): The number of workers to start. num_workers (int): The number of workers to start.
num_local_schedulers (int): The total number of local schedulers num_local_schedulers (int): The total number of local schedulers
required. This is also the total number of object stores required. required. This is also the total number of object stores required.
@ -1326,6 +1349,7 @@ def start_ray_head(address_info=None,
address_info=address_info, address_info=address_info,
node_ip_address=node_ip_address, node_ip_address=node_ip_address,
redis_port=redis_port, redis_port=redis_port,
redis_shard_ports=redis_shard_ports,
num_workers=num_workers, num_workers=num_workers,
num_local_schedulers=num_local_schedulers, num_local_schedulers=num_local_schedulers,
object_store_memory=object_store_memory, object_store_memory=object_store_memory,

View file

@ -216,6 +216,11 @@ class StartRayScriptTest(unittest.TestCase):
"--redis-port", "6379"]) "--redis-port", "6379"])
subprocess.Popen(["ray", "stop"]).wait() subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with redis shard ports specified.
subprocess.check_output(["ray", "start", "--head",
"--redis-shard-ports", "6380,6381,6382"])
subprocess.Popen(["ray", "stop"]).wait()
# Test starting Ray with a node IP address specified. # Test starting Ray with a node IP address specified.
subprocess.check_output(["ray", "start", "--head", subprocess.check_output(["ray", "start", "--head",
"--node-ip-address", "127.0.0.1"]) "--node-ip-address", "127.0.0.1"])
@ -245,6 +250,7 @@ class StartRayScriptTest(unittest.TestCase):
subprocess.check_output(["ray", "start", "--head", subprocess.check_output(["ray", "start", "--head",
"--num-workers", "20", "--num-workers", "20",
"--redis-port", "6379", "--redis-port", "6379",
"--redis-shard-ports", "6380,6381,6382",
"--object-manager-port", "12345", "--object-manager-port", "12345",
"--num-cpus", "100", "--num-cpus", "100",
"--num-gpus", "0", "--num-gpus", "0",