Wait longer when getting redis shards to initialize global state API. (#786)

This commit is contained in:
Robert Nishihara 2017-07-31 17:56:11 -07:00 committed by Philipp Moritz
parent 1fe49d7676
commit c394a65ffc

View file

@ -57,6 +57,7 @@ class GlobalState(object):
def __init__(self): def __init__(self):
"""Create a GlobalState object.""" """Create a GlobalState object."""
self.redis_client = None self.redis_client = None
self.redis_clients = None
def _check_connected(self): def _check_connected(self):
"""Check that the object has been initialized before it is used. """Check that the object has been initialized before it is used.
@ -69,32 +70,65 @@ class GlobalState(object):
raise Exception("The ray.global_state API cannot be used before " raise Exception("The ray.global_state API cannot be used before "
"ray.init has been called.") "ray.init has been called.")
def _initialize_global_state(self, redis_ip_address, redis_port): if self.redis_clients is None:
raise Exception("The ray.global_state API cannot be used before "
"ray.init has been called.")
def _initialize_global_state(self, redis_ip_address, redis_port,
timeout=20):
"""Initialize the GlobalState object by connecting to Redis. """Initialize the GlobalState object by connecting to Redis.
It's possible that certain keys in Redis may not have been fully
populated yet. In this case, we will retry this method until they have
been populated or we exceed a timeout.
Args: Args:
redis_ip_address: The IP address of the node that the Redis server redis_ip_address: The IP address of the node that the Redis server
lives on. lives on.
redis_port: The port that the Redis server is listening on. redis_port: The port that the Redis server is listening on.
timeout: The maximum amount of time (in seconds) that we should
wait for the keys in Redis to be populated.
""" """
self.redis_client = redis.StrictRedis(host=redis_ip_address, self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port) port=redis_port)
start_time = time.time()
num_redis_shards = None
ip_address_ports = []
while time.time() - start_time < timeout:
# Attempt to get the number of Redis shards.
num_redis_shards = self.redis_client.get("NumRedisShards")
if num_redis_shards is None:
print("Waiting longer for NumRedisShards to be populated.")
time.sleep(1)
continue
num_redis_shards = int(num_redis_shards)
if (num_redis_shards < 1):
raise Exception("Expected at least one Redis shard, found "
"{}.".format(num_redis_shards))
# Attempt to get all of the Redis shards.
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
end=-1)
if len(ip_address_ports) != num_redis_shards:
print("Waiting longer for RedisShards to be populated.")
time.sleep(1)
continue
# If we got here then we successfully got all of the information.
break
# Check to see if we timed out.
if time.time() - start_time >= timeout:
raise Exception("Timed out while attempting to initialize the "
"global state. num_redis_shards = {}, "
"ip_address_ports = {}"
.format(num_redis_shards, ip_address_ports))
# Get the rest of the information.
self.redis_clients = [] self.redis_clients = []
num_redis_shards = self.redis_client.get("NumRedisShards")
if num_redis_shards is None:
raise Exception("No entry found for NumRedisShards")
num_redis_shards = int(num_redis_shards)
if (num_redis_shards < 1):
raise Exception("Expected at least one Redis shard, found "
"{}.".format(num_redis_shards))
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
end=-1)
if len(ip_address_ports) != num_redis_shards:
raise Exception("Expected {} Redis shard addresses, found "
"{}".format(num_redis_shards,
len(ip_address_ports)))
for ip_address_port in ip_address_ports: for ip_address_port in ip_address_ports:
shard_address, shard_port = ip_address_port.split(b":") shard_address, shard_port = ip_address_port.split(b":")
self.redis_clients.append(redis.StrictRedis(host=shard_address, self.redis_clients.append(redis.StrictRedis(host=shard_address,