[Deploy]Don't start shard redis in local if we specify external redis. (#17856)

* Don't start shard redis in local if we specify external redis

* lint

* reuse primary as shard

* add test

* lint

* lint

* lint
This commit is contained in:
Tao Wang 2021-08-27 16:45:09 +08:00 committed by GitHub
parent a25cc47399
commit 7620afb8be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 20 deletions

View file

@ -832,17 +832,6 @@ def start_redis(node_ip_address,
addresses for the remaining shards, and the processes that were
started.
"""
if len(redirect_files) != 1 + num_redis_shards:
raise ValueError("The number of redirect file pairs should be equal "
"to the number of redis shards (including the "
"primary shard) we will start.")
if redis_shard_ports is None:
redis_shard_ports = num_redis_shards * [None]
elif len(redis_shard_ports) != num_redis_shards:
raise RuntimeError("The number of Redis shard ports does not match "
"the number of Redis shards.")
processes = []
if external_addresses is not None:
@ -855,6 +844,17 @@ def start_redis(node_ip_address,
# Deleting the key to avoid duplicated rpush.
primary_redis_client.delete("RedisShards")
else:
if len(redirect_files) != 1 + num_redis_shards:
raise ValueError(
"The number of redirect file pairs should be equal "
"to the number of redis shards (including the "
"primary shard) we will start.")
if redis_shard_ports is None:
redis_shard_ports = num_redis_shards * [None]
elif len(redis_shard_ports) != num_redis_shards:
raise RuntimeError(
"The number of Redis shard ports does not match "
"the number of Redis shards.")
redis_executable = REDIS_EXECUTABLE
redis_stdout_file, redis_stderr_file = redirect_files[0]
@ -910,7 +910,7 @@ def start_redis(node_ip_address,
# other Redis shards at a high, random port.
last_shard_port = new_port(denylist=port_denylist) - 1
for i in range(num_redis_shards):
if external_addresses is not None and len(external_addresses) > 1:
if external_addresses is not None:
shard_address = external_addresses[i + 1]
else:
redis_stdout_file, redis_stderr_file = redirect_files[i + 1]
@ -1209,7 +1209,7 @@ def start_dashboard(require_dashboard,
f"--log-dir={logdir}", f"--logging-rotate-bytes={max_bytes}",
f"--logging-rotate-backup-count={backup_count}"
]
if redis_password:
if redis_password is not None:
command += ["--redis-password", redis_password]
process_info = start_ray_process(
command,

View file

@ -693,10 +693,12 @@ class Node:
def start_redis(self):
"""Start the Redis servers."""
assert self._redis_address is None
redis_log_files = [self.get_log_file_handles("redis", unique=True)]
for i in range(self._ray_params.num_redis_shards):
redis_log_files.append(
self.get_log_file_handles(f"redis-shard_{i}", unique=True))
redis_log_files = []
if self._ray_params.external_addresses is None:
redis_log_files = [self.get_log_file_handles("redis", unique=True)]
for i in range(self._ray_params.num_redis_shards):
redis_log_files.append(
self.get_log_file_handles(f"redis-shard_{i}", unique=True))
(self._redis_address, redis_shards,
process_infos) = ray._private.services.start_redis(

View file

@ -576,7 +576,7 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports,
num_redis_shards = None
# Start Ray on the head node.
if redis_shard_ports is not None:
if redis_shard_ports is not None and address is None:
redis_shard_ports = redis_shard_ports.split(",")
# Infer the number of Redis shards from the ports if the number is
# not provided.
@ -588,6 +588,10 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports,
"If the primary one is not reachable, we starts new one(s) "
"with `{}` in local.", cf.bold("--address"), cf.bold("--port"))
external_addresses = address.split(",")
# We reuse primary redis as sharding when there's only one
# instance provided.
if len(external_addresses) == 1:
external_addresses.append(external_addresses[0])
reachable = False
try:
[primary_redis_ip, port] = external_addresses[0].split(":")
@ -604,8 +608,7 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports,
if reachable:
ray_params.update_if_absent(
external_addresses=external_addresses)
if len(external_addresses) > 1:
num_redis_shards = len(external_addresses) - 1
num_redis_shards = len(external_addresses) - 1
if redis_password == ray_constants.REDIS_DEFAULT_PASSWORD:
cli_logger.warning(
"`{}` should not be specified as empty string if "

View file

@ -9,6 +9,7 @@ import json
import ray
from ray.cluster_utils import Cluster
from ray._private.services import REDIS_EXECUTABLE, _start_redis_instance
from ray._private.test_utils import init_error_pubsub
import ray._private.gcs_utils as gcs_utils
@ -211,6 +212,24 @@ def call_ray_start(request):
subprocess.check_call(["ray", "stop"])
@pytest.fixture
def call_ray_start_with_external_redis(request):
ports = getattr(request, "param", "6379")
port_list = ports.split(",")
for port in port_list:
_start_redis_instance(REDIS_EXECUTABLE, int(port), password="123")
address_str = ",".join(map(lambda x: "localhost:" + x, port_list))
cmd = f"ray start --head --address={address_str} --redis-password=123"
subprocess.call(cmd.split(" "))
yield address_str.split(",")[0]
# Disconnect from the Ray cluster.
ray.shutdown()
# Kill the Ray cluster.
subprocess.check_call(["ray", "stop"])
@pytest.fixture
def call_ray_stop_only():
yield

View file

@ -0,0 +1,42 @@
import os
import pytest
import sys
import ray
@pytest.mark.parametrize(
"call_ray_start_with_external_redis", [
"6379",
"6379,6380",
"6379,6380,6381",
],
indirect=True)
def test_using_hostnames(call_ray_start_with_external_redis):
ray.init(address="127.0.0.1:6379", _redis_password="123")
@ray.remote
def f():
return 1
assert ray.get(f.remote()) == 1
@ray.remote
class Counter:
def __init__(self):
self.count = 0
def inc_and_get(self):
self.count += 1
return self.count
counter = Counter.remote()
assert ray.get(counter.inc_and_get.remote()) == 1
if __name__ == "__main__":
import pytest
# Make subprocess happy in bazel.
os.environ["LC_ALL"] = "en_US.UTF-8"
os.environ["LANG"] = "en_US.UTF-8"
sys.exit(pytest.main(["-v", __file__]))