[Dashboard] Remove token authentication from dashboard (#5888)

This commit is contained in:
Mitchell Stern 2019-10-21 12:48:48 -07:00 committed by Philipp Moritz
parent 26a724c5e6
commit 235dec8aa3
7 changed files with 66 additions and 62 deletions

View file

@ -51,58 +51,29 @@ class Dashboard(object):
"""
def __init__(self,
host,
port,
redis_address,
http_port,
token,
temp_dir,
redis_password=None):
"""Initialize the dashboard object."""
self.ip = ray.services.get_node_ip_address()
self.port = http_port
self.token = token
self.host = host
self.port = port
self.redis_client = ray.services.create_redis_client(
redis_address, password=redis_password)
self.temp_dir = temp_dir
self.node_stats = NodeStats(redis_address, redis_password)
# Setting the environment variable RAY_DASHBOARD_DEV=1 disables some
# security checks in the dashboard server to ease development while
# using the React dev server. Specifically, when this option is set, we
# disable the token-based authentication mechanism and allow
# cross-origin requests to be made.
# allow cross-origin requests to be made.
self.is_dev = os.environ.get("RAY_DASHBOARD_DEV") == "1"
self.app = aiohttp.web.Application(
middlewares=[] if self.is_dev else [self.auth_middleware])
self.app = aiohttp.web.Application()
self.setup_routes()
@aiohttp.web.middleware
async def auth_middleware(self, req, handler):
def valid_token(req):
# If the cookie token is correct, accept that.
try:
if req.cookies["token"] == self.token:
return True
except KeyError:
pass
# If the query token is correct, accept that.
try:
if req.query["token"] == self.token:
return True
except KeyError:
pass
# Reject.
logger.warning("Dashboard: rejected an invalid token")
return False
# Check that the token is present, either in query or as cookie.
if not valid_token(req):
return aiohttp.web.Response(status=401, text="401 Unauthorized")
resp = await handler(req)
resp.cookies["token"] = self.token
return resp
def setup_routes(self):
def forbidden() -> aiohttp.web.Response:
return aiohttp.web.Response(status=403, text="403 Forbidden")
@ -197,7 +168,7 @@ class Dashboard(object):
self.app.router.add_get("/{_}", get_forbidden)
def log_dashboard_url(self):
url = "http://{}:{}?token={}".format(self.ip, self.port, self.token)
url = ray.services.get_webui_url_from_redis(self.redis_client)
with open(os.path.join(self.temp_dir, "dashboard_url"), "w") as f:
f.write(url)
logger.info("Dashboard running on {}".format(url))
@ -205,7 +176,7 @@ class Dashboard(object):
def run(self):
self.log_dashboard_url()
self.node_stats.start()
aiohttp.web.run_app(self.app, host="0.0.0.0", port=self.port)
aiohttp.web.run_app(self.app, host=self.host, port=self.port)
class NodeStats(threading.Thread):
@ -383,15 +354,16 @@ if __name__ == "__main__":
description=("Parse Redis server for the "
"dashboard to connect to."))
parser.add_argument(
"--http-port",
"--host",
required=True,
type=str,
choices=["127.0.0.1", "0.0.0.0"],
help="The host to use for the HTTP server.")
parser.add_argument(
"--port",
required=True,
type=int,
help="The port to use for the HTTP server.")
parser.add_argument(
"--token",
required=True,
type=str,
help="The token to use for the HTTP server.")
parser.add_argument(
"--redis-address",
required=True,
@ -427,9 +399,9 @@ if __name__ == "__main__":
try:
dashboard = Dashboard(
args.host,
args.port,
args.redis_address,
args.http_port,
args.token,
args.temp_dir,
redis_password=args.redis_password,
)

View file

@ -465,6 +465,7 @@ class Node(object):
"""Start the dashboard."""
stdout_file, stderr_file = self.new_log_files("dashboard", True)
self._webui_url, process_info = ray.services.start_dashboard(
self._ray_params.webui_host,
self.redis_address,
self._temp_dir,
stdout_file=stdout_file,

View file

@ -57,6 +57,10 @@ class RayParams(object):
Store with hugetlbfs support. Requires plasma_directory.
include_webui: Boolean flag indicating whether to start the web
UI, which displays the status of the Ray cluster.
webui_host: The host to bind the web UI server to. Can either be
127.0.0.1 (localhost) or 0.0.0.0 (available from all interfaces).
By default, this is set to 127.0.0.1 to prevent access from
external machines.
logging_level: Logging level, default will be logging.INFO.
logging_format: Logging format, default contains a timestamp,
filename, line number, and message. See ray_constants.py.
@ -104,6 +108,7 @@ class RayParams(object):
worker_path=None,
huge_pages=False,
include_webui=None,
webui_host="127.0.0.1",
logging_level=logging.INFO,
logging_format=ray_constants.LOGGER_FORMAT,
plasma_store_socket_name=None,
@ -140,6 +145,7 @@ class RayParams(object):
self.worker_path = worker_path
self.huge_pages = huge_pages
self.include_webui = include_webui
self.webui_host = webui_host
self.plasma_store_socket_name = plasma_store_socket_name
self.raylet_socket_name = raylet_socket_name
self.temp_dir = temp_dir

View file

@ -162,6 +162,14 @@ def cli(logging_level, logging_format):
is_flag=True,
default=False,
help="provide this argument if the UI should be started")
@click.option(
"--webui-host",
required=False,
type=click.Choice(["127.0.0.1", "0.0.0.0"]),
default="127.0.0.1",
help="The host to bind the web UI server to. Can either be 127.0.0.1 "
"(localhost) or 0.0.0.0 (available from all interfaces). By default, this "
"is set to 127.0.0.1 to prevent access from external machines.")
@click.option(
"--block",
is_flag=True,
@ -234,7 +242,7 @@ def start(node_ip_address, redis_address, address, redis_port,
num_redis_shards, redis_max_clients, redis_password,
redis_shard_ports, object_manager_port, node_manager_port, memory,
object_store_memory, redis_max_memory, num_cpus, num_gpus, resources,
head, include_webui, block, plasma_directory, huge_pages,
head, include_webui, webui_host, block, plasma_directory, huge_pages,
autoscaling_config, no_redirect_worker_output, no_redirect_output,
plasma_store_socket_name, raylet_socket_name, temp_dir, include_java,
java_worker_options, load_code_from_local, use_pickle,
@ -277,6 +285,7 @@ def start(node_ip_address, redis_address, address, redis_port,
temp_dir=temp_dir,
include_java=include_java,
include_webui=include_webui,
webui_host=webui_host,
java_worker_options=java_worker_options,
load_code_from_local=load_code_from_local,
use_pickle=use_pickle,

View file

@ -2,7 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import binascii
import collections
import json
import logging
@ -13,6 +12,7 @@ import resource
import socket
import subprocess
import sys
import textwrap
import time
import redis
@ -545,7 +545,7 @@ def check_version_info(redis_client):
true_version_info = tuple(json.loads(ray.utils.decode(redis_reply)))
version_info = _compute_version_info()
if version_info != true_version_info:
node_ip_address = ray.services.get_node_ip_address()
node_ip_address = get_node_ip_address()
error_message = ("Version mismatch: The cluster was started with:\n"
" Ray: " + true_version_info[0] + "\n"
" Python: " + true_version_info[1] + "\n"
@ -972,7 +972,8 @@ def start_reporter(redis_address,
return process_info
def start_dashboard(redis_address,
def start_dashboard(host,
redis_address,
temp_dir,
stdout_file=None,
stderr_file=None,
@ -980,6 +981,7 @@ def start_dashboard(redis_address,
"""Start a dashboard process.
Args:
host (str): The host to bind the dashboard web server to.
redis_address (str): The address of the Redis instance.
temp_dir (str): The temporary directory used for log files and
information for this Ray session.
@ -1002,17 +1004,15 @@ def start_dashboard(redis_address,
except socket.error:
port += 1
token = ray.utils.decode(binascii.hexlify(os.urandom(24)))
dashboard_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "dashboard/dashboard.py")
command = [
sys.executable,
"-u",
dashboard_filepath,
"--host={}".format(host),
"--port={}".format(port),
"--redis-address={}".format(redis_address),
"--http-port={}".format(port),
"--token={}".format(token),
"--temp-dir={}".format(temp_dir),
]
if redis_password:
@ -1034,10 +1034,20 @@ def start_dashboard(redis_address,
ray_constants.PROCESS_TYPE_DASHBOARD,
stdout_file=stdout_file,
stderr_file=stderr_file)
dashboard_url = "http://{}:{}/?token={}".format(
ray.services.get_node_ip_address(), port, token)
dashboard_url = "http://{}:{}".format(
host if host == "127.0.0.1" else get_node_ip_address(), port)
print("\n" + "=" * 70)
print("View the dashboard at {}".format(dashboard_url))
print("View the dashboard at {}.".format(dashboard_url))
if host == "127.0.0.1":
note = (
"Note: If Ray is running on a remote node, you will need to set "
"up an SSH tunnel with local port forwarding in order to access "
"the dashboard in your browser, e.g. by running "
"'ssh -L {}:{}:{} <username>@<host>'. Alternatively, you can set "
"webui_host=\"0.0.0.0\" in the call to ray.init() to allow direct "
"access from external machines.")
note = note.format(port, host, port)
print("\n".join(textwrap.wrap(note, width=70)))
print("=" * 70 + "\n")
return dashboard_url, process_info

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import sys
import time
@ -18,13 +19,12 @@ def test_get_webui(shutdown_only):
webui_url = addresses["webui_url"]
assert ray.get_webui_url() == webui_url
base, token = webui_url.split("?")
assert token.startswith("token=")
assert re.match(r"^http://\d+\.\d+\.\d+\.\d+:8080$", webui_url)
start_time = time.time()
while True:
try:
node_info = requests.get(base + "api/node_info?" + token).json()
node_info = requests.get(webui_url + "/api/node_info").json()
break
except requests.exceptions.ConnectionError:
if time.time() > start_time + 30:

View file

@ -1160,6 +1160,7 @@ def init(address=None,
plasma_directory=None,
huge_pages=False,
include_webui=False,
webui_host="127.0.0.1",
job_id=None,
configure_logging=True,
logging_level=logging.INFO,
@ -1239,6 +1240,10 @@ def init(address=None,
Store with hugetlbfs support. Requires plasma_directory.
include_webui: Boolean flag indicating whether to start the web
UI, which displays the status of the Ray cluster.
webui_host: The host to bind the web UI server to. Can either be
127.0.0.1 (localhost) or 0.0.0.0 (available from all interfaces).
By default, this is set to 127.0.0.1 to prevent access from
external machines.
job_id: The ID of this job.
configure_logging: True if allow the logging cofiguration here.
Otherwise, the users may want to configure it by their own.
@ -1321,6 +1326,7 @@ def init(address=None,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
include_webui=include_webui,
webui_host=webui_host,
memory=memory,
object_store_memory=object_store_memory,
redis_max_memory=redis_max_memory,