Fix monitor.py bottleneck by removing excess Redis queries. (#1786)

* Fix monitor.py bottleneck by removing excess Redis queries.

* Remove unnecessary default value.
This commit is contained in:
Robert Nishihara 2018-03-26 22:30:38 -07:00 committed by Philipp Moritz
parent 51fdbe3867
commit de3cfa223d

View file

@ -95,6 +95,9 @@ class Monitor(object):
self.dead_local_schedulers = set() self.dead_local_schedulers = set()
self.live_plasma_managers = Counter() self.live_plasma_managers = Counter()
self.dead_plasma_managers = set() self.dead_plasma_managers = set()
# Keep a mapping from local scheduler client ID to IP address to use
# for updating the load metrics.
self.local_scheduler_id_to_ip_map = dict()
self.load_metrics = LoadMetrics() self.load_metrics = LoadMetrics()
if autoscaling_config: if autoscaling_config:
self.autoscaler = StandardAutoscaler( self.autoscaler = StandardAutoscaler(
@ -268,22 +271,15 @@ class Monitor(object):
static = message.StaticResources(i) static = message.StaticResources(i)
dynamic_resources[dyn.Key().decode("utf-8")] = dyn.Value() dynamic_resources[dyn.Key().decode("utf-8")] = dyn.Value()
static_resources[static.Key().decode("utf-8")] = static.Value() static_resources[static.Key().decode("utf-8")] = static.Value()
# Update the load metrics for this local scheduler.
client_id = binascii.hexlify(message.DbClientId()).decode("utf-8") client_id = binascii.hexlify(message.DbClientId()).decode("utf-8")
clients = ray.global_state.client_table() ip = self.local_scheduler_id_to_ip_map.get(client_id)
local_schedulers = [
entry for client in clients.values() for entry in client
if (entry["ClientType"] == "local_scheduler" and not
entry["Deleted"])
]
ip = None
for ls in local_schedulers:
if ls["DBClientID"] == client_id:
ip = ls["AuxAddress"].split(":")[0]
if ip: if ip:
self.load_metrics.update(ip, static_resources, dynamic_resources) self.load_metrics.update(ip, static_resources, dynamic_resources)
else: else:
print("Warning: could not find ip for client {} in {}".format( print("Warning: could not find ip for client {}."
client_id, local_schedulers)) .format(client_id))
def plasma_manager_heartbeat_handler(self, unused_channel, data): def plasma_manager_heartbeat_handler(self, unused_channel, data):
"""Handle a plasma manager heartbeat from Redis. """Handle a plasma manager heartbeat from Redis.
@ -437,13 +433,17 @@ class Monitor(object):
self._clean_up_entries_for_driver(driver_id) self._clean_up_entries_for_driver(driver_id)
def process_messages(self): def process_messages(self, max_messages=10000):
"""Process all messages ready in the subscription channels. """Process all messages ready in the subscription channels.
This reads messages from the subscription channels and calls the This reads messages from the subscription channels and calls the
appropriate handlers until there are no messages left. appropriate handlers until there are no messages left.
Args:
max_messages: The maximum number of messages to process before
returning.
""" """
while True: for _ in range(max_messages):
message = self.subscribe_client.get_message() message = self.subscribe_client.get_message()
if message is None: if message is None:
return return
@ -515,6 +515,15 @@ class Monitor(object):
# Handle messages from the subscription channels. # Handle messages from the subscription channels.
while True: while True:
# Update the mapping from local scheduler client ID to IP address.
# This is only used to update the load metrics for the autoscaler.
local_schedulers = self.state.local_schedulers()
self.local_scheduler_id_to_ip_map = {}
for local_scheduler_info in local_schedulers:
client_id = local_scheduler_info["DBClientID"]
ip_address = local_scheduler_info["AuxAddress"].split(":")[0]
self.local_scheduler_id_to_ip_map[client_id] = ip_address
# Process autoscaling actions # Process autoscaling actions
if self.autoscaler: if self.autoscaler:
self.autoscaler.update() self.autoscaler.update()
@ -556,6 +565,10 @@ class Monitor(object):
# messages. # messages.
time.sleep(ray._config.heartbeat_timeout_milliseconds() * 1e-3) time.sleep(ray._config.heartbeat_timeout_milliseconds() * 1e-3)
# TODO(rkn): This infinite loop should be inside of a try/except block,
# and if an exception is thrown we should push an error message to all
# drivers.
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description=("Parse Redis server for the " parser = argparse.ArgumentParser(description=("Parse Redis server for the "
@ -575,9 +588,6 @@ if __name__ == "__main__":
redis_ip_address = get_ip_address(args.redis_address) redis_ip_address = get_ip_address(args.redis_address)
redis_port = get_port(args.redis_address) redis_port = get_port(args.redis_address)
# Initialize the global state.
ray.global_state._initialize_global_state(redis_ip_address, redis_port)
if args.autoscaling_config: if args.autoscaling_config:
autoscaling_config = os.path.expanduser(args.autoscaling_config) autoscaling_config = os.path.expanduser(args.autoscaling_config)
else: else: