ray/webui/backend/ray_ui.py

461 lines
18 KiB
Python
Raw Normal View History

import aioredis
import argparse
import asyncio
import binascii
import collections
import datetime
import json
import numpy as np
import time
import websockets
# Import flatbuffer bindings.
from ray.core.generated.LocalSchedulerInfoMessage import \
LocalSchedulerInfoMessage
parser = argparse.ArgumentParser(
description="parse information for the web ui")
parser.add_argument("--redis-address", required=True, type=str,
help="the address to use for redis")
loop = asyncio.get_event_loop()
IDENTIFIER_LENGTH = 20
# This prefix must match the value defined in ray_redis_module.cc.
DB_CLIENT_PREFIX = b"CL:"
def hex_identifier(identifier):
return binascii.hexlify(identifier).decode()
def identifier(hex_identifier):
return binascii.unhexlify(hex_identifier)
def key_to_hex_identifier(key):
return hex_identifier(
key[(key.index(b":") + 1):(key.index(b":") + IDENTIFIER_LENGTH + 1)])
def timestamp_to_date_string(timestamp):
"""Convert a time stamp returned by time.time() to a formatted string."""
return (datetime.datetime.fromtimestamp(timestamp)
.strftime("%Y/%m/%d %H:%M:%S"))
def key_to_hex_identifiers(key):
# Extract worker_id and task_id from key of the form
# prefix:worker_id:task_id.
offset = key.index(b":") + 1
worker_id = hex_identifier(key[offset:(offset + IDENTIFIER_LENGTH)])
offset += IDENTIFIER_LENGTH + 1
task_id = hex_identifier(key[offset:(offset + IDENTIFIER_LENGTH)])
return worker_id, task_id
async def hgetall_as_dict(redis_conn, key):
fields = await redis_conn.execute("hgetall", key)
return {fields[2 * i]: fields[2 * i + 1] for i in range(len(fields) // 2)}
# Cache information about the local schedulers.
local_schedulers = {}
errors = []
def duration_to_string(duration):
"""Format a duration in seconds as a string.
Args:
duration (float): The duration in seconds.
Return:
A more human-readable version of the string (for example, "3.5 hours" or
"93 milliseconds").
"""
if duration > 3600 * 24:
duration_str = "{0:0.1f} days".format(duration / (3600 * 24))
elif duration > 3600:
duration_str = "{0:0.1f} hours".format(duration / 3600)
elif duration > 60:
duration_str = "{0:0.1f} minutes".format(duration / 60)
elif duration > 1:
duration_str = "{0:0.1f} seconds".format(duration)
elif duration > 0.001:
duration_str = "{0:0.1f} milliseconds".format(duration * 1000)
else:
duration_str = "{} microseconds".format(int(duration * 1000000))
return duration_str
async def handle_get_statistics(websocket, redis_conn):
cluster_start_time = float(await redis_conn.execute("get",
"redis_start_time"))
start_date = timestamp_to_date_string(cluster_start_time)
uptime = duration_to_string(time.time() - cluster_start_time)
client_keys = await redis_conn.execute("keys", "CL:*")
clients = []
for client_key in client_keys:
client_fields = await hgetall_as_dict(redis_conn, client_key)
clients.append(client_fields)
ip_addresses = list(set([client[b"node_ip_address"].decode("ascii")
for client in clients
if client[b"client_type"] == b"local_scheduler"]))
num_nodes = len(ip_addresses)
reply = {"uptime": uptime,
"start_date": start_date,
"nodes": num_nodes,
"addresses": ip_addresses}
await websocket.send(json.dumps(reply))
async def handle_get_drivers(websocket, redis_conn):
keys = await redis_conn.execute("keys", "Drivers:*")
drivers = []
for key in keys:
driver_fields = await hgetall_as_dict(redis_conn, key)
driver_info = {
"node ip address": driver_fields[b"node_ip_address"].decode("ascii"),
"name": driver_fields[b"name"].decode("ascii")}
driver_info["start time"] = timestamp_to_date_string(
float(driver_fields[b"start_time"]))
if b"end_time" in driver_fields:
duration = (float(driver_fields[b"end_time"]) -
float(driver_fields[b"start_time"]))
else:
duration = time.time() - float(driver_fields[b"start_time"])
driver_info["duration"] = duration_to_string(duration)
if b"exception" in driver_fields:
driver_info["status"] = "FAILED"
elif b"end_time" not in driver_fields:
driver_info["status"] = "IN PROGRESS"
else:
driver_info["status"] = "SUCCESS"
if b"exception" in driver_fields:
driver_info["exception"] = driver_fields[b"exception"].decode("ascii")
drivers.append(driver_info)
# Sort the drivers by their start times.
reply = sorted(drivers, key=(lambda driver: driver["start time"]))[::-1]
await websocket.send(json.dumps(reply))
async def listen_for_errors(redis_ip_address, redis_port):
pubsub_conn = await aioredis.create_connection(
(redis_ip_address, redis_port), loop=loop)
data_conn = await aioredis.create_connection((redis_ip_address, redis_port),
loop=loop)
error_pattern = "__keyspace@0__:ErrorKeys"
await pubsub_conn.execute_pubsub("psubscribe", error_pattern)
channel = pubsub_conn.pubsub_patterns[error_pattern]
print("Listening for error messages...")
index = 0
while (await channel.wait_message()):
await channel.get()
info = await data_conn.execute("lrange", "ErrorKeys", index, -1)
for error_key in info:
worker, task = key_to_hex_identifiers(error_key)
# TODO(richard): Filter out workers so that only relevant task errors are
# necessary.
result = await data_conn.execute("hget", error_key, "message")
result = result.decode("ascii")
# TODO(richard): Maybe also get rid of the coloring.
errors.append({"driver_id": worker,
"task_id": task,
"error": result})
index += 1
async def handle_get_errors(websocket):
"""Send error messages to the frontend."""
await websocket.send(json.dumps(errors))
node_info = collections.OrderedDict()
worker_info = collections.OrderedDict()
async def handle_get_recent_tasks(websocket, redis_conn, num_tasks):
# First update the cache of worker information.
worker_keys = await redis_conn.execute("keys", "Workers:*")
for key in worker_keys:
worker_id = hex_identifier(key[len("Workers:"):])
if worker_id not in worker_info:
worker_info[worker_id] = await hgetall_as_dict(redis_conn, key)
node_ip_address = (worker_info[worker_id][b"node_ip_address"]
.decode("ascii"))
if node_ip_address not in node_info:
node_info[node_ip_address] = {"workers": []}
node_info[node_ip_address]["workers"].append(worker_id)
keys = await redis_conn.execute("keys", "event_log:*")
if len(keys) == 0:
# There are no tasks, so send a message to the client saying so.
await websocket.send(json.dumps({"num_tasks": 0}))
else:
timestamps = []
contents = []
for key in keys:
content = await redis_conn.execute("lrange", key, "0", "-1")
contents.append(json.loads(content[0].decode()))
timestamps += [timestamp for (timestamp, task, kind, info)
in contents[-1] if task == "ray:task"]
timestamps.sort()
time_cutoff = timestamps[(-2 * num_tasks):][0]
max_time = timestamps[-1]
min_time = time_cutoff - (max_time - time_cutoff) * 0.1
max_time = max_time + (max_time - time_cutoff) * 0.1
worker_ids = list(worker_info.keys())
node_ip_addresses = list(node_info.keys())
num_tasks = 0
task_data = [{"task_data": [],
"num_workers": len(node_info[node_ip_address]["workers"])}
for node_ip_address in node_ip_addresses]
for i in range(len(keys)):
worker_id, task_id = key_to_hex_identifiers(keys[i])
data = contents[i]
if worker_id not in worker_ids:
# This case should be extremely rare.
raise Exception("A worker ID was not present in the list of worker "
"IDs.")
node_ip_address = (worker_info[worker_id][b"node_ip_address"]
.decode("ascii"))
worker_index = node_info[node_ip_address]["workers"].index(worker_id)
node_index = node_ip_addresses.index(node_ip_address)
task_times = [timestamp for (timestamp, task, kind, info) in data
if task == "ray:task"]
if task_times[1] <= time_cutoff:
continue
task_get_arguments_times = [timestamp for (timestamp, task, kind, info)
in data if task == "ray:task:get_arguments"]
task_execute_times = [timestamp for (timestamp, task, kind, info)
in data if task == "ray:task:execute"]
task_store_outputs_times = [timestamp for (timestamp, task, kind, info)
in data if task == "ray:task:store_outputs"]
task_info = {
"task": task_times,
"get_arguments": task_get_arguments_times,
"execute": task_execute_times,
"store_outputs": task_store_outputs_times,
"worker_index": worker_index,
"node_ip_address": node_ip_address,
"task_formatted_time": duration_to_string(task_times[1] -
task_times[0]),
"get_arguments_formatted_time":
duration_to_string(task_get_arguments_times[1] -
task_get_arguments_times[0])}
if len(task_execute_times) == 2:
task_info["execute_formatted_time"] = duration_to_string(
task_execute_times[1] - task_execute_times[0])
if len(task_store_outputs_times) == 2:
task_info["store_outputs_formatted_time"] = duration_to_string(
task_store_outputs_times[1] - task_store_outputs_times[0])
task_data[node_index]["task_data"].append(task_info)
num_tasks += 1
reply = {"min_time": min_time,
"max_time": max_time,
"num_tasks": num_tasks,
"task_data": task_data}
await websocket.send(json.dumps(reply))
async def send_heartbeat_payload(websocket):
"""Send heartbeat updates to the frontend every half second."""
while True:
reply = []
for local_scheduler_id, local_scheduler in local_schedulers.items():
current_time = time.time()
local_scheduler_info = {
"local scheduler ID": local_scheduler_id,
"time since heartbeat":
(duration_to_string(current_time -
local_scheduler["last_heartbeat"])),
"time since heartbeat numeric":
str(current_time - local_scheduler["last_heartbeat"]),
"node ip address": local_scheduler["node_ip_address"]}
reply.append(local_scheduler_info)
# Send the payload to the frontend.
await websocket.send(json.dumps(reply))
# Wait for a little while so as not to overwhelm the frontend.
await asyncio.sleep(0.5)
async def send_heartbeats(websocket, redis_conn):
# First update the local scheduler info locally.
client_keys = await redis_conn.execute("keys", "CL:*")
for client_key in client_keys:
client_fields = await hgetall_as_dict(redis_conn, client_key)
if client_fields[b"client_type"] == b"local_scheduler":
local_scheduler_id = hex_identifier(client_fields[b"ray_client_id"])
local_schedulers[local_scheduler_id] = {
"node_ip_address": client_fields[b"node_ip_address"].decode("ascii"),
"local_scheduler_socket_name":
client_fields[b"local_scheduler_socket_name"].decode("ascii"),
"aux_address": client_fields[b"aux_address"].decode("ascii"),
"last_heartbeat": -1 * np.inf}
# Subscribe to local scheduler heartbeats.
await redis_conn.execute_pubsub("subscribe", "local_schedulers")
# Start a method in the background to periodically update the frontend.
asyncio.ensure_future(send_heartbeat_payload(websocket))
while True:
msg = await redis_conn.pubsub_channels["local_schedulers"].get()
heartbeat = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage(
msg, 0)
local_scheduler_id_bytes = heartbeat.DbClientId()
local_scheduler_id = hex_identifier(local_scheduler_id_bytes)
if local_scheduler_id not in local_schedulers:
# A new local scheduler has joined the cluster. Ignore it. This won't be
# displayed in the UI until the page is refreshed.
continue
local_schedulers[local_scheduler_id]["last_heartbeat"] = time.time()
async def cache_data_from_redis(redis_ip_address, redis_port):
"""Open up ports to listen for new updates from Redis."""
# TODO(richard): A lot of code needs to be ported in order to open new
# websockets.
asyncio.ensure_future(listen_for_errors(redis_ip_address, redis_port))
async def handle_get_log_files(websocket, redis_conn):
reply = {}
# First get all keys for the log file lists.
log_file_list_keys = await redis_conn.execute("keys", "LOG_FILENAMES:*")
for log_file_list_key in log_file_list_keys:
node_ip_address = log_file_list_key.decode("ascii").split(":")[1]
reply[node_ip_address] = {}
# Get all of the log filenames for this node IP address.
log_filenames = await redis_conn.execute("lrange", log_file_list_key, 0,
-1)
for log_filename in log_filenames:
log_filename_key = "LOGFILE:{}:{}".format(node_ip_address,
log_filename.decode("ascii"))
logfile = await redis_conn.execute("lrange", log_filename_key, 0, -1)
logfile = [line.decode("ascii") for line in logfile]
reply[node_ip_address][log_filename.decode("ascii")] = logfile
# Send the reply back to the front end.
await websocket.send(json.dumps(reply))
async def serve_requests(websocket, path):
redis_conn = await aioredis.create_connection((redis_ip_address, redis_port),
loop=loop)
while True:
command = json.loads(await websocket.recv())
print("received command {}".format(command))
if command["command"] == "get-statistics":
await handle_get_statistics(websocket, redis_conn)
elif command["command"] == "get-drivers":
await handle_get_drivers(websocket, redis_conn)
elif command["command"] == "get-recent-tasks":
await handle_get_recent_tasks(websocket, redis_conn, command["num"])
elif command["command"] == "get-errors":
await handle_get_errors(websocket)
elif command["command"] == "get-heartbeats":
await send_heartbeats(websocket, redis_conn)
elif command["command"] == "get-log-files":
await handle_get_log_files(websocket, redis_conn)
if command["command"] == "get-workers":
result = []
workers = await redis_conn.execute("keys", "WorkerInfo:*")
for key in workers:
content = await redis_conn.execute("hgetall", key)
worker_id = key_to_hex_identifier(key)
result.append({"worker": worker_id, "export_counter": int(content[1])})
await websocket.send(json.dumps(result))
elif command["command"] == "get-clients":
result = []
clients = await redis_conn.execute("keys", "CL:*")
for key in clients:
content = await redis_conn.execute("hgetall", key)
result.append({"client": hex_identifier(content[1]),
"node_ip_address": content[3].decode(),
"client_type": content[5].decode()})
await websocket.send(json.dumps(result))
elif command["command"] == "get-objects":
result = []
objects = await redis_conn.execute("keys", "OI:*")
for key in objects:
content = await redis_conn.execute("hgetall", key)
result.append({"object_id": hex_identifier(content[1]),
"hash": hex_identifier(content[3]),
"data_size": content[5].decode()})
await websocket.send(json.dumps(result))
elif command["command"] == "get-object-info":
# TODO(pcm): Get the object here (have to connect to ray) and ship
# content and type back to webclient. One challenge here is that the
# naive implementation will block the web ui backend, which is not ok if
# it is serving multiple users.
await websocket.send(json.dumps({"object_id": "none"}))
elif command["command"] == "get-tasks":
result = []
tasks = await redis_conn.execute("keys", "TT:*")
for key in tasks:
content = await redis_conn.execute("hgetall", key)
result.append({"task_id": key_to_hex_identifier(key),
"state": int(content[1]),
"node_id": hex_identifier(content[3])})
await websocket.send(json.dumps(result))
elif command["command"] == "get-timeline":
tasks = collections.defaultdict(list)
for key in await redis_conn.execute("keys", "event_log:*"):
worker_id, task_id = key_to_hex_identifiers(key)
content = await redis_conn.execute("lrange", key, "0", "-1")
data = json.loads(content[0].decode())
begin_and_end_time = [timestamp for (timestamp, task, kind, info)
in data if task == "ray:task"]
tasks[worker_id].append({"task_id": task_id,
"start_task": min(begin_and_end_time),
"end_task": max(begin_and_end_time)})
await websocket.send(json.dumps(tasks))
elif command["command"] == "get-events":
result = []
for key in await redis_conn.execute("keys", "event_log:*"):
worker_id, task_id = key_to_hex_identifiers(key)
answer = await redis_conn.execute("lrange", key, "0", "-1")
assert len(answer) == 1
events = json.loads(answer[0].decode())
result.extend([{"worker_id": worker_id,
"task_id": task_id,
"time": event[0],
"type": event[1]} for event in events])
await websocket.send(json.dumps(result))
if __name__ == "__main__":
args = parser.parse_args()
redis_address = args.redis_address.split(":")
redis_ip_address, redis_port = redis_address[0], int(redis_address[1])
# The port here must match the value used by the frontend to connect over
# websockets. TODO(richard): Automatically increment the port if it is
# already taken.
port = 8888
loop.run_until_complete(cache_data_from_redis(redis_ip_address, redis_port))
start_server = websockets.serve(serve_requests, "localhost", port)
loop.run_until_complete(start_server)
loop.run_forever()