mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Wait longer when getting redis shards to initialize global state API. (#786)
This commit is contained in:
parent
1fe49d7676
commit
c394a65ffc
1 changed files with 50 additions and 16 deletions
|
@ -57,6 +57,7 @@ class GlobalState(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Create a GlobalState object."""
|
"""Create a GlobalState object."""
|
||||||
self.redis_client = None
|
self.redis_client = None
|
||||||
|
self.redis_clients = None
|
||||||
|
|
||||||
def _check_connected(self):
|
def _check_connected(self):
|
||||||
"""Check that the object has been initialized before it is used.
|
"""Check that the object has been initialized before it is used.
|
||||||
|
@ -69,32 +70,65 @@ class GlobalState(object):
|
||||||
raise Exception("The ray.global_state API cannot be used before "
|
raise Exception("The ray.global_state API cannot be used before "
|
||||||
"ray.init has been called.")
|
"ray.init has been called.")
|
||||||
|
|
||||||
def _initialize_global_state(self, redis_ip_address, redis_port):
|
if self.redis_clients is None:
|
||||||
|
raise Exception("The ray.global_state API cannot be used before "
|
||||||
|
"ray.init has been called.")
|
||||||
|
|
||||||
|
def _initialize_global_state(self, redis_ip_address, redis_port,
|
||||||
|
timeout=20):
|
||||||
"""Initialize the GlobalState object by connecting to Redis.
|
"""Initialize the GlobalState object by connecting to Redis.
|
||||||
|
|
||||||
|
It's possible that certain keys in Redis may not have been fully
|
||||||
|
populated yet. In this case, we will retry this method until they have
|
||||||
|
been populated or we exceed a timeout.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
redis_ip_address: The IP address of the node that the Redis server
|
redis_ip_address: The IP address of the node that the Redis server
|
||||||
lives on.
|
lives on.
|
||||||
redis_port: The port that the Redis server is listening on.
|
redis_port: The port that the Redis server is listening on.
|
||||||
|
timeout: The maximum amount of time (in seconds) that we should
|
||||||
|
wait for the keys in Redis to be populated.
|
||||||
"""
|
"""
|
||||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||||
port=redis_port)
|
port=redis_port)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
num_redis_shards = None
|
||||||
|
ip_address_ports = []
|
||||||
|
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
# Attempt to get the number of Redis shards.
|
||||||
|
num_redis_shards = self.redis_client.get("NumRedisShards")
|
||||||
|
if num_redis_shards is None:
|
||||||
|
print("Waiting longer for NumRedisShards to be populated.")
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
num_redis_shards = int(num_redis_shards)
|
||||||
|
if (num_redis_shards < 1):
|
||||||
|
raise Exception("Expected at least one Redis shard, found "
|
||||||
|
"{}.".format(num_redis_shards))
|
||||||
|
|
||||||
|
# Attempt to get all of the Redis shards.
|
||||||
|
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
|
||||||
|
end=-1)
|
||||||
|
if len(ip_address_ports) != num_redis_shards:
|
||||||
|
print("Waiting longer for RedisShards to be populated.")
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If we got here then we successfully got all of the information.
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check to see if we timed out.
|
||||||
|
if time.time() - start_time >= timeout:
|
||||||
|
raise Exception("Timed out while attempting to initialize the "
|
||||||
|
"global state. num_redis_shards = {}, "
|
||||||
|
"ip_address_ports = {}"
|
||||||
|
.format(num_redis_shards, ip_address_ports))
|
||||||
|
|
||||||
|
# Get the rest of the information.
|
||||||
self.redis_clients = []
|
self.redis_clients = []
|
||||||
num_redis_shards = self.redis_client.get("NumRedisShards")
|
|
||||||
if num_redis_shards is None:
|
|
||||||
raise Exception("No entry found for NumRedisShards")
|
|
||||||
num_redis_shards = int(num_redis_shards)
|
|
||||||
if (num_redis_shards < 1):
|
|
||||||
raise Exception("Expected at least one Redis shard, found "
|
|
||||||
"{}.".format(num_redis_shards))
|
|
||||||
|
|
||||||
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
|
|
||||||
end=-1)
|
|
||||||
if len(ip_address_ports) != num_redis_shards:
|
|
||||||
raise Exception("Expected {} Redis shard addresses, found "
|
|
||||||
"{}".format(num_redis_shards,
|
|
||||||
len(ip_address_ports)))
|
|
||||||
|
|
||||||
for ip_address_port in ip_address_ports:
|
for ip_address_port in ip_address_ports:
|
||||||
shard_address, shard_port = ip_address_port.split(b":")
|
shard_address, shard_port = ip_address_port.split(b":")
|
||||||
self.redis_clients.append(redis.StrictRedis(host=shard_address,
|
self.redis_clients.append(redis.StrictRedis(host=shard_address,
|
||||||
|
|
Loading…
Add table
Reference in a new issue