import aioredis import argparse import asyncio import binascii import collections import datetime import json import numpy as np import time import websockets # Import flatbuffer bindings. from ray.core.generated.LocalSchedulerInfoMessage import \ LocalSchedulerInfoMessage parser = argparse.ArgumentParser( description="parse information for the web ui") parser.add_argument("--redis-address", required=True, type=str, help="the address to use for redis") loop = asyncio.get_event_loop() IDENTIFIER_LENGTH = 20 # This prefix must match the value defined in ray_redis_module.cc. DB_CLIENT_PREFIX = b"CL:" def hex_identifier(identifier): return binascii.hexlify(identifier).decode() def identifier(hex_identifier): return binascii.unhexlify(hex_identifier) def key_to_hex_identifier(key): return hex_identifier( key[(key.index(b":") + 1):(key.index(b":") + IDENTIFIER_LENGTH + 1)]) def timestamp_to_date_string(timestamp): """Convert a time stamp returned by time.time() to a formatted string.""" return (datetime.datetime.fromtimestamp(timestamp) .strftime("%Y/%m/%d %H:%M:%S")) def key_to_hex_identifiers(key): # Extract worker_id and task_id from key of the form # prefix:worker_id:task_id. offset = key.index(b":") + 1 worker_id = hex_identifier(key[offset:(offset + IDENTIFIER_LENGTH)]) offset += IDENTIFIER_LENGTH + 1 task_id = hex_identifier(key[offset:(offset + IDENTIFIER_LENGTH)]) return worker_id, task_id async def hgetall_as_dict(redis_conn, key): fields = await redis_conn.execute("hgetall", key) return {fields[2 * i]: fields[2 * i + 1] for i in range(len(fields) // 2)} # Cache information about the local schedulers. local_schedulers = {} errors = [] def duration_to_string(duration): """Format a duration in seconds as a string. Args: duration (float): The duration in seconds. Return: A more human-readable version of the string (for example, "3.5 hours" or "93 milliseconds"). """ if duration > 3600 * 24: duration_str = "{0:0.1f} days".format(duration / (3600 * 24)) elif duration > 3600: duration_str = "{0:0.1f} hours".format(duration / 3600) elif duration > 60: duration_str = "{0:0.1f} minutes".format(duration / 60) elif duration > 1: duration_str = "{0:0.1f} seconds".format(duration) elif duration > 0.001: duration_str = "{0:0.1f} milliseconds".format(duration * 1000) else: duration_str = "{} microseconds".format(int(duration * 1000000)) return duration_str async def handle_get_statistics(websocket, redis_conn): cluster_start_time = float(await redis_conn.execute("get", "redis_start_time")) start_date = timestamp_to_date_string(cluster_start_time) uptime = duration_to_string(time.time() - cluster_start_time) client_keys = await redis_conn.execute("keys", "CL:*") clients = [] for client_key in client_keys: client_fields = await hgetall_as_dict(redis_conn, client_key) clients.append(client_fields) ip_addresses = list(set([client[b"node_ip_address"].decode("ascii") for client in clients if client[b"client_type"] == b"local_scheduler"])) num_nodes = len(ip_addresses) reply = {"uptime": uptime, "start_date": start_date, "nodes": num_nodes, "addresses": ip_addresses} await websocket.send(json.dumps(reply)) async def handle_get_drivers(websocket, redis_conn): keys = await redis_conn.execute("keys", "Drivers:*") drivers = [] for key in keys: driver_fields = await hgetall_as_dict(redis_conn, key) driver_info = { "node ip address": driver_fields[b"node_ip_address"].decode("ascii"), "name": driver_fields[b"name"].decode("ascii")} driver_info["start time"] = timestamp_to_date_string( float(driver_fields[b"start_time"])) if b"end_time" in driver_fields: duration = (float(driver_fields[b"end_time"]) - float(driver_fields[b"start_time"])) else: duration = time.time() - float(driver_fields[b"start_time"]) driver_info["duration"] = duration_to_string(duration) if b"exception" in driver_fields: driver_info["status"] = "FAILED" elif b"end_time" not in driver_fields: driver_info["status"] = "IN PROGRESS" else: driver_info["status"] = "SUCCESS" if b"exception" in driver_fields: driver_info["exception"] = driver_fields[b"exception"].decode("ascii") drivers.append(driver_info) # Sort the drivers by their start times. reply = sorted(drivers, key=(lambda driver: driver["start time"]))[::-1] await websocket.send(json.dumps(reply)) async def listen_for_errors(redis_ip_address, redis_port): pubsub_conn = await aioredis.create_connection( (redis_ip_address, redis_port), loop=loop) data_conn = await aioredis.create_connection((redis_ip_address, redis_port), loop=loop) error_pattern = "__keyspace@0__:ErrorKeys" await pubsub_conn.execute_pubsub("psubscribe", error_pattern) channel = pubsub_conn.pubsub_patterns[error_pattern] print("Listening for error messages...") index = 0 while (await channel.wait_message()): await channel.get() info = await data_conn.execute("lrange", "ErrorKeys", index, -1) for error_key in info: worker, task = key_to_hex_identifiers(error_key) # TODO(richard): Filter out workers so that only relevant task errors are # necessary. result = await data_conn.execute("hget", error_key, "message") result = result.decode("ascii") # TODO(richard): Maybe also get rid of the coloring. errors.append({"driver_id": worker, "task_id": task, "error": result}) index += 1 async def handle_get_errors(websocket): """Send error messages to the frontend.""" await websocket.send(json.dumps(errors)) node_info = collections.OrderedDict() worker_info = collections.OrderedDict() async def handle_get_recent_tasks(websocket, redis_conn, num_tasks): # First update the cache of worker information. worker_keys = await redis_conn.execute("keys", "Workers:*") for key in worker_keys: worker_id = hex_identifier(key[len("Workers:"):]) if worker_id not in worker_info: worker_info[worker_id] = await hgetall_as_dict(redis_conn, key) node_ip_address = (worker_info[worker_id][b"node_ip_address"] .decode("ascii")) if node_ip_address not in node_info: node_info[node_ip_address] = {"workers": []} node_info[node_ip_address]["workers"].append(worker_id) keys = await redis_conn.execute("keys", "event_log:*") if len(keys) == 0: # There are no tasks, so send a message to the client saying so. await websocket.send(json.dumps({"num_tasks": 0})) else: timestamps = [] contents = [] for key in keys: content = await redis_conn.execute("lrange", key, "0", "-1") contents.append(json.loads(content[0].decode())) timestamps += [timestamp for (timestamp, task, kind, info) in contents[-1] if task == "ray:task"] timestamps.sort() time_cutoff = timestamps[(-2 * num_tasks):][0] max_time = timestamps[-1] min_time = time_cutoff - (max_time - time_cutoff) * 0.1 max_time = max_time + (max_time - time_cutoff) * 0.1 worker_ids = list(worker_info.keys()) node_ip_addresses = list(node_info.keys()) num_tasks = 0 task_data = [{"task_data": [], "num_workers": len(node_info[node_ip_address]["workers"])} for node_ip_address in node_ip_addresses] for i in range(len(keys)): worker_id, task_id = key_to_hex_identifiers(keys[i]) data = contents[i] if worker_id not in worker_ids: # This case should be extremely rare. raise Exception("A worker ID was not present in the list of worker " "IDs.") node_ip_address = (worker_info[worker_id][b"node_ip_address"] .decode("ascii")) worker_index = node_info[node_ip_address]["workers"].index(worker_id) node_index = node_ip_addresses.index(node_ip_address) task_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task"] if task_times[1] <= time_cutoff: continue task_get_arguments_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:get_arguments"] task_execute_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:execute"] task_store_outputs_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:store_outputs"] task_info = { "task": task_times, "get_arguments": task_get_arguments_times, "execute": task_execute_times, "store_outputs": task_store_outputs_times, "worker_index": worker_index, "node_ip_address": node_ip_address, "task_formatted_time": duration_to_string(task_times[1] - task_times[0]), "get_arguments_formatted_time": duration_to_string(task_get_arguments_times[1] - task_get_arguments_times[0])} if len(task_execute_times) == 2: task_info["execute_formatted_time"] = duration_to_string( task_execute_times[1] - task_execute_times[0]) if len(task_store_outputs_times) == 2: task_info["store_outputs_formatted_time"] = duration_to_string( task_store_outputs_times[1] - task_store_outputs_times[0]) task_data[node_index]["task_data"].append(task_info) num_tasks += 1 reply = {"min_time": min_time, "max_time": max_time, "num_tasks": num_tasks, "task_data": task_data} await websocket.send(json.dumps(reply)) async def send_heartbeat_payload(websocket): """Send heartbeat updates to the frontend every half second.""" while True: reply = [] for local_scheduler_id, local_scheduler in local_schedulers.items(): current_time = time.time() local_scheduler_info = { "local scheduler ID": local_scheduler_id, "time since heartbeat": (duration_to_string(current_time - local_scheduler["last_heartbeat"])), "time since heartbeat numeric": str(current_time - local_scheduler["last_heartbeat"]), "node ip address": local_scheduler["node_ip_address"]} reply.append(local_scheduler_info) # Send the payload to the frontend. await websocket.send(json.dumps(reply)) # Wait for a little while so as not to overwhelm the frontend. await asyncio.sleep(0.5) async def send_heartbeats(websocket, redis_conn): # First update the local scheduler info locally. client_keys = await redis_conn.execute("keys", "CL:*") for client_key in client_keys: client_fields = await hgetall_as_dict(redis_conn, client_key) if client_fields[b"client_type"] == b"local_scheduler": local_scheduler_id = hex_identifier(client_fields[b"ray_client_id"]) local_schedulers[local_scheduler_id] = { "node_ip_address": client_fields[b"node_ip_address"].decode("ascii"), "local_scheduler_socket_name": client_fields[b"local_scheduler_socket_name"].decode("ascii"), "aux_address": client_fields[b"aux_address"].decode("ascii"), "last_heartbeat": -1 * np.inf} # Subscribe to local scheduler heartbeats. await redis_conn.execute_pubsub("subscribe", "local_schedulers") # Start a method in the background to periodically update the frontend. asyncio.ensure_future(send_heartbeat_payload(websocket)) while True: msg = await redis_conn.pubsub_channels["local_schedulers"].get() heartbeat = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage( msg, 0) local_scheduler_id_bytes = heartbeat.DbClientId() local_scheduler_id = hex_identifier(local_scheduler_id_bytes) if local_scheduler_id not in local_schedulers: # A new local scheduler has joined the cluster. Ignore it. This won't be # displayed in the UI until the page is refreshed. continue local_schedulers[local_scheduler_id]["last_heartbeat"] = time.time() async def cache_data_from_redis(redis_ip_address, redis_port): """Open up ports to listen for new updates from Redis.""" # TODO(richard): A lot of code needs to be ported in order to open new # websockets. asyncio.ensure_future(listen_for_errors(redis_ip_address, redis_port)) async def handle_get_log_files(websocket, redis_conn): reply = {} # First get all keys for the log file lists. log_file_list_keys = await redis_conn.execute("keys", "LOG_FILENAMES:*") for log_file_list_key in log_file_list_keys: node_ip_address = log_file_list_key.decode("ascii").split(":")[1] reply[node_ip_address] = {} # Get all of the log filenames for this node IP address. log_filenames = await redis_conn.execute("lrange", log_file_list_key, 0, -1) for log_filename in log_filenames: log_filename_key = "LOGFILE:{}:{}".format(node_ip_address, log_filename.decode("ascii")) logfile = await redis_conn.execute("lrange", log_filename_key, 0, -1) logfile = [line.decode("ascii") for line in logfile] reply[node_ip_address][log_filename.decode("ascii")] = logfile # Send the reply back to the front end. await websocket.send(json.dumps(reply)) async def serve_requests(websocket, path): redis_conn = await aioredis.create_connection((redis_ip_address, redis_port), loop=loop) while True: command = json.loads(await websocket.recv()) print("received command {}".format(command)) if command["command"] == "get-statistics": await handle_get_statistics(websocket, redis_conn) elif command["command"] == "get-drivers": await handle_get_drivers(websocket, redis_conn) elif command["command"] == "get-recent-tasks": await handle_get_recent_tasks(websocket, redis_conn, command["num"]) elif command["command"] == "get-errors": await handle_get_errors(websocket) elif command["command"] == "get-heartbeats": await send_heartbeats(websocket, redis_conn) elif command["command"] == "get-log-files": await handle_get_log_files(websocket, redis_conn) if command["command"] == "get-workers": result = [] workers = await redis_conn.execute("keys", "WorkerInfo:*") for key in workers: content = await redis_conn.execute("hgetall", key) worker_id = key_to_hex_identifier(key) result.append({"worker": worker_id, "export_counter": int(content[1])}) await websocket.send(json.dumps(result)) elif command["command"] == "get-clients": result = [] clients = await redis_conn.execute("keys", "CL:*") for key in clients: content = await redis_conn.execute("hgetall", key) result.append({"client": hex_identifier(content[1]), "node_ip_address": content[3].decode(), "client_type": content[5].decode()}) await websocket.send(json.dumps(result)) elif command["command"] == "get-objects": result = [] objects = await redis_conn.execute("keys", "OI:*") for key in objects: content = await redis_conn.execute("hgetall", key) result.append({"object_id": hex_identifier(content[1]), "hash": hex_identifier(content[3]), "data_size": content[5].decode()}) await websocket.send(json.dumps(result)) elif command["command"] == "get-object-info": # TODO(pcm): Get the object here (have to connect to ray) and ship # content and type back to webclient. One challenge here is that the # naive implementation will block the web ui backend, which is not ok if # it is serving multiple users. await websocket.send(json.dumps({"object_id": "none"})) elif command["command"] == "get-tasks": result = [] tasks = await redis_conn.execute("keys", "TT:*") for key in tasks: content = await redis_conn.execute("hgetall", key) result.append({"task_id": key_to_hex_identifier(key), "state": int(content[1]), "node_id": hex_identifier(content[3])}) await websocket.send(json.dumps(result)) elif command["command"] == "get-timeline": tasks = collections.defaultdict(list) for key in await redis_conn.execute("keys", "event_log:*"): worker_id, task_id = key_to_hex_identifiers(key) content = await redis_conn.execute("lrange", key, "0", "-1") data = json.loads(content[0].decode()) begin_and_end_time = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task"] tasks[worker_id].append({"task_id": task_id, "start_task": min(begin_and_end_time), "end_task": max(begin_and_end_time)}) await websocket.send(json.dumps(tasks)) elif command["command"] == "get-events": result = [] for key in await redis_conn.execute("keys", "event_log:*"): worker_id, task_id = key_to_hex_identifiers(key) answer = await redis_conn.execute("lrange", key, "0", "-1") assert len(answer) == 1 events = json.loads(answer[0].decode()) result.extend([{"worker_id": worker_id, "task_id": task_id, "time": event[0], "type": event[1]} for event in events]) await websocket.send(json.dumps(result)) if __name__ == "__main__": args = parser.parse_args() redis_address = args.redis_address.split(":") redis_ip_address, redis_port = redis_address[0], int(redis_address[1]) # The port here must match the value used by the frontend to connect over # websockets. TODO(richard): Automatically increment the port if it is # already taken. port = 8888 loop.run_until_complete(cache_data_from_redis(redis_ip_address, redis_port)) start_server = websockets.serve(serve_requests, "localhost", port) loop.run_until_complete(start_server) loop.run_forever()