ray/webui/backend/ray_ui.py
Robert Nishihara 53dffe0bf2 Use flatbuffers for some messages from Redis. (#341)
* Compile the Ray redis module with C++.

* Redo parsing of object table notifications with flatbuffers.

* Update redis module python tests.

* Redo parsing of task table notifications with flatbuffers.

* Fix linting.

* Redo parsing of db client notifications with flatbuffers.

* Redo publishing of local scheduler heartbeats with flatbuffers.

* Fix linting.

* Remove usage of fixed-width formatting of scheduling state in channel name.

* Reply with flatbuffer object to task table queries, also simplify redis string to flatbuffer string conversion.

* Fix linting and tests.

* fix

* cleanup

* simplify logic in ReplyWithTask
2017-03-10 18:35:25 -08:00

343 lines
15 KiB
Python

import aioredis
import argparse
import asyncio
import binascii
import collections
import datetime
import json
import numpy as np
import os
import redis
import sys
import time
import websockets
parser = argparse.ArgumentParser(description="parse information for the web ui")
parser.add_argument("--redis-address", required=True, type=str, help="the address to use for redis")
loop = asyncio.get_event_loop()
IDENTIFIER_LENGTH = 20
# This prefix must match the value defined in ray_redis_module.cc.
DB_CLIENT_PREFIX = b"CL:"
def hex_identifier(identifier):
return binascii.hexlify(identifier).decode()
def identifier(hex_identifier):
return binascii.unhexlify(hex_identifier)
def key_to_hex_identifier(key):
return hex_identifier(key[(key.index(b":") + 1):(key.index(b":") + IDENTIFIER_LENGTH + 1)])
def timestamp_to_date_string(timestamp):
"""Convert a time stamp returned by time.time() to a formatted string."""
return datetime.datetime.fromtimestamp(timestamp).strftime("%Y/%m/%d %H:%M:%S")
def key_to_hex_identifiers(key):
# Extract worker_id and task_id from key of the form prefix:worker_id:task_id.
offset = key.index(b":") + 1
worker_id = hex_identifier(key[offset:(offset + IDENTIFIER_LENGTH)])
offset += IDENTIFIER_LENGTH + 1
task_id = hex_identifier(key[offset:(offset + IDENTIFIER_LENGTH)])
return worker_id, task_id
async def hgetall_as_dict(redis_conn, key):
fields = await redis_conn.execute("hgetall", key)
return {fields[2 * i]: fields[2 * i + 1] for i in range(len(fields) // 2)}
# Cache information about the local schedulers.
local_schedulers = {}
def duration_to_string(duration):
"""Format a duration in seconds as a string.
Args:
duration (float): The duration in seconds.
Return:
A more human-readable version of the string (for example, "3.5 hours" or
"93 milliseconds").
"""
if duration > 3600 * 24:
duration_str = "{0:0.1f} days".format(duration / (3600 * 24))
elif duration > 3600:
duration_str = "{0:0.1f} hours".format(duration / 3600)
elif duration > 60:
duration_str = "{0:0.1f} minutes".format(duration / 60)
elif duration > 1:
duration_str = "{0:0.1f} seconds".format(duration)
elif duration > 0.001:
duration_str = "{0:0.1f} milliseconds".format(duration * 1000)
else:
duration_str = "{} microseconds".format(int(duration * 1000000))
return duration_str
async def handle_get_statistics(websocket, redis_conn):
cluster_start_time = float(await redis_conn.execute("get", "redis_start_time"))
start_date = timestamp_to_date_string(cluster_start_time)
uptime = duration_to_string(time.time() - cluster_start_time)
client_keys = await redis_conn.execute("keys", "CL:*")
clients = []
for client_key in client_keys:
client_fields = await hgetall_as_dict(redis_conn, client_key)
clients.append(client_fields)
ip_addresses = list(set([client[b"node_ip_address"].decode("ascii") for client in clients if client[b"client_type"] == b"local_scheduler"]))
num_nodes = len(ip_addresses)
reply = {"uptime": uptime,
"start_date": start_date,
"nodes": num_nodes,
"addresses": ip_addresses}
await websocket.send(json.dumps(reply))
async def handle_get_drivers(websocket, redis_conn):
keys = await redis_conn.execute("keys", "Drivers:*")
drivers = []
for key in keys:
driver_fields = await hgetall_as_dict(redis_conn, key)
driver_info = {"node ip address": driver_fields[b"node_ip_address"].decode("ascii"),
"name": driver_fields[b"name"].decode("ascii")}
driver_info["start time"] = timestamp_to_date_string(float(driver_fields[b"start_time"]))
if b"end_time" in driver_fields:
duration = float(driver_fields[b"end_time"]) - float(driver_fields[b"start_time"])
else:
duration = time.time() - float(driver_fields[b"start_time"])
driver_info["duration"] = duration_to_string(duration)
if b"exception" in driver_fields:
driver_info["status"] = "FAILED"
elif b"end_time" not in driver_fields:
driver_info["status"] = "IN PROGRESS"
else:
driver_info["status"] = "SUCCESS"
if b"exception" in driver_fields:
driver_info["exception"] = driver_fields[b"exception"].decode("ascii")
drivers.append(driver_info)
# Sort the drivers by their start times.
reply = sorted(drivers, key=(lambda driver: driver["start time"]))[::-1]
await websocket.send(json.dumps(reply))
node_info = collections.OrderedDict()
worker_info = collections.OrderedDict()
async def handle_get_recent_tasks(websocket, redis_conn, num_tasks):
# First update the cache of worker information.
worker_keys = await redis_conn.execute("keys", "Workers:*")
for key in worker_keys:
worker_id = hex_identifier(key[len("Workers:"):])
if worker_id not in worker_info:
worker_info[worker_id] = await hgetall_as_dict(redis_conn, key)
node_ip_address = worker_info[worker_id][b"node_ip_address"].decode("ascii")
if node_ip_address not in node_info:
node_info[node_ip_address] = {"workers": []}
node_info[node_ip_address]["workers"].append(worker_id)
keys = await redis_conn.execute("keys", "event_log:*")
if len(keys) == 0:
# There are no tasks, so send a message to the client saying so.
await websocket.send(json.dumps({"num_tasks": 0}))
else:
timestamps = []
contents = []
for key in keys:
content = await redis_conn.execute("lrange", key, "0", "-1")
contents.append(json.loads(content[0].decode()))
timestamps += [timestamp for (timestamp, task, kind, info) in contents[-1] if task == "ray:task"]
timestamps.sort()
time_cutoff = timestamps[(-2 * num_tasks):][0]
max_time = timestamps[-1]
min_time = time_cutoff - (max_time - time_cutoff) * 0.1
max_time = max_time + (max_time - time_cutoff) * 0.1
worker_ids = list(worker_info.keys())
node_ip_addresses = list(node_info.keys())
num_tasks = 0
task_data = [{"task_data": [],
"num_workers": len(node_info[node_ip_address]["workers"])} for node_ip_address in node_ip_addresses]
for i in range(len(keys)):
worker_id, task_id = key_to_hex_identifiers(keys[i])
data = contents[i]
if worker_id not in worker_ids:
# This case should be extremely rare.
raise Exception("A worker ID was not present in the list of worker IDs.")
node_ip_address = worker_info[worker_id][b"node_ip_address"].decode("ascii")
worker_index = node_info[node_ip_address]["workers"].index(worker_id)
node_index = node_ip_addresses.index(node_ip_address)
task_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task"]
if task_times[1] <= time_cutoff:
continue
task_get_arguments_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:get_arguments"]
task_execute_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:execute"]
task_store_outputs_times = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task:store_outputs"]
task_data[node_index]["task_data"].append(
{"task": task_times,
"get_arguments": task_get_arguments_times,
"execute": task_execute_times,
"store_outputs": task_store_outputs_times,
"worker_index": worker_index,
"node_ip_address": node_ip_address,
"task_formatted_time": duration_to_string(task_times[1] - task_times[0]),
"get_arguments_formatted_time": duration_to_string(task_get_arguments_times[1] - task_get_arguments_times[0]),
"execute_formatted_time": duration_to_string(task_execute_times[1] - task_execute_times[0]),
"store_outputs_formatted_time": duration_to_string(task_store_outputs_times[1] - task_store_outputs_times[0])})
num_tasks += 1
reply = {"min_time": min_time,
"max_time": max_time,
"num_tasks": num_tasks,
"task_data": task_data}
await websocket.send(json.dumps(reply))
async def send_heartbeat_payload(websocket):
"""Send heartbeat updates to the frontend every half second."""
while True:
reply = []
for local_scheduler_id, local_scheduler in local_schedulers.items():
current_time = time.time()
local_scheduler_info = {"local scheduler ID": local_scheduler_id,
"time since heartbeat": duration_to_string(current_time - local_scheduler["last_heartbeat"]),
"time since heartbeat numeric": str(current_time - local_scheduler["last_heartbeat"]),
"node ip address": local_scheduler["node_ip_address"]}
reply.append(local_scheduler_info)
# Send the payload to the frontend.
await websocket.send(json.dumps(reply))
# Wait for a little while so as not to overwhelm the frontend.
await asyncio.sleep(0.5)
async def send_heartbeats(websocket, redis_conn):
# First update the local scheduler info locally.
client_keys = await redis_conn.execute("keys", "CL:*")
clients = []
for client_key in client_keys:
client_fields = await hgetall_as_dict(redis_conn, client_key)
if client_fields[b"client_type"] == b"local_scheduler":
local_scheduler_id = hex_identifier(client_fields[b"ray_client_id"])
local_schedulers[local_scheduler_id] = {"node_ip_address": client_fields[b"node_ip_address"].decode("ascii"),
"local_scheduler_socket_name": client_fields[b"local_scheduler_socket_name"].decode("ascii"),
"aux_address": client_fields[b"aux_address"].decode("ascii"),
"last_heartbeat": -1 * np.inf}
# Subscribe to local scheduler heartbeats.
await redis_conn.execute_pubsub("subscribe", "local_schedulers")
# Start a method in the background to periodically update the frontend.
asyncio.ensure_future(send_heartbeat_payload(websocket))
while True:
msg = await redis_conn.pubsub_channels["local_schedulers"].get()
local_scheduler_id_bytes = msg[:IDENTIFIER_LENGTH]
local_scheduler_id = hex_identifier(local_scheduler_id_bytes)
if local_scheduler_id not in local_schedulers:
# A new local scheduler has joined the cluster. Ignore it. This won't be
# displayed in the UI until the page is refreshed.
continue
local_schedulers[local_scheduler_id]["last_heartbeat"] = time.time()
async def serve_requests(websocket, path):
redis_conn = await aioredis.create_connection((redis_ip_address, redis_port), loop=loop)
# We loop infinitely because otherwise the websocket will be closed.
# TODO(rkn): Maybe we should open a new web sockets for every request instead
# of looping here.
while True:
command = json.loads(await websocket.recv())
print("received command {}".format(command))
if command["command"] == "get-statistics":
await handle_get_statistics(websocket, redis_conn)
elif command["command"] == "get-drivers":
await handle_get_drivers(websocket, redis_conn)
elif command["command"] == "get-recent-tasks":
await handle_get_recent_tasks(websocket, redis_conn, command["num"])
elif command["command"] == "get-heartbeats":
await send_heartbeats(websocket, redis_conn)
if command["command"] == "get-workers":
result = []
workers = await redis_conn.execute("keys", "WorkerInfo:*")
for key in workers:
content = await redis_conn.execute("hgetall", key)
worker_id = key_to_hex_identifier(key)
result.append({"worker": worker_id, "export_counter": int(content[1])})
await websocket.send(json.dumps(result))
elif command["command"] == "get-clients":
result = []
clients = await redis_conn.execute("keys", "CL:*")
for key in clients:
content = await redis_conn.execute("hgetall", key)
result.append({"client": hex_identifier(content[1]),
"node_ip_address": content[3].decode(),
"client_type": content[5].decode()})
await websocket.send(json.dumps(result))
elif command["command"] == "get-objects":
result = []
objects = await redis_conn.execute("keys", "OI:*")
for key in objects:
content = await redis_conn.execute("hgetall", key)
result.append({"object_id": hex_identifier(content[1]),
"hash": hex_identifier(content[3]),
"data_size": content[5].decode()})
await websocket.send(json.dumps(result))
elif command["command"] == "get-object-info":
# TODO(pcm): Get the object here (have to connect to ray) and ship content
# and type back to webclient. One challenge here is that the naive
# implementation will block the web ui backend, which is not ok if it is
# serving multiple users.
await websocket.send(json.dumps({"object_id": "none"}))
elif command["command"] == "get-tasks":
result = []
tasks = await redis_conn.execute("keys", "TT:*")
for key in tasks:
content = await redis_conn.execute("hgetall", key)
result.append({"task_id": key_to_hex_identifier(key),
"state": int(content[1]),
"node_id": hex_identifier(content[3])})
await websocket.send(json.dumps(result))
elif command["command"] == "get-timeline":
tasks = collections.defaultdict(list)
for key in await redis_conn.execute("keys", "event_log:*"):
worker_id, task_id = key_to_hex_identifiers(key)
content = await redis_conn.execute("lrange", key, "0", "-1")
data = json.loads(content[0].decode())
begin_and_end_time = [timestamp for (timestamp, task, kind, info) in data if task == "ray:task"]
tasks[worker_id].append({"task_id": task_id,
"start_task": min(begin_and_end_time),
"end_task": max(begin_and_end_time)})
await websocket.send(json.dumps(tasks))
elif command["command"] == "get-events":
result = []
for key in await redis_conn.execute("keys", "event_log:*"):
worker_id, task_id = key_to_hex_identifiers(key)
answer = await redis_conn.execute("lrange", key, "0", "-1")
assert len(answer) == 1
events = json.loads(answer[0].decode())
result.extend([{"worker_id": worker_id,
"task_id": task_id,
"time": event[0],
"type": event[1]} for event in events])
await websocket.send(json.dumps(result))
if __name__ == "__main__":
args = parser.parse_args()
redis_address = args.redis_address.split(":")
redis_ip_address, redis_port = redis_address[0], int(redis_address[1])
# The port here must match the value used by the frontend to connect over
# websockets.
port = 8888
start_server = websockets.serve(serve_requests, "localhost", port)
loop.run_until_complete(start_server)
loop.run_forever()