mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Dashboard] Remove token authentication from dashboard (#5888)
This commit is contained in:
parent
26a724c5e6
commit
235dec8aa3
7 changed files with 66 additions and 62 deletions
|
@ -51,58 +51,29 @@ class Dashboard(object):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
host,
|
||||
port,
|
||||
redis_address,
|
||||
http_port,
|
||||
token,
|
||||
temp_dir,
|
||||
redis_password=None):
|
||||
"""Initialize the dashboard object."""
|
||||
self.ip = ray.services.get_node_ip_address()
|
||||
self.port = http_port
|
||||
self.token = token
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.redis_client = ray.services.create_redis_client(
|
||||
redis_address, password=redis_password)
|
||||
self.temp_dir = temp_dir
|
||||
|
||||
self.node_stats = NodeStats(redis_address, redis_password)
|
||||
|
||||
# Setting the environment variable RAY_DASHBOARD_DEV=1 disables some
|
||||
# security checks in the dashboard server to ease development while
|
||||
# using the React dev server. Specifically, when this option is set, we
|
||||
# disable the token-based authentication mechanism and allow
|
||||
# cross-origin requests to be made.
|
||||
# allow cross-origin requests to be made.
|
||||
self.is_dev = os.environ.get("RAY_DASHBOARD_DEV") == "1"
|
||||
|
||||
self.app = aiohttp.web.Application(
|
||||
middlewares=[] if self.is_dev else [self.auth_middleware])
|
||||
self.app = aiohttp.web.Application()
|
||||
self.setup_routes()
|
||||
|
||||
@aiohttp.web.middleware
|
||||
async def auth_middleware(self, req, handler):
|
||||
def valid_token(req):
|
||||
# If the cookie token is correct, accept that.
|
||||
try:
|
||||
if req.cookies["token"] == self.token:
|
||||
return True
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# If the query token is correct, accept that.
|
||||
try:
|
||||
if req.query["token"] == self.token:
|
||||
return True
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Reject.
|
||||
logger.warning("Dashboard: rejected an invalid token")
|
||||
return False
|
||||
|
||||
# Check that the token is present, either in query or as cookie.
|
||||
if not valid_token(req):
|
||||
return aiohttp.web.Response(status=401, text="401 Unauthorized")
|
||||
|
||||
resp = await handler(req)
|
||||
resp.cookies["token"] = self.token
|
||||
return resp
|
||||
|
||||
def setup_routes(self):
|
||||
def forbidden() -> aiohttp.web.Response:
|
||||
return aiohttp.web.Response(status=403, text="403 Forbidden")
|
||||
|
@ -197,7 +168,7 @@ class Dashboard(object):
|
|||
self.app.router.add_get("/{_}", get_forbidden)
|
||||
|
||||
def log_dashboard_url(self):
|
||||
url = "http://{}:{}?token={}".format(self.ip, self.port, self.token)
|
||||
url = ray.services.get_webui_url_from_redis(self.redis_client)
|
||||
with open(os.path.join(self.temp_dir, "dashboard_url"), "w") as f:
|
||||
f.write(url)
|
||||
logger.info("Dashboard running on {}".format(url))
|
||||
|
@ -205,7 +176,7 @@ class Dashboard(object):
|
|||
def run(self):
|
||||
self.log_dashboard_url()
|
||||
self.node_stats.start()
|
||||
aiohttp.web.run_app(self.app, host="0.0.0.0", port=self.port)
|
||||
aiohttp.web.run_app(self.app, host=self.host, port=self.port)
|
||||
|
||||
|
||||
class NodeStats(threading.Thread):
|
||||
|
@ -383,15 +354,16 @@ if __name__ == "__main__":
|
|||
description=("Parse Redis server for the "
|
||||
"dashboard to connect to."))
|
||||
parser.add_argument(
|
||||
"--http-port",
|
||||
"--host",
|
||||
required=True,
|
||||
type=str,
|
||||
choices=["127.0.0.1", "0.0.0.0"],
|
||||
help="The host to use for the HTTP server.")
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
required=True,
|
||||
type=int,
|
||||
help="The port to use for the HTTP server.")
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The token to use for the HTTP server.")
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
|
@ -427,9 +399,9 @@ if __name__ == "__main__":
|
|||
|
||||
try:
|
||||
dashboard = Dashboard(
|
||||
args.host,
|
||||
args.port,
|
||||
args.redis_address,
|
||||
args.http_port,
|
||||
args.token,
|
||||
args.temp_dir,
|
||||
redis_password=args.redis_password,
|
||||
)
|
||||
|
|
|
@ -465,6 +465,7 @@ class Node(object):
|
|||
"""Start the dashboard."""
|
||||
stdout_file, stderr_file = self.new_log_files("dashboard", True)
|
||||
self._webui_url, process_info = ray.services.start_dashboard(
|
||||
self._ray_params.webui_host,
|
||||
self.redis_address,
|
||||
self._temp_dir,
|
||||
stdout_file=stdout_file,
|
||||
|
|
|
@ -57,6 +57,10 @@ class RayParams(object):
|
|||
Store with hugetlbfs support. Requires plasma_directory.
|
||||
include_webui: Boolean flag indicating whether to start the web
|
||||
UI, which displays the status of the Ray cluster.
|
||||
webui_host: The host to bind the web UI server to. Can either be
|
||||
127.0.0.1 (localhost) or 0.0.0.0 (available from all interfaces).
|
||||
By default, this is set to 127.0.0.1 to prevent access from
|
||||
external machines.
|
||||
logging_level: Logging level, default will be logging.INFO.
|
||||
logging_format: Logging format, default contains a timestamp,
|
||||
filename, line number, and message. See ray_constants.py.
|
||||
|
@ -104,6 +108,7 @@ class RayParams(object):
|
|||
worker_path=None,
|
||||
huge_pages=False,
|
||||
include_webui=None,
|
||||
webui_host="127.0.0.1",
|
||||
logging_level=logging.INFO,
|
||||
logging_format=ray_constants.LOGGER_FORMAT,
|
||||
plasma_store_socket_name=None,
|
||||
|
@ -140,6 +145,7 @@ class RayParams(object):
|
|||
self.worker_path = worker_path
|
||||
self.huge_pages = huge_pages
|
||||
self.include_webui = include_webui
|
||||
self.webui_host = webui_host
|
||||
self.plasma_store_socket_name = plasma_store_socket_name
|
||||
self.raylet_socket_name = raylet_socket_name
|
||||
self.temp_dir = temp_dir
|
||||
|
|
|
@ -162,6 +162,14 @@ def cli(logging_level, logging_format):
|
|||
is_flag=True,
|
||||
default=False,
|
||||
help="provide this argument if the UI should be started")
|
||||
@click.option(
|
||||
"--webui-host",
|
||||
required=False,
|
||||
type=click.Choice(["127.0.0.1", "0.0.0.0"]),
|
||||
default="127.0.0.1",
|
||||
help="The host to bind the web UI server to. Can either be 127.0.0.1 "
|
||||
"(localhost) or 0.0.0.0 (available from all interfaces). By default, this "
|
||||
"is set to 127.0.0.1 to prevent access from external machines.")
|
||||
@click.option(
|
||||
"--block",
|
||||
is_flag=True,
|
||||
|
@ -234,7 +242,7 @@ def start(node_ip_address, redis_address, address, redis_port,
|
|||
num_redis_shards, redis_max_clients, redis_password,
|
||||
redis_shard_ports, object_manager_port, node_manager_port, memory,
|
||||
object_store_memory, redis_max_memory, num_cpus, num_gpus, resources,
|
||||
head, include_webui, block, plasma_directory, huge_pages,
|
||||
head, include_webui, webui_host, block, plasma_directory, huge_pages,
|
||||
autoscaling_config, no_redirect_worker_output, no_redirect_output,
|
||||
plasma_store_socket_name, raylet_socket_name, temp_dir, include_java,
|
||||
java_worker_options, load_code_from_local, use_pickle,
|
||||
|
@ -277,6 +285,7 @@ def start(node_ip_address, redis_address, address, redis_port,
|
|||
temp_dir=temp_dir,
|
||||
include_java=include_java,
|
||||
include_webui=include_webui,
|
||||
webui_host=webui_host,
|
||||
java_worker_options=java_worker_options,
|
||||
load_code_from_local=load_code_from_local,
|
||||
use_pickle=use_pickle,
|
||||
|
|
|
@ -2,7 +2,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import binascii
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
|
@ -13,6 +12,7 @@ import resource
|
|||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import redis
|
||||
|
||||
|
@ -545,7 +545,7 @@ def check_version_info(redis_client):
|
|||
true_version_info = tuple(json.loads(ray.utils.decode(redis_reply)))
|
||||
version_info = _compute_version_info()
|
||||
if version_info != true_version_info:
|
||||
node_ip_address = ray.services.get_node_ip_address()
|
||||
node_ip_address = get_node_ip_address()
|
||||
error_message = ("Version mismatch: The cluster was started with:\n"
|
||||
" Ray: " + true_version_info[0] + "\n"
|
||||
" Python: " + true_version_info[1] + "\n"
|
||||
|
@ -972,7 +972,8 @@ def start_reporter(redis_address,
|
|||
return process_info
|
||||
|
||||
|
||||
def start_dashboard(redis_address,
|
||||
def start_dashboard(host,
|
||||
redis_address,
|
||||
temp_dir,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
|
@ -980,6 +981,7 @@ def start_dashboard(redis_address,
|
|||
"""Start a dashboard process.
|
||||
|
||||
Args:
|
||||
host (str): The host to bind the dashboard web server to.
|
||||
redis_address (str): The address of the Redis instance.
|
||||
temp_dir (str): The temporary directory used for log files and
|
||||
information for this Ray session.
|
||||
|
@ -1002,17 +1004,15 @@ def start_dashboard(redis_address,
|
|||
except socket.error:
|
||||
port += 1
|
||||
|
||||
token = ray.utils.decode(binascii.hexlify(os.urandom(24)))
|
||||
|
||||
dashboard_filepath = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "dashboard/dashboard.py")
|
||||
command = [
|
||||
sys.executable,
|
||||
"-u",
|
||||
dashboard_filepath,
|
||||
"--host={}".format(host),
|
||||
"--port={}".format(port),
|
||||
"--redis-address={}".format(redis_address),
|
||||
"--http-port={}".format(port),
|
||||
"--token={}".format(token),
|
||||
"--temp-dir={}".format(temp_dir),
|
||||
]
|
||||
if redis_password:
|
||||
|
@ -1034,10 +1034,20 @@ def start_dashboard(redis_address,
|
|||
ray_constants.PROCESS_TYPE_DASHBOARD,
|
||||
stdout_file=stdout_file,
|
||||
stderr_file=stderr_file)
|
||||
dashboard_url = "http://{}:{}/?token={}".format(
|
||||
ray.services.get_node_ip_address(), port, token)
|
||||
dashboard_url = "http://{}:{}".format(
|
||||
host if host == "127.0.0.1" else get_node_ip_address(), port)
|
||||
print("\n" + "=" * 70)
|
||||
print("View the dashboard at {}".format(dashboard_url))
|
||||
print("View the dashboard at {}.".format(dashboard_url))
|
||||
if host == "127.0.0.1":
|
||||
note = (
|
||||
"Note: If Ray is running on a remote node, you will need to set "
|
||||
"up an SSH tunnel with local port forwarding in order to access "
|
||||
"the dashboard in your browser, e.g. by running "
|
||||
"'ssh -L {}:{}:{} <username>@<host>'. Alternatively, you can set "
|
||||
"webui_host=\"0.0.0.0\" in the call to ray.init() to allow direct "
|
||||
"access from external machines.")
|
||||
note = note.format(port, host, port)
|
||||
print("\n".join(textwrap.wrap(note, width=70)))
|
||||
print("=" * 70 + "\n")
|
||||
return dashboard_url, process_info
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
@ -18,13 +19,12 @@ def test_get_webui(shutdown_only):
|
|||
webui_url = addresses["webui_url"]
|
||||
assert ray.get_webui_url() == webui_url
|
||||
|
||||
base, token = webui_url.split("?")
|
||||
assert token.startswith("token=")
|
||||
assert re.match(r"^http://\d+\.\d+\.\d+\.\d+:8080$", webui_url)
|
||||
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
node_info = requests.get(base + "api/node_info?" + token).json()
|
||||
node_info = requests.get(webui_url + "/api/node_info").json()
|
||||
break
|
||||
except requests.exceptions.ConnectionError:
|
||||
if time.time() > start_time + 30:
|
||||
|
|
|
@ -1160,6 +1160,7 @@ def init(address=None,
|
|||
plasma_directory=None,
|
||||
huge_pages=False,
|
||||
include_webui=False,
|
||||
webui_host="127.0.0.1",
|
||||
job_id=None,
|
||||
configure_logging=True,
|
||||
logging_level=logging.INFO,
|
||||
|
@ -1239,6 +1240,10 @@ def init(address=None,
|
|||
Store with hugetlbfs support. Requires plasma_directory.
|
||||
include_webui: Boolean flag indicating whether to start the web
|
||||
UI, which displays the status of the Ray cluster.
|
||||
webui_host: The host to bind the web UI server to. Can either be
|
||||
127.0.0.1 (localhost) or 0.0.0.0 (available from all interfaces).
|
||||
By default, this is set to 127.0.0.1 to prevent access from
|
||||
external machines.
|
||||
job_id: The ID of this job.
|
||||
configure_logging: True if allow the logging cofiguration here.
|
||||
Otherwise, the users may want to configure it by their own.
|
||||
|
@ -1321,6 +1326,7 @@ def init(address=None,
|
|||
plasma_directory=plasma_directory,
|
||||
huge_pages=huge_pages,
|
||||
include_webui=include_webui,
|
||||
webui_host=webui_host,
|
||||
memory=memory,
|
||||
object_store_memory=object_store_memory,
|
||||
redis_max_memory=redis_max_memory,
|
||||
|
|
Loading…
Add table
Reference in a new issue