mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Wait longer when getting redis shards to initialize global state API. (#786)
This commit is contained in:
parent
1fe49d7676
commit
c394a65ffc
1 changed files with 50 additions and 16 deletions
|
@ -57,6 +57,7 @@ class GlobalState(object):
|
|||
def __init__(self):
|
||||
"""Create a GlobalState object."""
|
||||
self.redis_client = None
|
||||
self.redis_clients = None
|
||||
|
||||
def _check_connected(self):
|
||||
"""Check that the object has been initialized before it is used.
|
||||
|
@ -69,32 +70,65 @@ class GlobalState(object):
|
|||
raise Exception("The ray.global_state API cannot be used before "
|
||||
"ray.init has been called.")
|
||||
|
||||
def _initialize_global_state(self, redis_ip_address, redis_port):
|
||||
if self.redis_clients is None:
|
||||
raise Exception("The ray.global_state API cannot be used before "
|
||||
"ray.init has been called.")
|
||||
|
||||
def _initialize_global_state(self, redis_ip_address, redis_port,
|
||||
timeout=20):
|
||||
"""Initialize the GlobalState object by connecting to Redis.
|
||||
|
||||
It's possible that certain keys in Redis may not have been fully
|
||||
populated yet. In this case, we will retry this method until they have
|
||||
been populated or we exceed a timeout.
|
||||
|
||||
Args:
|
||||
redis_ip_address: The IP address of the node that the Redis server
|
||||
lives on.
|
||||
redis_port: The port that the Redis server is listening on.
|
||||
timeout: The maximum amount of time (in seconds) that we should
|
||||
wait for the keys in Redis to be populated.
|
||||
"""
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
num_redis_shards = None
|
||||
ip_address_ports = []
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# Attempt to get the number of Redis shards.
|
||||
num_redis_shards = self.redis_client.get("NumRedisShards")
|
||||
if num_redis_shards is None:
|
||||
print("Waiting longer for NumRedisShards to be populated.")
|
||||
time.sleep(1)
|
||||
continue
|
||||
num_redis_shards = int(num_redis_shards)
|
||||
if (num_redis_shards < 1):
|
||||
raise Exception("Expected at least one Redis shard, found "
|
||||
"{}.".format(num_redis_shards))
|
||||
|
||||
# Attempt to get all of the Redis shards.
|
||||
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
|
||||
end=-1)
|
||||
if len(ip_address_ports) != num_redis_shards:
|
||||
print("Waiting longer for RedisShards to be populated.")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
# If we got here then we successfully got all of the information.
|
||||
break
|
||||
|
||||
# Check to see if we timed out.
|
||||
if time.time() - start_time >= timeout:
|
||||
raise Exception("Timed out while attempting to initialize the "
|
||||
"global state. num_redis_shards = {}, "
|
||||
"ip_address_ports = {}"
|
||||
.format(num_redis_shards, ip_address_ports))
|
||||
|
||||
# Get the rest of the information.
|
||||
self.redis_clients = []
|
||||
num_redis_shards = self.redis_client.get("NumRedisShards")
|
||||
if num_redis_shards is None:
|
||||
raise Exception("No entry found for NumRedisShards")
|
||||
num_redis_shards = int(num_redis_shards)
|
||||
if (num_redis_shards < 1):
|
||||
raise Exception("Expected at least one Redis shard, found "
|
||||
"{}.".format(num_redis_shards))
|
||||
|
||||
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
|
||||
end=-1)
|
||||
if len(ip_address_ports) != num_redis_shards:
|
||||
raise Exception("Expected {} Redis shard addresses, found "
|
||||
"{}".format(num_redis_shards,
|
||||
len(ip_address_ports)))
|
||||
|
||||
for ip_address_port in ip_address_ports:
|
||||
shard_address, shard_port = ip_address_port.split(b":")
|
||||
self.redis_clients.append(redis.StrictRedis(host=shard_address,
|
||||
|
|
Loading…
Add table
Reference in a new issue