diff --git a/python/ray/services.py b/python/ray/services.py index 297c7e022..79f4b22e2 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -388,6 +388,7 @@ def check_version_info(redis_client): def start_credis(node_ip_address, + redis_address, port=None, redirect_output=False, cleanup=True): @@ -400,6 +401,8 @@ def start_credis(node_ip_address, Args: node_ip_address: The IP address of the current node. This is only used for recording the log filenames in Redis. + redis_address (str): The IP address and port of the primary redis + server. port (int): If provided, the primary Redis shard will be started on this port. redirect_output (bool): True if output should be redirected to a file @@ -439,7 +442,15 @@ def start_credis(node_ip_address, master_client.execute_command("MASTER.ADD", node_ip_address, head_port) master_client.execute_command("MASTER.ADD", node_ip_address, tail_port) - return address(node_ip_address, master_port) + credis_address = address(node_ip_address, master_port) + + # Register credis master in redis + redis_ip_address, redis_port = redis_address.split(":") + redis_client = redis.StrictRedis(host=redis_ip_address, + port=redis_port) + redis_client.set("credis_address", credis_address) + + return credis_address def start_redis(node_ip_address, @@ -1142,7 +1153,7 @@ def start_ray_processes(address_info=None, address_info["redis_address"] = redis_address if "RAY_USE_NEW_GCS" in os.environ: credis_address = start_credis( - node_ip_address, cleanup=cleanup) + node_ip_address, redis_address, cleanup=cleanup) address_info["credis_address"] = credis_address time.sleep(0.1) diff --git a/test/credis_test.py b/test/credis_test.py index 7f6ece0d0..08cc7fa96 100644 --- a/test/credis_test.py +++ b/test/credis_test.py @@ -14,17 +14,21 @@ import ray "Tests functionality of the new GCS.") class CredisTest(unittest.TestCase): def setUp(self): - self.config = ray.init() + self.config = ray.init(num_workers=0) def tearDown(self): ray.worker.cleanup() def test_credis_started(self): assert "credis_address" in self.config - address, port = self.config["credis_address"].split(":") - redis_client = redis.StrictRedis(host=address, - port=port) - assert redis_client.ping() is True + credis_address, credis_port = self.config["credis_address"].split(":") + credis_client = redis.StrictRedis(host=credis_address, + port=credis_port) + assert credis_client.ping() is True + + redis_client = ray.worker.global_state.redis_client + addr = redis_client.get("credis_address").decode("ascii") + assert addr == self.config["credis_address"] if __name__ == "__main__":