mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
This commit is contained in:
parent
95877be8ee
commit
7f1bacc7dc
1637 changed files with 75167 additions and 59155 deletions
|
@ -39,14 +39,16 @@ def perform_auth():
|
||||||
resp = requests.get(
|
resp = requests.get(
|
||||||
"https://vop4ss7n22.execute-api.us-west-2.amazonaws.com/endpoint/",
|
"https://vop4ss7n22.execute-api.us-west-2.amazonaws.com/endpoint/",
|
||||||
auth=auth,
|
auth=auth,
|
||||||
params={"job_id": os.environ["BUILDKITE_JOB_ID"]})
|
params={"job_id": os.environ["BUILDKITE_JOB_ID"]},
|
||||||
|
)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
def handle_docker_login(resp):
|
def handle_docker_login(resp):
|
||||||
pwd = resp.json()["docker_password"]
|
pwd = resp.json()["docker_password"]
|
||||||
subprocess.call(
|
subprocess.call(
|
||||||
["docker", "login", "--username", "raytravisbot", "--password", pwd])
|
["docker", "login", "--username", "raytravisbot", "--password", pwd]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def gather_paths(dir_path) -> List[str]:
|
def gather_paths(dir_path) -> List[str]:
|
||||||
|
@ -86,7 +88,7 @@ def upload_paths(paths, resp, destination):
|
||||||
"branch_wheels": f"{branch}/{sha}/{fn}",
|
"branch_wheels": f"{branch}/{sha}/{fn}",
|
||||||
"jars": f"jars/latest/{current_os}/{fn}",
|
"jars": f"jars/latest/{current_os}/{fn}",
|
||||||
"branch_jars": f"jars/{branch}/{sha}/{current_os}/{fn}",
|
"branch_jars": f"jars/{branch}/{sha}/{current_os}/{fn}",
|
||||||
"logs": f"bazel_events/{branch}/{sha}/{bk_job_id}/{fn}"
|
"logs": f"bazel_events/{branch}/{sha}/{bk_job_id}/{fn}",
|
||||||
}[destination]
|
}[destination]
|
||||||
of["file"] = open(path, "rb")
|
of["file"] = open(path, "rb")
|
||||||
r = requests.post(c["url"], files=of)
|
r = requests.post(c["url"], files=of)
|
||||||
|
@ -95,14 +97,19 @@ def upload_paths(paths, resp, destination):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Helper script to upload files to S3 bucket")
|
description="Helper script to upload files to S3 bucket"
|
||||||
|
)
|
||||||
parser.add_argument("--path", type=str, required=False)
|
parser.add_argument("--path", type=str, required=False)
|
||||||
parser.add_argument("--destination", type=str)
|
parser.add_argument("--destination", type=str)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assert args.destination in {
|
assert args.destination in {
|
||||||
"branch_jars", "branch_wheels", "jars", "logs", "wheels",
|
"branch_jars",
|
||||||
"docker_login"
|
"branch_wheels",
|
||||||
|
"jars",
|
||||||
|
"logs",
|
||||||
|
"wheels",
|
||||||
|
"docker_login",
|
||||||
}
|
}
|
||||||
assert "BUILDKITE_JOB_ID" in os.environ
|
assert "BUILDKITE_JOB_ID" in os.environ
|
||||||
assert "BUILDKITE_COMMIT" in os.environ
|
assert "BUILDKITE_COMMIT" in os.environ
|
||||||
|
|
|
@ -51,8 +51,10 @@ del monitor_actor
|
||||||
test_utils.wait_for_condition(no_resource_leaks)
|
test_utils.wait_for_condition(no_resource_leaks)
|
||||||
|
|
||||||
rate = MAX_ACTORS_IN_CLUSTER / (end_time - start_time)
|
rate = MAX_ACTORS_IN_CLUSTER / (end_time - start_time)
|
||||||
print(f"Success! Started {MAX_ACTORS_IN_CLUSTER} actors in "
|
print(
|
||||||
f"{end_time - start_time}s. ({rate} actors/s)")
|
f"Success! Started {MAX_ACTORS_IN_CLUSTER} actors in "
|
||||||
|
f"{end_time - start_time}s. ({rate} actors/s)"
|
||||||
|
)
|
||||||
|
|
||||||
if "TEST_OUTPUT_JSON" in os.environ:
|
if "TEST_OUTPUT_JSON" in os.environ:
|
||||||
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
||||||
|
@ -62,6 +64,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
|
||||||
"time": end_time - start_time,
|
"time": end_time - start_time,
|
||||||
"success": "1",
|
"success": "1",
|
||||||
"_peak_memory": round(used_gb, 2),
|
"_peak_memory": round(used_gb, 2),
|
||||||
"_peak_process_memory": usage
|
"_peak_process_memory": usage,
|
||||||
}
|
}
|
||||||
json.dump(results, out_file)
|
json.dump(results, out_file)
|
||||||
|
|
|
@ -77,8 +77,10 @@ del monitor_actor
|
||||||
test_utils.wait_for_condition(no_resource_leaks)
|
test_utils.wait_for_condition(no_resource_leaks)
|
||||||
|
|
||||||
rate = MAX_PLACEMENT_GROUPS / (end_time - start_time)
|
rate = MAX_PLACEMENT_GROUPS / (end_time - start_time)
|
||||||
print(f"Success! Started {MAX_PLACEMENT_GROUPS} pgs in "
|
print(
|
||||||
f"{end_time - start_time}s. ({rate} pgs/s)")
|
f"Success! Started {MAX_PLACEMENT_GROUPS} pgs in "
|
||||||
|
f"{end_time - start_time}s. ({rate} pgs/s)"
|
||||||
|
)
|
||||||
|
|
||||||
if "TEST_OUTPUT_JSON" in os.environ:
|
if "TEST_OUTPUT_JSON" in os.environ:
|
||||||
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
||||||
|
@ -88,6 +90,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
|
||||||
"time": end_time - start_time,
|
"time": end_time - start_time,
|
||||||
"success": "1",
|
"success": "1",
|
||||||
"_peak_memory": round(used_gb, 2),
|
"_peak_memory": round(used_gb, 2),
|
||||||
"_peak_process_memory": usage
|
"_peak_process_memory": usage,
|
||||||
}
|
}
|
||||||
json.dump(results, out_file)
|
json.dump(results, out_file)
|
||||||
|
|
|
@ -16,9 +16,7 @@ def test_max_running_tasks(num_tasks):
|
||||||
def task():
|
def task():
|
||||||
time.sleep(sleep_time)
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
refs = [
|
refs = [task.remote() for _ in tqdm.trange(num_tasks, desc="Launching tasks")]
|
||||||
task.remote() for _ in tqdm.trange(num_tasks, desc="Launching tasks")
|
|
||||||
]
|
|
||||||
|
|
||||||
max_cpus = ray.cluster_resources()["CPU"]
|
max_cpus = ray.cluster_resources()["CPU"]
|
||||||
min_cpus_available = max_cpus
|
min_cpus_available = max_cpus
|
||||||
|
@ -48,8 +46,7 @@ def no_resource_leaks():
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option(
|
@click.option("--num-tasks", required=True, type=int, help="Number of tasks to launch.")
|
||||||
"--num-tasks", required=True, type=int, help="Number of tasks to launch.")
|
|
||||||
def test(num_tasks):
|
def test(num_tasks):
|
||||||
ray.init(address="auto")
|
ray.init(address="auto")
|
||||||
|
|
||||||
|
@ -66,8 +63,10 @@ def test(num_tasks):
|
||||||
test_utils.wait_for_condition(no_resource_leaks)
|
test_utils.wait_for_condition(no_resource_leaks)
|
||||||
|
|
||||||
rate = num_tasks / (end_time - start_time - sleep_time)
|
rate = num_tasks / (end_time - start_time - sleep_time)
|
||||||
print(f"Success! Started {num_tasks} tasks in {end_time - start_time}s. "
|
print(
|
||||||
f"({rate} tasks/s)")
|
f"Success! Started {num_tasks} tasks in {end_time - start_time}s. "
|
||||||
|
f"({rate} tasks/s)"
|
||||||
|
)
|
||||||
|
|
||||||
if "TEST_OUTPUT_JSON" in os.environ:
|
if "TEST_OUTPUT_JSON" in os.environ:
|
||||||
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
||||||
|
@ -77,7 +76,7 @@ def test(num_tasks):
|
||||||
"time": end_time - start_time,
|
"time": end_time - start_time,
|
||||||
"success": "1",
|
"success": "1",
|
||||||
"_peak_memory": round(used_gb, 2),
|
"_peak_memory": round(used_gb, 2),
|
||||||
"_peak_process_memory": usage
|
"_peak_process_memory": usage,
|
||||||
}
|
}
|
||||||
json.dump(results, out_file)
|
json.dump(results, out_file)
|
||||||
|
|
||||||
|
|
|
@ -25,10 +25,12 @@ class SimpleActor:
|
||||||
|
|
||||||
|
|
||||||
def start_tasks(num_task, num_cpu_per_task, task_duration):
|
def start_tasks(num_task, num_cpu_per_task, task_duration):
|
||||||
ray.get([
|
ray.get(
|
||||||
simple_task.options(num_cpus=num_cpu_per_task).remote(task_duration)
|
[
|
||||||
for _ in range(num_task)
|
simple_task.options(num_cpus=num_cpu_per_task).remote(task_duration)
|
||||||
])
|
for _ in range(num_task)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def measure(f):
|
def measure(f):
|
||||||
|
@ -40,13 +42,16 @@ def measure(f):
|
||||||
|
|
||||||
def start_actor(num_actors, num_actors_per_nodes, job):
|
def start_actor(num_actors, num_actors_per_nodes, job):
|
||||||
resources = {"node": floor(1.0 / num_actors_per_nodes)}
|
resources = {"node": floor(1.0 / num_actors_per_nodes)}
|
||||||
submission_cost, actors = measure(lambda: [
|
submission_cost, actors = measure(
|
||||||
SimpleActor.options(resources=resources, num_cpus=0).remote(job)
|
lambda: [
|
||||||
for _ in range(num_actors)])
|
SimpleActor.options(resources=resources, num_cpus=0).remote(job)
|
||||||
ready_cost, _ = measure(
|
for _ in range(num_actors)
|
||||||
lambda: ray.get([actor.ready.remote() for actor in actors]))
|
]
|
||||||
|
)
|
||||||
|
ready_cost, _ = measure(lambda: ray.get([actor.ready.remote() for actor in actors]))
|
||||||
actor_job_cost, _ = measure(
|
actor_job_cost, _ = measure(
|
||||||
lambda: ray.get([actor.do_job.remote() for actor in actors]))
|
lambda: ray.get([actor.do_job.remote() for actor in actors])
|
||||||
|
)
|
||||||
return (submission_cost, ready_cost, actor_job_cost)
|
return (submission_cost, ready_cost, actor_job_cost)
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,33 +59,32 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(prog="Test Scheduling")
|
parser = argparse.ArgumentParser(prog="Test Scheduling")
|
||||||
# Task workloads
|
# Task workloads
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--total-num-task",
|
"--total-num-task", type=int, help="Total number of tasks.", required=False
|
||||||
type=int,
|
)
|
||||||
help="Total number of tasks.",
|
|
||||||
required=False)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-cpu-per-task",
|
"--num-cpu-per-task",
|
||||||
type=int,
|
type=int,
|
||||||
help="Resources needed for tasks.",
|
help="Resources needed for tasks.",
|
||||||
required=False)
|
required=False,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--task-duration-s",
|
"--task-duration-s",
|
||||||
type=int,
|
type=int,
|
||||||
help="How long does each task execute.",
|
help="How long does each task execute.",
|
||||||
required=False,
|
required=False,
|
||||||
default=1)
|
default=1,
|
||||||
|
)
|
||||||
|
|
||||||
# Actor workloads
|
# Actor workloads
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--total-num-actors",
|
"--total-num-actors", type=int, help="Total number of actors.", required=True
|
||||||
type=int,
|
)
|
||||||
help="Total number of actors.",
|
|
||||||
required=True)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-actors-per-nodes",
|
"--num-actors-per-nodes",
|
||||||
type=int,
|
type=int,
|
||||||
help="How many actors to allocate for each nodes.",
|
help="How many actors to allocate for each nodes.",
|
||||||
required=True)
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
ray.init(address="auto")
|
ray.init(address="auto")
|
||||||
|
|
||||||
|
@ -92,13 +96,14 @@ if __name__ == "__main__":
|
||||||
job = None
|
job = None
|
||||||
if args.total_num_task is not None:
|
if args.total_num_task is not None:
|
||||||
if args.num_cpu_per_task is None:
|
if args.num_cpu_per_task is None:
|
||||||
args.num_cpu_per_task = floor(
|
args.num_cpu_per_task = floor(1.0 * total_cpus / args.total_num_task)
|
||||||
1.0 * total_cpus / args.total_num_task)
|
|
||||||
job = lambda: start_tasks( # noqa: E731
|
job = lambda: start_tasks( # noqa: E731
|
||||||
args.total_num_task, args.num_cpu_per_task, args.task_duration_s)
|
args.total_num_task, args.num_cpu_per_task, args.task_duration_s
|
||||||
|
)
|
||||||
|
|
||||||
submission_cost, ready_cost, actor_job_cost = start_actor(
|
submission_cost, ready_cost, actor_job_cost = start_actor(
|
||||||
args.total_num_actors, args.num_actors_per_nodes, job)
|
args.total_num_actors, args.num_actors_per_nodes, job
|
||||||
|
)
|
||||||
|
|
||||||
output = os.environ.get("TEST_OUTPUT_JSON")
|
output = os.environ.get("TEST_OUTPUT_JSON")
|
||||||
|
|
||||||
|
@ -118,6 +123,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
if output is not None:
|
if output is not None:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
p = Path(output)
|
p = Path(output)
|
||||||
p.write_text(json.dumps(result))
|
p.write_text(json.dumps(result))
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,7 @@ def num_alive_nodes():
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option(
|
@click.option("--num-nodes", required=True, type=int, help="The target number of nodes")
|
||||||
"--num-nodes", required=True, type=int, help="The target number of nodes")
|
|
||||||
def wait_cluster(num_nodes: int):
|
def wait_cluster(num_nodes: int):
|
||||||
ray.init(address="auto")
|
ray.init(address="auto")
|
||||||
while num_alive_nodes() != num_nodes:
|
while num_alive_nodes() != num_nodes:
|
||||||
|
|
|
@ -9,7 +9,7 @@ from time import perf_counter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
NUM_NODES = 50
|
NUM_NODES = 50
|
||||||
OBJECT_SIZE = 2**30
|
OBJECT_SIZE = 2 ** 30
|
||||||
|
|
||||||
|
|
||||||
def num_alive_nodes():
|
def num_alive_nodes():
|
||||||
|
@ -60,6 +60,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
|
||||||
"broadcast_time": end - start,
|
"broadcast_time": end - start,
|
||||||
"object_size": OBJECT_SIZE,
|
"object_size": OBJECT_SIZE,
|
||||||
"num_nodes": NUM_NODES,
|
"num_nodes": NUM_NODES,
|
||||||
"success": "1"
|
"success": "1",
|
||||||
}
|
}
|
||||||
json.dump(results, out_file)
|
json.dump(results, out_file)
|
||||||
|
|
|
@ -13,7 +13,7 @@ MAX_ARGS = 10000
|
||||||
MAX_RETURNS = 3000
|
MAX_RETURNS = 3000
|
||||||
MAX_RAY_GET_ARGS = 10000
|
MAX_RAY_GET_ARGS = 10000
|
||||||
MAX_QUEUED_TASKS = 1_000_000
|
MAX_QUEUED_TASKS = 1_000_000
|
||||||
MAX_RAY_GET_SIZE = 100 * 2**30
|
MAX_RAY_GET_SIZE = 100 * 2 ** 30
|
||||||
|
|
||||||
|
|
||||||
def assert_no_leaks():
|
def assert_no_leaks():
|
||||||
|
@ -189,8 +189,7 @@ print(f"Many args time: {args_time} ({MAX_ARGS} args)")
|
||||||
print(f"Many returns time: {returns_time} ({MAX_RETURNS} returns)")
|
print(f"Many returns time: {returns_time} ({MAX_RETURNS} returns)")
|
||||||
print(f"Ray.get time: {get_time} ({MAX_RAY_GET_ARGS} args)")
|
print(f"Ray.get time: {get_time} ({MAX_RAY_GET_ARGS} args)")
|
||||||
print(f"Queued task time: {queued_time} ({MAX_QUEUED_TASKS} tasks)")
|
print(f"Queued task time: {queued_time} ({MAX_QUEUED_TASKS} tasks)")
|
||||||
print(f"Ray.get large object time: {large_object_time} "
|
print(f"Ray.get large object time: {large_object_time} " f"({MAX_RAY_GET_SIZE} bytes)")
|
||||||
f"({MAX_RAY_GET_SIZE} bytes)")
|
|
||||||
|
|
||||||
if "TEST_OUTPUT_JSON" in os.environ:
|
if "TEST_OUTPUT_JSON" in os.environ:
|
||||||
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
||||||
|
@ -205,6 +204,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
|
||||||
"num_queued": MAX_QUEUED_TASKS,
|
"num_queued": MAX_QUEUED_TASKS,
|
||||||
"large_object_time": large_object_time,
|
"large_object_time": large_object_time,
|
||||||
"large_object_size": MAX_RAY_GET_SIZE,
|
"large_object_size": MAX_RAY_GET_SIZE,
|
||||||
"success": "1"
|
"success": "1",
|
||||||
}
|
}
|
||||||
json.dump(results, out_file)
|
json.dump(results, out_file)
|
||||||
|
|
|
@ -66,7 +66,7 @@ def get_remote_url(remote):
|
||||||
|
|
||||||
def replace_suffix(base, old_suffix, new_suffix=""):
|
def replace_suffix(base, old_suffix, new_suffix=""):
|
||||||
if base.endswith(old_suffix):
|
if base.endswith(old_suffix):
|
||||||
base = base[:len(base) - len(old_suffix)] + new_suffix
|
base = base[: len(base) - len(old_suffix)] + new_suffix
|
||||||
return base
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
@ -199,12 +199,21 @@ def monitor():
|
||||||
expected_line = "{}\t{}".format(expected_sha, ref)
|
expected_line = "{}\t{}".format(expected_sha, ref)
|
||||||
|
|
||||||
if should_keep_alive(git("show", "-s", "--format=%B", "HEAD^-")):
|
if should_keep_alive(git("show", "-s", "--format=%B", "HEAD^-")):
|
||||||
logger.info("Not monitoring %s on %s due to keep-alive on: %s", ref,
|
logger.info(
|
||||||
remote, expected_line)
|
"Not monitoring %s on %s due to keep-alive on: %s",
|
||||||
|
ref,
|
||||||
|
remote,
|
||||||
|
expected_line,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Monitoring %s (%s) for changes in %s: %s", remote,
|
logger.info(
|
||||||
get_remote_url(remote), ref, expected_line)
|
"Monitoring %s (%s) for changes in %s: %s",
|
||||||
|
remote,
|
||||||
|
get_remote_url(remote),
|
||||||
|
ref,
|
||||||
|
expected_line,
|
||||||
|
)
|
||||||
|
|
||||||
for to_wait in yield_poll_schedule():
|
for to_wait in yield_poll_schedule():
|
||||||
time.sleep(to_wait)
|
time.sleep(to_wait)
|
||||||
|
@ -217,12 +226,21 @@ def monitor():
|
||||||
status = ex.returncode
|
status = ex.returncode
|
||||||
|
|
||||||
if status == 2:
|
if status == 2:
|
||||||
logger.info("Terminating job as %s has been deleted on %s: %s",
|
logger.info(
|
||||||
ref, remote, expected_line)
|
"Terminating job as %s has been deleted on %s: %s",
|
||||||
|
ref,
|
||||||
|
remote,
|
||||||
|
expected_line,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
elif status != 0:
|
elif status != 0:
|
||||||
logger.error("Error %d: unable to check %s on %s: %s", status, ref,
|
logger.error(
|
||||||
remote, expected_line)
|
"Error %d: unable to check %s on %s: %s",
|
||||||
|
status,
|
||||||
|
ref,
|
||||||
|
remote,
|
||||||
|
expected_line,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prev = expected_line
|
prev = expected_line
|
||||||
expected_line = detect_spurious_commit(line, expected_line, remote)
|
expected_line = detect_spurious_commit(line, expected_line, remote)
|
||||||
|
@ -230,14 +248,24 @@ def monitor():
|
||||||
logger.info(
|
logger.info(
|
||||||
"Terminating job as %s has been updated on %s\n"
|
"Terminating job as %s has been updated on %s\n"
|
||||||
" from:\t%s\n"
|
" from:\t%s\n"
|
||||||
" to: \t%s", ref, remote, expected_line, line)
|
" to: \t%s",
|
||||||
|
ref,
|
||||||
|
remote,
|
||||||
|
expected_line,
|
||||||
|
line,
|
||||||
|
)
|
||||||
time.sleep(1) # wait for CI to flush output
|
time.sleep(1) # wait for CI to flush output
|
||||||
break
|
break
|
||||||
if expected_line != prev:
|
if expected_line != prev:
|
||||||
logger.info(
|
logger.info(
|
||||||
"%s appeared to spuriously change on %s\n"
|
"%s appeared to spuriously change on %s\n"
|
||||||
" from:\t%s\n"
|
" from:\t%s\n"
|
||||||
" to: \t%s", ref, remote, prev, expected_line)
|
" to: \t%s",
|
||||||
|
ref,
|
||||||
|
remote,
|
||||||
|
prev,
|
||||||
|
expected_line,
|
||||||
|
)
|
||||||
|
|
||||||
return terminate_my_process_group()
|
return terminate_my_process_group()
|
||||||
|
|
||||||
|
@ -259,9 +287,8 @@ def main(program, *args):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(levelname)s: %(message)s",
|
format="%(levelname)s: %(message)s", stream=sys.stderr, level=logging.DEBUG
|
||||||
stream=sys.stderr,
|
)
|
||||||
level=logging.DEBUG)
|
|
||||||
try:
|
try:
|
||||||
raise SystemExit(main(*sys.argv) or 0)
|
raise SystemExit(main(*sys.argv) or 0)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|
271
ci/repro-ci.py
271
ci/repro-ci.py
|
@ -53,7 +53,10 @@ def maybe_fetch_buildkite_token():
|
||||||
"secretsmanager", region_name="us-west-2"
|
"secretsmanager", region_name="us-west-2"
|
||||||
).get_secret_value(
|
).get_secret_value(
|
||||||
SecretId="arn:aws:secretsmanager:us-west-2:029272617770:secret:"
|
SecretId="arn:aws:secretsmanager:us-west-2:029272617770:secret:"
|
||||||
"buildkite/ro-token")["SecretString"]
|
"buildkite/ro-token"
|
||||||
|
)[
|
||||||
|
"SecretString"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def escape(v: Any):
|
def escape(v: Any):
|
||||||
|
@ -85,26 +88,26 @@ def env_str(env: Dict[str, Any]):
|
||||||
|
|
||||||
def script_str(v: Any):
|
def script_str(v: Any):
|
||||||
if isinstance(v, bool):
|
if isinstance(v, bool):
|
||||||
return f"\"{int(v)}\""
|
return f'"{int(v)}"'
|
||||||
elif isinstance(v, Number):
|
elif isinstance(v, Number):
|
||||||
return f"\"{v}\""
|
return f'"{v}"'
|
||||||
elif isinstance(v, list):
|
elif isinstance(v, list):
|
||||||
return "(" + " ".join(f"\"{shlex.quote(w)}\"" for w in v) + ")"
|
return "(" + " ".join(f'"{shlex.quote(w)}"' for w in v) + ")"
|
||||||
else:
|
else:
|
||||||
return f"\"{shlex.quote(v)}\""
|
return f'"{shlex.quote(v)}"'
|
||||||
|
|
||||||
|
|
||||||
class ReproSession:
|
class ReproSession:
|
||||||
plugin_default_env = {
|
plugin_default_env = {
|
||||||
"docker": {
|
"docker": {"BUILDKITE_PLUGIN_DOCKER_MOUNT_BUILDKITE_AGENT": False}
|
||||||
"BUILDKITE_PLUGIN_DOCKER_MOUNT_BUILDKITE_AGENT": False
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
buildkite_token: str,
|
self,
|
||||||
instance_name: Optional[str] = None,
|
buildkite_token: str,
|
||||||
logger: Optional[logging.Logger] = None):
|
instance_name: Optional[str] = None,
|
||||||
|
logger: Optional[logging.Logger] = None,
|
||||||
|
):
|
||||||
self.logger = logger or logging.getLogger(self.__class__.__name__)
|
self.logger = logger or logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
self.bk = Buildkite()
|
self.bk = Buildkite()
|
||||||
|
@ -139,12 +142,15 @@ class ReproSession:
|
||||||
# https://buildkite.com/ray-project/ray-builders-pr/
|
# https://buildkite.com/ray-project/ray-builders-pr/
|
||||||
# builds/19635#55a0d71a-831e-4f68-b668-2b10c6f65ee6
|
# builds/19635#55a0d71a-831e-4f68-b668-2b10c6f65ee6
|
||||||
pattern = re.compile(
|
pattern = re.compile(
|
||||||
"https://buildkite.com/([^/]+)/([^/]+)/builds/([0-9]+)#(.+)")
|
"https://buildkite.com/([^/]+)/([^/]+)/builds/([0-9]+)#(.+)"
|
||||||
|
)
|
||||||
org, pipeline, build_id, job_id = pattern.match(session_url).groups()
|
org, pipeline, build_id, job_id = pattern.match(session_url).groups()
|
||||||
|
|
||||||
self.logger.debug(f"Parsed session URL: {session_url}. "
|
self.logger.debug(
|
||||||
f"Got org='{org}', pipeline='{pipeline}', "
|
f"Parsed session URL: {session_url}. "
|
||||||
f"build_id='{build_id}', job_id='{job_id}'.")
|
f"Got org='{org}', pipeline='{pipeline}', "
|
||||||
|
f"build_id='{build_id}', job_id='{job_id}'."
|
||||||
|
)
|
||||||
|
|
||||||
self.org = org
|
self.org = org
|
||||||
self.pipeline = pipeline
|
self.pipeline = pipeline
|
||||||
|
@ -155,7 +161,8 @@ class ReproSession:
|
||||||
assert self.bk
|
assert self.bk
|
||||||
|
|
||||||
self.env = self.bk.jobs().get_job_environment_variables(
|
self.env = self.bk.jobs().get_job_environment_variables(
|
||||||
self.org, self.pipeline, self.build_id, self.job_id)["env"]
|
self.org, self.pipeline, self.build_id, self.job_id
|
||||||
|
)["env"]
|
||||||
|
|
||||||
if overwrite:
|
if overwrite:
|
||||||
self.env.update(overwrite)
|
self.env.update(overwrite)
|
||||||
|
@ -166,33 +173,30 @@ class ReproSession:
|
||||||
assert self.env
|
assert self.env
|
||||||
|
|
||||||
if not self.aws_instance_name:
|
if not self.aws_instance_name:
|
||||||
self.aws_instance_name = (
|
self.aws_instance_name = f"repro_ci_{self.build_id}_{self.job_id[:8]}"
|
||||||
f"repro_ci_{self.build_id}_{self.job_id[:8]}")
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"No instance name provided, using {self.aws_instance_name}")
|
f"No instance name provided, using {self.aws_instance_name}"
|
||||||
|
)
|
||||||
|
|
||||||
instance_type = self.env["BUILDKITE_AGENT_META_DATA_AWS_INSTANCE_TYPE"]
|
instance_type = self.env["BUILDKITE_AGENT_META_DATA_AWS_INSTANCE_TYPE"]
|
||||||
instance_ami = self.env["BUILDKITE_AGENT_META_DATA_AWS_AMI_ID"]
|
instance_ami = self.env["BUILDKITE_AGENT_META_DATA_AWS_AMI_ID"]
|
||||||
instance_sg = "sg-0ccfca2ef191c04ae"
|
instance_sg = "sg-0ccfca2ef191c04ae"
|
||||||
instance_block_device_mappings = [{
|
instance_block_device_mappings = [
|
||||||
"DeviceName": "/dev/xvda",
|
{"DeviceName": "/dev/xvda", "Ebs": {"VolumeSize": 500}}
|
||||||
"Ebs": {
|
]
|
||||||
"VolumeSize": 500
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
|
|
||||||
# Check if instance exists:
|
# Check if instance exists:
|
||||||
running_instances = self.ec2_resource.instances.filter(Filters=[{
|
running_instances = self.ec2_resource.instances.filter(
|
||||||
"Name": "tag:repro_name",
|
Filters=[
|
||||||
"Values": [self.aws_instance_name]
|
{"Name": "tag:repro_name", "Values": [self.aws_instance_name]},
|
||||||
}, {
|
{"Name": "instance-state-name", "Values": ["running"]},
|
||||||
"Name": "instance-state-name",
|
]
|
||||||
"Values": ["running"]
|
)
|
||||||
}])
|
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Check if instance with name {self.aws_instance_name} "
|
f"Check if instance with name {self.aws_instance_name} "
|
||||||
f"already exists...")
|
f"already exists..."
|
||||||
|
)
|
||||||
|
|
||||||
for instance in running_instances:
|
for instance in running_instances:
|
||||||
self.aws_instance_id = instance.id
|
self.aws_instance_id = instance.id
|
||||||
|
@ -201,8 +205,8 @@ class ReproSession:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Instance with name {self.aws_instance_name} not found, "
|
f"Instance with name {self.aws_instance_name} not found, " f"creating..."
|
||||||
f"creating...")
|
)
|
||||||
|
|
||||||
# Else, not running, yet, start.
|
# Else, not running, yet, start.
|
||||||
instance = self.ec2_resource.create_instances(
|
instance = self.ec2_resource.create_instances(
|
||||||
|
@ -211,20 +215,18 @@ class ReproSession:
|
||||||
InstanceType=instance_type,
|
InstanceType=instance_type,
|
||||||
KeyName=self.ssh_key_name,
|
KeyName=self.ssh_key_name,
|
||||||
SecurityGroupIds=[instance_sg],
|
SecurityGroupIds=[instance_sg],
|
||||||
TagSpecifications=[{
|
TagSpecifications=[
|
||||||
"ResourceType": "instance",
|
{
|
||||||
"Tags": [{
|
"ResourceType": "instance",
|
||||||
"Key": "repro_name",
|
"Tags": [{"Key": "repro_name", "Value": self.aws_instance_name}],
|
||||||
"Value": self.aws_instance_name
|
}
|
||||||
}]
|
],
|
||||||
}],
|
|
||||||
MinCount=1,
|
MinCount=1,
|
||||||
MaxCount=1,
|
MaxCount=1,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
self.aws_instance_id = instance.id
|
self.aws_instance_id = instance.id
|
||||||
self.logger.info(
|
self.logger.info(f"Created new instance with ID {self.aws_instance_id}")
|
||||||
f"Created new instance with ID {self.aws_instance_id}")
|
|
||||||
|
|
||||||
def aws_wait_for_instance(self):
|
def aws_wait_for_instance(self):
|
||||||
assert self.aws_instance_id
|
assert self.aws_instance_id
|
||||||
|
@ -234,28 +236,32 @@ class ReproSession:
|
||||||
repro_instance_state = None
|
repro_instance_state = None
|
||||||
while repro_instance_state != "running":
|
while repro_instance_state != "running":
|
||||||
detail = self.ec2_client.describe_instances(
|
detail = self.ec2_client.describe_instances(
|
||||||
InstanceIds=[self.aws_instance_id], )
|
InstanceIds=[self.aws_instance_id],
|
||||||
repro_instance_state = \
|
)
|
||||||
detail["Reservations"][0]["Instances"][0]["State"]["Name"]
|
repro_instance_state = detail["Reservations"][0]["Instances"][0]["State"][
|
||||||
|
"Name"
|
||||||
|
]
|
||||||
|
|
||||||
if repro_instance_state != "running":
|
if repro_instance_state != "running":
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
self.aws_instance_ip = detail["Reservations"][0]["Instances"][0][
|
self.aws_instance_ip = detail["Reservations"][0]["Instances"][0][
|
||||||
"PublicIpAddress"]
|
"PublicIpAddress"
|
||||||
|
]
|
||||||
|
|
||||||
def aws_stop_instance(self):
|
def aws_stop_instance(self):
|
||||||
assert self.aws_instance_id
|
assert self.aws_instance_id
|
||||||
|
|
||||||
self.ec2_client.terminate_instances(
|
self.ec2_client.terminate_instances(
|
||||||
InstanceIds=[self.aws_instance_id], )
|
InstanceIds=[self.aws_instance_id],
|
||||||
|
)
|
||||||
|
|
||||||
def print_stop_command(self):
|
def print_stop_command(self):
|
||||||
click.secho("To stop this instance in the future, run this: ")
|
click.secho("To stop this instance in the future, run this: ")
|
||||||
click.secho(
|
click.secho(
|
||||||
f"aws ec2 terminate-instances "
|
f"aws ec2 terminate-instances " f"--instance-ids={self.aws_instance_id}",
|
||||||
f"--instance-ids={self.aws_instance_id}",
|
bold=True,
|
||||||
bold=True)
|
)
|
||||||
|
|
||||||
def create_new_ssh_client(self):
|
def create_new_ssh_client(self):
|
||||||
assert self.aws_instance_ip
|
assert self.aws_instance_ip
|
||||||
|
@ -264,7 +270,8 @@ class ReproSession:
|
||||||
self.ssh.close()
|
self.ssh.close()
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Creating SSH client and waiting for SSH to become available...")
|
"Creating SSH client and waiting for SSH to become available..."
|
||||||
|
)
|
||||||
|
|
||||||
ssh = paramiko.client.SSHClient()
|
ssh = paramiko.client.SSHClient()
|
||||||
ssh.load_system_host_keys()
|
ssh.load_system_host_keys()
|
||||||
|
@ -275,7 +282,8 @@ class ReproSession:
|
||||||
ssh.connect(
|
ssh.connect(
|
||||||
self.aws_instance_ip,
|
self.aws_instance_ip,
|
||||||
username=self.ssh_user,
|
username=self.ssh_user,
|
||||||
key_filename=os.path.expanduser(self.ssh_key_file))
|
key_filename=os.path.expanduser(self.ssh_key_file),
|
||||||
|
)
|
||||||
break
|
break
|
||||||
except paramiko.ssh_exception.NoValidConnectionsError:
|
except paramiko.ssh_exception.NoValidConnectionsError:
|
||||||
self.logger.info("SSH not ready, yet, sleeping 5 seconds")
|
self.logger.info("SSH not ready, yet, sleeping 5 seconds")
|
||||||
|
@ -291,8 +299,7 @@ class ReproSession:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
def exec():
|
def exec():
|
||||||
stdin, stdout, stderr = self.ssh.exec_command(
|
stdin, stdout, stderr = self.ssh.exec_command(command, get_pty=True)
|
||||||
command, get_pty=True)
|
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
for line in stdout.readlines():
|
for line in stdout.readlines():
|
||||||
|
@ -321,12 +328,13 @@ class ReproSession:
|
||||||
return result.get("output", "")
|
return result.get("output", "")
|
||||||
|
|
||||||
def execute_ssh_command(
|
def execute_ssh_command(
|
||||||
self,
|
self,
|
||||||
command: str,
|
command: str,
|
||||||
env: Optional[Dict[str, str]] = None,
|
env: Optional[Dict[str, str]] = None,
|
||||||
as_script: bool = False,
|
as_script: bool = False,
|
||||||
quiet: bool = False,
|
quiet: bool = False,
|
||||||
command_wrapper: Optional[Callable[[str], str]] = None) -> str:
|
command_wrapper: Optional[Callable[[str], str]] = None,
|
||||||
|
) -> str:
|
||||||
assert self.ssh
|
assert self.ssh
|
||||||
|
|
||||||
if not command_wrapper:
|
if not command_wrapper:
|
||||||
|
@ -360,23 +368,25 @@ class ReproSession:
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def execute_ssh_commands(self,
|
def execute_ssh_commands(
|
||||||
commands: List[str],
|
self,
|
||||||
env: Optional[Dict[str, str]] = None,
|
commands: List[str],
|
||||||
quiet: bool = False):
|
env: Optional[Dict[str, str]] = None,
|
||||||
|
quiet: bool = False,
|
||||||
|
):
|
||||||
for command in commands:
|
for command in commands:
|
||||||
self.execute_ssh_command(command, env=env, quiet=quiet)
|
self.execute_ssh_command(command, env=env, quiet=quiet)
|
||||||
|
|
||||||
def execute_docker_command(self,
|
def execute_docker_command(
|
||||||
command: str,
|
self, command: str, env: Optional[Dict[str, str]] = None, quiet: bool = False
|
||||||
env: Optional[Dict[str, str]] = None,
|
):
|
||||||
quiet: bool = False):
|
|
||||||
def command_wrapper(s):
|
def command_wrapper(s):
|
||||||
escaped = s.replace("'", "'\"'\"'")
|
escaped = s.replace("'", "'\"'\"'")
|
||||||
return f"docker exec -it ray_container /bin/bash -ci '{escaped}'"
|
return f"docker exec -it ray_container /bin/bash -ci '{escaped}'"
|
||||||
|
|
||||||
self.execute_ssh_command(
|
self.execute_ssh_command(
|
||||||
command, env=env, quiet=quiet, command_wrapper=command_wrapper)
|
command, env=env, quiet=quiet, command_wrapper=command_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_instance(self):
|
def prepare_instance(self):
|
||||||
self.create_new_ssh_client()
|
self.create_new_ssh_client()
|
||||||
|
@ -387,8 +397,9 @@ class ReproSession:
|
||||||
|
|
||||||
self.logger.info("Preparing instance (installing docker etc.)")
|
self.logger.info("Preparing instance (installing docker etc.)")
|
||||||
commands = [
|
commands = [
|
||||||
"sudo yum install -y docker", "sudo service docker start",
|
"sudo yum install -y docker",
|
||||||
f"sudo usermod -aG docker {self.ssh_user}"
|
"sudo service docker start",
|
||||||
|
f"sudo usermod -aG docker {self.ssh_user}",
|
||||||
]
|
]
|
||||||
self.execute_ssh_commands(commands, quiet=True)
|
self.execute_ssh_commands(commands, quiet=True)
|
||||||
self.create_new_ssh_client()
|
self.create_new_ssh_client()
|
||||||
|
@ -398,13 +409,18 @@ class ReproSession:
|
||||||
def docker_login(self):
|
def docker_login(self):
|
||||||
self.logger.info("Logging into docker...")
|
self.logger.info("Logging into docker...")
|
||||||
credentials = boto3.client(
|
credentials = boto3.client(
|
||||||
"ecr", region_name="us-west-2").get_authorization_token()
|
"ecr", region_name="us-west-2"
|
||||||
token = base64.b64decode(credentials["authorizationData"][0][
|
).get_authorization_token()
|
||||||
"authorizationToken"]).decode("utf-8").replace("AWS:", "")
|
token = (
|
||||||
|
base64.b64decode(credentials["authorizationData"][0]["authorizationToken"])
|
||||||
|
.decode("utf-8")
|
||||||
|
.replace("AWS:", "")
|
||||||
|
)
|
||||||
endpoint = credentials["authorizationData"][0]["proxyEndpoint"]
|
endpoint = credentials["authorizationData"][0]["proxyEndpoint"]
|
||||||
|
|
||||||
self.execute_ssh_command(
|
self.execute_ssh_command(
|
||||||
f"docker login -u AWS -p {token} {endpoint}", quiet=True)
|
f"docker login -u AWS -p {token} {endpoint}", quiet=True
|
||||||
|
)
|
||||||
|
|
||||||
def fetch_buildkite_plugins(self):
|
def fetch_buildkite_plugins(self):
|
||||||
assert self.env
|
assert self.env
|
||||||
|
@ -415,8 +431,9 @@ class ReproSession:
|
||||||
for collection in plugins:
|
for collection in plugins:
|
||||||
for plugin, options in collection.items():
|
for plugin, options in collection.items():
|
||||||
plugin_url, plugin_version = plugin.split("#")
|
plugin_url, plugin_version = plugin.split("#")
|
||||||
if not plugin_url.startswith(
|
if not plugin_url.startswith("http://") or not plugin_url.startswith(
|
||||||
"http://") or not plugin_url.startswith("https://"):
|
"https://"
|
||||||
|
):
|
||||||
plugin_url = f"https://{plugin_url}"
|
plugin_url = f"https://{plugin_url}"
|
||||||
|
|
||||||
plugin_name = plugin_url.split("/")[-1].rstrip(".git")
|
plugin_name = plugin_url.split("/")[-1].rstrip(".git")
|
||||||
|
@ -432,7 +449,7 @@ class ReproSession:
|
||||||
"version": plugin_version,
|
"version": plugin_version,
|
||||||
"dir": plugin_dir,
|
"dir": plugin_dir,
|
||||||
"env": plugin_env,
|
"env": plugin_env,
|
||||||
"details": {}
|
"details": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_plugin_env(self, plugin_short: str, options: Dict[str, Any]):
|
def get_plugin_env(self, plugin_short: str, options: Dict[str, Any]):
|
||||||
|
@ -457,30 +474,33 @@ class ReproSession:
|
||||||
self.execute_ssh_command(
|
self.execute_ssh_command(
|
||||||
f"[ ! -e {plugin_dir} ] && git clone --depth 1 "
|
f"[ ! -e {plugin_dir} ] && git clone --depth 1 "
|
||||||
f"--branch {plugin_version} {plugin_url} {plugin_dir}",
|
f"--branch {plugin_version} {plugin_url} {plugin_dir}",
|
||||||
quiet=True)
|
quiet=True,
|
||||||
|
)
|
||||||
|
|
||||||
def load_plugin_details(self, plugin: str):
|
def load_plugin_details(self, plugin: str):
|
||||||
assert plugin in self.plugins
|
assert plugin in self.plugins
|
||||||
|
|
||||||
plugin_dir = self.plugins[plugin]["dir"]
|
plugin_dir = self.plugins[plugin]["dir"]
|
||||||
|
|
||||||
yaml_str = self.execute_ssh_command(
|
yaml_str = self.execute_ssh_command(f"cat {plugin_dir}/plugin.yml", quiet=True)
|
||||||
f"cat {plugin_dir}/plugin.yml", quiet=True)
|
|
||||||
|
|
||||||
details = yaml.safe_load(yaml_str)
|
details = yaml.safe_load(yaml_str)
|
||||||
self.plugins[plugin]["details"] = details
|
self.plugins[plugin]["details"] = details
|
||||||
return details
|
return details
|
||||||
|
|
||||||
def execute_plugin_hook(self,
|
def execute_plugin_hook(
|
||||||
plugin: str,
|
self,
|
||||||
hook: str,
|
plugin: str,
|
||||||
env: Optional[Dict[str, Any]] = None,
|
hook: str,
|
||||||
script_command: Optional[str] = None):
|
env: Optional[Dict[str, Any]] = None,
|
||||||
|
script_command: Optional[str] = None,
|
||||||
|
):
|
||||||
assert plugin in self.plugins
|
assert plugin in self.plugins
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Executing Buildkite hook for plugin {plugin}: {hook}. "
|
f"Executing Buildkite hook for plugin {plugin}: {hook}. "
|
||||||
f"This pulls a Docker image and could take a while.")
|
f"This pulls a Docker image and could take a while."
|
||||||
|
)
|
||||||
|
|
||||||
plugin_dir = self.plugins[plugin]["dir"]
|
plugin_dir = self.plugins[plugin]["dir"]
|
||||||
plugin_env = self.plugins[plugin]["env"].copy()
|
plugin_env = self.plugins[plugin]["env"].copy()
|
||||||
|
@ -500,21 +520,23 @@ class ReproSession:
|
||||||
|
|
||||||
def print_buildkite_command(self, skipped: bool = False):
|
def print_buildkite_command(self, skipped: bool = False):
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
print("These are the commands you need to execute to fully reproduce "
|
print(
|
||||||
"the run")
|
"These are the commands you need to execute to fully reproduce " "the run"
|
||||||
|
)
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
print(self.env["BUILDKITE_COMMAND"])
|
print(self.env["BUILDKITE_COMMAND"])
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
if skipped and self.skipped_commands:
|
if skipped and self.skipped_commands:
|
||||||
print("Some of the commands above have already been run. "
|
print(
|
||||||
"Remaining commands:")
|
"Some of the commands above have already been run. "
|
||||||
|
"Remaining commands:"
|
||||||
|
)
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
print("\n".join(self.skipped_commands))
|
print("\n".join(self.skipped_commands))
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
def run_buildkite_command(self,
|
def run_buildkite_command(self, command_filter: Optional[List[str]] = None):
|
||||||
command_filter: Optional[List[str]] = None):
|
|
||||||
commands = self.env["BUILDKITE_COMMAND"].split("\n")
|
commands = self.env["BUILDKITE_COMMAND"].split("\n")
|
||||||
regexes = [re.compile(cf) for cf in command_filter or []]
|
regexes = [re.compile(cf) for cf in command_filter or []]
|
||||||
|
|
||||||
|
@ -537,15 +559,18 @@ class ReproSession:
|
||||||
f"grep -q 'source ~/.env' $HOME/.bashrc "
|
f"grep -q 'source ~/.env' $HOME/.bashrc "
|
||||||
f"|| echo 'source ~/.env' >> $HOME/.bashrc; "
|
f"|| echo 'source ~/.env' >> $HOME/.bashrc; "
|
||||||
f"echo 'export {escaped}' > $HOME/.env",
|
f"echo 'export {escaped}' > $HOME/.env",
|
||||||
quiet=True)
|
quiet=True,
|
||||||
|
)
|
||||||
|
|
||||||
def attach_to_container(self):
|
def attach_to_container(self):
|
||||||
self.logger.info("Attaching to AWS instance...")
|
self.logger.info("Attaching to AWS instance...")
|
||||||
ssh_command = (f"ssh -ti {self.ssh_key_file} "
|
ssh_command = (
|
||||||
f"-o StrictHostKeyChecking=no "
|
f"ssh -ti {self.ssh_key_file} "
|
||||||
f"-o ServerAliveInterval=30 "
|
f"-o StrictHostKeyChecking=no "
|
||||||
f"{self.ssh_user}@{self.aws_instance_ip} "
|
f"-o ServerAliveInterval=30 "
|
||||||
f"'docker exec -it ray_container bash -l'")
|
f"{self.ssh_user}@{self.aws_instance_ip} "
|
||||||
|
f"'docker exec -it ray_container bash -l'"
|
||||||
|
)
|
||||||
|
|
||||||
subprocess.run(ssh_command, shell=True)
|
subprocess.run(ssh_command, shell=True)
|
||||||
|
|
||||||
|
@ -555,29 +580,32 @@ class ReproSession:
|
||||||
@click.option("-n", "--instance-name", default=None)
|
@click.option("-n", "--instance-name", default=None)
|
||||||
@click.option("-c", "--commands", is_flag=True, default=False)
|
@click.option("-c", "--commands", is_flag=True, default=False)
|
||||||
@click.option("-f", "--filters", multiple=True, default=[])
|
@click.option("-f", "--filters", multiple=True, default=[])
|
||||||
def main(session_url: Optional[str],
|
def main(
|
||||||
instance_name: Optional[str] = None,
|
session_url: Optional[str],
|
||||||
commands: bool = False,
|
instance_name: Optional[str] = None,
|
||||||
filters: Optional[List[str]] = None):
|
commands: bool = False,
|
||||||
|
filters: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
random.seed(1235)
|
random.seed(1235)
|
||||||
|
|
||||||
logger = logging.getLogger("main")
|
logger = logging.getLogger("main")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
handler.setFormatter(
|
handler.setFormatter(
|
||||||
logging.Formatter("[%(levelname)s %(asctime)s] "
|
logging.Formatter(
|
||||||
"%(filename)s: %(lineno)d "
|
"[%(levelname)s %(asctime)s] " "%(filename)s: %(lineno)d " "%(message)s"
|
||||||
"%(message)s"))
|
)
|
||||||
|
)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
maybe_fetch_buildkite_token()
|
maybe_fetch_buildkite_token()
|
||||||
repro = ReproSession(
|
repro = ReproSession(
|
||||||
os.environ["BUILDKITE_TOKEN"],
|
os.environ["BUILDKITE_TOKEN"], instance_name=instance_name, logger=logger
|
||||||
instance_name=instance_name,
|
)
|
||||||
logger=logger)
|
|
||||||
|
|
||||||
session_url = session_url or click.prompt(
|
session_url = session_url or click.prompt(
|
||||||
"Please copy and paste the Buildkite job build URI here")
|
"Please copy and paste the Buildkite job build URI here"
|
||||||
|
)
|
||||||
|
|
||||||
repro.set_session(session_url)
|
repro.set_session(session_url)
|
||||||
|
|
||||||
|
@ -610,13 +638,16 @@ def main(session_url: Optional[str],
|
||||||
"BUILDKITE_PLUGIN_DOCKER_TTY": "0",
|
"BUILDKITE_PLUGIN_DOCKER_TTY": "0",
|
||||||
"BUILDKITE_PLUGIN_DOCKER_MOUNT_CHECKOUT": "0",
|
"BUILDKITE_PLUGIN_DOCKER_MOUNT_CHECKOUT": "0",
|
||||||
},
|
},
|
||||||
script_command=("sed -E 's/"
|
script_command=(
|
||||||
"docker run/"
|
"sed -E 's/"
|
||||||
"docker run "
|
"docker run/"
|
||||||
"--cap-add=SYS_PTRACE "
|
"docker run "
|
||||||
"--name ray_container "
|
"--cap-add=SYS_PTRACE "
|
||||||
"-d/g' | "
|
"--name ray_container "
|
||||||
"bash -l"))
|
"-d/g' | "
|
||||||
|
"bash -l"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
repro.create_new_ssh_client()
|
repro.create_new_ssh_client()
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,8 @@ def get_target_expansion_query(targets, tests_only, exclude_manual):
|
||||||
|
|
||||||
if exclude_manual:
|
if exclude_manual:
|
||||||
query = '{} except tests(attr("tags", "manual", set({})))'.format(
|
query = '{} except tests(attr("tags", "manual", set({})))'.format(
|
||||||
query, included_targets)
|
query, included_targets
|
||||||
|
)
|
||||||
|
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
@ -82,17 +83,16 @@ def get_targets_for_shard(targets, index, count):
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Expand and shard Bazel targets.")
|
||||||
description="Expand and shard Bazel targets.")
|
|
||||||
parser.add_argument("--debug", action="store_true")
|
parser.add_argument("--debug", action="store_true")
|
||||||
parser.add_argument("--tests_only", action="store_true")
|
parser.add_argument("--tests_only", action="store_true")
|
||||||
parser.add_argument("--exclude_manual", action="store_true")
|
parser.add_argument("--exclude_manual", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--index", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB", 1))
|
"--index", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB", 1)
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--count",
|
"--count", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB_COUNT", 1)
|
||||||
type=int,
|
)
|
||||||
default=os.getenv("BUILDKITE_PARALLEL_JOB_COUNT", 1))
|
|
||||||
parser.add_argument("targets", nargs="+")
|
parser.add_argument("targets", nargs="+")
|
||||||
args, extra_args = parser.parse_known_args()
|
args, extra_args = parser.parse_known_args()
|
||||||
args.targets = list(args.targets) + list(extra_args)
|
args.targets = list(args.targets) + list(extra_args)
|
||||||
|
@ -100,11 +100,11 @@ def main():
|
||||||
if args.index >= args.count:
|
if args.index >= args.count:
|
||||||
parser.error("--index must be between 0 and {}".format(args.count - 1))
|
parser.error("--index must be between 0 and {}".format(args.count - 1))
|
||||||
|
|
||||||
query = get_target_expansion_query(args.targets, args.tests_only,
|
query = get_target_expansion_query(
|
||||||
args.exclude_manual)
|
args.targets, args.tests_only, args.exclude_manual
|
||||||
|
)
|
||||||
expanded_targets = run_bazel_query(query, args.debug)
|
expanded_targets = run_bazel_query(query, args.debug)
|
||||||
my_targets = get_targets_for_shard(expanded_targets, args.index,
|
my_targets = get_targets_for_shard(expanded_targets, args.index, args.count)
|
||||||
args.count)
|
|
||||||
print(" ".join(my_targets))
|
print(" ".join(my_targets))
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
|
@ -14,10 +14,10 @@ from collections import defaultdict, OrderedDict
|
||||||
|
|
||||||
def textproto_format(space, key, value, json_encoder):
|
def textproto_format(space, key, value, json_encoder):
|
||||||
"""Rewrites a key-value pair from textproto as JSON."""
|
"""Rewrites a key-value pair from textproto as JSON."""
|
||||||
if value.startswith(b"\""):
|
if value.startswith(b'"'):
|
||||||
evaluated = ast.literal_eval(value.decode("utf-8"))
|
evaluated = ast.literal_eval(value.decode("utf-8"))
|
||||||
value = json_encoder.encode(evaluated).encode("utf-8")
|
value = json_encoder.encode(evaluated).encode("utf-8")
|
||||||
return b"%s[\"%s\", %s]" % (space, key, value)
|
return b'%s["%s", %s]' % (space, key, value)
|
||||||
|
|
||||||
|
|
||||||
def textproto_split(input_lines, json_encoder):
|
def textproto_split(input_lines, json_encoder):
|
||||||
|
@ -50,19 +50,21 @@ def textproto_split(input_lines, json_encoder):
|
||||||
pieces = re.split(b"(\\r|\\n)", full_line, 1)
|
pieces = re.split(b"(\\r|\\n)", full_line, 1)
|
||||||
pieces[1:] = [b"".join(pieces[1:])]
|
pieces[1:] = [b"".join(pieces[1:])]
|
||||||
[line, tail] = pieces
|
[line, tail] = pieces
|
||||||
next_line = pat_open.sub(b"\\1[\"\\2\",\\3[", line)
|
next_line = pat_open.sub(b'\\1["\\2",\\3[', line)
|
||||||
outputs.append(b"" if not prev_comma else b"]"
|
outputs.append(
|
||||||
if next_line.endswith(b"}") else b",")
|
b"" if not prev_comma else b"]" if next_line.endswith(b"}") else b","
|
||||||
|
)
|
||||||
next_line = pat_close.sub(b"]", next_line)
|
next_line = pat_close.sub(b"]", next_line)
|
||||||
next_line = pat_line.sub(
|
next_line = pat_line.sub(
|
||||||
lambda m: textproto_format(*(m.groups() + (json_encoder, ))),
|
lambda m: textproto_format(*(m.groups() + (json_encoder,))), next_line
|
||||||
next_line)
|
)
|
||||||
outputs.append(prev_tail + next_line)
|
outputs.append(prev_tail + next_line)
|
||||||
if line == b"}":
|
if line == b"}":
|
||||||
yield b"".join(outputs)
|
yield b"".join(outputs)
|
||||||
del outputs[:]
|
del outputs[:]
|
||||||
prev_comma = line != b"}" and (next_line.endswith(b"]")
|
prev_comma = line != b"}" and (
|
||||||
or next_line.endswith(b"\""))
|
next_line.endswith(b"]") or next_line.endswith(b'"')
|
||||||
|
)
|
||||||
prev_tail = tail
|
prev_tail = tail
|
||||||
if len(outputs) > 0:
|
if len(outputs) > 0:
|
||||||
yield b"".join(outputs)
|
yield b"".join(outputs)
|
||||||
|
@ -80,13 +82,14 @@ class Bazel(object):
|
||||||
def __init__(self, program=None):
|
def __init__(self, program=None):
|
||||||
if program is None:
|
if program is None:
|
||||||
program = os.getenv("BAZEL_EXECUTABLE", "bazel")
|
program = os.getenv("BAZEL_EXECUTABLE", "bazel")
|
||||||
self.argv = (program, )
|
self.argv = (program,)
|
||||||
self.extra_args = ("--show_progress=no", )
|
self.extra_args = ("--show_progress=no",)
|
||||||
|
|
||||||
def _call(self, command, *args):
|
def _call(self, command, *args):
|
||||||
return subprocess.check_output(
|
return subprocess.check_output(
|
||||||
self.argv + (command, ) + args[:1] + self.extra_args + args[1:],
|
self.argv + (command,) + args[:1] + self.extra_args + args[1:],
|
||||||
stdin=subprocess.PIPE)
|
stdin=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
def info(self, *args):
|
def info(self, *args):
|
||||||
result = OrderedDict()
|
result = OrderedDict()
|
||||||
|
@ -248,8 +251,7 @@ def shellcheck(bazel_aquery, *shellcheck_argv):
|
||||||
def main(program, command, *command_args):
|
def main(program, command, *command_args):
|
||||||
result = 0
|
result = 0
|
||||||
if command == textproto2json.__name__:
|
if command == textproto2json.__name__:
|
||||||
result = textproto2json(sys.stdin.buffer, sys.stdout.buffer,
|
result = textproto2json(sys.stdin.buffer, sys.stdout.buffer, *command_args)
|
||||||
*command_args)
|
|
||||||
elif command == shellcheck.__name__:
|
elif command == shellcheck.__name__:
|
||||||
result = shellcheck(*command_args)
|
result = shellcheck(*command_args)
|
||||||
elif command == preclean.__name__:
|
elif command == preclean.__name__:
|
||||||
|
|
|
@ -20,21 +20,16 @@ DOCKER_CLIENT = None
|
||||||
PYTHON_WHL_VERSION = "cp3"
|
PYTHON_WHL_VERSION = "cp3"
|
||||||
|
|
||||||
DOCKER_HUB_DESCRIPTION = {
|
DOCKER_HUB_DESCRIPTION = {
|
||||||
"base-deps": ("Internal Image, refer to "
|
"base-deps": (
|
||||||
"https://hub.docker.com/r/rayproject/ray"),
|
"Internal Image, refer to " "https://hub.docker.com/r/rayproject/ray"
|
||||||
"ray-deps": ("Internal Image, refer to "
|
),
|
||||||
"https://hub.docker.com/r/rayproject/ray"),
|
"ray-deps": ("Internal Image, refer to " "https://hub.docker.com/r/rayproject/ray"),
|
||||||
"ray": "Official Docker Images for Ray, the distributed computing API.",
|
"ray": "Official Docker Images for Ray, the distributed computing API.",
|
||||||
"ray-ml": "Developer ready Docker Image for Ray.",
|
"ray-ml": "Developer ready Docker Image for Ray.",
|
||||||
"ray-worker-container": "Internal Image for CI test",
|
"ray-worker-container": "Internal Image for CI test",
|
||||||
}
|
}
|
||||||
|
|
||||||
PY_MATRIX = {
|
PY_MATRIX = {"py36": "3.6.12", "py37": "3.7.7", "py38": "3.8.5", "py39": "3.9.5"}
|
||||||
"py36": "3.6.12",
|
|
||||||
"py37": "3.7.7",
|
|
||||||
"py38": "3.8.5",
|
|
||||||
"py39": "3.9.5"
|
|
||||||
}
|
|
||||||
|
|
||||||
BASE_IMAGES = {
|
BASE_IMAGES = {
|
||||||
"cu112": "nvidia/cuda:11.2.0-cudnn8-devel-ubuntu18.04",
|
"cu112": "nvidia/cuda:11.2.0-cudnn8-devel-ubuntu18.04",
|
||||||
|
@ -50,7 +45,7 @@ CUDA_FULL = {
|
||||||
"cu111": "CUDA 11.1",
|
"cu111": "CUDA 11.1",
|
||||||
"cu110": "CUDA 11.0",
|
"cu110": "CUDA 11.0",
|
||||||
"cu102": "CUDA 10.2",
|
"cu102": "CUDA 10.2",
|
||||||
"cu101": "CUDA 10.1"
|
"cu101": "CUDA 10.1",
|
||||||
}
|
}
|
||||||
|
|
||||||
# The CUDA version to use for the ML Docker image.
|
# The CUDA version to use for the ML Docker image.
|
||||||
|
@ -62,8 +57,7 @@ IMAGE_NAMES = list(DOCKER_HUB_DESCRIPTION.keys())
|
||||||
|
|
||||||
|
|
||||||
def _get_branch():
|
def _get_branch():
|
||||||
branch = (os.environ.get("TRAVIS_BRANCH")
|
branch = os.environ.get("TRAVIS_BRANCH") or os.environ.get("BUILDKITE_BRANCH")
|
||||||
or os.environ.get("BUILDKITE_BRANCH"))
|
|
||||||
if not branch:
|
if not branch:
|
||||||
print("Branch not found!")
|
print("Branch not found!")
|
||||||
print(os.environ)
|
print(os.environ)
|
||||||
|
@ -94,8 +88,7 @@ def _get_root_dir():
|
||||||
|
|
||||||
|
|
||||||
def _get_commit_sha():
|
def _get_commit_sha():
|
||||||
sha = (os.environ.get("TRAVIS_COMMIT")
|
sha = os.environ.get("TRAVIS_COMMIT") or os.environ.get("BUILDKITE_COMMIT") or ""
|
||||||
or os.environ.get("BUILDKITE_COMMIT") or "")
|
|
||||||
if len(sha) < 6:
|
if len(sha) < 6:
|
||||||
print("INVALID SHA FOUND")
|
print("INVALID SHA FOUND")
|
||||||
return "ERROR"
|
return "ERROR"
|
||||||
|
@ -105,8 +98,9 @@ def _get_commit_sha():
|
||||||
def _configure_human_version():
|
def _configure_human_version():
|
||||||
global _get_branch
|
global _get_branch
|
||||||
global _get_commit_sha
|
global _get_commit_sha
|
||||||
fake_branch_name = input("Provide a 'branch name'. For releases, it "
|
fake_branch_name = input(
|
||||||
"should be `releases/x.x.x`")
|
"Provide a 'branch name'. For releases, it " "should be `releases/x.x.x`"
|
||||||
|
)
|
||||||
_get_branch = lambda: fake_branch_name # noqa: E731
|
_get_branch = lambda: fake_branch_name # noqa: E731
|
||||||
fake_sha = input("Provide a SHA (used for tag value)")
|
fake_sha = input("Provide a SHA (used for tag value)")
|
||||||
_get_commit_sha = lambda: fake_sha # noqa: E731
|
_get_commit_sha = lambda: fake_sha # noqa: E731
|
||||||
|
@ -115,38 +109,44 @@ def _configure_human_version():
|
||||||
def _get_wheel_name(minor_version_number):
|
def _get_wheel_name(minor_version_number):
|
||||||
if minor_version_number:
|
if minor_version_number:
|
||||||
matches = [
|
matches = [
|
||||||
file for file in glob.glob(
|
file
|
||||||
|
for file in glob.glob(
|
||||||
f"{_get_root_dir()}/.whl/ray-*{PYTHON_WHL_VERSION}"
|
f"{_get_root_dir()}/.whl/ray-*{PYTHON_WHL_VERSION}"
|
||||||
f"{minor_version_number}*-manylinux*")
|
f"{minor_version_number}*-manylinux*"
|
||||||
|
)
|
||||||
if "+" not in file # Exclude dbg, asan builds
|
if "+" not in file # Exclude dbg, asan builds
|
||||||
]
|
]
|
||||||
assert len(matches) == 1, (
|
assert len(matches) == 1, (
|
||||||
f"Found ({len(matches)}) matches for 'ray-*{PYTHON_WHL_VERSION}"
|
f"Found ({len(matches)}) matches for 'ray-*{PYTHON_WHL_VERSION}"
|
||||||
f"{minor_version_number}*-manylinux*' instead of 1.\n"
|
f"{minor_version_number}*-manylinux*' instead of 1.\n"
|
||||||
f"wheel matches: {matches}")
|
f"wheel matches: {matches}"
|
||||||
|
)
|
||||||
return os.path.basename(matches[0])
|
return os.path.basename(matches[0])
|
||||||
else:
|
else:
|
||||||
matches = glob.glob(
|
matches = glob.glob(f"{_get_root_dir()}/.whl/*{PYTHON_WHL_VERSION}*-manylinux*")
|
||||||
f"{_get_root_dir()}/.whl/*{PYTHON_WHL_VERSION}*-manylinux*")
|
|
||||||
return [os.path.basename(i) for i in matches]
|
return [os.path.basename(i) for i in matches]
|
||||||
|
|
||||||
|
|
||||||
def _check_if_docker_files_modified():
|
def _check_if_docker_files_modified():
|
||||||
stdout = subprocess.check_output([
|
stdout = subprocess.check_output(
|
||||||
sys.executable, f"{_get_curr_dir()}/determine_tests_to_run.py",
|
[
|
||||||
"--output=json"
|
sys.executable,
|
||||||
])
|
f"{_get_curr_dir()}/determine_tests_to_run.py",
|
||||||
|
"--output=json",
|
||||||
|
]
|
||||||
|
)
|
||||||
affected_env_var_list = json.loads(stdout)
|
affected_env_var_list = json.loads(stdout)
|
||||||
affected = ("RAY_CI_DOCKER_AFFECTED" in affected_env_var_list or
|
affected = (
|
||||||
"RAY_CI_PYTHON_DEPENDENCIES_AFFECTED" in affected_env_var_list)
|
"RAY_CI_DOCKER_AFFECTED" in affected_env_var_list
|
||||||
|
or "RAY_CI_PYTHON_DEPENDENCIES_AFFECTED" in affected_env_var_list
|
||||||
|
)
|
||||||
print(f"Docker affected: {affected}")
|
print(f"Docker affected: {affected}")
|
||||||
return affected
|
return affected
|
||||||
|
|
||||||
|
|
||||||
def _build_docker_image(image_name: str,
|
def _build_docker_image(
|
||||||
py_version: str,
|
image_name: str, py_version: str, image_type: str, no_cache=True
|
||||||
image_type: str,
|
):
|
||||||
no_cache=True):
|
|
||||||
"""Builds Docker image with the provided info.
|
"""Builds Docker image with the provided info.
|
||||||
|
|
||||||
image_name (str): The name of the image to build. Must be one of
|
image_name (str): The name of the image to build. Must be one of
|
||||||
|
@ -161,23 +161,27 @@ def _build_docker_image(image_name: str,
|
||||||
if image_name not in IMAGE_NAMES:
|
if image_name not in IMAGE_NAMES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The provided image name {image_name} is not "
|
f"The provided image name {image_name} is not "
|
||||||
f"recognized. Image names must be one of {IMAGE_NAMES}")
|
f"recognized. Image names must be one of {IMAGE_NAMES}"
|
||||||
|
)
|
||||||
|
|
||||||
if py_version not in PY_MATRIX.keys():
|
if py_version not in PY_MATRIX.keys():
|
||||||
raise ValueError(f"The provided python version {py_version} is not "
|
raise ValueError(
|
||||||
f"recognized. Python version must be one of"
|
f"The provided python version {py_version} is not "
|
||||||
f" {PY_MATRIX.keys()}")
|
f"recognized. Python version must be one of"
|
||||||
|
f" {PY_MATRIX.keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
if image_type not in BASE_IMAGES.keys():
|
if image_type not in BASE_IMAGES.keys():
|
||||||
raise ValueError(f"The provided CUDA version {image_type} is not "
|
raise ValueError(
|
||||||
f"recognized. CUDA version must be one of"
|
f"The provided CUDA version {image_type} is not "
|
||||||
f" {image_type.keys()}")
|
f"recognized. CUDA version must be one of"
|
||||||
|
f" {image_type.keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(https://github.com/ray-project/ray/issues/16599):
|
# TODO(https://github.com/ray-project/ray/issues/16599):
|
||||||
# remove below after supporting ray-ml images with Python 3.9
|
# remove below after supporting ray-ml images with Python 3.9
|
||||||
if image_name == "ray-ml" and py_version == "py39":
|
if image_name == "ray-ml" and py_version == "py39":
|
||||||
print(f"{image_name} image is currently unsupported with "
|
print(f"{image_name} image is currently unsupported with " "Python 3.9")
|
||||||
"Python 3.9")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
build_args = {}
|
build_args = {}
|
||||||
|
@ -212,7 +216,7 @@ def _build_docker_image(image_name: str,
|
||||||
labels = {
|
labels = {
|
||||||
"image-name": image_name,
|
"image-name": image_name,
|
||||||
"python-version": PY_MATRIX[py_version],
|
"python-version": PY_MATRIX[py_version],
|
||||||
"ray-commit": _get_commit_sha()
|
"ray-commit": _get_commit_sha(),
|
||||||
}
|
}
|
||||||
if image_type in CUDA_FULL:
|
if image_type in CUDA_FULL:
|
||||||
labels["cuda-version"] = CUDA_FULL[image_type]
|
labels["cuda-version"] = CUDA_FULL[image_type]
|
||||||
|
@ -222,7 +226,8 @@ def _build_docker_image(image_name: str,
|
||||||
tag=tagged_name,
|
tag=tagged_name,
|
||||||
nocache=no_cache,
|
nocache=no_cache,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
buildargs=build_args)
|
buildargs=build_args,
|
||||||
|
)
|
||||||
|
|
||||||
cmd_output = []
|
cmd_output = []
|
||||||
try:
|
try:
|
||||||
|
@ -230,12 +235,15 @@ def _build_docker_image(image_name: str,
|
||||||
current_iter = start
|
current_iter = start
|
||||||
for line in output:
|
for line in output:
|
||||||
cmd_output.append(line.decode("utf-8"))
|
cmd_output.append(line.decode("utf-8"))
|
||||||
if datetime.datetime.now(
|
if datetime.datetime.now() - current_iter >= datetime.timedelta(
|
||||||
) - current_iter >= datetime.timedelta(minutes=5):
|
minutes=5
|
||||||
|
):
|
||||||
current_iter = datetime.datetime.now()
|
current_iter = datetime.datetime.now()
|
||||||
elapsed = datetime.datetime.now() - start
|
elapsed = datetime.datetime.now() - start
|
||||||
print(f"Still building {tagged_name} after "
|
print(
|
||||||
f"{elapsed.seconds} seconds")
|
f"Still building {tagged_name} after "
|
||||||
|
f"{elapsed.seconds} seconds"
|
||||||
|
)
|
||||||
if elapsed >= datetime.timedelta(minutes=15):
|
if elapsed >= datetime.timedelta(minutes=15):
|
||||||
print("Additional build output:")
|
print("Additional build output:")
|
||||||
print(*cmd_output, sep="\n")
|
print(*cmd_output, sep="\n")
|
||||||
|
@ -259,8 +267,10 @@ def _build_docker_image(image_name: str,
|
||||||
|
|
||||||
def copy_wheels(human_build):
|
def copy_wheels(human_build):
|
||||||
if human_build:
|
if human_build:
|
||||||
print("Please download images using:\n"
|
print(
|
||||||
"`pip download --python-version <py_version> ray==<ray_version>")
|
"Please download images using:\n"
|
||||||
|
"`pip download --python-version <py_version> ray==<ray_version>"
|
||||||
|
)
|
||||||
root_dir = _get_root_dir()
|
root_dir = _get_root_dir()
|
||||||
wheels = _get_wheel_name(None)
|
wheels = _get_wheel_name(None)
|
||||||
for wheel in wheels:
|
for wheel in wheels:
|
||||||
|
@ -268,7 +278,8 @@ def copy_wheels(human_build):
|
||||||
ray_dst = os.path.join(root_dir, "docker/ray/.whl/")
|
ray_dst = os.path.join(root_dir, "docker/ray/.whl/")
|
||||||
ray_dep_dst = os.path.join(root_dir, "docker/ray-deps/.whl/")
|
ray_dep_dst = os.path.join(root_dir, "docker/ray-deps/.whl/")
|
||||||
ray_worker_container_dst = os.path.join(
|
ray_worker_container_dst = os.path.join(
|
||||||
root_dir, "docker/ray-worker-container/.whl/")
|
root_dir, "docker/ray-worker-container/.whl/"
|
||||||
|
)
|
||||||
os.makedirs(ray_dst, exist_ok=True)
|
os.makedirs(ray_dst, exist_ok=True)
|
||||||
shutil.copy(source, ray_dst)
|
shutil.copy(source, ray_dst)
|
||||||
os.makedirs(ray_dep_dst, exist_ok=True)
|
os.makedirs(ray_dep_dst, exist_ok=True)
|
||||||
|
@ -282,8 +293,7 @@ def check_staleness(repository, tag):
|
||||||
|
|
||||||
age = DOCKER_CLIENT.api.inspect_image(f"{repository}:{tag}")["Created"]
|
age = DOCKER_CLIENT.api.inspect_image(f"{repository}:{tag}")["Created"]
|
||||||
short_date = datetime.datetime.strptime(age.split("T")[0], "%Y-%m-%d")
|
short_date = datetime.datetime.strptime(age.split("T")[0], "%Y-%m-%d")
|
||||||
is_stale = (
|
is_stale = (datetime.datetime.now() - short_date) > datetime.timedelta(days=14)
|
||||||
datetime.datetime.now() - short_date) > datetime.timedelta(days=14)
|
|
||||||
return is_stale
|
return is_stale
|
||||||
|
|
||||||
|
|
||||||
|
@ -292,28 +302,23 @@ def build_for_all_versions(image_name, py_versions, image_types, **kwargs):
|
||||||
for py_version in py_versions:
|
for py_version in py_versions:
|
||||||
for image_type in image_types:
|
for image_type in image_types:
|
||||||
_build_docker_image(
|
_build_docker_image(
|
||||||
image_name,
|
image_name, py_version=py_version, image_type=image_type, **kwargs
|
||||||
py_version=py_version,
|
)
|
||||||
image_type=image_type,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def build_base_images(py_versions, image_types):
|
def build_base_images(py_versions, image_types):
|
||||||
build_for_all_versions(
|
build_for_all_versions("base-deps", py_versions, image_types, no_cache=False)
|
||||||
"base-deps", py_versions, image_types, no_cache=False)
|
build_for_all_versions("ray-deps", py_versions, image_types, no_cache=False)
|
||||||
build_for_all_versions(
|
|
||||||
"ray-deps", py_versions, image_types, no_cache=False)
|
|
||||||
|
|
||||||
|
|
||||||
def build_or_pull_base_images(py_versions: List[str],
|
def build_or_pull_base_images(
|
||||||
image_types: List[str],
|
py_versions: List[str], image_types: List[str], rebuild_base_images: bool = True
|
||||||
rebuild_base_images: bool = True) -> bool:
|
) -> bool:
|
||||||
"""Returns images to tag and build."""
|
"""Returns images to tag and build."""
|
||||||
repositories = ["rayproject/base-deps", "rayproject/ray-deps"]
|
repositories = ["rayproject/base-deps", "rayproject/ray-deps"]
|
||||||
tags = [
|
tags = [
|
||||||
f"nightly-{py_version}-{image_type}"
|
f"nightly-{py_version}-{image_type}"
|
||||||
for py_version, image_type in itertools.product(
|
for py_version, image_type in itertools.product(py_versions, image_types)
|
||||||
py_versions, image_types)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -339,12 +344,15 @@ def build_or_pull_base_images(py_versions: List[str],
|
||||||
def prep_ray_ml():
|
def prep_ray_ml():
|
||||||
root_dir = _get_root_dir()
|
root_dir = _get_root_dir()
|
||||||
requirement_files = glob.glob(
|
requirement_files = glob.glob(
|
||||||
f"{_get_root_dir()}/python/**/requirements*.txt", recursive=True)
|
f"{_get_root_dir()}/python/**/requirements*.txt", recursive=True
|
||||||
|
)
|
||||||
for fl in requirement_files:
|
for fl in requirement_files:
|
||||||
shutil.copy(fl, os.path.join(root_dir, "docker/ray-ml/"))
|
shutil.copy(fl, os.path.join(root_dir, "docker/ray-ml/"))
|
||||||
# Install atari roms script
|
# Install atari roms script
|
||||||
shutil.copy(f"{_get_root_dir()}/rllib/utils/install_atari_roms.sh",
|
shutil.copy(
|
||||||
os.path.join(root_dir, "docker/ray-ml/"))
|
f"{_get_root_dir()}/rllib/utils/install_atari_roms.sh",
|
||||||
|
os.path.join(root_dir, "docker/ray-ml/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_docker_creds() -> Tuple[str, str]:
|
def _get_docker_creds() -> Tuple[str, str]:
|
||||||
|
@ -377,10 +385,13 @@ def _tag_and_push(full_image_name, old_tag, new_tag, merge_build=False):
|
||||||
DOCKER_CLIENT.api.tag(
|
DOCKER_CLIENT.api.tag(
|
||||||
image=f"{full_image_name}:{old_tag}",
|
image=f"{full_image_name}:{old_tag}",
|
||||||
repository=full_image_name,
|
repository=full_image_name,
|
||||||
tag=new_tag)
|
tag=new_tag,
|
||||||
|
)
|
||||||
if not merge_build:
|
if not merge_build:
|
||||||
print("This is a PR Build! On a merge build, we would normally push"
|
print(
|
||||||
f"to: {full_image_name}:{new_tag}")
|
"This is a PR Build! On a merge build, we would normally push"
|
||||||
|
f"to: {full_image_name}:{new_tag}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
_docker_push(full_image_name, new_tag)
|
_docker_push(full_image_name, new_tag)
|
||||||
|
|
||||||
|
@ -395,16 +406,17 @@ def _create_new_tags(all_tags, old_str, new_str):
|
||||||
|
|
||||||
# For non-release builds, push "nightly" & "sha"
|
# For non-release builds, push "nightly" & "sha"
|
||||||
# For release builds, push "nightly" & "latest" & "x.x.x"
|
# For release builds, push "nightly" & "latest" & "x.x.x"
|
||||||
def push_and_tag_images(py_versions: List[str],
|
def push_and_tag_images(
|
||||||
image_types: List[str],
|
py_versions: List[str],
|
||||||
push_base_images: bool,
|
image_types: List[str],
|
||||||
merge_build: bool = False):
|
push_base_images: bool,
|
||||||
|
merge_build: bool = False,
|
||||||
|
):
|
||||||
|
|
||||||
date_tag = datetime.datetime.now().strftime("%Y-%m-%d")
|
date_tag = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||||
sha_tag = _get_commit_sha()
|
sha_tag = _get_commit_sha()
|
||||||
if _release_build():
|
if _release_build():
|
||||||
release_name = re.search("[0-9]+\.[0-9]+\.[0-9].*",
|
release_name = re.search("[0-9]+\.[0-9]+\.[0-9].*", _get_branch()).group(0)
|
||||||
_get_branch()).group(0)
|
|
||||||
date_tag = release_name
|
date_tag = release_name
|
||||||
sha_tag = release_name
|
sha_tag = release_name
|
||||||
|
|
||||||
|
@ -423,16 +435,19 @@ def push_and_tag_images(py_versions: List[str],
|
||||||
for py_name in py_versions:
|
for py_name in py_versions:
|
||||||
for image_type in image_types:
|
for image_type in image_types:
|
||||||
if image_name == "ray-ml" and image_type != ML_CUDA_VERSION:
|
if image_name == "ray-ml" and image_type != ML_CUDA_VERSION:
|
||||||
print("ML Docker image is not built for the following "
|
print(
|
||||||
f"device type: {image_type}")
|
"ML Docker image is not built for the following "
|
||||||
|
f"device type: {image_type}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# TODO(https://github.com/ray-project/ray/issues/16599):
|
# TODO(https://github.com/ray-project/ray/issues/16599):
|
||||||
# remove below after supporting ray-ml images with Python 3.9
|
# remove below after supporting ray-ml images with Python 3.9
|
||||||
if image_name in ["ray-ml"
|
if image_name in ["ray-ml"] and PY_MATRIX[py_name].startswith("3.9"):
|
||||||
] and PY_MATRIX[py_name].startswith("3.9"):
|
print(
|
||||||
print(f"{image_name} image is currently "
|
f"{image_name} image is currently "
|
||||||
f"unsupported with Python 3.9")
|
f"unsupported with Python 3.9"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tag = f"nightly-{py_name}-{image_type}"
|
tag = f"nightly-{py_name}-{image_type}"
|
||||||
|
@ -445,20 +460,19 @@ def push_and_tag_images(py_versions: List[str],
|
||||||
for old_tag in tag_mapping.keys():
|
for old_tag in tag_mapping.keys():
|
||||||
if "cpu" in old_tag:
|
if "cpu" in old_tag:
|
||||||
new_tags = _create_new_tags(
|
new_tags = _create_new_tags(
|
||||||
tag_mapping[old_tag], old_str="-cpu", new_str="")
|
tag_mapping[old_tag], old_str="-cpu", new_str=""
|
||||||
|
)
|
||||||
tag_mapping[old_tag].extend(new_tags)
|
tag_mapping[old_tag].extend(new_tags)
|
||||||
elif ML_CUDA_VERSION in old_tag:
|
elif ML_CUDA_VERSION in old_tag:
|
||||||
new_tags = _create_new_tags(
|
new_tags = _create_new_tags(
|
||||||
tag_mapping[old_tag],
|
tag_mapping[old_tag], old_str=f"-{ML_CUDA_VERSION}", new_str="-gpu"
|
||||||
old_str=f"-{ML_CUDA_VERSION}",
|
)
|
||||||
new_str="-gpu")
|
|
||||||
tag_mapping[old_tag].extend(new_tags)
|
tag_mapping[old_tag].extend(new_tags)
|
||||||
|
|
||||||
if image_name == "ray-ml":
|
if image_name == "ray-ml":
|
||||||
new_tags = _create_new_tags(
|
new_tags = _create_new_tags(
|
||||||
tag_mapping[old_tag],
|
tag_mapping[old_tag], old_str=f"-{ML_CUDA_VERSION}", new_str=""
|
||||||
old_str=f"-{ML_CUDA_VERSION}",
|
)
|
||||||
new_str="")
|
|
||||||
tag_mapping[old_tag].extend(new_tags)
|
tag_mapping[old_tag].extend(new_tags)
|
||||||
|
|
||||||
# No Python version specified should refer to DEFAULT_PYTHON_VERSION
|
# No Python version specified should refer to DEFAULT_PYTHON_VERSION
|
||||||
|
@ -467,7 +481,8 @@ def push_and_tag_images(py_versions: List[str],
|
||||||
new_tags = _create_new_tags(
|
new_tags = _create_new_tags(
|
||||||
tag_mapping[old_tag],
|
tag_mapping[old_tag],
|
||||||
old_str=f"-{DEFAULT_PYTHON_VERSION}",
|
old_str=f"-{DEFAULT_PYTHON_VERSION}",
|
||||||
new_str="")
|
new_str="",
|
||||||
|
)
|
||||||
tag_mapping[old_tag].extend(new_tags)
|
tag_mapping[old_tag].extend(new_tags)
|
||||||
|
|
||||||
# For all tags, create Date/Sha tags
|
# For all tags, create Date/Sha tags
|
||||||
|
@ -475,7 +490,8 @@ def push_and_tag_images(py_versions: List[str],
|
||||||
new_tags = _create_new_tags(
|
new_tags = _create_new_tags(
|
||||||
tag_mapping[old_tag],
|
tag_mapping[old_tag],
|
||||||
old_str="nightly",
|
old_str="nightly",
|
||||||
new_str=date_tag if "-deps" in image_name else sha_tag)
|
new_str=date_tag if "-deps" in image_name else sha_tag,
|
||||||
|
)
|
||||||
tag_mapping[old_tag].extend(new_tags)
|
tag_mapping[old_tag].extend(new_tags)
|
||||||
|
|
||||||
# Sanity checking.
|
# Sanity checking.
|
||||||
|
@ -511,7 +527,8 @@ def push_and_tag_images(py_versions: List[str],
|
||||||
full_image_name,
|
full_image_name,
|
||||||
old_tag=old_tag,
|
old_tag=old_tag,
|
||||||
new_tag=new_tag,
|
new_tag=new_tag,
|
||||||
merge_build=merge_build)
|
merge_build=merge_build,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Push infra here:
|
# Push infra here:
|
||||||
|
@ -527,9 +544,9 @@ def push_readmes(merge_build: bool):
|
||||||
"DOCKER_PASS": password,
|
"DOCKER_PASS": password,
|
||||||
"PUSHRM_FILE": f"/myvol/docker/{image}/README.md",
|
"PUSHRM_FILE": f"/myvol/docker/{image}/README.md",
|
||||||
"PUSHRM_DEBUG": 1,
|
"PUSHRM_DEBUG": 1,
|
||||||
"PUSHRM_SHORT": tag_line
|
"PUSHRM_SHORT": tag_line,
|
||||||
}
|
}
|
||||||
cmd_string = (f"rayproject/{image}")
|
cmd_string = f"rayproject/{image}"
|
||||||
|
|
||||||
print(
|
print(
|
||||||
DOCKER_CLIENT.containers.run(
|
DOCKER_CLIENT.containers.run(
|
||||||
|
@ -546,7 +563,9 @@ def push_readmes(merge_build: bool):
|
||||||
detach=False,
|
detach=False,
|
||||||
stderr=True,
|
stderr=True,
|
||||||
stdout=True,
|
stdout=True,
|
||||||
tty=False))
|
tty=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Build base-deps/ray-deps only on file change, 2 weeks, per release
|
# Build base-deps/ray-deps only on file change, 2 weeks, per release
|
||||||
|
@ -566,63 +585,73 @@ if __name__ == "__main__":
|
||||||
choices=list(PY_MATRIX.keys()),
|
choices=list(PY_MATRIX.keys()),
|
||||||
default="py37",
|
default="py37",
|
||||||
nargs="*",
|
nargs="*",
|
||||||
help="Which python versions to build. "
|
help="Which python versions to build. " "Must be in (py36, py37, py38, py39)",
|
||||||
"Must be in (py36, py37, py38, py39)")
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device-types",
|
"--device-types",
|
||||||
choices=list(BASE_IMAGES.keys()),
|
choices=list(BASE_IMAGES.keys()),
|
||||||
default=None,
|
default=None,
|
||||||
nargs="*",
|
nargs="*",
|
||||||
help="Which device types (CPU/CUDA versions) to build images for. "
|
help="Which device types (CPU/CUDA versions) to build images for. "
|
||||||
"If not specified, images will be built for all device types.")
|
"If not specified, images will be built for all device types.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--build-type",
|
"--build-type",
|
||||||
choices=BUILD_TYPES,
|
choices=BUILD_TYPES,
|
||||||
required=True,
|
required=True,
|
||||||
help="Whether to bypass checking if docker is affected")
|
help="Whether to bypass checking if docker is affected",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--build-base",
|
"--build-base",
|
||||||
dest="base",
|
dest="base",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to build base-deps & ray-deps")
|
help="Whether to build base-deps & ray-deps",
|
||||||
|
)
|
||||||
parser.add_argument("--no-build-base", dest="base", action="store_false")
|
parser.add_argument("--no-build-base", dest="base", action="store_false")
|
||||||
parser.set_defaults(base=True)
|
parser.set_defaults(base=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--only-build-worker-container",
|
"--only-build-worker-container",
|
||||||
dest="only_build_worker_container",
|
dest="only_build_worker_container",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether only to build ray-worker-container")
|
help="Whether only to build ray-worker-container",
|
||||||
|
)
|
||||||
parser.set_defaults(only_build_worker_container=False)
|
parser.set_defaults(only_build_worker_container=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
py_versions = args.py_versions
|
py_versions = args.py_versions
|
||||||
py_versions = py_versions if isinstance(py_versions,
|
py_versions = py_versions if isinstance(py_versions, list) else [py_versions]
|
||||||
list) else [py_versions]
|
|
||||||
|
|
||||||
image_types = args.device_types if args.device_types else list(
|
image_types = args.device_types if args.device_types else list(BASE_IMAGES.keys())
|
||||||
BASE_IMAGES.keys())
|
|
||||||
|
|
||||||
assert set(list(CUDA_FULL.keys()) + ["cpu"]) == set(BASE_IMAGES.keys())
|
assert set(list(CUDA_FULL.keys()) + ["cpu"]) == set(BASE_IMAGES.keys())
|
||||||
|
|
||||||
# Make sure the python images and cuda versions we build here are
|
# Make sure the python images and cuda versions we build here are
|
||||||
# consistent with the ones used with fix-latest-docker.sh script.
|
# consistent with the ones used with fix-latest-docker.sh script.
|
||||||
py_version_file = os.path.join(_get_root_dir(), "docker/retag-lambda",
|
py_version_file = os.path.join(
|
||||||
"python_versions.txt")
|
_get_root_dir(), "docker/retag-lambda", "python_versions.txt"
|
||||||
|
)
|
||||||
with open(py_version_file) as f:
|
with open(py_version_file) as f:
|
||||||
py_file_versions = f.read().splitlines()
|
py_file_versions = f.read().splitlines()
|
||||||
assert set(PY_MATRIX.keys()) == set(py_file_versions), \
|
assert set(PY_MATRIX.keys()) == set(py_file_versions), (
|
||||||
(PY_MATRIX.keys(), py_file_versions)
|
PY_MATRIX.keys(),
|
||||||
|
py_file_versions,
|
||||||
|
)
|
||||||
|
|
||||||
cuda_version_file = os.path.join(_get_root_dir(), "docker/retag-lambda",
|
cuda_version_file = os.path.join(
|
||||||
"cuda_versions.txt")
|
_get_root_dir(), "docker/retag-lambda", "cuda_versions.txt"
|
||||||
|
)
|
||||||
|
|
||||||
with open(cuda_version_file) as f:
|
with open(cuda_version_file) as f:
|
||||||
cuda_file_versions = f.read().splitlines()
|
cuda_file_versions = f.read().splitlines()
|
||||||
assert set(BASE_IMAGES.keys()) == set(cuda_file_versions + ["cpu"]),\
|
assert set(BASE_IMAGES.keys()) == set(cuda_file_versions + ["cpu"]), (
|
||||||
(BASE_IMAGES.keys(), cuda_file_versions + ["cpu"])
|
BASE_IMAGES.keys(),
|
||||||
|
cuda_file_versions + ["cpu"],
|
||||||
|
)
|
||||||
|
|
||||||
print("Building the following python versions: ",
|
print(
|
||||||
[PY_MATRIX[py_version] for py_version in py_versions])
|
"Building the following python versions: ",
|
||||||
|
[PY_MATRIX[py_version] for py_version in py_versions],
|
||||||
|
)
|
||||||
print("Building images for the following devices: ", image_types)
|
print("Building images for the following devices: ", image_types)
|
||||||
print("Building base images: ", args.base)
|
print("Building base images: ", args.base)
|
||||||
|
|
||||||
|
@ -639,9 +668,11 @@ if __name__ == "__main__":
|
||||||
if build_type == HUMAN:
|
if build_type == HUMAN:
|
||||||
# If manually triggered, request user for branch and SHA value to use.
|
# If manually triggered, request user for branch and SHA value to use.
|
||||||
_configure_human_version()
|
_configure_human_version()
|
||||||
if (build_type in {HUMAN, MERGE, BUILDKITE, LOCAL}
|
if (
|
||||||
or _check_if_docker_files_modified()
|
build_type in {HUMAN, MERGE, BUILDKITE, LOCAL}
|
||||||
or args.only_build_worker_container):
|
or _check_if_docker_files_modified()
|
||||||
|
or args.only_build_worker_container
|
||||||
|
):
|
||||||
DOCKER_CLIENT = docker.from_env()
|
DOCKER_CLIENT = docker.from_env()
|
||||||
is_merge = build_type == MERGE
|
is_merge = build_type == MERGE
|
||||||
# Buildkite is authenticated in the background.
|
# Buildkite is authenticated in the background.
|
||||||
|
@ -652,11 +683,11 @@ if __name__ == "__main__":
|
||||||
DOCKER_CLIENT.api.login(username=username, password=password)
|
DOCKER_CLIENT.api.login(username=username, password=password)
|
||||||
copy_wheels(build_type == HUMAN)
|
copy_wheels(build_type == HUMAN)
|
||||||
is_base_images_built = build_or_pull_base_images(
|
is_base_images_built = build_or_pull_base_images(
|
||||||
py_versions, image_types, args.base)
|
py_versions, image_types, args.base
|
||||||
|
)
|
||||||
|
|
||||||
if args.only_build_worker_container:
|
if args.only_build_worker_container:
|
||||||
build_for_all_versions("ray-worker-container", py_versions,
|
build_for_all_versions("ray-worker-container", py_versions, image_types)
|
||||||
image_types)
|
|
||||||
# TODO Currently don't push ray_worker_container
|
# TODO Currently don't push ray_worker_container
|
||||||
else:
|
else:
|
||||||
# Build Ray Docker images.
|
# Build Ray Docker images.
|
||||||
|
@ -668,15 +699,19 @@ if __name__ == "__main__":
|
||||||
prep_ray_ml()
|
prep_ray_ml()
|
||||||
# Only build ML Docker for the ML_CUDA_VERSION
|
# Only build ML Docker for the ML_CUDA_VERSION
|
||||||
build_for_all_versions(
|
build_for_all_versions(
|
||||||
"ray-ml", py_versions, image_types=[ML_CUDA_VERSION])
|
"ray-ml", py_versions, image_types=[ML_CUDA_VERSION]
|
||||||
|
)
|
||||||
|
|
||||||
if build_type in {MERGE, PR}:
|
if build_type in {MERGE, PR}:
|
||||||
valid_branch = _valid_branch()
|
valid_branch = _valid_branch()
|
||||||
if (not valid_branch) and is_merge:
|
if (not valid_branch) and is_merge:
|
||||||
print(f"Invalid Branch found: {_get_branch()}")
|
print(f"Invalid Branch found: {_get_branch()}")
|
||||||
push_and_tag_images(py_versions, image_types,
|
push_and_tag_images(
|
||||||
is_base_images_built, valid_branch
|
py_versions,
|
||||||
and is_merge)
|
image_types,
|
||||||
|
is_base_images_built,
|
||||||
|
valid_branch and is_merge,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(ilr) Re-Enable Push READMEs by using a normal password
|
# TODO(ilr) Re-Enable Push READMEs by using a normal password
|
||||||
# (not auth token :/)
|
# (not auth token :/)
|
||||||
|
|
|
@ -20,7 +20,8 @@ def build_multinode_image(source_image: str, target_image: str):
|
||||||
f.write("RUN sudo apt install -y openssh-server\n")
|
f.write("RUN sudo apt install -y openssh-server\n")
|
||||||
|
|
||||||
subprocess.check_output(
|
subprocess.check_output(
|
||||||
f"docker build -t {target_image} .", shell=True, cwd=tempdir)
|
f"docker build -t {target_image} .", shell=True, cwd=tempdir
|
||||||
|
)
|
||||||
|
|
||||||
shutil.rmtree(tempdir)
|
shutil.rmtree(tempdir)
|
||||||
|
|
||||||
|
|
|
@ -25,9 +25,7 @@ def perform_check(raw_xml_string: str):
|
||||||
missing_owners = []
|
missing_owners = []
|
||||||
for rule in tree.findall("rule"):
|
for rule in tree.findall("rule"):
|
||||||
test_name = rule.attrib["name"]
|
test_name = rule.attrib["name"]
|
||||||
tags = [
|
tags = [child.attrib["value"] for child in rule.find("list").getchildren()]
|
||||||
child.attrib["value"] for child in rule.find("list").getchildren()
|
|
||||||
]
|
|
||||||
team_owner = [t for t in tags if t.startswith("team")]
|
team_owner = [t for t in tags if t.startswith("team")]
|
||||||
if len(team_owner) == 0:
|
if len(team_owner) == 0:
|
||||||
missing_owners.append(test_name)
|
missing_owners.append(test_name)
|
||||||
|
@ -36,7 +34,8 @@ def perform_check(raw_xml_string: str):
|
||||||
if len(missing_owners):
|
if len(missing_owners):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Cannot find owner for tests {missing_owners}, please add "
|
f"Cannot find owner for tests {missing_owners}, please add "
|
||||||
"`team:*` to the tags.")
|
"`team:*` to the tags."
|
||||||
|
)
|
||||||
|
|
||||||
print(owners)
|
print(owners)
|
||||||
|
|
||||||
|
|
|
@ -19,11 +19,7 @@ exit_with_error = False
|
||||||
|
|
||||||
|
|
||||||
def check_import(file):
|
def check_import(file):
|
||||||
check_to_lines = {
|
check_to_lines = {"import ray": -1, "import psutil": -1, "import setproctitle": -1}
|
||||||
"import ray": -1,
|
|
||||||
"import psutil": -1,
|
|
||||||
"import setproctitle": -1
|
|
||||||
}
|
|
||||||
|
|
||||||
with io.open(file, "r", encoding="utf-8") as f:
|
with io.open(file, "r", encoding="utf-8") as f:
|
||||||
for i, line in enumerate(f):
|
for i, line in enumerate(f):
|
||||||
|
@ -37,8 +33,10 @@ def check_import(file):
|
||||||
# It will not match the following
|
# It will not match the following
|
||||||
# - submodule import: `import ray.constants as ray_constants`
|
# - submodule import: `import ray.constants as ray_constants`
|
||||||
# - submodule import: `from ray import xyz`
|
# - submodule import: `from ray import xyz`
|
||||||
if re.search(r"^\s*" + check + r"(\s*|\s+# noqa F401.*)$",
|
if (
|
||||||
line) and check_to_lines[check] == -1:
|
re.search(r"^\s*" + check + r"(\s*|\s+# noqa F401.*)$", line)
|
||||||
|
and check_to_lines[check] == -1
|
||||||
|
):
|
||||||
check_to_lines[check] = i
|
check_to_lines[check] = i
|
||||||
|
|
||||||
for import_lib in ["import psutil", "import setproctitle"]:
|
for import_lib in ["import psutil", "import setproctitle"]:
|
||||||
|
@ -48,8 +46,8 @@ def check_import(file):
|
||||||
if import_ray_line == -1 or import_ray_line > import_psutil_line:
|
if import_ray_line == -1 or import_ray_line > import_psutil_line:
|
||||||
print(
|
print(
|
||||||
"{}:{}".format(str(file), import_psutil_line + 1),
|
"{}:{}".format(str(file), import_psutil_line + 1),
|
||||||
"{} without explicitly import ray before it.".format(
|
"{} without explicitly import ray before it.".format(import_lib),
|
||||||
import_lib))
|
)
|
||||||
global exit_with_error
|
global exit_with_error
|
||||||
exit_with_error = True
|
exit_with_error = True
|
||||||
|
|
||||||
|
@ -59,8 +57,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("path", help="File path to check. e.g. '.' or './src'")
|
parser.add_argument("path", help="File path to check. e.g. '.' or './src'")
|
||||||
# TODO(simon): For the future, consider adding a feature to explicitly
|
# TODO(simon): For the future, consider adding a feature to explicitly
|
||||||
# white-list the path instead of skipping them.
|
# white-list the path instead of skipping them.
|
||||||
parser.add_argument(
|
parser.add_argument("-s", "--skip", action="append", help="Skip certian directory")
|
||||||
"-s", "--skip", action="append", help="Skip certian directory")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
file_path = Path(args.path)
|
file_path = Path(args.path)
|
||||||
|
|
|
@ -18,7 +18,7 @@ DEFAULT_BLACKLIST = [
|
||||||
"gpustat",
|
"gpustat",
|
||||||
"opencensus",
|
"opencensus",
|
||||||
"prometheus_client",
|
"prometheus_client",
|
||||||
"smart_open"
|
"smart_open",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,19 +28,20 @@ def assert_packages_not_installed(blacklist: List[str]):
|
||||||
except ImportError: # pip < 10.0
|
except ImportError: # pip < 10.0
|
||||||
from pip.operations import freeze
|
from pip.operations import freeze
|
||||||
|
|
||||||
installed_packages = [
|
installed_packages = [p.split("==")[0].split(" @ ")[0] for p in freeze.freeze()]
|
||||||
p.split("==")[0].split(" @ ")[0] for p in freeze.freeze()
|
|
||||||
]
|
|
||||||
|
|
||||||
assert not any(p in installed_packages for p in blacklist), \
|
assert not any(p in installed_packages for p in blacklist), (
|
||||||
f"Found blacklisted packages in installed python packages: " \
|
f"Found blacklisted packages in installed python packages: "
|
||||||
f"{[p for p in blacklist if p in installed_packages]}. " \
|
f"{[p for p in blacklist if p in installed_packages]}. "
|
||||||
f"Minimal dependency tests could be tainted by this. " \
|
f"Minimal dependency tests could be tainted by this. "
|
||||||
f"Check the install logs and primary dependencies if any of these " \
|
f"Check the install logs and primary dependencies if any of these "
|
||||||
f"packages were installed as part of another install step."
|
f"packages were installed as part of another install step."
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Confirmed that blacklisted packages are not installed in "
|
print(
|
||||||
f"current Python environment: {blacklist}")
|
f"Confirmed that blacklisted packages are not installed in "
|
||||||
|
f"current Python environment: {blacklist}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -54,7 +54,8 @@ def run_tidy(task_queue, lock, timeout):
|
||||||
command = task_queue.get()
|
command = task_queue.get()
|
||||||
try:
|
try:
|
||||||
proc = subprocess.Popen(
|
proc = subprocess.Popen(
|
||||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
|
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
watchdog = threading.Timer(timeout, proc.kill)
|
watchdog = threading.Timer(timeout, proc.kill)
|
||||||
|
@ -70,22 +71,21 @@ def run_tidy(task_queue, lock, timeout):
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
with lock:
|
with lock:
|
||||||
sys.stderr.write("Failed: " + str(e) + ": ".join(command) +
|
sys.stderr.write("Failed: " + str(e) + ": ".join(command) + "\n")
|
||||||
"\n")
|
|
||||||
finally:
|
finally:
|
||||||
with lock:
|
with lock:
|
||||||
if timeout is not None and watchdog is not None:
|
if timeout is not None and watchdog is not None:
|
||||||
if not watchdog.is_alive():
|
if not watchdog.is_alive():
|
||||||
sys.stderr.write("Terminated by timeout: " +
|
sys.stderr.write(
|
||||||
" ".join(command) + "\n")
|
"Terminated by timeout: " + " ".join(command) + "\n"
|
||||||
|
)
|
||||||
watchdog.cancel()
|
watchdog.cancel()
|
||||||
task_queue.task_done()
|
task_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
def start_workers(max_tasks, tidy_caller, task_queue, lock, timeout):
|
def start_workers(max_tasks, tidy_caller, task_queue, lock, timeout):
|
||||||
for _ in range(max_tasks):
|
for _ in range(max_tasks):
|
||||||
t = threading.Thread(
|
t = threading.Thread(target=tidy_caller, args=(task_queue, lock, timeout))
|
||||||
target=tidy_caller, args=(task_queue, lock, timeout))
|
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
@ -119,84 +119,87 @@ def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Run clang-tidy against changed files, and "
|
description="Run clang-tidy against changed files, and "
|
||||||
"output diagnostics only for modified "
|
"output diagnostics only for modified "
|
||||||
"lines.")
|
"lines."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-clang-tidy-binary",
|
"-clang-tidy-binary",
|
||||||
metavar="PATH",
|
metavar="PATH",
|
||||||
default="clang-tidy",
|
default="clang-tidy",
|
||||||
help="path to clang-tidy binary")
|
help="path to clang-tidy binary",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-p",
|
"-p",
|
||||||
metavar="NUM",
|
metavar="NUM",
|
||||||
default=0,
|
default=0,
|
||||||
help="strip the smallest prefix containing P slashes")
|
help="strip the smallest prefix containing P slashes",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-regex",
|
"-regex",
|
||||||
metavar="PATTERN",
|
metavar="PATTERN",
|
||||||
default=None,
|
default=None,
|
||||||
help="custom pattern selecting file paths to check "
|
help="custom pattern selecting file paths to check "
|
||||||
"(case sensitive, overrides -iregex)")
|
"(case sensitive, overrides -iregex)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-iregex",
|
"-iregex",
|
||||||
metavar="PATTERN",
|
metavar="PATTERN",
|
||||||
default=r".*\.(cpp|cc|c\+\+|cxx|c|cl|h|hpp|m|mm|inc)",
|
default=r".*\.(cpp|cc|c\+\+|cxx|c|cl|h|hpp|m|mm|inc)",
|
||||||
help="custom pattern selecting file paths to check "
|
help="custom pattern selecting file paths to check "
|
||||||
"(case insensitive, overridden by -regex)")
|
"(case insensitive, overridden by -regex)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-j",
|
"-j",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="number of tidy instances to be run in parallel.")
|
help="number of tidy instances to be run in parallel.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-timeout",
|
"-timeout", type=int, default=None, help="timeout per each file in seconds."
|
||||||
type=int,
|
)
|
||||||
default=None,
|
|
||||||
help="timeout per each file in seconds.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-fix",
|
"-fix", action="store_true", default=False, help="apply suggested fixes"
|
||||||
action="store_true",
|
)
|
||||||
default=False,
|
|
||||||
help="apply suggested fixes")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-checks",
|
"-checks",
|
||||||
help="checks filter, when not specified, use clang-tidy "
|
help="checks filter, when not specified, use clang-tidy " "default",
|
||||||
"default",
|
default="",
|
||||||
default="")
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-path",
|
"-path", dest="build_path", help="Path used to read a compile command database."
|
||||||
dest="build_path",
|
)
|
||||||
help="Path used to read a compile command database.")
|
|
||||||
if yaml:
|
if yaml:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-export-fixes",
|
"-export-fixes",
|
||||||
metavar="FILE",
|
metavar="FILE",
|
||||||
dest="export_fixes",
|
dest="export_fixes",
|
||||||
help="Create a yaml file to store suggested fixes in, "
|
help="Create a yaml file to store suggested fixes in, "
|
||||||
"which can be applied with clang-apply-replacements.")
|
"which can be applied with clang-apply-replacements.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-extra-arg",
|
"-extra-arg",
|
||||||
dest="extra_arg",
|
dest="extra_arg",
|
||||||
action="append",
|
action="append",
|
||||||
default=[],
|
default=[],
|
||||||
help="Additional argument to append to the compiler "
|
help="Additional argument to append to the compiler " "command line.",
|
||||||
"command line.")
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-extra-arg-before",
|
"-extra-arg-before",
|
||||||
dest="extra_arg_before",
|
dest="extra_arg_before",
|
||||||
action="append",
|
action="append",
|
||||||
default=[],
|
default=[],
|
||||||
help="Additional argument to prepend to the compiler "
|
help="Additional argument to prepend to the compiler " "command line.",
|
||||||
"command line.")
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-quiet",
|
"-quiet",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Run clang-tidy in quiet mode")
|
help="Run clang-tidy in quiet mode",
|
||||||
|
)
|
||||||
clang_tidy_args = []
|
clang_tidy_args = []
|
||||||
argv = sys.argv[1:]
|
argv = sys.argv[1:]
|
||||||
if "--" in argv:
|
if "--" in argv:
|
||||||
clang_tidy_args.extend(argv[argv.index("--"):])
|
clang_tidy_args.extend(argv[argv.index("--") :])
|
||||||
argv = argv[:argv.index("--")]
|
argv = argv[: argv.index("--")]
|
||||||
|
|
||||||
args = parser.parse_args(argv)
|
args = parser.parse_args(argv)
|
||||||
|
|
||||||
|
@ -204,7 +207,7 @@ def main():
|
||||||
filename = None
|
filename = None
|
||||||
lines_by_file = {}
|
lines_by_file = {}
|
||||||
for line in sys.stdin:
|
for line in sys.stdin:
|
||||||
match = re.search('^\+\+\+\ \"?(.*?/){%s}([^ \t\n\"]*)' % args.p, line)
|
match = re.search('^\+\+\+\ "?(.*?/){%s}([^ \t\n"]*)' % args.p, line)
|
||||||
if match:
|
if match:
|
||||||
filename = match.group(2)
|
filename = match.group(2)
|
||||||
if filename is None:
|
if filename is None:
|
||||||
|
@ -226,8 +229,7 @@ def main():
|
||||||
if line_count == 0:
|
if line_count == 0:
|
||||||
continue
|
continue
|
||||||
end_line = start_line + line_count - 1
|
end_line = start_line + line_count - 1
|
||||||
lines_by_file.setdefault(filename,
|
lines_by_file.setdefault(filename, []).append([start_line, end_line])
|
||||||
[]).append([start_line, end_line])
|
|
||||||
|
|
||||||
if not any(lines_by_file):
|
if not any(lines_by_file):
|
||||||
print("No relevant changes found.")
|
print("No relevant changes found.")
|
||||||
|
@ -267,11 +269,8 @@ def main():
|
||||||
|
|
||||||
for name in lines_by_file:
|
for name in lines_by_file:
|
||||||
line_filter_json = json.dumps(
|
line_filter_json = json.dumps(
|
||||||
[{
|
[{"name": name, "lines": lines_by_file[name]}], separators=(",", ":")
|
||||||
"name": name,
|
)
|
||||||
"lines": lines_by_file[name]
|
|
||||||
}],
|
|
||||||
separators=(",", ":"))
|
|
||||||
|
|
||||||
# Run clang-tidy on files containing changes.
|
# Run clang-tidy on files containing changes.
|
||||||
command = [args.clang_tidy_binary]
|
command = [args.clang_tidy_binary]
|
||||||
|
|
|
@ -38,8 +38,10 @@ def is_pull_request():
|
||||||
for key in ["GITHUB_EVENT_NAME", "TRAVIS_EVENT_TYPE"]:
|
for key in ["GITHUB_EVENT_NAME", "TRAVIS_EVENT_TYPE"]:
|
||||||
event_type = os.getenv(key, event_type)
|
event_type = os.getenv(key, event_type)
|
||||||
|
|
||||||
if (os.environ.get("BUILDKITE")
|
if (
|
||||||
and os.environ.get("BUILDKITE_PULL_REQUEST") != "false"):
|
os.environ.get("BUILDKITE")
|
||||||
|
and os.environ.get("BUILDKITE_PULL_REQUEST") != "false"
|
||||||
|
):
|
||||||
event_type = "pull_request"
|
event_type = "pull_request"
|
||||||
|
|
||||||
return event_type == "pull_request"
|
return event_type == "pull_request"
|
||||||
|
@ -67,8 +69,7 @@ def get_commit_range():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("--output", type=str, help="json or envvars", default="envvars")
|
||||||
"--output", type=str, help="json or envvars", default="envvars")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
RAY_CI_TUNE_AFFECTED = 0
|
RAY_CI_TUNE_AFFECTED = 0
|
||||||
|
@ -103,8 +104,7 @@ if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
graph = pda.build_dep_graph()
|
graph = pda.build_dep_graph()
|
||||||
rllib_tests = pda.list_rllib_tests()
|
rllib_tests = pda.list_rllib_tests()
|
||||||
print(
|
print("Total # of RLlib tests: ", len(rllib_tests), file=sys.stderr)
|
||||||
"Total # of RLlib tests: ", len(rllib_tests), file=sys.stderr)
|
|
||||||
|
|
||||||
impacted = {}
|
impacted = {}
|
||||||
for test in rllib_tests:
|
for test in rllib_tests:
|
||||||
|
@ -120,9 +120,7 @@ if __name__ == "__main__":
|
||||||
print(e, file=sys.stderr)
|
print(e, file=sys.stderr)
|
||||||
# End of dry run.
|
# End of dry run.
|
||||||
|
|
||||||
skip_prefix_list = [
|
skip_prefix_list = ["doc/", "examples/", "dev/", "kubernetes/", "site/"]
|
||||||
"doc/", "examples/", "dev/", "kubernetes/", "site/"
|
|
||||||
]
|
|
||||||
|
|
||||||
for changed_file in files:
|
for changed_file in files:
|
||||||
if changed_file.startswith("python/ray/tune"):
|
if changed_file.startswith("python/ray/tune"):
|
||||||
|
@ -181,7 +179,8 @@ if __name__ == "__main__":
|
||||||
# Java also depends on Python CLI to manage processes.
|
# Java also depends on Python CLI to manage processes.
|
||||||
RAY_CI_JAVA_AFFECTED = 1
|
RAY_CI_JAVA_AFFECTED = 1
|
||||||
if changed_file.startswith("python/setup.py") or re.match(
|
if changed_file.startswith("python/setup.py") or re.match(
|
||||||
".*requirements.*\.txt", changed_file):
|
".*requirements.*\.txt", changed_file
|
||||||
|
):
|
||||||
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED = 1
|
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED = 1
|
||||||
elif changed_file.startswith("java/"):
|
elif changed_file.startswith("java/"):
|
||||||
RAY_CI_JAVA_AFFECTED = 1
|
RAY_CI_JAVA_AFFECTED = 1
|
||||||
|
@ -190,12 +189,9 @@ if __name__ == "__main__":
|
||||||
elif changed_file.startswith("docker/"):
|
elif changed_file.startswith("docker/"):
|
||||||
RAY_CI_DOCKER_AFFECTED = 1
|
RAY_CI_DOCKER_AFFECTED = 1
|
||||||
RAY_CI_LINUX_WHEELS_AFFECTED = 1
|
RAY_CI_LINUX_WHEELS_AFFECTED = 1
|
||||||
elif changed_file.startswith("doc/") and changed_file.endswith(
|
elif changed_file.startswith("doc/") and changed_file.endswith(".py"):
|
||||||
".py"):
|
|
||||||
RAY_CI_DOC_AFFECTED = 1
|
RAY_CI_DOC_AFFECTED = 1
|
||||||
elif any(
|
elif any(changed_file.startswith(prefix) for prefix in skip_prefix_list):
|
||||||
changed_file.startswith(prefix)
|
|
||||||
for prefix in skip_prefix_list):
|
|
||||||
# nothing is run but linting in these cases
|
# nothing is run but linting in these cases
|
||||||
pass
|
pass
|
||||||
elif changed_file.endswith("build-docker-images.py"):
|
elif changed_file.endswith("build-docker-images.py"):
|
||||||
|
@ -246,26 +242,28 @@ if __name__ == "__main__":
|
||||||
RAY_CI_DASHBOARD_AFFECTED = 1
|
RAY_CI_DASHBOARD_AFFECTED = 1
|
||||||
|
|
||||||
# Log the modified environment variables visible in console.
|
# Log the modified environment variables visible in console.
|
||||||
output_string = " ".join([
|
output_string = " ".join(
|
||||||
"RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED),
|
[
|
||||||
"RAY_CI_SGD_AFFECTED={}".format(RAY_CI_SGD_AFFECTED),
|
"RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED),
|
||||||
"RAY_CI_TRAIN_AFFECTED={}".format(RAY_CI_TRAIN_AFFECTED),
|
"RAY_CI_SGD_AFFECTED={}".format(RAY_CI_SGD_AFFECTED),
|
||||||
"RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED),
|
"RAY_CI_TRAIN_AFFECTED={}".format(RAY_CI_TRAIN_AFFECTED),
|
||||||
"RAY_CI_RLLIB_DIRECTLY_AFFECTED={}".format(
|
"RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED),
|
||||||
RAY_CI_RLLIB_DIRECTLY_AFFECTED),
|
"RAY_CI_RLLIB_DIRECTLY_AFFECTED={}".format(RAY_CI_RLLIB_DIRECTLY_AFFECTED),
|
||||||
"RAY_CI_SERVE_AFFECTED={}".format(RAY_CI_SERVE_AFFECTED),
|
"RAY_CI_SERVE_AFFECTED={}".format(RAY_CI_SERVE_AFFECTED),
|
||||||
"RAY_CI_DASHBOARD_AFFECTED={}".format(RAY_CI_DASHBOARD_AFFECTED),
|
"RAY_CI_DASHBOARD_AFFECTED={}".format(RAY_CI_DASHBOARD_AFFECTED),
|
||||||
"RAY_CI_DOC_AFFECTED={}".format(RAY_CI_DOC_AFFECTED),
|
"RAY_CI_DOC_AFFECTED={}".format(RAY_CI_DOC_AFFECTED),
|
||||||
"RAY_CI_CORE_CPP_AFFECTED={}".format(RAY_CI_CORE_CPP_AFFECTED),
|
"RAY_CI_CORE_CPP_AFFECTED={}".format(RAY_CI_CORE_CPP_AFFECTED),
|
||||||
"RAY_CI_CPP_AFFECTED={}".format(RAY_CI_CPP_AFFECTED),
|
"RAY_CI_CPP_AFFECTED={}".format(RAY_CI_CPP_AFFECTED),
|
||||||
"RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED),
|
"RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED),
|
||||||
"RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED),
|
"RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED),
|
||||||
"RAY_CI_LINUX_WHEELS_AFFECTED={}".format(RAY_CI_LINUX_WHEELS_AFFECTED),
|
"RAY_CI_LINUX_WHEELS_AFFECTED={}".format(RAY_CI_LINUX_WHEELS_AFFECTED),
|
||||||
"RAY_CI_MACOS_WHEELS_AFFECTED={}".format(RAY_CI_MACOS_WHEELS_AFFECTED),
|
"RAY_CI_MACOS_WHEELS_AFFECTED={}".format(RAY_CI_MACOS_WHEELS_AFFECTED),
|
||||||
"RAY_CI_DOCKER_AFFECTED={}".format(RAY_CI_DOCKER_AFFECTED),
|
"RAY_CI_DOCKER_AFFECTED={}".format(RAY_CI_DOCKER_AFFECTED),
|
||||||
"RAY_CI_PYTHON_DEPENDENCIES_AFFECTED={}".format(
|
"RAY_CI_PYTHON_DEPENDENCIES_AFFECTED={}".format(
|
||||||
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED),
|
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED
|
||||||
])
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Debug purpose
|
# Debug purpose
|
||||||
print(output_string, file=sys.stderr)
|
print(output_string, file=sys.stderr)
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# Cause the script to exit if a single command fails
|
# Cause the script to exit if a single command fails
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
BLACK_IS_ENABLED=false
|
BLACK_IS_ENABLED=true
|
||||||
|
|
||||||
FLAKE8_VERSION_REQUIRED="3.9.1"
|
FLAKE8_VERSION_REQUIRED="3.9.1"
|
||||||
BLACK_VERSION_REQUIRED="21.12b0"
|
BLACK_VERSION_REQUIRED="21.12b0"
|
||||||
|
|
|
@ -17,19 +17,19 @@ import json
|
||||||
|
|
||||||
def gha_get_self_url():
|
def gha_get_self_url():
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# stringed together api call to get the current check's html url.
|
# stringed together api call to get the current check's html url.
|
||||||
sha = os.environ["GITHUB_SHA"]
|
sha = os.environ["GITHUB_SHA"]
|
||||||
repo = os.environ["GITHUB_REPOSITORY"]
|
repo = os.environ["GITHUB_REPOSITORY"]
|
||||||
resp = requests.get(
|
resp = requests.get(
|
||||||
"https://api.github.com/repos/{}/commits/{}/check-suites".format(
|
"https://api.github.com/repos/{}/commits/{}/check-suites".format(repo, sha)
|
||||||
repo, sha))
|
)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
for check in data["check_suites"]:
|
for check in data["check_suites"]:
|
||||||
slug = check["app"]["slug"]
|
slug = check["app"]["slug"]
|
||||||
if slug == "github-actions":
|
if slug == "github-actions":
|
||||||
run_url = check["check_runs_url"]
|
run_url = check["check_runs_url"]
|
||||||
html_url = (
|
html_url = requests.get(run_url).json()["check_runs"][0]["html_url"]
|
||||||
requests.get(run_url).json()["check_runs"][0]["html_url"])
|
|
||||||
return html_url
|
return html_url
|
||||||
|
|
||||||
# Return a fallback url
|
# Return a fallback url
|
||||||
|
@ -47,10 +47,12 @@ def get_build_env():
|
||||||
if os.environ.get("BUILDKITE"):
|
if os.environ.get("BUILDKITE"):
|
||||||
return {
|
return {
|
||||||
"TRAVIS_COMMIT": os.environ["BUILDKITE_COMMIT"],
|
"TRAVIS_COMMIT": os.environ["BUILDKITE_COMMIT"],
|
||||||
"TRAVIS_JOB_WEB_URL": (os.environ["BUILDKITE_BUILD_URL"] + "#" +
|
"TRAVIS_JOB_WEB_URL": (
|
||||||
os.environ["BUILDKITE_BUILD_ID"]),
|
os.environ["BUILDKITE_BUILD_URL"]
|
||||||
"TRAVIS_OS_NAME": # The map is used to stay consistent with Travis
|
+ "#"
|
||||||
{
|
+ os.environ["BUILDKITE_BUILD_ID"]
|
||||||
|
),
|
||||||
|
"TRAVIS_OS_NAME": { # The map is used to stay consistent with Travis
|
||||||
"linux": "linux",
|
"linux": "linux",
|
||||||
"darwin": "osx",
|
"darwin": "osx",
|
||||||
"win32": "windows",
|
"win32": "windows",
|
||||||
|
@ -70,13 +72,10 @@ def get_build_config():
|
||||||
return {"config": {"env": "Windows CI"}}
|
return {"config": {"env": "Windows CI"}}
|
||||||
|
|
||||||
if os.environ.get("BUILDKITE"):
|
if os.environ.get("BUILDKITE"):
|
||||||
return {
|
return {"config": {"env": "Buildkite " + os.environ["BUILDKITE_LABEL"]}}
|
||||||
"config": {
|
|
||||||
"env": "Buildkite " + os.environ["BUILDKITE_LABEL"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
url = "https://api.travis-ci.com/job/{job_id}?include=job.config"
|
url = "https://api.travis-ci.com/job/{job_id}?include=job.config"
|
||||||
url = url.format(job_id=os.environ["TRAVIS_JOB_ID"])
|
url = url.format(job_id=os.environ["TRAVIS_JOB_ID"])
|
||||||
resp = requests.get(url, headers={"Travis-API-Version": "3"})
|
resp = requests.get(url, headers={"Travis-API-Version": "3"})
|
||||||
|
@ -87,9 +86,4 @@ if __name__ == "__main__":
|
||||||
build_env = get_build_env()
|
build_env = get_build_env()
|
||||||
build_config = get_build_config()
|
build_config = get_build_config()
|
||||||
|
|
||||||
print(
|
print(json.dumps({"build_env": build_env, "build_config": build_config}, indent=2))
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"build_env": build_env,
|
|
||||||
"build_config": build_config
|
|
||||||
}, indent=2))
|
|
||||||
|
|
|
@ -43,7 +43,8 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
|
||||||
test: only return information about a specific test.
|
test: only return information about a specific test.
|
||||||
"""
|
"""
|
||||||
tests_res = _run_shell(
|
tests_res = _run_shell(
|
||||||
["bazel", "query", "tests(//python/ray/rllib:*)", "--output", "label"])
|
["bazel", "query", "tests(//python/ray/rllib:*)", "--output", "label"]
|
||||||
|
)
|
||||||
|
|
||||||
all_tests = []
|
all_tests = []
|
||||||
|
|
||||||
|
@ -53,15 +54,18 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
|
||||||
if test and t != test:
|
if test and t != test:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
src_out = _run_shell([
|
src_out = _run_shell(
|
||||||
"bazel", "query", "kind(\"source file\", deps({}))".format(t),
|
[
|
||||||
"--output", "label"
|
"bazel",
|
||||||
])
|
"query",
|
||||||
|
'kind("source file", deps({}))'.format(t),
|
||||||
|
"--output",
|
||||||
|
"label",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
srcs = [f.strip() for f in src_out.splitlines()]
|
srcs = [f.strip() for f in src_out.splitlines()]
|
||||||
srcs = [
|
srcs = [f for f in srcs if f.startswith("//python") and f.endswith(".py")]
|
||||||
f for f in srcs if f.startswith("//python") and f.endswith(".py")
|
|
||||||
]
|
|
||||||
if srcs:
|
if srcs:
|
||||||
all_tests.append((t, srcs))
|
all_tests.append((t, srcs))
|
||||||
|
|
||||||
|
@ -73,8 +77,7 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
|
||||||
|
|
||||||
|
|
||||||
def _new_dep(graph: DepGraph, src_module: str, dep: str):
|
def _new_dep(graph: DepGraph, src_module: str, dep: str):
|
||||||
"""Create a new dependency between src_module and dep.
|
"""Create a new dependency between src_module and dep."""
|
||||||
"""
|
|
||||||
if dep not in graph.ids:
|
if dep not in graph.ids:
|
||||||
graph.ids[dep] = len(graph.ids)
|
graph.ids[dep] = len(graph.ids)
|
||||||
|
|
||||||
|
@ -87,8 +90,7 @@ def _new_dep(graph: DepGraph, src_module: str, dep: str):
|
||||||
|
|
||||||
|
|
||||||
def _new_import(graph: DepGraph, src_module: str, dep_module: str):
|
def _new_import(graph: DepGraph, src_module: str, dep_module: str):
|
||||||
"""Process a new import statement in src_module.
|
"""Process a new import statement in src_module."""
|
||||||
"""
|
|
||||||
# We don't care about system imports.
|
# We don't care about system imports.
|
||||||
if not dep_module.startswith("ray"):
|
if not dep_module.startswith("ray"):
|
||||||
return
|
return
|
||||||
|
@ -97,8 +99,7 @@ def _new_import(graph: DepGraph, src_module: str, dep_module: str):
|
||||||
|
|
||||||
|
|
||||||
def _is_path_module(module: str, name: str, _base_dir: str) -> bool:
|
def _is_path_module(module: str, name: str, _base_dir: str) -> bool:
|
||||||
"""Figure out if base.sub is a python module or not.
|
"""Figure out if base.sub is a python module or not."""
|
||||||
"""
|
|
||||||
# Special handling for _raylet, which is a C++ lib.
|
# Special handling for _raylet, which is a C++ lib.
|
||||||
if module == "ray._raylet":
|
if module == "ray._raylet":
|
||||||
return False
|
return False
|
||||||
|
@ -110,10 +111,10 @@ def _is_path_module(module: str, name: str, _base_dir: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _new_from_import(graph: DepGraph, src_module: str, dep_module: str,
|
def _new_from_import(
|
||||||
dep_name: str, _base_dir: str):
|
graph: DepGraph, src_module: str, dep_module: str, dep_name: str, _base_dir: str
|
||||||
"""Process a new "from ... import ..." statement in src_module.
|
):
|
||||||
"""
|
"""Process a new "from ... import ..." statement in src_module."""
|
||||||
# We don't care about imports outside of ray package.
|
# We don't care about imports outside of ray package.
|
||||||
if not dep_module or not dep_module.startswith("ray"):
|
if not dep_module or not dep_module.startswith("ray"):
|
||||||
return
|
return
|
||||||
|
@ -126,10 +127,7 @@ def _new_from_import(graph: DepGraph, src_module: str, dep_module: str,
|
||||||
_new_dep(graph, src_module, dep_module)
|
_new_dep(graph, src_module, dep_module)
|
||||||
|
|
||||||
|
|
||||||
def _process_file(graph: DepGraph,
|
def _process_file(graph: DepGraph, src_path: str, src_module: str, _base_dir=""):
|
||||||
src_path: str,
|
|
||||||
src_module: str,
|
|
||||||
_base_dir=""):
|
|
||||||
"""Create dependencies from src_module to all the valid imports in src_path.
|
"""Create dependencies from src_module to all the valid imports in src_path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -147,13 +145,13 @@ def _process_file(graph: DepGraph,
|
||||||
_new_import(graph, src_module, alias.name)
|
_new_import(graph, src_module, alias.name)
|
||||||
elif isinstance(node, ast.ImportFrom):
|
elif isinstance(node, ast.ImportFrom):
|
||||||
for alias in node.names:
|
for alias in node.names:
|
||||||
_new_from_import(graph, src_module, node.module,
|
_new_from_import(
|
||||||
alias.name, _base_dir)
|
graph, src_module, node.module, alias.name, _base_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_dep_graph() -> DepGraph:
|
def build_dep_graph() -> DepGraph:
|
||||||
"""Build index from py files to their immediate dependees.
|
"""Build index from py files to their immediate dependees."""
|
||||||
"""
|
|
||||||
graph = DepGraph()
|
graph = DepGraph()
|
||||||
|
|
||||||
# Assuming we run from root /ray directory.
|
# Assuming we run from root /ray directory.
|
||||||
|
@ -197,8 +195,7 @@ def _full_module_path(module, f) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _should_skip(d: str) -> bool:
|
def _should_skip(d: str) -> bool:
|
||||||
"""Skip directories that should not contain py sources.
|
"""Skip directories that should not contain py sources."""
|
||||||
"""
|
|
||||||
if d.startswith("python/.eggs/"):
|
if d.startswith("python/.eggs/"):
|
||||||
return True
|
return True
|
||||||
if d.startswith("python/."):
|
if d.startswith("python/."):
|
||||||
|
@ -224,14 +221,14 @@ def _bazel_path_to_module_path(d: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _file_path_to_module_path(f: str) -> str:
|
def _file_path_to_module_path(f: str) -> str:
|
||||||
"""Return the corresponding module path for a .py file.
|
"""Return the corresponding module path for a .py file."""
|
||||||
"""
|
|
||||||
dir, fn = os.path.split(f)
|
dir, fn = os.path.split(f)
|
||||||
return _full_module_path(_bazel_path_to_module_path(dir), fn)
|
return _full_module_path(_bazel_path_to_module_path(dir), fn)
|
||||||
|
|
||||||
|
|
||||||
def _depends(graph: DepGraph, visited: Dict[int, bool], tid: int,
|
def _depends(
|
||||||
qid: int) -> List[int]:
|
graph: DepGraph, visited: Dict[int, bool], tid: int, qid: int
|
||||||
|
) -> List[int]:
|
||||||
"""Whether there is a dependency path from module tid to module qid.
|
"""Whether there is a dependency path from module tid to module qid.
|
||||||
|
|
||||||
Given graph, and without going through visited.
|
Given graph, and without going through visited.
|
||||||
|
@ -253,8 +250,9 @@ def _depends(graph: DepGraph, visited: Dict[int, bool], tid: int,
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def test_depends_on_file(graph: DepGraph, test: Tuple[str, Tuple[str]],
|
def test_depends_on_file(
|
||||||
path: str) -> List[int]:
|
graph: DepGraph, test: Tuple[str, Tuple[str]], path: str
|
||||||
|
) -> List[int]:
|
||||||
"""Give dependency graph, check if a test depends on a specific .py file.
|
"""Give dependency graph, check if a test depends on a specific .py file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -307,8 +305,7 @@ def _find_circular_dep_impl(graph: DepGraph, id: str, branch: str) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def find_circular_dep(graph: DepGraph) -> Dict[str, List[int]]:
|
def find_circular_dep(graph: DepGraph) -> Dict[str, List[int]]:
|
||||||
"""Find circular dependencies among a dependency graph.
|
"""Find circular dependencies among a dependency graph."""
|
||||||
"""
|
|
||||||
known = {}
|
known = {}
|
||||||
circles = {}
|
circles = {}
|
||||||
for m, id in graph.ids.items():
|
for m, id in graph.ids.items():
|
||||||
|
@ -334,25 +331,29 @@ if __name__ == "__main__":
|
||||||
"--mode",
|
"--mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="test-dep",
|
default="test-dep",
|
||||||
help=("test-dep: find dependencies for a specified test. "
|
help=(
|
||||||
"circular-dep: find circular dependencies in "
|
"test-dep: find dependencies for a specified test. "
|
||||||
"the specific codebase."))
|
"circular-dep: find circular dependencies in "
|
||||||
|
"the specific codebase."
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--file",
|
"--file", type=str, help="Path of a .py source file relative to --base_dir."
|
||||||
type=str,
|
)
|
||||||
help="Path of a .py source file relative to --base_dir.")
|
|
||||||
parser.add_argument("--test", type=str, help="Specific test to check.")
|
parser.add_argument("--test", type=str, help="Specific test to check.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--smoke-test",
|
"--smoke-test", action="store_true", help="Load only a few tests for testing."
|
||||||
action="store_true",
|
)
|
||||||
help="Load only a few tests for testing.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print("building dep graph ...")
|
print("building dep graph ...")
|
||||||
graph = build_dep_graph()
|
graph = build_dep_graph()
|
||||||
print("done. total {} files, {} of which have dependencies.".format(
|
print(
|
||||||
len(graph.ids), len(graph.edges)))
|
"done. total {} files, {} of which have dependencies.".format(
|
||||||
|
len(graph.ids), len(graph.edges)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if args.mode == "circular-dep":
|
if args.mode == "circular-dep":
|
||||||
circles = find_circular_dep(graph)
|
circles = find_circular_dep(graph)
|
||||||
|
|
|
@ -12,30 +12,33 @@ class TestPyDepAnalysis(unittest.TestCase):
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
def test_full_module_path(self):
|
def test_full_module_path(self):
|
||||||
self.assertEqual(
|
self.assertEqual(pda._full_module_path("aa.bb.cc", "__init__.py"), "aa.bb.cc")
|
||||||
pda._full_module_path("aa.bb.cc", "__init__.py"), "aa.bb.cc")
|
self.assertEqual(pda._full_module_path("aa.bb.cc", "dd.py"), "aa.bb.cc.dd")
|
||||||
self.assertEqual(
|
|
||||||
pda._full_module_path("aa.bb.cc", "dd.py"), "aa.bb.cc.dd")
|
|
||||||
self.assertEqual(pda._full_module_path("", "dd.py"), "dd")
|
self.assertEqual(pda._full_module_path("", "dd.py"), "dd")
|
||||||
|
|
||||||
def test_bazel_path_to_module_path(self):
|
def test_bazel_path_to_module_path(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
pda._bazel_path_to_module_path("//python/ray/rllib:xxx/yyy/dd"),
|
pda._bazel_path_to_module_path("//python/ray/rllib:xxx/yyy/dd"),
|
||||||
"ray.rllib.xxx.yyy.dd")
|
"ray.rllib.xxx.yyy.dd",
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
pda._bazel_path_to_module_path("python:ray/rllib/xxx/yyy/dd"),
|
pda._bazel_path_to_module_path("python:ray/rllib/xxx/yyy/dd"),
|
||||||
"ray.rllib.xxx.yyy.dd")
|
"ray.rllib.xxx.yyy.dd",
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
pda._bazel_path_to_module_path("python/ray/rllib:xxx/yyy/dd"),
|
pda._bazel_path_to_module_path("python/ray/rllib:xxx/yyy/dd"),
|
||||||
"ray.rllib.xxx.yyy.dd")
|
"ray.rllib.xxx.yyy.dd",
|
||||||
|
)
|
||||||
|
|
||||||
def test_file_path_to_module_path(self):
|
def test_file_path_to_module_path(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
pda._file_path_to_module_path("python/ray/rllib/env/env.py"),
|
pda._file_path_to_module_path("python/ray/rllib/env/env.py"),
|
||||||
"ray.rllib.env.env")
|
"ray.rllib.env.env",
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
pda._file_path_to_module_path("python/ray/rllib/env/__init__.py"),
|
pda._file_path_to_module_path("python/ray/rllib/env/__init__.py"),
|
||||||
"ray.rllib.env")
|
"ray.rllib.env",
|
||||||
|
)
|
||||||
|
|
||||||
def test_import_line_continuation(self):
|
def test_import_line_continuation(self):
|
||||||
graph = pda.DepGraph()
|
graph = pda.DepGraph()
|
||||||
|
@ -44,11 +47,13 @@ class TestPyDepAnalysis(unittest.TestCase):
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
src_path = os.path.join(tmpdir, "continuation1.py")
|
src_path = os.path.join(tmpdir, "continuation1.py")
|
||||||
self.create_tmp_file(
|
self.create_tmp_file(
|
||||||
src_path, """
|
src_path,
|
||||||
|
"""
|
||||||
import ray.rllib.env.\\
|
import ray.rllib.env.\\
|
||||||
mock_env
|
mock_env
|
||||||
b = 2
|
b = 2
|
||||||
""")
|
""",
|
||||||
|
)
|
||||||
pda._process_file(graph, src_path, "ray")
|
pda._process_file(graph, src_path, "ray")
|
||||||
|
|
||||||
self.assertEqual(len(graph.ids), 2)
|
self.assertEqual(len(graph.ids), 2)
|
||||||
|
@ -64,11 +69,13 @@ b = 2
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
src_path = os.path.join(tmpdir, "continuation1.py")
|
src_path = os.path.join(tmpdir, "continuation1.py")
|
||||||
self.create_tmp_file(
|
self.create_tmp_file(
|
||||||
src_path, """
|
src_path,
|
||||||
|
"""
|
||||||
from ray.rllib.env import (ClassName,
|
from ray.rllib.env import (ClassName,
|
||||||
module1, module2)
|
module1, module2)
|
||||||
b = 2
|
b = 2
|
||||||
""")
|
""",
|
||||||
|
)
|
||||||
pda._process_file(graph, src_path, "ray")
|
pda._process_file(graph, src_path, "ray")
|
||||||
|
|
||||||
self.assertEqual(len(graph.ids), 2)
|
self.assertEqual(len(graph.ids), 2)
|
||||||
|
@ -84,11 +91,13 @@ b = 2
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
src_path = "multi_line_comment_3.py"
|
src_path = "multi_line_comment_3.py"
|
||||||
self.create_tmp_file(
|
self.create_tmp_file(
|
||||||
os.path.join(tmpdir, src_path), """
|
os.path.join(tmpdir, src_path),
|
||||||
|
"""
|
||||||
from ray.rllib.env import mock_env
|
from ray.rllib.env import mock_env
|
||||||
a = 1
|
a = 1
|
||||||
b = 2
|
b = 2
|
||||||
""")
|
""",
|
||||||
|
)
|
||||||
# Touch ray/rllib/env/mock_env.py in tmpdir,
|
# Touch ray/rllib/env/mock_env.py in tmpdir,
|
||||||
# so that it looks like a module.
|
# so that it looks like a module.
|
||||||
module_dir = os.path.join(tmpdir, "python", "ray", "rllib", "env")
|
module_dir = os.path.join(tmpdir, "python", "ray", "rllib", "env")
|
||||||
|
@ -112,11 +121,13 @@ b = 2
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
src_path = "multi_line_comment_3.py"
|
src_path = "multi_line_comment_3.py"
|
||||||
self.create_tmp_file(
|
self.create_tmp_file(
|
||||||
os.path.join(tmpdir, src_path), """
|
os.path.join(tmpdir, src_path),
|
||||||
|
"""
|
||||||
from ray.rllib.env import MockEnv
|
from ray.rllib.env import MockEnv
|
||||||
a = 1
|
a = 1
|
||||||
b = 2
|
b = 2
|
||||||
""")
|
""",
|
||||||
|
)
|
||||||
# Touch ray/rllib/env.py in tmpdir,
|
# Touch ray/rllib/env.py in tmpdir,
|
||||||
# MockEnv is a class on env module.
|
# MockEnv is a class on env module.
|
||||||
module_dir = os.path.join(tmpdir, "python", "ray", "rllib")
|
module_dir = os.path.join(tmpdir, "python", "ray", "rllib")
|
||||||
|
@ -138,4 +149,5 @@ b = 2
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.exit(pytest.main(["-v", __file__]))
|
sys.exit(pytest.main(["-v", __file__]))
|
||||||
|
|
|
@ -22,8 +22,11 @@ import ray.ray_constants as ray_constants
|
||||||
import ray._private.services
|
import ray._private.services
|
||||||
import ray._private.utils
|
import ray._private.utils
|
||||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||||
from ray._private.gcs_utils import GcsClient, \
|
from ray._private.gcs_utils import (
|
||||||
get_gcs_address_from_redis, use_gcs_for_bootstrap
|
GcsClient,
|
||||||
|
get_gcs_address_from_redis,
|
||||||
|
use_gcs_for_bootstrap,
|
||||||
|
)
|
||||||
from ray.core.generated import agent_manager_pb2
|
from ray.core.generated import agent_manager_pb2
|
||||||
from ray.core.generated import agent_manager_pb2_grpc
|
from ray.core.generated import agent_manager_pb2_grpc
|
||||||
from ray._private.ray_logging import setup_component_logger
|
from ray._private.ray_logging import setup_component_logger
|
||||||
|
@ -42,23 +45,25 @@ aiogrpc.init_grpc_aio()
|
||||||
|
|
||||||
|
|
||||||
class DashboardAgent(object):
|
class DashboardAgent(object):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
node_ip_address,
|
self,
|
||||||
redis_address,
|
node_ip_address,
|
||||||
dashboard_agent_port,
|
redis_address,
|
||||||
gcs_address,
|
dashboard_agent_port,
|
||||||
minimal,
|
gcs_address,
|
||||||
redis_password=None,
|
minimal,
|
||||||
temp_dir=None,
|
redis_password=None,
|
||||||
session_dir=None,
|
temp_dir=None,
|
||||||
runtime_env_dir=None,
|
session_dir=None,
|
||||||
log_dir=None,
|
runtime_env_dir=None,
|
||||||
metrics_export_port=None,
|
log_dir=None,
|
||||||
node_manager_port=None,
|
metrics_export_port=None,
|
||||||
listen_port=0,
|
node_manager_port=None,
|
||||||
object_store_name=None,
|
listen_port=0,
|
||||||
raylet_name=None,
|
object_store_name=None,
|
||||||
logging_params=None):
|
raylet_name=None,
|
||||||
|
logging_params=None,
|
||||||
|
):
|
||||||
"""Initialize the DashboardAgent object."""
|
"""Initialize the DashboardAgent object."""
|
||||||
# Public attributes are accessible for all agent modules.
|
# Public attributes are accessible for all agent modules.
|
||||||
self.ip = node_ip_address
|
self.ip = node_ip_address
|
||||||
|
@ -92,15 +97,16 @@ class DashboardAgent(object):
|
||||||
self.ppid = int(os.environ["RAY_RAYLET_PID"])
|
self.ppid = int(os.environ["RAY_RAYLET_PID"])
|
||||||
assert self.ppid > 0
|
assert self.ppid > 0
|
||||||
logger.info("Parent pid is %s", self.ppid)
|
logger.info("Parent pid is %s", self.ppid)
|
||||||
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
|
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
|
||||||
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
|
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
|
||||||
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
|
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
|
||||||
self.server, f"{grpc_ip}:{self.dashboard_agent_port}")
|
self.server, f"{grpc_ip}:{self.dashboard_agent_port}"
|
||||||
logger.info("Dashboard agent grpc address: %s:%s", grpc_ip,
|
)
|
||||||
self.grpc_port)
|
logger.info("Dashboard agent grpc address: %s:%s", grpc_ip, self.grpc_port)
|
||||||
options = (("grpc.enable_http_proxy", 0), )
|
options = (("grpc.enable_http_proxy", 0),)
|
||||||
self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel(
|
self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel(
|
||||||
f"{self.ip}:{self.node_manager_port}", options, asynchronous=True)
|
f"{self.ip}:{self.node_manager_port}", options, asynchronous=True
|
||||||
|
)
|
||||||
|
|
||||||
# If the agent is started as non-minimal version, http server should
|
# If the agent is started as non-minimal version, http server should
|
||||||
# be configured to communicate with the dashboard in a head node.
|
# be configured to communicate with the dashboard in a head node.
|
||||||
|
@ -108,6 +114,7 @@ class DashboardAgent(object):
|
||||||
|
|
||||||
async def _configure_http_server(self, modules):
|
async def _configure_http_server(self, modules):
|
||||||
from ray.dashboard.http_server_agent import HttpServerAgent
|
from ray.dashboard.http_server_agent import HttpServerAgent
|
||||||
|
|
||||||
http_server = HttpServerAgent(self.ip, self.listen_port)
|
http_server = HttpServerAgent(self.ip, self.listen_port)
|
||||||
await http_server.start(modules)
|
await http_server.start(modules)
|
||||||
return http_server
|
return http_server
|
||||||
|
@ -116,10 +123,12 @@ class DashboardAgent(object):
|
||||||
"""Load dashboard agent modules."""
|
"""Load dashboard agent modules."""
|
||||||
modules = []
|
modules = []
|
||||||
agent_cls_list = dashboard_utils.get_all_modules(
|
agent_cls_list = dashboard_utils.get_all_modules(
|
||||||
dashboard_utils.DashboardAgentModule)
|
dashboard_utils.DashboardAgentModule
|
||||||
|
)
|
||||||
for cls in agent_cls_list:
|
for cls in agent_cls_list:
|
||||||
logger.info("Loading %s: %s",
|
logger.info(
|
||||||
dashboard_utils.DashboardAgentModule.__name__, cls)
|
"Loading %s: %s", dashboard_utils.DashboardAgentModule.__name__, cls
|
||||||
|
)
|
||||||
c = cls(self)
|
c = cls(self)
|
||||||
modules.append(c)
|
modules.append(c)
|
||||||
logger.info("Loaded %d modules.", len(modules))
|
logger.info("Loaded %d modules.", len(modules))
|
||||||
|
@ -137,13 +146,12 @@ class DashboardAgent(object):
|
||||||
curr_proc = psutil.Process()
|
curr_proc = psutil.Process()
|
||||||
while True:
|
while True:
|
||||||
parent = curr_proc.parent()
|
parent = curr_proc.parent()
|
||||||
if (parent is None or parent.pid == 1
|
if parent is None or parent.pid == 1 or self.ppid != parent.pid:
|
||||||
or self.ppid != parent.pid):
|
|
||||||
logger.error("Raylet is dead, exiting.")
|
logger.error("Raylet is dead, exiting.")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
dashboard_consts.
|
dashboard_consts.DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS
|
||||||
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Failed to check parent PID, exiting.")
|
logger.error("Failed to check parent PID, exiting.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
@ -154,15 +162,17 @@ class DashboardAgent(object):
|
||||||
if not use_gcs_for_bootstrap():
|
if not use_gcs_for_bootstrap():
|
||||||
# Create an aioredis client for all modules.
|
# Create an aioredis client for all modules.
|
||||||
try:
|
try:
|
||||||
self.aioredis_client = \
|
self.aioredis_client = await dashboard_utils.get_aioredis_client(
|
||||||
await dashboard_utils.get_aioredis_client(
|
self.redis_address,
|
||||||
self.redis_address, self.redis_password,
|
self.redis_password,
|
||||||
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
|
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
|
||||||
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES)
|
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES,
|
||||||
|
)
|
||||||
except (socket.gaierror, ConnectionRefusedError):
|
except (socket.gaierror, ConnectionRefusedError):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Dashboard agent exiting: "
|
"Dashboard agent exiting: " "Failed to connect to redis at %s",
|
||||||
"Failed to connect to redis at %s", self.redis_address)
|
self.redis_address,
|
||||||
|
)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
# Start a grpc asyncio server.
|
# Start a grpc asyncio server.
|
||||||
|
@ -170,7 +180,8 @@ class DashboardAgent(object):
|
||||||
|
|
||||||
if not use_gcs_for_bootstrap():
|
if not use_gcs_for_bootstrap():
|
||||||
gcs_address = await self.aioredis_client.get(
|
gcs_address = await self.aioredis_client.get(
|
||||||
dashboard_consts.GCS_SERVER_ADDRESS)
|
dashboard_consts.GCS_SERVER_ADDRESS
|
||||||
|
)
|
||||||
self.gcs_client = GcsClient(address=gcs_address.decode())
|
self.gcs_client = GcsClient(address=gcs_address.decode())
|
||||||
else:
|
else:
|
||||||
self.gcs_client = GcsClient(address=self.gcs_address)
|
self.gcs_client = GcsClient(address=self.gcs_address)
|
||||||
|
@ -192,17 +203,21 @@ class DashboardAgent(object):
|
||||||
internal_kv._internal_kv_put(
|
internal_kv._internal_kv_put(
|
||||||
f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{self.node_id}",
|
f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{self.node_id}",
|
||||||
json.dumps([http_port, self.grpc_port]),
|
json.dumps([http_port, self.grpc_port]),
|
||||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||||
|
)
|
||||||
|
|
||||||
# Register agent to agent manager.
|
# Register agent to agent manager.
|
||||||
raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub(
|
raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub(
|
||||||
self.aiogrpc_raylet_channel)
|
self.aiogrpc_raylet_channel
|
||||||
|
)
|
||||||
|
|
||||||
await raylet_stub.RegisterAgent(
|
await raylet_stub.RegisterAgent(
|
||||||
agent_manager_pb2.RegisterAgentRequest(
|
agent_manager_pb2.RegisterAgentRequest(
|
||||||
agent_pid=os.getpid(),
|
agent_pid=os.getpid(),
|
||||||
agent_port=self.grpc_port,
|
agent_port=self.grpc_port,
|
||||||
agent_ip_address=self.ip))
|
agent_ip_address=self.ip,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
tasks = [m.run(self.server) for m in modules]
|
tasks = [m.run(self.server) for m in modules]
|
||||||
if sys.platform not in ["win32", "cygwin"]:
|
if sys.platform not in ["win32", "cygwin"]:
|
||||||
|
@ -221,123 +236,139 @@ if __name__ == "__main__":
|
||||||
"--node-ip-address",
|
"--node-ip-address",
|
||||||
required=True,
|
required=True,
|
||||||
type=str,
|
type=str,
|
||||||
help="the IP address of this node.")
|
help="the IP address of this node.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gcs-address",
|
"--gcs-address", required=False, type=str, help="The address (ip:port) of GCS."
|
||||||
required=False,
|
)
|
||||||
type=str,
|
|
||||||
help="The address (ip:port) of GCS.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--redis-address",
|
"--redis-address", required=True, type=str, help="The address to use for Redis."
|
||||||
required=True,
|
)
|
||||||
type=str,
|
|
||||||
help="The address to use for Redis.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--metrics-export-port",
|
"--metrics-export-port",
|
||||||
required=True,
|
required=True,
|
||||||
type=int,
|
type=int,
|
||||||
help="The port to expose metrics through Prometheus.")
|
help="The port to expose metrics through Prometheus.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dashboard-agent-port",
|
"--dashboard-agent-port",
|
||||||
required=True,
|
required=True,
|
||||||
type=int,
|
type=int,
|
||||||
help="The port on which the dashboard agent will receive GRPCs.")
|
help="The port on which the dashboard agent will receive GRPCs.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--node-manager-port",
|
"--node-manager-port",
|
||||||
required=True,
|
required=True,
|
||||||
type=int,
|
type=int,
|
||||||
help="The port to use for starting the node manager")
|
help="The port to use for starting the node manager",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--object-store-name",
|
"--object-store-name",
|
||||||
required=True,
|
required=True,
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The socket name of the plasma store")
|
help="The socket name of the plasma store",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--listen-port",
|
"--listen-port",
|
||||||
required=False,
|
required=False,
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Port for HTTP server to listen on")
|
help="Port for HTTP server to listen on",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--raylet-name",
|
"--raylet-name",
|
||||||
required=True,
|
required=True,
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The socket path of the raylet process")
|
help="The socket path of the raylet process",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--redis-password",
|
"--redis-password",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The password to use for Redis")
|
help="The password to use for Redis",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging-level",
|
"--logging-level",
|
||||||
required=False,
|
required=False,
|
||||||
type=lambda s: logging.getLevelName(s.upper()),
|
type=lambda s: logging.getLevelName(s.upper()),
|
||||||
default=ray_constants.LOGGER_LEVEL,
|
default=ray_constants.LOGGER_LEVEL,
|
||||||
choices=ray_constants.LOGGER_LEVEL_CHOICES,
|
choices=ray_constants.LOGGER_LEVEL_CHOICES,
|
||||||
help=ray_constants.LOGGER_LEVEL_HELP)
|
help=ray_constants.LOGGER_LEVEL_HELP,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging-format",
|
"--logging-format",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
default=ray_constants.LOGGER_FORMAT,
|
default=ray_constants.LOGGER_FORMAT,
|
||||||
help=ray_constants.LOGGER_FORMAT_HELP)
|
help=ray_constants.LOGGER_FORMAT_HELP,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging-filename",
|
"--logging-filename",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
default=dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME,
|
default=dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME,
|
||||||
help="Specify the name of log file, "
|
help="Specify the name of log file, "
|
||||||
"log to stdout if set empty, default is \"{}\".".format(
|
'log to stdout if set empty, default is "{}".'.format(
|
||||||
dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME))
|
dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging-rotate-bytes",
|
"--logging-rotate-bytes",
|
||||||
required=False,
|
required=False,
|
||||||
type=int,
|
type=int,
|
||||||
default=ray_constants.LOGGING_ROTATE_BYTES,
|
default=ray_constants.LOGGING_ROTATE_BYTES,
|
||||||
help="Specify the max bytes for rotating "
|
help="Specify the max bytes for rotating "
|
||||||
"log file, default is {} bytes.".format(
|
"log file, default is {} bytes.".format(ray_constants.LOGGING_ROTATE_BYTES),
|
||||||
ray_constants.LOGGING_ROTATE_BYTES))
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging-rotate-backup-count",
|
"--logging-rotate-backup-count",
|
||||||
required=False,
|
required=False,
|
||||||
type=int,
|
type=int,
|
||||||
default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
|
default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
|
||||||
help="Specify the backup count of rotated log file, default is {}.".
|
help="Specify the backup count of rotated log file, default is {}.".format(
|
||||||
format(ray_constants.LOGGING_ROTATE_BACKUP_COUNT))
|
ray_constants.LOGGING_ROTATE_BACKUP_COUNT
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-dir",
|
"--log-dir",
|
||||||
required=True,
|
required=True,
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the path of log directory.")
|
help="Specify the path of log directory.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--temp-dir",
|
"--temp-dir",
|
||||||
required=True,
|
required=True,
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the path of the temporary directory use by Ray process.")
|
help="Specify the path of the temporary directory use by Ray process.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--session-dir",
|
"--session-dir",
|
||||||
required=True,
|
required=True,
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the path of this session.")
|
help="Specify the path of this session.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--runtime-env-dir",
|
"--runtime-env-dir",
|
||||||
required=True,
|
required=True,
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the path of the resource directory used by runtime_env.")
|
help="Specify the path of the resource directory used by runtime_env.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--minimal",
|
"--minimal",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help=(
|
help=(
|
||||||
"Minimal agent only contains a subset of features that don't "
|
"Minimal agent only contains a subset of features that don't "
|
||||||
"require additional dependencies installed when ray is installed "
|
"require additional dependencies installed when ray is installed "
|
||||||
"by `pip install ray[default]`."))
|
"by `pip install ray[default]`."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
try:
|
try:
|
||||||
|
@ -347,7 +378,8 @@ if __name__ == "__main__":
|
||||||
log_dir=args.log_dir,
|
log_dir=args.log_dir,
|
||||||
filename=args.logging_filename,
|
filename=args.logging_filename,
|
||||||
max_bytes=args.logging_rotate_bytes,
|
max_bytes=args.logging_rotate_bytes,
|
||||||
backup_count=args.logging_rotate_backup_count)
|
backup_count=args.logging_rotate_backup_count,
|
||||||
|
)
|
||||||
setup_component_logger(**logging_params)
|
setup_component_logger(**logging_params)
|
||||||
|
|
||||||
agent = DashboardAgent(
|
agent = DashboardAgent(
|
||||||
|
@ -366,7 +398,8 @@ if __name__ == "__main__":
|
||||||
listen_port=args.listen_port,
|
listen_port=args.listen_port,
|
||||||
object_store_name=args.object_store_name,
|
object_store_name=args.object_store_name,
|
||||||
raylet_name=args.raylet_name,
|
raylet_name=args.raylet_name,
|
||||||
logging_params=logging_params)
|
logging_params=logging_params,
|
||||||
|
)
|
||||||
if os.environ.get("_RAY_AGENT_FAILING"):
|
if os.environ.get("_RAY_AGENT_FAILING"):
|
||||||
raise Exception("Failure injection failure.")
|
raise Exception("Failure injection failure.")
|
||||||
|
|
||||||
|
@ -390,15 +423,19 @@ if __name__ == "__main__":
|
||||||
gcs_publisher = GcsPublisher(args.gcs_address)
|
gcs_publisher = GcsPublisher(args.gcs_address)
|
||||||
else:
|
else:
|
||||||
redis_client = ray._private.services.create_redis_client(
|
redis_client = ray._private.services.create_redis_client(
|
||||||
args.redis_address, password=args.redis_password)
|
args.redis_address, password=args.redis_password
|
||||||
|
)
|
||||||
gcs_publisher = GcsPublisher(
|
gcs_publisher = GcsPublisher(
|
||||||
address=get_gcs_address_from_redis(redis_client))
|
address=get_gcs_address_from_redis(redis_client)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
redis_client = ray._private.services.create_redis_client(
|
redis_client = ray._private.services.create_redis_client(
|
||||||
args.redis_address, password=args.redis_password)
|
args.redis_address, password=args.redis_password
|
||||||
|
)
|
||||||
|
|
||||||
traceback_str = ray._private.utils.format_error_message(
|
traceback_str = ray._private.utils.format_error_message(
|
||||||
traceback.format_exc())
|
traceback.format_exc()
|
||||||
|
)
|
||||||
message = (
|
message = (
|
||||||
f"(ip={node_ip}) "
|
f"(ip={node_ip}) "
|
||||||
f"The agent on node {platform.uname()[1]} failed to "
|
f"The agent on node {platform.uname()[1]} failed to "
|
||||||
|
@ -409,12 +446,14 @@ if __name__ == "__main__":
|
||||||
"\n 2. Metrics on this node won't be reported."
|
"\n 2. Metrics on this node won't be reported."
|
||||||
"\n 3. runtime_env APIs won't work."
|
"\n 3. runtime_env APIs won't work."
|
||||||
"\nCheck out the `dashboard_agent.log` to see the "
|
"\nCheck out the `dashboard_agent.log` to see the "
|
||||||
"detailed failure messages.")
|
"detailed failure messages."
|
||||||
|
)
|
||||||
ray._private.utils.publish_error_to_driver(
|
ray._private.utils.publish_error_to_driver(
|
||||||
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
|
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
|
||||||
message,
|
message,
|
||||||
redis_client=redis_client,
|
redis_client=redis_client,
|
||||||
gcs_publisher=gcs_publisher)
|
gcs_publisher=gcs_publisher,
|
||||||
|
)
|
||||||
logger.error(message)
|
logger.error(message)
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
|
@ -12,12 +12,13 @@ DASHBOARD_RPC_ADDRESS = "dashboard_rpc"
|
||||||
GCS_SERVER_ADDRESS = "GcsServerAddress"
|
GCS_SERVER_ADDRESS = "GcsServerAddress"
|
||||||
# GCS check alive
|
# GCS check alive
|
||||||
GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR = env_integer(
|
GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR = env_integer(
|
||||||
"GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR", 10)
|
"GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR", 10
|
||||||
GCS_CHECK_ALIVE_INTERVAL_SECONDS = env_integer(
|
)
|
||||||
"GCS_CHECK_ALIVE_INTERVAL_SECONDS", 5)
|
GCS_CHECK_ALIVE_INTERVAL_SECONDS = env_integer("GCS_CHECK_ALIVE_INTERVAL_SECONDS", 5)
|
||||||
GCS_CHECK_ALIVE_RPC_TIMEOUT = env_integer("GCS_CHECK_ALIVE_RPC_TIMEOUT", 10)
|
GCS_CHECK_ALIVE_RPC_TIMEOUT = env_integer("GCS_CHECK_ALIVE_RPC_TIMEOUT", 10)
|
||||||
GCS_RETRY_CONNECT_INTERVAL_SECONDS = env_integer(
|
GCS_RETRY_CONNECT_INTERVAL_SECONDS = env_integer(
|
||||||
"GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2)
|
"GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2
|
||||||
|
)
|
||||||
# aiohttp_cache
|
# aiohttp_cache
|
||||||
AIOHTTP_CACHE_TTL_SECONDS = 2
|
AIOHTTP_CACHE_TTL_SECONDS = 2
|
||||||
AIOHTTP_CACHE_MAX_SIZE = 128
|
AIOHTTP_CACHE_MAX_SIZE = 128
|
||||||
|
|
|
@ -38,17 +38,21 @@ class FrontendNotFoundError(OSError):
|
||||||
|
|
||||||
def setup_static_dir():
|
def setup_static_dir():
|
||||||
build_dir = os.path.join(
|
build_dir = os.path.join(
|
||||||
os.path.dirname(os.path.abspath(__file__)), "client", "build")
|
os.path.dirname(os.path.abspath(__file__)), "client", "build"
|
||||||
|
)
|
||||||
module_name = os.path.basename(os.path.dirname(__file__))
|
module_name = os.path.basename(os.path.dirname(__file__))
|
||||||
if not os.path.isdir(build_dir):
|
if not os.path.isdir(build_dir):
|
||||||
raise FrontendNotFoundError(
|
raise FrontendNotFoundError(
|
||||||
errno.ENOENT, "Dashboard build directory not found. If installing "
|
errno.ENOENT,
|
||||||
|
"Dashboard build directory not found. If installing "
|
||||||
"from source, please follow the additional steps "
|
"from source, please follow the additional steps "
|
||||||
"required to build the dashboard"
|
"required to build the dashboard"
|
||||||
f"(cd python/ray/{module_name}/client "
|
f"(cd python/ray/{module_name}/client "
|
||||||
"&& npm install "
|
"&& npm install "
|
||||||
"&& npm ci "
|
"&& npm ci "
|
||||||
"&& npm run build)", build_dir)
|
"&& npm run build)",
|
||||||
|
build_dir,
|
||||||
|
)
|
||||||
|
|
||||||
static_dir = os.path.join(build_dir, "static")
|
static_dir = os.path.join(build_dir, "static")
|
||||||
routes.static("/static", static_dir, follow_symlinks=True)
|
routes.static("/static", static_dir, follow_symlinks=True)
|
||||||
|
@ -72,14 +76,16 @@ class Dashboard:
|
||||||
log_dir(str): Log directory of dashboard.
|
log_dir(str): Log directory of dashboard.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
host,
|
self,
|
||||||
port,
|
host,
|
||||||
port_retries,
|
port,
|
||||||
gcs_address,
|
port_retries,
|
||||||
redis_address,
|
gcs_address,
|
||||||
redis_password=None,
|
redis_address,
|
||||||
log_dir=None):
|
redis_password=None,
|
||||||
|
log_dir=None,
|
||||||
|
):
|
||||||
self.dashboard_head = dashboard_head.DashboardHead(
|
self.dashboard_head = dashboard_head.DashboardHead(
|
||||||
http_host=host,
|
http_host=host,
|
||||||
http_port=port,
|
http_port=port,
|
||||||
|
@ -87,7 +93,8 @@ class Dashboard:
|
||||||
gcs_address=gcs_address,
|
gcs_address=gcs_address,
|
||||||
redis_address=redis_address,
|
redis_address=redis_address,
|
||||||
redis_password=redis_password,
|
redis_password=redis_password,
|
||||||
log_dir=log_dir)
|
log_dir=log_dir,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup Dashboard Routes
|
# Setup Dashboard Routes
|
||||||
try:
|
try:
|
||||||
|
@ -107,15 +114,17 @@ class Dashboard:
|
||||||
async def get_index(self, req) -> aiohttp.web.FileResponse:
|
async def get_index(self, req) -> aiohttp.web.FileResponse:
|
||||||
return aiohttp.web.FileResponse(
|
return aiohttp.web.FileResponse(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
os.path.dirname(os.path.abspath(__file__)),
|
os.path.dirname(os.path.abspath(__file__)), "client/build/index.html"
|
||||||
"client/build/index.html"))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/favicon.ico")
|
@routes.get("/favicon.ico")
|
||||||
async def get_favicon(self, req) -> aiohttp.web.FileResponse:
|
async def get_favicon(self, req) -> aiohttp.web.FileResponse:
|
||||||
return aiohttp.web.FileResponse(
|
return aiohttp.web.FileResponse(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
os.path.dirname(os.path.abspath(__file__)),
|
os.path.dirname(os.path.abspath(__file__)), "client/build/favicon.ico"
|
||||||
"client/build/favicon.ico"))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
await self.dashboard_head.run()
|
await self.dashboard_head.run()
|
||||||
|
@ -124,92 +133,96 @@ class Dashboard:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Ray dashboard.")
|
parser = argparse.ArgumentParser(description="Ray dashboard.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host",
|
"--host", required=True, type=str, help="The host to use for the HTTP server."
|
||||||
required=True,
|
)
|
||||||
type=str,
|
|
||||||
help="The host to use for the HTTP server.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port",
|
"--port", required=True, type=int, help="The port to use for the HTTP server."
|
||||||
required=True,
|
)
|
||||||
type=int,
|
|
||||||
help="The port to use for the HTTP server.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port-retries",
|
"--port-retries",
|
||||||
required=False,
|
required=False,
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="The retry times to select a valid port.")
|
help="The retry times to select a valid port.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gcs-address",
|
"--gcs-address", required=False, type=str, help="The address (ip:port) of GCS."
|
||||||
required=False,
|
)
|
||||||
type=str,
|
|
||||||
help="The address (ip:port) of GCS.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--redis-address",
|
"--redis-address", required=True, type=str, help="The address to use for Redis."
|
||||||
required=True,
|
)
|
||||||
type=str,
|
|
||||||
help="The address to use for Redis.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--redis-password",
|
"--redis-password",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The password to use for Redis")
|
help="The password to use for Redis",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging-level",
|
"--logging-level",
|
||||||
required=False,
|
required=False,
|
||||||
type=lambda s: logging.getLevelName(s.upper()),
|
type=lambda s: logging.getLevelName(s.upper()),
|
||||||
default=ray_constants.LOGGER_LEVEL,
|
default=ray_constants.LOGGER_LEVEL,
|
||||||
choices=ray_constants.LOGGER_LEVEL_CHOICES,
|
choices=ray_constants.LOGGER_LEVEL_CHOICES,
|
||||||
help=ray_constants.LOGGER_LEVEL_HELP)
|
help=ray_constants.LOGGER_LEVEL_HELP,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging-format",
|
"--logging-format",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
default=ray_constants.LOGGER_FORMAT,
|
default=ray_constants.LOGGER_FORMAT,
|
||||||
help=ray_constants.LOGGER_FORMAT_HELP)
|
help=ray_constants.LOGGER_FORMAT_HELP,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging-filename",
|
"--logging-filename",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
default=dashboard_consts.DASHBOARD_LOG_FILENAME,
|
default=dashboard_consts.DASHBOARD_LOG_FILENAME,
|
||||||
help="Specify the name of log file, "
|
help="Specify the name of log file, "
|
||||||
"log to stdout if set empty, default is \"{}\"".format(
|
'log to stdout if set empty, default is "{}"'.format(
|
||||||
dashboard_consts.DASHBOARD_LOG_FILENAME))
|
dashboard_consts.DASHBOARD_LOG_FILENAME
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging-rotate-bytes",
|
"--logging-rotate-bytes",
|
||||||
required=False,
|
required=False,
|
||||||
type=int,
|
type=int,
|
||||||
default=ray_constants.LOGGING_ROTATE_BYTES,
|
default=ray_constants.LOGGING_ROTATE_BYTES,
|
||||||
help="Specify the max bytes for rotating "
|
help="Specify the max bytes for rotating "
|
||||||
"log file, default is {} bytes.".format(
|
"log file, default is {} bytes.".format(ray_constants.LOGGING_ROTATE_BYTES),
|
||||||
ray_constants.LOGGING_ROTATE_BYTES))
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logging-rotate-backup-count",
|
"--logging-rotate-backup-count",
|
||||||
required=False,
|
required=False,
|
||||||
type=int,
|
type=int,
|
||||||
default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
|
default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
|
||||||
help="Specify the backup count of rotated log file, default is {}.".
|
help="Specify the backup count of rotated log file, default is {}.".format(
|
||||||
format(ray_constants.LOGGING_ROTATE_BACKUP_COUNT))
|
ray_constants.LOGGING_ROTATE_BACKUP_COUNT
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-dir",
|
"--log-dir",
|
||||||
required=True,
|
required=True,
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the path of log directory.")
|
help="Specify the path of log directory.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--temp-dir",
|
"--temp-dir",
|
||||||
required=True,
|
required=True,
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the path of the temporary directory use by Ray process.")
|
help="Specify the path of the temporary directory use by Ray process.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--minimal",
|
"--minimal",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help=(
|
help=(
|
||||||
"Minimal dashboard only contains a subset of features that don't "
|
"Minimal dashboard only contains a subset of features that don't "
|
||||||
"require additional dependencies installed when ray is installed "
|
"require additional dependencies installed when ray is installed "
|
||||||
"by `pip install ray[default]`."))
|
"by `pip install ray[default]`."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -226,7 +239,8 @@ if __name__ == "__main__":
|
||||||
log_dir=args.log_dir,
|
log_dir=args.log_dir,
|
||||||
filename=args.logging_filename,
|
filename=args.logging_filename,
|
||||||
max_bytes=args.logging_rotate_bytes,
|
max_bytes=args.logging_rotate_bytes,
|
||||||
backup_count=args.logging_rotate_backup_count)
|
backup_count=args.logging_rotate_backup_count,
|
||||||
|
)
|
||||||
|
|
||||||
dashboard = Dashboard(
|
dashboard = Dashboard(
|
||||||
args.host,
|
args.host,
|
||||||
|
@ -235,25 +249,27 @@ if __name__ == "__main__":
|
||||||
args.gcs_address,
|
args.gcs_address,
|
||||||
args.redis_address,
|
args.redis_address,
|
||||||
redis_password=args.redis_password,
|
redis_password=args.redis_password,
|
||||||
log_dir=args.log_dir)
|
log_dir=args.log_dir,
|
||||||
|
)
|
||||||
# TODO(fyrestone): Avoid using ray.state in dashboard, it's not
|
# TODO(fyrestone): Avoid using ray.state in dashboard, it's not
|
||||||
# asynchronous and will lead to low performance. ray disconnect()
|
# asynchronous and will lead to low performance. ray disconnect()
|
||||||
# will be hang when the ray.state is connected and the GCS is exit.
|
# will be hang when the ray.state is connected and the GCS is exit.
|
||||||
# Please refer to: https://github.com/ray-project/ray/issues/16328
|
# Please refer to: https://github.com/ray-project/ray/issues/16328
|
||||||
service_discovery = PrometheusServiceDiscoveryWriter(
|
service_discovery = PrometheusServiceDiscoveryWriter(
|
||||||
args.redis_address, args.redis_password, args.gcs_address,
|
args.redis_address, args.redis_password, args.gcs_address, args.temp_dir
|
||||||
args.temp_dir)
|
)
|
||||||
# Need daemon True to avoid dashboard hangs at exit.
|
# Need daemon True to avoid dashboard hangs at exit.
|
||||||
service_discovery.daemon = True
|
service_discovery.daemon = True
|
||||||
service_discovery.start()
|
service_discovery.start()
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
loop.run_until_complete(dashboard.run())
|
loop.run_until_complete(dashboard.run())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback_str = ray._private.utils.format_error_message(
|
traceback_str = ray._private.utils.format_error_message(traceback.format_exc())
|
||||||
traceback.format_exc())
|
message = (
|
||||||
message = f"The dashboard on node {platform.uname()[1]} " \
|
f"The dashboard on node {platform.uname()[1]} "
|
||||||
f"failed with the following " \
|
f"failed with the following "
|
||||||
f"error:\n{traceback_str}"
|
f"error:\n{traceback_str}"
|
||||||
|
)
|
||||||
if isinstance(e, FrontendNotFoundError):
|
if isinstance(e, FrontendNotFoundError):
|
||||||
logger.warning(message)
|
logger.warning(message)
|
||||||
else:
|
else:
|
||||||
|
@ -268,17 +284,21 @@ if __name__ == "__main__":
|
||||||
gcs_publisher = GcsPublisher(args.gcs_address)
|
gcs_publisher = GcsPublisher(args.gcs_address)
|
||||||
else:
|
else:
|
||||||
redis_client = ray._private.services.create_redis_client(
|
redis_client = ray._private.services.create_redis_client(
|
||||||
args.redis_address, password=args.redis_password)
|
args.redis_address, password=args.redis_password
|
||||||
|
)
|
||||||
gcs_publisher = GcsPublisher(
|
gcs_publisher = GcsPublisher(
|
||||||
address=gcs_utils.get_gcs_address_from_redis(redis_client))
|
address=gcs_utils.get_gcs_address_from_redis(redis_client)
|
||||||
|
)
|
||||||
redis_client = None
|
redis_client = None
|
||||||
else:
|
else:
|
||||||
redis_client = ray._private.services.create_redis_client(
|
redis_client = ray._private.services.create_redis_client(
|
||||||
args.redis_address, password=args.redis_password)
|
args.redis_address, password=args.redis_password
|
||||||
|
)
|
||||||
|
|
||||||
ray._private.utils.publish_error_to_driver(
|
ray._private.utils.publish_error_to_driver(
|
||||||
redis_client,
|
redis_client,
|
||||||
ray_constants.DASHBOARD_DIED_ERROR,
|
ray_constants.DASHBOARD_DIED_ERROR,
|
||||||
message,
|
message,
|
||||||
redis_client=redis_client,
|
redis_client=redis_client,
|
||||||
gcs_publisher=gcs_publisher)
|
gcs_publisher=gcs_publisher,
|
||||||
|
)
|
||||||
|
|
|
@ -2,9 +2,9 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import ray.dashboard.consts as dashboard_consts
|
import ray.dashboard.consts as dashboard_consts
|
||||||
import ray.dashboard.memory_utils as memory_utils
|
import ray.dashboard.memory_utils as memory_utils
|
||||||
|
|
||||||
# TODO(fyrestone): Not import from dashboard module.
|
# TODO(fyrestone): Not import from dashboard module.
|
||||||
from ray.dashboard.modules.actor.actor_utils import \
|
from ray.dashboard.modules.actor.actor_utils import actor_classname_from_task_spec
|
||||||
actor_classname_from_task_spec
|
|
||||||
from ray.dashboard.utils import Dict, Signal, async_loop_forever
|
from ray.dashboard.utils import Dict, Signal, async_loop_forever
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -132,9 +132,9 @@ class DataOrganizer:
|
||||||
worker["errorCount"] = len(node_errs.get(str(pid), []))
|
worker["errorCount"] = len(node_errs.get(str(pid), []))
|
||||||
worker["coreWorkerStats"] = pid_to_worker_stats.get(pid, [])
|
worker["coreWorkerStats"] = pid_to_worker_stats.get(pid, [])
|
||||||
worker["language"] = pid_to_language.get(
|
worker["language"] = pid_to_language.get(
|
||||||
pid, dashboard_consts.DEFAULT_LANGUAGE)
|
pid, dashboard_consts.DEFAULT_LANGUAGE
|
||||||
worker["jobId"] = pid_to_job_id.get(
|
)
|
||||||
pid, dashboard_consts.DEFAULT_JOB_ID)
|
worker["jobId"] = pid_to_job_id.get(pid, dashboard_consts.DEFAULT_JOB_ID)
|
||||||
|
|
||||||
await GlobalSignals.worker_info_fetched.send(node_id, worker)
|
await GlobalSignals.worker_info_fetched.send(node_id, worker)
|
||||||
|
|
||||||
|
@ -143,8 +143,7 @@ class DataOrganizer:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_node_info(cls, node_id):
|
async def get_node_info(cls, node_id):
|
||||||
node_physical_stats = dict(
|
node_physical_stats = dict(DataSource.node_physical_stats.get(node_id, {}))
|
||||||
DataSource.node_physical_stats.get(node_id, {}))
|
|
||||||
node_stats = dict(DataSource.node_stats.get(node_id, {}))
|
node_stats = dict(DataSource.node_stats.get(node_id, {}))
|
||||||
node = DataSource.nodes.get(node_id, {})
|
node = DataSource.nodes.get(node_id, {})
|
||||||
node_ip = DataSource.node_id_to_ip.get(node_id)
|
node_ip = DataSource.node_id_to_ip.get(node_id)
|
||||||
|
@ -162,8 +161,8 @@ class DataOrganizer:
|
||||||
|
|
||||||
view_data = node_stats.get("viewData", [])
|
view_data = node_stats.get("viewData", [])
|
||||||
ray_stats = cls._extract_view_data(
|
ray_stats = cls._extract_view_data(
|
||||||
view_data,
|
view_data, {"object_store_used_memory", "object_store_available_memory"}
|
||||||
{"object_store_used_memory", "object_store_available_memory"})
|
)
|
||||||
|
|
||||||
node_info = node_physical_stats
|
node_info = node_physical_stats
|
||||||
# Merge node stats to node physical stats under raylet
|
# Merge node stats to node physical stats under raylet
|
||||||
|
@ -184,8 +183,7 @@ class DataOrganizer:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_node_summary(cls, node_id):
|
async def get_node_summary(cls, node_id):
|
||||||
node_physical_stats = dict(
|
node_physical_stats = dict(DataSource.node_physical_stats.get(node_id, {}))
|
||||||
DataSource.node_physical_stats.get(node_id, {}))
|
|
||||||
node_stats = dict(DataSource.node_stats.get(node_id, {}))
|
node_stats = dict(DataSource.node_stats.get(node_id, {}))
|
||||||
node = DataSource.nodes.get(node_id, {})
|
node = DataSource.nodes.get(node_id, {})
|
||||||
|
|
||||||
|
@ -193,8 +191,8 @@ class DataOrganizer:
|
||||||
node_stats.pop("workersStats", None)
|
node_stats.pop("workersStats", None)
|
||||||
view_data = node_stats.get("viewData", [])
|
view_data = node_stats.get("viewData", [])
|
||||||
ray_stats = cls._extract_view_data(
|
ray_stats = cls._extract_view_data(
|
||||||
view_data,
|
view_data, {"object_store_used_memory", "object_store_available_memory"}
|
||||||
{"object_store_used_memory", "object_store_available_memory"})
|
)
|
||||||
node_stats.pop("viewData", None)
|
node_stats.pop("viewData", None)
|
||||||
|
|
||||||
node_summary = node_physical_stats
|
node_summary = node_physical_stats
|
||||||
|
@ -244,8 +242,9 @@ class DataOrganizer:
|
||||||
actor = dict(actor)
|
actor = dict(actor)
|
||||||
worker_id = actor["address"]["workerId"]
|
worker_id = actor["address"]["workerId"]
|
||||||
core_worker_stats = DataSource.core_worker_stats.get(worker_id, {})
|
core_worker_stats = DataSource.core_worker_stats.get(worker_id, {})
|
||||||
actor_constructor = core_worker_stats.get("actorTitle",
|
actor_constructor = core_worker_stats.get(
|
||||||
"Unknown actor constructor")
|
"actorTitle", "Unknown actor constructor"
|
||||||
|
)
|
||||||
actor["actorConstructor"] = actor_constructor
|
actor["actorConstructor"] = actor_constructor
|
||||||
actor.update(core_worker_stats)
|
actor.update(core_worker_stats)
|
||||||
|
|
||||||
|
@ -275,8 +274,12 @@ class DataOrganizer:
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_actor_creation_tasks(cls):
|
async def get_actor_creation_tasks(cls):
|
||||||
infeasible_tasks = sum(
|
infeasible_tasks = sum(
|
||||||
(list(node_stats.get("infeasibleTasks", []))
|
(
|
||||||
for node_stats in DataSource.node_stats.values()), [])
|
list(node_stats.get("infeasibleTasks", []))
|
||||||
|
for node_stats in DataSource.node_stats.values()
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)
|
||||||
new_infeasible_tasks = []
|
new_infeasible_tasks = []
|
||||||
for task in infeasible_tasks:
|
for task in infeasible_tasks:
|
||||||
task = dict(task)
|
task = dict(task)
|
||||||
|
@ -285,8 +288,12 @@ class DataOrganizer:
|
||||||
new_infeasible_tasks.append(task)
|
new_infeasible_tasks.append(task)
|
||||||
|
|
||||||
resource_pending_tasks = sum(
|
resource_pending_tasks = sum(
|
||||||
(list(data.get("readyTasks", []))
|
(
|
||||||
for data in DataSource.node_stats.values()), [])
|
list(data.get("readyTasks", []))
|
||||||
|
for data in DataSource.node_stats.values()
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)
|
||||||
new_resource_pending_tasks = []
|
new_resource_pending_tasks = []
|
||||||
for task in resource_pending_tasks:
|
for task in resource_pending_tasks:
|
||||||
task = dict(task)
|
task = dict(task)
|
||||||
|
@ -301,14 +308,17 @@ class DataOrganizer:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_memory_table(cls,
|
async def get_memory_table(
|
||||||
sort_by=memory_utils.SortingType.OBJECT_SIZE,
|
cls,
|
||||||
group_by=memory_utils.GroupByType.STACK_TRACE):
|
sort_by=memory_utils.SortingType.OBJECT_SIZE,
|
||||||
|
group_by=memory_utils.GroupByType.STACK_TRACE,
|
||||||
|
):
|
||||||
all_worker_stats = []
|
all_worker_stats = []
|
||||||
for node_stats in DataSource.node_stats.values():
|
for node_stats in DataSource.node_stats.values():
|
||||||
all_worker_stats.extend(node_stats.get("coreWorkersStats", []))
|
all_worker_stats.extend(node_stats.get("coreWorkersStats", []))
|
||||||
memory_information = memory_utils.construct_memory_table(
|
memory_information = memory_utils.construct_memory_table(
|
||||||
all_worker_stats, group_by=group_by, sort_by=sort_by)
|
all_worker_stats, group_by=group_by, sort_by=sort_by
|
||||||
|
)
|
||||||
return memory_information
|
return memory_information
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -10,6 +10,7 @@ from queue import Queue
|
||||||
|
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from grpc import aio as aiogrpc
|
from grpc import aio as aiogrpc
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -23,8 +24,11 @@ import ray.dashboard.consts as dashboard_consts
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||||
from ray import ray_constants
|
from ray import ray_constants
|
||||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, \
|
from ray._private.gcs_pubsub import (
|
||||||
GcsAioErrorSubscriber, GcsAioLogSubscriber
|
gcs_pubsub_enabled,
|
||||||
|
GcsAioErrorSubscriber,
|
||||||
|
GcsAioLogSubscriber,
|
||||||
|
)
|
||||||
from ray.core.generated import gcs_service_pb2
|
from ray.core.generated import gcs_service_pb2
|
||||||
from ray.core.generated import gcs_service_pb2_grpc
|
from ray.core.generated import gcs_service_pb2_grpc
|
||||||
from ray.dashboard.datacenter import DataOrganizer
|
from ray.dashboard.datacenter import DataOrganizer
|
||||||
|
@ -42,33 +46,33 @@ aiogrpc.init_grpc_aio()
|
||||||
GRPC_CHANNEL_OPTIONS = (
|
GRPC_CHANNEL_OPTIONS = (
|
||||||
("grpc.enable_http_proxy", 0),
|
("grpc.enable_http_proxy", 0),
|
||||||
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
|
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
|
||||||
("grpc.max_receive_message_length",
|
("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
|
||||||
ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_gcs_address_with_retry(redis_client) -> str:
|
async def get_gcs_address_with_retry(redis_client) -> str:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
gcs_address = (await redis_client.get(
|
gcs_address = (
|
||||||
dashboard_consts.GCS_SERVER_ADDRESS)).decode()
|
await redis_client.get(dashboard_consts.GCS_SERVER_ADDRESS)
|
||||||
|
).decode()
|
||||||
if not gcs_address:
|
if not gcs_address:
|
||||||
raise Exception("GCS address not found.")
|
raise Exception("GCS address not found.")
|
||||||
logger.info("Connect to GCS at %s", gcs_address)
|
logger.info("Connect to GCS at %s", gcs_address)
|
||||||
return gcs_address
|
return gcs_address
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.error("Connect to GCS failed: %s, retry...", ex)
|
logger.error("Connect to GCS failed: %s, retry...", ex)
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(dashboard_consts.GCS_RETRY_CONNECT_INTERVAL_SECONDS)
|
||||||
dashboard_consts.GCS_RETRY_CONNECT_INTERVAL_SECONDS)
|
|
||||||
|
|
||||||
|
|
||||||
class GCSHealthCheckThread(threading.Thread):
|
class GCSHealthCheckThread(threading.Thread):
|
||||||
def __init__(self, gcs_address: str):
|
def __init__(self, gcs_address: str):
|
||||||
self.grpc_gcs_channel = ray._private.utils.init_grpc_channel(
|
self.grpc_gcs_channel = ray._private.utils.init_grpc_channel(
|
||||||
gcs_address, options=GRPC_CHANNEL_OPTIONS)
|
gcs_address, options=GRPC_CHANNEL_OPTIONS
|
||||||
self.gcs_heartbeat_info_stub = (
|
)
|
||||||
gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
|
self.gcs_heartbeat_info_stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
|
||||||
self.grpc_gcs_channel))
|
self.grpc_gcs_channel
|
||||||
|
)
|
||||||
self.work_queue = Queue()
|
self.work_queue = Queue()
|
||||||
|
|
||||||
super().__init__(daemon=True)
|
super().__init__(daemon=True)
|
||||||
|
@ -83,10 +87,10 @@ class GCSHealthCheckThread(threading.Thread):
|
||||||
request = gcs_service_pb2.CheckAliveRequest()
|
request = gcs_service_pb2.CheckAliveRequest()
|
||||||
try:
|
try:
|
||||||
reply = self.gcs_heartbeat_info_stub.CheckAlive(
|
reply = self.gcs_heartbeat_info_stub.CheckAlive(
|
||||||
request, timeout=dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT)
|
request, timeout=dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT
|
||||||
|
)
|
||||||
if reply.status.code != 0:
|
if reply.status.code != 0:
|
||||||
logger.exception(
|
logger.exception(f"Failed to CheckAlive: {reply.status.message}")
|
||||||
f"Failed to CheckAlive: {reply.status.message}")
|
|
||||||
return False
|
return False
|
||||||
except grpc.RpcError: # Deadline Exceeded
|
except grpc.RpcError: # Deadline Exceeded
|
||||||
logger.exception("Got RpcError when checking GCS is alive")
|
logger.exception("Got RpcError when checking GCS is alive")
|
||||||
|
@ -95,9 +99,9 @@ class GCSHealthCheckThread(threading.Thread):
|
||||||
|
|
||||||
async def check_once(self) -> bool:
|
async def check_once(self) -> bool:
|
||||||
"""Ask the thread to perform a healthcheck."""
|
"""Ask the thread to perform a healthcheck."""
|
||||||
assert threading.current_thread != self, (
|
assert (
|
||||||
"caller shouldn't be from the same thread as GCSHealthCheckThread."
|
threading.current_thread != self
|
||||||
)
|
), "caller shouldn't be from the same thread as GCSHealthCheckThread."
|
||||||
|
|
||||||
future = Future()
|
future = Future()
|
||||||
self.work_queue.put(future)
|
self.work_queue.put(future)
|
||||||
|
@ -105,8 +109,16 @@ class GCSHealthCheckThread(threading.Thread):
|
||||||
|
|
||||||
|
|
||||||
class DashboardHead:
|
class DashboardHead:
|
||||||
def __init__(self, http_host, http_port, http_port_retries, gcs_address,
|
def __init__(
|
||||||
redis_address, redis_password, log_dir):
|
self,
|
||||||
|
http_host,
|
||||||
|
http_port,
|
||||||
|
http_port_retries,
|
||||||
|
gcs_address,
|
||||||
|
redis_address,
|
||||||
|
redis_password,
|
||||||
|
log_dir,
|
||||||
|
):
|
||||||
self.health_check_thread: GCSHealthCheckThread = None
|
self.health_check_thread: GCSHealthCheckThread = None
|
||||||
self._gcs_rpc_error_counter = 0
|
self._gcs_rpc_error_counter = 0
|
||||||
# Public attributes are accessible for all head modules.
|
# Public attributes are accessible for all head modules.
|
||||||
|
@ -134,12 +146,12 @@ class DashboardHead:
|
||||||
else:
|
else:
|
||||||
ip, port = gcs_address.split(":")
|
ip, port = gcs_address.split(":")
|
||||||
|
|
||||||
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
|
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
|
||||||
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
|
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
|
||||||
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
|
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
|
||||||
self.server, f"{grpc_ip}:0")
|
self.server, f"{grpc_ip}:0"
|
||||||
logger.info("Dashboard head grpc address: %s:%s", grpc_ip,
|
)
|
||||||
self.grpc_port)
|
logger.info("Dashboard head grpc address: %s:%s", grpc_ip, self.grpc_port)
|
||||||
|
|
||||||
@async_loop_forever(dashboard_consts.GCS_CHECK_ALIVE_INTERVAL_SECONDS)
|
@async_loop_forever(dashboard_consts.GCS_CHECK_ALIVE_INTERVAL_SECONDS)
|
||||||
async def _gcs_check_alive(self):
|
async def _gcs_check_alive(self):
|
||||||
|
@ -149,7 +161,8 @@ class DashboardHead:
|
||||||
# Otherwise, the dashboard will always think that gcs is alive.
|
# Otherwise, the dashboard will always think that gcs is alive.
|
||||||
try:
|
try:
|
||||||
is_alive = await asyncio.wait_for(
|
is_alive = await asyncio.wait_for(
|
||||||
check_future, dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT + 1)
|
check_future, dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT + 1
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.error("Failed to check gcs health, client timed out.")
|
logger.error("Failed to check gcs health, client timed out.")
|
||||||
is_alive = False
|
is_alive = False
|
||||||
|
@ -158,13 +171,16 @@ class DashboardHead:
|
||||||
self._gcs_rpc_error_counter = 0
|
self._gcs_rpc_error_counter = 0
|
||||||
else:
|
else:
|
||||||
self._gcs_rpc_error_counter += 1
|
self._gcs_rpc_error_counter += 1
|
||||||
if self._gcs_rpc_error_counter > \
|
if (
|
||||||
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR:
|
self._gcs_rpc_error_counter
|
||||||
|
> dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR
|
||||||
|
):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Dashboard exiting because it received too many GCS RPC "
|
"Dashboard exiting because it received too many GCS RPC "
|
||||||
"errors count: %s, threshold is %s.",
|
"errors count: %s, threshold is %s.",
|
||||||
self._gcs_rpc_error_counter,
|
self._gcs_rpc_error_counter,
|
||||||
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR)
|
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR,
|
||||||
|
)
|
||||||
# TODO(fyrestone): Do not use ray.state in
|
# TODO(fyrestone): Do not use ray.state in
|
||||||
# PrometheusServiceDiscoveryWriter.
|
# PrometheusServiceDiscoveryWriter.
|
||||||
# Currently, we use os._exit() here to avoid hanging at the ray
|
# Currently, we use os._exit() here to avoid hanging at the ray
|
||||||
|
@ -176,10 +192,12 @@ class DashboardHead:
|
||||||
"""Load dashboard head modules."""
|
"""Load dashboard head modules."""
|
||||||
modules = []
|
modules = []
|
||||||
head_cls_list = dashboard_utils.get_all_modules(
|
head_cls_list = dashboard_utils.get_all_modules(
|
||||||
dashboard_utils.DashboardHeadModule)
|
dashboard_utils.DashboardHeadModule
|
||||||
|
)
|
||||||
for cls in head_cls_list:
|
for cls in head_cls_list:
|
||||||
logger.info("Loading %s: %s",
|
logger.info(
|
||||||
dashboard_utils.DashboardHeadModule.__name__, cls)
|
"Loading %s: %s", dashboard_utils.DashboardHeadModule.__name__, cls
|
||||||
|
)
|
||||||
c = cls(self)
|
c = cls(self)
|
||||||
dashboard_optional_utils.ClassMethodRouteTable.bind(c)
|
dashboard_optional_utils.ClassMethodRouteTable.bind(c)
|
||||||
modules.append(c)
|
modules.append(c)
|
||||||
|
@ -192,15 +210,17 @@ class DashboardHead:
|
||||||
return self.gcs_address
|
return self.gcs_address
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
self.aioredis_client = \
|
self.aioredis_client = await dashboard_utils.get_aioredis_client(
|
||||||
await dashboard_utils.get_aioredis_client(
|
self.redis_address,
|
||||||
self.redis_address, self.redis_password,
|
self.redis_password,
|
||||||
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
|
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
|
||||||
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES)
|
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES,
|
||||||
|
)
|
||||||
except (socket.gaierror, ConnectionError):
|
except (socket.gaierror, ConnectionError):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Dashboard head exiting: "
|
"Dashboard head exiting: " "Failed to connect to redis at %s",
|
||||||
"Failed to connect to redis at %s", self.redis_address)
|
self.redis_address,
|
||||||
|
)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
return await get_gcs_address_with_retry(self.aioredis_client)
|
return await get_gcs_address_with_retry(self.aioredis_client)
|
||||||
|
|
||||||
|
@ -209,22 +229,20 @@ class DashboardHead:
|
||||||
# Create a http session for all modules.
|
# Create a http session for all modules.
|
||||||
# aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
|
# aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
|
||||||
if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
|
if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
|
||||||
self.http_session = aiohttp.ClientSession(
|
self.http_session = aiohttp.ClientSession(loop=asyncio.get_event_loop())
|
||||||
loop=asyncio.get_event_loop())
|
|
||||||
else:
|
else:
|
||||||
self.http_session = aiohttp.ClientSession()
|
self.http_session = aiohttp.ClientSession()
|
||||||
|
|
||||||
gcs_address = await self.get_gcs_address()
|
gcs_address = await self.get_gcs_address()
|
||||||
|
|
||||||
# Dashboard will handle connection failure automatically
|
# Dashboard will handle connection failure automatically
|
||||||
self.gcs_client = GcsClient(
|
self.gcs_client = GcsClient(address=gcs_address, nums_reconnect_retry=0)
|
||||||
address=gcs_address, nums_reconnect_retry=0)
|
|
||||||
internal_kv._initialize_internal_kv(self.gcs_client)
|
internal_kv._initialize_internal_kv(self.gcs_client)
|
||||||
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
|
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
|
||||||
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True)
|
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True
|
||||||
|
)
|
||||||
if gcs_pubsub_enabled():
|
if gcs_pubsub_enabled():
|
||||||
self.gcs_error_subscriber = GcsAioErrorSubscriber(
|
self.gcs_error_subscriber = GcsAioErrorSubscriber(address=gcs_address)
|
||||||
address=gcs_address)
|
|
||||||
self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address)
|
self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address)
|
||||||
await self.gcs_error_subscriber.subscribe()
|
await self.gcs_error_subscriber.subscribe()
|
||||||
await self.gcs_log_subscriber.subscribe()
|
await self.gcs_log_subscriber.subscribe()
|
||||||
|
@ -248,7 +266,7 @@ class DashboardHead:
|
||||||
|
|
||||||
# Http server should be initialized after all modules loaded.
|
# Http server should be initialized after all modules loaded.
|
||||||
# working_dir uploads for job submission can be up to 100MiB.
|
# working_dir uploads for job submission can be up to 100MiB.
|
||||||
app = aiohttp.web.Application(client_max_size=100 * 1024**2)
|
app = aiohttp.web.Application(client_max_size=100 * 1024 ** 2)
|
||||||
app.add_routes(routes=routes.bound_routes())
|
app.add_routes(routes=routes.bound_routes())
|
||||||
|
|
||||||
runner = aiohttp.web.AppRunner(app)
|
runner = aiohttp.web.AppRunner(app)
|
||||||
|
@ -256,8 +274,7 @@ class DashboardHead:
|
||||||
last_ex = None
|
last_ex = None
|
||||||
for i in range(1 + self.http_port_retries):
|
for i in range(1 + self.http_port_retries):
|
||||||
try:
|
try:
|
||||||
site = aiohttp.web.TCPSite(runner, self.http_host,
|
site = aiohttp.web.TCPSite(runner, self.http_host, self.http_port)
|
||||||
self.http_port)
|
|
||||||
await site.start()
|
await site.start()
|
||||||
break
|
break
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
|
@ -265,11 +282,14 @@ class DashboardHead:
|
||||||
self.http_port += 1
|
self.http_port += 1
|
||||||
logger.warning("Try to use port %s: %s", self.http_port, e)
|
logger.warning("Try to use port %s: %s", self.http_port, e)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Failed to find a valid port for dashboard after "
|
raise Exception(
|
||||||
f"{self.http_port_retries} retries: {last_ex}")
|
f"Failed to find a valid port for dashboard after "
|
||||||
|
f"{self.http_port_retries} retries: {last_ex}"
|
||||||
|
)
|
||||||
http_host, http_port, *_ = site._server.sockets[0].getsockname()
|
http_host, http_port, *_ = site._server.sockets[0].getsockname()
|
||||||
http_host = self.ip if ipaddress.ip_address(
|
http_host = (
|
||||||
http_host).is_unspecified else http_host
|
self.ip if ipaddress.ip_address(http_host).is_unspecified else http_host
|
||||||
|
)
|
||||||
logger.info("Dashboard head http address: %s:%s", http_host, http_port)
|
logger.info("Dashboard head http address: %s:%s", http_host, http_port)
|
||||||
|
|
||||||
# TODO: Use async version if performance is an issue
|
# TODO: Use async version if performance is an issue
|
||||||
|
@ -277,16 +297,16 @@ class DashboardHead:
|
||||||
internal_kv._internal_kv_put(
|
internal_kv._internal_kv_put(
|
||||||
ray_constants.DASHBOARD_ADDRESS,
|
ray_constants.DASHBOARD_ADDRESS,
|
||||||
f"{http_host}:{http_port}",
|
f"{http_host}:{http_port}",
|
||||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||||
|
)
|
||||||
internal_kv._internal_kv_put(
|
internal_kv._internal_kv_put(
|
||||||
dashboard_consts.DASHBOARD_RPC_ADDRESS,
|
dashboard_consts.DASHBOARD_RPC_ADDRESS,
|
||||||
f"{self.ip}:{self.grpc_port}",
|
f"{self.ip}:{self.grpc_port}",
|
||||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||||
|
)
|
||||||
|
|
||||||
# Dump registered http routes.
|
# Dump registered http routes.
|
||||||
dump_routes = [
|
dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD]
|
||||||
r for r in app.router.routes() if r.method != hdrs.METH_HEAD
|
|
||||||
]
|
|
||||||
for r in dump_routes:
|
for r in dump_routes:
|
||||||
logger.info(r)
|
logger.info(r)
|
||||||
logger.info("Registered %s routes.", len(dump_routes))
|
logger.info("Registered %s routes.", len(dump_routes))
|
||||||
|
@ -299,6 +319,5 @@ class DashboardHead:
|
||||||
DataOrganizer.purge(),
|
DataOrganizer.purge(),
|
||||||
DataOrganizer.organize(),
|
DataOrganizer.organize(),
|
||||||
]
|
]
|
||||||
await asyncio.gather(*concurrent_tasks,
|
await asyncio.gather(*concurrent_tasks, *(m.run(self.server) for m in modules))
|
||||||
*(m.run(self.server) for m in modules))
|
|
||||||
await self.server.wait_for_termination()
|
await self.server.wait_for_termination()
|
||||||
|
|
|
@ -23,8 +23,7 @@ class HttpServerAgent:
|
||||||
# Create a http session for all modules.
|
# Create a http session for all modules.
|
||||||
# aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
|
# aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
|
||||||
if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
|
if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
|
||||||
self.http_session = aiohttp.ClientSession(
|
self.http_session = aiohttp.ClientSession(loop=asyncio.get_event_loop())
|
||||||
loop=asyncio.get_event_loop())
|
|
||||||
else:
|
else:
|
||||||
self.http_session = aiohttp.ClientSession()
|
self.http_session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
@ -47,25 +46,26 @@ class HttpServerAgent:
|
||||||
allow_methods="*",
|
allow_methods="*",
|
||||||
allow_headers=("Content-Type", "X-Header"),
|
allow_headers=("Content-Type", "X-Header"),
|
||||||
)
|
)
|
||||||
})
|
},
|
||||||
|
)
|
||||||
for route in list(app.router.routes()):
|
for route in list(app.router.routes()):
|
||||||
cors.add(route)
|
cors.add(route)
|
||||||
|
|
||||||
self.runner = aiohttp.web.AppRunner(app)
|
self.runner = aiohttp.web.AppRunner(app)
|
||||||
await self.runner.setup()
|
await self.runner.setup()
|
||||||
site = aiohttp.web.TCPSite(
|
site = aiohttp.web.TCPSite(
|
||||||
self.runner, "127.0.0.1"
|
self.runner,
|
||||||
if self.ip == "127.0.0.1" else "0.0.0.0", self.listen_port)
|
"127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0",
|
||||||
|
self.listen_port,
|
||||||
|
)
|
||||||
await site.start()
|
await site.start()
|
||||||
self.http_host, self.http_port, *_ = (
|
self.http_host, self.http_port, *_ = site._server.sockets[0].getsockname()
|
||||||
site._server.sockets[0].getsockname())
|
logger.info(
|
||||||
logger.info("Dashboard agent http address: %s:%s", self.http_host,
|
"Dashboard agent http address: %s:%s", self.http_host, self.http_port
|
||||||
self.http_port)
|
)
|
||||||
|
|
||||||
# Dump registered http routes.
|
# Dump registered http routes.
|
||||||
dump_routes = [
|
dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD]
|
||||||
r for r in app.router.routes() if r.method != hdrs.METH_HEAD
|
|
||||||
]
|
|
||||||
for r in dump_routes:
|
for r in dump_routes:
|
||||||
logger.info(r)
|
logger.info(r)
|
||||||
logger.info("Registered %s routes.", len(dump_routes))
|
logger.info("Registered %s routes.", len(dump_routes))
|
||||||
|
|
|
@ -31,7 +31,7 @@ def cpu_percent():
|
||||||
delta in total host cpu usage, averaged over host's cpus.
|
delta in total host cpu usage, averaged over host's cpus.
|
||||||
|
|
||||||
Since deltas are not initially available, return 0.0 on first call.
|
Since deltas are not initially available, return 0.0 on first call.
|
||||||
""" # noqa
|
""" # noqa
|
||||||
global last_system_usage
|
global last_system_usage
|
||||||
global last_cpu_usage
|
global last_cpu_usage
|
||||||
try:
|
try:
|
||||||
|
@ -43,12 +43,10 @@ def cpu_percent():
|
||||||
else:
|
else:
|
||||||
cpu_delta = cpu_usage - last_cpu_usage
|
cpu_delta = cpu_usage - last_cpu_usage
|
||||||
# "System time passed." (Typically close to clock time.)
|
# "System time passed." (Typically close to clock time.)
|
||||||
system_delta = (
|
system_delta = (system_usage - last_system_usage) / _host_num_cpus()
|
||||||
(system_usage - last_system_usage) / _host_num_cpus())
|
|
||||||
|
|
||||||
quotient = cpu_delta / system_delta
|
quotient = cpu_delta / system_delta
|
||||||
cpu_percent = round(
|
cpu_percent = round(quotient * 100 / ray._private.utils.get_k8s_cpus(), 1)
|
||||||
quotient * 100 / ray._private.utils.get_k8s_cpus(), 1)
|
|
||||||
last_system_usage = system_usage
|
last_system_usage = system_usage
|
||||||
last_cpu_usage = cpu_usage
|
last_cpu_usage = cpu_usage
|
||||||
# Computed percentage might be slightly above 100%.
|
# Computed percentage might be slightly above 100%.
|
||||||
|
@ -73,14 +71,14 @@ def _system_usage():
|
||||||
|
|
||||||
See also the /proc/stat entry here:
|
See also the /proc/stat entry here:
|
||||||
https://man7.org/linux/man-pages/man5/proc.5.html
|
https://man7.org/linux/man-pages/man5/proc.5.html
|
||||||
""" # noqa
|
""" # noqa
|
||||||
cpu_summary_str = open(PROC_STAT_PATH).read().split("\n")[0]
|
cpu_summary_str = open(PROC_STAT_PATH).read().split("\n")[0]
|
||||||
parts = cpu_summary_str.split()
|
parts = cpu_summary_str.split()
|
||||||
assert parts[0] == "cpu"
|
assert parts[0] == "cpu"
|
||||||
usage_data = parts[1:8]
|
usage_data = parts[1:8]
|
||||||
total_clock_ticks = sum(int(entry) for entry in usage_data)
|
total_clock_ticks = sum(int(entry) for entry in usage_data)
|
||||||
# 100 clock ticks per second, 10^9 ns per second
|
# 100 clock ticks per second, 10^9 ns per second
|
||||||
usage_ns = total_clock_ticks * 10**7
|
usage_ns = total_clock_ticks * 10 ** 7
|
||||||
return usage_ns
|
return usage_ns
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,7 +89,8 @@ def _host_num_cpus():
|
||||||
proc_stat_lines = open(PROC_STAT_PATH).read().split("\n")
|
proc_stat_lines = open(PROC_STAT_PATH).read().split("\n")
|
||||||
split_proc_stat_lines = [line.split() for line in proc_stat_lines]
|
split_proc_stat_lines = [line.split() for line in proc_stat_lines]
|
||||||
cpu_lines = [
|
cpu_lines = [
|
||||||
split_line for split_line in split_proc_stat_lines
|
split_line
|
||||||
|
for split_line in split_proc_stat_lines
|
||||||
if len(split_line) > 0 and "cpu" in split_line[0]
|
if len(split_line) > 0 and "cpu" in split_line[0]
|
||||||
]
|
]
|
||||||
# Number of lines starting with a word including 'cpu', subtracting
|
# Number of lines starting with a word including 'cpu', subtracting
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import List
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
from ray._raylet import (TaskID, ActorID, JobID)
|
from ray._raylet import TaskID, ActorID, JobID
|
||||||
from ray.internal.internal_api import node_stats
|
from ray.internal.internal_api import node_stats
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -69,8 +69,10 @@ def get_sorting_type(sort_by: str):
|
||||||
elif sort_by == "REFERENCE_TYPE":
|
elif sort_by == "REFERENCE_TYPE":
|
||||||
return SortingType.REFERENCE_TYPE
|
return SortingType.REFERENCE_TYPE
|
||||||
else:
|
else:
|
||||||
raise Exception("The sort-by input provided is not one of\
|
raise Exception(
|
||||||
PID, OBJECT_SIZE, or REFERENCE_TYPE.")
|
"The sort-by input provided is not one of\
|
||||||
|
PID, OBJECT_SIZE, or REFERENCE_TYPE."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_group_by_type(group_by: str):
|
def get_group_by_type(group_by: str):
|
||||||
|
@ -81,13 +83,16 @@ def get_group_by_type(group_by: str):
|
||||||
elif group_by == "STACK_TRACE":
|
elif group_by == "STACK_TRACE":
|
||||||
return GroupByType.STACK_TRACE
|
return GroupByType.STACK_TRACE
|
||||||
else:
|
else:
|
||||||
raise Exception("The group-by input provided is not one of\
|
raise Exception(
|
||||||
NODE_ADDRESS or STACK_TRACE.")
|
"The group-by input provided is not one of\
|
||||||
|
NODE_ADDRESS or STACK_TRACE."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MemoryTableEntry:
|
class MemoryTableEntry:
|
||||||
def __init__(self, *, object_ref: dict, node_address: str, is_driver: bool,
|
def __init__(
|
||||||
pid: int):
|
self, *, object_ref: dict, node_address: str, is_driver: bool, pid: int
|
||||||
|
):
|
||||||
# worker info
|
# worker info
|
||||||
self.is_driver = is_driver
|
self.is_driver = is_driver
|
||||||
self.pid = pid
|
self.pid = pid
|
||||||
|
@ -97,13 +102,13 @@ class MemoryTableEntry:
|
||||||
self.object_size = int(object_ref.get("objectSize", -1))
|
self.object_size = int(object_ref.get("objectSize", -1))
|
||||||
self.call_site = object_ref.get("callSite", "<Unknown>")
|
self.call_site = object_ref.get("callSite", "<Unknown>")
|
||||||
self.object_ref = ray.ObjectRef(
|
self.object_ref = ray.ObjectRef(
|
||||||
decode_object_ref_if_needed(object_ref["objectId"]))
|
decode_object_ref_if_needed(object_ref["objectId"])
|
||||||
|
)
|
||||||
|
|
||||||
# reference info
|
# reference info
|
||||||
self.local_ref_count = int(object_ref.get("localRefCount", 0))
|
self.local_ref_count = int(object_ref.get("localRefCount", 0))
|
||||||
self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False))
|
self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False))
|
||||||
self.submitted_task_ref_count = int(
|
self.submitted_task_ref_count = int(object_ref.get("submittedTaskRefCount", 0))
|
||||||
object_ref.get("submittedTaskRefCount", 0))
|
|
||||||
self.contained_in_owned = [
|
self.contained_in_owned = [
|
||||||
ray.ObjectRef(decode_object_ref_if_needed(object_ref))
|
ray.ObjectRef(decode_object_ref_if_needed(object_ref))
|
||||||
for object_ref in object_ref.get("containedInOwned", [])
|
for object_ref in object_ref.get("containedInOwned", [])
|
||||||
|
@ -113,9 +118,12 @@ class MemoryTableEntry:
|
||||||
def is_valid(self) -> bool:
|
def is_valid(self) -> bool:
|
||||||
# If the entry doesn't have a reference type or some invalid state,
|
# If the entry doesn't have a reference type or some invalid state,
|
||||||
# (e.g., no object ref presented), it is considered invalid.
|
# (e.g., no object ref presented), it is considered invalid.
|
||||||
if (not self.pinned_in_memory and self.local_ref_count == 0
|
if (
|
||||||
and self.submitted_task_ref_count == 0
|
not self.pinned_in_memory
|
||||||
and len(self.contained_in_owned) == 0):
|
and self.local_ref_count == 0
|
||||||
|
and self.submitted_task_ref_count == 0
|
||||||
|
and len(self.contained_in_owned) == 0
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
elif self.object_ref.is_nil():
|
elif self.object_ref.is_nil():
|
||||||
return False
|
return False
|
||||||
|
@ -153,10 +161,10 @@ class MemoryTableEntry:
|
||||||
# are not all 'f', that means it is an actor creation
|
# are not all 'f', that means it is an actor creation
|
||||||
# task, which is an actor handle.
|
# task, which is an actor handle.
|
||||||
random_bits = object_ref_hex[:TASKID_RANDOM_BITS_SIZE]
|
random_bits = object_ref_hex[:TASKID_RANDOM_BITS_SIZE]
|
||||||
actor_random_bits = object_ref_hex[TASKID_RANDOM_BITS_SIZE:
|
actor_random_bits = object_ref_hex[
|
||||||
TASKID_RANDOM_BITS_SIZE +
|
TASKID_RANDOM_BITS_SIZE : TASKID_RANDOM_BITS_SIZE + ACTORID_RANDOM_BITS_SIZE
|
||||||
ACTORID_RANDOM_BITS_SIZE]
|
]
|
||||||
if (random_bits == "f" * 16 and not actor_random_bits == "f" * 24):
|
if random_bits == "f" * 16 and not actor_random_bits == "f" * 24:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
@ -175,7 +183,7 @@ class MemoryTableEntry:
|
||||||
"contained_in_owned": [
|
"contained_in_owned": [
|
||||||
object_ref.hex() for object_ref in self.contained_in_owned
|
object_ref.hex() for object_ref in self.contained_in_owned
|
||||||
],
|
],
|
||||||
"type": "Driver" if self.is_driver else "Worker"
|
"type": "Driver" if self.is_driver else "Worker",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -186,10 +194,12 @@ class MemoryTableEntry:
|
||||||
|
|
||||||
|
|
||||||
class MemoryTable:
|
class MemoryTable:
|
||||||
def __init__(self,
|
def __init__(
|
||||||
entries: List[MemoryTableEntry],
|
self,
|
||||||
group_by_type: GroupByType = GroupByType.NODE_ADDRESS,
|
entries: List[MemoryTableEntry],
|
||||||
sort_by_type: SortingType = SortingType.PID):
|
group_by_type: GroupByType = GroupByType.NODE_ADDRESS,
|
||||||
|
sort_by_type: SortingType = SortingType.PID,
|
||||||
|
):
|
||||||
self.table = entries
|
self.table = entries
|
||||||
# Group is a list of memory tables grouped by a group key.
|
# Group is a list of memory tables grouped by a group key.
|
||||||
self.group = {}
|
self.group = {}
|
||||||
|
@ -247,7 +257,7 @@ class MemoryTable:
|
||||||
"total_pinned_in_memory": total_pinned_in_memory,
|
"total_pinned_in_memory": total_pinned_in_memory,
|
||||||
"total_used_by_pending_task": total_used_by_pending_task,
|
"total_used_by_pending_task": total_used_by_pending_task,
|
||||||
"total_captured_in_objects": total_captured_in_objects,
|
"total_captured_in_objects": total_captured_in_objects,
|
||||||
"total_actor_handles": total_actor_handles
|
"total_actor_handles": total_actor_handles,
|
||||||
}
|
}
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -278,7 +288,8 @@ class MemoryTable:
|
||||||
# Build a group table.
|
# Build a group table.
|
||||||
for group_key, entries in group.items():
|
for group_key, entries in group.items():
|
||||||
self.group[group_key] = MemoryTable(
|
self.group[group_key] = MemoryTable(
|
||||||
entries, group_by_type=None, sort_by_type=None)
|
entries, group_by_type=None, sort_by_type=None
|
||||||
|
)
|
||||||
for group_key, group_memory_table in self.group.items():
|
for group_key, group_memory_table in self.group.items():
|
||||||
group_memory_table.summarize()
|
group_memory_table.summarize()
|
||||||
return self
|
return self
|
||||||
|
@ -289,10 +300,10 @@ class MemoryTable:
|
||||||
"group": {
|
"group": {
|
||||||
group_key: {
|
group_key: {
|
||||||
"entries": group_memory_table.get_entries(),
|
"entries": group_memory_table.get_entries(),
|
||||||
"summary": group_memory_table.summary
|
"summary": group_memory_table.summary,
|
||||||
}
|
}
|
||||||
for group_key, group_memory_table in self.group.items()
|
for group_key, group_memory_table in self.group.items()
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_entries(self) -> List[dict]:
|
def get_entries(self) -> List[dict]:
|
||||||
|
@ -305,9 +316,11 @@ class MemoryTable:
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
|
|
||||||
|
|
||||||
def construct_memory_table(workers_stats: List,
|
def construct_memory_table(
|
||||||
group_by: GroupByType = GroupByType.NODE_ADDRESS,
|
workers_stats: List,
|
||||||
sort_by=SortingType.OBJECT_SIZE) -> MemoryTable:
|
group_by: GroupByType = GroupByType.NODE_ADDRESS,
|
||||||
|
sort_by=SortingType.OBJECT_SIZE,
|
||||||
|
) -> MemoryTable:
|
||||||
memory_table_entries = []
|
memory_table_entries = []
|
||||||
for core_worker_stats in workers_stats:
|
for core_worker_stats in workers_stats:
|
||||||
pid = core_worker_stats["pid"]
|
pid = core_worker_stats["pid"]
|
||||||
|
@ -320,11 +333,13 @@ def construct_memory_table(workers_stats: List,
|
||||||
object_ref=object_ref,
|
object_ref=object_ref,
|
||||||
node_address=node_address,
|
node_address=node_address,
|
||||||
is_driver=is_driver,
|
is_driver=is_driver,
|
||||||
pid=pid)
|
pid=pid,
|
||||||
|
)
|
||||||
if memory_table_entry.is_valid():
|
if memory_table_entry.is_valid():
|
||||||
memory_table_entries.append(memory_table_entry)
|
memory_table_entries.append(memory_table_entry)
|
||||||
memory_table = MemoryTable(
|
memory_table = MemoryTable(
|
||||||
memory_table_entries, group_by_type=group_by, sort_by_type=sort_by)
|
memory_table_entries, group_by_type=group_by, sort_by_type=sort_by
|
||||||
|
)
|
||||||
return memory_table
|
return memory_table
|
||||||
|
|
||||||
|
|
||||||
|
@ -337,7 +352,7 @@ def track_reference_size(group):
|
||||||
"PINNED_IN_MEMORY": "total_pinned_in_memory",
|
"PINNED_IN_MEMORY": "total_pinned_in_memory",
|
||||||
"USED_BY_PENDING_TASK": "total_used_by_pending_task",
|
"USED_BY_PENDING_TASK": "total_used_by_pending_task",
|
||||||
"CAPTURED_IN_OBJECT": "total_captured_in_objects",
|
"CAPTURED_IN_OBJECT": "total_captured_in_objects",
|
||||||
"ACTOR_HANDLE": "total_actor_handles"
|
"ACTOR_HANDLE": "total_actor_handles",
|
||||||
}
|
}
|
||||||
for entry in group["entries"]:
|
for entry in group["entries"]:
|
||||||
size = entry["object_size"]
|
size = entry["object_size"]
|
||||||
|
@ -348,51 +363,64 @@ def track_reference_size(group):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def memory_summary(state,
|
def memory_summary(
|
||||||
group_by="NODE_ADDRESS",
|
state,
|
||||||
sort_by="OBJECT_SIZE",
|
group_by="NODE_ADDRESS",
|
||||||
line_wrap=True,
|
sort_by="OBJECT_SIZE",
|
||||||
unit="B",
|
line_wrap=True,
|
||||||
num_entries=None) -> str:
|
unit="B",
|
||||||
from ray.dashboard.modules.node.node_head\
|
num_entries=None,
|
||||||
import node_stats_to_dict
|
) -> str:
|
||||||
|
from ray.dashboard.modules.node.node_head import node_stats_to_dict
|
||||||
|
|
||||||
# Get terminal size
|
# Get terminal size
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
size = shutil.get_terminal_size((80, 20)).columns
|
size = shutil.get_terminal_size((80, 20)).columns
|
||||||
line_wrap_threshold = 137
|
line_wrap_threshold = 137
|
||||||
|
|
||||||
# Unit conversions
|
# Unit conversions
|
||||||
units = {"B": 10**0, "KB": 10**3, "MB": 10**6, "GB": 10**9}
|
units = {"B": 10 ** 0, "KB": 10 ** 3, "MB": 10 ** 6, "GB": 10 ** 9}
|
||||||
|
|
||||||
# Fetch core memory worker stats, store as a dictionary
|
# Fetch core memory worker stats, store as a dictionary
|
||||||
core_worker_stats = []
|
core_worker_stats = []
|
||||||
for raylet in state.node_table():
|
for raylet in state.node_table():
|
||||||
stats = node_stats_to_dict(
|
stats = node_stats_to_dict(
|
||||||
node_stats(raylet["NodeManagerAddress"],
|
node_stats(raylet["NodeManagerAddress"], raylet["NodeManagerPort"])
|
||||||
raylet["NodeManagerPort"]))
|
)
|
||||||
core_worker_stats.extend(stats["coreWorkersStats"])
|
core_worker_stats.extend(stats["coreWorkersStats"])
|
||||||
assert type(stats) is dict and "coreWorkersStats" in stats
|
assert type(stats) is dict and "coreWorkersStats" in stats
|
||||||
|
|
||||||
# Build memory table with "group_by" and "sort_by" parameters
|
# Build memory table with "group_by" and "sort_by" parameters
|
||||||
group_by, sort_by = get_group_by_type(group_by), get_sorting_type(sort_by)
|
group_by, sort_by = get_group_by_type(group_by), get_sorting_type(sort_by)
|
||||||
memory_table = construct_memory_table(core_worker_stats, group_by,
|
memory_table = construct_memory_table(
|
||||||
sort_by).as_dict()
|
core_worker_stats, group_by, sort_by
|
||||||
|
).as_dict()
|
||||||
assert "summary" in memory_table and "group" in memory_table
|
assert "summary" in memory_table and "group" in memory_table
|
||||||
|
|
||||||
# Build memory summary
|
# Build memory summary
|
||||||
mem = ""
|
mem = ""
|
||||||
group_by, sort_by = group_by.name.lower().replace(
|
group_by, sort_by = group_by.name.lower().replace(
|
||||||
"_", " "), sort_by.name.lower().replace("_", " ")
|
"_", " "
|
||||||
|
), sort_by.name.lower().replace("_", " ")
|
||||||
summary_labels = [
|
summary_labels = [
|
||||||
"Mem Used by Objects", "Local References", "Pinned", "Pending Tasks",
|
"Mem Used by Objects",
|
||||||
"Captured in Objects", "Actor Handles"
|
"Local References",
|
||||||
|
"Pinned",
|
||||||
|
"Pending Tasks",
|
||||||
|
"Captured in Objects",
|
||||||
|
"Actor Handles",
|
||||||
]
|
]
|
||||||
summary_string = "{:<19} {:<16} {:<12} {:<13} {:<19} {:<13}\n"
|
summary_string = "{:<19} {:<16} {:<12} {:<13} {:<19} {:<13}\n"
|
||||||
|
|
||||||
object_ref_labels = [
|
object_ref_labels = [
|
||||||
"IP Address", "PID", "Type", "Call Site", "Size", "Reference Type",
|
"IP Address",
|
||||||
"Object Ref"
|
"PID",
|
||||||
|
"Type",
|
||||||
|
"Call Site",
|
||||||
|
"Size",
|
||||||
|
"Reference Type",
|
||||||
|
"Object Ref",
|
||||||
]
|
]
|
||||||
object_ref_string = "{:<13} | {:<8} | {:<7} | {:<9} \
|
object_ref_string = "{:<13} | {:<8} | {:<7} | {:<9} \
|
||||||
| {:<8} | {:<14} | {:<10}\n"
|
| {:<8} | {:<14} | {:<10}\n"
|
||||||
|
@ -416,22 +444,21 @@ entries per group...\n\n\n"
|
||||||
else:
|
else:
|
||||||
summary[k] = str(v) + f", ({ref_size[k] / units[unit]} {unit})"
|
summary[k] = str(v) + f", ({ref_size[k] / units[unit]} {unit})"
|
||||||
mem += f"--- Summary for {group_by}: {key} ---\n"
|
mem += f"--- Summary for {group_by}: {key} ---\n"
|
||||||
mem += summary_string\
|
mem += summary_string.format(*summary_labels)
|
||||||
.format(*summary_labels)
|
mem += summary_string.format(*summary.values()) + "\n"
|
||||||
mem += summary_string\
|
|
||||||
.format(*summary.values()) + "\n"
|
|
||||||
|
|
||||||
# Memory table per group
|
# Memory table per group
|
||||||
mem += f"--- Object references for {group_by}: {key} ---\n"
|
mem += f"--- Object references for {group_by}: {key} ---\n"
|
||||||
mem += object_ref_string\
|
mem += object_ref_string.format(*object_ref_labels)
|
||||||
.format(*object_ref_labels)
|
|
||||||
n = 1 # Counter for num entries per group
|
n = 1 # Counter for num entries per group
|
||||||
for entry in group["entries"]:
|
for entry in group["entries"]:
|
||||||
if num_entries is not None and n > num_entries:
|
if num_entries is not None and n > num_entries:
|
||||||
break
|
break
|
||||||
entry["object_size"] = str(
|
entry["object_size"] = (
|
||||||
entry["object_size"] /
|
str(entry["object_size"] / units[unit]) + f" {unit}"
|
||||||
units[unit]) + f" {unit}" if entry["object_size"] > -1 else "?"
|
if entry["object_size"] > -1
|
||||||
|
else "?"
|
||||||
|
)
|
||||||
num_lines = 1
|
num_lines = 1
|
||||||
if size > line_wrap_threshold and line_wrap:
|
if size > line_wrap_threshold and line_wrap:
|
||||||
call_site_length = 22
|
call_site_length = 22
|
||||||
|
@ -439,30 +466,36 @@ entries per group...\n\n\n"
|
||||||
entry["call_site"] = ["disabled"]
|
entry["call_site"] = ["disabled"]
|
||||||
else:
|
else:
|
||||||
entry["call_site"] = [
|
entry["call_site"] = [
|
||||||
entry["call_site"][i:i + call_site_length] for i in
|
entry["call_site"][i : i + call_site_length]
|
||||||
range(0, len(entry["call_site"]), call_site_length)
|
for i in range(0, len(entry["call_site"]), call_site_length)
|
||||||
]
|
]
|
||||||
num_lines = len(entry["call_site"])
|
num_lines = len(entry["call_site"])
|
||||||
else:
|
else:
|
||||||
mem += "\n"
|
mem += "\n"
|
||||||
object_ref_values = [
|
object_ref_values = [
|
||||||
entry["node_ip_address"], entry["pid"], entry["type"],
|
entry["node_ip_address"],
|
||||||
entry["call_site"], entry["object_size"],
|
entry["pid"],
|
||||||
entry["reference_type"], entry["object_ref"]
|
entry["type"],
|
||||||
|
entry["call_site"],
|
||||||
|
entry["object_size"],
|
||||||
|
entry["reference_type"],
|
||||||
|
entry["object_ref"],
|
||||||
]
|
]
|
||||||
for i in range(len(object_ref_values)):
|
for i in range(len(object_ref_values)):
|
||||||
if not isinstance(object_ref_values[i], list):
|
if not isinstance(object_ref_values[i], list):
|
||||||
object_ref_values[i] = [object_ref_values[i]]
|
object_ref_values[i] = [object_ref_values[i]]
|
||||||
object_ref_values[i].extend(
|
object_ref_values[i].extend(
|
||||||
["" for x in range(num_lines - len(object_ref_values[i]))])
|
["" for x in range(num_lines - len(object_ref_values[i]))]
|
||||||
|
)
|
||||||
for i in range(num_lines):
|
for i in range(num_lines):
|
||||||
row = [elem[i] for elem in object_ref_values]
|
row = [elem[i] for elem in object_ref_values]
|
||||||
mem += object_ref_string\
|
mem += object_ref_string.format(*row)
|
||||||
.format(*row)
|
|
||||||
mem += "\n"
|
mem += "\n"
|
||||||
n += 1
|
n += 1
|
||||||
|
|
||||||
mem += "To record callsite information for each ObjectRef created, set " \
|
mem += (
|
||||||
"env variable RAY_record_ref_creation_sites=1\n\n"
|
"To record callsite information for each ObjectRef created, set "
|
||||||
|
"env variable RAY_record_ref_creation_sites=1\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
return mem
|
return mem
|
||||||
|
|
|
@ -4,6 +4,7 @@ import aiohttp.web
|
||||||
import ray._private.utils
|
import ray._private.utils
|
||||||
from ray.dashboard.modules.actor import actor_utils
|
from ray.dashboard.modules.actor import actor_utils
|
||||||
from aioredis.pubsub import Receiver
|
from aioredis.pubsub import Receiver
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from grpc import aio as aiogrpc
|
from grpc import aio as aiogrpc
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -15,8 +16,7 @@ import ray.dashboard.utils as dashboard_utils
|
||||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||||
from ray.dashboard.optional_utils import rest_response
|
from ray.dashboard.optional_utils import rest_response
|
||||||
from ray.dashboard.modules.actor import actor_consts
|
from ray.dashboard.modules.actor import actor_consts
|
||||||
from ray.dashboard.modules.actor.actor_utils import \
|
from ray.dashboard.modules.actor.actor_utils import actor_classname_from_task_spec
|
||||||
actor_classname_from_task_spec
|
|
||||||
from ray.core.generated import node_manager_pb2_grpc
|
from ray.core.generated import node_manager_pb2_grpc
|
||||||
from ray.core.generated import gcs_service_pb2
|
from ray.core.generated import gcs_service_pb2
|
||||||
from ray.core.generated import gcs_service_pb2_grpc
|
from ray.core.generated import gcs_service_pb2_grpc
|
||||||
|
@ -30,12 +30,22 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
|
||||||
|
|
||||||
def actor_table_data_to_dict(message):
|
def actor_table_data_to_dict(message):
|
||||||
orig_message = dashboard_utils.message_to_dict(
|
orig_message = dashboard_utils.message_to_dict(
|
||||||
message, {
|
message,
|
||||||
"actorId", "parentId", "jobId", "workerId", "rayletId",
|
{
|
||||||
"actorCreationDummyObjectId", "callerId", "taskId", "parentTaskId",
|
"actorId",
|
||||||
"sourceActorId", "placementGroupId"
|
"parentId",
|
||||||
|
"jobId",
|
||||||
|
"workerId",
|
||||||
|
"rayletId",
|
||||||
|
"actorCreationDummyObjectId",
|
||||||
|
"callerId",
|
||||||
|
"taskId",
|
||||||
|
"parentTaskId",
|
||||||
|
"sourceActorId",
|
||||||
|
"placementGroupId",
|
||||||
},
|
},
|
||||||
including_default_value_fields=True)
|
including_default_value_fields=True,
|
||||||
|
)
|
||||||
# The complete schema for actor table is here:
|
# The complete schema for actor table is here:
|
||||||
# src/ray/protobuf/gcs.proto
|
# src/ray/protobuf/gcs.proto
|
||||||
# It is super big and for dashboard, we don't need that much information.
|
# It is super big and for dashboard, we don't need that much information.
|
||||||
|
@ -58,8 +68,7 @@ def actor_table_data_to_dict(message):
|
||||||
light_message["actorClass"] = actor_class
|
light_message["actorClass"] = actor_class
|
||||||
if "functionDescriptor" in light_message["taskSpec"]:
|
if "functionDescriptor" in light_message["taskSpec"]:
|
||||||
light_message["taskSpec"] = {
|
light_message["taskSpec"] = {
|
||||||
"functionDescriptor": light_message["taskSpec"][
|
"functionDescriptor": light_message["taskSpec"]["functionDescriptor"]
|
||||||
"functionDescriptor"]
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
light_message.pop("taskSpec")
|
light_message.pop("taskSpec")
|
||||||
|
@ -81,11 +90,13 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
||||||
if change.new:
|
if change.new:
|
||||||
# TODO(fyrestone): Handle exceptions.
|
# TODO(fyrestone): Handle exceptions.
|
||||||
node_id, node_info = change.new
|
node_id, node_info = change.new
|
||||||
address = "{}:{}".format(node_info["nodeManagerAddress"],
|
address = "{}:{}".format(
|
||||||
int(node_info["nodeManagerPort"]))
|
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
|
||||||
options = (("grpc.enable_http_proxy", 0), )
|
)
|
||||||
|
options = (("grpc.enable_http_proxy", 0),)
|
||||||
channel = ray._private.utils.init_grpc_channel(
|
channel = ray._private.utils.init_grpc_channel(
|
||||||
address, options, asynchronous=True)
|
address, options, asynchronous=True
|
||||||
|
)
|
||||||
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
|
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
|
||||||
self._stubs[node_id] = stub
|
self._stubs[node_id] = stub
|
||||||
|
|
||||||
|
@ -96,7 +107,8 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
||||||
logger.info("Getting all actor info from GCS.")
|
logger.info("Getting all actor info from GCS.")
|
||||||
request = gcs_service_pb2.GetAllActorInfoRequest()
|
request = gcs_service_pb2.GetAllActorInfoRequest()
|
||||||
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
|
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
|
||||||
request, timeout=5)
|
request, timeout=5
|
||||||
|
)
|
||||||
if reply.status.code == 0:
|
if reply.status.code == 0:
|
||||||
actors = {}
|
actors = {}
|
||||||
for message in reply.actor_table_data:
|
for message in reply.actor_table_data:
|
||||||
|
@ -110,24 +122,25 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
||||||
for actor_id, actor_table_data in actors.items():
|
for actor_id, actor_table_data in actors.items():
|
||||||
job_id = actor_table_data["jobId"]
|
job_id = actor_table_data["jobId"]
|
||||||
node_id = actor_table_data["address"]["rayletId"]
|
node_id = actor_table_data["address"]["rayletId"]
|
||||||
job_actors.setdefault(job_id,
|
job_actors.setdefault(job_id, {})[actor_id] = actor_table_data
|
||||||
{})[actor_id] = actor_table_data
|
|
||||||
# Update only when node_id is not Nil.
|
# Update only when node_id is not Nil.
|
||||||
if node_id != actor_consts.NIL_NODE_ID:
|
if node_id != actor_consts.NIL_NODE_ID:
|
||||||
node_actors.setdefault(
|
node_actors.setdefault(node_id, {})[
|
||||||
node_id, {})[actor_id] = actor_table_data
|
actor_id
|
||||||
|
] = actor_table_data
|
||||||
DataSource.job_actors.reset(job_actors)
|
DataSource.job_actors.reset(job_actors)
|
||||||
DataSource.node_actors.reset(node_actors)
|
DataSource.node_actors.reset(node_actors)
|
||||||
logger.info("Received %d actor info from GCS.",
|
logger.info("Received %d actor info from GCS.", len(actors))
|
||||||
len(actors))
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to GetAllActorInfo: {reply.status.message}")
|
f"Failed to GetAllActorInfo: {reply.status.message}"
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error Getting all actor info from GCS.")
|
logger.exception("Error Getting all actor info from GCS.")
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS)
|
actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS
|
||||||
|
)
|
||||||
|
|
||||||
state_keys = ("state", "address", "numRestarts", "timestamp", "pid")
|
state_keys = ("state", "address", "numRestarts", "timestamp", "pid")
|
||||||
|
|
||||||
|
@ -167,8 +180,7 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
||||||
if actor_id is not None:
|
if actor_id is not None:
|
||||||
# Convert to lower case hex ID.
|
# Convert to lower case hex ID.
|
||||||
actor_id = actor_id.hex()
|
actor_id = actor_id.hex()
|
||||||
process_actor_data_from_pubsub(actor_id,
|
process_actor_data_from_pubsub(actor_id, actor_table_data)
|
||||||
actor_table_data)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error processing actor info from GCS.")
|
logger.exception("Error processing actor info from GCS.")
|
||||||
|
|
||||||
|
@ -183,12 +195,15 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
||||||
async for sender, msg in receiver.iter():
|
async for sender, msg in receiver.iter():
|
||||||
try:
|
try:
|
||||||
actor_id, actor_table_data = msg
|
actor_id, actor_table_data = msg
|
||||||
actor_id = actor_id.decode("UTF-8")[len(
|
actor_id = actor_id.decode("UTF-8")[
|
||||||
gcs_utils.TablePrefix_ACTOR_string + ":"):]
|
len(gcs_utils.TablePrefix_ACTOR_string + ":") :
|
||||||
|
]
|
||||||
pubsub_message = gcs_utils.PubSubMessage.FromString(
|
pubsub_message = gcs_utils.PubSubMessage.FromString(
|
||||||
actor_table_data)
|
actor_table_data
|
||||||
|
)
|
||||||
actor_table_data = gcs_utils.ActorTableData.FromString(
|
actor_table_data = gcs_utils.ActorTableData.FromString(
|
||||||
pubsub_message.data)
|
pubsub_message.data
|
||||||
|
)
|
||||||
process_actor_data_from_pubsub(actor_id, actor_table_data)
|
process_actor_data_from_pubsub(actor_id, actor_table_data)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error processing actor info from Redis.")
|
logger.exception("Error processing actor info from Redis.")
|
||||||
|
@ -203,17 +218,15 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
||||||
actors.update(actor_creation_tasks)
|
actors.update(actor_creation_tasks)
|
||||||
actor_groups = actor_utils.construct_actor_groups(actors)
|
actor_groups = actor_utils.construct_actor_groups(actors)
|
||||||
return rest_response(
|
return rest_response(
|
||||||
success=True,
|
success=True, message="Fetched actor groups.", actor_groups=actor_groups
|
||||||
message="Fetched actor groups.",
|
)
|
||||||
actor_groups=actor_groups)
|
|
||||||
|
|
||||||
@routes.get("/logical/actors")
|
@routes.get("/logical/actors")
|
||||||
@dashboard_optional_utils.aiohttp_cache
|
@dashboard_optional_utils.aiohttp_cache
|
||||||
async def get_all_actors(self, req) -> aiohttp.web.Response:
|
async def get_all_actors(self, req) -> aiohttp.web.Response:
|
||||||
return rest_response(
|
return rest_response(
|
||||||
success=True,
|
success=True, message="All actors fetched.", actors=DataSource.actors
|
||||||
message="All actors fetched.",
|
)
|
||||||
actors=DataSource.actors)
|
|
||||||
|
|
||||||
@routes.get("/logical/kill_actor")
|
@routes.get("/logical/kill_actor")
|
||||||
async def kill_actor(self, req) -> aiohttp.web.Response:
|
async def kill_actor(self, req) -> aiohttp.web.Response:
|
||||||
|
@ -224,15 +237,17 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return rest_response(success=False, message="Bad Request")
|
return rest_response(success=False, message="Bad Request")
|
||||||
try:
|
try:
|
||||||
options = (("grpc.enable_http_proxy", 0), )
|
options = (("grpc.enable_http_proxy", 0),)
|
||||||
channel = ray._private.utils.init_grpc_channel(
|
channel = ray._private.utils.init_grpc_channel(
|
||||||
f"{ip_address}:{port}", options=options, asynchronous=True)
|
f"{ip_address}:{port}", options=options, asynchronous=True
|
||||||
|
)
|
||||||
stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel)
|
stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel)
|
||||||
|
|
||||||
await stub.KillActor(
|
await stub.KillActor(
|
||||||
core_worker_pb2.KillActorRequest(
|
core_worker_pb2.KillActorRequest(
|
||||||
intended_actor_id=ray._private.utils.hex_to_binary(
|
intended_actor_id=ray._private.utils.hex_to_binary(actor_id)
|
||||||
actor_id)))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
except aiogrpc.AioRpcError:
|
except aiogrpc.AioRpcError:
|
||||||
# This always throws an exception because the worker
|
# This always throws an exception because the worker
|
||||||
|
@ -240,13 +255,13 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
||||||
# before this handler, however it deletes the actor correctly.
|
# before this handler, however it deletes the actor correctly.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return rest_response(
|
return rest_response(success=True, message=f"Killed actor with id {actor_id}")
|
||||||
success=True, message=f"Killed actor with id {actor_id}")
|
|
||||||
|
|
||||||
async def run(self, server):
|
async def run(self, server):
|
||||||
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
|
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
|
||||||
self._gcs_actor_info_stub = \
|
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
|
||||||
gcs_service_pb2_grpc.ActorInfoGcsServiceStub(gcs_channel)
|
gcs_channel
|
||||||
|
)
|
||||||
|
|
||||||
await asyncio.gather(self._update_actors())
|
await asyncio.gather(self._update_actors())
|
||||||
|
|
||||||
|
|
|
@ -7,27 +7,29 @@ PYCLASSNAME_RE = re.compile(r"(.+?)\(")
|
||||||
|
|
||||||
def construct_actor_groups(actors):
|
def construct_actor_groups(actors):
|
||||||
"""actors is a dict from actor id to an actor or an
|
"""actors is a dict from actor id to an actor or an
|
||||||
actor creation task The shared fields currently are
|
actor creation task The shared fields currently are
|
||||||
"actorClass", "actorId", and "state" """
|
"actorClass", "actorId", and "state" """
|
||||||
actor_groups = _group_actors_by_python_class(actors)
|
actor_groups = _group_actors_by_python_class(actors)
|
||||||
stats_by_group = {
|
stats_by_group = {
|
||||||
name: _get_actor_group_stats(group)
|
name: _get_actor_group_stats(group) for name, group in actor_groups.items()
|
||||||
for name, group in actor_groups.items()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
summarized_actor_groups = {}
|
summarized_actor_groups = {}
|
||||||
for name, group in actor_groups.items():
|
for name, group in actor_groups.items():
|
||||||
summarized_actor_groups[name] = {
|
summarized_actor_groups[name] = {
|
||||||
"entries": group,
|
"entries": group,
|
||||||
"summary": stats_by_group[name]
|
"summary": stats_by_group[name],
|
||||||
}
|
}
|
||||||
return summarized_actor_groups
|
return summarized_actor_groups
|
||||||
|
|
||||||
|
|
||||||
def actor_classname_from_task_spec(task_spec):
|
def actor_classname_from_task_spec(task_spec):
|
||||||
return task_spec.get("functionDescriptor", {})\
|
return (
|
||||||
.get("pythonFunctionDescriptor", {})\
|
task_spec.get("functionDescriptor", {})
|
||||||
.get("className", "Unknown actor class").split(".")[-1]
|
.get("pythonFunctionDescriptor", {})
|
||||||
|
.get("className", "Unknown actor class")
|
||||||
|
.split(".")[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _group_actors_by_python_class(actors):
|
def _group_actors_by_python_class(actors):
|
||||||
|
|
|
@ -50,8 +50,7 @@ def test_actor_groups(ray_start_with_dashboard):
|
||||||
response = requests.get(webui_url + "/logical/actor_groups")
|
response = requests.get(webui_url + "/logical/actor_groups")
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
actor_groups_resp = response.json()
|
actor_groups_resp = response.json()
|
||||||
assert actor_groups_resp["result"] is True, actor_groups_resp[
|
assert actor_groups_resp["result"] is True, actor_groups_resp["msg"]
|
||||||
"msg"]
|
|
||||||
actor_groups = actor_groups_resp["data"]["actorGroups"]
|
actor_groups = actor_groups_resp["data"]["actorGroups"]
|
||||||
assert "Foo" in actor_groups
|
assert "Foo" in actor_groups
|
||||||
summary = actor_groups["Foo"]["summary"]
|
summary = actor_groups["Foo"]["summary"]
|
||||||
|
@ -78,9 +77,13 @@ def test_actor_groups(ray_start_with_dashboard):
|
||||||
last_ex = ex
|
last_ex = ex
|
||||||
finally:
|
finally:
|
||||||
if time.time() > start_time + timeout_seconds:
|
if time.time() > start_time + timeout_seconds:
|
||||||
ex_stack = traceback.format_exception(
|
ex_stack = (
|
||||||
type(last_ex), last_ex,
|
traceback.format_exception(
|
||||||
last_ex.__traceback__) if last_ex else []
|
type(last_ex), last_ex, last_ex.__traceback__
|
||||||
|
)
|
||||||
|
if last_ex
|
||||||
|
else []
|
||||||
|
)
|
||||||
ex_stack = "".join(ex_stack)
|
ex_stack = "".join(ex_stack)
|
||||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||||
|
|
||||||
|
@ -135,9 +138,13 @@ def test_actors(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
last_ex = ex
|
last_ex = ex
|
||||||
finally:
|
finally:
|
||||||
if time.time() > start_time + timeout_seconds:
|
if time.time() > start_time + timeout_seconds:
|
||||||
ex_stack = traceback.format_exception(
|
ex_stack = (
|
||||||
type(last_ex), last_ex,
|
traceback.format_exception(
|
||||||
last_ex.__traceback__) if last_ex else []
|
type(last_ex), last_ex, last_ex.__traceback__
|
||||||
|
)
|
||||||
|
if last_ex
|
||||||
|
else []
|
||||||
|
)
|
||||||
ex_stack = "".join(ex_stack)
|
ex_stack = "".join(ex_stack)
|
||||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||||
|
|
||||||
|
@ -183,8 +190,9 @@ def test_kill_actor(ray_start_with_dashboard):
|
||||||
params={
|
params={
|
||||||
"actorId": actor["actorId"],
|
"actorId": actor["actorId"],
|
||||||
"ipAddress": actor["ipAddress"],
|
"ipAddress": actor["ipAddress"],
|
||||||
"port": actor["port"]
|
"port": actor["port"],
|
||||||
})
|
},
|
||||||
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
resp_json = resp.json()
|
resp_json = resp.json()
|
||||||
assert resp_json["result"] is True, "msg" in resp_json
|
assert resp_json["result"] is True, "msg" in resp_json
|
||||||
|
@ -199,19 +207,17 @@ def test_kill_actor(ray_start_with_dashboard):
|
||||||
break
|
break
|
||||||
except (KeyError, AssertionError) as e:
|
except (KeyError, AssertionError) as e:
|
||||||
last_exc = e
|
last_exc = e
|
||||||
time.sleep(.1)
|
time.sleep(0.1)
|
||||||
assert last_exc is None
|
assert last_exc is None
|
||||||
|
|
||||||
|
|
||||||
def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
timeout = 5
|
timeout = 5
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
is True)
|
|
||||||
address_info = ray_start_with_dashboard
|
address_info = ray_start_with_dashboard
|
||||||
|
|
||||||
if gcs_pubsub.gcs_pubsub_enabled():
|
if gcs_pubsub.gcs_pubsub_enabled():
|
||||||
sub = gcs_pubsub.GcsActorSubscriber(
|
sub = gcs_pubsub.GcsActorSubscriber(address=address_info["gcs_address"])
|
||||||
address=address_info["gcs_address"])
|
|
||||||
sub.subscribe()
|
sub.subscribe()
|
||||||
else:
|
else:
|
||||||
address = address_info["redis_address"]
|
address = address_info["redis_address"]
|
||||||
|
@ -221,7 +227,8 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
client = redis.StrictRedis(
|
client = redis.StrictRedis(
|
||||||
host=address[0],
|
host=address[0],
|
||||||
port=int(address[1]),
|
port=int(address[1]),
|
||||||
password=ray_constants.REDIS_DEFAULT_PASSWORD)
|
password=ray_constants.REDIS_DEFAULT_PASSWORD,
|
||||||
|
)
|
||||||
|
|
||||||
sub = client.pubsub(ignore_subscribe_messages=True)
|
sub = client.pubsub(ignore_subscribe_messages=True)
|
||||||
sub.psubscribe(gcs_utils.RAY_ACTOR_PUBSUB_PATTERN)
|
sub.psubscribe(gcs_utils.RAY_ACTOR_PUBSUB_PATTERN)
|
||||||
|
@ -245,8 +252,7 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
continue
|
continue
|
||||||
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
|
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
|
||||||
actor_data = gcs_utils.ActorTableData.FromString(
|
actor_data = gcs_utils.ActorTableData.FromString(pubsub_msg.data)
|
||||||
pubsub_msg.data)
|
|
||||||
if actor_data is None:
|
if actor_data is None:
|
||||||
continue
|
continue
|
||||||
msgs.append(actor_data)
|
msgs.append(actor_data)
|
||||||
|
@ -266,12 +272,22 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
|
|
||||||
def actor_table_data_to_dict(message):
|
def actor_table_data_to_dict(message):
|
||||||
return dashboard_utils.message_to_dict(
|
return dashboard_utils.message_to_dict(
|
||||||
message, {
|
message,
|
||||||
"actorId", "parentId", "jobId", "workerId", "rayletId",
|
{
|
||||||
"actorCreationDummyObjectId", "callerId", "taskId",
|
"actorId",
|
||||||
"parentTaskId", "sourceActorId", "placementGroupId"
|
"parentId",
|
||||||
|
"jobId",
|
||||||
|
"workerId",
|
||||||
|
"rayletId",
|
||||||
|
"actorCreationDummyObjectId",
|
||||||
|
"callerId",
|
||||||
|
"taskId",
|
||||||
|
"parentTaskId",
|
||||||
|
"sourceActorId",
|
||||||
|
"placementGroupId",
|
||||||
},
|
},
|
||||||
including_default_value_fields=False)
|
including_default_value_fields=False,
|
||||||
|
)
|
||||||
|
|
||||||
non_state_keys = ("actorId", "jobId", "taskSpec")
|
non_state_keys = ("actorId", "jobId", "taskSpec")
|
||||||
|
|
||||||
|
@ -287,23 +303,31 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
# be published.
|
# be published.
|
||||||
elif actor_data_dict["state"] in ("ALIVE", "DEAD"):
|
elif actor_data_dict["state"] in ("ALIVE", "DEAD"):
|
||||||
assert actor_data_dict.keys() >= {
|
assert actor_data_dict.keys() >= {
|
||||||
"state", "address", "timestamp", "pid", "rayNamespace"
|
"state",
|
||||||
|
"address",
|
||||||
|
"timestamp",
|
||||||
|
"pid",
|
||||||
|
"rayNamespace",
|
||||||
}
|
}
|
||||||
elif actor_data_dict["state"] == "PENDING_CREATION":
|
elif actor_data_dict["state"] == "PENDING_CREATION":
|
||||||
assert actor_data_dict.keys() == {
|
assert actor_data_dict.keys() == {
|
||||||
"state", "address", "actorId", "actorCreationDummyObjectId",
|
"state",
|
||||||
"jobId", "ownerAddress", "taskSpec", "className",
|
"address",
|
||||||
"serializedRuntimeEnv", "rayNamespace"
|
"actorId",
|
||||||
|
"actorCreationDummyObjectId",
|
||||||
|
"jobId",
|
||||||
|
"ownerAddress",
|
||||||
|
"taskSpec",
|
||||||
|
"className",
|
||||||
|
"serializedRuntimeEnv",
|
||||||
|
"rayNamespace",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise Exception("Unknown state: {}".format(
|
raise Exception("Unknown state: {}".format(actor_data_dict["state"]))
|
||||||
actor_data_dict["state"]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_nil_node(enable_test_module, disable_aiohttp_cache,
|
def test_nil_node(enable_test_module, disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
ray_start_with_dashboard):
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
|
||||||
is True)
|
|
||||||
webui_url = ray_start_with_dashboard["webui_url"]
|
webui_url = ray_start_with_dashboard["webui_url"]
|
||||||
assert wait_until_server_available(webui_url)
|
assert wait_until_server_available(webui_url)
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
|
@ -334,9 +358,13 @@ def test_nil_node(enable_test_module, disable_aiohttp_cache,
|
||||||
last_ex = ex
|
last_ex = ex
|
||||||
finally:
|
finally:
|
||||||
if time.time() > start_time + timeout_seconds:
|
if time.time() > start_time + timeout_seconds:
|
||||||
ex_stack = traceback.format_exception(
|
ex_stack = (
|
||||||
type(last_ex), last_ex,
|
traceback.format_exception(
|
||||||
last_ex.__traceback__) if last_ex else []
|
type(last_ex), last_ex, last_ex.__traceback__
|
||||||
|
)
|
||||||
|
if last_ex
|
||||||
|
else []
|
||||||
|
)
|
||||||
ex_stack = "".join(ex_stack)
|
ex_stack = "".join(ex_stack)
|
||||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||||
|
|
||||||
|
|
|
@ -24,13 +24,11 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
|
||||||
os.makedirs(self._event_dir, exist_ok=True)
|
os.makedirs(self._event_dir, exist_ok=True)
|
||||||
self._monitor: Union[asyncio.Task, None] = None
|
self._monitor: Union[asyncio.Task, None] = None
|
||||||
self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None
|
self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None
|
||||||
self._cached_events = asyncio.Queue(
|
self._cached_events = asyncio.Queue(event_consts.EVENT_AGENT_CACHE_SIZE)
|
||||||
event_consts.EVENT_AGENT_CACHE_SIZE)
|
logger.info("Event agent cache buffer size: %s", self._cached_events.maxsize)
|
||||||
logger.info("Event agent cache buffer size: %s",
|
|
||||||
self._cached_events.maxsize)
|
|
||||||
|
|
||||||
async def _connect_to_dashboard(self):
|
async def _connect_to_dashboard(self):
|
||||||
""" Connect to the dashboard. If the dashboard is not started, then
|
"""Connect to the dashboard. If the dashboard is not started, then
|
||||||
this method will never returns.
|
this method will never returns.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -41,23 +39,24 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
|
||||||
# TODO: Use async version if performance is an issue
|
# TODO: Use async version if performance is an issue
|
||||||
dashboard_rpc_address = internal_kv._internal_kv_get(
|
dashboard_rpc_address = internal_kv._internal_kv_get(
|
||||||
dashboard_consts.DASHBOARD_RPC_ADDRESS,
|
dashboard_consts.DASHBOARD_RPC_ADDRESS,
|
||||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||||
|
)
|
||||||
if dashboard_rpc_address:
|
if dashboard_rpc_address:
|
||||||
logger.info("Report events to %s", dashboard_rpc_address)
|
logger.info("Report events to %s", dashboard_rpc_address)
|
||||||
options = (("grpc.enable_http_proxy", 0), )
|
options = (("grpc.enable_http_proxy", 0),)
|
||||||
channel = utils.init_grpc_channel(
|
channel = utils.init_grpc_channel(
|
||||||
dashboard_rpc_address,
|
dashboard_rpc_address, options=options, asynchronous=True
|
||||||
options=options,
|
)
|
||||||
asynchronous=True)
|
|
||||||
return event_pb2_grpc.ReportEventServiceStub(channel)
|
return event_pb2_grpc.ReportEventServiceStub(channel)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Connect to dashboard failed.")
|
logger.exception("Connect to dashboard failed.")
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
event_consts.RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS)
|
event_consts.RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS
|
||||||
|
)
|
||||||
|
|
||||||
@async_loop_forever(event_consts.EVENT_AGENT_REPORT_INTERVAL_SECONDS)
|
@async_loop_forever(event_consts.EVENT_AGENT_REPORT_INTERVAL_SECONDS)
|
||||||
async def report_events(self):
|
async def report_events(self):
|
||||||
""" Report events from cached events queue. Reconnect to dashboard if
|
"""Report events from cached events queue. Reconnect to dashboard if
|
||||||
report failed. Log error after retry EVENT_AGENT_RETRY_TIMES.
|
report failed. Log error after retry EVENT_AGENT_RETRY_TIMES.
|
||||||
|
|
||||||
This method will never returns.
|
This method will never returns.
|
||||||
|
@ -70,14 +69,15 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
|
||||||
await self._stub.ReportEvents(request)
|
await self._stub.ReportEvents(request)
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Report event failed, reconnect to the "
|
logger.exception("Report event failed, reconnect to the " "dashboard.")
|
||||||
"dashboard.")
|
|
||||||
self._stub = await self._connect_to_dashboard()
|
self._stub = await self._connect_to_dashboard()
|
||||||
else:
|
else:
|
||||||
data_str = str(data)
|
data_str = str(data)
|
||||||
limit = event_consts.LOG_ERROR_EVENT_STRING_LENGTH_LIMIT
|
limit = event_consts.LOG_ERROR_EVENT_STRING_LENGTH_LIMIT
|
||||||
logger.error("Report event failed: %s",
|
logger.error(
|
||||||
data_str[:limit] + (data_str[limit:] and "..."))
|
"Report event failed: %s",
|
||||||
|
data_str[:limit] + (data_str[limit:] and "..."),
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self, server):
|
async def run(self, server):
|
||||||
# Connect to dashboard.
|
# Connect to dashboard.
|
||||||
|
@ -86,7 +86,8 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
|
||||||
self._monitor = monitor_events(
|
self._monitor = monitor_events(
|
||||||
self._event_dir,
|
self._event_dir,
|
||||||
lambda data: create_task(self._cached_events.put(data)),
|
lambda data: create_task(self._cached_events.put(data)),
|
||||||
source_types=event_consts.EVENT_AGENT_MONITOR_SOURCE_TYPES)
|
source_types=event_consts.EVENT_AGENT_MONITOR_SOURCE_TYPES,
|
||||||
|
)
|
||||||
# Start reporting events.
|
# Start reporting events.
|
||||||
await self.report_events()
|
await self.report_events()
|
||||||
|
|
||||||
|
|
|
@ -4,22 +4,20 @@ from ray.core.generated import event_pb2
|
||||||
LOG_ERROR_EVENT_STRING_LENGTH_LIMIT = 1000
|
LOG_ERROR_EVENT_STRING_LENGTH_LIMIT = 1000
|
||||||
RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS = 2
|
RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS = 2
|
||||||
# Monitor events
|
# Monitor events
|
||||||
SCAN_EVENT_DIR_INTERVAL_SECONDS = env_integer(
|
SCAN_EVENT_DIR_INTERVAL_SECONDS = env_integer("SCAN_EVENT_DIR_INTERVAL_SECONDS", 2)
|
||||||
"SCAN_EVENT_DIR_INTERVAL_SECONDS", 2)
|
|
||||||
SCAN_EVENT_START_OFFSET_SECONDS = -30 * 60
|
SCAN_EVENT_START_OFFSET_SECONDS = -30 * 60
|
||||||
CONCURRENT_READ_LIMIT = 50
|
CONCURRENT_READ_LIMIT = 50
|
||||||
EVENT_READ_LINE_COUNT_LIMIT = 200
|
EVENT_READ_LINE_COUNT_LIMIT = 200
|
||||||
EVENT_READ_LINE_LENGTH_LIMIT = env_integer("EVENT_READ_LINE_LENGTH_LIMIT",
|
EVENT_READ_LINE_LENGTH_LIMIT = env_integer(
|
||||||
2 * 1024 * 1024) # 2MB
|
"EVENT_READ_LINE_LENGTH_LIMIT", 2 * 1024 * 1024
|
||||||
|
) # 2MB
|
||||||
# Report events
|
# Report events
|
||||||
EVENT_AGENT_REPORT_INTERVAL_SECONDS = 0.1
|
EVENT_AGENT_REPORT_INTERVAL_SECONDS = 0.1
|
||||||
EVENT_AGENT_RETRY_TIMES = 10
|
EVENT_AGENT_RETRY_TIMES = 10
|
||||||
EVENT_AGENT_CACHE_SIZE = 10240
|
EVENT_AGENT_CACHE_SIZE = 10240
|
||||||
# Event sources
|
# Event sources
|
||||||
EVENT_HEAD_MONITOR_SOURCE_TYPES = [
|
EVENT_HEAD_MONITOR_SOURCE_TYPES = [event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)]
|
||||||
event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)
|
|
||||||
]
|
|
||||||
EVENT_AGENT_MONITOR_SOURCE_TYPES = list(
|
EVENT_AGENT_MONITOR_SOURCE_TYPES = list(
|
||||||
set(event_pb2.Event.SourceType.keys()) -
|
set(event_pb2.Event.SourceType.keys()) - set(EVENT_HEAD_MONITOR_SOURCE_TYPES)
|
||||||
set(EVENT_HEAD_MONITOR_SOURCE_TYPES))
|
)
|
||||||
EVENT_SOURCE_ALL = event_pb2.Event.SourceType.keys()
|
EVENT_SOURCE_ALL = event_pb2.Event.SourceType.keys()
|
||||||
|
|
|
@ -24,8 +24,9 @@ JobEvents = OrderedDict
|
||||||
dashboard_utils._json_compatible_types.add(JobEvents)
|
dashboard_utils._json_compatible_types.add(JobEvents)
|
||||||
|
|
||||||
|
|
||||||
class EventHead(dashboard_utils.DashboardHeadModule,
|
class EventHead(
|
||||||
event_pb2_grpc.ReportEventServiceServicer):
|
dashboard_utils.DashboardHeadModule, event_pb2_grpc.ReportEventServiceServicer
|
||||||
|
):
|
||||||
def __init__(self, dashboard_head):
|
def __init__(self, dashboard_head):
|
||||||
super().__init__(dashboard_head)
|
super().__init__(dashboard_head)
|
||||||
self._event_dir = os.path.join(self._dashboard_head.log_dir, "events")
|
self._event_dir = os.path.join(self._dashboard_head.log_dir, "events")
|
||||||
|
@ -70,21 +71,24 @@ class EventHead(dashboard_utils.DashboardHeadModule,
|
||||||
for job_id, job_events in DataSource.events.items()
|
for job_id, job_events in DataSource.events.items()
|
||||||
}
|
}
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True, message="All events fetched.", events=all_events)
|
success=True, message="All events fetched.", events=all_events
|
||||||
|
)
|
||||||
|
|
||||||
job_events = DataSource.events.get(job_id, {})
|
job_events = DataSource.events.get(job_id, {})
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True,
|
success=True,
|
||||||
message="Job events fetched.",
|
message="Job events fetched.",
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
events=list(job_events.values()))
|
events=list(job_events.values()),
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self, server):
|
async def run(self, server):
|
||||||
event_pb2_grpc.add_ReportEventServiceServicer_to_server(self, server)
|
event_pb2_grpc.add_ReportEventServiceServicer_to_server(self, server)
|
||||||
self._monitor = monitor_events(
|
self._monitor = monitor_events(
|
||||||
self._event_dir,
|
self._event_dir,
|
||||||
lambda data: self._update_events(parse_event_strings(data)),
|
lambda data: self._update_events(parse_event_strings(data)),
|
||||||
source_types=event_consts.EVENT_HEAD_MONITOR_SOURCE_TYPES)
|
source_types=event_consts.EVENT_HEAD_MONITOR_SOURCE_TYPES,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_minimal_module():
|
def is_minimal_module():
|
||||||
|
|
|
@ -19,8 +19,7 @@ def _get_source_files(event_dir, source_types=None, event_file_filter=None):
|
||||||
source_files = {}
|
source_files = {}
|
||||||
all_source_types = set(event_consts.EVENT_SOURCE_ALL)
|
all_source_types = set(event_consts.EVENT_SOURCE_ALL)
|
||||||
for source_type in source_types or event_consts.EVENT_SOURCE_ALL:
|
for source_type in source_types or event_consts.EVENT_SOURCE_ALL:
|
||||||
assert source_type in all_source_types, \
|
assert source_type in all_source_types, f"Invalid source type: {source_type}"
|
||||||
f"Invalid source type: {source_type}"
|
|
||||||
files = []
|
files = []
|
||||||
for n in event_log_names:
|
for n in event_log_names:
|
||||||
if fnmatch.fnmatch(n, f"*{source_type}*"):
|
if fnmatch.fnmatch(n, f"*{source_type}*"):
|
||||||
|
@ -35,9 +34,9 @@ def _get_source_files(event_dir, source_types=None, event_file_filter=None):
|
||||||
|
|
||||||
def _restore_newline(event_dict):
|
def _restore_newline(event_dict):
|
||||||
try:
|
try:
|
||||||
event_dict["message"] = event_dict["message"]\
|
event_dict["message"] = (
|
||||||
.replace("\\n", "\n")\
|
event_dict["message"].replace("\\n", "\n").replace("\\r", "\n")
|
||||||
.replace("\\r", "\n")
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Restore newline for event failed: %s", event_dict)
|
logger.exception("Restore newline for event failed: %s", event_dict)
|
||||||
return event_dict
|
return event_dict
|
||||||
|
@ -61,13 +60,13 @@ def parse_event_strings(event_string_list):
|
||||||
|
|
||||||
|
|
||||||
ReadFileResult = collections.namedtuple(
|
ReadFileResult = collections.namedtuple(
|
||||||
"ReadFileResult", ["fid", "size", "mtime", "position", "lines"])
|
"ReadFileResult", ["fid", "size", "mtime", "position", "lines"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _read_file(file,
|
def _read_file(
|
||||||
pos,
|
file, pos, n_lines=event_consts.EVENT_READ_LINE_COUNT_LIMIT, closefd=True
|
||||||
n_lines=event_consts.EVENT_READ_LINE_COUNT_LIMIT,
|
):
|
||||||
closefd=True):
|
|
||||||
with open(file, "rb", closefd=closefd) as f:
|
with open(file, "rb", closefd=closefd) as f:
|
||||||
# The ino may be 0 on Windows.
|
# The ino may be 0 on Windows.
|
||||||
stat = os.stat(f.fileno())
|
stat = os.stat(f.fileno())
|
||||||
|
@ -82,24 +81,25 @@ def _read_file(file,
|
||||||
if sep - start <= event_consts.EVENT_READ_LINE_LENGTH_LIMIT:
|
if sep - start <= event_consts.EVENT_READ_LINE_LENGTH_LIMIT:
|
||||||
lines.append(mm[start:sep].decode("utf-8"))
|
lines.append(mm[start:sep].decode("utf-8"))
|
||||||
else:
|
else:
|
||||||
truncated_size = min(
|
truncated_size = min(100, event_consts.EVENT_READ_LINE_LENGTH_LIMIT)
|
||||||
100, event_consts.EVENT_READ_LINE_LENGTH_LIMIT)
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Ignored long string: %s...(%s chars)",
|
"Ignored long string: %s...(%s chars)",
|
||||||
mm[start:start + truncated_size].decode("utf-8"),
|
mm[start : start + truncated_size].decode("utf-8"),
|
||||||
sep - start)
|
sep - start,
|
||||||
|
)
|
||||||
start = sep + 1
|
start = sep + 1
|
||||||
return ReadFileResult(fid, stat.st_size, stat.st_mtime, start, lines)
|
return ReadFileResult(fid, stat.st_size, stat.st_mtime, start, lines)
|
||||||
|
|
||||||
|
|
||||||
def monitor_events(
|
def monitor_events(
|
||||||
event_dir,
|
event_dir,
|
||||||
callback,
|
callback,
|
||||||
scan_interval_seconds=event_consts.SCAN_EVENT_DIR_INTERVAL_SECONDS,
|
scan_interval_seconds=event_consts.SCAN_EVENT_DIR_INTERVAL_SECONDS,
|
||||||
start_mtime=time.time() + event_consts.SCAN_EVENT_START_OFFSET_SECONDS,
|
start_mtime=time.time() + event_consts.SCAN_EVENT_START_OFFSET_SECONDS,
|
||||||
monitor_files=None,
|
monitor_files=None,
|
||||||
source_types=None):
|
source_types=None,
|
||||||
""" Monitor events in directory. New events will be read and passed to the
|
):
|
||||||
|
"""Monitor events in directory. New events will be read and passed to the
|
||||||
callback.
|
callback.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -121,20 +121,22 @@ def monitor_events(
|
||||||
monitor_files = {}
|
monitor_files = {}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Monitor events logs modified after %s on %s, "
|
"Monitor events logs modified after %s on %s, " "the source types are %s.",
|
||||||
"the source types are %s.", start_mtime, event_dir, "all"
|
start_mtime,
|
||||||
if source_types is None else source_types)
|
event_dir,
|
||||||
|
"all" if source_types is None else source_types,
|
||||||
|
)
|
||||||
|
|
||||||
MonitorFile = collections.namedtuple("MonitorFile",
|
MonitorFile = collections.namedtuple("MonitorFile", ["size", "mtime", "position"])
|
||||||
["size", "mtime", "position"])
|
|
||||||
|
|
||||||
def _source_file_filter(source_file):
|
def _source_file_filter(source_file):
|
||||||
stat = os.stat(source_file)
|
stat = os.stat(source_file)
|
||||||
return stat.st_mtime > start_mtime
|
return stat.st_mtime > start_mtime
|
||||||
|
|
||||||
def _read_monitor_file(file, pos):
|
def _read_monitor_file(file, pos):
|
||||||
assert isinstance(file, str), \
|
assert isinstance(
|
||||||
f"File should be a str, but a {type(file)}({file}) found"
|
file, str
|
||||||
|
), f"File should be a str, but a {type(file)}({file}) found"
|
||||||
fd = os.open(file, os.O_RDONLY)
|
fd = os.open(file, os.O_RDONLY)
|
||||||
try:
|
try:
|
||||||
stat = os.stat(fd)
|
stat = os.stat(fd)
|
||||||
|
@ -145,12 +147,14 @@ def monitor_events(
|
||||||
fid = stat.st_ino or file
|
fid = stat.st_ino or file
|
||||||
monitor_file = monitor_files.get(fid)
|
monitor_file = monitor_files.get(fid)
|
||||||
if monitor_file:
|
if monitor_file:
|
||||||
if (monitor_file.position == monitor_file.size
|
if (
|
||||||
and monitor_file.size == stat.st_size
|
monitor_file.position == monitor_file.size
|
||||||
and monitor_file.mtime == stat.st_mtime):
|
and monitor_file.size == stat.st_size
|
||||||
|
and monitor_file.mtime == stat.st_mtime
|
||||||
|
):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Skip reading the file because "
|
"Skip reading the file because " "there is no change: %s", file
|
||||||
"there is no change: %s", file)
|
)
|
||||||
return []
|
return []
|
||||||
position = monitor_file.position
|
position = monitor_file.position
|
||||||
else:
|
else:
|
||||||
|
@ -169,22 +173,23 @@ def monitor_events(
|
||||||
@async_loop_forever(scan_interval_seconds, cancellable=True)
|
@async_loop_forever(scan_interval_seconds, cancellable=True)
|
||||||
async def _scan_event_log_files():
|
async def _scan_event_log_files():
|
||||||
# Scan event files.
|
# Scan event files.
|
||||||
source_files = await loop.run_in_executor(None, _get_source_files,
|
source_files = await loop.run_in_executor(
|
||||||
event_dir, source_types,
|
None, _get_source_files, event_dir, source_types, _source_file_filter
|
||||||
_source_file_filter)
|
)
|
||||||
|
|
||||||
# Limit concurrent read to avoid fd exhaustion.
|
# Limit concurrent read to avoid fd exhaustion.
|
||||||
semaphore = asyncio.Semaphore(event_consts.CONCURRENT_READ_LIMIT)
|
semaphore = asyncio.Semaphore(event_consts.CONCURRENT_READ_LIMIT)
|
||||||
|
|
||||||
async def _concurrent_coro(filename):
|
async def _concurrent_coro(filename):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
return await loop.run_in_executor(None, _read_monitor_file,
|
return await loop.run_in_executor(None, _read_monitor_file, filename, 0)
|
||||||
filename, 0)
|
|
||||||
|
|
||||||
# Read files.
|
# Read files.
|
||||||
await asyncio.gather(*[
|
await asyncio.gather(
|
||||||
_concurrent_coro(filename)
|
*[
|
||||||
for filename in list(itertools.chain(*source_files.values()))
|
_concurrent_coro(filename)
|
||||||
])
|
for filename in list(itertools.chain(*source_files.values()))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return create_task(_scan_event_log_files())
|
return create_task(_scan_event_log_files())
|
||||||
|
|
|
@ -23,7 +23,8 @@ from ray._private.test_utils import (
|
||||||
wait_for_condition,
|
wait_for_condition,
|
||||||
)
|
)
|
||||||
from ray.dashboard.modules.event.event_utils import (
|
from ray.dashboard.modules.event.event_utils import (
|
||||||
monitor_events, )
|
monitor_events,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -32,7 +33,8 @@ def _get_event(msg="empty message", job_id=None, source_type=None):
|
||||||
return {
|
return {
|
||||||
"event_id": binary_to_hex(np.random.bytes(18)),
|
"event_id": binary_to_hex(np.random.bytes(18)),
|
||||||
"source_type": random.choice(event_pb2.Event.SourceType.keys())
|
"source_type": random.choice(event_pb2.Event.SourceType.keys())
|
||||||
if source_type is None else source_type,
|
if source_type is None
|
||||||
|
else source_type,
|
||||||
"host_name": "po-dev.inc.alipay.net",
|
"host_name": "po-dev.inc.alipay.net",
|
||||||
"pid": random.randint(1, 65536),
|
"pid": random.randint(1, 65536),
|
||||||
"label": "",
|
"label": "",
|
||||||
|
@ -41,16 +43,18 @@ def _get_event(msg="empty message", job_id=None, source_type=None):
|
||||||
"severity": "INFO",
|
"severity": "INFO",
|
||||||
"custom_fields": {
|
"custom_fields": {
|
||||||
"job_id": ray.JobID.from_int(random.randint(1, 100)).hex()
|
"job_id": ray.JobID.from_int(random.randint(1, 100)).hex()
|
||||||
if job_id is None else job_id,
|
if job_id is None
|
||||||
|
else job_id,
|
||||||
"node_id": "",
|
"node_id": "",
|
||||||
"task_id": "",
|
"task_id": "",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _test_logger(name, log_file, max_bytes, backup_count):
|
def _test_logger(name, log_file, max_bytes, backup_count):
|
||||||
handler = logging.handlers.RotatingFileHandler(
|
handler = logging.handlers.RotatingFileHandler(
|
||||||
log_file, maxBytes=max_bytes, backupCount=backup_count)
|
log_file, maxBytes=max_bytes, backupCount=backup_count
|
||||||
|
)
|
||||||
formatter = logging.Formatter("%(message)s")
|
formatter = logging.Formatter("%(message)s")
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
@ -63,15 +67,14 @@ def _test_logger(name, log_file, max_bytes, backup_count):
|
||||||
|
|
||||||
|
|
||||||
def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
|
def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||||
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
||||||
session_dir = ray_start_with_dashboard["session_dir"]
|
session_dir = ray_start_with_dashboard["session_dir"]
|
||||||
event_dir = os.path.join(session_dir, "logs", "events")
|
event_dir = os.path.join(session_dir, "logs", "events")
|
||||||
job_id = ray.JobID.from_int(100).hex()
|
job_id = ray.JobID.from_int(100).hex()
|
||||||
|
|
||||||
source_type_gcs = event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)
|
source_type_gcs = event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)
|
||||||
source_type_raylet = event_pb2.Event.SourceType.Name(
|
source_type_raylet = event_pb2.Event.SourceType.Name(event_pb2.Event.RAYLET)
|
||||||
event_pb2.Event.RAYLET)
|
|
||||||
test_count = 20
|
test_count = 20
|
||||||
|
|
||||||
for source_type in [source_type_gcs, source_type_raylet]:
|
for source_type in [source_type_gcs, source_type_raylet]:
|
||||||
|
@ -80,10 +83,10 @@ def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
__name__ + str(random.random()),
|
__name__ + str(random.random()),
|
||||||
test_log_file,
|
test_log_file,
|
||||||
max_bytes=2000,
|
max_bytes=2000,
|
||||||
backup_count=1000)
|
backup_count=1000,
|
||||||
|
)
|
||||||
for i in range(test_count):
|
for i in range(test_count):
|
||||||
sample_event = _get_event(
|
sample_event = _get_event(str(i), job_id=job_id, source_type=source_type)
|
||||||
str(i), job_id=job_id, source_type=source_type)
|
|
||||||
test_logger.info("%s", json.dumps(sample_event))
|
test_logger.info("%s", json.dumps(sample_event))
|
||||||
|
|
||||||
def _check_events():
|
def _check_events():
|
||||||
|
@ -112,10 +115,11 @@ def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
wait_for_condition(_check_events, timeout=15)
|
wait_for_condition(_check_events, timeout=15)
|
||||||
|
|
||||||
|
|
||||||
def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache,
|
def test_event_message_limit(
|
||||||
ray_start_with_dashboard):
|
small_event_line_limit, disable_aiohttp_cache, ray_start_with_dashboard
|
||||||
|
):
|
||||||
event_read_line_length_limit = small_event_line_limit
|
event_read_line_length_limit = small_event_line_limit
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||||
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
||||||
session_dir = ray_start_with_dashboard["session_dir"]
|
session_dir = ray_start_with_dashboard["session_dir"]
|
||||||
event_dir = os.path.join(session_dir, "logs", "events")
|
event_dir = os.path.join(session_dir, "logs", "events")
|
||||||
|
@ -148,8 +152,8 @@ def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache,
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
os.rename(
|
os.rename(
|
||||||
os.path.join(event_dir, "tmp.log"),
|
os.path.join(event_dir, "tmp.log"), os.path.join(event_dir, "event_GCS.log")
|
||||||
os.path.join(event_dir, "event_GCS.log"))
|
)
|
||||||
|
|
||||||
def _check_events():
|
def _check_events():
|
||||||
try:
|
try:
|
||||||
|
@ -157,14 +161,14 @@ def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache,
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
result = resp.json()
|
result = resp.json()
|
||||||
all_events = result["data"]["events"]
|
all_events = result["data"]["events"]
|
||||||
assert len(all_events[job_id]
|
assert (
|
||||||
) >= event_consts.EVENT_READ_LINE_COUNT_LIMIT + 10
|
len(all_events[job_id]) >= event_consts.EVENT_READ_LINE_COUNT_LIMIT + 10
|
||||||
|
)
|
||||||
messages = [e["message"] for e in all_events[job_id]]
|
messages = [e["message"] for e in all_events[job_id]]
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
assert str(i) * message_len in messages
|
assert str(i) * message_len in messages
|
||||||
assert "2" * (message_len + 1) not in messages
|
assert "2" * (message_len + 1) not in messages
|
||||||
assert str(event_consts.EVENT_READ_LINE_COUNT_LIMIT -
|
assert str(event_consts.EVENT_READ_LINE_COUNT_LIMIT - 1) in messages
|
||||||
1) in messages
|
|
||||||
return True
|
return True
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(ex)
|
logger.exception(ex)
|
||||||
|
@ -179,15 +183,12 @@ async def test_monitor_events():
|
||||||
common = event_pb2.Event.SourceType.Name(event_pb2.Event.COMMON)
|
common = event_pb2.Event.SourceType.Name(event_pb2.Event.COMMON)
|
||||||
common_log = os.path.join(temp_dir, f"event_{common}.log")
|
common_log = os.path.join(temp_dir, f"event_{common}.log")
|
||||||
test_logger = _test_logger(
|
test_logger = _test_logger(
|
||||||
__name__ + str(random.random()),
|
__name__ + str(random.random()), common_log, max_bytes=10, backup_count=10
|
||||||
common_log,
|
)
|
||||||
max_bytes=10,
|
|
||||||
backup_count=10)
|
|
||||||
test_events1 = []
|
test_events1 = []
|
||||||
monitor_task = monitor_events(
|
monitor_task = monitor_events(
|
||||||
temp_dir,
|
temp_dir, lambda x: test_events1.extend(x), scan_interval_seconds=0.01
|
||||||
lambda x: test_events1.extend(x),
|
)
|
||||||
scan_interval_seconds=0.01)
|
|
||||||
assert not monitor_task.done()
|
assert not monitor_task.done()
|
||||||
count = 10
|
count = 10
|
||||||
|
|
||||||
|
@ -206,7 +207,8 @@ async def test_monitor_events():
|
||||||
if time.time() - start_time > timeout:
|
if time.time() - start_time > timeout:
|
||||||
raise TimeoutError(
|
raise TimeoutError(
|
||||||
f"Timeout, read events: {sorted_events}, "
|
f"Timeout, read events: {sorted_events}, "
|
||||||
f"expect events: {expect_events}")
|
f"expect events: {expect_events}"
|
||||||
|
)
|
||||||
if len(sorted_events) == len(expect_events):
|
if len(sorted_events) == len(expect_events):
|
||||||
if sorted_events == expect_events:
|
if sorted_events == expect_events:
|
||||||
break
|
break
|
||||||
|
@ -214,40 +216,37 @@ async def test_monitor_events():
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
_writer(count, read_events=test_events1),
|
_writer(count, read_events=test_events1),
|
||||||
_check_events(
|
_check_events([str(i) for i in range(count)], read_events=test_events1),
|
||||||
[str(i) for i in range(count)], read_events=test_events1))
|
)
|
||||||
|
|
||||||
monitor_task.cancel()
|
monitor_task.cancel()
|
||||||
test_events2 = []
|
test_events2 = []
|
||||||
monitor_task = monitor_events(
|
monitor_task = monitor_events(
|
||||||
temp_dir,
|
temp_dir, lambda x: test_events2.extend(x), scan_interval_seconds=0.1
|
||||||
lambda x: test_events2.extend(x),
|
)
|
||||||
scan_interval_seconds=0.1)
|
|
||||||
|
|
||||||
await _check_events(
|
await _check_events([str(i) for i in range(count)], read_events=test_events2)
|
||||||
[str(i) for i in range(count)], read_events=test_events2)
|
|
||||||
|
|
||||||
await _writer(count, count * 2, read_events=test_events2)
|
await _writer(count, count * 2, read_events=test_events2)
|
||||||
await _check_events(
|
await _check_events(
|
||||||
[str(i) for i in range(count * 2)], read_events=test_events2)
|
[str(i) for i in range(count * 2)], read_events=test_events2
|
||||||
|
)
|
||||||
|
|
||||||
log_file_count = len(os.listdir(temp_dir))
|
log_file_count = len(os.listdir(temp_dir))
|
||||||
|
|
||||||
test_logger = _test_logger(
|
test_logger = _test_logger(
|
||||||
__name__ + str(random.random()),
|
__name__ + str(random.random()), common_log, max_bytes=1000, backup_count=10
|
||||||
common_log,
|
)
|
||||||
max_bytes=1000,
|
|
||||||
backup_count=10)
|
|
||||||
assert len(os.listdir(temp_dir)) == log_file_count
|
assert len(os.listdir(temp_dir)) == log_file_count
|
||||||
|
|
||||||
await _writer(
|
await _writer(count * 2, count * 3, spin=False, read_events=test_events2)
|
||||||
count * 2, count * 3, spin=False, read_events=test_events2)
|
|
||||||
await _check_events(
|
await _check_events(
|
||||||
[str(i) for i in range(count * 3)], read_events=test_events2)
|
[str(i) for i in range(count * 3)], read_events=test_events2
|
||||||
await _writer(
|
)
|
||||||
count * 3, count * 4, spin=False, read_events=test_events2)
|
await _writer(count * 3, count * 4, spin=False, read_events=test_events2)
|
||||||
await _check_events(
|
await _check_events(
|
||||||
[str(i) for i in range(count * 4)], read_events=test_events2)
|
[str(i) for i in range(count * 4)], read_events=test_events2
|
||||||
|
)
|
||||||
|
|
||||||
# Test cancel monitor task.
|
# Test cancel monitor task.
|
||||||
monitor_task.cancel()
|
monitor_task.cancel()
|
||||||
|
@ -255,8 +254,7 @@ async def test_monitor_events():
|
||||||
await monitor_task
|
await monitor_task
|
||||||
assert monitor_task.done()
|
assert monitor_task.done()
|
||||||
|
|
||||||
assert len(
|
assert len(os.listdir(temp_dir)) > 1, "Event log should have rollovers."
|
||||||
os.listdir(temp_dir)) > 1, "Event log should have rollovers."
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -7,21 +7,21 @@ import yaml
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from ray.autoscaler._private.cli_logger import (add_click_logging_options,
|
from ray.autoscaler._private.cli_logger import add_click_logging_options, cli_logger, cf
|
||||||
cli_logger, cf)
|
|
||||||
from ray.dashboard.modules.job.common import JobStatus
|
from ray.dashboard.modules.job.common import JobStatus
|
||||||
from ray.dashboard.modules.job.sdk import JobSubmissionClient
|
from ray.dashboard.modules.job.sdk import JobSubmissionClient
|
||||||
|
|
||||||
|
|
||||||
def _get_sdk_client(address: Optional[str],
|
def _get_sdk_client(
|
||||||
create_cluster_if_needed: bool = False
|
address: Optional[str], create_cluster_if_needed: bool = False
|
||||||
) -> JobSubmissionClient:
|
) -> JobSubmissionClient:
|
||||||
|
|
||||||
if address is None:
|
if address is None:
|
||||||
if "RAY_ADDRESS" not in os.environ:
|
if "RAY_ADDRESS" not in os.environ:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Address must be specified using either the --address flag "
|
"Address must be specified using either the --address flag "
|
||||||
"or RAY_ADDRESS environment variable.")
|
"or RAY_ADDRESS environment variable."
|
||||||
|
)
|
||||||
address = os.environ["RAY_ADDRESS"]
|
address = os.environ["RAY_ADDRESS"]
|
||||||
|
|
||||||
cli_logger.labeled_value("Job submission server address", address)
|
cli_logger.labeled_value("Job submission server address", address)
|
||||||
|
@ -73,55 +73,67 @@ def job_cli_group():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@job_cli_group.command(
|
@job_cli_group.command("submit", help="Submit a job to be executed on the cluster.")
|
||||||
"submit", help="Submit a job to be executed on the cluster.")
|
|
||||||
@click.option(
|
@click.option(
|
||||||
"--address",
|
"--address",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
help=("Address of the Ray cluster to connect to. Can also be specified "
|
help=(
|
||||||
"using the RAY_ADDRESS environment variable."))
|
"Address of the Ray cluster to connect to. Can also be specified "
|
||||||
|
"using the RAY_ADDRESS environment variable."
|
||||||
|
),
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--job-id",
|
"--job-id",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
help=("Job ID to specify for the job. "
|
help=("Job ID to specify for the job. " "If not provided, one will be generated."),
|
||||||
"If not provided, one will be generated."))
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--runtime-env",
|
"--runtime-env",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
help="Path to a local YAML file containing a runtime_env definition.")
|
help="Path to a local YAML file containing a runtime_env definition.",
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--runtime-env-json",
|
"--runtime-env-json",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
help="JSON-serialized runtime_env dictionary.")
|
help="JSON-serialized runtime_env dictionary.",
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--working-dir",
|
"--working-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
help=("Directory containing files that your job will run in. Can be a "
|
help=(
|
||||||
"local directory or a remote URI to a .zip file (S3, GS, HTTP). "
|
"Directory containing files that your job will run in. Can be a "
|
||||||
"If specified, this overrides the option in --runtime-env."),
|
"local directory or a remote URI to a .zip file (S3, GS, HTTP). "
|
||||||
|
"If specified, this overrides the option in --runtime-env."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--no-wait",
|
"--no-wait",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
type=bool,
|
type=bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="If set, will not stream logs and wait for the job to exit.")
|
help="If set, will not stream logs and wait for the job to exit.",
|
||||||
|
)
|
||||||
@add_click_logging_options
|
@add_click_logging_options
|
||||||
@click.argument("entrypoint", nargs=-1, required=True, type=click.UNPROCESSED)
|
@click.argument("entrypoint", nargs=-1, required=True, type=click.UNPROCESSED)
|
||||||
def job_submit(address: Optional[str], job_id: Optional[str],
|
def job_submit(
|
||||||
runtime_env: Optional[str], runtime_env_json: Optional[str],
|
address: Optional[str],
|
||||||
working_dir: Optional[str], entrypoint: Tuple[str],
|
job_id: Optional[str],
|
||||||
no_wait: bool):
|
runtime_env: Optional[str],
|
||||||
|
runtime_env_json: Optional[str],
|
||||||
|
working_dir: Optional[str],
|
||||||
|
entrypoint: Tuple[str],
|
||||||
|
no_wait: bool,
|
||||||
|
):
|
||||||
"""Submits a job to be run on the cluster.
|
"""Submits a job to be run on the cluster.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@ -132,8 +144,9 @@ def job_submit(address: Optional[str], job_id: Optional[str],
|
||||||
final_runtime_env = {}
|
final_runtime_env = {}
|
||||||
if runtime_env is not None:
|
if runtime_env is not None:
|
||||||
if runtime_env_json is not None:
|
if runtime_env_json is not None:
|
||||||
raise ValueError("Only one of --runtime_env and "
|
raise ValueError(
|
||||||
"--runtime-env-json can be provided.")
|
"Only one of --runtime_env and " "--runtime-env-json can be provided."
|
||||||
|
)
|
||||||
with open(runtime_env, "r") as f:
|
with open(runtime_env, "r") as f:
|
||||||
final_runtime_env = yaml.safe_load(f)
|
final_runtime_env = yaml.safe_load(f)
|
||||||
|
|
||||||
|
@ -143,14 +156,14 @@ def job_submit(address: Optional[str], job_id: Optional[str],
|
||||||
if working_dir is not None:
|
if working_dir is not None:
|
||||||
if "working_dir" in final_runtime_env:
|
if "working_dir" in final_runtime_env:
|
||||||
cli_logger.warning(
|
cli_logger.warning(
|
||||||
"Overriding runtime_env working_dir with --working-dir option")
|
"Overriding runtime_env working_dir with --working-dir option"
|
||||||
|
)
|
||||||
|
|
||||||
final_runtime_env["working_dir"] = working_dir
|
final_runtime_env["working_dir"] = working_dir
|
||||||
|
|
||||||
job_id = client.submit_job(
|
job_id = client.submit_job(
|
||||||
entrypoint=" ".join(entrypoint),
|
entrypoint=" ".join(entrypoint), job_id=job_id, runtime_env=final_runtime_env
|
||||||
job_id=job_id,
|
)
|
||||||
runtime_env=final_runtime_env)
|
|
||||||
|
|
||||||
_log_big_success_msg(f"Job '{job_id}' submitted successfully")
|
_log_big_success_msg(f"Job '{job_id}' submitted successfully")
|
||||||
|
|
||||||
|
@ -172,15 +185,16 @@ def job_submit(address: Optional[str], job_id: Optional[str],
|
||||||
# sdk version 0 does not have log streaming
|
# sdk version 0 does not have log streaming
|
||||||
if not no_wait:
|
if not no_wait:
|
||||||
if int(sdk_version) > 0:
|
if int(sdk_version) > 0:
|
||||||
cli_logger.print("Tailing logs until the job exits "
|
cli_logger.print(
|
||||||
"(disable with --no-wait):")
|
"Tailing logs until the job exits " "(disable with --no-wait):"
|
||||||
asyncio.get_event_loop().run_until_complete(
|
)
|
||||||
_tail_logs(client, job_id))
|
asyncio.get_event_loop().run_until_complete(_tail_logs(client, job_id))
|
||||||
else:
|
else:
|
||||||
cli_logger.warning(
|
cli_logger.warning(
|
||||||
"Tailing logs is not enabled for job sdk client version "
|
"Tailing logs is not enabled for job sdk client version "
|
||||||
f"{sdk_version}. Please upgrade your ray to latest version "
|
f"{sdk_version}. Please upgrade your ray to latest version "
|
||||||
"for this feature.")
|
"for this feature."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@job_cli_group.command("status", help="Get the status of a running job.")
|
@job_cli_group.command("status", help="Get the status of a running job.")
|
||||||
|
@ -189,8 +203,11 @@ def job_submit(address: Optional[str], job_id: Optional[str],
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
help=("Address of the Ray cluster to connect to. Can also be specified "
|
help=(
|
||||||
"using the RAY_ADDRESS environment variable."))
|
"Address of the Ray cluster to connect to. Can also be specified "
|
||||||
|
"using the RAY_ADDRESS environment variable."
|
||||||
|
),
|
||||||
|
)
|
||||||
@click.argument("job-id", type=str)
|
@click.argument("job-id", type=str)
|
||||||
@add_click_logging_options
|
@add_click_logging_options
|
||||||
def job_status(address: Optional[str], job_id: str):
|
def job_status(address: Optional[str], job_id: str):
|
||||||
|
@ -209,14 +226,18 @@ def job_status(address: Optional[str], job_id: str):
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
help=("Address of the Ray cluster to connect to. Can also be specified "
|
help=(
|
||||||
"using the RAY_ADDRESS environment variable."))
|
"Address of the Ray cluster to connect to. Can also be specified "
|
||||||
|
"using the RAY_ADDRESS environment variable."
|
||||||
|
),
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--no-wait",
|
"--no-wait",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
type=bool,
|
type=bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="If set, will not wait for the job to exit.")
|
help="If set, will not wait for the job to exit.",
|
||||||
|
)
|
||||||
@click.argument("job-id", type=str)
|
@click.argument("job-id", type=str)
|
||||||
@add_click_logging_options
|
@add_click_logging_options
|
||||||
def job_stop(address: Optional[str], no_wait: bool, job_id: str):
|
def job_stop(address: Optional[str], no_wait: bool, job_id: str):
|
||||||
|
@ -232,14 +253,13 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
|
||||||
if no_wait:
|
if no_wait:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
cli_logger.print(f"Waiting for job '{job_id}' to exit "
|
cli_logger.print(
|
||||||
f"(disable with --no-wait):")
|
f"Waiting for job '{job_id}' to exit " f"(disable with --no-wait):"
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
status = client.get_job_status(job_id)
|
status = client.get_job_status(job_id)
|
||||||
if status.status in {
|
if status.status in {JobStatus.STOPPED, JobStatus.SUCCEEDED, JobStatus.FAILED}:
|
||||||
JobStatus.STOPPED, JobStatus.SUCCEEDED, JobStatus.FAILED
|
|
||||||
}:
|
|
||||||
_log_job_status(client, job_id)
|
_log_job_status(client, job_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
@ -253,8 +273,11 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
help=("Address of the Ray cluster to connect to. Can also be specified "
|
help=(
|
||||||
"using the RAY_ADDRESS environment variable."))
|
"Address of the Ray cluster to connect to. Can also be specified "
|
||||||
|
"using the RAY_ADDRESS environment variable."
|
||||||
|
),
|
||||||
|
)
|
||||||
@click.argument("job-id", type=str)
|
@click.argument("job-id", type=str)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-f",
|
"-f",
|
||||||
|
@ -262,7 +285,8 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
type=bool,
|
type=bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="If set, follow the logs (like `tail -f`).")
|
help="If set, follow the logs (like `tail -f`).",
|
||||||
|
)
|
||||||
@add_click_logging_options
|
@add_click_logging_options
|
||||||
def job_logs(address: Optional[str], job_id: str, follow: bool):
|
def job_logs(address: Optional[str], job_id: str, follow: bool):
|
||||||
"""Gets the logs of a job.
|
"""Gets the logs of a job.
|
||||||
|
@ -275,12 +299,12 @@ def job_logs(address: Optional[str], job_id: str, follow: bool):
|
||||||
# sdk version 0 did not have log streaming
|
# sdk version 0 did not have log streaming
|
||||||
if follow:
|
if follow:
|
||||||
if int(sdk_version) > 0:
|
if int(sdk_version) > 0:
|
||||||
asyncio.get_event_loop().run_until_complete(
|
asyncio.get_event_loop().run_until_complete(_tail_logs(client, job_id))
|
||||||
_tail_logs(client, job_id))
|
|
||||||
else:
|
else:
|
||||||
cli_logger.warning(
|
cli_logger.warning(
|
||||||
"Tailing logs is not enabled for job sdk client version "
|
"Tailing logs is not enabled for job sdk client version "
|
||||||
f"{sdk_version}. Please upgrade your ray to latest version "
|
f"{sdk_version}. Please upgrade your ray to latest version "
|
||||||
"for this feature.")
|
"for this feature."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(client.get_job_logs(job_id), end="")
|
print(client.get_job_logs(job_id), end="")
|
||||||
|
|
|
@ -39,8 +39,10 @@ class JobStatusInfo:
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.message is None:
|
if self.message is None:
|
||||||
if self.status == JobStatus.PENDING:
|
if self.status == JobStatus.PENDING:
|
||||||
self.message = ("Job has not started yet, likely waiting "
|
self.message = (
|
||||||
"for the runtime_env to be set up.")
|
"Job has not started yet, likely waiting "
|
||||||
|
"for the runtime_env to be set up."
|
||||||
|
)
|
||||||
elif self.status == JobStatus.RUNNING:
|
elif self.status == JobStatus.RUNNING:
|
||||||
self.message = "Job is currently running."
|
self.message = "Job is currently running."
|
||||||
elif self.status == JobStatus.STOPPED:
|
elif self.status == JobStatus.STOPPED:
|
||||||
|
@ -55,6 +57,7 @@ class JobStatusStorageClient:
|
||||||
"""
|
"""
|
||||||
Handles formatting of status storage key given job id.
|
Handles formatting of status storage key given job id.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
JOB_STATUS_KEY = "_ray_internal_job_status_{job_id}"
|
JOB_STATUS_KEY = "_ray_internal_job_status_{job_id}"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -69,12 +72,14 @@ class JobStatusStorageClient:
|
||||||
_internal_kv_put(
|
_internal_kv_put(
|
||||||
self.JOB_STATUS_KEY.format(job_id=job_id),
|
self.JOB_STATUS_KEY.format(job_id=job_id),
|
||||||
pickle.dumps(status),
|
pickle.dumps(status),
|
||||||
namespace=ray_constants.KV_NAMESPACE_JOB)
|
namespace=ray_constants.KV_NAMESPACE_JOB,
|
||||||
|
)
|
||||||
|
|
||||||
def get_status(self, job_id: str) -> Optional[JobStatusInfo]:
|
def get_status(self, job_id: str) -> Optional[JobStatusInfo]:
|
||||||
pickled_status = _internal_kv_get(
|
pickled_status = _internal_kv_get(
|
||||||
self.JOB_STATUS_KEY.format(job_id=job_id),
|
self.JOB_STATUS_KEY.format(job_id=job_id),
|
||||||
namespace=ray_constants.KV_NAMESPACE_JOB)
|
namespace=ray_constants.KV_NAMESPACE_JOB,
|
||||||
|
)
|
||||||
if pickled_status is None:
|
if pickled_status is None:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
|
@ -87,18 +92,16 @@ def uri_to_http_components(package_uri: str) -> Tuple[str, str]:
|
||||||
# We need to strip the gcs:// prefix and .zip suffix to make it
|
# We need to strip the gcs:// prefix and .zip suffix to make it
|
||||||
# possible to pass the package_uri over HTTP.
|
# possible to pass the package_uri over HTTP.
|
||||||
protocol, package_name = parse_uri(package_uri)
|
protocol, package_name = parse_uri(package_uri)
|
||||||
return protocol.value, package_name[:-len(".zip")]
|
return protocol.value, package_name[: -len(".zip")]
|
||||||
|
|
||||||
|
|
||||||
def http_uri_components_to_uri(protocol: str, package_name: str) -> str:
|
def http_uri_components_to_uri(protocol: str, package_name: str) -> str:
|
||||||
if package_name.endswith(".zip"):
|
if package_name.endswith(".zip"):
|
||||||
raise ValueError(
|
raise ValueError(f"package_name ({package_name}) should not end in .zip")
|
||||||
f"package_name ({package_name}) should not end in .zip")
|
|
||||||
return f"{protocol}://{package_name}.zip"
|
return f"{protocol}://{package_name}.zip"
|
||||||
|
|
||||||
|
|
||||||
def validate_request_type(json_data: Dict[str, Any],
|
def validate_request_type(json_data: Dict[str, Any], request_type: dataclass) -> Any:
|
||||||
request_type: dataclass) -> Any:
|
|
||||||
return request_type(**json_data)
|
return request_type(**json_data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,8 +127,7 @@ class JobSubmitRequest:
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not isinstance(self.entrypoint, str):
|
if not isinstance(self.entrypoint, str):
|
||||||
raise TypeError(
|
raise TypeError(f"entrypoint must be a string, got {type(self.entrypoint)}")
|
||||||
f"entrypoint must be a string, got {type(self.entrypoint)}")
|
|
||||||
|
|
||||||
if self.job_id is not None and not isinstance(self.job_id, str):
|
if self.job_id is not None and not isinstance(self.job_id, str):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
@ -141,21 +143,21 @@ class JobSubmitRequest:
|
||||||
for k in self.runtime_env.keys():
|
for k in self.runtime_env.keys():
|
||||||
if not isinstance(k, str):
|
if not isinstance(k, str):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"runtime_env keys must be strings, got {type(k)}")
|
f"runtime_env keys must be strings, got {type(k)}"
|
||||||
|
)
|
||||||
|
|
||||||
if self.metadata is not None:
|
if self.metadata is not None:
|
||||||
if not isinstance(self.metadata, dict):
|
if not isinstance(self.metadata, dict):
|
||||||
raise TypeError(
|
raise TypeError(f"metadata must be a dict, got {type(self.metadata)}")
|
||||||
f"metadata must be a dict, got {type(self.metadata)}")
|
|
||||||
else:
|
else:
|
||||||
for k in self.metadata.keys():
|
for k in self.metadata.keys():
|
||||||
if not isinstance(k, str):
|
if not isinstance(k, str):
|
||||||
raise TypeError(
|
raise TypeError(f"metadata keys must be strings, got {type(k)}")
|
||||||
f"metadata keys must be strings, got {type(k)}")
|
|
||||||
for v in self.metadata.values():
|
for v in self.metadata.values():
|
||||||
if not isinstance(v, str):
|
if not isinstance(v, str):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"metadata values must be strings, got {type(v)}")
|
f"metadata values must be strings, got {type(v)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -12,8 +12,7 @@ import ray
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||||
from ray._private.gcs_utils import use_gcs_for_bootstrap
|
from ray._private.gcs_utils import use_gcs_for_bootstrap
|
||||||
from ray._private.runtime_env.packaging import (package_exists,
|
from ray._private.runtime_env.packaging import package_exists, upload_package_to_gcs
|
||||||
upload_package_to_gcs)
|
|
||||||
from ray.dashboard.modules.job.common import (
|
from ray.dashboard.modules.job.common import (
|
||||||
CURRENT_VERSION,
|
CURRENT_VERSION,
|
||||||
http_uri_components_to_uri,
|
http_uri_components_to_uri,
|
||||||
|
@ -45,19 +44,20 @@ def _init_ray_and_catch_exceptions(f: Callable) -> Callable:
|
||||||
if use_gcs_for_bootstrap():
|
if use_gcs_for_bootstrap():
|
||||||
address = self._dashboard_head.gcs_address
|
address = self._dashboard_head.gcs_address
|
||||||
redis_pw = None
|
redis_pw = None
|
||||||
logger.info(
|
logger.info(f"Connecting to ray with address={address}")
|
||||||
f"Connecting to ray with address={address}")
|
|
||||||
else:
|
else:
|
||||||
ip, port = self._dashboard_head.redis_address
|
ip, port = self._dashboard_head.redis_address
|
||||||
redis_pw = self._dashboard_head.redis_password
|
redis_pw = self._dashboard_head.redis_password
|
||||||
address = f"{ip}:{port}"
|
address = f"{ip}:{port}"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Connecting to ray with address={address}, "
|
f"Connecting to ray with address={address}, "
|
||||||
f"redis_pw={redis_pw}")
|
f"redis_pw={redis_pw}"
|
||||||
|
)
|
||||||
ray.init(
|
ray.init(
|
||||||
address=address,
|
address=address,
|
||||||
namespace=RAY_INTERNAL_JOBS_NAMESPACE,
|
namespace=RAY_INTERNAL_JOBS_NAMESPACE,
|
||||||
_redis_password=redis_pw)
|
_redis_password=redis_pw,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
raise e from None
|
raise e from None
|
||||||
|
@ -67,7 +67,8 @@ def _init_ray_and_catch_exceptions(f: Callable) -> Callable:
|
||||||
logger.exception(f"Unexpected error in handler: {e}")
|
logger.exception(f"Unexpected error in handler: {e}")
|
||||||
return Response(
|
return Response(
|
||||||
text=traceback.format_exc(),
|
text=traceback.format_exc(),
|
||||||
status=aiohttp.web.HTTPInternalServerError.status_code)
|
status=aiohttp.web.HTTPInternalServerError.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
return check
|
return check
|
||||||
|
|
||||||
|
@ -77,8 +78,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
||||||
super().__init__(dashboard_head)
|
super().__init__(dashboard_head)
|
||||||
self._job_manager = None
|
self._job_manager = None
|
||||||
|
|
||||||
async def _parse_and_validate_request(self, req: Request,
|
async def _parse_and_validate_request(
|
||||||
request_type: dataclass) -> Any:
|
self, req: Request, request_type: dataclass
|
||||||
|
) -> Any:
|
||||||
"""Parse request and cast to request type. If parsing failed, return a
|
"""Parse request and cast to request type. If parsing failed, return a
|
||||||
Response object with status 400 and stacktrace instead.
|
Response object with status 400 and stacktrace instead.
|
||||||
"""
|
"""
|
||||||
|
@ -88,7 +90,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
||||||
logger.info(f"Got invalid request type: {e}")
|
logger.info(f"Got invalid request type: {e}")
|
||||||
return Response(
|
return Response(
|
||||||
text=traceback.format_exc(),
|
text=traceback.format_exc(),
|
||||||
status=aiohttp.web.HTTPBadRequest.status_code)
|
status=aiohttp.web.HTTPBadRequest.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
def job_exists(self, job_id: str) -> bool:
|
def job_exists(self, job_id: str) -> bool:
|
||||||
status = self._job_manager.get_job_status(job_id)
|
status = self._job_manager.get_job_status(job_id)
|
||||||
|
@ -101,7 +104,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
||||||
resp = VersionResponse(
|
resp = VersionResponse(
|
||||||
version=CURRENT_VERSION,
|
version=CURRENT_VERSION,
|
||||||
ray_version=ray.__version__,
|
ray_version=ray.__version__,
|
||||||
ray_commit=ray.__commit__)
|
ray_commit=ray.__commit__,
|
||||||
|
)
|
||||||
return Response(
|
return Response(
|
||||||
text=json.dumps(dataclasses.asdict(resp)),
|
text=json.dumps(dataclasses.asdict(resp)),
|
||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
|
@ -113,12 +117,14 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
||||||
async def get_package(self, req: Request) -> Response:
|
async def get_package(self, req: Request) -> Response:
|
||||||
package_uri = http_uri_components_to_uri(
|
package_uri = http_uri_components_to_uri(
|
||||||
protocol=req.match_info["protocol"],
|
protocol=req.match_info["protocol"],
|
||||||
package_name=req.match_info["package_name"])
|
package_name=req.match_info["package_name"],
|
||||||
|
)
|
||||||
|
|
||||||
if not package_exists(package_uri):
|
if not package_exists(package_uri):
|
||||||
return Response(
|
return Response(
|
||||||
text=f"Package {package_uri} does not exist",
|
text=f"Package {package_uri} does not exist",
|
||||||
status=aiohttp.web.HTTPNotFound.status_code)
|
status=aiohttp.web.HTTPNotFound.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
return Response()
|
return Response()
|
||||||
|
|
||||||
|
@ -127,14 +133,16 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
||||||
async def upload_package(self, req: Request):
|
async def upload_package(self, req: Request):
|
||||||
package_uri = http_uri_components_to_uri(
|
package_uri = http_uri_components_to_uri(
|
||||||
protocol=req.match_info["protocol"],
|
protocol=req.match_info["protocol"],
|
||||||
package_name=req.match_info["package_name"])
|
package_name=req.match_info["package_name"],
|
||||||
|
)
|
||||||
logger.info(f"Uploading package {package_uri} to the GCS.")
|
logger.info(f"Uploading package {package_uri} to the GCS.")
|
||||||
try:
|
try:
|
||||||
upload_package_to_gcs(package_uri, await req.read())
|
upload_package_to_gcs(package_uri, await req.read())
|
||||||
except Exception:
|
except Exception:
|
||||||
return Response(
|
return Response(
|
||||||
text=traceback.format_exc(),
|
text=traceback.format_exc(),
|
||||||
status=aiohttp.web.HTTPInternalServerError.status_code)
|
status=aiohttp.web.HTTPInternalServerError.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
return Response(status=aiohttp.web.HTTPOk.status_code)
|
return Response(status=aiohttp.web.HTTPOk.status_code)
|
||||||
|
|
||||||
|
@ -153,17 +161,20 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
||||||
entrypoint=submit_request.entrypoint,
|
entrypoint=submit_request.entrypoint,
|
||||||
job_id=submit_request.job_id,
|
job_id=submit_request.job_id,
|
||||||
runtime_env=submit_request.runtime_env,
|
runtime_env=submit_request.runtime_env,
|
||||||
metadata=submit_request.metadata)
|
metadata=submit_request.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
resp = JobSubmitResponse(job_id=job_id)
|
resp = JobSubmitResponse(job_id=job_id)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return Response(
|
return Response(
|
||||||
text=traceback.format_exc(),
|
text=traceback.format_exc(),
|
||||||
status=aiohttp.web.HTTPBadRequest.status_code)
|
status=aiohttp.web.HTTPBadRequest.status_code,
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return Response(
|
return Response(
|
||||||
text=traceback.format_exc(),
|
text=traceback.format_exc(),
|
||||||
status=aiohttp.web.HTTPInternalServerError.status_code)
|
status=aiohttp.web.HTTPInternalServerError.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
text=json.dumps(dataclasses.asdict(resp)),
|
text=json.dumps(dataclasses.asdict(resp)),
|
||||||
|
@ -178,7 +189,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
||||||
if not self.job_exists(job_id):
|
if not self.job_exists(job_id):
|
||||||
return Response(
|
return Response(
|
||||||
text=f"Job {job_id} does not exist",
|
text=f"Job {job_id} does not exist",
|
||||||
status=aiohttp.web.HTTPNotFound.status_code)
|
status=aiohttp.web.HTTPNotFound.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stopped = self._job_manager.stop_job(job_id)
|
stopped = self._job_manager.stop_job(job_id)
|
||||||
|
@ -186,11 +198,12 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
||||||
except Exception:
|
except Exception:
|
||||||
return Response(
|
return Response(
|
||||||
text=traceback.format_exc(),
|
text=traceback.format_exc(),
|
||||||
status=aiohttp.web.HTTPInternalServerError.status_code)
|
status=aiohttp.web.HTTPInternalServerError.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
text=json.dumps(dataclasses.asdict(resp)),
|
text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
|
||||||
content_type="application/json")
|
)
|
||||||
|
|
||||||
@routes.get("/api/jobs/{job_id}")
|
@routes.get("/api/jobs/{job_id}")
|
||||||
@_init_ray_and_catch_exceptions
|
@_init_ray_and_catch_exceptions
|
||||||
|
@ -199,13 +212,14 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
||||||
if not self.job_exists(job_id):
|
if not self.job_exists(job_id):
|
||||||
return Response(
|
return Response(
|
||||||
text=f"Job {job_id} does not exist",
|
text=f"Job {job_id} does not exist",
|
||||||
status=aiohttp.web.HTTPNotFound.status_code)
|
status=aiohttp.web.HTTPNotFound.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
status: JobStatusInfo = self._job_manager.get_job_status(job_id)
|
status: JobStatusInfo = self._job_manager.get_job_status(job_id)
|
||||||
resp = JobStatusResponse(status=status.status, message=status.message)
|
resp = JobStatusResponse(status=status.status, message=status.message)
|
||||||
return Response(
|
return Response(
|
||||||
text=json.dumps(dataclasses.asdict(resp)),
|
text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
|
||||||
content_type="application/json")
|
)
|
||||||
|
|
||||||
@routes.get("/api/jobs/{job_id}/logs")
|
@routes.get("/api/jobs/{job_id}/logs")
|
||||||
@_init_ray_and_catch_exceptions
|
@_init_ray_and_catch_exceptions
|
||||||
|
@ -214,12 +228,13 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
||||||
if not self.job_exists(job_id):
|
if not self.job_exists(job_id):
|
||||||
return Response(
|
return Response(
|
||||||
text=f"Job {job_id} does not exist",
|
text=f"Job {job_id} does not exist",
|
||||||
status=aiohttp.web.HTTPNotFound.status_code)
|
status=aiohttp.web.HTTPNotFound.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
resp = JobLogsResponse(logs=self._job_manager.get_job_logs(job_id))
|
resp = JobLogsResponse(logs=self._job_manager.get_job_logs(job_id))
|
||||||
return Response(
|
return Response(
|
||||||
text=json.dumps(dataclasses.asdict(resp)),
|
text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
|
||||||
content_type="application/json")
|
)
|
||||||
|
|
||||||
@routes.get("/api/jobs/{job_id}/logs/tail")
|
@routes.get("/api/jobs/{job_id}/logs/tail")
|
||||||
@_init_ray_and_catch_exceptions
|
@_init_ray_and_catch_exceptions
|
||||||
|
@ -228,7 +243,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
||||||
if not self.job_exists(job_id):
|
if not self.job_exists(job_id):
|
||||||
return Response(
|
return Response(
|
||||||
text=f"Job {job_id} does not exist",
|
text=f"Job {job_id} does not exist",
|
||||||
status=aiohttp.web.HTTPNotFound.status_code)
|
status=aiohttp.web.HTTPNotFound.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
ws = aiohttp.web.WebSocketResponse()
|
ws = aiohttp.web.WebSocketResponse()
|
||||||
await ws.prepare(req)
|
await ws.prepare(req)
|
||||||
|
|
|
@ -15,8 +15,12 @@ from ray.exceptions import RuntimeEnvSetupError
|
||||||
import ray.ray_constants as ray_constants
|
import ray.ray_constants as ray_constants
|
||||||
from ray.actor import ActorHandle
|
from ray.actor import ActorHandle
|
||||||
from ray.dashboard.modules.job.common import (
|
from ray.dashboard.modules.job.common import (
|
||||||
JobStatus, JobStatusInfo, JobStatusStorageClient, JOB_ID_METADATA_KEY,
|
JobStatus,
|
||||||
JOB_NAME_METADATA_KEY)
|
JobStatusInfo,
|
||||||
|
JobStatusStorageClient,
|
||||||
|
JOB_ID_METADATA_KEY,
|
||||||
|
JOB_NAME_METADATA_KEY,
|
||||||
|
)
|
||||||
from ray.dashboard.modules.job.utils import file_tail_iterator
|
from ray.dashboard.modules.job.utils import file_tail_iterator
|
||||||
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
|
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
|
||||||
|
|
||||||
|
@ -36,8 +40,8 @@ def generate_job_id() -> str:
|
||||||
"""
|
"""
|
||||||
rand = random.SystemRandom()
|
rand = random.SystemRandom()
|
||||||
possible_characters = list(
|
possible_characters = list(
|
||||||
set(string.ascii_letters + string.digits) -
|
set(string.ascii_letters + string.digits)
|
||||||
{"I", "l", "o", "O", "0"} # No confusing characters
|
- {"I", "l", "o", "O", "0"} # No confusing characters
|
||||||
)
|
)
|
||||||
id_part = "".join(rand.choices(possible_characters, k=16))
|
id_part = "".join(rand.choices(possible_characters, k=16))
|
||||||
return f"raysubmit_{id_part}"
|
return f"raysubmit_{id_part}"
|
||||||
|
@ -47,6 +51,7 @@ class JobLogStorageClient:
|
||||||
"""
|
"""
|
||||||
Disk storage for stdout / stderr of driver script logs.
|
Disk storage for stdout / stderr of driver script logs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
JOB_LOGS_PATH = "job-driver-{job_id}.log"
|
JOB_LOGS_PATH = "job-driver-{job_id}.log"
|
||||||
# Number of last N lines to put in job message upon failure.
|
# Number of last N lines to put in job message upon failure.
|
||||||
NUM_LOG_LINES_ON_ERROR = 10
|
NUM_LOG_LINES_ON_ERROR = 10
|
||||||
|
@ -61,9 +66,9 @@ class JobLogStorageClient:
|
||||||
def tail_logs(self, job_id: str) -> Iterator[str]:
|
def tail_logs(self, job_id: str) -> Iterator[str]:
|
||||||
return file_tail_iterator(self.get_log_file_path(job_id))
|
return file_tail_iterator(self.get_log_file_path(job_id))
|
||||||
|
|
||||||
def get_last_n_log_lines(self,
|
def get_last_n_log_lines(
|
||||||
job_id: str,
|
self, job_id: str, num_log_lines=NUM_LOG_LINES_ON_ERROR
|
||||||
num_log_lines=NUM_LOG_LINES_ON_ERROR) -> str:
|
) -> str:
|
||||||
log_tail_iter = self.tail_logs(job_id)
|
log_tail_iter = self.tail_logs(job_id)
|
||||||
log_tail_deque = deque(maxlen=num_log_lines)
|
log_tail_deque = deque(maxlen=num_log_lines)
|
||||||
for line in log_tail_iter:
|
for line in log_tail_iter:
|
||||||
|
@ -80,7 +85,8 @@ class JobLogStorageClient:
|
||||||
"""
|
"""
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
ray.worker._global_node.get_logs_dir_path(),
|
ray.worker._global_node.get_logs_dir_path(),
|
||||||
self.JOB_LOGS_PATH.format(job_id=job_id))
|
self.JOB_LOGS_PATH.format(job_id=job_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JobSupervisor:
|
class JobSupervisor:
|
||||||
|
@ -95,8 +101,7 @@ class JobSupervisor:
|
||||||
|
|
||||||
SUBPROCESS_POLL_PERIOD_S = 0.1
|
SUBPROCESS_POLL_PERIOD_S = 0.1
|
||||||
|
|
||||||
def __init__(self, job_id: str, entrypoint: str,
|
def __init__(self, job_id: str, entrypoint: str, user_metadata: Dict[str, str]):
|
||||||
user_metadata: Dict[str, str]):
|
|
||||||
self._job_id = job_id
|
self._job_id = job_id
|
||||||
self._status_client = JobStatusStorageClient()
|
self._status_client = JobStatusStorageClient()
|
||||||
self._log_client = JobLogStorageClient()
|
self._log_client = JobLogStorageClient()
|
||||||
|
@ -104,10 +109,7 @@ class JobSupervisor:
|
||||||
self._entrypoint = entrypoint
|
self._entrypoint = entrypoint
|
||||||
|
|
||||||
# Default metadata if not passed by the user.
|
# Default metadata if not passed by the user.
|
||||||
self._metadata = {
|
self._metadata = {JOB_ID_METADATA_KEY: job_id, JOB_NAME_METADATA_KEY: job_id}
|
||||||
JOB_ID_METADATA_KEY: job_id,
|
|
||||||
JOB_NAME_METADATA_KEY: job_id
|
|
||||||
}
|
|
||||||
self._metadata.update(user_metadata)
|
self._metadata.update(user_metadata)
|
||||||
|
|
||||||
# fire and forget call from outer job manager to this actor
|
# fire and forget call from outer job manager to this actor
|
||||||
|
@ -142,7 +144,8 @@ class JobSupervisor:
|
||||||
shell=True,
|
shell=True,
|
||||||
start_new_session=True,
|
start_new_session=True,
|
||||||
stdout=logs_file,
|
stdout=logs_file,
|
||||||
stderr=subprocess.STDOUT)
|
stderr=subprocess.STDOUT,
|
||||||
|
)
|
||||||
parent_pid = os.getpid()
|
parent_pid = os.getpid()
|
||||||
# Create new pgid with new subprocess to execute driver command
|
# Create new pgid with new subprocess to execute driver command
|
||||||
child_pid = child_process.pid
|
child_pid = child_process.pid
|
||||||
|
@ -177,9 +180,10 @@ class JobSupervisor:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
# Signal actor used in testing to capture PENDING -> RUNNING cases
|
# Signal actor used in testing to capture PENDING -> RUNNING cases
|
||||||
_start_signal_actor: Optional[ActorHandle] = None):
|
_start_signal_actor: Optional[ActorHandle] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Stop and start both happen asynchrously, coordinated by asyncio event
|
Stop and start both happen asynchrously, coordinated by asyncio event
|
||||||
and coroutine, respectively.
|
and coroutine, respectively.
|
||||||
|
@ -190,26 +194,26 @@ class JobSupervisor:
|
||||||
3) Handle concurrent events of driver execution and
|
3) Handle concurrent events of driver execution and
|
||||||
"""
|
"""
|
||||||
cur_status = self._get_status()
|
cur_status = self._get_status()
|
||||||
assert cur_status.status == JobStatus.PENDING, (
|
assert cur_status.status == JobStatus.PENDING, "Run should only be called once."
|
||||||
"Run should only be called once.")
|
|
||||||
|
|
||||||
if _start_signal_actor:
|
if _start_signal_actor:
|
||||||
# Block in PENDING state until start signal received.
|
# Block in PENDING state until start signal received.
|
||||||
await _start_signal_actor.wait.remote()
|
await _start_signal_actor.wait.remote()
|
||||||
|
|
||||||
self._status_client.put_status(self._job_id,
|
self._status_client.put_status(self._job_id, JobStatusInfo(JobStatus.RUNNING))
|
||||||
JobStatusInfo(JobStatus.RUNNING))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set JobConfig for the child process (runtime_env, metadata).
|
# Set JobConfig for the child process (runtime_env, metadata).
|
||||||
os.environ[RAY_JOB_CONFIG_JSON_ENV_VAR] = json.dumps({
|
os.environ[RAY_JOB_CONFIG_JSON_ENV_VAR] = json.dumps(
|
||||||
"runtime_env": self._runtime_env,
|
{
|
||||||
"metadata": self._metadata,
|
"runtime_env": self._runtime_env,
|
||||||
})
|
"metadata": self._metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
# Set RAY_ADDRESS to local Ray address, if it is not set.
|
# Set RAY_ADDRESS to local Ray address, if it is not set.
|
||||||
os.environ[
|
os.environ[
|
||||||
ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE] = \
|
ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE
|
||||||
ray._private.services.get_ray_address_from_environment()
|
] = ray._private.services.get_ray_address_from_environment()
|
||||||
# Set PYTHONUNBUFFERED=1 to stream logs during the job instead of
|
# Set PYTHONUNBUFFERED=1 to stream logs during the job instead of
|
||||||
# only streaming them upon completion of the job.
|
# only streaming them upon completion of the job.
|
||||||
os.environ["PYTHONUNBUFFERED"] = "1"
|
os.environ["PYTHONUNBUFFERED"] = "1"
|
||||||
|
@ -218,8 +222,8 @@ class JobSupervisor:
|
||||||
|
|
||||||
polling_task = create_task(self._polling(child_process))
|
polling_task = create_task(self._polling(child_process))
|
||||||
finished, _ = await asyncio.wait(
|
finished, _ = await asyncio.wait(
|
||||||
[polling_task, self._stop_event.wait()],
|
[polling_task, self._stop_event.wait()], return_when=FIRST_COMPLETED
|
||||||
return_when=FIRST_COMPLETED)
|
)
|
||||||
|
|
||||||
if self._stop_event.is_set():
|
if self._stop_event.is_set():
|
||||||
polling_task.cancel()
|
polling_task.cancel()
|
||||||
|
@ -229,29 +233,29 @@ class JobSupervisor:
|
||||||
else:
|
else:
|
||||||
# Child process finished execution and no stop event is set
|
# Child process finished execution and no stop event is set
|
||||||
# at the same time
|
# at the same time
|
||||||
assert len(
|
assert len(finished) == 1, "Should have only one coroutine done"
|
||||||
finished) == 1, "Should have only one coroutine done"
|
|
||||||
[child_process_task] = finished
|
[child_process_task] = finished
|
||||||
return_code = child_process_task.result()
|
return_code = child_process_task.result()
|
||||||
if return_code == 0:
|
if return_code == 0:
|
||||||
self._status_client.put_status(self._job_id,
|
self._status_client.put_status(self._job_id, JobStatus.SUCCEEDED)
|
||||||
JobStatus.SUCCEEDED)
|
|
||||||
else:
|
else:
|
||||||
log_tail = self._log_client.get_last_n_log_lines(
|
log_tail = self._log_client.get_last_n_log_lines(self._job_id)
|
||||||
self._job_id)
|
|
||||||
if log_tail is not None and log_tail != "":
|
if log_tail is not None and log_tail != "":
|
||||||
message = ("Job failed due to an application error, "
|
message = (
|
||||||
"last available logs:\n" + log_tail)
|
"Job failed due to an application error, "
|
||||||
|
"last available logs:\n" + log_tail
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
message = None
|
message = None
|
||||||
self._status_client.put_status(
|
self._status_client.put_status(
|
||||||
self._job_id,
|
self._job_id,
|
||||||
JobStatusInfo(
|
JobStatusInfo(status=JobStatus.FAILED, message=message),
|
||||||
status=JobStatus.FAILED, message=message))
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Got unexpected exception while trying to execute driver "
|
"Got unexpected exception while trying to execute driver "
|
||||||
f"command. {traceback.format_exc()}")
|
f"command. {traceback.format_exc()}"
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
# clean up actor after tasks are finished
|
# clean up actor after tasks are finished
|
||||||
ray.actor.exit_actor()
|
ray.actor.exit_actor()
|
||||||
|
@ -260,8 +264,7 @@ class JobSupervisor:
|
||||||
return self._status_client.get_status(self._job_id)
|
return self._status_client.get_status(self._job_id)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Set step_event and let run() handle the rest in its asyncio.wait().
|
"""Set step_event and let run() handle the rest in its asyncio.wait()."""
|
||||||
"""
|
|
||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
|
|
||||||
|
|
||||||
|
@ -271,6 +274,7 @@ class JobManager:
|
||||||
It does not provide persistence, all info will be lost if the cluster
|
It does not provide persistence, all info will be lost if the cluster
|
||||||
goes down.
|
goes down.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
JOB_ACTOR_NAME = "_ray_internal_job_actor_{job_id}"
|
JOB_ACTOR_NAME = "_ray_internal_job_actor_{job_id}"
|
||||||
# Time that we will sleep while tailing logs if no new log line is
|
# Time that we will sleep while tailing logs if no new log line is
|
||||||
# available.
|
# available.
|
||||||
|
@ -300,11 +304,9 @@ class JobManager:
|
||||||
if key.startswith("node:"):
|
if key.startswith("node:"):
|
||||||
return key
|
return key
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Cannot find the node dictionary for current node.")
|
||||||
"Cannot find the node dictionary for current node.")
|
|
||||||
|
|
||||||
def _handle_supervisor_startup(self, job_id: str,
|
def _handle_supervisor_startup(self, job_id: str, result: Optional[Exception]):
|
||||||
result: Optional[Exception]):
|
|
||||||
"""Handle the result of starting a job supervisor actor.
|
"""Handle the result of starting a job supervisor actor.
|
||||||
|
|
||||||
If started successfully, result should be None. Otherwise it should be
|
If started successfully, result should be None. Otherwise it should be
|
||||||
|
@ -321,26 +323,30 @@ class JobManager:
|
||||||
job_id,
|
job_id,
|
||||||
JobStatusInfo(
|
JobStatusInfo(
|
||||||
status=JobStatus.FAILED,
|
status=JobStatus.FAILED,
|
||||||
message=(f"runtime_env setup failed: {result}")))
|
message=(f"runtime_env setup failed: {result}"),
|
||||||
|
),
|
||||||
|
)
|
||||||
elif isinstance(result, Exception):
|
elif isinstance(result, Exception):
|
||||||
logger.error(
|
logger.error(f"Failed to start supervisor for job {job_id}: {result}.")
|
||||||
f"Failed to start supervisor for job {job_id}: {result}.")
|
|
||||||
self._status_client.put_status(
|
self._status_client.put_status(
|
||||||
job_id,
|
job_id,
|
||||||
JobStatusInfo(
|
JobStatusInfo(
|
||||||
status=JobStatus.FAILED,
|
status=JobStatus.FAILED,
|
||||||
message=f"Error occurred while starting the job: {result}")
|
message=f"Error occurred while starting the job: {result}",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert False, "This should not be reached."
|
assert False, "This should not be reached."
|
||||||
|
|
||||||
def submit_job(self,
|
def submit_job(
|
||||||
*,
|
self,
|
||||||
entrypoint: str,
|
*,
|
||||||
job_id: Optional[str] = None,
|
entrypoint: str,
|
||||||
runtime_env: Optional[Dict[str, Any]] = None,
|
job_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
runtime_env: Optional[Dict[str, Any]] = None,
|
||||||
_start_signal_actor: Optional[ActorHandle] = None) -> str:
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
|
_start_signal_actor: Optional[ActorHandle] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Job execution happens asynchronously.
|
Job execution happens asynchronously.
|
||||||
|
|
||||||
|
@ -390,8 +396,8 @@ class JobManager:
|
||||||
resources={
|
resources={
|
||||||
self._get_current_node_resource_key(): 0.001,
|
self._get_current_node_resource_key(): 0.001,
|
||||||
},
|
},
|
||||||
runtime_env=runtime_env).remote(job_id, entrypoint, metadata
|
runtime_env=runtime_env,
|
||||||
or {})
|
).remote(job_id, entrypoint, metadata or {})
|
||||||
actor.run.remote(_start_signal_actor=_start_signal_actor)
|
actor.run.remote(_start_signal_actor=_start_signal_actor)
|
||||||
|
|
||||||
def callback(result: Optional[Exception]):
|
def callback(result: Optional[Exception]):
|
||||||
|
@ -441,7 +447,8 @@ class JobManager:
|
||||||
# updating GCS with latest status.
|
# updating GCS with latest status.
|
||||||
last_status = self._status_client.get_status(job_id)
|
last_status = self._status_client.get_status(job_id)
|
||||||
if last_status and last_status.status in {
|
if last_status and last_status.status in {
|
||||||
JobStatus.PENDING, JobStatus.RUNNING
|
JobStatus.PENDING,
|
||||||
|
JobStatus.RUNNING,
|
||||||
}:
|
}:
|
||||||
self._status_client.put_status(job_id, JobStatus.FAILED)
|
self._status_client.put_status(job_id, JobStatus.FAILED)
|
||||||
|
|
||||||
|
|
|
@ -13,10 +13,19 @@ except ImportError:
|
||||||
requests = None
|
requests = None
|
||||||
|
|
||||||
from ray._private.runtime_env.packaging import (
|
from ray._private.runtime_env.packaging import (
|
||||||
create_package, get_uri_for_directory, parse_uri)
|
create_package,
|
||||||
|
get_uri_for_directory,
|
||||||
|
parse_uri,
|
||||||
|
)
|
||||||
from ray.dashboard.modules.job.common import (
|
from ray.dashboard.modules.job.common import (
|
||||||
JobSubmitRequest, JobSubmitResponse, JobStopResponse, JobStatusInfo,
|
JobSubmitRequest,
|
||||||
JobStatusResponse, JobLogsResponse, uri_to_http_components)
|
JobSubmitResponse,
|
||||||
|
JobStopResponse,
|
||||||
|
JobStatusInfo,
|
||||||
|
JobStatusResponse,
|
||||||
|
JobLogsResponse,
|
||||||
|
uri_to_http_components,
|
||||||
|
)
|
||||||
|
|
||||||
from ray.client_builder import _split_address
|
from ray.client_builder import _split_address
|
||||||
|
|
||||||
|
@ -33,51 +42,49 @@ class ClusterInfo:
|
||||||
|
|
||||||
|
|
||||||
def get_job_submission_client_cluster_info(
|
def get_job_submission_client_cluster_info(
|
||||||
address: str,
|
address: str,
|
||||||
# For backwards compatibility
|
# For backwards compatibility
|
||||||
*,
|
*,
|
||||||
# only used in importlib case in parse_cluster_info, but needed
|
# only used in importlib case in parse_cluster_info, but needed
|
||||||
# in function signature.
|
# in function signature.
|
||||||
create_cluster_if_needed: Optional[bool] = False,
|
create_cluster_if_needed: Optional[bool] = False,
|
||||||
cookies: Optional[Dict[str, Any]] = None,
|
cookies: Optional[Dict[str, Any]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
headers: Optional[Dict[str, Any]] = None) -> ClusterInfo:
|
headers: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> ClusterInfo:
|
||||||
"""Get address, cookies, and metadata used for JobSubmissionClient.
|
"""Get address, cookies, and metadata used for JobSubmissionClient.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
address (str): Address without the module prefix that is passed
|
address (str): Address without the module prefix that is passed
|
||||||
to JobSubmissionClient.
|
to JobSubmissionClient.
|
||||||
create_cluster_if_needed (bool): Indicates whether the cluster
|
create_cluster_if_needed (bool): Indicates whether the cluster
|
||||||
of the address returned needs to be running. Ray doesn't
|
of the address returned needs to be running. Ray doesn't
|
||||||
start a cluster before interacting with jobs, but other
|
start a cluster before interacting with jobs, but other
|
||||||
implementations may do so.
|
implementations may do so.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ClusterInfo object consisting of address, cookies, and metadata
|
ClusterInfo object consisting of address, cookies, and metadata
|
||||||
for JobSubmissionClient to use.
|
for JobSubmissionClient to use.
|
||||||
"""
|
"""
|
||||||
return ClusterInfo(
|
return ClusterInfo(
|
||||||
address="http://" + address,
|
address="http://" + address, cookies=cookies, metadata=metadata, headers=headers
|
||||||
cookies=cookies,
|
)
|
||||||
metadata=metadata,
|
|
||||||
headers=headers)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_cluster_info(
|
def parse_cluster_info(
|
||||||
address: str,
|
address: str,
|
||||||
create_cluster_if_needed: bool = False,
|
create_cluster_if_needed: bool = False,
|
||||||
cookies: Optional[Dict[str, Any]] = None,
|
cookies: Optional[Dict[str, Any]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
headers: Optional[Dict[str, Any]] = None) -> ClusterInfo:
|
headers: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> ClusterInfo:
|
||||||
module_string, inner_address = _split_address(address.rstrip("/"))
|
module_string, inner_address = _split_address(address.rstrip("/"))
|
||||||
|
|
||||||
# If user passes in a raw HTTP(S) address, just pass it through.
|
# If user passes in a raw HTTP(S) address, just pass it through.
|
||||||
if module_string == "http" or module_string == "https":
|
if module_string == "http" or module_string == "https":
|
||||||
return ClusterInfo(
|
return ClusterInfo(
|
||||||
address=address,
|
address=address, cookies=cookies, metadata=metadata, headers=headers
|
||||||
cookies=cookies,
|
)
|
||||||
metadata=metadata,
|
|
||||||
headers=headers)
|
|
||||||
# If user passes in a Ray address, convert it to HTTP.
|
# If user passes in a Ray address, convert it to HTTP.
|
||||||
elif module_string == "ray":
|
elif module_string == "ray":
|
||||||
return get_job_submission_client_cluster_info(
|
return get_job_submission_client_cluster_info(
|
||||||
|
@ -85,7 +92,8 @@ def parse_cluster_info(
|
||||||
create_cluster_if_needed=create_cluster_if_needed,
|
create_cluster_if_needed=create_cluster_if_needed,
|
||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
headers=headers)
|
headers=headers,
|
||||||
|
)
|
||||||
# Try to dynamically import the function to get cluster info.
|
# Try to dynamically import the function to get cluster info.
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
@ -93,33 +101,40 @@ def parse_cluster_info(
|
||||||
except Exception:
|
except Exception:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Module: {module_string} does not exist.\n"
|
f"Module: {module_string} does not exist.\n"
|
||||||
f"This module was parsed from Address: {address}") from None
|
f"This module was parsed from Address: {address}"
|
||||||
|
) from None
|
||||||
assert "get_job_submission_client_cluster_info" in dir(module), (
|
assert "get_job_submission_client_cluster_info" in dir(module), (
|
||||||
f"Module: {module_string} does "
|
f"Module: {module_string} does "
|
||||||
"not have `get_job_submission_client_cluster_info`.")
|
"not have `get_job_submission_client_cluster_info`."
|
||||||
|
)
|
||||||
|
|
||||||
return module.get_job_submission_client_cluster_info(
|
return module.get_job_submission_client_cluster_info(
|
||||||
inner_address,
|
inner_address,
|
||||||
create_cluster_if_needed=create_cluster_if_needed,
|
create_cluster_if_needed=create_cluster_if_needed,
|
||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
headers=headers)
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JobSubmissionClient:
|
class JobSubmissionClient:
|
||||||
def __init__(self,
|
def __init__(
|
||||||
address: str,
|
self,
|
||||||
create_cluster_if_needed=False,
|
address: str,
|
||||||
cookies: Optional[Dict[str, Any]] = None,
|
create_cluster_if_needed=False,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
cookies: Optional[Dict[str, Any]] = None,
|
||||||
headers: Optional[Dict[str, Any]] = None):
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
headers: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
if requests is None:
|
if requests is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"The Ray jobs CLI & SDK require the ray[default] "
|
"The Ray jobs CLI & SDK require the ray[default] "
|
||||||
"installation: `pip install 'ray[default']``")
|
"installation: `pip install 'ray[default']``"
|
||||||
|
)
|
||||||
|
|
||||||
cluster_info = parse_cluster_info(address, create_cluster_if_needed,
|
cluster_info = parse_cluster_info(
|
||||||
cookies, metadata, headers)
|
address, create_cluster_if_needed, cookies, metadata, headers
|
||||||
|
)
|
||||||
self._address = cluster_info.address
|
self._address = cluster_info.address
|
||||||
self._cookies = cluster_info.cookies
|
self._cookies = cluster_info.cookies
|
||||||
self._default_metadata = cluster_info.metadata or {}
|
self._default_metadata = cluster_info.metadata or {}
|
||||||
|
@ -136,38 +151,43 @@ class JobSubmissionClient:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Jobs API not supported on the Ray cluster. "
|
"Jobs API not supported on the Ray cluster. "
|
||||||
"Please ensure the cluster is running "
|
"Please ensure the cluster is running "
|
||||||
"Ray 1.9 or higher.")
|
"Ray 1.9 or higher."
|
||||||
|
)
|
||||||
|
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
# TODO(edoakes): check the version if/when we break compatibility.
|
# TODO(edoakes): check the version if/when we break compatibility.
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Failed to connect to Ray at address: {self._address}.")
|
f"Failed to connect to Ray at address: {self._address}."
|
||||||
|
)
|
||||||
|
|
||||||
def _raise_error(self, r: "requests.Response"):
|
def _raise_error(self, r: "requests.Response"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Request failed with status code {r.status_code}: {r.text}.")
|
f"Request failed with status code {r.status_code}: {r.text}."
|
||||||
|
)
|
||||||
|
|
||||||
def _do_request(self,
|
def _do_request(
|
||||||
method: str,
|
self,
|
||||||
endpoint: str,
|
method: str,
|
||||||
*,
|
endpoint: str,
|
||||||
data: Optional[bytes] = None,
|
*,
|
||||||
json_data: Optional[dict] = None) -> Optional[object]:
|
data: Optional[bytes] = None,
|
||||||
|
json_data: Optional[dict] = None,
|
||||||
|
) -> Optional[object]:
|
||||||
url = self._address + endpoint
|
url = self._address + endpoint
|
||||||
logger.debug(
|
logger.debug(f"Sending request to {url} with json data: {json_data or {}}.")
|
||||||
f"Sending request to {url} with json data: {json_data or {}}.")
|
|
||||||
return requests.request(
|
return requests.request(
|
||||||
method,
|
method,
|
||||||
url,
|
url,
|
||||||
cookies=self._cookies,
|
cookies=self._cookies,
|
||||||
data=data,
|
data=data,
|
||||||
json=json_data,
|
json=json_data,
|
||||||
headers=self._headers)
|
headers=self._headers,
|
||||||
|
)
|
||||||
|
|
||||||
def _package_exists(
|
def _package_exists(
|
||||||
self,
|
self,
|
||||||
package_uri: str,
|
package_uri: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
protocol, package_name = uri_to_http_components(package_uri)
|
protocol, package_name = uri_to_http_components(package_uri)
|
||||||
r = self._do_request("GET", f"/api/packages/{protocol}/{package_name}")
|
r = self._do_request("GET", f"/api/packages/{protocol}/{package_name}")
|
||||||
|
@ -181,11 +201,13 @@ class JobSubmissionClient:
|
||||||
else:
|
else:
|
||||||
self._raise_error(r)
|
self._raise_error(r)
|
||||||
|
|
||||||
def _upload_package(self,
|
def _upload_package(
|
||||||
package_uri: str,
|
self,
|
||||||
package_path: str,
|
package_uri: str,
|
||||||
include_parent_dir: Optional[bool] = False,
|
package_path: str,
|
||||||
excludes: Optional[List[str]] = None) -> bool:
|
include_parent_dir: Optional[bool] = False,
|
||||||
|
excludes: Optional[List[str]] = None,
|
||||||
|
) -> bool:
|
||||||
logger.info(f"Uploading package {package_uri}.")
|
logger.info(f"Uploading package {package_uri}.")
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
protocol, package_name = uri_to_http_components(package_uri)
|
protocol, package_name = uri_to_http_components(package_uri)
|
||||||
|
@ -194,26 +216,27 @@ class JobSubmissionClient:
|
||||||
package_path,
|
package_path,
|
||||||
package_file,
|
package_file,
|
||||||
include_parent_dir=include_parent_dir,
|
include_parent_dir=include_parent_dir,
|
||||||
excludes=excludes)
|
excludes=excludes,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
r = self._do_request(
|
r = self._do_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
f"/api/packages/{protocol}/{package_name}",
|
f"/api/packages/{protocol}/{package_name}",
|
||||||
data=package_file.read_bytes())
|
data=package_file.read_bytes(),
|
||||||
|
)
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
self._raise_error(r)
|
self._raise_error(r)
|
||||||
finally:
|
finally:
|
||||||
package_file.unlink()
|
package_file.unlink()
|
||||||
|
|
||||||
def _upload_package_if_needed(self,
|
def _upload_package_if_needed(
|
||||||
package_path: str,
|
self, package_path: str, excludes: Optional[List[str]] = None
|
||||||
excludes: Optional[List[str]] = None) -> str:
|
) -> str:
|
||||||
package_uri = get_uri_for_directory(package_path, excludes=excludes)
|
package_uri = get_uri_for_directory(package_path, excludes=excludes)
|
||||||
if not self._package_exists(package_uri):
|
if not self._package_exists(package_uri):
|
||||||
self._upload_package(package_uri, package_path, excludes=excludes)
|
self._upload_package(package_uri, package_path, excludes=excludes)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(f"Package {package_uri} already exists, skipping upload.")
|
||||||
f"Package {package_uri} already exists, skipping upload.")
|
|
||||||
|
|
||||||
return package_uri
|
return package_uri
|
||||||
|
|
||||||
|
@ -230,7 +253,8 @@ class JobSubmissionClient:
|
||||||
if not is_uri:
|
if not is_uri:
|
||||||
logger.debug("working_dir is not a URI, attempting to upload.")
|
logger.debug("working_dir is not a URI, attempting to upload.")
|
||||||
package_uri = self._upload_package_if_needed(
|
package_uri = self._upload_package_if_needed(
|
||||||
working_dir, excludes=runtime_env.get("excludes", None))
|
working_dir, excludes=runtime_env.get("excludes", None)
|
||||||
|
)
|
||||||
runtime_env["working_dir"] = package_uri
|
runtime_env["working_dir"] = package_uri
|
||||||
|
|
||||||
def get_version(self) -> str:
|
def get_version(self) -> str:
|
||||||
|
@ -241,12 +265,12 @@ class JobSubmissionClient:
|
||||||
self._raise_error(r)
|
self._raise_error(r)
|
||||||
|
|
||||||
def submit_job(
|
def submit_job(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
entrypoint: str,
|
entrypoint: str,
|
||||||
job_id: Optional[str] = None,
|
job_id: Optional[str] = None,
|
||||||
runtime_env: Optional[Dict[str, Any]] = None,
|
runtime_env: Optional[Dict[str, Any]] = None,
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
runtime_env = runtime_env or {}
|
runtime_env = runtime_env or {}
|
||||||
metadata = metadata or {}
|
metadata = metadata or {}
|
||||||
|
@ -257,11 +281,11 @@ class JobSubmissionClient:
|
||||||
entrypoint=entrypoint,
|
entrypoint=entrypoint,
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
runtime_env=runtime_env,
|
runtime_env=runtime_env,
|
||||||
metadata=metadata)
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"Submitting job with job_id={job_id}.")
|
logger.debug(f"Submitting job with job_id={job_id}.")
|
||||||
r = self._do_request(
|
r = self._do_request("POST", "/api/jobs/", json_data=dataclasses.asdict(req))
|
||||||
"POST", "/api/jobs/", json_data=dataclasses.asdict(req))
|
|
||||||
|
|
||||||
if r.status_code == 200:
|
if r.status_code == 200:
|
||||||
return JobSubmitResponse(**r.json()).job_id
|
return JobSubmitResponse(**r.json()).job_id
|
||||||
|
@ -269,8 +293,8 @@ class JobSubmissionClient:
|
||||||
self._raise_error(r)
|
self._raise_error(r)
|
||||||
|
|
||||||
def stop_job(
|
def stop_job(
|
||||||
self,
|
self,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
logger.debug(f"Stopping job with job_id={job_id}.")
|
logger.debug(f"Stopping job with job_id={job_id}.")
|
||||||
r = self._do_request("POST", f"/api/jobs/{job_id}/stop")
|
r = self._do_request("POST", f"/api/jobs/{job_id}/stop")
|
||||||
|
@ -281,15 +305,14 @@ class JobSubmissionClient:
|
||||||
self._raise_error(r)
|
self._raise_error(r)
|
||||||
|
|
||||||
def get_job_status(
|
def get_job_status(
|
||||||
self,
|
self,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> JobStatusInfo:
|
) -> JobStatusInfo:
|
||||||
r = self._do_request("GET", f"/api/jobs/{job_id}")
|
r = self._do_request("GET", f"/api/jobs/{job_id}")
|
||||||
|
|
||||||
if r.status_code == 200:
|
if r.status_code == 200:
|
||||||
response = JobStatusResponse(**r.json())
|
response = JobStatusResponse(**r.json())
|
||||||
return JobStatusInfo(
|
return JobStatusInfo(status=response.status, message=response.message)
|
||||||
status=response.status, message=response.message)
|
|
||||||
else:
|
else:
|
||||||
self._raise_error(r)
|
self._raise_error(r)
|
||||||
|
|
||||||
|
@ -304,7 +327,8 @@ class JobSubmissionClient:
|
||||||
async def tail_job_logs(self, job_id: str) -> Iterator[str]:
|
async def tail_job_logs(self, job_id: str) -> Iterator[str]:
|
||||||
async with aiohttp.ClientSession(cookies=self._cookies) as session:
|
async with aiohttp.ClientSession(cookies=self._cookies) as session:
|
||||||
ws = await session.ws_connect(
|
ws = await session.ws_connect(
|
||||||
f"{self._address}/api/jobs/{job_id}/logs/tail")
|
f"{self._address}/api/jobs/{job_id}/logs/tail"
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
msg = await ws.receive()
|
msg = await ws.receive()
|
||||||
|
|
|
@ -7,13 +7,13 @@ we ended up using job submission API call's runtime_env instead of scripts
|
||||||
def run():
|
def run():
|
||||||
import ray
|
import ray
|
||||||
import os
|
import os
|
||||||
|
|
||||||
ray.init(
|
ray.init(
|
||||||
address=os.environ["RAY_ADDRESS"],
|
address=os.environ["RAY_ADDRESS"],
|
||||||
runtime_env={
|
runtime_env={
|
||||||
"env_vars": {
|
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "SHOULD_BE_OVERRIDEN"}
|
||||||
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "SHOULD_BE_OVERRIDEN"
|
},
|
||||||
}
|
)
|
||||||
})
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
def foo():
|
def foo():
|
||||||
|
|
|
@ -21,15 +21,14 @@ def conda_env(env_name):
|
||||||
# Clean up created conda env upon test exit to prevent leaking
|
# Clean up created conda env upon test exit to prevent leaking
|
||||||
del os.environ["JOB_COMPATIBILITY_TEST_TEMP_ENV"]
|
del os.environ["JOB_COMPATIBILITY_TEST_TEMP_ENV"]
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
f"conda env remove -y --name {env_name}",
|
f"conda env remove -y --name {env_name}", shell=True, stdout=subprocess.PIPE
|
||||||
shell=True,
|
)
|
||||||
stdout=subprocess.PIPE)
|
|
||||||
|
|
||||||
|
|
||||||
def _compatibility_script_path(file_name: str) -> str:
|
def _compatibility_script_path(file_name: str) -> str:
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
os.path.dirname(__file__), "backwards_compatibility_scripts",
|
os.path.dirname(__file__), "backwards_compatibility_scripts", file_name
|
||||||
file_name)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestBackwardsCompatibility:
|
class TestBackwardsCompatibility:
|
||||||
|
@ -48,8 +47,7 @@ class TestBackwardsCompatibility:
|
||||||
shell_cmd = f"{_compatibility_script_path('test_backwards_compatibility.sh')}" # noqa: E501
|
shell_cmd = f"{_compatibility_script_path('test_backwards_compatibility.sh')}" # noqa: E501
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.check_output(
|
subprocess.check_output(shell_cmd, shell=True, stderr=subprocess.STDOUT)
|
||||||
shell_cmd, shell=True, stderr=subprocess.STDOUT)
|
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
logger.error(e.stdout.decode())
|
logger.error(e.stdout.decode())
|
||||||
|
|
|
@ -34,8 +34,7 @@ def mock_sdk_client():
|
||||||
|
|
||||||
if "RAY_ADDRESS" in os.environ:
|
if "RAY_ADDRESS" in os.environ:
|
||||||
del os.environ["RAY_ADDRESS"]
|
del os.environ["RAY_ADDRESS"]
|
||||||
with mock.patch("ray.dashboard.modules.job.cli.JobSubmissionClient"
|
with mock.patch("ray.dashboard.modules.job.cli.JobSubmissionClient") as mock_client:
|
||||||
) as mock_client:
|
|
||||||
# In python 3.6 it will fail with error
|
# In python 3.6 it will fail with error
|
||||||
# 'async for' requires an object with __aiter__ method, got MagicMock"
|
# 'async for' requires an object with __aiter__ method, got MagicMock"
|
||||||
mock_client().tail_job_logs.return_value = AsyncIterator(range(10))
|
mock_client().tail_job_logs.return_value = AsyncIterator(range(10))
|
||||||
|
@ -52,9 +51,7 @@ def runtime_env_formats():
|
||||||
"working_dir": "s3://bogus.zip",
|
"working_dir": "s3://bogus.zip",
|
||||||
"conda": "conda_env",
|
"conda": "conda_env",
|
||||||
"pip": ["pip-install-test"],
|
"pip": ["pip-install-test"],
|
||||||
"env_vars": {
|
"env_vars": {"hi": "hi2"},
|
||||||
"hi": "hi2"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
yaml_file = path / "env.yaml"
|
yaml_file = path / "env.yaml"
|
||||||
|
@ -86,14 +83,13 @@ class TestSubmit:
|
||||||
|
|
||||||
# Test passing address via command line.
|
# Test passing address via command line.
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
job_cli_group,
|
job_cli_group, ["submit", "--address=arg_addr", "--", "echo hello"]
|
||||||
["submit", "--address=arg_addr", "--", "echo hello"])
|
)
|
||||||
assert mock_sdk_client.called_with("arg_addr")
|
assert mock_sdk_client.called_with("arg_addr")
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
# Test passing address via env var.
|
# Test passing address via env var.
|
||||||
with set_env_var("RAY_ADDRESS", "env_addr"):
|
with set_env_var("RAY_ADDRESS", "env_addr"):
|
||||||
result = runner.invoke(job_cli_group,
|
result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
|
||||||
["submit", "--", "echo hello"])
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert mock_sdk_client.called_with("env_addr")
|
assert mock_sdk_client.called_with("env_addr")
|
||||||
# Test passing no address.
|
# Test passing no address.
|
||||||
|
@ -106,24 +102,22 @@ class TestSubmit:
|
||||||
mock_client_instance = mock_sdk_client.return_value
|
mock_client_instance = mock_sdk_client.return_value
|
||||||
|
|
||||||
with set_env_var("RAY_ADDRESS", "env_addr"):
|
with set_env_var("RAY_ADDRESS", "env_addr"):
|
||||||
result = runner.invoke(job_cli_group,
|
result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
|
||||||
["submit", "--", "echo hello"])
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert mock_client_instance.called_with(runtime_env={})
|
assert mock_client_instance.called_with(runtime_env={})
|
||||||
|
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
job_cli_group,
|
job_cli_group,
|
||||||
["submit", "--", "--working-dir", "blah", "--", "echo hello"])
|
["submit", "--", "--working-dir", "blah", "--", "echo hello"],
|
||||||
|
)
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert mock_client_instance.called_with(
|
assert mock_client_instance.called_with(runtime_env={"working_dir": "blah"})
|
||||||
runtime_env={"working_dir": "blah"})
|
|
||||||
|
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
job_cli_group,
|
job_cli_group, ["submit", "--", "--working-dir='.'", "--", "echo hello"]
|
||||||
["submit", "--", "--working-dir='.'", "--", "echo hello"])
|
)
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert mock_client_instance.called_with(
|
assert mock_client_instance.called_with(runtime_env={"working_dir": "."})
|
||||||
runtime_env={"working_dir": "."})
|
|
||||||
|
|
||||||
def test_runtime_env(self, mock_sdk_client, runtime_env_formats):
|
def test_runtime_env(self, mock_sdk_client, runtime_env_formats):
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
@ -133,39 +127,64 @@ class TestSubmit:
|
||||||
with set_env_var("RAY_ADDRESS", "env_addr"):
|
with set_env_var("RAY_ADDRESS", "env_addr"):
|
||||||
# Test passing via file.
|
# Test passing via file.
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
job_cli_group,
|
job_cli_group, ["submit", "--runtime-env", env_yaml, "--", "echo hello"]
|
||||||
["submit", "--runtime-env", env_yaml, "--", "echo hello"])
|
)
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert mock_client_instance.called_with(runtime_env=env_dict)
|
assert mock_client_instance.called_with(runtime_env=env_dict)
|
||||||
|
|
||||||
# Test passing via json.
|
# Test passing via json.
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
job_cli_group,
|
job_cli_group,
|
||||||
["submit", "--runtime-env-json", env_json, "--", "echo hello"])
|
["submit", "--runtime-env-json", env_json, "--", "echo hello"],
|
||||||
|
)
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert mock_client_instance.called_with(runtime_env=env_dict)
|
assert mock_client_instance.called_with(runtime_env=env_dict)
|
||||||
|
|
||||||
# Test passing both throws an error.
|
# Test passing both throws an error.
|
||||||
result = runner.invoke(job_cli_group, [
|
result = runner.invoke(
|
||||||
"submit", "--runtime-env", env_yaml, "--runtime-env-json",
|
job_cli_group,
|
||||||
env_json, "--", "echo hello"
|
[
|
||||||
])
|
"submit",
|
||||||
|
"--runtime-env",
|
||||||
|
env_yaml,
|
||||||
|
"--runtime-env-json",
|
||||||
|
env_json,
|
||||||
|
"--",
|
||||||
|
"echo hello",
|
||||||
|
],
|
||||||
|
)
|
||||||
assert result.exit_code == 1
|
assert result.exit_code == 1
|
||||||
assert "Only one of" in str(result.exception)
|
assert "Only one of" in str(result.exception)
|
||||||
|
|
||||||
# Test overriding working_dir.
|
# Test overriding working_dir.
|
||||||
env_dict.update(working_dir=".")
|
env_dict.update(working_dir=".")
|
||||||
result = runner.invoke(job_cli_group, [
|
result = runner.invoke(
|
||||||
"submit", "--runtime-env", env_yaml, "--working-dir", ".",
|
job_cli_group,
|
||||||
"--", "echo hello"
|
[
|
||||||
])
|
"submit",
|
||||||
|
"--runtime-env",
|
||||||
|
env_yaml,
|
||||||
|
"--working-dir",
|
||||||
|
".",
|
||||||
|
"--",
|
||||||
|
"echo hello",
|
||||||
|
],
|
||||||
|
)
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert mock_client_instance.called_with(runtime_env=env_dict)
|
assert mock_client_instance.called_with(runtime_env=env_dict)
|
||||||
|
|
||||||
result = runner.invoke(job_cli_group, [
|
result = runner.invoke(
|
||||||
"submit", "--runtime-env-json", env_json, "--working-dir", ".",
|
job_cli_group,
|
||||||
"--", "echo hello"
|
[
|
||||||
])
|
"submit",
|
||||||
|
"--runtime-env-json",
|
||||||
|
env_json,
|
||||||
|
"--working-dir",
|
||||||
|
".",
|
||||||
|
"--",
|
||||||
|
"echo hello",
|
||||||
|
],
|
||||||
|
)
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert mock_client_instance.called_with(runtime_env=env_dict)
|
assert mock_client_instance.called_with(runtime_env=env_dict)
|
||||||
|
|
||||||
|
@ -174,18 +193,18 @@ class TestSubmit:
|
||||||
mock_client_instance = mock_sdk_client.return_value
|
mock_client_instance = mock_sdk_client.return_value
|
||||||
|
|
||||||
with set_env_var("RAY_ADDRESS", "env_addr"):
|
with set_env_var("RAY_ADDRESS", "env_addr"):
|
||||||
result = runner.invoke(job_cli_group,
|
result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
|
||||||
["submit", "--", "echo hello"])
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert mock_client_instance.called_with(job_id=None)
|
assert mock_client_instance.called_with(job_id=None)
|
||||||
|
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
job_cli_group,
|
job_cli_group, ["submit", "--", "--job-id=my_job_id", "echo hello"]
|
||||||
["submit", "--", "--job-id=my_job_id", "echo hello"])
|
)
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert mock_client_instance.called_with(job_id="my_job_id")
|
assert mock_client_instance.called_with(job_id="my_job_id")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.exit(pytest.main(["-v", __file__]))
|
sys.exit(pytest.main(["-v", __file__]))
|
||||||
|
|
|
@ -60,25 +60,27 @@ class TestRayAddress:
|
||||||
def test_empty_ray_address(self, ray_start_stop):
|
def test_empty_ray_address(self, ray_start_stop):
|
||||||
with set_env_var("RAY_ADDRESS", None):
|
with set_env_var("RAY_ADDRESS", None):
|
||||||
completed_process = subprocess.run(
|
completed_process = subprocess.run(
|
||||||
["ray", "job", "submit", "--", "echo hello"],
|
["ray", "job", "submit", "--", "echo hello"], stderr=subprocess.PIPE
|
||||||
stderr=subprocess.PIPE)
|
)
|
||||||
stderr = completed_process.stderr.decode("utf-8")
|
stderr = completed_process.stderr.decode("utf-8")
|
||||||
# Current dashboard module that raises no exception from requests..
|
# Current dashboard module that raises no exception from requests..
|
||||||
assert ("Address must be specified using either the "
|
assert (
|
||||||
"--address flag or RAY_ADDRESS environment") in stderr
|
"Address must be specified using either the "
|
||||||
|
"--address flag or RAY_ADDRESS environment"
|
||||||
|
) in stderr
|
||||||
|
|
||||||
def test_ray_client_address(self, ray_start_stop):
|
def test_ray_client_address(self, ray_start_stop):
|
||||||
completed_process = subprocess.run(
|
completed_process = subprocess.run(
|
||||||
["ray", "job", "submit", "--", "echo hello"],
|
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
|
||||||
stdout=subprocess.PIPE)
|
)
|
||||||
stdout = completed_process.stdout.decode("utf-8")
|
stdout = completed_process.stdout.decode("utf-8")
|
||||||
assert "hello" in stdout
|
assert "hello" in stdout
|
||||||
assert "succeeded" in stdout
|
assert "succeeded" in stdout
|
||||||
|
|
||||||
def test_valid_http_ray_address(self, ray_start_stop):
|
def test_valid_http_ray_address(self, ray_start_stop):
|
||||||
completed_process = subprocess.run(
|
completed_process = subprocess.run(
|
||||||
["ray", "job", "submit", "--", "echo hello"],
|
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
|
||||||
stdout=subprocess.PIPE)
|
)
|
||||||
stdout = completed_process.stdout.decode("utf-8")
|
stdout = completed_process.stdout.decode("utf-8")
|
||||||
assert "hello" in stdout
|
assert "hello" in stdout
|
||||||
assert "succeeded" in stdout
|
assert "succeeded" in stdout
|
||||||
|
@ -87,8 +89,8 @@ class TestRayAddress:
|
||||||
with set_env_var("RAY_ADDRESS", "http://127.0.0.1:8265"):
|
with set_env_var("RAY_ADDRESS", "http://127.0.0.1:8265"):
|
||||||
with ray_cluster_manager():
|
with ray_cluster_manager():
|
||||||
completed_process = subprocess.run(
|
completed_process = subprocess.run(
|
||||||
["ray", "job", "submit", "--", "echo hello"],
|
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
|
||||||
stdout=subprocess.PIPE)
|
)
|
||||||
stdout = completed_process.stdout.decode("utf-8")
|
stdout = completed_process.stdout.decode("utf-8")
|
||||||
assert "hello" in stdout
|
assert "hello" in stdout
|
||||||
assert "succeeded" in stdout
|
assert "succeeded" in stdout
|
||||||
|
@ -97,8 +99,8 @@ class TestRayAddress:
|
||||||
with set_env_var("RAY_ADDRESS", "127.0.0.1:8265"):
|
with set_env_var("RAY_ADDRESS", "127.0.0.1:8265"):
|
||||||
with ray_cluster_manager():
|
with ray_cluster_manager():
|
||||||
completed_process = subprocess.run(
|
completed_process = subprocess.run(
|
||||||
["ray", "job", "submit", "--", "echo hello"],
|
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
|
||||||
stdout=subprocess.PIPE)
|
)
|
||||||
stdout = completed_process.stdout.decode("utf-8")
|
stdout = completed_process.stdout.decode("utf-8")
|
||||||
assert "hello" in stdout
|
assert "hello" in stdout
|
||||||
assert "succeeded" in stdout
|
assert "succeeded" in stdout
|
||||||
|
@ -109,7 +111,8 @@ class TestJobSubmit:
|
||||||
"""Should tail logs and wait for process to exit."""
|
"""Should tail logs and wait for process to exit."""
|
||||||
cmd = "sleep 1 && echo hello && sleep 1 && echo hello"
|
cmd = "sleep 1 && echo hello && sleep 1 && echo hello"
|
||||||
completed_process = subprocess.run(
|
completed_process = subprocess.run(
|
||||||
["ray", "job", "submit", "--", cmd], stdout=subprocess.PIPE)
|
["ray", "job", "submit", "--", cmd], stdout=subprocess.PIPE
|
||||||
|
)
|
||||||
stdout = completed_process.stdout.decode("utf-8")
|
stdout = completed_process.stdout.decode("utf-8")
|
||||||
assert "hello\nhello" in stdout
|
assert "hello\nhello" in stdout
|
||||||
assert "succeeded" in stdout
|
assert "succeeded" in stdout
|
||||||
|
@ -118,8 +121,8 @@ class TestJobSubmit:
|
||||||
"""Should exit immediately w/o printing logs."""
|
"""Should exit immediately w/o printing logs."""
|
||||||
cmd = "echo hello && sleep 1000"
|
cmd = "echo hello && sleep 1000"
|
||||||
completed_process = subprocess.run(
|
completed_process = subprocess.run(
|
||||||
["ray", "job", "submit", "--no-wait", "--", cmd],
|
["ray", "job", "submit", "--no-wait", "--", cmd], stdout=subprocess.PIPE
|
||||||
stdout=subprocess.PIPE)
|
)
|
||||||
stdout = completed_process.stdout.decode("utf-8")
|
stdout = completed_process.stdout.decode("utf-8")
|
||||||
assert "hello" not in stdout
|
assert "hello" not in stdout
|
||||||
assert "Tailing logs until the job exits" not in stdout
|
assert "Tailing logs until the job exits" not in stdout
|
||||||
|
@ -130,13 +133,13 @@ class TestJobStop:
|
||||||
"""Should wait until the job is stopped."""
|
"""Should wait until the job is stopped."""
|
||||||
cmd = "sleep 1000"
|
cmd = "sleep 1000"
|
||||||
job_id = "test_basic_stop"
|
job_id = "test_basic_stop"
|
||||||
completed_process = subprocess.run([
|
completed_process = subprocess.run(
|
||||||
"ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--",
|
["ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--", cmd]
|
||||||
cmd
|
)
|
||||||
])
|
|
||||||
|
|
||||||
completed_process = subprocess.run(
|
completed_process = subprocess.run(
|
||||||
["ray", "job", "stop", job_id], stdout=subprocess.PIPE)
|
["ray", "job", "stop", job_id], stdout=subprocess.PIPE
|
||||||
|
)
|
||||||
stdout = completed_process.stdout.decode("utf-8")
|
stdout = completed_process.stdout.decode("utf-8")
|
||||||
assert "Waiting for job" in stdout
|
assert "Waiting for job" in stdout
|
||||||
assert f"Job '{job_id}' was stopped" in stdout
|
assert f"Job '{job_id}' was stopped" in stdout
|
||||||
|
@ -145,14 +148,13 @@ class TestJobStop:
|
||||||
"""Should not wait until the job is stopped."""
|
"""Should not wait until the job is stopped."""
|
||||||
cmd = "echo hello && sleep 1000"
|
cmd = "echo hello && sleep 1000"
|
||||||
job_id = "test_stop_no_wait"
|
job_id = "test_stop_no_wait"
|
||||||
completed_process = subprocess.run([
|
completed_process = subprocess.run(
|
||||||
"ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--",
|
["ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--", cmd]
|
||||||
cmd
|
)
|
||||||
])
|
|
||||||
|
|
||||||
completed_process = subprocess.run(
|
completed_process = subprocess.run(
|
||||||
["ray", "job", "stop", "--no-wait", job_id],
|
["ray", "job", "stop", "--no-wait", job_id], stdout=subprocess.PIPE
|
||||||
stdout=subprocess.PIPE)
|
)
|
||||||
stdout = completed_process.stdout.decode("utf-8")
|
stdout = completed_process.stdout.decode("utf-8")
|
||||||
assert "Waiting for job" not in stdout
|
assert "Waiting for job" not in stdout
|
||||||
assert f"Job '{job_id}' was stopped" not in stdout
|
assert f"Job '{job_id}' was stopped" not in stdout
|
||||||
|
|
|
@ -24,82 +24,61 @@ class TestJobSubmitRequestValidation:
|
||||||
assert r.entrypoint == "abc"
|
assert r.entrypoint == "abc"
|
||||||
assert r.job_id is None
|
assert r.job_id is None
|
||||||
|
|
||||||
r = validate_request_type({
|
r = validate_request_type(
|
||||||
"entrypoint": "abc",
|
{"entrypoint": "abc", "job_id": "123"}, JobSubmitRequest
|
||||||
"job_id": "123"
|
)
|
||||||
}, JobSubmitRequest)
|
|
||||||
assert r.entrypoint == "abc"
|
assert r.entrypoint == "abc"
|
||||||
assert r.job_id == "123"
|
assert r.job_id == "123"
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="must be a string"):
|
with pytest.raises(TypeError, match="must be a string"):
|
||||||
validate_request_type({
|
validate_request_type({"entrypoint": 123, "job_id": 1}, JobSubmitRequest)
|
||||||
"entrypoint": 123,
|
|
||||||
"job_id": 1
|
|
||||||
}, JobSubmitRequest)
|
|
||||||
|
|
||||||
def test_validate_runtime_env(self):
|
def test_validate_runtime_env(self):
|
||||||
r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest)
|
r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest)
|
||||||
assert r.entrypoint == "abc"
|
assert r.entrypoint == "abc"
|
||||||
assert r.runtime_env is None
|
assert r.runtime_env is None
|
||||||
|
|
||||||
r = validate_request_type({
|
r = validate_request_type(
|
||||||
"entrypoint": "abc",
|
{"entrypoint": "abc", "runtime_env": {"hi": "hi2"}}, JobSubmitRequest
|
||||||
"runtime_env": {
|
)
|
||||||
"hi": "hi2"
|
|
||||||
}
|
|
||||||
}, JobSubmitRequest)
|
|
||||||
assert r.entrypoint == "abc"
|
assert r.entrypoint == "abc"
|
||||||
assert r.runtime_env == {"hi": "hi2"}
|
assert r.runtime_env == {"hi": "hi2"}
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="must be a dict"):
|
with pytest.raises(TypeError, match="must be a dict"):
|
||||||
validate_request_type({
|
validate_request_type(
|
||||||
"entrypoint": "abc",
|
{"entrypoint": "abc", "runtime_env": 123}, JobSubmitRequest
|
||||||
"runtime_env": 123
|
)
|
||||||
}, JobSubmitRequest)
|
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="keys must be strings"):
|
with pytest.raises(TypeError, match="keys must be strings"):
|
||||||
validate_request_type({
|
validate_request_type(
|
||||||
"entrypoint": "abc",
|
{"entrypoint": "abc", "runtime_env": {1: "hi"}}, JobSubmitRequest
|
||||||
"runtime_env": {
|
)
|
||||||
1: "hi"
|
|
||||||
}
|
|
||||||
}, JobSubmitRequest)
|
|
||||||
|
|
||||||
def test_validate_metadata(self):
|
def test_validate_metadata(self):
|
||||||
r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest)
|
r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest)
|
||||||
assert r.entrypoint == "abc"
|
assert r.entrypoint == "abc"
|
||||||
assert r.metadata is None
|
assert r.metadata is None
|
||||||
|
|
||||||
r = validate_request_type({
|
r = validate_request_type(
|
||||||
"entrypoint": "abc",
|
{"entrypoint": "abc", "metadata": {"hi": "hi2"}}, JobSubmitRequest
|
||||||
"metadata": {
|
)
|
||||||
"hi": "hi2"
|
|
||||||
}
|
|
||||||
}, JobSubmitRequest)
|
|
||||||
assert r.entrypoint == "abc"
|
assert r.entrypoint == "abc"
|
||||||
assert r.metadata == {"hi": "hi2"}
|
assert r.metadata == {"hi": "hi2"}
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="must be a dict"):
|
with pytest.raises(TypeError, match="must be a dict"):
|
||||||
validate_request_type({
|
validate_request_type(
|
||||||
"entrypoint": "abc",
|
{"entrypoint": "abc", "metadata": 123}, JobSubmitRequest
|
||||||
"metadata": 123
|
)
|
||||||
}, JobSubmitRequest)
|
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="keys must be strings"):
|
with pytest.raises(TypeError, match="keys must be strings"):
|
||||||
validate_request_type({
|
validate_request_type(
|
||||||
"entrypoint": "abc",
|
{"entrypoint": "abc", "metadata": {1: "hi"}}, JobSubmitRequest
|
||||||
"metadata": {
|
)
|
||||||
1: "hi"
|
|
||||||
}
|
|
||||||
}, JobSubmitRequest)
|
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="values must be strings"):
|
with pytest.raises(TypeError, match="values must be strings"):
|
||||||
validate_request_type({
|
validate_request_type(
|
||||||
"entrypoint": "abc",
|
{"entrypoint": "abc", "metadata": {"hi": 1}}, JobSubmitRequest
|
||||||
"metadata": {
|
)
|
||||||
"hi": 1
|
|
||||||
}
|
|
||||||
}, JobSubmitRequest)
|
|
||||||
|
|
||||||
|
|
||||||
def test_uri_to_http_and_back():
|
def test_uri_to_http_and_back():
|
||||||
|
@ -127,4 +106,5 @@ def test_uri_to_http_and_back():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.exit(pytest.main(["-v", __file__]))
|
sys.exit(pytest.main(["-v", __file__]))
|
||||||
|
|
|
@ -8,11 +8,17 @@ import pytest
|
||||||
import ray
|
import ray
|
||||||
from ray.dashboard.tests.conftest import * # noqa
|
from ray.dashboard.tests.conftest import * # noqa
|
||||||
from ray.tests.conftest import _ray_start
|
from ray.tests.conftest import _ray_start
|
||||||
from ray._private.test_utils import (format_web_url, wait_for_condition,
|
from ray._private.test_utils import (
|
||||||
wait_until_server_available)
|
format_web_url,
|
||||||
|
wait_for_condition,
|
||||||
|
wait_until_server_available,
|
||||||
|
)
|
||||||
from ray.dashboard.modules.job.common import CURRENT_VERSION, JobStatus
|
from ray.dashboard.modules.job.common import CURRENT_VERSION, JobStatus
|
||||||
from ray.dashboard.modules.job.sdk import (ClusterInfo, JobSubmissionClient,
|
from ray.dashboard.modules.job.sdk import (
|
||||||
parse_cluster_info)
|
ClusterInfo,
|
||||||
|
JobSubmissionClient,
|
||||||
|
parse_cluster_info,
|
||||||
|
)
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -50,8 +56,8 @@ def _check_job_stopped(client: JobSubmissionClient, job_id: str) -> bool:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="module",
|
scope="module", params=["no_working_dir", "local_working_dir", "s3_working_dir"]
|
||||||
params=["no_working_dir", "local_working_dir", "s3_working_dir"])
|
)
|
||||||
def working_dir_option(request):
|
def working_dir_option(request):
|
||||||
if request.param == "no_working_dir":
|
if request.param == "no_working_dir":
|
||||||
yield {
|
yield {
|
||||||
|
@ -81,9 +87,7 @@ def working_dir_option(request):
|
||||||
f.write("from test_module.test import run_test\n")
|
f.write("from test_module.test import run_test\n")
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"runtime_env": {
|
"runtime_env": {"working_dir": tmp_dir},
|
||||||
"working_dir": tmp_dir
|
|
||||||
},
|
|
||||||
"entrypoint": "python test.py",
|
"entrypoint": "python test.py",
|
||||||
"expected_logs": "Hello from test_module!\n",
|
"expected_logs": "Hello from test_module!\n",
|
||||||
}
|
}
|
||||||
|
@ -104,7 +108,8 @@ def test_submit_job(job_sdk_client, working_dir_option):
|
||||||
|
|
||||||
job_id = client.submit_job(
|
job_id = client.submit_job(
|
||||||
entrypoint=working_dir_option["entrypoint"],
|
entrypoint=working_dir_option["entrypoint"],
|
||||||
runtime_env=working_dir_option["runtime_env"])
|
runtime_env=working_dir_option["runtime_env"],
|
||||||
|
)
|
||||||
|
|
||||||
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
|
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
|
||||||
|
|
||||||
|
@ -133,7 +138,8 @@ def test_http_bad_request(job_sdk_client):
|
||||||
def test_invalid_runtime_env(job_sdk_client):
|
def test_invalid_runtime_env(job_sdk_client):
|
||||||
client = job_sdk_client
|
client = job_sdk_client
|
||||||
job_id = client.submit_job(
|
job_id = client.submit_job(
|
||||||
entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"})
|
entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"}
|
||||||
|
)
|
||||||
|
|
||||||
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
|
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
|
||||||
status = client.get_job_status(job_id)
|
status = client.get_job_status(job_id)
|
||||||
|
@ -143,8 +149,8 @@ def test_invalid_runtime_env(job_sdk_client):
|
||||||
def test_runtime_env_setup_failure(job_sdk_client):
|
def test_runtime_env_setup_failure(job_sdk_client):
|
||||||
client = job_sdk_client
|
client = job_sdk_client
|
||||||
job_id = client.submit_job(
|
job_id = client.submit_job(
|
||||||
entrypoint="echo hello",
|
entrypoint="echo hello", runtime_env={"working_dir": "s3://does_not_exist.zip"}
|
||||||
runtime_env={"working_dir": "s3://does_not_exist.zip"})
|
)
|
||||||
|
|
||||||
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
|
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
|
||||||
status = client.get_job_status(job_id)
|
status = client.get_job_status(job_id)
|
||||||
|
@ -168,8 +174,8 @@ raise RuntimeError('Intentionally failed.')
|
||||||
file.write(driver_script)
|
file.write(driver_script)
|
||||||
|
|
||||||
job_id = client.submit_job(
|
job_id = client.submit_job(
|
||||||
entrypoint="python test_script.py",
|
entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
|
||||||
runtime_env={"working_dir": tmp_dir})
|
)
|
||||||
|
|
||||||
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
|
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
|
||||||
logs = client.get_job_logs(job_id)
|
logs = client.get_job_logs(job_id)
|
||||||
|
@ -196,8 +202,8 @@ raise RuntimeError('Intentionally failed.')
|
||||||
file.write(driver_script)
|
file.write(driver_script)
|
||||||
|
|
||||||
job_id = client.submit_job(
|
job_id = client.submit_job(
|
||||||
entrypoint="python test_script.py",
|
entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
|
||||||
runtime_env={"working_dir": tmp_dir})
|
)
|
||||||
assert client.stop_job(job_id) is True
|
assert client.stop_job(job_id) is True
|
||||||
wait_for_condition(_check_job_stopped, client=client, job_id=job_id)
|
wait_for_condition(_check_job_stopped, client=client, job_id=job_id)
|
||||||
|
|
||||||
|
@ -206,28 +212,31 @@ def test_job_metadata(job_sdk_client):
|
||||||
client = job_sdk_client
|
client = job_sdk_client
|
||||||
|
|
||||||
print_metadata_cmd = (
|
print_metadata_cmd = (
|
||||||
"python -c\""
|
'python -c"'
|
||||||
"import ray;"
|
"import ray;"
|
||||||
"ray.init();"
|
"ray.init();"
|
||||||
"job_config=ray.worker.global_worker.core_worker.get_job_config();"
|
"job_config=ray.worker.global_worker.core_worker.get_job_config();"
|
||||||
"print(dict(sorted(job_config.metadata.items())))"
|
"print(dict(sorted(job_config.metadata.items())))"
|
||||||
"\"")
|
'"'
|
||||||
|
)
|
||||||
|
|
||||||
job_id = client.submit_job(
|
job_id = client.submit_job(
|
||||||
entrypoint=print_metadata_cmd,
|
entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
|
||||||
metadata={
|
)
|
||||||
"key1": "val1",
|
|
||||||
"key2": "val2"
|
|
||||||
})
|
|
||||||
|
|
||||||
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
|
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
|
||||||
|
|
||||||
assert str({
|
assert (
|
||||||
"job_name": job_id,
|
str(
|
||||||
"job_submission_id": job_id,
|
{
|
||||||
"key1": "val1",
|
"job_name": job_id,
|
||||||
"key2": "val2"
|
"job_submission_id": job_id,
|
||||||
}) in client.get_job_logs(job_id)
|
"key1": "val1",
|
||||||
|
"key2": "val2",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
in client.get_job_logs(job_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_pass_job_id(job_sdk_client):
|
def test_pass_job_id(job_sdk_client):
|
||||||
|
@ -261,19 +270,19 @@ def test_submit_optional_args(job_sdk_client):
|
||||||
json_data={"entrypoint": "ls"},
|
json_data={"entrypoint": "ls"},
|
||||||
)
|
)
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(_check_job_succeeded, client=client, job_id=r.json()["job_id"])
|
||||||
_check_job_succeeded, client=client, job_id=r.json()["job_id"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_missing_resources(job_sdk_client):
|
def test_missing_resources(job_sdk_client):
|
||||||
"""Check that 404s are raised for resources that don't exist."""
|
"""Check that 404s are raised for resources that don't exist."""
|
||||||
client = job_sdk_client
|
client = job_sdk_client
|
||||||
|
|
||||||
conditions = [("GET",
|
conditions = [
|
||||||
"/api/jobs/fake_job_id"), ("GET",
|
("GET", "/api/jobs/fake_job_id"),
|
||||||
"/api/jobs/fake_job_id/logs"),
|
("GET", "/api/jobs/fake_job_id/logs"),
|
||||||
("POST", "/api/jobs/fake_job_id/stop"),
|
("POST", "/api/jobs/fake_job_id/stop"),
|
||||||
("GET", "/api/packages/fake_package_uri")]
|
("GET", "/api/packages/fake_package_uri"),
|
||||||
|
]
|
||||||
|
|
||||||
for method, route in conditions:
|
for method, route in conditions:
|
||||||
assert client._do_request(method, route).status_code == 404
|
assert client._do_request(method, route).status_code == 404
|
||||||
|
@ -287,7 +296,7 @@ def test_version_endpoint(job_sdk_client):
|
||||||
assert r.json() == {
|
assert r.json() == {
|
||||||
"version": CURRENT_VERSION,
|
"version": CURRENT_VERSION,
|
||||||
"ray_version": ray.__version__,
|
"ray_version": ray.__version__,
|
||||||
"ray_commit": ray.__commit__
|
"ray_commit": ray.__commit__,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -306,26 +315,31 @@ def test_request_headers(job_sdk_client):
|
||||||
cookies=None,
|
cookies=None,
|
||||||
data=None,
|
data=None,
|
||||||
json={"entrypoint": "ls"},
|
json={"entrypoint": "ls"},
|
||||||
headers={
|
headers={"Connection": "keep-alive", "Authorization": "TOK:<MY_TOKEN>"},
|
||||||
"Connection": "keep-alive",
|
)
|
||||||
"Authorization": "TOK:<MY_TOKEN>"
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("address", [
|
@pytest.mark.parametrize(
|
||||||
"http://127.0.0.1", "https://127.0.0.1", "ray://127.0.0.1",
|
"address",
|
||||||
"fake_module://127.0.0.1"
|
[
|
||||||
])
|
"http://127.0.0.1",
|
||||||
|
"https://127.0.0.1",
|
||||||
|
"ray://127.0.0.1",
|
||||||
|
"fake_module://127.0.0.1",
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_parse_cluster_info(address: str):
|
def test_parse_cluster_info(address: str):
|
||||||
if address.startswith("ray"):
|
if address.startswith("ray"):
|
||||||
assert parse_cluster_info(address, False) == ClusterInfo(
|
assert parse_cluster_info(address, False) == ClusterInfo(
|
||||||
address="http" + address[address.index("://"):],
|
address="http" + address[address.index("://") :],
|
||||||
cookies=None,
|
cookies=None,
|
||||||
metadata=None,
|
metadata=None,
|
||||||
headers=None)
|
headers=None,
|
||||||
|
)
|
||||||
elif address.startswith("http") or address.startswith("https"):
|
elif address.startswith("http") or address.startswith("https"):
|
||||||
assert parse_cluster_info(address, False) == ClusterInfo(
|
assert parse_cluster_info(address, False) == ClusterInfo(
|
||||||
address=address, cookies=None, metadata=None, headers=None)
|
address=address, cookies=None, metadata=None, headers=None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
parse_cluster_info(address, False)
|
parse_cluster_info(address, False)
|
||||||
|
@ -347,8 +361,8 @@ for i in range(100):
|
||||||
f.write(driver_script)
|
f.write(driver_script)
|
||||||
|
|
||||||
job_id = client.submit_job(
|
job_id = client.submit_job(
|
||||||
entrypoint="python test_script.py",
|
entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
|
||||||
runtime_env={"working_dir": tmp_dir})
|
)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
async for lines in client.tail_job_logs(job_id):
|
async for lines in client.tail_job_logs(job_id):
|
||||||
|
|
|
@ -8,8 +8,11 @@ import signal
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.dashboard.modules.job.common import (JobStatus, JOB_ID_METADATA_KEY,
|
from ray.dashboard.modules.job.common import (
|
||||||
JOB_NAME_METADATA_KEY)
|
JobStatus,
|
||||||
|
JOB_ID_METADATA_KEY,
|
||||||
|
JOB_NAME_METADATA_KEY,
|
||||||
|
)
|
||||||
from ray.dashboard.modules.job.job_manager import generate_job_id, JobManager
|
from ray.dashboard.modules.job.job_manager import generate_job_id, JobManager
|
||||||
from ray._private.test_utils import SignalActor, wait_for_condition
|
from ray._private.test_utils import SignalActor, wait_for_condition
|
||||||
|
|
||||||
|
@ -28,7 +31,8 @@ def job_manager(shared_ray_instance):
|
||||||
|
|
||||||
def _driver_script_path(file_name: str) -> str:
|
def _driver_script_path(file_name: str) -> str:
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
os.path.dirname(__file__), "subprocess_driver_scripts", file_name)
|
os.path.dirname(__file__), "subprocess_driver_scripts", file_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
|
def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
|
||||||
|
@ -36,12 +40,15 @@ def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
|
||||||
pid_file = os.path.join(tmp_dir, "pid")
|
pid_file = os.path.join(tmp_dir, "pid")
|
||||||
|
|
||||||
# Write subprocess pid to pid_file and block until tmp_file is present.
|
# Write subprocess pid to pid_file and block until tmp_file is present.
|
||||||
wait_for_file_cmd = (f"echo $$ > {pid_file} && "
|
wait_for_file_cmd = (
|
||||||
f"until [ -f {tmp_file} ]; "
|
f"echo $$ > {pid_file} && "
|
||||||
"do echo 'Waiting...' && sleep 1; "
|
f"until [ -f {tmp_file} ]; "
|
||||||
"done")
|
"do echo 'Waiting...' && sleep 1; "
|
||||||
|
"done"
|
||||||
|
)
|
||||||
job_id = job_manager.submit_job(
|
job_id = job_manager.submit_job(
|
||||||
entrypoint=wait_for_file_cmd, _start_signal_actor=start_signal_actor)
|
entrypoint=wait_for_file_cmd, _start_signal_actor=start_signal_actor
|
||||||
|
)
|
||||||
|
|
||||||
status = job_manager.get_job_status(job_id)
|
status = job_manager.get_job_status(job_id)
|
||||||
if start_signal_actor:
|
if start_signal_actor:
|
||||||
|
@ -50,11 +57,9 @@ def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
|
||||||
logs = job_manager.get_job_logs(job_id)
|
logs = job_manager.get_job_logs(job_id)
|
||||||
assert logs == ""
|
assert logs == ""
|
||||||
else:
|
else:
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_running, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_running, job_manager=job_manager, job_id=job_id)
|
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(lambda: "Waiting..." in job_manager.get_job_logs(job_id))
|
||||||
lambda: "Waiting..." in job_manager.get_job_logs(job_id))
|
|
||||||
|
|
||||||
return pid_file, tmp_file, job_id
|
return pid_file, tmp_file, job_id
|
||||||
|
|
||||||
|
@ -63,25 +68,19 @@ def check_job_succeeded(job_manager, job_id):
|
||||||
status = job_manager.get_job_status(job_id)
|
status = job_manager.get_job_status(job_id)
|
||||||
if status.status == JobStatus.FAILED:
|
if status.status == JobStatus.FAILED:
|
||||||
raise RuntimeError(f"Job failed! {status.message}")
|
raise RuntimeError(f"Job failed! {status.message}")
|
||||||
assert status.status in {
|
assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.SUCCEEDED}
|
||||||
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.SUCCEEDED
|
|
||||||
}
|
|
||||||
return status.status == JobStatus.SUCCEEDED
|
return status.status == JobStatus.SUCCEEDED
|
||||||
|
|
||||||
|
|
||||||
def check_job_failed(job_manager, job_id):
|
def check_job_failed(job_manager, job_id):
|
||||||
status = job_manager.get_job_status(job_id)
|
status = job_manager.get_job_status(job_id)
|
||||||
assert status.status in {
|
assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.FAILED}
|
||||||
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.FAILED
|
|
||||||
}
|
|
||||||
return status.status == JobStatus.FAILED
|
return status.status == JobStatus.FAILED
|
||||||
|
|
||||||
|
|
||||||
def check_job_stopped(job_manager, job_id):
|
def check_job_stopped(job_manager, job_id):
|
||||||
status = job_manager.get_job_status(job_id)
|
status = job_manager.get_job_status(job_id)
|
||||||
assert status.status in {
|
assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.STOPPED}
|
||||||
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.STOPPED
|
|
||||||
}
|
|
||||||
return status.status == JobStatus.STOPPED
|
return status.status == JobStatus.STOPPED
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,12 +110,10 @@ def test_generate_job_id():
|
||||||
def test_pass_job_id(job_manager):
|
def test_pass_job_id(job_manager):
|
||||||
job_id = "my_custom_id"
|
job_id = "my_custom_id"
|
||||||
|
|
||||||
returned_id = job_manager.submit_job(
|
returned_id = job_manager.submit_job(entrypoint="echo hello", job_id=job_id)
|
||||||
entrypoint="echo hello", job_id=job_id)
|
|
||||||
assert returned_id == job_id
|
assert returned_id == job_id
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
|
||||||
|
|
||||||
# Check that the same job_id is rejected.
|
# Check that the same job_id is rejected.
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
|
@ -127,23 +124,20 @@ class TestShellScriptExecution:
|
||||||
def test_submit_basic_echo(self, job_manager):
|
def test_submit_basic_echo(self, job_manager):
|
||||||
job_id = job_manager.submit_job(entrypoint="echo hello")
|
job_id = job_manager.submit_job(entrypoint="echo hello")
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
|
||||||
assert job_manager.get_job_logs(job_id) == "hello\n"
|
assert job_manager.get_job_logs(job_id) == "hello\n"
|
||||||
|
|
||||||
def test_submit_stderr(self, job_manager):
|
def test_submit_stderr(self, job_manager):
|
||||||
job_id = job_manager.submit_job(entrypoint="echo error 1>&2")
|
job_id = job_manager.submit_job(entrypoint="echo error 1>&2")
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
|
||||||
assert job_manager.get_job_logs(job_id) == "error\n"
|
assert job_manager.get_job_logs(job_id) == "error\n"
|
||||||
|
|
||||||
def test_submit_ls_grep(self, job_manager):
|
def test_submit_ls_grep(self, job_manager):
|
||||||
grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py"
|
grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py"
|
||||||
job_id = job_manager.submit_job(entrypoint=grep_cmd)
|
job_id = job_manager.submit_job(entrypoint=grep_cmd)
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
|
||||||
assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n"
|
assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n"
|
||||||
|
|
||||||
def test_subprocess_exception(self, job_manager):
|
def test_subprocess_exception(self, job_manager):
|
||||||
|
@ -161,8 +155,7 @@ class TestShellScriptExecution:
|
||||||
status = job_manager.get_job_status(job_id)
|
status = job_manager.get_job_status(job_id)
|
||||||
if status.status != JobStatus.FAILED:
|
if status.status != JobStatus.FAILED:
|
||||||
return False
|
return False
|
||||||
if ("Exception: Script failed with exception !" not in
|
if "Exception: Script failed with exception !" not in status.message:
|
||||||
status.message):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return job_manager._get_actor_for_job(job_id) is None
|
return job_manager._get_actor_for_job(job_id) is None
|
||||||
|
@ -172,14 +165,13 @@ class TestShellScriptExecution:
|
||||||
def test_submit_with_s3_runtime_env(self, job_manager):
|
def test_submit_with_s3_runtime_env(self, job_manager):
|
||||||
job_id = job_manager.submit_job(
|
job_id = job_manager.submit_job(
|
||||||
entrypoint="python script.py",
|
entrypoint="python script.py",
|
||||||
runtime_env={
|
runtime_env={"working_dir": "s3://runtime-env-test/script_runtime_env.zip"},
|
||||||
"working_dir": "s3://runtime-env-test/script_runtime_env.zip"
|
)
|
||||||
})
|
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
assert (
|
||||||
assert job_manager.get_job_logs(
|
job_manager.get_job_logs(job_id) == "Executing main() from script.py !!\n"
|
||||||
job_id) == "Executing main() from script.py !!\n"
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestRuntimeEnv:
|
class TestRuntimeEnv:
|
||||||
|
@ -193,14 +185,10 @@ class TestRuntimeEnv:
|
||||||
"""
|
"""
|
||||||
job_id = job_manager.submit_job(
|
job_id = job_manager.submit_job(
|
||||||
entrypoint="echo $TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR",
|
entrypoint="echo $TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR",
|
||||||
runtime_env={
|
runtime_env={"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"}},
|
||||||
"env_vars": {
|
)
|
||||||
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
|
||||||
assert job_manager.get_job_logs(job_id) == "233\n"
|
assert job_manager.get_job_logs(job_id) == "233\n"
|
||||||
|
|
||||||
def test_multiple_runtime_envs(self, job_manager):
|
def test_multiple_runtime_envs(self, job_manager):
|
||||||
|
@ -208,28 +196,32 @@ class TestRuntimeEnv:
|
||||||
job_id_1 = job_manager.submit_job(
|
job_id_1 = job_manager.submit_job(
|
||||||
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
|
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
|
||||||
runtime_env={
|
runtime_env={
|
||||||
"env_vars": {
|
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
|
||||||
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"
|
},
|
||||||
}
|
)
|
||||||
})
|
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id_1)
|
check_job_succeeded, job_manager=job_manager, job_id=job_id_1
|
||||||
|
)
|
||||||
logs = job_manager.get_job_logs(job_id_1)
|
logs = job_manager.get_job_logs(job_id_1)
|
||||||
assert "{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_1_VAR'}}" in logs # noqa: E501
|
assert (
|
||||||
|
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_1_VAR'}}" in logs
|
||||||
|
) # noqa: E501
|
||||||
|
|
||||||
job_id_2 = job_manager.submit_job(
|
job_id_2 = job_manager.submit_job(
|
||||||
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
|
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
|
||||||
runtime_env={
|
runtime_env={
|
||||||
"env_vars": {
|
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_2_VAR"}
|
||||||
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_2_VAR"
|
},
|
||||||
}
|
)
|
||||||
})
|
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id_2)
|
check_job_succeeded, job_manager=job_manager, job_id=job_id_2
|
||||||
|
)
|
||||||
logs = job_manager.get_job_logs(job_id_2)
|
logs = job_manager.get_job_logs(job_id_2)
|
||||||
assert "{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_2_VAR'}}" in logs # noqa: E501
|
assert (
|
||||||
|
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_2_VAR'}}" in logs
|
||||||
|
) # noqa: E501
|
||||||
|
|
||||||
def test_env_var_and_driver_job_config_warning(self, job_manager):
|
def test_env_var_and_driver_job_config_warning(self, job_manager):
|
||||||
"""Ensure we got error message from worker.py and job logs
|
"""Ensure we got error message from worker.py and job logs
|
||||||
|
@ -238,17 +230,15 @@ class TestRuntimeEnv:
|
||||||
job_id = job_manager.submit_job(
|
job_id = job_manager.submit_job(
|
||||||
entrypoint=f"python {_driver_script_path('override_env_var.py')}",
|
entrypoint=f"python {_driver_script_path('override_env_var.py')}",
|
||||||
runtime_env={
|
runtime_env={
|
||||||
"env_vars": {
|
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
|
||||||
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"
|
},
|
||||||
}
|
)
|
||||||
})
|
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
|
||||||
logs = job_manager.get_job_logs(job_id)
|
logs = job_manager.get_job_logs(job_id)
|
||||||
assert logs.startswith(
|
assert logs.startswith(
|
||||||
"Both RAY_JOB_CONFIG_JSON_ENV_VAR and ray.init(runtime_env) "
|
"Both RAY_JOB_CONFIG_JSON_ENV_VAR and ray.init(runtime_env) " "are provided"
|
||||||
"are provided")
|
)
|
||||||
assert "JOB_1_VAR" in logs
|
assert "JOB_1_VAR" in logs
|
||||||
|
|
||||||
def test_failed_runtime_env_validation(self, job_manager):
|
def test_failed_runtime_env_validation(self, job_manager):
|
||||||
|
@ -257,7 +247,8 @@ class TestRuntimeEnv:
|
||||||
"""
|
"""
|
||||||
run_cmd = f"python {_driver_script_path('override_env_var.py')}"
|
run_cmd = f"python {_driver_script_path('override_env_var.py')}"
|
||||||
job_id = job_manager.submit_job(
|
job_id = job_manager.submit_job(
|
||||||
entrypoint=run_cmd, runtime_env={"working_dir": "path_not_exist"})
|
entrypoint=run_cmd, runtime_env={"working_dir": "path_not_exist"}
|
||||||
|
)
|
||||||
|
|
||||||
status = job_manager.get_job_status(job_id)
|
status = job_manager.get_job_status(job_id)
|
||||||
assert status.status == JobStatus.FAILED
|
assert status.status == JobStatus.FAILED
|
||||||
|
@ -269,11 +260,10 @@ class TestRuntimeEnv:
|
||||||
"""
|
"""
|
||||||
run_cmd = f"python {_driver_script_path('override_env_var.py')}"
|
run_cmd = f"python {_driver_script_path('override_env_var.py')}"
|
||||||
job_id = job_manager.submit_job(
|
job_id = job_manager.submit_job(
|
||||||
entrypoint=run_cmd,
|
entrypoint=run_cmd, runtime_env={"working_dir": "s3://does_not_exist.zip"}
|
||||||
runtime_env={"working_dir": "s3://does_not_exist.zip"})
|
)
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_failed, job_manager=job_manager, job_id=job_id)
|
|
||||||
|
|
||||||
status = job_manager.get_job_status(job_id)
|
status = job_manager.get_job_status(job_id)
|
||||||
assert "runtime_env setup failed" in status.message
|
assert "runtime_env setup failed" in status.message
|
||||||
|
@ -283,69 +273,67 @@ class TestRuntimeEnv:
|
||||||
return str(dict(sorted(d.items())))
|
return str(dict(sorted(d.items())))
|
||||||
|
|
||||||
print_metadata_cmd = (
|
print_metadata_cmd = (
|
||||||
"python -c\""
|
'python -c"'
|
||||||
"import ray;"
|
"import ray;"
|
||||||
"ray.init();"
|
"ray.init();"
|
||||||
"job_config=ray.worker.global_worker.core_worker.get_job_config();"
|
"job_config=ray.worker.global_worker.core_worker.get_job_config();"
|
||||||
"print(dict(sorted(job_config.metadata.items())))"
|
"print(dict(sorted(job_config.metadata.items())))"
|
||||||
"\"")
|
'"'
|
||||||
|
)
|
||||||
|
|
||||||
# Check that we default to only the job ID and job name.
|
# Check that we default to only the job ID and job name.
|
||||||
job_id = job_manager.submit_job(entrypoint=print_metadata_cmd)
|
job_id = job_manager.submit_job(entrypoint=print_metadata_cmd)
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
assert dict_to_str(
|
||||||
assert dict_to_str({
|
{JOB_NAME_METADATA_KEY: job_id, JOB_ID_METADATA_KEY: job_id}
|
||||||
JOB_NAME_METADATA_KEY: job_id,
|
) in job_manager.get_job_logs(job_id)
|
||||||
JOB_ID_METADATA_KEY: job_id
|
|
||||||
}) in job_manager.get_job_logs(job_id)
|
|
||||||
|
|
||||||
# Check that we can pass custom metadata.
|
# Check that we can pass custom metadata.
|
||||||
job_id = job_manager.submit_job(
|
job_id = job_manager.submit_job(
|
||||||
entrypoint=print_metadata_cmd,
|
entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
|
||||||
metadata={
|
)
|
||||||
"key1": "val1",
|
|
||||||
"key2": "val2"
|
|
||||||
})
|
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
assert (
|
||||||
assert dict_to_str({
|
dict_to_str(
|
||||||
JOB_NAME_METADATA_KEY: job_id,
|
{
|
||||||
JOB_ID_METADATA_KEY: job_id,
|
JOB_NAME_METADATA_KEY: job_id,
|
||||||
"key1": "val1",
|
JOB_ID_METADATA_KEY: job_id,
|
||||||
"key2": "val2"
|
"key1": "val1",
|
||||||
}) in job_manager.get_job_logs(job_id)
|
"key2": "val2",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
in job_manager.get_job_logs(job_id)
|
||||||
|
)
|
||||||
|
|
||||||
# Check that we can override job name.
|
# Check that we can override job name.
|
||||||
job_id = job_manager.submit_job(
|
job_id = job_manager.submit_job(
|
||||||
entrypoint=print_metadata_cmd,
|
entrypoint=print_metadata_cmd,
|
||||||
metadata={JOB_NAME_METADATA_KEY: "custom_name"})
|
metadata={JOB_NAME_METADATA_KEY: "custom_name"},
|
||||||
|
)
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
assert dict_to_str(
|
||||||
assert dict_to_str({
|
{JOB_NAME_METADATA_KEY: "custom_name", JOB_ID_METADATA_KEY: job_id}
|
||||||
JOB_NAME_METADATA_KEY: "custom_name",
|
) in job_manager.get_job_logs(job_id)
|
||||||
JOB_ID_METADATA_KEY: job_id
|
|
||||||
}) in job_manager.get_job_logs(job_id)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAsyncAPI:
|
class TestAsyncAPI:
|
||||||
def test_status_and_logs_while_blocking(self, job_manager):
|
def test_status_and_logs_while_blocking(self, job_manager):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
pid_file, tmp_file, job_id = _run_hanging_command(
|
pid_file, tmp_file, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||||
job_manager, tmp_dir)
|
|
||||||
with open(pid_file, "r") as file:
|
with open(pid_file, "r") as file:
|
||||||
pid = int(file.read())
|
pid = int(file.read())
|
||||||
assert psutil.pid_exists(pid), (
|
assert psutil.pid_exists(pid), "driver subprocess should be running"
|
||||||
"driver subprocess should be running")
|
|
||||||
|
|
||||||
# Signal the job to exit by writing to the file.
|
# Signal the job to exit by writing to the file.
|
||||||
with open(tmp_file, "w") as f:
|
with open(tmp_file, "w") as f:
|
||||||
print("hello", file=f)
|
print("hello", file=f)
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||||
|
)
|
||||||
# Ensure driver subprocess gets cleaned up after job reached
|
# Ensure driver subprocess gets cleaned up after job reached
|
||||||
# termination state
|
# termination state
|
||||||
wait_for_condition(check_subprocess_cleaned, pid=pid)
|
wait_for_condition(check_subprocess_cleaned, pid=pid)
|
||||||
|
@ -356,7 +344,8 @@ class TestAsyncAPI:
|
||||||
|
|
||||||
assert job_manager.stop_job(job_id) is True
|
assert job_manager.stop_job(job_id) is True
|
||||||
wait_for_condition(
|
wait_for_condition(
|
||||||
check_job_stopped, job_manager=job_manager, job_id=job_id)
|
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||||
|
)
|
||||||
# Assert re-stopping a stopped job also returns False
|
# Assert re-stopping a stopped job also returns False
|
||||||
wait_for_condition(lambda: job_manager.stop_job(job_id) is False)
|
wait_for_condition(lambda: job_manager.stop_job(job_id) is False)
|
||||||
# Assert stopping non-existent job returns False
|
# Assert stopping non-existent job returns False
|
||||||
|
@ -375,13 +364,11 @@ class TestAsyncAPI:
|
||||||
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||||
with open(pid_file, "r") as file:
|
with open(pid_file, "r") as file:
|
||||||
pid = int(file.read())
|
pid = int(file.read())
|
||||||
assert psutil.pid_exists(pid), (
|
assert psutil.pid_exists(pid), "driver subprocess should be running"
|
||||||
"driver subprocess should be running")
|
|
||||||
|
|
||||||
actor = job_manager._get_actor_for_job(job_id)
|
actor = job_manager._get_actor_for_job(job_id)
|
||||||
ray.kill(actor, no_restart=True)
|
ray.kill(actor, no_restart=True)
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_failed, job_manager=job_manager, job_id=job_id)
|
|
||||||
|
|
||||||
# Ensure driver subprocess gets cleaned up after job reached
|
# Ensure driver subprocess gets cleaned up after job reached
|
||||||
# termination state
|
# termination state
|
||||||
|
@ -398,16 +385,18 @@ class TestAsyncAPI:
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
pid_file, _, job_id = _run_hanging_command(
|
pid_file, _, job_id = _run_hanging_command(
|
||||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor)
|
job_manager, tmp_dir, start_signal_actor=start_signal_actor
|
||||||
|
)
|
||||||
assert not os.path.exists(pid_file), (
|
assert not os.path.exists(pid_file), (
|
||||||
"driver subprocess should NOT be running while job is "
|
"driver subprocess should NOT be running while job is " "still PENDING."
|
||||||
"still PENDING.")
|
)
|
||||||
|
|
||||||
assert job_manager.stop_job(job_id) is True
|
assert job_manager.stop_job(job_id) is True
|
||||||
# Send run signal to unblock run function
|
# Send run signal to unblock run function
|
||||||
ray.get(start_signal_actor.send.remote())
|
ray.get(start_signal_actor.send.remote())
|
||||||
wait_for_condition(
|
wait_for_condition(
|
||||||
check_job_stopped, job_manager=job_manager, job_id=job_id)
|
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||||
|
)
|
||||||
|
|
||||||
def test_kill_job_actor_in_pending(self, job_manager):
|
def test_kill_job_actor_in_pending(self, job_manager):
|
||||||
"""
|
"""
|
||||||
|
@ -420,16 +409,16 @@ class TestAsyncAPI:
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
pid_file, _, job_id = _run_hanging_command(
|
pid_file, _, job_id = _run_hanging_command(
|
||||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor)
|
job_manager, tmp_dir, start_signal_actor=start_signal_actor
|
||||||
|
)
|
||||||
|
|
||||||
assert not os.path.exists(pid_file), (
|
assert not os.path.exists(pid_file), (
|
||||||
"driver subprocess should NOT be running while job is "
|
"driver subprocess should NOT be running while job is " "still PENDING."
|
||||||
"still PENDING.")
|
)
|
||||||
|
|
||||||
actor = job_manager._get_actor_for_job(job_id)
|
actor = job_manager._get_actor_for_job(job_id)
|
||||||
ray.kill(actor, no_restart=True)
|
ray.kill(actor, no_restart=True)
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_failed, job_manager=job_manager, job_id=job_id)
|
|
||||||
|
|
||||||
def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager):
|
def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager):
|
||||||
"""
|
"""
|
||||||
|
@ -442,12 +431,12 @@ class TestAsyncAPI:
|
||||||
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||||
with open(pid_file, "r") as file:
|
with open(pid_file, "r") as file:
|
||||||
pid = int(file.read())
|
pid = int(file.read())
|
||||||
assert psutil.pid_exists(pid), (
|
assert psutil.pid_exists(pid), "driver subprocess should be running"
|
||||||
"driver subprocess should be running")
|
|
||||||
|
|
||||||
assert job_manager.stop_job(job_id) is True
|
assert job_manager.stop_job(job_id) is True
|
||||||
wait_for_condition(
|
wait_for_condition(
|
||||||
check_job_stopped, job_manager=job_manager, job_id=job_id)
|
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure driver subprocess gets cleaned up after job reached
|
# Ensure driver subprocess gets cleaned up after job reached
|
||||||
# termination state
|
# termination state
|
||||||
|
@ -455,11 +444,9 @@ class TestAsyncAPI:
|
||||||
|
|
||||||
|
|
||||||
class TestTailLogs:
|
class TestTailLogs:
|
||||||
async def _tail_and_assert_logs(self,
|
async def _tail_and_assert_logs(
|
||||||
job_id,
|
self, job_id, job_manager, expected_log="", num_iteration=5
|
||||||
job_manager,
|
):
|
||||||
expected_log="",
|
|
||||||
num_iteration=5):
|
|
||||||
i = 0
|
i = 0
|
||||||
async for lines in job_manager.tail_job_logs(job_id):
|
async for lines in job_manager.tail_job_logs(job_id):
|
||||||
assert all(s == expected_log for s in lines.strip().split("\n"))
|
assert all(s == expected_log for s in lines.strip().split("\n"))
|
||||||
|
@ -470,8 +457,7 @@ class TestTailLogs:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unknown_job(self, job_manager):
|
async def test_unknown_job(self, job_manager):
|
||||||
with pytest.raises(
|
with pytest.raises(RuntimeError, match="Job 'unknown' does not exist."):
|
||||||
RuntimeError, match="Job 'unknown' does not exist."):
|
|
||||||
async for _ in job_manager.tail_job_logs("unknown"):
|
async for _ in job_manager.tail_job_logs("unknown"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -482,33 +468,31 @@ class TestTailLogs:
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
_, tmp_file, job_id = _run_hanging_command(
|
_, tmp_file, job_id = _run_hanging_command(
|
||||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor)
|
job_manager, tmp_dir, start_signal_actor=start_signal_actor
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(edoakes): check we get no logs before actor starts (not sure
|
# TODO(edoakes): check we get no logs before actor starts (not sure
|
||||||
# how to timeout the iterator call).
|
# how to timeout the iterator call).
|
||||||
assert job_manager.get_job_status(
|
assert job_manager.get_job_status(job_id).status == JobStatus.PENDING
|
||||||
job_id).status == JobStatus.PENDING
|
|
||||||
|
|
||||||
# Signal job to start.
|
# Signal job to start.
|
||||||
ray.get(start_signal_actor.send.remote())
|
ray.get(start_signal_actor.send.remote())
|
||||||
|
|
||||||
await self._tail_and_assert_logs(
|
await self._tail_and_assert_logs(
|
||||||
job_id,
|
job_id, job_manager, expected_log="Waiting...", num_iteration=5
|
||||||
job_manager,
|
)
|
||||||
expected_log="Waiting...",
|
|
||||||
num_iteration=5)
|
|
||||||
|
|
||||||
# Signal the job to exit by writing to the file.
|
# Signal the job to exit by writing to the file.
|
||||||
with open(tmp_file, "w") as f:
|
with open(tmp_file, "w") as f:
|
||||||
print("hello", file=f)
|
print("hello", file=f)
|
||||||
|
|
||||||
async for lines in job_manager.tail_job_logs(job_id):
|
async for lines in job_manager.tail_job_logs(job_id):
|
||||||
assert all(
|
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
|
||||||
s == "Waiting..." for s in lines.strip().split("\n"))
|
|
||||||
print(lines, end="")
|
print(lines, end="")
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(
|
||||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_failed_job(self, job_manager):
|
async def test_failed_job(self, job_manager):
|
||||||
|
@ -517,22 +501,18 @@ class TestTailLogs:
|
||||||
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||||
|
|
||||||
await self._tail_and_assert_logs(
|
await self._tail_and_assert_logs(
|
||||||
job_id,
|
job_id, job_manager, expected_log="Waiting...", num_iteration=5
|
||||||
job_manager,
|
)
|
||||||
expected_log="Waiting...",
|
|
||||||
num_iteration=5)
|
|
||||||
|
|
||||||
# Kill the job unexpectedly.
|
# Kill the job unexpectedly.
|
||||||
with open(pid_file, "r") as f:
|
with open(pid_file, "r") as f:
|
||||||
os.kill(int(f.read()), signal.SIGKILL)
|
os.kill(int(f.read()), signal.SIGKILL)
|
||||||
|
|
||||||
async for lines in job_manager.tail_job_logs(job_id):
|
async for lines in job_manager.tail_job_logs(job_id):
|
||||||
assert all(
|
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
|
||||||
s == "Waiting..." for s in lines.strip().split("\n"))
|
|
||||||
print(lines, end="")
|
print(lines, end="")
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||||
check_job_failed, job_manager=job_manager, job_id=job_id)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stopped_job(self, job_manager):
|
async def test_stopped_job(self, job_manager):
|
||||||
|
@ -541,21 +521,19 @@ class TestTailLogs:
|
||||||
_, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
_, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||||
|
|
||||||
await self._tail_and_assert_logs(
|
await self._tail_and_assert_logs(
|
||||||
job_id,
|
job_id, job_manager, expected_log="Waiting...", num_iteration=5
|
||||||
job_manager,
|
)
|
||||||
expected_log="Waiting...",
|
|
||||||
num_iteration=5)
|
|
||||||
|
|
||||||
# Stop the job via the API.
|
# Stop the job via the API.
|
||||||
job_manager.stop_job(job_id)
|
job_manager.stop_job(job_id)
|
||||||
|
|
||||||
async for lines in job_manager.tail_job_logs(job_id):
|
async for lines in job_manager.tail_job_logs(job_id):
|
||||||
assert all(
|
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
|
||||||
s == "Waiting..." for s in lines.strip().split("\n"))
|
|
||||||
print(lines, end="")
|
print(lines, end="")
|
||||||
|
|
||||||
wait_for_condition(
|
wait_for_condition(
|
||||||
check_job_stopped, job_manager=job_manager, job_id=job_id)
|
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_logs_streaming(job_manager):
|
def test_logs_streaming(job_manager):
|
||||||
|
@ -568,7 +546,7 @@ while True:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stream_logs_cmd = f"python -c \"{stream_logs_script}\""
|
stream_logs_cmd = f'python -c "{stream_logs_script}"'
|
||||||
|
|
||||||
job_id = job_manager.submit_job(entrypoint=stream_logs_cmd)
|
job_id = job_manager.submit_job(entrypoint=stream_logs_cmd)
|
||||||
wait_for_condition(lambda: "STREAMED" in job_manager.get_job_logs(job_id))
|
wait_for_condition(lambda: "STREAMED" in job_manager.get_job_logs(job_id))
|
||||||
|
|
|
@ -12,7 +12,7 @@ def tmp():
|
||||||
yield f.name
|
yield f.name
|
||||||
|
|
||||||
|
|
||||||
class TestIterLine():
|
class TestIterLine:
|
||||||
def test_invalid_type(self):
|
def test_invalid_type(self):
|
||||||
with pytest.raises(TypeError, match="path must be a string"):
|
with pytest.raises(TypeError, match="path must be a string"):
|
||||||
next(file_tail_iterator(1))
|
next(file_tail_iterator(1))
|
||||||
|
|
|
@ -19,10 +19,8 @@ class LogHead(dashboard_utils.DashboardHeadModule):
|
||||||
self._proxy_session = aiohttp.ClientSession(auto_decompress=False)
|
self._proxy_session = aiohttp.ClientSession(auto_decompress=False)
|
||||||
log_utils.register_mimetypes()
|
log_utils.register_mimetypes()
|
||||||
routes.static("/logs", self._dashboard_head.log_dir, show_index=True)
|
routes.static("/logs", self._dashboard_head.log_dir, show_index=True)
|
||||||
GlobalSignals.node_info_fetched.append(
|
GlobalSignals.node_info_fetched.append(self.insert_log_url_to_node_info)
|
||||||
self.insert_log_url_to_node_info)
|
GlobalSignals.node_summary_fetched.append(self.insert_log_url_to_node_info)
|
||||||
GlobalSignals.node_summary_fetched.append(
|
|
||||||
self.insert_log_url_to_node_info)
|
|
||||||
|
|
||||||
async def insert_log_url_to_node_info(self, node_info):
|
async def insert_log_url_to_node_info(self, node_info):
|
||||||
node_id = node_info.get("raylet", {}).get("nodeId")
|
node_id = node_info.get("raylet", {}).get("nodeId")
|
||||||
|
@ -33,7 +31,8 @@ class LogHead(dashboard_utils.DashboardHeadModule):
|
||||||
return
|
return
|
||||||
agent_http_port, _ = agent_port
|
agent_http_port, _ = agent_port
|
||||||
log_url = self.LOG_URL_TEMPLATE.format(
|
log_url = self.LOG_URL_TEMPLATE.format(
|
||||||
ip=node_info.get("ip"), port=agent_http_port)
|
ip=node_info.get("ip"), port=agent_http_port
|
||||||
|
)
|
||||||
node_info["logUrl"] = log_url
|
node_info["logUrl"] = log_url
|
||||||
|
|
||||||
@routes.get("/log_index")
|
@routes.get("/log_index")
|
||||||
|
@ -43,15 +42,16 @@ class LogHead(dashboard_utils.DashboardHeadModule):
|
||||||
for node_id, ports in DataSource.agents.items():
|
for node_id, ports in DataSource.agents.items():
|
||||||
ip = DataSource.node_id_to_ip[node_id]
|
ip = DataSource.node_id_to_ip[node_id]
|
||||||
agent_ips.append(ip)
|
agent_ips.append(ip)
|
||||||
url_list.append(
|
url_list.append(self.LOG_URL_TEMPLATE.format(ip=ip, port=str(ports[0])))
|
||||||
self.LOG_URL_TEMPLATE.format(ip=ip, port=str(ports[0])))
|
|
||||||
if self._dashboard_head.ip not in agent_ips:
|
if self._dashboard_head.ip not in agent_ips:
|
||||||
url_list.append(
|
url_list.append(
|
||||||
self.LOG_URL_TEMPLATE.format(
|
self.LOG_URL_TEMPLATE.format(
|
||||||
ip=self._dashboard_head.ip,
|
ip=self._dashboard_head.ip, port=self._dashboard_head.http_port
|
||||||
port=self._dashboard_head.http_port))
|
)
|
||||||
|
)
|
||||||
return aiohttp.web.Response(
|
return aiohttp.web.Response(
|
||||||
text=self._directory_as_html(url_list), content_type="text/html")
|
text=self._directory_as_html(url_list), content_type="text/html"
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/log_proxy")
|
@routes.get("/log_proxy")
|
||||||
async def get_log_from_proxy(self, req) -> aiohttp.web.StreamResponse:
|
async def get_log_from_proxy(self, req) -> aiohttp.web.StreamResponse:
|
||||||
|
@ -60,9 +60,11 @@ class LogHead(dashboard_utils.DashboardHeadModule):
|
||||||
raise Exception("url is None.")
|
raise Exception("url is None.")
|
||||||
body = await req.read()
|
body = await req.read()
|
||||||
async with self._proxy_session.request(
|
async with self._proxy_session.request(
|
||||||
req.method, url, data=body, headers=req.headers) as r:
|
req.method, url, data=body, headers=req.headers
|
||||||
|
) as r:
|
||||||
sr = aiohttp.web.StreamResponse(
|
sr = aiohttp.web.StreamResponse(
|
||||||
status=r.status, reason=r.reason, headers=req.headers)
|
status=r.status, reason=r.reason, headers=req.headers
|
||||||
|
)
|
||||||
sr.content_length = r.content_length
|
sr.content_length = r.content_length
|
||||||
sr.content_type = r.content_type
|
sr.content_type = r.content_type
|
||||||
sr.charset = r.charset
|
sr.charset = r.charset
|
||||||
|
|
|
@ -40,8 +40,7 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
|
|
||||||
test_log_text = "test_log_text"
|
test_log_text = "test_log_text"
|
||||||
ray.get(write_log.remote(test_log_text))
|
ray.get(write_log.remote(test_log_text))
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
is True)
|
|
||||||
webui_url = ray_start_with_dashboard["webui_url"]
|
webui_url = ray_start_with_dashboard["webui_url"]
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
node_id = ray_start_with_dashboard["node_id"]
|
node_id = ray_start_with_dashboard["node_id"]
|
||||||
|
@ -82,8 +81,8 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
|
|
||||||
# Test range request.
|
# Test range request.
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
webui_url + "/logs/dashboard.log",
|
webui_url + "/logs/dashboard.log", headers={"Range": "bytes=44-52"}
|
||||||
headers={"Range": "bytes=44-52"})
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
assert response.text == "Dashboard"
|
assert response.text == "Dashboard"
|
||||||
|
|
||||||
|
@ -100,16 +99,19 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
last_ex = ex
|
last_ex = ex
|
||||||
finally:
|
finally:
|
||||||
if time.time() > start_time + timeout_seconds:
|
if time.time() > start_time + timeout_seconds:
|
||||||
ex_stack = traceback.format_exception(
|
ex_stack = (
|
||||||
type(last_ex), last_ex,
|
traceback.format_exception(
|
||||||
last_ex.__traceback__) if last_ex else []
|
type(last_ex), last_ex, last_ex.__traceback__
|
||||||
|
)
|
||||||
|
if last_ex
|
||||||
|
else []
|
||||||
|
)
|
||||||
ex_stack = "".join(ex_stack)
|
ex_stack = "".join(ex_stack)
|
||||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||||
|
|
||||||
|
|
||||||
def test_log_proxy(ray_start_with_dashboard):
|
def test_log_proxy(ray_start_with_dashboard):
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
is True)
|
|
||||||
webui_url = ray_start_with_dashboard["webui_url"]
|
webui_url = ray_start_with_dashboard["webui_url"]
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
|
|
||||||
|
@ -122,21 +124,27 @@ def test_log_proxy(ray_start_with_dashboard):
|
||||||
# Test range request.
|
# Test range request.
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{webui_url}/log_proxy?url={webui_url}/logs/dashboard.log",
|
f"{webui_url}/log_proxy?url={webui_url}/logs/dashboard.log",
|
||||||
headers={"Range": "bytes=44-52"})
|
headers={"Range": "bytes=44-52"},
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
assert response.text == "Dashboard"
|
assert response.text == "Dashboard"
|
||||||
# Test 404.
|
# Test 404.
|
||||||
response = requests.get(f"{webui_url}/log_proxy?"
|
response = requests.get(
|
||||||
f"url={webui_url}/logs/not_exist_file.log")
|
f"{webui_url}/log_proxy?" f"url={webui_url}/logs/not_exist_file.log"
|
||||||
|
)
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
break
|
break
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
last_ex = ex
|
last_ex = ex
|
||||||
finally:
|
finally:
|
||||||
if time.time() > start_time + timeout_seconds:
|
if time.time() > start_time + timeout_seconds:
|
||||||
ex_stack = traceback.format_exception(
|
ex_stack = (
|
||||||
type(last_ex), last_ex,
|
traceback.format_exception(
|
||||||
last_ex.__traceback__) if last_ex else []
|
type(last_ex), last_ex, last_ex.__traceback__
|
||||||
|
)
|
||||||
|
if last_ex
|
||||||
|
else []
|
||||||
|
)
|
||||||
ex_stack = "".join(ex_stack)
|
ex_stack = "".join(ex_stack)
|
||||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,10 @@ import ray._private.utils
|
||||||
import ray._private.gcs_utils as gcs_utils
|
import ray._private.gcs_utils as gcs_utils
|
||||||
from ray import ray_constants
|
from ray import ray_constants
|
||||||
from ray.dashboard.modules.node import node_consts
|
from ray.dashboard.modules.node import node_consts
|
||||||
from ray.dashboard.modules.node.node_consts import (MAX_LOGS_TO_CACHE,
|
from ray.dashboard.modules.node.node_consts import (
|
||||||
LOG_PRUNE_THREASHOLD)
|
MAX_LOGS_TO_CACHE,
|
||||||
|
LOG_PRUNE_THREASHOLD,
|
||||||
|
)
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||||
import ray.dashboard.consts as dashboard_consts
|
import ray.dashboard.consts as dashboard_consts
|
||||||
|
@ -28,13 +30,21 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
|
||||||
|
|
||||||
def gcs_node_info_to_dict(message):
|
def gcs_node_info_to_dict(message):
|
||||||
return dashboard_utils.message_to_dict(
|
return dashboard_utils.message_to_dict(
|
||||||
message, {"nodeId"}, including_default_value_fields=True)
|
message, {"nodeId"}, including_default_value_fields=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def node_stats_to_dict(message):
|
def node_stats_to_dict(message):
|
||||||
decode_keys = {
|
decode_keys = {
|
||||||
"actorId", "jobId", "taskId", "parentTaskId", "sourceActorId",
|
"actorId",
|
||||||
"callerId", "rayletId", "workerId", "placementGroupId"
|
"jobId",
|
||||||
|
"taskId",
|
||||||
|
"parentTaskId",
|
||||||
|
"sourceActorId",
|
||||||
|
"callerId",
|
||||||
|
"rayletId",
|
||||||
|
"workerId",
|
||||||
|
"placementGroupId",
|
||||||
}
|
}
|
||||||
core_workers_stats = message.core_workers_stats
|
core_workers_stats = message.core_workers_stats
|
||||||
message.ClearField("core_workers_stats")
|
message.ClearField("core_workers_stats")
|
||||||
|
@ -42,7 +52,8 @@ def node_stats_to_dict(message):
|
||||||
result = dashboard_utils.message_to_dict(message, decode_keys)
|
result = dashboard_utils.message_to_dict(message, decode_keys)
|
||||||
result["coreWorkersStats"] = [
|
result["coreWorkersStats"] = [
|
||||||
dashboard_utils.message_to_dict(
|
dashboard_utils.message_to_dict(
|
||||||
m, decode_keys, including_default_value_fields=True)
|
m, decode_keys, including_default_value_fields=True
|
||||||
|
)
|
||||||
for m in core_workers_stats
|
for m in core_workers_stats
|
||||||
]
|
]
|
||||||
return result
|
return result
|
||||||
|
@ -66,11 +77,13 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
if change.new:
|
if change.new:
|
||||||
# TODO(fyrestone): Handle exceptions.
|
# TODO(fyrestone): Handle exceptions.
|
||||||
node_id, node_info = change.new
|
node_id, node_info = change.new
|
||||||
address = "{}:{}".format(node_info["nodeManagerAddress"],
|
address = "{}:{}".format(
|
||||||
int(node_info["nodeManagerPort"]))
|
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
|
||||||
options = (("grpc.enable_http_proxy", 0), )
|
)
|
||||||
|
options = (("grpc.enable_http_proxy", 0),)
|
||||||
channel = ray._private.utils.init_grpc_channel(
|
channel = ray._private.utils.init_grpc_channel(
|
||||||
address, options, asynchronous=True)
|
address, options, asynchronous=True
|
||||||
|
)
|
||||||
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
|
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
|
||||||
self._stubs[node_id] = stub
|
self._stubs[node_id] = stub
|
||||||
|
|
||||||
|
@ -81,8 +94,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
A dict of information about the nodes in the cluster.
|
A dict of information about the nodes in the cluster.
|
||||||
"""
|
"""
|
||||||
request = gcs_service_pb2.GetAllNodeInfoRequest()
|
request = gcs_service_pb2.GetAllNodeInfoRequest()
|
||||||
reply = await self._gcs_node_info_stub.GetAllNodeInfo(
|
reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=2)
|
||||||
request, timeout=2)
|
|
||||||
if reply.status.code == 0:
|
if reply.status.code == 0:
|
||||||
result = {}
|
result = {}
|
||||||
for node_info in reply.node_info_list:
|
for node_info in reply.node_info_list:
|
||||||
|
@ -116,11 +128,11 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
|
|
||||||
agents = dict(DataSource.agents)
|
agents = dict(DataSource.agents)
|
||||||
for node_id in alive_node_ids:
|
for node_id in alive_node_ids:
|
||||||
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" \
|
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" f"{node_id}"
|
||||||
f"{node_id}"
|
|
||||||
# TODO: Use async version if performance is an issue
|
# TODO: Use async version if performance is an issue
|
||||||
agent_port = ray.experimental.internal_kv._internal_kv_get(
|
agent_port = ray.experimental.internal_kv._internal_kv_get(
|
||||||
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
|
||||||
|
)
|
||||||
if agent_port:
|
if agent_port:
|
||||||
agents[node_id] = json.loads(agent_port)
|
agents[node_id] = json.loads(agent_port)
|
||||||
for node_id in agents.keys() - set(alive_node_ids):
|
for node_id in agents.keys() - set(alive_node_ids):
|
||||||
|
@ -142,9 +154,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
if view == "summary":
|
if view == "summary":
|
||||||
all_node_summary = await DataOrganizer.get_all_node_summary()
|
all_node_summary = await DataOrganizer.get_all_node_summary()
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True,
|
success=True, message="Node summary fetched.", summary=all_node_summary
|
||||||
message="Node summary fetched.",
|
)
|
||||||
summary=all_node_summary)
|
|
||||||
elif view == "details":
|
elif view == "details":
|
||||||
all_node_details = await DataOrganizer.get_all_node_details()
|
all_node_details = await DataOrganizer.get_all_node_details()
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
|
@ -160,10 +171,12 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True,
|
success=True,
|
||||||
message="Node hostname list fetched.",
|
message="Node hostname list fetched.",
|
||||||
host_name_list=list(alive_hostnames))
|
host_name_list=list(alive_hostnames),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=False, message=f"Unknown view {view}")
|
success=False, message=f"Unknown view {view}"
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/nodes/{node_id}")
|
@routes.get("/nodes/{node_id}")
|
||||||
@dashboard_optional_utils.aiohttp_cache
|
@dashboard_optional_utils.aiohttp_cache
|
||||||
|
@ -171,7 +184,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
node_id = req.match_info.get("node_id")
|
node_id = req.match_info.get("node_id")
|
||||||
node_info = await DataOrganizer.get_node_info(node_id)
|
node_info = await DataOrganizer.get_node_info(node_id)
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True, message="Node details fetched.", detail=node_info)
|
success=True, message="Node details fetched.", detail=node_info
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/memory/memory_table")
|
@routes.get("/memory/memory_table")
|
||||||
async def get_memory_table(self, req) -> aiohttp.web.Response:
|
async def get_memory_table(self, req) -> aiohttp.web.Response:
|
||||||
|
@ -187,7 +201,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True,
|
success=True,
|
||||||
message="Fetched memory table",
|
message="Fetched memory table",
|
||||||
memory_table=memory_table.as_dict())
|
memory_table=memory_table.as_dict(),
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/memory/set_fetch")
|
@routes.get("/memory/set_fetch")
|
||||||
async def set_fetch_memory_info(self, req) -> aiohttp.web.Response:
|
async def set_fetch_memory_info(self, req) -> aiohttp.web.Response:
|
||||||
|
@ -198,11 +213,11 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
self._collect_memory_info = False
|
self._collect_memory_info = False
|
||||||
else:
|
else:
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=False,
|
success=False, message=f"Unknown argument to set_fetch {should_fetch}"
|
||||||
message=f"Unknown argument to set_fetch {should_fetch}")
|
)
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True,
|
success=True, message=f"Successfully set fetching to {should_fetch}"
|
||||||
message=f"Successfully set fetching to {should_fetch}")
|
)
|
||||||
|
|
||||||
@routes.get("/node_logs")
|
@routes.get("/node_logs")
|
||||||
async def get_logs(self, req) -> aiohttp.web.Response:
|
async def get_logs(self, req) -> aiohttp.web.Response:
|
||||||
|
@ -212,7 +227,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
if pid:
|
if pid:
|
||||||
node_logs = {str(pid): node_logs.get(pid, [])}
|
node_logs = {str(pid): node_logs.get(pid, [])}
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True, message="Fetched logs.", logs=node_logs)
|
success=True, message="Fetched logs.", logs=node_logs
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/node_errors")
|
@routes.get("/node_errors")
|
||||||
async def get_errors(self, req) -> aiohttp.web.Response:
|
async def get_errors(self, req) -> aiohttp.web.Response:
|
||||||
|
@ -222,7 +238,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
if pid:
|
if pid:
|
||||||
node_errors = {str(pid): node_errors.get(pid, [])}
|
node_errors = {str(pid): node_errors.get(pid, [])}
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True, message="Fetched errors.", errors=node_errors)
|
success=True, message="Fetched errors.", errors=node_errors
|
||||||
|
)
|
||||||
|
|
||||||
@async_loop_forever(node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS)
|
@async_loop_forever(node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS)
|
||||||
async def _update_node_stats(self):
|
async def _update_node_stats(self):
|
||||||
|
@ -234,8 +251,10 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
try:
|
try:
|
||||||
reply = await stub.GetNodeStats(
|
reply = await stub.GetNodeStats(
|
||||||
node_manager_pb2.GetNodeStatsRequest(
|
node_manager_pb2.GetNodeStatsRequest(
|
||||||
include_memory_info=self._collect_memory_info),
|
include_memory_info=self._collect_memory_info
|
||||||
timeout=2)
|
),
|
||||||
|
timeout=2,
|
||||||
|
)
|
||||||
reply_dict = node_stats_to_dict(reply)
|
reply_dict = node_stats_to_dict(reply)
|
||||||
DataSource.node_stats[node_id] = reply_dict
|
DataSource.node_stats[node_id] = reply_dict
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -263,8 +282,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
if self._dashboard_head.gcs_log_subscriber:
|
if self._dashboard_head.gcs_log_subscriber:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
log_batch = await \
|
log_batch = await self._dashboard_head.gcs_log_subscriber.poll()
|
||||||
self._dashboard_head.gcs_log_subscriber.poll()
|
|
||||||
if log_batch is None:
|
if log_batch is None:
|
||||||
continue
|
continue
|
||||||
process_log_batch(log_batch)
|
process_log_batch(log_batch)
|
||||||
|
@ -296,11 +314,13 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
ip = match.group(2)
|
ip = match.group(2)
|
||||||
errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
|
errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
|
||||||
pid_errors = list(errs_for_ip.get(pid, []))
|
pid_errors = list(errs_for_ip.get(pid, []))
|
||||||
pid_errors.append({
|
pid_errors.append(
|
||||||
"message": message,
|
{
|
||||||
"timestamp": error_data.timestamp,
|
"message": message,
|
||||||
"type": error_data.type
|
"timestamp": error_data.timestamp,
|
||||||
})
|
"type": error_data.type,
|
||||||
|
}
|
||||||
|
)
|
||||||
errs_for_ip[pid] = pid_errors
|
errs_for_ip[pid] = pid_errors
|
||||||
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
|
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
|
||||||
logger.info(f"Received error entry for {ip} {pid}")
|
logger.info(f"Received error entry for {ip} {pid}")
|
||||||
|
@ -308,8 +328,10 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
if self._dashboard_head.gcs_error_subscriber:
|
if self._dashboard_head.gcs_error_subscriber:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
_, error_data = await \
|
(
|
||||||
self._dashboard_head.gcs_error_subscriber.poll()
|
_,
|
||||||
|
error_data,
|
||||||
|
) = await self._dashboard_head.gcs_error_subscriber.poll()
|
||||||
if error_data is None:
|
if error_data is None:
|
||||||
continue
|
continue
|
||||||
process_error(error_data)
|
process_error(error_data)
|
||||||
|
@ -328,20 +350,23 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
||||||
try:
|
try:
|
||||||
_, data = msg
|
_, data = msg
|
||||||
pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
|
pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
|
||||||
error_data = gcs_utils.ErrorTableData.FromString(
|
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
|
||||||
pubsub_msg.data)
|
|
||||||
process_error(error_data)
|
process_error(error_data)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error receiving error info from Redis.")
|
logger.exception("Error receiving error info from Redis.")
|
||||||
|
|
||||||
async def run(self, server):
|
async def run(self, server):
|
||||||
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
|
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
|
||||||
self._gcs_node_info_stub = \
|
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
|
||||||
gcs_service_pb2_grpc.NodeInfoGcsServiceStub(gcs_channel)
|
gcs_channel
|
||||||
|
)
|
||||||
|
|
||||||
await asyncio.gather(self._update_nodes(), self._update_node_stats(),
|
await asyncio.gather(
|
||||||
self._update_log_info(),
|
self._update_nodes(),
|
||||||
self._update_error_info())
|
self._update_node_stats(),
|
||||||
|
self._update_log_info(),
|
||||||
|
self._update_error_info(),
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_minimal_module():
|
def is_minimal_module():
|
||||||
|
|
|
@ -10,19 +10,23 @@ import ray
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from ray.cluster_utils import Cluster
|
from ray.cluster_utils import Cluster
|
||||||
from ray.dashboard.modules.node.node_consts import (LOG_PRUNE_THREASHOLD,
|
from ray.dashboard.modules.node.node_consts import (
|
||||||
MAX_LOGS_TO_CACHE)
|
LOG_PRUNE_THREASHOLD,
|
||||||
|
MAX_LOGS_TO_CACHE,
|
||||||
|
)
|
||||||
from ray.dashboard.tests.conftest import * # noqa
|
from ray.dashboard.tests.conftest import * # noqa
|
||||||
from ray._private.test_utils import (
|
from ray._private.test_utils import (
|
||||||
format_web_url, wait_until_server_available, wait_for_condition,
|
format_web_url,
|
||||||
wait_until_succeeded_without_exception)
|
wait_until_server_available,
|
||||||
|
wait_for_condition,
|
||||||
|
wait_until_succeeded_without_exception,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def test_nodes_update(enable_test_module, ray_start_with_dashboard):
|
def test_nodes_update(enable_test_module, ray_start_with_dashboard):
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
is True)
|
|
||||||
webui_url = ray_start_with_dashboard["webui_url"]
|
webui_url = ray_start_with_dashboard["webui_url"]
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
|
|
||||||
|
@ -44,8 +48,7 @@ def test_nodes_update(enable_test_module, ray_start_with_dashboard):
|
||||||
assert len(dump_data["agents"]) == 1
|
assert len(dump_data["agents"]) == 1
|
||||||
assert len(dump_data["nodeIdToIp"]) == 1
|
assert len(dump_data["nodeIdToIp"]) == 1
|
||||||
assert len(dump_data["nodeIdToHostname"]) == 1
|
assert len(dump_data["nodeIdToHostname"]) == 1
|
||||||
assert dump_data["nodes"].keys() == dump_data[
|
assert dump_data["nodes"].keys() == dump_data["nodeIdToHostname"].keys()
|
||||||
"nodeIdToHostname"].keys()
|
|
||||||
|
|
||||||
response = requests.get(webui_url + "/test/notified_agents")
|
response = requests.get(webui_url + "/test/notified_agents")
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -77,8 +80,7 @@ def test_node_info(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
actor_pids = [actor.getpid.remote() for actor in actors]
|
actor_pids = [actor.getpid.remote() for actor in actors]
|
||||||
actor_pids = set(ray.get(actor_pids))
|
actor_pids = set(ray.get(actor_pids))
|
||||||
|
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
is True)
|
|
||||||
webui_url = ray_start_with_dashboard["webui_url"]
|
webui_url = ray_start_with_dashboard["webui_url"]
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
node_id = ray_start_with_dashboard["node_id"]
|
node_id = ray_start_with_dashboard["node_id"]
|
||||||
|
@ -134,15 +136,19 @@ def test_node_info(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
last_ex = ex
|
last_ex = ex
|
||||||
finally:
|
finally:
|
||||||
if time.time() > start_time + timeout_seconds:
|
if time.time() > start_time + timeout_seconds:
|
||||||
ex_stack = traceback.format_exception(
|
ex_stack = (
|
||||||
type(last_ex), last_ex,
|
traceback.format_exception(
|
||||||
last_ex.__traceback__) if last_ex else []
|
type(last_ex), last_ex, last_ex.__traceback__
|
||||||
|
)
|
||||||
|
if last_ex
|
||||||
|
else []
|
||||||
|
)
|
||||||
ex_stack = "".join(ex_stack)
|
ex_stack = "".join(ex_stack)
|
||||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||||
|
|
||||||
|
|
||||||
def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
|
def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
class ActorWithObjs:
|
class ActorWithObjs:
|
||||||
|
@ -156,8 +162,7 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
actors = [ActorWithObjs.remote() for _ in range(2)] # noqa
|
actors = [ActorWithObjs.remote() for _ in range(2)] # noqa
|
||||||
results = ray.get([actor.get_obj.remote() for actor in actors]) # noqa
|
results = ray.get([actor.get_obj.remote() for actor in actors]) # noqa
|
||||||
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
||||||
resp = requests.get(
|
resp = requests.get(webui_url + "/memory/set_fetch", params={"shouldFetch": "true"})
|
||||||
webui_url + "/memory/set_fetch", params={"shouldFetch": "true"})
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|
||||||
def check_mem_table():
|
def check_mem_table():
|
||||||
|
@ -172,11 +177,12 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
assert summary["totalLocalRefCount"] == 3
|
assert summary["totalLocalRefCount"] == 3
|
||||||
|
|
||||||
assert wait_until_succeeded_without_exception(
|
assert wait_until_succeeded_without_exception(
|
||||||
check_mem_table, (AssertionError, ), timeout_ms=10000)
|
check_mem_table, (AssertionError,), timeout_ms=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard):
|
def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||||
|
|
||||||
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
||||||
|
|
||||||
|
@ -220,21 +226,25 @@ def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||||
last_ex = ex
|
last_ex = ex
|
||||||
finally:
|
finally:
|
||||||
if time.time() > start_time + timeout_seconds:
|
if time.time() > start_time + timeout_seconds:
|
||||||
ex_stack = traceback.format_exception(
|
ex_stack = (
|
||||||
type(last_ex), last_ex,
|
traceback.format_exception(
|
||||||
last_ex.__traceback__) if last_ex else []
|
type(last_ex), last_ex, last_ex.__traceback__
|
||||||
|
)
|
||||||
|
if last_ex
|
||||||
|
else []
|
||||||
|
)
|
||||||
ex_stack = "".join(ex_stack)
|
ex_stack = "".join(ex_stack)
|
||||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ray_start_cluster_head", [{
|
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
|
||||||
"include_dashboard": True
|
)
|
||||||
}], indirect=True)
|
def test_multi_nodes_info(
|
||||||
def test_multi_nodes_info(enable_test_module, disable_aiohttp_cache,
|
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
|
||||||
ray_start_cluster_head):
|
):
|
||||||
cluster: Cluster = ray_start_cluster_head
|
cluster: Cluster = ray_start_cluster_head
|
||||||
assert (wait_until_server_available(cluster.webui_url) is True)
|
assert wait_until_server_available(cluster.webui_url) is True
|
||||||
webui_url = cluster.webui_url
|
webui_url = cluster.webui_url
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
cluster.add_node()
|
cluster.add_node()
|
||||||
|
@ -269,13 +279,13 @@ def test_multi_nodes_info(enable_test_module, disable_aiohttp_cache,
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ray_start_cluster_head", [{
|
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
|
||||||
"include_dashboard": True
|
)
|
||||||
}], indirect=True)
|
def test_multi_node_churn(
|
||||||
def test_multi_node_churn(enable_test_module, disable_aiohttp_cache,
|
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
|
||||||
ray_start_cluster_head):
|
):
|
||||||
cluster: Cluster = ray_start_cluster_head
|
cluster: Cluster = ray_start_cluster_head
|
||||||
assert (wait_until_server_available(cluster.webui_url) is True)
|
assert wait_until_server_available(cluster.webui_url) is True
|
||||||
webui_url = format_web_url(cluster.webui_url)
|
webui_url = format_web_url(cluster.webui_url)
|
||||||
|
|
||||||
def cluster_chaos_monkey():
|
def cluster_chaos_monkey():
|
||||||
|
@ -315,13 +325,11 @@ def test_multi_node_churn(enable_test_module, disable_aiohttp_cache,
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ray_start_cluster_head", [{
|
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
|
||||||
"include_dashboard": True
|
)
|
||||||
}], indirect=True)
|
def test_logs(enable_test_module, disable_aiohttp_cache, ray_start_cluster_head):
|
||||||
def test_logs(enable_test_module, disable_aiohttp_cache,
|
|
||||||
ray_start_cluster_head):
|
|
||||||
cluster = ray_start_cluster_head
|
cluster = ray_start_cluster_head
|
||||||
assert (wait_until_server_available(cluster.webui_url) is True)
|
assert wait_until_server_available(cluster.webui_url) is True
|
||||||
webui_url = cluster.webui_url
|
webui_url = cluster.webui_url
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
nodes = ray.nodes()
|
nodes = ray.nodes()
|
||||||
|
@ -348,21 +356,18 @@ def test_logs(enable_test_module, disable_aiohttp_cache,
|
||||||
|
|
||||||
def check_logs():
|
def check_logs():
|
||||||
node_logs_response = requests.get(
|
node_logs_response = requests.get(
|
||||||
f"{webui_url}/node_logs", params={"ip": node_ip})
|
f"{webui_url}/node_logs", params={"ip": node_ip}
|
||||||
|
)
|
||||||
node_logs_response.raise_for_status()
|
node_logs_response.raise_for_status()
|
||||||
node_logs = node_logs_response.json()
|
node_logs = node_logs_response.json()
|
||||||
assert node_logs["result"]
|
assert node_logs["result"]
|
||||||
assert type(node_logs["data"]["logs"]) is dict
|
assert type(node_logs["data"]["logs"]) is dict
|
||||||
assert all(
|
assert all(pid in node_logs["data"]["logs"] for pid in (la_pid, la2_pid))
|
||||||
pid in node_logs["data"]["logs"] for pid in (la_pid, la2_pid))
|
|
||||||
assert len(node_logs["data"]["logs"][la2_pid]) == 1
|
assert len(node_logs["data"]["logs"][la2_pid]) == 1
|
||||||
|
|
||||||
actor_one_logs_response = requests.get(
|
actor_one_logs_response = requests.get(
|
||||||
f"{webui_url}/node_logs",
|
f"{webui_url}/node_logs", params={"ip": node_ip, "pid": str(la_pid)}
|
||||||
params={
|
)
|
||||||
"ip": node_ip,
|
|
||||||
"pid": str(la_pid)
|
|
||||||
})
|
|
||||||
actor_one_logs_response.raise_for_status()
|
actor_one_logs_response.raise_for_status()
|
||||||
actor_one_logs = actor_one_logs_response.json()
|
actor_one_logs = actor_one_logs_response.json()
|
||||||
assert actor_one_logs["result"]
|
assert actor_one_logs["result"]
|
||||||
|
@ -370,19 +375,19 @@ def test_logs(enable_test_module, disable_aiohttp_cache,
|
||||||
assert len(actor_one_logs["data"]["logs"][la_pid]) == 4
|
assert len(actor_one_logs["data"]["logs"][la_pid]) == 4
|
||||||
|
|
||||||
assert wait_until_succeeded_without_exception(
|
assert wait_until_succeeded_without_exception(
|
||||||
check_logs, (AssertionError, ), timeout_ms=1000)
|
check_logs, (AssertionError,), timeout_ms=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ray_start_cluster_head", [{
|
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
|
||||||
"include_dashboard": True
|
)
|
||||||
}], indirect=True)
|
def test_logs_clean_up(
|
||||||
def test_logs_clean_up(enable_test_module, disable_aiohttp_cache,
|
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
|
||||||
ray_start_cluster_head):
|
):
|
||||||
"""Check if logs from the dead pids are GC'ed.
|
"""Check if logs from the dead pids are GC'ed."""
|
||||||
"""
|
|
||||||
cluster = ray_start_cluster_head
|
cluster = ray_start_cluster_head
|
||||||
assert (wait_until_server_available(cluster.webui_url) is True)
|
assert wait_until_server_available(cluster.webui_url) is True
|
||||||
webui_url = cluster.webui_url
|
webui_url = cluster.webui_url
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
nodes = ray.nodes()
|
nodes = ray.nodes()
|
||||||
|
@ -406,38 +411,41 @@ def test_logs_clean_up(enable_test_module, disable_aiohttp_cache,
|
||||||
|
|
||||||
def check_logs():
|
def check_logs():
|
||||||
node_logs_response = requests.get(
|
node_logs_response = requests.get(
|
||||||
f"{webui_url}/node_logs", params={"ip": node_ip})
|
f"{webui_url}/node_logs", params={"ip": node_ip}
|
||||||
|
)
|
||||||
node_logs_response.raise_for_status()
|
node_logs_response.raise_for_status()
|
||||||
node_logs = node_logs_response.json()
|
node_logs = node_logs_response.json()
|
||||||
assert node_logs["result"]
|
assert node_logs["result"]
|
||||||
assert la_pid in node_logs["data"]["logs"]
|
assert la_pid in node_logs["data"]["logs"]
|
||||||
|
|
||||||
assert wait_until_succeeded_without_exception(
|
assert wait_until_succeeded_without_exception(
|
||||||
check_logs, (AssertionError, ), timeout_ms=1000)
|
check_logs, (AssertionError,), timeout_ms=1000
|
||||||
|
)
|
||||||
ray.kill(la)
|
ray.kill(la)
|
||||||
|
|
||||||
def check_logs_not_exist():
|
def check_logs_not_exist():
|
||||||
node_logs_response = requests.get(
|
node_logs_response = requests.get(
|
||||||
f"{webui_url}/node_logs", params={"ip": node_ip})
|
f"{webui_url}/node_logs", params={"ip": node_ip}
|
||||||
|
)
|
||||||
node_logs_response.raise_for_status()
|
node_logs_response.raise_for_status()
|
||||||
node_logs = node_logs_response.json()
|
node_logs = node_logs_response.json()
|
||||||
assert node_logs["result"]
|
assert node_logs["result"]
|
||||||
assert la_pid not in node_logs["data"]["logs"]
|
assert la_pid not in node_logs["data"]["logs"]
|
||||||
|
|
||||||
assert wait_until_succeeded_without_exception(
|
assert wait_until_succeeded_without_exception(
|
||||||
check_logs_not_exist, (AssertionError, ), timeout_ms=10000)
|
check_logs_not_exist, (AssertionError,), timeout_ms=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ray_start_cluster_head", [{
|
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
|
||||||
"include_dashboard": True
|
)
|
||||||
}], indirect=True)
|
def test_logs_max_count(
|
||||||
def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
|
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
|
||||||
ray_start_cluster_head):
|
):
|
||||||
"""Test that each Ray worker cannot cache more than 1000 logs at a time.
|
"""Test that each Ray worker cannot cache more than 1000 logs at a time."""
|
||||||
"""
|
|
||||||
cluster = ray_start_cluster_head
|
cluster = ray_start_cluster_head
|
||||||
assert (wait_until_server_available(cluster.webui_url) is True)
|
assert wait_until_server_available(cluster.webui_url) is True
|
||||||
webui_url = cluster.webui_url
|
webui_url = cluster.webui_url
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
nodes = ray.nodes()
|
nodes = ray.nodes()
|
||||||
|
@ -461,7 +469,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
|
||||||
|
|
||||||
def check_logs():
|
def check_logs():
|
||||||
node_logs_response = requests.get(
|
node_logs_response = requests.get(
|
||||||
f"{webui_url}/node_logs", params={"ip": node_ip})
|
f"{webui_url}/node_logs", params={"ip": node_ip}
|
||||||
|
)
|
||||||
node_logs_response.raise_for_status()
|
node_logs_response.raise_for_status()
|
||||||
node_logs = node_logs_response.json()
|
node_logs = node_logs_response.json()
|
||||||
assert node_logs["result"]
|
assert node_logs["result"]
|
||||||
|
@ -472,11 +481,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
|
||||||
assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD
|
assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD
|
||||||
|
|
||||||
actor_one_logs_response = requests.get(
|
actor_one_logs_response = requests.get(
|
||||||
f"{webui_url}/node_logs",
|
f"{webui_url}/node_logs", params={"ip": node_ip, "pid": str(la_pid)}
|
||||||
params={
|
)
|
||||||
"ip": node_ip,
|
|
||||||
"pid": str(la_pid)
|
|
||||||
})
|
|
||||||
actor_one_logs_response.raise_for_status()
|
actor_one_logs_response.raise_for_status()
|
||||||
actor_one_logs = actor_one_logs_response.json()
|
actor_one_logs = actor_one_logs_response.json()
|
||||||
assert actor_one_logs["result"]
|
assert actor_one_logs["result"]
|
||||||
|
@ -486,7 +492,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
|
||||||
assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD
|
assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD
|
||||||
|
|
||||||
assert wait_until_succeeded_without_exception(
|
assert wait_until_succeeded_without_exception(
|
||||||
check_logs, (AssertionError, ), timeout_ms=10000)
|
check_logs, (AssertionError,), timeout_ms=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -39,10 +39,12 @@ try:
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
gpustat = None
|
gpustat = None
|
||||||
if log_once("gpustat_import_warning"):
|
if log_once("gpustat_import_warning"):
|
||||||
warnings.warn("`gpustat` package is not installed. GPU monitoring is "
|
warnings.warn(
|
||||||
"not available. To have full functionality of the "
|
"`gpustat` package is not installed. GPU monitoring is "
|
||||||
"dashboard please install `pip install ray["
|
"not available. To have full functionality of the "
|
||||||
"default]`.)")
|
"dashboard please install `pip install ray["
|
||||||
|
"default]`.)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def recursive_asdict(o):
|
def recursive_asdict(o):
|
||||||
|
@ -68,68 +70,81 @@ def jsonify_asdict(o) -> str:
|
||||||
|
|
||||||
# A list of gauges to record and export metrics.
|
# A list of gauges to record and export metrics.
|
||||||
METRICS_GAUGES = {
|
METRICS_GAUGES = {
|
||||||
"node_cpu_utilization": Gauge("node_cpu_utilization",
|
"node_cpu_utilization": Gauge(
|
||||||
"Total CPU usage on a ray node",
|
"node_cpu_utilization", "Total CPU usage on a ray node", "percentage", ["ip"]
|
||||||
"percentage", ["ip"]),
|
),
|
||||||
"node_cpu_count": Gauge("node_cpu_count",
|
"node_cpu_count": Gauge(
|
||||||
"Total CPUs available on a ray node", "cores",
|
"node_cpu_count", "Total CPUs available on a ray node", "cores", ["ip"]
|
||||||
["ip"]),
|
),
|
||||||
"node_mem_used": Gauge("node_mem_used", "Memory usage on a ray node",
|
"node_mem_used": Gauge(
|
||||||
"bytes", ["ip"]),
|
"node_mem_used", "Memory usage on a ray node", "bytes", ["ip"]
|
||||||
"node_mem_available": Gauge("node_mem_available",
|
),
|
||||||
"Memory available on a ray node", "bytes",
|
"node_mem_available": Gauge(
|
||||||
["ip"]),
|
"node_mem_available", "Memory available on a ray node", "bytes", ["ip"]
|
||||||
"node_mem_total": Gauge("node_mem_total", "Total memory on a ray node",
|
),
|
||||||
"bytes", ["ip"]),
|
"node_mem_total": Gauge(
|
||||||
"node_gpus_available": Gauge("node_gpus_available",
|
"node_mem_total", "Total memory on a ray node", "bytes", ["ip"]
|
||||||
"Total GPUs available on a ray node",
|
),
|
||||||
"percentage", ["ip"]),
|
"node_gpus_available": Gauge(
|
||||||
"node_gpus_utilization": Gauge("node_gpus_utilization",
|
"node_gpus_available",
|
||||||
"Total GPUs usage on a ray node",
|
"Total GPUs available on a ray node",
|
||||||
"percentage", ["ip"]),
|
"percentage",
|
||||||
"node_gram_used": Gauge("node_gram_used",
|
["ip"],
|
||||||
"Total GPU RAM usage on a ray node", "bytes",
|
),
|
||||||
["ip"]),
|
"node_gpus_utilization": Gauge(
|
||||||
"node_gram_available": Gauge("node_gram_available",
|
"node_gpus_utilization", "Total GPUs usage on a ray node", "percentage", ["ip"]
|
||||||
"Total GPU RAM available on a ray node",
|
),
|
||||||
"bytes", ["ip"]),
|
"node_gram_used": Gauge(
|
||||||
"node_disk_usage": Gauge("node_disk_usage",
|
"node_gram_used", "Total GPU RAM usage on a ray node", "bytes", ["ip"]
|
||||||
"Total disk usage (bytes) on a ray node", "bytes",
|
),
|
||||||
["ip"]),
|
"node_gram_available": Gauge(
|
||||||
"node_disk_free": Gauge("node_disk_free",
|
"node_gram_available", "Total GPU RAM available on a ray node", "bytes", ["ip"]
|
||||||
"Total disk free (bytes) on a ray node", "bytes",
|
),
|
||||||
["ip"]),
|
"node_disk_usage": Gauge(
|
||||||
|
"node_disk_usage", "Total disk usage (bytes) on a ray node", "bytes", ["ip"]
|
||||||
|
),
|
||||||
|
"node_disk_free": Gauge(
|
||||||
|
"node_disk_free", "Total disk free (bytes) on a ray node", "bytes", ["ip"]
|
||||||
|
),
|
||||||
"node_disk_utilization_percentage": Gauge(
|
"node_disk_utilization_percentage": Gauge(
|
||||||
"node_disk_utilization_percentage",
|
"node_disk_utilization_percentage",
|
||||||
"Total disk utilization (percentage) on a ray node", "percentage",
|
"Total disk utilization (percentage) on a ray node",
|
||||||
["ip"]),
|
"percentage",
|
||||||
"node_network_sent": Gauge("node_network_sent", "Total network sent",
|
["ip"],
|
||||||
"bytes", ["ip"]),
|
),
|
||||||
"node_network_received": Gauge("node_network_received",
|
"node_network_sent": Gauge(
|
||||||
"Total network received", "bytes", ["ip"]),
|
"node_network_sent", "Total network sent", "bytes", ["ip"]
|
||||||
|
),
|
||||||
|
"node_network_received": Gauge(
|
||||||
|
"node_network_received", "Total network received", "bytes", ["ip"]
|
||||||
|
),
|
||||||
"node_network_send_speed": Gauge(
|
"node_network_send_speed": Gauge(
|
||||||
"node_network_send_speed", "Network send speed", "bytes/sec", ["ip"]),
|
"node_network_send_speed", "Network send speed", "bytes/sec", ["ip"]
|
||||||
"node_network_receive_speed": Gauge("node_network_receive_speed",
|
),
|
||||||
"Network receive speed", "bytes/sec",
|
"node_network_receive_speed": Gauge(
|
||||||
["ip"]),
|
"node_network_receive_speed", "Network receive speed", "bytes/sec", ["ip"]
|
||||||
"raylet_cpu": Gauge("raylet_cpu", "CPU usage of the raylet on a node.",
|
),
|
||||||
"percentage", ["ip", "pid"]),
|
"raylet_cpu": Gauge(
|
||||||
"raylet_mem": Gauge("raylet_mem", "Memory usage of the raylet on a node",
|
"raylet_cpu", "CPU usage of the raylet on a node.", "percentage", ["ip", "pid"]
|
||||||
"mb", ["ip", "pid"]),
|
),
|
||||||
"cluster_active_nodes": Gauge("cluster_active_nodes",
|
"raylet_mem": Gauge(
|
||||||
"Active nodes on the cluster", "count",
|
"raylet_mem", "Memory usage of the raylet on a node", "mb", ["ip", "pid"]
|
||||||
["node_type"]),
|
),
|
||||||
"cluster_failed_nodes": Gauge("cluster_failed_nodes",
|
"cluster_active_nodes": Gauge(
|
||||||
"Failed nodes on the cluster", "count",
|
"cluster_active_nodes", "Active nodes on the cluster", "count", ["node_type"]
|
||||||
["node_type"]),
|
),
|
||||||
"cluster_pending_nodes": Gauge("cluster_pending_nodes",
|
"cluster_failed_nodes": Gauge(
|
||||||
"Pending nodes on the cluster", "count",
|
"cluster_failed_nodes", "Failed nodes on the cluster", "count", ["node_type"]
|
||||||
["node_type"]),
|
),
|
||||||
|
"cluster_pending_nodes": Gauge(
|
||||||
|
"cluster_pending_nodes", "Pending nodes on the cluster", "count", ["node_type"]
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
class ReporterAgent(
|
||||||
reporter_pb2_grpc.ReporterServiceServicer):
|
dashboard_utils.DashboardAgentModule, reporter_pb2_grpc.ReporterServiceServicer
|
||||||
|
):
|
||||||
"""A monitor process for monitoring Ray nodes.
|
"""A monitor process for monitoring Ray nodes.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -145,37 +160,39 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
cpu_count = ray._private.utils.get_num_cpus()
|
cpu_count = ray._private.utils.get_num_cpus()
|
||||||
self._cpu_counts = (cpu_count, cpu_count)
|
self._cpu_counts = (cpu_count, cpu_count)
|
||||||
else:
|
else:
|
||||||
self._cpu_counts = (psutil.cpu_count(),
|
self._cpu_counts = (psutil.cpu_count(), psutil.cpu_count(logical=False))
|
||||||
psutil.cpu_count(logical=False))
|
|
||||||
|
|
||||||
self._ip = dashboard_agent.ip
|
self._ip = dashboard_agent.ip
|
||||||
if not use_gcs_for_bootstrap():
|
if not use_gcs_for_bootstrap():
|
||||||
self._redis_address, _ = dashboard_agent.redis_address
|
self._redis_address, _ = dashboard_agent.redis_address
|
||||||
self._is_head_node = (self._ip == self._redis_address)
|
self._is_head_node = self._ip == self._redis_address
|
||||||
else:
|
else:
|
||||||
self._is_head_node = (
|
self._is_head_node = self._ip == dashboard_agent.gcs_address.split(":")[0]
|
||||||
self._ip == dashboard_agent.gcs_address.split(":")[0])
|
|
||||||
self._hostname = socket.gethostname()
|
self._hostname = socket.gethostname()
|
||||||
self._workers = set()
|
self._workers = set()
|
||||||
self._network_stats_hist = [(0, (0.0, 0.0))] # time, (sent, recv)
|
self._network_stats_hist = [(0, (0.0, 0.0))] # time, (sent, recv)
|
||||||
self._metrics_agent = MetricsAgent(
|
self._metrics_agent = MetricsAgent(
|
||||||
"127.0.0.1" if self._ip == "127.0.0.1" else "",
|
"127.0.0.1" if self._ip == "127.0.0.1" else "",
|
||||||
dashboard_agent.metrics_export_port)
|
dashboard_agent.metrics_export_port,
|
||||||
self._key = f"{reporter_consts.REPORTER_PREFIX}" \
|
)
|
||||||
f"{self._dashboard_agent.node_id}"
|
self._key = (
|
||||||
|
f"{reporter_consts.REPORTER_PREFIX}" f"{self._dashboard_agent.node_id}"
|
||||||
|
)
|
||||||
|
|
||||||
async def GetProfilingStats(self, request, context):
|
async def GetProfilingStats(self, request, context):
|
||||||
pid = request.pid
|
pid = request.pid
|
||||||
duration = request.duration
|
duration = request.duration
|
||||||
profiling_file_path = os.path.join(
|
profiling_file_path = os.path.join(
|
||||||
ray._private.utils.get_ray_temp_dir(), f"{pid}_profiling.txt")
|
ray._private.utils.get_ray_temp_dir(), f"{pid}_profiling.txt"
|
||||||
|
)
|
||||||
sudo = "sudo" if ray._private.utils.get_user() != "root" else ""
|
sudo = "sudo" if ray._private.utils.get_user() != "root" else ""
|
||||||
process = await asyncio.create_subprocess_shell(
|
process = await asyncio.create_subprocess_shell(
|
||||||
f"{sudo} $(which py-spy) record "
|
f"{sudo} $(which py-spy) record "
|
||||||
f"-o {profiling_file_path} -p {pid} -d {duration} -f speedscope",
|
f"-o {profiling_file_path} -p {pid} -d {duration} -f speedscope",
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
shell=True)
|
shell=True,
|
||||||
|
)
|
||||||
stdout, stderr = await process.communicate()
|
stdout, stderr = await process.communicate()
|
||||||
if process.returncode != 0:
|
if process.returncode != 0:
|
||||||
profiling_stats = ""
|
profiling_stats = ""
|
||||||
|
@ -183,14 +200,14 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
with open(profiling_file_path, "r") as f:
|
with open(profiling_file_path, "r") as f:
|
||||||
profiling_stats = f.read()
|
profiling_stats = f.read()
|
||||||
return reporter_pb2.GetProfilingStatsReply(
|
return reporter_pb2.GetProfilingStatsReply(
|
||||||
profiling_stats=profiling_stats, std_out=stdout, std_err=stderr)
|
profiling_stats=profiling_stats, std_out=stdout, std_err=stderr
|
||||||
|
)
|
||||||
|
|
||||||
async def ReportOCMetrics(self, request, context):
|
async def ReportOCMetrics(self, request, context):
|
||||||
# This function receives a GRPC containing OpenCensus (OC) metrics
|
# This function receives a GRPC containing OpenCensus (OC) metrics
|
||||||
# from a Ray process, then exposes those metrics to Prometheus.
|
# from a Ray process, then exposes those metrics to Prometheus.
|
||||||
try:
|
try:
|
||||||
self._metrics_agent.record_metric_points_from_protobuf(
|
self._metrics_agent.record_metric_points_from_protobuf(request.metrics)
|
||||||
request.metrics)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return reporter_pb2.ReportOCMetricsReply()
|
return reporter_pb2.ReportOCMetricsReply()
|
||||||
|
@ -227,10 +244,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
for gpu in gpus:
|
for gpu in gpus:
|
||||||
# Note the keys in this dict have periods which throws
|
# Note the keys in this dict have periods which throws
|
||||||
# off javascript so we change .s to _s
|
# off javascript so we change .s to _s
|
||||||
gpu_data = {
|
gpu_data = {"_".join(key.split(".")): val for key, val in gpu.entry.items()}
|
||||||
"_".join(key.split(".")): val
|
|
||||||
for key, val in gpu.entry.items()
|
|
||||||
}
|
|
||||||
gpu_utilizations.append(gpu_data)
|
gpu_utilizations.append(gpu_data)
|
||||||
return gpu_utilizations
|
return gpu_utilizations
|
||||||
|
|
||||||
|
@ -245,8 +259,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_network_stats():
|
def _get_network_stats():
|
||||||
ifaces = [
|
ifaces = [
|
||||||
v for k, v in psutil.net_io_counters(pernic=True).items()
|
v for k, v in psutil.net_io_counters(pernic=True).items() if k[0] == "e"
|
||||||
if k[0] == "e"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
sent = sum((iface.bytes_sent for iface in ifaces))
|
sent = sum((iface.bytes_sent for iface in ifaces))
|
||||||
|
@ -266,8 +279,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
if IN_KUBERNETES_POD:
|
if IN_KUBERNETES_POD:
|
||||||
# If in a K8s pod, disable disk display by passing in dummy values.
|
# If in a K8s pod, disable disk display by passing in dummy values.
|
||||||
return {
|
return {
|
||||||
"/": psutil._common.sdiskusage(
|
"/": psutil._common.sdiskusage(total=1, used=0, free=1, percent=0.0)
|
||||||
total=1, used=0, free=1, percent=0.0)
|
|
||||||
}
|
}
|
||||||
root = os.environ["USERPROFILE"] if sys.platform == "win32" else os.sep
|
root = os.environ["USERPROFILE"] if sys.platform == "win32" else os.sep
|
||||||
tmp = ray._private.utils.get_user_temp_dir()
|
tmp = ray._private.utils.get_user_temp_dir()
|
||||||
|
@ -286,14 +298,18 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
self._workers.update(workers)
|
self._workers.update(workers)
|
||||||
self._workers.discard(psutil.Process())
|
self._workers.discard(psutil.Process())
|
||||||
return [
|
return [
|
||||||
w.as_dict(attrs=[
|
w.as_dict(
|
||||||
"pid",
|
attrs=[
|
||||||
"create_time",
|
"pid",
|
||||||
"cpu_percent",
|
"create_time",
|
||||||
"cpu_times",
|
"cpu_percent",
|
||||||
"cmdline",
|
"cpu_times",
|
||||||
"memory_info",
|
"cmdline",
|
||||||
]) for w in self._workers if w.status() != psutil.STATUS_ZOMBIE
|
"memory_info",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for w in self._workers
|
||||||
|
if w.status() != psutil.STATUS_ZOMBIE
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -318,14 +334,16 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
if raylet_proc is None:
|
if raylet_proc is None:
|
||||||
return {}
|
return {}
|
||||||
else:
|
else:
|
||||||
return raylet_proc.as_dict(attrs=[
|
return raylet_proc.as_dict(
|
||||||
"pid",
|
attrs=[
|
||||||
"create_time",
|
"pid",
|
||||||
"cpu_percent",
|
"create_time",
|
||||||
"cpu_times",
|
"cpu_percent",
|
||||||
"cmdline",
|
"cpu_times",
|
||||||
"memory_info",
|
"cmdline",
|
||||||
])
|
"memory_info",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def _get_load_avg(self):
|
def _get_load_avg(self):
|
||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
|
@ -345,8 +363,10 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
then, prev_network_stats = self._network_stats_hist[0]
|
then, prev_network_stats = self._network_stats_hist[0]
|
||||||
prev_send, prev_recv = prev_network_stats
|
prev_send, prev_recv = prev_network_stats
|
||||||
now_send, now_recv = network_stats
|
now_send, now_recv = network_stats
|
||||||
network_speed_stats = ((now_send - prev_send) / (now - then),
|
network_speed_stats = (
|
||||||
(now_recv - prev_recv) / (now - then))
|
(now_send - prev_send) / (now - then),
|
||||||
|
(now_recv - prev_recv) / (now - then),
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"now": now,
|
"now": now,
|
||||||
"hostname": self._hostname,
|
"hostname": self._hostname,
|
||||||
|
@ -379,7 +399,9 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
Record(
|
Record(
|
||||||
gauge=METRICS_GAUGES["cluster_active_nodes"],
|
gauge=METRICS_GAUGES["cluster_active_nodes"],
|
||||||
value=active_node_count,
|
value=active_node_count,
|
||||||
tags={"node_type": node_type}))
|
tags={"node_type": node_type},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
failed_nodes = cluster_stats["autoscaler_report"]["failed_nodes"]
|
failed_nodes = cluster_stats["autoscaler_report"]["failed_nodes"]
|
||||||
failed_nodes_dict = {}
|
failed_nodes_dict = {}
|
||||||
|
@ -394,7 +416,9 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
Record(
|
Record(
|
||||||
gauge=METRICS_GAUGES["cluster_failed_nodes"],
|
gauge=METRICS_GAUGES["cluster_failed_nodes"],
|
||||||
value=failed_node_count,
|
value=failed_node_count,
|
||||||
tags={"node_type": node_type}))
|
tags={"node_type": node_type},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
pending_nodes = cluster_stats["autoscaler_report"]["pending_nodes"]
|
pending_nodes = cluster_stats["autoscaler_report"]["pending_nodes"]
|
||||||
pending_nodes_dict = {}
|
pending_nodes_dict = {}
|
||||||
|
@ -409,35 +433,36 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
Record(
|
Record(
|
||||||
gauge=METRICS_GAUGES["cluster_pending_nodes"],
|
gauge=METRICS_GAUGES["cluster_pending_nodes"],
|
||||||
value=pending_node_count,
|
value=pending_node_count,
|
||||||
tags={"node_type": node_type}))
|
tags={"node_type": node_type},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# -- CPU per node --
|
# -- CPU per node --
|
||||||
cpu_usage = float(stats["cpu"])
|
cpu_usage = float(stats["cpu"])
|
||||||
cpu_record = Record(
|
cpu_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_cpu_utilization"],
|
gauge=METRICS_GAUGES["node_cpu_utilization"],
|
||||||
value=cpu_usage,
|
value=cpu_usage,
|
||||||
tags={"ip": ip})
|
tags={"ip": ip},
|
||||||
|
)
|
||||||
|
|
||||||
cpu_count, _ = stats["cpus"]
|
cpu_count, _ = stats["cpus"]
|
||||||
cpu_count_record = Record(
|
cpu_count_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_cpu_count"],
|
gauge=METRICS_GAUGES["node_cpu_count"], value=cpu_count, tags={"ip": ip}
|
||||||
value=cpu_count,
|
)
|
||||||
tags={"ip": ip})
|
|
||||||
|
|
||||||
# -- Mem per node --
|
# -- Mem per node --
|
||||||
mem_total, mem_available, _, mem_used = stats["mem"]
|
mem_total, mem_available, _, mem_used = stats["mem"]
|
||||||
mem_used_record = Record(
|
mem_used_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_mem_used"],
|
gauge=METRICS_GAUGES["node_mem_used"], value=mem_used, tags={"ip": ip}
|
||||||
value=mem_used,
|
)
|
||||||
tags={"ip": ip})
|
|
||||||
mem_available_record = Record(
|
mem_available_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_mem_available"],
|
gauge=METRICS_GAUGES["node_mem_available"],
|
||||||
value=mem_available,
|
value=mem_available,
|
||||||
tags={"ip": ip})
|
tags={"ip": ip},
|
||||||
|
)
|
||||||
mem_total_record = Record(
|
mem_total_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_mem_total"],
|
gauge=METRICS_GAUGES["node_mem_total"], value=mem_total, tags={"ip": ip}
|
||||||
value=mem_total,
|
)
|
||||||
tags={"ip": ip})
|
|
||||||
|
|
||||||
# -- GPU per node --
|
# -- GPU per node --
|
||||||
gpus = stats["gpus"]
|
gpus = stats["gpus"]
|
||||||
|
@ -455,23 +480,29 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
gpus_available_record = Record(
|
gpus_available_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_gpus_available"],
|
gauge=METRICS_GAUGES["node_gpus_available"],
|
||||||
value=gpus_available,
|
value=gpus_available,
|
||||||
tags={"ip": ip})
|
tags={"ip": ip},
|
||||||
|
)
|
||||||
gpus_utilization_record = Record(
|
gpus_utilization_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_gpus_utilization"],
|
gauge=METRICS_GAUGES["node_gpus_utilization"],
|
||||||
value=gpus_utilization,
|
value=gpus_utilization,
|
||||||
tags={"ip": ip})
|
tags={"ip": ip},
|
||||||
|
)
|
||||||
gram_used_record = Record(
|
gram_used_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_gram_used"],
|
gauge=METRICS_GAUGES["node_gram_used"], value=gram_used, tags={"ip": ip}
|
||||||
value=gram_used,
|
)
|
||||||
tags={"ip": ip})
|
|
||||||
gram_available_record = Record(
|
gram_available_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_gram_available"],
|
gauge=METRICS_GAUGES["node_gram_available"],
|
||||||
value=gram_available,
|
value=gram_available,
|
||||||
tags={"ip": ip})
|
tags={"ip": ip},
|
||||||
records_reported.extend([
|
)
|
||||||
gpus_available_record, gpus_utilization_record,
|
records_reported.extend(
|
||||||
gram_used_record, gram_available_record
|
[
|
||||||
])
|
gpus_available_record,
|
||||||
|
gpus_utilization_record,
|
||||||
|
gram_used_record,
|
||||||
|
gram_available_record,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# -- Disk per node --
|
# -- Disk per node --
|
||||||
used, free = 0, 0
|
used, free = 0, 0
|
||||||
|
@ -480,39 +511,42 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
free += entry.free
|
free += entry.free
|
||||||
disk_utilization = float(used / (used + free)) * 100
|
disk_utilization = float(used / (used + free)) * 100
|
||||||
disk_usage_record = Record(
|
disk_usage_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_disk_usage"],
|
gauge=METRICS_GAUGES["node_disk_usage"], value=used, tags={"ip": ip}
|
||||||
value=used,
|
)
|
||||||
tags={"ip": ip})
|
|
||||||
disk_free_record = Record(
|
disk_free_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_disk_free"],
|
gauge=METRICS_GAUGES["node_disk_free"], value=free, tags={"ip": ip}
|
||||||
value=free,
|
)
|
||||||
tags={"ip": ip})
|
|
||||||
disk_utilization_percentage_record = Record(
|
disk_utilization_percentage_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_disk_utilization_percentage"],
|
gauge=METRICS_GAUGES["node_disk_utilization_percentage"],
|
||||||
value=disk_utilization,
|
value=disk_utilization,
|
||||||
tags={"ip": ip})
|
tags={"ip": ip},
|
||||||
|
)
|
||||||
|
|
||||||
# -- Network speed (send/receive) stats per node --
|
# -- Network speed (send/receive) stats per node --
|
||||||
network_stats = stats["network"]
|
network_stats = stats["network"]
|
||||||
network_sent_record = Record(
|
network_sent_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_network_sent"],
|
gauge=METRICS_GAUGES["node_network_sent"],
|
||||||
value=network_stats[0],
|
value=network_stats[0],
|
||||||
tags={"ip": ip})
|
tags={"ip": ip},
|
||||||
|
)
|
||||||
network_received_record = Record(
|
network_received_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_network_received"],
|
gauge=METRICS_GAUGES["node_network_received"],
|
||||||
value=network_stats[1],
|
value=network_stats[1],
|
||||||
tags={"ip": ip})
|
tags={"ip": ip},
|
||||||
|
)
|
||||||
|
|
||||||
# -- Network speed (send/receive) per node --
|
# -- Network speed (send/receive) per node --
|
||||||
network_speed_stats = stats["network_speed"]
|
network_speed_stats = stats["network_speed"]
|
||||||
network_send_speed_record = Record(
|
network_send_speed_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_network_send_speed"],
|
gauge=METRICS_GAUGES["node_network_send_speed"],
|
||||||
value=network_speed_stats[0],
|
value=network_speed_stats[0],
|
||||||
tags={"ip": ip})
|
tags={"ip": ip},
|
||||||
|
)
|
||||||
network_receive_speed_record = Record(
|
network_receive_speed_record = Record(
|
||||||
gauge=METRICS_GAUGES["node_network_receive_speed"],
|
gauge=METRICS_GAUGES["node_network_receive_speed"],
|
||||||
value=network_speed_stats[1],
|
value=network_speed_stats[1],
|
||||||
tags={"ip": ip})
|
tags={"ip": ip},
|
||||||
|
)
|
||||||
|
|
||||||
raylet_stats = stats["raylet"]
|
raylet_stats = stats["raylet"]
|
||||||
if raylet_stats:
|
if raylet_stats:
|
||||||
|
@ -522,29 +556,34 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
raylet_cpu_record = Record(
|
raylet_cpu_record = Record(
|
||||||
gauge=METRICS_GAUGES["raylet_cpu"],
|
gauge=METRICS_GAUGES["raylet_cpu"],
|
||||||
value=raylet_cpu_usage,
|
value=raylet_cpu_usage,
|
||||||
tags={
|
tags={"ip": ip, "pid": raylet_pid},
|
||||||
"ip": ip,
|
)
|
||||||
"pid": raylet_pid
|
|
||||||
})
|
|
||||||
|
|
||||||
# -- raylet mem --
|
# -- raylet mem --
|
||||||
raylet_mem_usage = float(raylet_stats["memory_info"].rss) / 1e6
|
raylet_mem_usage = float(raylet_stats["memory_info"].rss) / 1e6
|
||||||
raylet_mem_record = Record(
|
raylet_mem_record = Record(
|
||||||
gauge=METRICS_GAUGES["raylet_mem"],
|
gauge=METRICS_GAUGES["raylet_mem"],
|
||||||
value=raylet_mem_usage,
|
value=raylet_mem_usage,
|
||||||
tags={
|
tags={"ip": ip, "pid": raylet_pid},
|
||||||
"ip": ip,
|
)
|
||||||
"pid": raylet_pid
|
|
||||||
})
|
|
||||||
records_reported.extend([raylet_cpu_record, raylet_mem_record])
|
records_reported.extend([raylet_cpu_record, raylet_mem_record])
|
||||||
|
|
||||||
records_reported.extend([
|
records_reported.extend(
|
||||||
cpu_record, cpu_count_record, mem_used_record,
|
[
|
||||||
mem_available_record, mem_total_record, disk_usage_record,
|
cpu_record,
|
||||||
disk_free_record, disk_utilization_percentage_record,
|
cpu_count_record,
|
||||||
network_sent_record, network_received_record,
|
mem_used_record,
|
||||||
network_send_speed_record, network_receive_speed_record
|
mem_available_record,
|
||||||
])
|
mem_total_record,
|
||||||
|
disk_usage_record,
|
||||||
|
disk_free_record,
|
||||||
|
disk_utilization_percentage_record,
|
||||||
|
network_sent_record,
|
||||||
|
network_received_record,
|
||||||
|
network_send_speed_record,
|
||||||
|
network_receive_speed_record,
|
||||||
|
]
|
||||||
|
)
|
||||||
return records_reported
|
return records_reported
|
||||||
|
|
||||||
async def _perform_iteration(self, publish):
|
async def _perform_iteration(self, publish):
|
||||||
|
@ -552,9 +591,13 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
formatted_status_string = internal_kv._internal_kv_get(
|
formatted_status_string = internal_kv._internal_kv_get(
|
||||||
DEBUG_AUTOSCALING_STATUS)
|
DEBUG_AUTOSCALING_STATUS
|
||||||
cluster_stats = json.loads(formatted_status_string.decode(
|
)
|
||||||
)) if formatted_status_string else {}
|
cluster_stats = (
|
||||||
|
json.loads(formatted_status_string.decode())
|
||||||
|
if formatted_status_string
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
stats = self._get_all_stats()
|
stats = self._get_all_stats()
|
||||||
records_reported = self._record_stats(stats, cluster_stats)
|
records_reported = self._record_stats(stats, cluster_stats)
|
||||||
|
@ -563,8 +606,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error publishing node physical stats.")
|
logger.exception("Error publishing node physical stats.")
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(reporter_consts.REPORTER_UPDATE_INTERVAL_MS / 1000)
|
||||||
reporter_consts.REPORTER_UPDATE_INTERVAL_MS / 1000)
|
|
||||||
|
|
||||||
async def run(self, server):
|
async def run(self, server):
|
||||||
reporter_pb2_grpc.add_ReporterServiceServicer_to_server(self, server)
|
reporter_pb2_grpc.add_ReporterServiceServicer_to_server(self, server)
|
||||||
|
@ -573,17 +615,20 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||||
if gcs_addr is None:
|
if gcs_addr is None:
|
||||||
aioredis_client = await aioredis.create_redis_pool(
|
aioredis_client = await aioredis.create_redis_pool(
|
||||||
address=self._dashboard_agent.redis_address,
|
address=self._dashboard_agent.redis_address,
|
||||||
password=self._dashboard_agent.redis_password)
|
password=self._dashboard_agent.redis_password,
|
||||||
|
)
|
||||||
gcs_addr = await aioredis_client.get("GcsServerAddress")
|
gcs_addr = await aioredis_client.get("GcsServerAddress")
|
||||||
gcs_addr = gcs_addr.decode()
|
gcs_addr = gcs_addr.decode()
|
||||||
publisher = GcsAioPublisher(address=gcs_addr)
|
publisher = GcsAioPublisher(address=gcs_addr)
|
||||||
|
|
||||||
async def publish(key: str, data: str):
|
async def publish(key: str, data: str):
|
||||||
await publisher.publish_resource_usage(key, data)
|
await publisher.publish_resource_usage(key, data)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
aioredis_client = await aioredis.create_redis_pool(
|
aioredis_client = await aioredis.create_redis_pool(
|
||||||
address=self._dashboard_agent.redis_address,
|
address=self._dashboard_agent.redis_address,
|
||||||
password=self._dashboard_agent.redis_password)
|
password=self._dashboard_agent.redis_password,
|
||||||
|
)
|
||||||
|
|
||||||
async def publish(key: str, data: str):
|
async def publish(key: str, data: str):
|
||||||
await aioredis_client.publish(key, data)
|
await aioredis_client.publish(key, data)
|
||||||
|
|
|
@ -3,4 +3,5 @@ import ray.ray_constants as ray_constants
|
||||||
REPORTER_PREFIX = "RAY_REPORTER:"
|
REPORTER_PREFIX = "RAY_REPORTER:"
|
||||||
# The reporter will report its statistics this often (milliseconds).
|
# The reporter will report its statistics this often (milliseconds).
|
||||||
REPORTER_UPDATE_INTERVAL_MS = ray_constants.env_integer(
|
REPORTER_UPDATE_INTERVAL_MS = ray_constants.env_integer(
|
||||||
"REPORTER_UPDATE_INTERVAL_MS", 2500)
|
"REPORTER_UPDATE_INTERVAL_MS", 2500
|
||||||
|
)
|
||||||
|
|
|
@ -9,13 +9,14 @@ import ray
|
||||||
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
|
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, \
|
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioResourceUsageSubscriber
|
||||||
GcsAioResourceUsageSubscriber
|
|
||||||
import ray._private.services
|
import ray._private.services
|
||||||
import ray._private.utils
|
import ray._private.utils
|
||||||
from ray.ray_constants import (DEBUG_AUTOSCALING_STATUS,
|
from ray.ray_constants import (
|
||||||
DEBUG_AUTOSCALING_STATUS_LEGACY,
|
DEBUG_AUTOSCALING_STATUS,
|
||||||
DEBUG_AUTOSCALING_ERROR)
|
DEBUG_AUTOSCALING_STATUS_LEGACY,
|
||||||
|
DEBUG_AUTOSCALING_ERROR,
|
||||||
|
)
|
||||||
from ray.core.generated import reporter_pb2
|
from ray.core.generated import reporter_pb2
|
||||||
from ray.core.generated import reporter_pb2_grpc
|
from ray.core.generated import reporter_pb2_grpc
|
||||||
import ray.experimental.internal_kv as internal_kv
|
import ray.experimental.internal_kv as internal_kv
|
||||||
|
@ -40,9 +41,10 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
||||||
if change.new:
|
if change.new:
|
||||||
node_id, ports = change.new
|
node_id, ports = change.new
|
||||||
ip = DataSource.node_id_to_ip[node_id]
|
ip = DataSource.node_id_to_ip[node_id]
|
||||||
options = (("grpc.enable_http_proxy", 0), )
|
options = (("grpc.enable_http_proxy", 0),)
|
||||||
channel = ray._private.utils.init_grpc_channel(
|
channel = ray._private.utils.init_grpc_channel(
|
||||||
f"{ip}:{ports[1]}", options=options, asynchronous=True)
|
f"{ip}:{ports[1]}", options=options, asynchronous=True
|
||||||
|
)
|
||||||
stub = reporter_pb2_grpc.ReporterServiceStub(channel)
|
stub = reporter_pb2_grpc.ReporterServiceStub(channel)
|
||||||
self._stubs[ip] = stub
|
self._stubs[ip] = stub
|
||||||
|
|
||||||
|
@ -53,13 +55,16 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
||||||
duration = int(req.query["duration"])
|
duration = int(req.query["duration"])
|
||||||
reporter_stub = self._stubs[ip]
|
reporter_stub = self._stubs[ip]
|
||||||
reply = await reporter_stub.GetProfilingStats(
|
reply = await reporter_stub.GetProfilingStats(
|
||||||
reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration))
|
reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration)
|
||||||
profiling_info = (json.loads(reply.profiling_stats)
|
)
|
||||||
if reply.profiling_stats else reply.std_out)
|
profiling_info = (
|
||||||
|
json.loads(reply.profiling_stats)
|
||||||
|
if reply.profiling_stats
|
||||||
|
else reply.std_out
|
||||||
|
)
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True,
|
success=True, message="Profiling success.", profiling_info=profiling_info
|
||||||
message="Profiling success.",
|
)
|
||||||
profiling_info=profiling_info)
|
|
||||||
|
|
||||||
@routes.get("/api/ray_config")
|
@routes.get("/api/ray_config")
|
||||||
async def get_ray_config(self, req) -> aiohttp.web.Response:
|
async def get_ray_config(self, req) -> aiohttp.web.Response:
|
||||||
|
@ -75,12 +80,12 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
||||||
)
|
)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=False,
|
success=False, message="Invalid config, could not load YAML."
|
||||||
message="Invalid config, could not load YAML.")
|
)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"min_workers": cfg.get("min_workers", "unspecified"),
|
"min_workers": cfg.get("min_workers", "unspecified"),
|
||||||
"max_workers": cfg.get("max_workers", "unspecified")
|
"max_workers": cfg.get("max_workers", "unspecified"),
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -115,18 +120,18 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert ray.experimental.internal_kv._internal_kv_initialized()
|
assert ray.experimental.internal_kv._internal_kv_initialized()
|
||||||
legacy_status = internal_kv._internal_kv_get(
|
legacy_status = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS_LEGACY)
|
||||||
DEBUG_AUTOSCALING_STATUS_LEGACY)
|
formatted_status_string = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS)
|
||||||
formatted_status_string = internal_kv._internal_kv_get(
|
formatted_status = (
|
||||||
DEBUG_AUTOSCALING_STATUS)
|
json.loads(formatted_status_string.decode())
|
||||||
formatted_status = json.loads(formatted_status_string.decode()
|
if formatted_status_string
|
||||||
) if formatted_status_string else {}
|
else {}
|
||||||
|
)
|
||||||
error = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_ERROR)
|
error = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_ERROR)
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True,
|
success=True,
|
||||||
message="Got cluster status.",
|
message="Got cluster status.",
|
||||||
autoscaling_status=legacy_status.decode()
|
autoscaling_status=legacy_status.decode() if legacy_status else None,
|
||||||
if legacy_status else None,
|
|
||||||
autoscaling_error=error.decode() if error else None,
|
autoscaling_error=error.decode() if error else None,
|
||||||
cluster_status=formatted_status if formatted_status else None,
|
cluster_status=formatted_status if formatted_status else None,
|
||||||
)
|
)
|
||||||
|
@ -148,8 +153,9 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
||||||
node_id = key.split(":")[-1]
|
node_id = key.split(":")[-1]
|
||||||
DataSource.node_physical_stats[node_id] = data
|
DataSource.node_physical_stats[node_id] = data
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error receiving node physical stats "
|
logger.exception(
|
||||||
"from reporter agent.")
|
"Error receiving node physical stats " "from reporter agent."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
receiver = Receiver()
|
receiver = Receiver()
|
||||||
aioredis_client = self._dashboard_head.aioredis_client
|
aioredis_client = self._dashboard_head.aioredis_client
|
||||||
|
@ -165,8 +171,9 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
||||||
node_id = key.split(":")[-1]
|
node_id = key.split(":")[-1]
|
||||||
DataSource.node_physical_stats[node_id] = data
|
DataSource.node_physical_stats[node_id] = data
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error receiving node physical stats "
|
logger.exception(
|
||||||
"from reporter agent.")
|
"Error receiving node physical stats " "from reporter agent."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_minimal_module():
|
def is_minimal_module():
|
||||||
|
|
|
@ -10,9 +10,13 @@ from ray import ray_constants
|
||||||
from ray.dashboard.tests.conftest import * # noqa
|
from ray.dashboard.tests.conftest import * # noqa
|
||||||
from ray.dashboard.utils import Bunch
|
from ray.dashboard.utils import Bunch
|
||||||
from ray.dashboard.modules.reporter.reporter_agent import ReporterAgent
|
from ray.dashboard.modules.reporter.reporter_agent import ReporterAgent
|
||||||
from ray._private.test_utils import (format_web_url, RayTestTimeoutException,
|
from ray._private.test_utils import (
|
||||||
wait_until_server_available,
|
format_web_url,
|
||||||
wait_for_condition, fetch_prometheus)
|
RayTestTimeoutException,
|
||||||
|
wait_until_server_available,
|
||||||
|
wait_for_condition,
|
||||||
|
fetch_prometheus,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import prometheus_client
|
import prometheus_client
|
||||||
|
@ -34,7 +38,7 @@ def test_profiling(shutdown_only):
|
||||||
actor_pid = ray.get(c.getpid.remote())
|
actor_pid = ray.get(c.getpid.remote())
|
||||||
|
|
||||||
webui_url = addresses["webui_url"]
|
webui_url = addresses["webui_url"]
|
||||||
assert (wait_until_server_available(webui_url) is True)
|
assert wait_until_server_available(webui_url) is True
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -44,14 +48,16 @@ def test_profiling(shutdown_only):
|
||||||
if time.time() - start_time > 15:
|
if time.time() - start_time > 15:
|
||||||
raise RayTestTimeoutException(
|
raise RayTestTimeoutException(
|
||||||
"Timed out while collecting profiling stats, "
|
"Timed out while collecting profiling stats, "
|
||||||
f"launch_profiling: {launch_profiling}")
|
f"launch_profiling: {launch_profiling}"
|
||||||
|
)
|
||||||
launch_profiling = requests.get(
|
launch_profiling = requests.get(
|
||||||
webui_url + "/api/launch_profiling",
|
webui_url + "/api/launch_profiling",
|
||||||
params={
|
params={
|
||||||
"ip": ray.nodes()[0]["NodeManagerAddress"],
|
"ip": ray.nodes()[0]["NodeManagerAddress"],
|
||||||
"pid": actor_pid,
|
"pid": actor_pid,
|
||||||
"duration": 5
|
"duration": 5,
|
||||||
}).json()
|
},
|
||||||
|
).json()
|
||||||
if launch_profiling["result"]:
|
if launch_profiling["result"]:
|
||||||
profiling_info = launch_profiling["data"]["profilingInfo"]
|
profiling_info = launch_profiling["data"]["profilingInfo"]
|
||||||
break
|
break
|
||||||
|
@ -72,13 +78,12 @@ def test_node_physical_stats(enable_test_module, shutdown_only):
|
||||||
actor_pids = set(actor_pids)
|
actor_pids = set(actor_pids)
|
||||||
|
|
||||||
webui_url = addresses["webui_url"]
|
webui_url = addresses["webui_url"]
|
||||||
assert (wait_until_server_available(webui_url) is True)
|
assert wait_until_server_available(webui_url) is True
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
|
|
||||||
def _check_workers():
|
def _check_workers():
|
||||||
try:
|
try:
|
||||||
resp = requests.get(webui_url +
|
resp = requests.get(webui_url + "/test/dump?key=node_physical_stats")
|
||||||
"/test/dump?key=node_physical_stats")
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
result = resp.json()
|
result = resp.json()
|
||||||
assert result["result"] is True
|
assert result["result"] is True
|
||||||
|
@ -101,8 +106,7 @@ def test_node_physical_stats(enable_test_module, shutdown_only):
|
||||||
wait_for_condition(_check_workers, timeout=10)
|
wait_for_condition(_check_workers, timeout=10)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(prometheus_client is None, reason="prometheus_client not installed")
|
||||||
prometheus_client is None, reason="prometheus_client not installed")
|
|
||||||
def test_prometheus_physical_stats_record(enable_test_module, shutdown_only):
|
def test_prometheus_physical_stats_record(enable_test_module, shutdown_only):
|
||||||
addresses = ray.init(include_dashboard=True, num_cpus=1)
|
addresses = ray.init(include_dashboard=True, num_cpus=1)
|
||||||
metrics_export_port = addresses["metrics_export_port"]
|
metrics_export_port = addresses["metrics_export_port"]
|
||||||
|
@ -110,29 +114,31 @@ def test_prometheus_physical_stats_record(enable_test_module, shutdown_only):
|
||||||
prom_addresses = [f"{addr}:{metrics_export_port}"]
|
prom_addresses = [f"{addr}:{metrics_export_port}"]
|
||||||
|
|
||||||
def test_case_stats_exist():
|
def test_case_stats_exist():
|
||||||
components_dict, metric_names, metric_samples = fetch_prometheus(
|
components_dict, metric_names, metric_samples = fetch_prometheus(prom_addresses)
|
||||||
prom_addresses)
|
return all(
|
||||||
return all([
|
[
|
||||||
"ray_node_cpu_utilization" in metric_names,
|
"ray_node_cpu_utilization" in metric_names,
|
||||||
"ray_node_cpu_count" in metric_names,
|
"ray_node_cpu_count" in metric_names,
|
||||||
"ray_node_mem_used" in metric_names,
|
"ray_node_mem_used" in metric_names,
|
||||||
"ray_node_mem_available" in metric_names,
|
"ray_node_mem_available" in metric_names,
|
||||||
"ray_node_mem_total" in metric_names,
|
"ray_node_mem_total" in metric_names,
|
||||||
"ray_raylet_cpu" in metric_names, "ray_raylet_mem" in metric_names,
|
"ray_raylet_cpu" in metric_names,
|
||||||
"ray_node_disk_usage" in metric_names,
|
"ray_raylet_mem" in metric_names,
|
||||||
"ray_node_disk_free" in metric_names,
|
"ray_node_disk_usage" in metric_names,
|
||||||
"ray_node_disk_utilization_percentage" in metric_names,
|
"ray_node_disk_free" in metric_names,
|
||||||
"ray_node_network_sent" in metric_names,
|
"ray_node_disk_utilization_percentage" in metric_names,
|
||||||
"ray_node_network_received" in metric_names,
|
"ray_node_network_sent" in metric_names,
|
||||||
"ray_node_network_send_speed" in metric_names,
|
"ray_node_network_received" in metric_names,
|
||||||
"ray_node_network_receive_speed" in metric_names
|
"ray_node_network_send_speed" in metric_names,
|
||||||
])
|
"ray_node_network_receive_speed" in metric_names,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def test_case_ip_correct():
|
def test_case_ip_correct():
|
||||||
components_dict, metric_names, metric_samples = fetch_prometheus(
|
components_dict, metric_names, metric_samples = fetch_prometheus(prom_addresses)
|
||||||
prom_addresses)
|
|
||||||
raylet_proc = ray.worker._global_node.all_processes[
|
raylet_proc = ray.worker._global_node.all_processes[
|
||||||
ray_constants.PROCESS_TYPE_RAYLET][0]
|
ray_constants.PROCESS_TYPE_RAYLET
|
||||||
|
][0]
|
||||||
raylet_pid = None
|
raylet_pid = None
|
||||||
# Find the raylet pid recorded in the tag.
|
# Find the raylet pid recorded in the tag.
|
||||||
for sample in metric_samples:
|
for sample in metric_samples:
|
||||||
|
@ -159,24 +165,25 @@ def test_report_stats():
|
||||||
"cpu": 57.4,
|
"cpu": 57.4,
|
||||||
"cpus": (8, 4),
|
"cpus": (8, 4),
|
||||||
"mem": (17179869184, 5723353088, 66.7, 9234341888),
|
"mem": (17179869184, 5723353088, 66.7, 9234341888),
|
||||||
"workers": [{
|
"workers": [
|
||||||
"memory_info": Bunch(
|
{
|
||||||
rss=55934976, vms=7026937856, pfaults=15354, pageins=0),
|
"memory_info": Bunch(
|
||||||
"cpu_percent": 0.0,
|
rss=55934976, vms=7026937856, pfaults=15354, pageins=0
|
||||||
"cmdline": [
|
),
|
||||||
"ray::IDLE", "", "", "", "", "", "", "", "", "", "", ""
|
"cpu_percent": 0.0,
|
||||||
],
|
"cmdline": ["ray::IDLE", "", "", "", "", "", "", "", "", "", "", ""],
|
||||||
"create_time": 1614826391.338613,
|
"create_time": 1614826391.338613,
|
||||||
"pid": 7174,
|
"pid": 7174,
|
||||||
"cpu_times": Bunch(
|
"cpu_times": Bunch(
|
||||||
user=0.607899328,
|
user=0.607899328,
|
||||||
system=0.274044032,
|
system=0.274044032,
|
||||||
children_user=0.0,
|
children_user=0.0,
|
||||||
children_system=0.0)
|
children_system=0.0,
|
||||||
}],
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
"raylet": {
|
"raylet": {
|
||||||
"memory_info": Bunch(
|
"memory_info": Bunch(rss=18354176, vms=6921486336, pfaults=6206, pageins=3),
|
||||||
rss=18354176, vms=6921486336, pfaults=6206, pageins=3),
|
|
||||||
"cpu_percent": 0.0,
|
"cpu_percent": 0.0,
|
||||||
"cmdline": ["fake raylet cmdline"],
|
"cmdline": ["fake raylet cmdline"],
|
||||||
"create_time": 1614826390.274854,
|
"create_time": 1614826390.274854,
|
||||||
|
@ -185,22 +192,18 @@ def test_report_stats():
|
||||||
user=0.03683138,
|
user=0.03683138,
|
||||||
system=0.035913716,
|
system=0.035913716,
|
||||||
children_user=0.0,
|
children_user=0.0,
|
||||||
children_system=0.0)
|
children_system=0.0,
|
||||||
|
),
|
||||||
},
|
},
|
||||||
"bootTime": 1612934656.0,
|
"bootTime": 1612934656.0,
|
||||||
"loadAvg": ((4.4521484375, 3.61083984375, 3.5400390625), (0.56, 0.45,
|
"loadAvg": ((4.4521484375, 3.61083984375, 3.5400390625), (0.56, 0.45, 0.44)),
|
||||||
0.44)),
|
|
||||||
"disk": {
|
"disk": {
|
||||||
"/": Bunch(
|
"/": Bunch(
|
||||||
total=250790436864,
|
total=250790436864, used=11316781056, free=22748921856, percent=33.2
|
||||||
used=11316781056,
|
),
|
||||||
free=22748921856,
|
|
||||||
percent=33.2),
|
|
||||||
"/tmp": Bunch(
|
"/tmp": Bunch(
|
||||||
total=250790436864,
|
total=250790436864, used=209532035072, free=22748921856, percent=90.2
|
||||||
used=209532035072,
|
),
|
||||||
free=22748921856,
|
|
||||||
percent=90.2)
|
|
||||||
},
|
},
|
||||||
"gpus": [],
|
"gpus": [],
|
||||||
"network": (13621160960, 11914936320),
|
"network": (13621160960, 11914936320),
|
||||||
|
@ -209,13 +212,10 @@ def test_report_stats():
|
||||||
|
|
||||||
cluster_stats = {
|
cluster_stats = {
|
||||||
"autoscaler_report": {
|
"autoscaler_report": {
|
||||||
"active_nodes": {
|
"active_nodes": {"head_node": 1, "worker-node-0": 2},
|
||||||
"head_node": 1,
|
|
||||||
"worker-node-0": 2
|
|
||||||
},
|
|
||||||
"failed_nodes": [],
|
"failed_nodes": [],
|
||||||
"pending_launches": {},
|
"pending_launches": {},
|
||||||
"pending_nodes": []
|
"pending_nodes": [],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,11 +226,9 @@ def test_report_stats():
|
||||||
records = ReporterAgent._record_stats(obj, test_stats, cluster_stats)
|
records = ReporterAgent._record_stats(obj, test_stats, cluster_stats)
|
||||||
assert len(records) == 14
|
assert len(records) == 14
|
||||||
# Test stats with gpus
|
# Test stats with gpus
|
||||||
test_stats["gpus"] = [{
|
test_stats["gpus"] = [
|
||||||
"utilization_gpu": 1,
|
{"utilization_gpu": 1, "memory_used": 100, "memory_total": 1000}
|
||||||
"memory_used": 100,
|
]
|
||||||
"memory_total": 1000
|
|
||||||
}]
|
|
||||||
records = ReporterAgent._record_stats(obj, test_stats, cluster_stats)
|
records = ReporterAgent._record_stats(obj, test_stats, cluster_stats)
|
||||||
assert len(records) == 18
|
assert len(records) == 18
|
||||||
# Test stats without autoscaler report
|
# Test stats without autoscaler report
|
||||||
|
|
|
@ -12,10 +12,11 @@ from ray.core.generated import runtime_env_agent_pb2
|
||||||
from ray.core.generated import runtime_env_agent_pb2_grpc
|
from ray.core.generated import runtime_env_agent_pb2_grpc
|
||||||
from ray.core.generated import agent_manager_pb2
|
from ray.core.generated import agent_manager_pb2
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
import ray.dashboard.modules.runtime_env.runtime_env_consts \
|
import ray.dashboard.modules.runtime_env.runtime_env_consts as runtime_env_consts
|
||||||
as runtime_env_consts
|
from ray.experimental.internal_kv import (
|
||||||
from ray.experimental.internal_kv import _internal_kv_initialized, \
|
_internal_kv_initialized,
|
||||||
_initialize_internal_kv
|
_initialize_internal_kv,
|
||||||
|
)
|
||||||
from ray._private.ray_logging import setup_component_logger
|
from ray._private.ray_logging import setup_component_logger
|
||||||
from ray._private.runtime_env.pip import PipManager
|
from ray._private.runtime_env.pip import PipManager
|
||||||
from ray._private.runtime_env.conda import CondaManager
|
from ray._private.runtime_env.conda import CondaManager
|
||||||
|
@ -42,8 +43,10 @@ class CreatedEnvResult:
|
||||||
result: str
|
result: str
|
||||||
|
|
||||||
|
|
||||||
class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
class RuntimeEnvAgent(
|
||||||
runtime_env_agent_pb2_grpc.RuntimeEnvServiceServicer):
|
dashboard_utils.DashboardAgentModule,
|
||||||
|
runtime_env_agent_pb2_grpc.RuntimeEnvServiceServicer,
|
||||||
|
):
|
||||||
"""An RPC server to create and delete runtime envs.
|
"""An RPC server to create and delete runtime envs.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -86,32 +89,33 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
||||||
return self._per_job_logger_cache[job_id]
|
return self._per_job_logger_cache[job_id]
|
||||||
|
|
||||||
async def CreateRuntimeEnv(self, request, context):
|
async def CreateRuntimeEnv(self, request, context):
|
||||||
async def _setup_runtime_env(serialized_runtime_env,
|
async def _setup_runtime_env(
|
||||||
serialized_allocated_resource_instances):
|
serialized_runtime_env, serialized_allocated_resource_instances
|
||||||
|
):
|
||||||
# This function will be ran inside a thread
|
# This function will be ran inside a thread
|
||||||
def run_setup_with_logger():
|
def run_setup_with_logger():
|
||||||
runtime_env = RuntimeEnv(
|
runtime_env = RuntimeEnv(serialized_runtime_env=serialized_runtime_env)
|
||||||
serialized_runtime_env=serialized_runtime_env)
|
|
||||||
allocated_resource: dict = json.loads(
|
allocated_resource: dict = json.loads(
|
||||||
serialized_allocated_resource_instances or "{}")
|
serialized_allocated_resource_instances or "{}"
|
||||||
|
)
|
||||||
|
|
||||||
# Use a separate logger for each job.
|
# Use a separate logger for each job.
|
||||||
per_job_logger = self.get_or_create_logger(request.job_id)
|
per_job_logger = self.get_or_create_logger(request.job_id)
|
||||||
# TODO(chenk008): Add log about allocated_resource to
|
# TODO(chenk008): Add log about allocated_resource to
|
||||||
# avoid lint error. That will be moved to cgroup plugin.
|
# avoid lint error. That will be moved to cgroup plugin.
|
||||||
per_job_logger.debug(f"Worker has resource :"
|
per_job_logger.debug(f"Worker has resource :" f"{allocated_resource}")
|
||||||
f"{allocated_resource}")
|
|
||||||
context = RuntimeEnvContext(env_vars=runtime_env.env_vars())
|
context = RuntimeEnvContext(env_vars=runtime_env.env_vars())
|
||||||
self._pip_manager.setup(
|
self._pip_manager.setup(runtime_env, context, logger=per_job_logger)
|
||||||
runtime_env, context, logger=per_job_logger)
|
self._conda_manager.setup(runtime_env, context, logger=per_job_logger)
|
||||||
self._conda_manager.setup(
|
|
||||||
runtime_env, context, logger=per_job_logger)
|
|
||||||
self._py_modules_manager.setup(
|
self._py_modules_manager.setup(
|
||||||
runtime_env, context, logger=per_job_logger)
|
runtime_env, context, logger=per_job_logger
|
||||||
|
)
|
||||||
self._working_dir_manager.setup(
|
self._working_dir_manager.setup(
|
||||||
runtime_env, context, logger=per_job_logger)
|
runtime_env, context, logger=per_job_logger
|
||||||
|
)
|
||||||
self._container_manager.setup(
|
self._container_manager.setup(
|
||||||
runtime_env, context, logger=per_job_logger)
|
runtime_env, context, logger=per_job_logger
|
||||||
|
)
|
||||||
|
|
||||||
# Add the mapping of URIs -> the serialized environment to be
|
# Add the mapping of URIs -> the serialized environment to be
|
||||||
# used for cache invalidation.
|
# used for cache invalidation.
|
||||||
|
@ -133,14 +137,15 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
||||||
|
|
||||||
# Run setup function from all the plugins
|
# Run setup function from all the plugins
|
||||||
for plugin_class_path, config in runtime_env.plugins():
|
for plugin_class_path, config in runtime_env.plugins():
|
||||||
logger.debug(
|
logger.debug(f"Setting up runtime env plugin {plugin_class_path}")
|
||||||
f"Setting up runtime env plugin {plugin_class_path}")
|
|
||||||
plugin_class = import_attr(plugin_class_path)
|
plugin_class = import_attr(plugin_class_path)
|
||||||
# TODO(simon): implement uri support
|
# TODO(simon): implement uri support
|
||||||
plugin_class.create("uri not implemented",
|
plugin_class.create(
|
||||||
json.loads(config), context)
|
"uri not implemented", json.loads(config), context
|
||||||
plugin_class.modify_context("uri not implemented",
|
)
|
||||||
json.loads(config), context)
|
plugin_class.modify_context(
|
||||||
|
"uri not implemented", json.loads(config), context
|
||||||
|
)
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
@ -159,18 +164,24 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
||||||
result = self._env_cache[serialized_env]
|
result = self._env_cache[serialized_env]
|
||||||
if result.success:
|
if result.success:
|
||||||
context = result.result
|
context = result.result
|
||||||
logger.info("Runtime env already created successfully. "
|
logger.info(
|
||||||
f"Env: {serialized_env}, context: {context}")
|
"Runtime env already created successfully. "
|
||||||
|
f"Env: {serialized_env}, context: {context}"
|
||||||
|
)
|
||||||
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
||||||
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
|
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
|
||||||
serialized_runtime_env_context=context)
|
serialized_runtime_env_context=context,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
error_message = result.result
|
error_message = result.result
|
||||||
logger.info("Runtime env already failed. "
|
logger.info(
|
||||||
f"Env: {serialized_env}, err: {error_message}")
|
"Runtime env already failed. "
|
||||||
|
f"Env: {serialized_env}, err: {error_message}"
|
||||||
|
)
|
||||||
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
||||||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
||||||
error_message=error_message)
|
error_message=error_message,
|
||||||
|
)
|
||||||
|
|
||||||
if SLEEP_FOR_TESTING_S:
|
if SLEEP_FOR_TESTING_S:
|
||||||
logger.info(f"Sleeping for {SLEEP_FOR_TESTING_S}s.")
|
logger.info(f"Sleeping for {SLEEP_FOR_TESTING_S}s.")
|
||||||
|
@ -182,8 +193,8 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
||||||
for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES):
|
for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES):
|
||||||
try:
|
try:
|
||||||
runtime_env_context = await _setup_runtime_env(
|
runtime_env_context = await _setup_runtime_env(
|
||||||
serialized_env,
|
serialized_env, request.serialized_allocated_resource_instances
|
||||||
request.serialized_allocated_resource_instances)
|
)
|
||||||
break
|
break
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception("Runtime env creation failed.")
|
logger.exception("Runtime env creation failed.")
|
||||||
|
@ -195,22 +206,25 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
||||||
logger.error(
|
logger.error(
|
||||||
"Runtime env creation failed for %d times, "
|
"Runtime env creation failed for %d times, "
|
||||||
"don't retry any more.",
|
"don't retry any more.",
|
||||||
runtime_env_consts.RUNTIME_ENV_RETRY_TIMES)
|
runtime_env_consts.RUNTIME_ENV_RETRY_TIMES,
|
||||||
self._env_cache[serialized_env] = CreatedEnvResult(
|
)
|
||||||
False, error_message)
|
self._env_cache[serialized_env] = CreatedEnvResult(False, error_message)
|
||||||
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
||||||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
||||||
error_message=error_message)
|
error_message=error_message,
|
||||||
|
)
|
||||||
|
|
||||||
serialized_context = runtime_env_context.serialize()
|
serialized_context = runtime_env_context.serialize()
|
||||||
self._env_cache[serialized_env] = CreatedEnvResult(
|
self._env_cache[serialized_env] = CreatedEnvResult(True, serialized_context)
|
||||||
True, serialized_context)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Successfully created runtime env: %s, the context: %s",
|
"Successfully created runtime env: %s, the context: %s",
|
||||||
serialized_env, serialized_context)
|
serialized_env,
|
||||||
|
serialized_context,
|
||||||
|
)
|
||||||
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
||||||
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
|
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
|
||||||
serialized_runtime_env_context=serialized_context)
|
serialized_runtime_env_context=serialized_context,
|
||||||
|
)
|
||||||
|
|
||||||
async def DeleteURIs(self, request, context):
|
async def DeleteURIs(self, request, context):
|
||||||
logger.info(f"Got request to delete URIs: {request.uris}.")
|
logger.info(f"Got request to delete URIs: {request.uris}.")
|
||||||
|
@ -239,20 +253,21 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"RuntimeEnvAgent received DeleteURI request "
|
"RuntimeEnvAgent received DeleteURI request "
|
||||||
f"for unsupported plugin {plugin}. URI: {uri}")
|
f"for unsupported plugin {plugin}. URI: {uri}"
|
||||||
|
)
|
||||||
|
|
||||||
if failed_uris:
|
if failed_uris:
|
||||||
return runtime_env_agent_pb2.DeleteURIsReply(
|
return runtime_env_agent_pb2.DeleteURIsReply(
|
||||||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
||||||
error_message="Local files for URI(s) "
|
error_message="Local files for URI(s) " f"{failed_uris} not found.",
|
||||||
f"{failed_uris} not found.")
|
)
|
||||||
else:
|
else:
|
||||||
return runtime_env_agent_pb2.DeleteURIsReply(
|
return runtime_env_agent_pb2.DeleteURIsReply(
|
||||||
status=agent_manager_pb2.AGENT_RPC_STATUS_OK)
|
status=agent_manager_pb2.AGENT_RPC_STATUS_OK
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self, server):
|
async def run(self, server):
|
||||||
runtime_env_agent_pb2_grpc.add_RuntimeEnvServiceServicer_to_server(
|
runtime_env_agent_pb2_grpc.add_RuntimeEnvServiceServicer_to_server(self, server)
|
||||||
self, server)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_minimal_module():
|
def is_minimal_module():
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import ray.ray_constants as ray_constants
|
import ray.ray_constants as ray_constants
|
||||||
|
|
||||||
RUNTIME_ENV_RETRY_TIMES = ray_constants.env_integer("RUNTIME_ENV_RETRY_TIMES",
|
RUNTIME_ENV_RETRY_TIMES = ray_constants.env_integer("RUNTIME_ENV_RETRY_TIMES", 3)
|
||||||
3)
|
|
||||||
|
|
||||||
RUNTIME_ENV_RETRY_INTERVAL_MS = ray_constants.env_integer(
|
RUNTIME_ENV_RETRY_INTERVAL_MS = ray_constants.env_integer(
|
||||||
"RUNTIME_ENV_RETRY_INTERVAL_MS", 1000)
|
"RUNTIME_ENV_RETRY_INTERVAL_MS", 1000
|
||||||
|
)
|
||||||
|
|
|
@ -8,13 +8,19 @@ from ray import ray_constants
|
||||||
from ray.core.generated import gcs_service_pb2
|
from ray.core.generated import gcs_service_pb2
|
||||||
from ray.core.generated import gcs_pb2
|
from ray.core.generated import gcs_pb2
|
||||||
from ray.core.generated import gcs_service_pb2_grpc
|
from ray.core.generated import gcs_service_pb2_grpc
|
||||||
from ray.experimental.internal_kv import (_internal_kv_initialized,
|
from ray.experimental.internal_kv import (
|
||||||
_internal_kv_get, _internal_kv_list)
|
_internal_kv_initialized,
|
||||||
|
_internal_kv_get,
|
||||||
|
_internal_kv_list,
|
||||||
|
)
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||||
from ray._private.runtime_env.validation import ParsedRuntimeEnv
|
from ray._private.runtime_env.validation import ParsedRuntimeEnv
|
||||||
from ray.dashboard.modules.job.common import (
|
from ray.dashboard.modules.job.common import (
|
||||||
JobStatusInfo, JobStatusStorageClient, JOB_ID_METADATA_KEY)
|
JobStatusInfo,
|
||||||
|
JobStatusStorageClient,
|
||||||
|
JOB_ID_METADATA_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import aiohttp.web
|
import aiohttp.web
|
||||||
|
@ -32,7 +38,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
||||||
self._job_status_client = JobStatusStorageClient()
|
self._job_status_client = JobStatusStorageClient()
|
||||||
# For offloading CPU intensive work.
|
# For offloading CPU intensive work.
|
||||||
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||||
max_workers=2, thread_name_prefix="api_head")
|
max_workers=2, thread_name_prefix="api_head"
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/api/actors/kill")
|
@routes.get("/api/actors/kill")
|
||||||
async def kill_actor_gcs(self, req) -> aiohttp.web.Response:
|
async def kill_actor_gcs(self, req) -> aiohttp.web.Response:
|
||||||
|
@ -41,7 +48,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
||||||
no_restart = req.query.get("no_restart", False) in ("true", "True")
|
no_restart = req.query.get("no_restart", False) in ("true", "True")
|
||||||
if not actor_id:
|
if not actor_id:
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=False, message="actor_id is required.")
|
success=False, message="actor_id is required."
|
||||||
|
)
|
||||||
|
|
||||||
request = gcs_service_pb2.KillActorViaGcsRequest()
|
request = gcs_service_pb2.KillActorViaGcsRequest()
|
||||||
request.actor_id = bytes.fromhex(actor_id)
|
request.actor_id = bytes.fromhex(actor_id)
|
||||||
|
@ -49,31 +57,36 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
||||||
request.no_restart = no_restart
|
request.no_restart = no_restart
|
||||||
await self._gcs_actor_info_stub.KillActorViaGcs(request, timeout=5)
|
await self._gcs_actor_info_stub.KillActorViaGcs(request, timeout=5)
|
||||||
|
|
||||||
message = (f"Force killed actor with id {actor_id}" if force_kill else
|
message = (
|
||||||
f"Requested actor with id {actor_id} to terminate. " +
|
f"Force killed actor with id {actor_id}"
|
||||||
"It will exit once running tasks complete")
|
if force_kill
|
||||||
|
else f"Requested actor with id {actor_id} to terminate. "
|
||||||
|
+ "It will exit once running tasks complete"
|
||||||
|
)
|
||||||
|
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(success=True, message=message)
|
||||||
success=True, message=message)
|
|
||||||
|
|
||||||
@routes.get("/api/snapshot")
|
@routes.get("/api/snapshot")
|
||||||
async def snapshot(self, req):
|
async def snapshot(self, req):
|
||||||
job_data, actor_data, serve_data, session_name = await asyncio.gather(
|
job_data, actor_data, serve_data, session_name = await asyncio.gather(
|
||||||
self.get_job_info(), self.get_actor_info(), self.get_serve_info(),
|
self.get_job_info(),
|
||||||
self.get_session_name())
|
self.get_actor_info(),
|
||||||
|
self.get_serve_info(),
|
||||||
|
self.get_session_name(),
|
||||||
|
)
|
||||||
snapshot = {
|
snapshot = {
|
||||||
"jobs": job_data,
|
"jobs": job_data,
|
||||||
"actors": actor_data,
|
"actors": actor_data,
|
||||||
"deployments": serve_data,
|
"deployments": serve_data,
|
||||||
"session_name": session_name,
|
"session_name": session_name,
|
||||||
"ray_version": ray.__version__,
|
"ray_version": ray.__version__,
|
||||||
"ray_commit": ray.__commit__
|
"ray_commit": ray.__commit__,
|
||||||
}
|
}
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True, message="hello", snapshot=snapshot)
|
success=True, message="hello", snapshot=snapshot
|
||||||
|
)
|
||||||
|
|
||||||
def _get_job_status(self,
|
def _get_job_status(self, metadata: Dict[str, str]) -> Optional[JobStatusInfo]:
|
||||||
metadata: Dict[str, str]) -> Optional[JobStatusInfo]:
|
|
||||||
# If a job submission ID has been added to a job, the status is
|
# If a job submission ID has been added to a job, the status is
|
||||||
# guaranteed to be returned.
|
# guaranteed to be returned.
|
||||||
job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
|
job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
|
||||||
|
@ -91,8 +104,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
||||||
"namespace": job_table_entry.config.ray_namespace,
|
"namespace": job_table_entry.config.ray_namespace,
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
"runtime_env": ParsedRuntimeEnv.deserialize(
|
"runtime_env": ParsedRuntimeEnv.deserialize(
|
||||||
job_table_entry.config.runtime_env_info.
|
job_table_entry.config.runtime_env_info.serialized_runtime_env
|
||||||
serialized_runtime_env),
|
),
|
||||||
}
|
}
|
||||||
status = self._get_job_status(metadata)
|
status = self._get_job_status(metadata)
|
||||||
entry = {
|
entry = {
|
||||||
|
@ -111,8 +124,7 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
||||||
# TODO (Alex): GCS still needs to return actors from dead jobs.
|
# TODO (Alex): GCS still needs to return actors from dead jobs.
|
||||||
request = gcs_service_pb2.GetAllActorInfoRequest()
|
request = gcs_service_pb2.GetAllActorInfoRequest()
|
||||||
request.show_dead_jobs = True
|
request.show_dead_jobs = True
|
||||||
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
|
reply = await self._gcs_actor_info_stub.GetAllActorInfo(request, timeout=5)
|
||||||
request, timeout=5)
|
|
||||||
actors = {}
|
actors = {}
|
||||||
for actor_table_entry in reply.actor_table_data:
|
for actor_table_entry in reply.actor_table_data:
|
||||||
actor_id = actor_table_entry.actor_id.hex()
|
actor_id = actor_table_entry.actor_id.hex()
|
||||||
|
@ -120,37 +132,33 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
||||||
entry = {
|
entry = {
|
||||||
"job_id": actor_table_entry.job_id.hex(),
|
"job_id": actor_table_entry.job_id.hex(),
|
||||||
"state": gcs_pb2.ActorTableData.ActorState.Name(
|
"state": gcs_pb2.ActorTableData.ActorState.Name(
|
||||||
actor_table_entry.state),
|
actor_table_entry.state
|
||||||
|
),
|
||||||
"name": actor_table_entry.name,
|
"name": actor_table_entry.name,
|
||||||
"namespace": actor_table_entry.ray_namespace,
|
"namespace": actor_table_entry.ray_namespace,
|
||||||
"runtime_env": runtime_env,
|
"runtime_env": runtime_env,
|
||||||
"start_time": actor_table_entry.start_time,
|
"start_time": actor_table_entry.start_time,
|
||||||
"end_time": actor_table_entry.end_time,
|
"end_time": actor_table_entry.end_time,
|
||||||
"is_detached": actor_table_entry.is_detached,
|
"is_detached": actor_table_entry.is_detached,
|
||||||
"resources": dict(
|
"resources": dict(actor_table_entry.task_spec.required_resources),
|
||||||
actor_table_entry.task_spec.required_resources),
|
|
||||||
"actor_class": actor_table_entry.class_name,
|
"actor_class": actor_table_entry.class_name,
|
||||||
"current_worker_id": actor_table_entry.address.worker_id.hex(),
|
"current_worker_id": actor_table_entry.address.worker_id.hex(),
|
||||||
"current_raylet_id": actor_table_entry.address.raylet_id.hex(),
|
"current_raylet_id": actor_table_entry.address.raylet_id.hex(),
|
||||||
"ip_address": actor_table_entry.address.ip_address,
|
"ip_address": actor_table_entry.address.ip_address,
|
||||||
"port": actor_table_entry.address.port,
|
"port": actor_table_entry.address.port,
|
||||||
"metadata": dict()
|
"metadata": dict(),
|
||||||
}
|
}
|
||||||
actors[actor_id] = entry
|
actors[actor_id] = entry
|
||||||
|
|
||||||
deployments = await self.get_serve_info()
|
deployments = await self.get_serve_info()
|
||||||
for _, deployment_info in deployments.items():
|
for _, deployment_info in deployments.items():
|
||||||
for replica_actor_id, actor_info in deployment_info[
|
for replica_actor_id, actor_info in deployment_info["actors"].items():
|
||||||
"actors"].items():
|
|
||||||
if replica_actor_id in actors:
|
if replica_actor_id in actors:
|
||||||
serve_metadata = dict()
|
serve_metadata = dict()
|
||||||
serve_metadata["replica_tag"] = actor_info[
|
serve_metadata["replica_tag"] = actor_info["replica_tag"]
|
||||||
"replica_tag"]
|
serve_metadata["deployment_name"] = deployment_info["name"]
|
||||||
serve_metadata["deployment_name"] = deployment_info[
|
|
||||||
"name"]
|
|
||||||
serve_metadata["version"] = actor_info["version"]
|
serve_metadata["version"] = actor_info["version"]
|
||||||
actors[replica_actor_id]["metadata"][
|
actors[replica_actor_id]["metadata"]["serve"] = serve_metadata
|
||||||
"serve"] = serve_metadata
|
|
||||||
return actors
|
return actors
|
||||||
|
|
||||||
async def get_serve_info(self) -> Dict[str, Any]:
|
async def get_serve_info(self) -> Dict[str, Any]:
|
||||||
|
@ -168,22 +176,21 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
||||||
# TODO: Convert to async GRPC, if CPU usage is not a concern.
|
# TODO: Convert to async GRPC, if CPU usage is not a concern.
|
||||||
def get_deployments():
|
def get_deployments():
|
||||||
serve_keys = _internal_kv_list(
|
serve_keys = _internal_kv_list(
|
||||||
SERVE_CONTROLLER_NAME,
|
SERVE_CONTROLLER_NAME, namespace=ray_constants.KV_NAMESPACE_SERVE
|
||||||
namespace=ray_constants.KV_NAMESPACE_SERVE)
|
)
|
||||||
serve_snapshot_keys = filter(
|
serve_snapshot_keys = filter(
|
||||||
lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys)
|
lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys
|
||||||
|
)
|
||||||
|
|
||||||
deployments_per_controller: List[Dict[str, Any]] = []
|
deployments_per_controller: List[Dict[str, Any]] = []
|
||||||
for key in serve_snapshot_keys:
|
for key in serve_snapshot_keys:
|
||||||
val_bytes = _internal_kv_get(
|
val_bytes = _internal_kv_get(
|
||||||
key, namespace=ray_constants.KV_NAMESPACE_SERVE
|
key, namespace=ray_constants.KV_NAMESPACE_SERVE
|
||||||
) or "{}".encode("utf-8")
|
) or "{}".encode("utf-8")
|
||||||
deployments_per_controller.append(
|
deployments_per_controller.append(json.loads(val_bytes.decode("utf-8")))
|
||||||
json.loads(val_bytes.decode("utf-8")))
|
|
||||||
# Merge the deployments dicts of all controllers.
|
# Merge the deployments dicts of all controllers.
|
||||||
deployments: Dict[str, Any] = {
|
deployments: Dict[str, Any] = {
|
||||||
k: v
|
k: v for d in deployments_per_controller for k, v in d.items()
|
||||||
for d in deployments_per_controller for k, v in d.items()
|
|
||||||
}
|
}
|
||||||
# Replace the keys (deployment names) with their hashes to prevent
|
# Replace the keys (deployment names) with their hashes to prevent
|
||||||
# collisions caused by the automatic conversion to camelcase by the
|
# collisions caused by the automatic conversion to camelcase by the
|
||||||
|
@ -194,24 +201,27 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
||||||
}
|
}
|
||||||
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
executor=self._thread_pool, func=get_deployments)
|
executor=self._thread_pool, func=get_deployments
|
||||||
|
)
|
||||||
|
|
||||||
async def get_session_name(self):
|
async def get_session_name(self):
|
||||||
# TODO(yic): Convert to async GRPC.
|
# TODO(yic): Convert to async GRPC.
|
||||||
def get_session():
|
def get_session():
|
||||||
return ray.experimental.internal_kv._internal_kv_get(
|
return ray.experimental.internal_kv._internal_kv_get(
|
||||||
"session_name",
|
"session_name", namespace=ray_constants.KV_NAMESPACE_SESSION
|
||||||
namespace=ray_constants.KV_NAMESPACE_SESSION).decode()
|
).decode()
|
||||||
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
executor=self._thread_pool, func=get_session)
|
executor=self._thread_pool, func=get_session
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self, server):
|
async def run(self, server):
|
||||||
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
|
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
|
||||||
self._dashboard_head.aiogrpc_gcs_channel)
|
self._dashboard_head.aiogrpc_gcs_channel
|
||||||
self._gcs_actor_info_stub = \
|
)
|
||||||
gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
|
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
|
||||||
self._dashboard_head.aiogrpc_gcs_channel)
|
self._dashboard_head.aiogrpc_gcs_channel
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_minimal_module():
|
def is_minimal_module():
|
||||||
|
|
|
@ -36,15 +36,14 @@ def _actor_killed_loop(worker_pid: str, timeout_secs=3) -> bool:
|
||||||
return dead
|
return dead
|
||||||
|
|
||||||
|
|
||||||
def _kill_actor_using_dashboard_gcs(webui_url: str,
|
def _kill_actor_using_dashboard_gcs(webui_url: str, actor_id: str, force_kill=False):
|
||||||
actor_id: str,
|
|
||||||
force_kill=False):
|
|
||||||
resp = requests.get(
|
resp = requests.get(
|
||||||
webui_url + KILL_ACTOR_ENDPOINT,
|
webui_url + KILL_ACTOR_ENDPOINT,
|
||||||
params={
|
params={
|
||||||
"actor_id": actor_id,
|
"actor_id": actor_id,
|
||||||
"force_kill": force_kill,
|
"force_kill": force_kill,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
resp_json = resp.json()
|
resp_json = resp.json()
|
||||||
assert resp_json["result"] is True, "msg" in resp_json
|
assert resp_json["result"] is True, "msg" in resp_json
|
||||||
|
|
|
@ -8,8 +8,11 @@ import pprint
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ray._private.test_utils import (format_web_url, wait_for_condition,
|
from ray._private.test_utils import (
|
||||||
wait_until_server_available)
|
format_web_url,
|
||||||
|
wait_for_condition,
|
||||||
|
wait_until_server_available,
|
||||||
|
)
|
||||||
from ray.dashboard import dashboard
|
from ray.dashboard import dashboard
|
||||||
from ray.dashboard.tests.conftest import * # noqa
|
from ray.dashboard.tests.conftest import * # noqa
|
||||||
from ray.dashboard.modules.job.sdk import JobSubmissionClient
|
from ray.dashboard.modules.job.sdk import JobSubmissionClient
|
||||||
|
@ -22,25 +25,23 @@ def _get_snapshot(address: str):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
schema_path = os.path.join(
|
schema_path = os.path.join(
|
||||||
os.path.dirname(dashboard.__file__),
|
os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
|
||||||
"modules/snapshot/snapshot_schema.json")
|
)
|
||||||
pprint.pprint(data)
|
pprint.pprint(data)
|
||||||
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
|
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
|
def test_successful_job_status(
|
||||||
enable_test_module):
|
ray_start_with_dashboard, disable_aiohttp_cache, enable_test_module
|
||||||
|
):
|
||||||
address = ray_start_with_dashboard["webui_url"]
|
address = ray_start_with_dashboard["webui_url"]
|
||||||
assert wait_until_server_available(address)
|
assert wait_until_server_available(address)
|
||||||
address = format_web_url(address)
|
address = format_web_url(address)
|
||||||
|
|
||||||
entrypoint_cmd = ("python -c\""
|
entrypoint_cmd = (
|
||||||
"import ray;"
|
'python -c"' "import ray;" "ray.init();" "import time;" "time.sleep(5);" '"'
|
||||||
"ray.init();"
|
)
|
||||||
"import time;"
|
|
||||||
"time.sleep(5);"
|
|
||||||
"\"")
|
|
||||||
|
|
||||||
client = JobSubmissionClient(address)
|
client = JobSubmissionClient(address)
|
||||||
job_id = client.submit_job(entrypoint=entrypoint_cmd)
|
job_id = client.submit_job(entrypoint=entrypoint_cmd)
|
||||||
|
@ -49,11 +50,8 @@ def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
|
||||||
data = _get_snapshot(address)
|
data = _get_snapshot(address)
|
||||||
for job_entry in data["data"]["snapshot"]["jobs"].values():
|
for job_entry in data["data"]["snapshot"]["jobs"].values():
|
||||||
if job_entry["status"] is not None:
|
if job_entry["status"] is not None:
|
||||||
assert job_entry["config"]["metadata"][
|
assert job_entry["config"]["metadata"]["jobSubmissionId"] == job_id
|
||||||
"jobSubmissionId"] == job_id
|
assert job_entry["status"] in {"PENDING", "RUNNING", "SUCCEEDED"}
|
||||||
assert job_entry["status"] in {
|
|
||||||
"PENDING", "RUNNING", "SUCCEEDED"
|
|
||||||
}
|
|
||||||
assert job_entry["statusMessage"] is not None
|
assert job_entry["statusMessage"] is not None
|
||||||
return job_entry["status"] == "SUCCEEDED"
|
return job_entry["status"] == "SUCCEEDED"
|
||||||
|
|
||||||
|
@ -62,20 +60,23 @@ def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
|
||||||
wait_for_condition(wait_for_job_to_succeed, timeout=30)
|
wait_for_condition(wait_for_job_to_succeed, timeout=30)
|
||||||
|
|
||||||
|
|
||||||
def test_failed_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
|
def test_failed_job_status(
|
||||||
enable_test_module):
|
ray_start_with_dashboard, disable_aiohttp_cache, enable_test_module
|
||||||
|
):
|
||||||
address = ray_start_with_dashboard["webui_url"]
|
address = ray_start_with_dashboard["webui_url"]
|
||||||
assert wait_until_server_available(address)
|
assert wait_until_server_available(address)
|
||||||
address = format_web_url(address)
|
address = format_web_url(address)
|
||||||
|
|
||||||
entrypoint_cmd = ("python -c\""
|
entrypoint_cmd = (
|
||||||
"import ray;"
|
'python -c"'
|
||||||
"ray.init();"
|
"import ray;"
|
||||||
"import time;"
|
"ray.init();"
|
||||||
"time.sleep(5);"
|
"import time;"
|
||||||
"import sys;"
|
"time.sleep(5);"
|
||||||
"sys.exit(1);"
|
"import sys;"
|
||||||
"\"")
|
"sys.exit(1);"
|
||||||
|
'"'
|
||||||
|
)
|
||||||
client = JobSubmissionClient(address)
|
client = JobSubmissionClient(address)
|
||||||
job_id = client.submit_job(entrypoint=entrypoint_cmd)
|
job_id = client.submit_job(entrypoint=entrypoint_cmd)
|
||||||
|
|
||||||
|
@ -83,8 +84,7 @@ def test_failed_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
|
||||||
data = _get_snapshot(address)
|
data = _get_snapshot(address)
|
||||||
for job_entry in data["data"]["snapshot"]["jobs"].values():
|
for job_entry in data["data"]["snapshot"]["jobs"].values():
|
||||||
if job_entry["status"] is not None:
|
if job_entry["status"] is not None:
|
||||||
assert job_entry["config"]["metadata"][
|
assert job_entry["config"]["metadata"]["jobSubmissionId"] == job_id
|
||||||
"jobSubmissionId"] == job_id
|
|
||||||
assert job_entry["status"] in {"PENDING", "RUNNING", "FAILED"}
|
assert job_entry["status"] in {"PENDING", "RUNNING", "FAILED"}
|
||||||
assert job_entry["statusMessage"] is not None
|
assert job_entry["statusMessage"] is not None
|
||||||
return job_entry["status"] == "FAILED"
|
return job_entry["status"] == "FAILED"
|
||||||
|
|
|
@ -34,11 +34,14 @@ ray.get(a.ping.remote())
|
||||||
"""
|
"""
|
||||||
address = ray_start_with_dashboard["address"]
|
address = ray_start_with_dashboard["address"]
|
||||||
detached_driver = driver_template.format(
|
detached_driver = driver_template.format(
|
||||||
address=address, lifetime="'detached'", name="'abc'")
|
address=address, lifetime="'detached'", name="'abc'"
|
||||||
|
)
|
||||||
named_driver = driver_template.format(
|
named_driver = driver_template.format(
|
||||||
address=address, lifetime="None", name="'xyz'")
|
address=address, lifetime="None", name="'xyz'"
|
||||||
|
)
|
||||||
unnamed_driver = driver_template.format(
|
unnamed_driver = driver_template.format(
|
||||||
address=address, lifetime="None", name="None")
|
address=address, lifetime="None", name="None"
|
||||||
|
)
|
||||||
|
|
||||||
run_string_as_driver(detached_driver)
|
run_string_as_driver(detached_driver)
|
||||||
run_string_as_driver(named_driver)
|
run_string_as_driver(named_driver)
|
||||||
|
@ -50,8 +53,8 @@ ray.get(a.ping.remote())
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
schema_path = os.path.join(
|
schema_path = os.path.join(
|
||||||
os.path.dirname(dashboard.__file__),
|
os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
|
||||||
"modules/snapshot/snapshot_schema.json")
|
)
|
||||||
pprint.pprint(data)
|
pprint.pprint(data)
|
||||||
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
|
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
|
||||||
|
|
||||||
|
@ -72,10 +75,7 @@ ray.get(a.ping.remote())
|
||||||
assert data["data"]["snapshot"]["rayVersion"] == ray.__version__
|
assert data["data"]["snapshot"]["rayVersion"] == ray.__version__
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("ray_start_with_dashboard", [{"num_cpus": 4}], indirect=True)
|
||||||
"ray_start_with_dashboard", [{
|
|
||||||
"num_cpus": 4
|
|
||||||
}], indirect=True)
|
|
||||||
def test_serve_snapshot(ray_start_with_dashboard):
|
def test_serve_snapshot(ray_start_with_dashboard):
|
||||||
"""Test detached and nondetached Serve instances running concurrently."""
|
"""Test detached and nondetached Serve instances running concurrently."""
|
||||||
|
|
||||||
|
@ -115,8 +115,7 @@ my_func_deleted.delete()
|
||||||
|
|
||||||
my_func_nondetached.deploy()
|
my_func_nondetached.deploy()
|
||||||
|
|
||||||
assert requests.get(
|
assert requests.get("http://127.0.0.1:8123/my_func_nondetached").text == "hello"
|
||||||
"http://127.0.0.1:8123/my_func_nondetached").text == "hello"
|
|
||||||
|
|
||||||
webui_url = ray_start_with_dashboard["webui_url"]
|
webui_url = ray_start_with_dashboard["webui_url"]
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
|
@ -124,15 +123,16 @@ my_func_deleted.delete()
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
schema_path = os.path.join(
|
schema_path = os.path.join(
|
||||||
os.path.dirname(dashboard.__file__),
|
os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
|
||||||
"modules/snapshot/snapshot_schema.json")
|
)
|
||||||
pprint.pprint(data)
|
pprint.pprint(data)
|
||||||
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
|
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
|
||||||
|
|
||||||
assert len(data["data"]["snapshot"]["deployments"]) == 3
|
assert len(data["data"]["snapshot"]["deployments"]) == 3
|
||||||
|
|
||||||
entry = data["data"]["snapshot"]["deployments"][hashlib.sha1(
|
entry = data["data"]["snapshot"]["deployments"][
|
||||||
"my_func".encode()).hexdigest()]
|
hashlib.sha1("my_func".encode()).hexdigest()
|
||||||
|
]
|
||||||
assert entry["name"] == "my_func"
|
assert entry["name"] == "my_func"
|
||||||
assert entry["version"] is None
|
assert entry["version"] is None
|
||||||
assert entry["namespace"] == "serve"
|
assert entry["namespace"] == "serve"
|
||||||
|
@ -145,14 +145,14 @@ my_func_deleted.delete()
|
||||||
|
|
||||||
assert len(entry["actors"]) == 1
|
assert len(entry["actors"]) == 1
|
||||||
actor_id = next(iter(entry["actors"]))
|
actor_id = next(iter(entry["actors"]))
|
||||||
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"][
|
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"]["serve"]
|
||||||
"serve"]
|
|
||||||
assert metadata["deploymentName"] == "my_func"
|
assert metadata["deploymentName"] == "my_func"
|
||||||
assert metadata["version"] is None
|
assert metadata["version"] is None
|
||||||
assert len(metadata["replicaTag"]) > 0
|
assert len(metadata["replicaTag"]) > 0
|
||||||
|
|
||||||
entry_deleted = data["data"]["snapshot"]["deployments"][hashlib.sha1(
|
entry_deleted = data["data"]["snapshot"]["deployments"][
|
||||||
"my_func_deleted".encode()).hexdigest()]
|
hashlib.sha1("my_func_deleted".encode()).hexdigest()
|
||||||
|
]
|
||||||
assert entry_deleted["name"] == "my_func_deleted"
|
assert entry_deleted["name"] == "my_func_deleted"
|
||||||
assert entry_deleted["version"] == "v1"
|
assert entry_deleted["version"] == "v1"
|
||||||
assert entry_deleted["namespace"] == "serve"
|
assert entry_deleted["namespace"] == "serve"
|
||||||
|
@ -163,8 +163,9 @@ my_func_deleted.delete()
|
||||||
assert entry_deleted["startTime"] > 0
|
assert entry_deleted["startTime"] > 0
|
||||||
assert entry_deleted["endTime"] > entry_deleted["startTime"]
|
assert entry_deleted["endTime"] > entry_deleted["startTime"]
|
||||||
|
|
||||||
entry_nondetached = data["data"]["snapshot"]["deployments"][hashlib.sha1(
|
entry_nondetached = data["data"]["snapshot"]["deployments"][
|
||||||
"my_func_nondetached".encode()).hexdigest()]
|
hashlib.sha1("my_func_nondetached".encode()).hexdigest()
|
||||||
|
]
|
||||||
assert entry_nondetached["name"] == "my_func_nondetached"
|
assert entry_nondetached["name"] == "my_func_nondetached"
|
||||||
assert entry_nondetached["version"] == "v1"
|
assert entry_nondetached["version"] == "v1"
|
||||||
assert entry_nondetached["namespace"] == "default_test_namespace"
|
assert entry_nondetached["namespace"] == "default_test_namespace"
|
||||||
|
@ -177,8 +178,7 @@ my_func_deleted.delete()
|
||||||
|
|
||||||
assert len(entry_nondetached["actors"]) == 1
|
assert len(entry_nondetached["actors"]) == 1
|
||||||
actor_id = next(iter(entry_nondetached["actors"]))
|
actor_id = next(iter(entry_nondetached["actors"]))
|
||||||
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"][
|
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"]["serve"]
|
||||||
"serve"]
|
|
||||||
assert metadata["deploymentName"] == "my_func_nondetached"
|
assert metadata["deploymentName"] == "my_func_nondetached"
|
||||||
assert metadata["version"] == "v1"
|
assert metadata["version"] == "v1"
|
||||||
assert len(metadata["replicaTag"]) > 0
|
assert len(metadata["replicaTag"]) > 0
|
||||||
|
|
|
@ -13,7 +13,8 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
|
||||||
|
|
||||||
|
|
||||||
@dashboard_utils.dashboard_module(
|
@dashboard_utils.dashboard_module(
|
||||||
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False))
|
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False)
|
||||||
|
)
|
||||||
class TestAgent(dashboard_utils.DashboardAgentModule):
|
class TestAgent(dashboard_utils.DashboardAgentModule):
|
||||||
def __init__(self, dashboard_agent):
|
def __init__(self, dashboard_agent):
|
||||||
super().__init__(dashboard_agent)
|
super().__init__(dashboard_agent)
|
||||||
|
@ -25,8 +26,7 @@ class TestAgent(dashboard_utils.DashboardAgentModule):
|
||||||
@routes.get("/test/http_get_from_agent")
|
@routes.get("/test/http_get_from_agent")
|
||||||
async def get_url(self, req) -> aiohttp.web.Response:
|
async def get_url(self, req) -> aiohttp.web.Response:
|
||||||
url = req.query.get("url")
|
url = req.query.get("url")
|
||||||
result = await test_utils.http_get(self._dashboard_agent.http_session,
|
result = await test_utils.http_get(self._dashboard_agent.http_session, url)
|
||||||
url)
|
|
||||||
return aiohttp.web.json_response(result)
|
return aiohttp.web.json_response(result)
|
||||||
|
|
||||||
@routes.head("/test/route_head")
|
@routes.head("/test/route_head")
|
||||||
|
|
|
@ -15,7 +15,8 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
|
||||||
|
|
||||||
|
|
||||||
@dashboard_utils.dashboard_module(
|
@dashboard_utils.dashboard_module(
|
||||||
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False))
|
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False)
|
||||||
|
)
|
||||||
class TestHead(dashboard_utils.DashboardHeadModule):
|
class TestHead(dashboard_utils.DashboardHeadModule):
|
||||||
def __init__(self, dashboard_head):
|
def __init__(self, dashboard_head):
|
||||||
super().__init__(dashboard_head)
|
super().__init__(dashboard_head)
|
||||||
|
@ -62,26 +63,28 @@ class TestHead(dashboard_utils.DashboardHeadModule):
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True,
|
success=True,
|
||||||
message="Fetch all data from datacenter success.",
|
message="Fetch all data from datacenter success.",
|
||||||
**all_data)
|
**all_data,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
data = dict(DataSource.__dict__.get(key))
|
data = dict(DataSource.__dict__.get(key))
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True,
|
success=True,
|
||||||
message=f"Fetch {key} from datacenter success.",
|
message=f"Fetch {key} from datacenter success.",
|
||||||
**{key: data})
|
**{key: data},
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/test/notified_agents")
|
@routes.get("/test/notified_agents")
|
||||||
async def get_notified_agents(self, req) -> aiohttp.web.Response:
|
async def get_notified_agents(self, req) -> aiohttp.web.Response:
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True,
|
success=True,
|
||||||
message="Fetch notified agents success.",
|
message="Fetch notified agents success.",
|
||||||
**self._notified_agents)
|
**self._notified_agents,
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/test/http_get")
|
@routes.get("/test/http_get")
|
||||||
async def get_url(self, req) -> aiohttp.web.Response:
|
async def get_url(self, req) -> aiohttp.web.Response:
|
||||||
url = req.query.get("url")
|
url = req.query.get("url")
|
||||||
result = await test_utils.http_get(self._dashboard_head.http_session,
|
result = await test_utils.http_get(self._dashboard_head.http_session, url)
|
||||||
url)
|
|
||||||
return aiohttp.web.json_response(result)
|
return aiohttp.web.json_response(result)
|
||||||
|
|
||||||
@routes.get("/test/aiohttp_cache/{sub_path}")
|
@routes.get("/test/aiohttp_cache/{sub_path}")
|
||||||
|
@ -89,14 +92,16 @@ class TestHead(dashboard_utils.DashboardHeadModule):
|
||||||
async def test_aiohttp_cache(self, req) -> aiohttp.web.Response:
|
async def test_aiohttp_cache(self, req) -> aiohttp.web.Response:
|
||||||
value = req.query["value"]
|
value = req.query["value"]
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True, message="OK", value=value, timestamp=time.time())
|
success=True, message="OK", value=value, timestamp=time.time()
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/test/aiohttp_cache_lru/{sub_path}")
|
@routes.get("/test/aiohttp_cache_lru/{sub_path}")
|
||||||
@dashboard_optional_utils.aiohttp_cache(ttl_seconds=60, maxsize=5)
|
@dashboard_optional_utils.aiohttp_cache(ttl_seconds=60, maxsize=5)
|
||||||
async def test_aiohttp_cache_lru(self, req) -> aiohttp.web.Response:
|
async def test_aiohttp_cache_lru(self, req) -> aiohttp.web.Response:
|
||||||
value = req.query.get("value")
|
value = req.query.get("value")
|
||||||
return dashboard_optional_utils.rest_response(
|
return dashboard_optional_utils.rest_response(
|
||||||
success=True, message="OK", value=value, timestamp=time.time())
|
success=True, message="OK", value=value, timestamp=time.time()
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/test/file")
|
@routes.get("/test/file")
|
||||||
async def test_file(self, req) -> aiohttp.web.FileResponse:
|
async def test_file(self, req) -> aiohttp.web.FileResponse:
|
||||||
|
|
|
@ -4,8 +4,7 @@ import copy
|
||||||
import os
|
import os
|
||||||
import aiohttp.web
|
import aiohttp.web
|
||||||
|
|
||||||
import ray.dashboard.modules.tune.tune_consts \
|
import ray.dashboard.modules.tune.tune_consts as tune_consts
|
||||||
as tune_consts
|
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||||
from ray.dashboard.utils import async_loop_forever
|
from ray.dashboard.utils import async_loop_forever
|
||||||
|
@ -45,19 +44,17 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
||||||
@routes.get("/tune/info")
|
@routes.get("/tune/info")
|
||||||
async def tune_info(self, req) -> aiohttp.web.Response:
|
async def tune_info(self, req) -> aiohttp.web.Response:
|
||||||
stats = self.get_stats()
|
stats = self.get_stats()
|
||||||
return rest_response(
|
return rest_response(success=True, message="Fetched tune info", result=stats)
|
||||||
success=True, message="Fetched tune info", result=stats)
|
|
||||||
|
|
||||||
@routes.get("/tune/availability")
|
@routes.get("/tune/availability")
|
||||||
async def get_availability(self, req) -> aiohttp.web.Response:
|
async def get_availability(self, req) -> aiohttp.web.Response:
|
||||||
availability = {
|
availability = {
|
||||||
"available": ExperimentAnalysis is not None,
|
"available": ExperimentAnalysis is not None,
|
||||||
"trials_available": self._trials_available
|
"trials_available": self._trials_available,
|
||||||
}
|
}
|
||||||
return rest_response(
|
return rest_response(
|
||||||
success=True,
|
success=True, message="Fetched tune availability", result=availability
|
||||||
message="Fetched tune availability",
|
)
|
||||||
result=availability)
|
|
||||||
|
|
||||||
@routes.get("/tune/set_experiment")
|
@routes.get("/tune/set_experiment")
|
||||||
async def set_tune_experiment(self, req) -> aiohttp.web.Response:
|
async def set_tune_experiment(self, req) -> aiohttp.web.Response:
|
||||||
|
@ -66,25 +63,25 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
||||||
if err:
|
if err:
|
||||||
return rest_response(success=False, error=err)
|
return rest_response(success=False, error=err)
|
||||||
return rest_response(
|
return rest_response(
|
||||||
success=True, message="Successfully set experiment", **experiment)
|
success=True, message="Successfully set experiment", **experiment
|
||||||
|
)
|
||||||
|
|
||||||
@routes.get("/tune/enable_tensorboard")
|
@routes.get("/tune/enable_tensorboard")
|
||||||
async def enable_tensorboard(self, req) -> aiohttp.web.Response:
|
async def enable_tensorboard(self, req) -> aiohttp.web.Response:
|
||||||
self._enable_tensorboard()
|
self._enable_tensorboard()
|
||||||
if not self._tensor_board_dir:
|
if not self._tensor_board_dir:
|
||||||
return rest_response(
|
return rest_response(success=False, message="Error enabling tensorboard")
|
||||||
success=False, message="Error enabling tensorboard")
|
|
||||||
return rest_response(success=True, message="Enabled tensorboard")
|
return rest_response(success=True, message="Enabled tensorboard")
|
||||||
|
|
||||||
def get_stats(self):
|
def get_stats(self):
|
||||||
tensor_board_info = {
|
tensor_board_info = {
|
||||||
"tensorboard_current": self._logdir == self._tensor_board_dir,
|
"tensorboard_current": self._logdir == self._tensor_board_dir,
|
||||||
"tensorboard_enabled": self._tensor_board_dir != ""
|
"tensorboard_enabled": self._tensor_board_dir != "",
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
"trial_records": copy.deepcopy(self._trial_records),
|
"trial_records": copy.deepcopy(self._trial_records),
|
||||||
"errors": copy.deepcopy(self._errors),
|
"errors": copy.deepcopy(self._errors),
|
||||||
"tensorboard": tensor_board_info
|
"tensorboard": tensor_board_info,
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_experiment(self, experiment):
|
def set_experiment(self, experiment):
|
||||||
|
@ -104,7 +101,8 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
||||||
def collect_errors(self, df):
|
def collect_errors(self, df):
|
||||||
sub_dirs = os.listdir(self._logdir)
|
sub_dirs = os.listdir(self._logdir)
|
||||||
trial_names = filter(
|
trial_names = filter(
|
||||||
lambda d: os.path.isdir(os.path.join(self._logdir, d)), sub_dirs)
|
lambda d: os.path.isdir(os.path.join(self._logdir, d)), sub_dirs
|
||||||
|
)
|
||||||
for trial in trial_names:
|
for trial in trial_names:
|
||||||
error_path = os.path.join(self._logdir, trial, "error.txt")
|
error_path = os.path.join(self._logdir, trial, "error.txt")
|
||||||
if os.path.isfile(error_path):
|
if os.path.isfile(error_path):
|
||||||
|
@ -114,7 +112,7 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
||||||
self._errors[str(trial)] = {
|
self._errors[str(trial)] = {
|
||||||
"text": text,
|
"text": text,
|
||||||
"job_id": os.path.basename(self._logdir),
|
"job_id": os.path.basename(self._logdir),
|
||||||
"trial_id": "No Trial ID"
|
"trial_id": "No Trial ID",
|
||||||
}
|
}
|
||||||
other_data = df[df["logdir"].str.contains(trial)]
|
other_data = df[df["logdir"].str.contains(trial)]
|
||||||
if len(other_data) > 0:
|
if len(other_data) > 0:
|
||||||
|
@ -175,12 +173,25 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
||||||
|
|
||||||
# list of static attributes for trial
|
# list of static attributes for trial
|
||||||
default_names = {
|
default_names = {
|
||||||
"logdir", "time_this_iter_s", "done", "episodes_total",
|
"logdir",
|
||||||
"training_iteration", "timestamp", "timesteps_total",
|
"time_this_iter_s",
|
||||||
"experiment_id", "date", "timestamp", "time_total_s", "pid",
|
"done",
|
||||||
"hostname", "node_ip", "time_since_restore",
|
"episodes_total",
|
||||||
"timesteps_since_restore", "iterations_since_restore",
|
"training_iteration",
|
||||||
"experiment_tag", "trial_id"
|
"timestamp",
|
||||||
|
"timesteps_total",
|
||||||
|
"experiment_id",
|
||||||
|
"date",
|
||||||
|
"timestamp",
|
||||||
|
"time_total_s",
|
||||||
|
"pid",
|
||||||
|
"hostname",
|
||||||
|
"node_ip",
|
||||||
|
"time_since_restore",
|
||||||
|
"timesteps_since_restore",
|
||||||
|
"iterations_since_restore",
|
||||||
|
"experiment_tag",
|
||||||
|
"trial_id",
|
||||||
}
|
}
|
||||||
|
|
||||||
# filter attributes into floats, metrics, and config variables
|
# filter attributes into floats, metrics, and config variables
|
||||||
|
@ -196,7 +207,8 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
||||||
for trial, details in trial_details.items():
|
for trial, details in trial_details.items():
|
||||||
ts = os.path.getctime(details["logdir"])
|
ts = os.path.getctime(details["logdir"])
|
||||||
formatted_time = datetime.datetime.fromtimestamp(ts).strftime(
|
formatted_time = datetime.datetime.fromtimestamp(ts).strftime(
|
||||||
"%Y-%m-%d %H:%M:%S")
|
"%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
details["start_time"] = formatted_time
|
details["start_time"] = formatted_time
|
||||||
details["params"] = {}
|
details["params"] = {}
|
||||||
details["metrics"] = {}
|
details["metrics"] = {}
|
||||||
|
|
|
@ -25,7 +25,7 @@ except AttributeError:
|
||||||
# All third-party dependencies that are not included in the minimal Ray
|
# All third-party dependencies that are not included in the minimal Ray
|
||||||
# installation must be included in this file. This allows us to determine if
|
# installation must be included in this file. This allows us to determine if
|
||||||
# the agent has the necessary dependencies to be started.
|
# the agent has the necessary dependencies to be started.
|
||||||
from ray.dashboard.optional_deps import (aiohttp, hdrs, PathLike, RouteDef)
|
from ray.dashboard.optional_deps import aiohttp, hdrs, PathLike, RouteDef
|
||||||
from ray.dashboard.utils import to_google_style, CustomEncoder
|
from ray.dashboard.utils import to_google_style, CustomEncoder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -68,12 +68,15 @@ class ClassMethodRouteTable:
|
||||||
def _wrapper(handler):
|
def _wrapper(handler):
|
||||||
if path in cls._bind_map[method]:
|
if path in cls._bind_map[method]:
|
||||||
bind_info = cls._bind_map[method][path]
|
bind_info = cls._bind_map[method][path]
|
||||||
raise Exception(f"Duplicated route path: {path}, "
|
raise Exception(
|
||||||
f"previous one registered at "
|
f"Duplicated route path: {path}, "
|
||||||
f"{bind_info.filename}:{bind_info.lineno}")
|
f"previous one registered at "
|
||||||
|
f"{bind_info.filename}:{bind_info.lineno}"
|
||||||
|
)
|
||||||
|
|
||||||
bind_info = cls._BindInfo(handler.__code__.co_filename,
|
bind_info = cls._BindInfo(
|
||||||
handler.__code__.co_firstlineno, None)
|
handler.__code__.co_filename, handler.__code__.co_firstlineno, None
|
||||||
|
)
|
||||||
|
|
||||||
@functools.wraps(handler)
|
@functools.wraps(handler)
|
||||||
async def _handler_route(*args) -> aiohttp.web.Response:
|
async def _handler_route(*args) -> aiohttp.web.Response:
|
||||||
|
@ -86,8 +89,7 @@ class ClassMethodRouteTable:
|
||||||
return await handler(bind_info.instance, req)
|
return await handler(bind_info.instance, req)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Handle %s %s failed.", method, path)
|
logger.exception("Handle %s %s failed.", method, path)
|
||||||
return rest_response(
|
return rest_response(success=False, message=traceback.format_exc())
|
||||||
success=False, message=traceback.format_exc())
|
|
||||||
|
|
||||||
cls._bind_map[method][path] = bind_info
|
cls._bind_map[method][path] = bind_info
|
||||||
_handler_route.__route_method__ = method
|
_handler_route.__route_method__ = method
|
||||||
|
@ -132,18 +134,19 @@ class ClassMethodRouteTable:
|
||||||
def bind(cls, instance):
|
def bind(cls, instance):
|
||||||
def predicate(o):
|
def predicate(o):
|
||||||
if inspect.ismethod(o):
|
if inspect.ismethod(o):
|
||||||
return hasattr(o, "__route_method__") and hasattr(
|
return hasattr(o, "__route_method__") and hasattr(o, "__route_path__")
|
||||||
o, "__route_path__")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
handler_routes = inspect.getmembers(instance, predicate)
|
handler_routes = inspect.getmembers(instance, predicate)
|
||||||
for _, h in handler_routes:
|
for _, h in handler_routes:
|
||||||
cls._bind_map[h.__func__.__route_method__][
|
cls._bind_map[h.__func__.__route_method__][
|
||||||
h.__func__.__route_path__].instance = instance
|
h.__func__.__route_path__
|
||||||
|
].instance = instance
|
||||||
|
|
||||||
|
|
||||||
def rest_response(success, message, convert_google_style=True,
|
def rest_response(
|
||||||
**kwargs) -> aiohttp.web.Response:
|
success, message, convert_google_style=True, **kwargs
|
||||||
|
) -> aiohttp.web.Response:
|
||||||
# In the dev context we allow a dev server running on a
|
# In the dev context we allow a dev server running on a
|
||||||
# different port to consume the API, meaning we need to allow
|
# different port to consume the API, meaning we need to allow
|
||||||
# cross-origin access
|
# cross-origin access
|
||||||
|
@ -155,24 +158,24 @@ def rest_response(success, message, convert_google_style=True,
|
||||||
{
|
{
|
||||||
"result": success,
|
"result": success,
|
||||||
"msg": message,
|
"msg": message,
|
||||||
"data": to_google_style(kwargs) if convert_google_style else kwargs
|
"data": to_google_style(kwargs) if convert_google_style else kwargs,
|
||||||
},
|
},
|
||||||
dumps=functools.partial(json.dumps, cls=CustomEncoder),
|
dumps=functools.partial(json.dumps, cls=CustomEncoder),
|
||||||
headers=headers)
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# The cache value type used by aiohttp_cache.
|
# The cache value type used by aiohttp_cache.
|
||||||
_AiohttpCacheValue = namedtuple("AiohttpCacheValue",
|
_AiohttpCacheValue = namedtuple("AiohttpCacheValue", ["data", "expiration", "task"])
|
||||||
["data", "expiration", "task"])
|
|
||||||
# The methods with no request body used by aiohttp_cache.
|
# The methods with no request body used by aiohttp_cache.
|
||||||
_AIOHTTP_CACHE_NOBODY_METHODS = {hdrs.METH_GET, hdrs.METH_DELETE}
|
_AIOHTTP_CACHE_NOBODY_METHODS = {hdrs.METH_GET, hdrs.METH_DELETE}
|
||||||
|
|
||||||
|
|
||||||
def aiohttp_cache(
|
def aiohttp_cache(
|
||||||
ttl_seconds=dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS,
|
ttl_seconds=dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS,
|
||||||
maxsize=dashboard_consts.AIOHTTP_CACHE_MAX_SIZE,
|
maxsize=dashboard_consts.AIOHTTP_CACHE_MAX_SIZE,
|
||||||
enable=not env_bool(
|
enable=not env_bool(dashboard_consts.AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY, False),
|
||||||
dashboard_consts.AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY, False)):
|
):
|
||||||
assert maxsize > 0
|
assert maxsize > 0
|
||||||
cache = collections.OrderedDict()
|
cache = collections.OrderedDict()
|
||||||
|
|
||||||
|
@ -195,8 +198,7 @@ def aiohttp_cache(
|
||||||
value = cache.get(key)
|
value = cache.get(key)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
cache.move_to_end(key)
|
cache.move_to_end(key)
|
||||||
if (not value.task.done()
|
if not value.task.done() or value.expiration >= time.time():
|
||||||
or value.expiration >= time.time()):
|
|
||||||
# Update task not done or the data is not expired.
|
# Update task not done or the data is not expired.
|
||||||
return aiohttp.web.Response(**value.data)
|
return aiohttp.web.Response(**value.data)
|
||||||
|
|
||||||
|
@ -205,15 +207,16 @@ def aiohttp_cache(
|
||||||
response = task.result()
|
response = task.result()
|
||||||
except Exception:
|
except Exception:
|
||||||
response = rest_response(
|
response = rest_response(
|
||||||
success=False, message=traceback.format_exc())
|
success=False, message=traceback.format_exc()
|
||||||
|
)
|
||||||
data = {
|
data = {
|
||||||
"status": response.status,
|
"status": response.status,
|
||||||
"headers": dict(response.headers),
|
"headers": dict(response.headers),
|
||||||
"body": response.body,
|
"body": response.body,
|
||||||
}
|
}
|
||||||
cache[key] = _AiohttpCacheValue(data,
|
cache[key] = _AiohttpCacheValue(
|
||||||
time.time() + ttl_seconds,
|
data, time.time() + ttl_seconds, task
|
||||||
task)
|
)
|
||||||
cache.move_to_end(key)
|
cache.move_to_end(key)
|
||||||
if len(cache) > maxsize:
|
if len(cache) > maxsize:
|
||||||
cache.popitem(last=False)
|
cache.popitem(last=False)
|
||||||
|
|
|
@ -19,11 +19,14 @@ import requests
|
||||||
|
|
||||||
from ray import ray_constants
|
from ray import ray_constants
|
||||||
from ray._private.test_utils import (
|
from ray._private.test_utils import (
|
||||||
format_web_url, wait_for_condition, wait_until_server_available,
|
format_web_url,
|
||||||
run_string_as_driver, wait_until_succeeded_without_exception)
|
wait_for_condition,
|
||||||
|
wait_until_server_available,
|
||||||
|
run_string_as_driver,
|
||||||
|
wait_until_succeeded_without_exception,
|
||||||
|
)
|
||||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled
|
from ray._private.gcs_pubsub import gcs_pubsub_enabled
|
||||||
from ray.ray_constants import (DEBUG_AUTOSCALING_STATUS_LEGACY,
|
from ray.ray_constants import DEBUG_AUTOSCALING_STATUS_LEGACY, DEBUG_AUTOSCALING_ERROR
|
||||||
DEBUG_AUTOSCALING_ERROR)
|
|
||||||
from ray.dashboard import dashboard
|
from ray.dashboard import dashboard
|
||||||
import ray.dashboard.consts as dashboard_consts
|
import ray.dashboard.consts as dashboard_consts
|
||||||
import ray.dashboard.utils as dashboard_utils
|
import ray.dashboard.utils as dashboard_utils
|
||||||
|
@ -43,7 +46,8 @@ def make_gcs_client(address_info):
|
||||||
client = redis.StrictRedis(
|
client = redis.StrictRedis(
|
||||||
host=address[0],
|
host=address[0],
|
||||||
port=int(address[1]),
|
port=int(address[1]),
|
||||||
password=ray_constants.REDIS_DEFAULT_PASSWORD)
|
password=ray_constants.REDIS_DEFAULT_PASSWORD,
|
||||||
|
)
|
||||||
gcs_client = ray._private.gcs_utils.GcsClient.create_from_redis(client)
|
gcs_client = ray._private.gcs_utils.GcsClient.create_from_redis(client)
|
||||||
else:
|
else:
|
||||||
address = address_info["gcs_address"]
|
address = address_info["gcs_address"]
|
||||||
|
@ -73,17 +77,14 @@ cleanup_test_files()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ray_start_with_dashboard", [{
|
"ray_start_with_dashboard",
|
||||||
"_system_config": {
|
[{"_system_config": {"agent_register_timeout_ms": 5000}}],
|
||||||
"agent_register_timeout_ms": 5000
|
indirect=True,
|
||||||
}
|
)
|
||||||
}],
|
|
||||||
indirect=True)
|
|
||||||
def test_basic(ray_start_with_dashboard):
|
def test_basic(ray_start_with_dashboard):
|
||||||
"""Dashboard test that starts a Ray cluster with a dashboard server running,
|
"""Dashboard test that starts a Ray cluster with a dashboard server running,
|
||||||
then hits the dashboard API and asserts that it receives sensible data."""
|
then hits the dashboard API and asserts that it receives sensible data."""
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
is True)
|
|
||||||
address_info = ray_start_with_dashboard
|
address_info = ray_start_with_dashboard
|
||||||
node_id = address_info["node_id"]
|
node_id = address_info["node_id"]
|
||||||
gcs_client = make_gcs_client(address_info)
|
gcs_client = make_gcs_client(address_info)
|
||||||
|
@ -92,11 +93,12 @@ def test_basic(ray_start_with_dashboard):
|
||||||
all_processes = ray.worker._global_node.all_processes
|
all_processes = ray.worker._global_node.all_processes
|
||||||
assert ray_constants.PROCESS_TYPE_DASHBOARD in all_processes
|
assert ray_constants.PROCESS_TYPE_DASHBOARD in all_processes
|
||||||
assert ray_constants.PROCESS_TYPE_REPORTER not in all_processes
|
assert ray_constants.PROCESS_TYPE_REPORTER not in all_processes
|
||||||
dashboard_proc_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][
|
dashboard_proc_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
|
||||||
0]
|
|
||||||
dashboard_proc = psutil.Process(dashboard_proc_info.process.pid)
|
dashboard_proc = psutil.Process(dashboard_proc_info.process.pid)
|
||||||
assert dashboard_proc.status() in [
|
assert dashboard_proc.status() in [
|
||||||
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING, psutil.STATUS_DISK_SLEEP
|
psutil.STATUS_RUNNING,
|
||||||
|
psutil.STATUS_SLEEPING,
|
||||||
|
psutil.STATUS_DISK_SLEEP,
|
||||||
]
|
]
|
||||||
raylet_proc_info = all_processes[ray_constants.PROCESS_TYPE_RAYLET][0]
|
raylet_proc_info = all_processes[ray_constants.PROCESS_TYPE_RAYLET][0]
|
||||||
raylet_proc = psutil.Process(raylet_proc_info.process.pid)
|
raylet_proc = psutil.Process(raylet_proc_info.process.pid)
|
||||||
|
@ -140,9 +142,7 @@ def test_basic(ray_start_with_dashboard):
|
||||||
|
|
||||||
logger.info("Test agent register is OK.")
|
logger.info("Test agent register is OK.")
|
||||||
wait_for_condition(lambda: _search_agent(raylet_proc.children()))
|
wait_for_condition(lambda: _search_agent(raylet_proc.children()))
|
||||||
assert dashboard_proc.status() in [
|
assert dashboard_proc.status() in [psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING]
|
||||||
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING
|
|
||||||
]
|
|
||||||
agent_proc = _search_agent(raylet_proc.children())
|
agent_proc = _search_agent(raylet_proc.children())
|
||||||
agent_pid = agent_proc.pid
|
agent_pid = agent_proc.pid
|
||||||
|
|
||||||
|
@ -161,40 +161,39 @@ def test_basic(ray_start_with_dashboard):
|
||||||
# Check kv keys are set.
|
# Check kv keys are set.
|
||||||
logger.info("Check kv keys are set.")
|
logger.info("Check kv keys are set.")
|
||||||
dashboard_address = ray.experimental.internal_kv._internal_kv_get(
|
dashboard_address = ray.experimental.internal_kv._internal_kv_get(
|
||||||
ray_constants.DASHBOARD_ADDRESS,
|
ray_constants.DASHBOARD_ADDRESS, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
|
||||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
)
|
||||||
assert dashboard_address is not None
|
assert dashboard_address is not None
|
||||||
dashboard_rpc_address = ray.experimental.internal_kv._internal_kv_get(
|
dashboard_rpc_address = ray.experimental.internal_kv._internal_kv_get(
|
||||||
dashboard_consts.DASHBOARD_RPC_ADDRESS,
|
dashboard_consts.DASHBOARD_RPC_ADDRESS,
|
||||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||||
|
)
|
||||||
assert dashboard_rpc_address is not None
|
assert dashboard_rpc_address is not None
|
||||||
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}"
|
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}"
|
||||||
agent_ports = ray.experimental.internal_kv._internal_kv_get(
|
agent_ports = ray.experimental.internal_kv._internal_kv_get(
|
||||||
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
|
||||||
|
)
|
||||||
assert agent_ports is not None
|
assert agent_ports is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ray_start_with_dashboard", [{
|
"ray_start_with_dashboard",
|
||||||
"dashboard_host": "127.0.0.1"
|
[
|
||||||
}, {
|
{"dashboard_host": "127.0.0.1"},
|
||||||
"dashboard_host": "0.0.0.0"
|
{"dashboard_host": "0.0.0.0"},
|
||||||
}, {
|
{"dashboard_host": "::"},
|
||||||
"dashboard_host": "::"
|
],
|
||||||
}],
|
indirect=True,
|
||||||
indirect=True)
|
)
|
||||||
def test_dashboard_address(ray_start_with_dashboard):
|
def test_dashboard_address(ray_start_with_dashboard):
|
||||||
webui_url = ray_start_with_dashboard["webui_url"]
|
webui_url = ray_start_with_dashboard["webui_url"]
|
||||||
webui_ip = webui_url.split(":")[0]
|
webui_ip = webui_url.split(":")[0]
|
||||||
assert not ipaddress.ip_address(webui_ip).is_unspecified
|
assert not ipaddress.ip_address(webui_ip).is_unspecified
|
||||||
assert webui_ip in [
|
assert webui_ip in ["127.0.0.1", ray_start_with_dashboard["node_ip_address"]]
|
||||||
"127.0.0.1", ray_start_with_dashboard["node_ip_address"]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_http_get(enable_test_module, ray_start_with_dashboard):
|
def test_http_get(enable_test_module, ray_start_with_dashboard):
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
is True)
|
|
||||||
webui_url = ray_start_with_dashboard["webui_url"]
|
webui_url = ray_start_with_dashboard["webui_url"]
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
|
|
||||||
|
@ -205,8 +204,7 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
|
||||||
while True:
|
while True:
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
try:
|
try:
|
||||||
response = requests.get(webui_url + "/test/http_get?url=" +
|
response = requests.get(webui_url + "/test/http_get?url=" + target_url)
|
||||||
target_url)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
try:
|
try:
|
||||||
dump_info = response.json()
|
dump_info = response.json()
|
||||||
|
@ -221,8 +219,8 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
|
||||||
http_port, grpc_port = ports
|
http_port, grpc_port = ports
|
||||||
|
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"http://{ip}:{http_port}"
|
f"http://{ip}:{http_port}" f"/test/http_get_from_agent?url={target_url}"
|
||||||
f"/test/http_get_from_agent?url={target_url}")
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
try:
|
try:
|
||||||
dump_info = response.json()
|
dump_info = response.json()
|
||||||
|
@ -239,10 +237,10 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
|
||||||
|
|
||||||
|
|
||||||
def test_class_method_route_table(enable_test_module):
|
def test_class_method_route_table(enable_test_module):
|
||||||
head_cls_list = dashboard_utils.get_all_modules(
|
head_cls_list = dashboard_utils.get_all_modules(dashboard_utils.DashboardHeadModule)
|
||||||
dashboard_utils.DashboardHeadModule)
|
|
||||||
agent_cls_list = dashboard_utils.get_all_modules(
|
agent_cls_list = dashboard_utils.get_all_modules(
|
||||||
dashboard_utils.DashboardAgentModule)
|
dashboard_utils.DashboardAgentModule
|
||||||
|
)
|
||||||
test_head_cls = None
|
test_head_cls = None
|
||||||
for cls in head_cls_list:
|
for cls in head_cls_list:
|
||||||
if cls.__name__ == "TestHead":
|
if cls.__name__ == "TestHead":
|
||||||
|
@ -274,28 +272,23 @@ def test_class_method_route_table(enable_test_module):
|
||||||
assert any(_has_route(r, "POST", "/test/route_post") for r in all_routes)
|
assert any(_has_route(r, "POST", "/test/route_post") for r in all_routes)
|
||||||
assert any(_has_route(r, "PUT", "/test/route_put") for r in all_routes)
|
assert any(_has_route(r, "PUT", "/test/route_put") for r in all_routes)
|
||||||
assert any(_has_route(r, "PATCH", "/test/route_patch") for r in all_routes)
|
assert any(_has_route(r, "PATCH", "/test/route_patch") for r in all_routes)
|
||||||
assert any(
|
assert any(_has_route(r, "DELETE", "/test/route_delete") for r in all_routes)
|
||||||
_has_route(r, "DELETE", "/test/route_delete") for r in all_routes)
|
|
||||||
assert any(_has_route(r, "*", "/test/route_view") for r in all_routes)
|
assert any(_has_route(r, "*", "/test/route_view") for r in all_routes)
|
||||||
|
|
||||||
# Test bind()
|
# Test bind()
|
||||||
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes(
|
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
|
||||||
)
|
|
||||||
assert len(bound_routes) == 0
|
assert len(bound_routes) == 0
|
||||||
dashboard_optional_utils.ClassMethodRouteTable.bind(
|
dashboard_optional_utils.ClassMethodRouteTable.bind(
|
||||||
test_agent_cls.__new__(test_agent_cls))
|
test_agent_cls.__new__(test_agent_cls)
|
||||||
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes(
|
|
||||||
)
|
)
|
||||||
|
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
|
||||||
assert any(_has_route(r, "POST", "/test/route_post") for r in bound_routes)
|
assert any(_has_route(r, "POST", "/test/route_post") for r in bound_routes)
|
||||||
assert all(
|
assert all(not _has_route(r, "PUT", "/test/route_put") for r in bound_routes)
|
||||||
not _has_route(r, "PUT", "/test/route_put") for r in bound_routes)
|
|
||||||
|
|
||||||
# Static def should be in bound routes.
|
# Static def should be in bound routes.
|
||||||
routes.static("/test/route_static", "/path")
|
routes.static("/test/route_static", "/path")
|
||||||
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes(
|
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
|
||||||
)
|
assert any(_has_static(r, "/path", "/test/route_static") for r in bound_routes)
|
||||||
assert any(
|
|
||||||
_has_static(r, "/path", "/test/route_static") for r in bound_routes)
|
|
||||||
|
|
||||||
# Test duplicated routes should raise exception.
|
# Test duplicated routes should raise exception.
|
||||||
try:
|
try:
|
||||||
|
@ -358,10 +351,10 @@ def test_async_loop_forever():
|
||||||
|
|
||||||
|
|
||||||
def test_dashboard_module_decorator(enable_test_module):
|
def test_dashboard_module_decorator(enable_test_module):
|
||||||
head_cls_list = dashboard_utils.get_all_modules(
|
head_cls_list = dashboard_utils.get_all_modules(dashboard_utils.DashboardHeadModule)
|
||||||
dashboard_utils.DashboardHeadModule)
|
|
||||||
agent_cls_list = dashboard_utils.get_all_modules(
|
agent_cls_list = dashboard_utils.get_all_modules(
|
||||||
dashboard_utils.DashboardAgentModule)
|
dashboard_utils.DashboardAgentModule
|
||||||
|
)
|
||||||
|
|
||||||
assert any(cls.__name__ == "TestHead" for cls in head_cls_list)
|
assert any(cls.__name__ == "TestHead" for cls in head_cls_list)
|
||||||
assert any(cls.__name__ == "TestAgent" for cls in agent_cls_list)
|
assert any(cls.__name__ == "TestAgent" for cls in agent_cls_list)
|
||||||
|
@ -385,8 +378,7 @@ print("success")
|
||||||
|
|
||||||
|
|
||||||
def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
is True)
|
|
||||||
webui_url = ray_start_with_dashboard["webui_url"]
|
webui_url = ray_start_with_dashboard["webui_url"]
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
|
|
||||||
|
@ -397,8 +389,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
try:
|
try:
|
||||||
for x in range(10):
|
for x in range(10):
|
||||||
response = requests.get(webui_url +
|
response = requests.get(webui_url + "/test/aiohttp_cache/t1?value=1")
|
||||||
"/test/aiohttp_cache/t1?value=1")
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
timestamp = response.json()["data"]["timestamp"]
|
timestamp = response.json()["data"]["timestamp"]
|
||||||
value1_timestamps.append(timestamp)
|
value1_timestamps.append(timestamp)
|
||||||
|
@ -412,8 +403,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
||||||
|
|
||||||
sub_path_timestamps = []
|
sub_path_timestamps = []
|
||||||
for x in range(10):
|
for x in range(10):
|
||||||
response = requests.get(webui_url +
|
response = requests.get(webui_url + f"/test/aiohttp_cache/tt{x}?value=1")
|
||||||
f"/test/aiohttp_cache/tt{x}?value=1")
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
timestamp = response.json()["data"]["timestamp"]
|
timestamp = response.json()["data"]["timestamp"]
|
||||||
sub_path_timestamps.append(timestamp)
|
sub_path_timestamps.append(timestamp)
|
||||||
|
@ -421,8 +411,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
||||||
|
|
||||||
volatile_value_timestamps = []
|
volatile_value_timestamps = []
|
||||||
for x in range(10):
|
for x in range(10):
|
||||||
response = requests.get(webui_url +
|
response = requests.get(webui_url + f"/test/aiohttp_cache/tt?value={x}")
|
||||||
f"/test/aiohttp_cache/tt?value={x}")
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
timestamp = response.json()["data"]["timestamp"]
|
timestamp = response.json()["data"]["timestamp"]
|
||||||
volatile_value_timestamps.append(timestamp)
|
volatile_value_timestamps.append(timestamp)
|
||||||
|
@ -436,8 +425,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
||||||
|
|
||||||
volatile_value_timestamps = []
|
volatile_value_timestamps = []
|
||||||
for x in range(10):
|
for x in range(10):
|
||||||
response = requests.get(webui_url +
|
response = requests.get(webui_url + f"/test/aiohttp_cache_lru/tt{x % 4}")
|
||||||
f"/test/aiohttp_cache_lru/tt{x % 4}")
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
timestamp = response.json()["data"]["timestamp"]
|
timestamp = response.json()["data"]["timestamp"]
|
||||||
volatile_value_timestamps.append(timestamp)
|
volatile_value_timestamps.append(timestamp)
|
||||||
|
@ -446,8 +434,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
||||||
volatile_value_timestamps = []
|
volatile_value_timestamps = []
|
||||||
data = collections.defaultdict(set)
|
data = collections.defaultdict(set)
|
||||||
for x in [0, 1, 2, 3, 4, 5, 2, 1, 0, 3]:
|
for x in [0, 1, 2, 3, 4, 5, 2, 1, 0, 3]:
|
||||||
response = requests.get(webui_url +
|
response = requests.get(webui_url + f"/test/aiohttp_cache_lru/t1?value={x}")
|
||||||
f"/test/aiohttp_cache_lru/t1?value={x}")
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
timestamp = response.json()["data"]["timestamp"]
|
timestamp = response.json()["data"]["timestamp"]
|
||||||
data[x].add(timestamp)
|
data[x].add(timestamp)
|
||||||
|
@ -458,8 +445,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
||||||
|
|
||||||
|
|
||||||
def test_get_cluster_status(ray_start_with_dashboard):
|
def test_get_cluster_status(ray_start_with_dashboard):
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
is True)
|
|
||||||
address_info = ray_start_with_dashboard
|
address_info = ray_start_with_dashboard
|
||||||
webui_url = address_info["webui_url"]
|
webui_url = address_info["webui_url"]
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
|
@ -478,14 +464,15 @@ def test_get_cluster_status(ray_start_with_dashboard):
|
||||||
assert "loadMetricsReport" in response.json()["data"]["clusterStatus"]
|
assert "loadMetricsReport" in response.json()["data"]["clusterStatus"]
|
||||||
|
|
||||||
assert wait_until_succeeded_without_exception(
|
assert wait_until_succeeded_without_exception(
|
||||||
get_cluster_status, (requests.RequestException, ))
|
get_cluster_status, (requests.RequestException,)
|
||||||
|
)
|
||||||
|
|
||||||
gcs_client = make_gcs_client(address_info)
|
gcs_client = make_gcs_client(address_info)
|
||||||
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
|
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
|
||||||
ray.experimental.internal_kv._internal_kv_put(
|
ray.experimental.internal_kv._internal_kv_put(
|
||||||
DEBUG_AUTOSCALING_STATUS_LEGACY, "hello")
|
DEBUG_AUTOSCALING_STATUS_LEGACY, "hello"
|
||||||
ray.experimental.internal_kv._internal_kv_put(DEBUG_AUTOSCALING_ERROR,
|
)
|
||||||
"world")
|
ray.experimental.internal_kv._internal_kv_put(DEBUG_AUTOSCALING_ERROR, "world")
|
||||||
|
|
||||||
response = requests.get(f"{webui_url}/api/cluster_status")
|
response = requests.get(f"{webui_url}/api/cluster_status")
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -508,20 +495,19 @@ def test_immutable_types():
|
||||||
assert immutable_dict == dashboard_utils.ImmutableDict(d)
|
assert immutable_dict == dashboard_utils.ImmutableDict(d)
|
||||||
assert immutable_dict == d
|
assert immutable_dict == d
|
||||||
assert dashboard_utils.ImmutableDict(immutable_dict) == immutable_dict
|
assert dashboard_utils.ImmutableDict(immutable_dict) == immutable_dict
|
||||||
assert dashboard_utils.ImmutableList(
|
assert (
|
||||||
immutable_dict["list"]) == immutable_dict["list"]
|
dashboard_utils.ImmutableList(immutable_dict["list"]) == immutable_dict["list"]
|
||||||
|
)
|
||||||
assert "512" in d
|
assert "512" in d
|
||||||
assert "512" in d["list"][0]
|
assert "512" in d["list"][0]
|
||||||
assert "512" in d["dict"]
|
assert "512" in d["dict"]
|
||||||
|
|
||||||
# Test type conversion
|
# Test type conversion
|
||||||
assert type(dict(immutable_dict)["list"]) == dashboard_utils.ImmutableList
|
assert type(dict(immutable_dict)["list"]) == dashboard_utils.ImmutableList
|
||||||
assert type(list(
|
assert type(list(immutable_dict["list"])[0]) == dashboard_utils.ImmutableDict
|
||||||
immutable_dict["list"])[0]) == dashboard_utils.ImmutableDict
|
|
||||||
|
|
||||||
# Test json dumps / loads
|
# Test json dumps / loads
|
||||||
json_str = json.dumps(
|
json_str = json.dumps(immutable_dict, cls=dashboard_optional_utils.CustomEncoder)
|
||||||
immutable_dict, cls=dashboard_optional_utils.CustomEncoder)
|
|
||||||
deserialized_immutable_dict = json.loads(json_str)
|
deserialized_immutable_dict = json.loads(json_str)
|
||||||
assert type(deserialized_immutable_dict) == dict
|
assert type(deserialized_immutable_dict) == dict
|
||||||
assert type(deserialized_immutable_dict["list"]) == list
|
assert type(deserialized_immutable_dict["list"]) == list
|
||||||
|
@ -577,7 +563,7 @@ def test_immutable_types():
|
||||||
|
|
||||||
def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
|
def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
|
||||||
address_info = ray.init(num_cpus=1, include_dashboard=True)
|
address_info = ray.init(num_cpus=1, include_dashboard=True)
|
||||||
assert (wait_until_server_available(address_info["webui_url"]) is True)
|
assert wait_until_server_available(address_info["webui_url"]) is True
|
||||||
|
|
||||||
webui_url = address_info["webui_url"]
|
webui_url = address_info["webui_url"]
|
||||||
webui_url = format_web_url(webui_url)
|
webui_url = format_web_url(webui_url)
|
||||||
|
@ -588,11 +574,8 @@ def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
try:
|
try:
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
webui_url + "/test/dump",
|
webui_url + "/test/dump", proxies={"http": None, "https": None}
|
||||||
proxies={
|
)
|
||||||
"http": None,
|
|
||||||
"https": None
|
|
||||||
})
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
try:
|
try:
|
||||||
response.json()
|
response.json()
|
||||||
|
@ -609,8 +592,7 @@ def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
|
||||||
|
|
||||||
|
|
||||||
def test_dashboard_port_conflict(ray_start_with_dashboard):
|
def test_dashboard_port_conflict(ray_start_with_dashboard):
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
is True)
|
|
||||||
address_info = ray_start_with_dashboard
|
address_info = ray_start_with_dashboard
|
||||||
gcs_client = make_gcs_client(address_info)
|
gcs_client = make_gcs_client(address_info)
|
||||||
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
|
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
|
||||||
|
@ -618,11 +600,15 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
|
||||||
temp_dir = "/tmp/ray"
|
temp_dir = "/tmp/ray"
|
||||||
log_dir = "/tmp/ray/session_latest/logs"
|
log_dir = "/tmp/ray/session_latest/logs"
|
||||||
dashboard_cmd = [
|
dashboard_cmd = [
|
||||||
sys.executable, dashboard.__file__, f"--host={host}", f"--port={port}",
|
sys.executable,
|
||||||
f"--temp-dir={temp_dir}", f"--log-dir={log_dir}",
|
dashboard.__file__,
|
||||||
|
f"--host={host}",
|
||||||
|
f"--port={port}",
|
||||||
|
f"--temp-dir={temp_dir}",
|
||||||
|
f"--log-dir={log_dir}",
|
||||||
f"--redis-address={address_info['redis_address']}",
|
f"--redis-address={address_info['redis_address']}",
|
||||||
f"--redis-password={ray_constants.REDIS_DEFAULT_PASSWORD}",
|
f"--redis-password={ray_constants.REDIS_DEFAULT_PASSWORD}",
|
||||||
f"--gcs-address={address_info['gcs_address']}"
|
f"--gcs-address={address_info['gcs_address']}",
|
||||||
]
|
]
|
||||||
logger.info("The dashboard should be exit: %s", dashboard_cmd)
|
logger.info("The dashboard should be exit: %s", dashboard_cmd)
|
||||||
p = subprocess.Popen(dashboard_cmd)
|
p = subprocess.Popen(dashboard_cmd)
|
||||||
|
@ -638,7 +624,8 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
|
||||||
try:
|
try:
|
||||||
dashboard_url = ray.experimental.internal_kv._internal_kv_get(
|
dashboard_url = ray.experimental.internal_kv._internal_kv_get(
|
||||||
ray_constants.DASHBOARD_ADDRESS,
|
ray_constants.DASHBOARD_ADDRESS,
|
||||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||||
|
)
|
||||||
if dashboard_url:
|
if dashboard_url:
|
||||||
new_port = int(dashboard_url.split(b":")[-1])
|
new_port = int(dashboard_url.split(b":")[-1])
|
||||||
assert new_port > int(port)
|
assert new_port > int(port)
|
||||||
|
@ -651,8 +638,7 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
|
||||||
|
|
||||||
|
|
||||||
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
|
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
|
||||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||||
is True)
|
|
||||||
|
|
||||||
all_processes = ray.worker._global_node.all_processes
|
all_processes = ray.worker._global_node.all_processes
|
||||||
dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
|
dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
|
||||||
|
@ -661,7 +647,9 @@ def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
|
||||||
gcs_server_proc = psutil.Process(gcs_server_info.process.pid)
|
gcs_server_proc = psutil.Process(gcs_server_info.process.pid)
|
||||||
|
|
||||||
assert dashboard_proc.status() in [
|
assert dashboard_proc.status() in [
|
||||||
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING, psutil.STATUS_DISK_SLEEP
|
psutil.STATUS_RUNNING,
|
||||||
|
psutil.STATUS_SLEEPING,
|
||||||
|
psutil.STATUS_DISK_SLEEP,
|
||||||
]
|
]
|
||||||
|
|
||||||
gcs_server_proc.kill()
|
gcs_server_proc.kill()
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
import ray
|
import ray
|
||||||
from ray.dashboard.memory_utils import (
|
from ray.dashboard.memory_utils import (
|
||||||
ReferenceType, decode_object_ref_if_needed, MemoryTableEntry, MemoryTable,
|
ReferenceType,
|
||||||
SortingType)
|
decode_object_ref_if_needed,
|
||||||
|
MemoryTableEntry,
|
||||||
|
MemoryTable,
|
||||||
|
SortingType,
|
||||||
|
)
|
||||||
|
|
||||||
"""Memory Table Unit Test"""
|
"""Memory Table Unit Test"""
|
||||||
|
|
||||||
NODE_ADDRESS = "127.0.0.1"
|
NODE_ADDRESS = "127.0.0.1"
|
||||||
|
@ -14,15 +19,17 @@ DECODED_ID = decode_object_ref_if_needed(OBJECT_ID)
|
||||||
OBJECT_SIZE = 100
|
OBJECT_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
def build_memory_entry(*,
|
def build_memory_entry(
|
||||||
local_ref_count,
|
*,
|
||||||
pinned_in_memory,
|
local_ref_count,
|
||||||
submitted_task_reference_count,
|
pinned_in_memory,
|
||||||
contained_in_owned,
|
submitted_task_reference_count,
|
||||||
object_size,
|
contained_in_owned,
|
||||||
pid,
|
object_size,
|
||||||
object_id=OBJECT_ID,
|
pid,
|
||||||
node_address=NODE_ADDRESS):
|
object_id=OBJECT_ID,
|
||||||
|
node_address=NODE_ADDRESS
|
||||||
|
):
|
||||||
object_ref = {
|
object_ref = {
|
||||||
"objectId": object_id,
|
"objectId": object_id,
|
||||||
"callSite": "(task call) /Users:458",
|
"callSite": "(task call) /Users:458",
|
||||||
|
@ -30,18 +37,16 @@ def build_memory_entry(*,
|
||||||
"localRefCount": local_ref_count,
|
"localRefCount": local_ref_count,
|
||||||
"pinnedInMemory": pinned_in_memory,
|
"pinnedInMemory": pinned_in_memory,
|
||||||
"submittedTaskRefCount": submitted_task_reference_count,
|
"submittedTaskRefCount": submitted_task_reference_count,
|
||||||
"containedInOwned": contained_in_owned
|
"containedInOwned": contained_in_owned,
|
||||||
}
|
}
|
||||||
return MemoryTableEntry(
|
return MemoryTableEntry(
|
||||||
object_ref=object_ref,
|
object_ref=object_ref, node_address=node_address, is_driver=IS_DRIVER, pid=pid
|
||||||
node_address=node_address,
|
)
|
||||||
is_driver=IS_DRIVER,
|
|
||||||
pid=pid)
|
|
||||||
|
|
||||||
|
|
||||||
def build_local_reference_entry(object_size=OBJECT_SIZE,
|
def build_local_reference_entry(
|
||||||
pid=PID,
|
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
|
||||||
node_address=NODE_ADDRESS):
|
):
|
||||||
return build_memory_entry(
|
return build_memory_entry(
|
||||||
local_ref_count=1,
|
local_ref_count=1,
|
||||||
pinned_in_memory=False,
|
pinned_in_memory=False,
|
||||||
|
@ -49,12 +54,13 @@ def build_local_reference_entry(object_size=OBJECT_SIZE,
|
||||||
contained_in_owned=[],
|
contained_in_owned=[],
|
||||||
object_size=object_size,
|
object_size=object_size,
|
||||||
pid=pid,
|
pid=pid,
|
||||||
node_address=node_address)
|
node_address=node_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_used_by_pending_task_entry(object_size=OBJECT_SIZE,
|
def build_used_by_pending_task_entry(
|
||||||
pid=PID,
|
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
|
||||||
node_address=NODE_ADDRESS):
|
):
|
||||||
return build_memory_entry(
|
return build_memory_entry(
|
||||||
local_ref_count=0,
|
local_ref_count=0,
|
||||||
pinned_in_memory=False,
|
pinned_in_memory=False,
|
||||||
|
@ -62,12 +68,13 @@ def build_used_by_pending_task_entry(object_size=OBJECT_SIZE,
|
||||||
contained_in_owned=[],
|
contained_in_owned=[],
|
||||||
object_size=object_size,
|
object_size=object_size,
|
||||||
pid=pid,
|
pid=pid,
|
||||||
node_address=node_address)
|
node_address=node_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_captured_in_object_entry(object_size=OBJECT_SIZE,
|
def build_captured_in_object_entry(
|
||||||
pid=PID,
|
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
|
||||||
node_address=NODE_ADDRESS):
|
):
|
||||||
return build_memory_entry(
|
return build_memory_entry(
|
||||||
local_ref_count=0,
|
local_ref_count=0,
|
||||||
pinned_in_memory=False,
|
pinned_in_memory=False,
|
||||||
|
@ -75,12 +82,13 @@ def build_captured_in_object_entry(object_size=OBJECT_SIZE,
|
||||||
contained_in_owned=[OBJECT_ID],
|
contained_in_owned=[OBJECT_ID],
|
||||||
object_size=object_size,
|
object_size=object_size,
|
||||||
pid=pid,
|
pid=pid,
|
||||||
node_address=node_address)
|
node_address=node_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_actor_handle_entry(object_size=OBJECT_SIZE,
|
def build_actor_handle_entry(
|
||||||
pid=PID,
|
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
|
||||||
node_address=NODE_ADDRESS):
|
):
|
||||||
return build_memory_entry(
|
return build_memory_entry(
|
||||||
local_ref_count=1,
|
local_ref_count=1,
|
||||||
pinned_in_memory=False,
|
pinned_in_memory=False,
|
||||||
|
@ -89,12 +97,13 @@ def build_actor_handle_entry(object_size=OBJECT_SIZE,
|
||||||
object_size=object_size,
|
object_size=object_size,
|
||||||
pid=pid,
|
pid=pid,
|
||||||
node_address=node_address,
|
node_address=node_address,
|
||||||
object_id=ACTOR_ID)
|
object_id=ACTOR_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_pinned_in_memory_entry(object_size=OBJECT_SIZE,
|
def build_pinned_in_memory_entry(
|
||||||
pid=PID,
|
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
|
||||||
node_address=NODE_ADDRESS):
|
):
|
||||||
return build_memory_entry(
|
return build_memory_entry(
|
||||||
local_ref_count=0,
|
local_ref_count=0,
|
||||||
pinned_in_memory=True,
|
pinned_in_memory=True,
|
||||||
|
@ -102,28 +111,36 @@ def build_pinned_in_memory_entry(object_size=OBJECT_SIZE,
|
||||||
contained_in_owned=[],
|
contained_in_owned=[],
|
||||||
object_size=object_size,
|
object_size=object_size,
|
||||||
pid=pid,
|
pid=pid,
|
||||||
node_address=node_address)
|
node_address=node_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_entry(object_size=OBJECT_SIZE,
|
def build_entry(
|
||||||
pid=PID,
|
object_size=OBJECT_SIZE,
|
||||||
node_address=NODE_ADDRESS,
|
pid=PID,
|
||||||
reference_type=ReferenceType.PINNED_IN_MEMORY):
|
node_address=NODE_ADDRESS,
|
||||||
|
reference_type=ReferenceType.PINNED_IN_MEMORY,
|
||||||
|
):
|
||||||
if reference_type == ReferenceType.USED_BY_PENDING_TASK:
|
if reference_type == ReferenceType.USED_BY_PENDING_TASK:
|
||||||
return build_used_by_pending_task_entry(
|
return build_used_by_pending_task_entry(
|
||||||
pid=pid, object_size=object_size, node_address=node_address)
|
pid=pid, object_size=object_size, node_address=node_address
|
||||||
|
)
|
||||||
elif reference_type == ReferenceType.LOCAL_REFERENCE:
|
elif reference_type == ReferenceType.LOCAL_REFERENCE:
|
||||||
return build_local_reference_entry(
|
return build_local_reference_entry(
|
||||||
pid=pid, object_size=object_size, node_address=node_address)
|
pid=pid, object_size=object_size, node_address=node_address
|
||||||
|
)
|
||||||
elif reference_type == ReferenceType.PINNED_IN_MEMORY:
|
elif reference_type == ReferenceType.PINNED_IN_MEMORY:
|
||||||
return build_pinned_in_memory_entry(
|
return build_pinned_in_memory_entry(
|
||||||
pid=pid, object_size=object_size, node_address=node_address)
|
pid=pid, object_size=object_size, node_address=node_address
|
||||||
|
)
|
||||||
elif reference_type == ReferenceType.ACTOR_HANDLE:
|
elif reference_type == ReferenceType.ACTOR_HANDLE:
|
||||||
return build_actor_handle_entry(
|
return build_actor_handle_entry(
|
||||||
pid=pid, object_size=object_size, node_address=node_address)
|
pid=pid, object_size=object_size, node_address=node_address
|
||||||
|
)
|
||||||
elif reference_type == ReferenceType.CAPTURED_IN_OBJECT:
|
elif reference_type == ReferenceType.CAPTURED_IN_OBJECT:
|
||||||
return build_captured_in_object_entry(
|
return build_captured_in_object_entry(
|
||||||
pid=pid, object_size=object_size, node_address=node_address)
|
pid=pid, object_size=object_size, node_address=node_address
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_memory_entry():
|
def test_invalid_memory_entry():
|
||||||
|
@ -133,7 +150,8 @@ def test_invalid_memory_entry():
|
||||||
submitted_task_reference_count=0,
|
submitted_task_reference_count=0,
|
||||||
contained_in_owned=[],
|
contained_in_owned=[],
|
||||||
object_size=OBJECT_SIZE,
|
object_size=OBJECT_SIZE,
|
||||||
pid=PID)
|
pid=PID,
|
||||||
|
)
|
||||||
assert memory_entry.is_valid() is False
|
assert memory_entry.is_valid() is False
|
||||||
memory_entry = build_memory_entry(
|
memory_entry = build_memory_entry(
|
||||||
local_ref_count=0,
|
local_ref_count=0,
|
||||||
|
@ -141,7 +159,8 @@ def test_invalid_memory_entry():
|
||||||
submitted_task_reference_count=0,
|
submitted_task_reference_count=0,
|
||||||
contained_in_owned=[],
|
contained_in_owned=[],
|
||||||
object_size=-1,
|
object_size=-1,
|
||||||
pid=PID)
|
pid=PID,
|
||||||
|
)
|
||||||
assert memory_entry.is_valid() is False
|
assert memory_entry.is_valid() is False
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,7 +168,8 @@ def test_valid_reference_memory_entry():
|
||||||
memory_entry = build_local_reference_entry()
|
memory_entry = build_local_reference_entry()
|
||||||
assert memory_entry.reference_type == ReferenceType.LOCAL_REFERENCE
|
assert memory_entry.reference_type == ReferenceType.LOCAL_REFERENCE
|
||||||
assert memory_entry.object_ref == ray.ObjectRef(
|
assert memory_entry.object_ref == ray.ObjectRef(
|
||||||
decode_object_ref_if_needed(OBJECT_ID))
|
decode_object_ref_if_needed(OBJECT_ID)
|
||||||
|
)
|
||||||
assert memory_entry.is_valid() is True
|
assert memory_entry.is_valid() is True
|
||||||
|
|
||||||
|
|
||||||
|
@ -178,15 +198,14 @@ def test_memory_table_summary():
|
||||||
build_captured_in_object_entry(),
|
build_captured_in_object_entry(),
|
||||||
build_actor_handle_entry(),
|
build_actor_handle_entry(),
|
||||||
build_local_reference_entry(),
|
build_local_reference_entry(),
|
||||||
build_local_reference_entry()
|
build_local_reference_entry(),
|
||||||
]
|
]
|
||||||
memory_table = MemoryTable(entries)
|
memory_table = MemoryTable(entries)
|
||||||
assert len(memory_table.group) == 1
|
assert len(memory_table.group) == 1
|
||||||
assert memory_table.summary["total_actor_handles"] == 1
|
assert memory_table.summary["total_actor_handles"] == 1
|
||||||
assert memory_table.summary["total_captured_in_objects"] == 1
|
assert memory_table.summary["total_captured_in_objects"] == 1
|
||||||
assert memory_table.summary["total_local_ref_count"] == 2
|
assert memory_table.summary["total_local_ref_count"] == 2
|
||||||
assert memory_table.summary[
|
assert memory_table.summary["total_object_size"] == len(entries) * OBJECT_SIZE
|
||||||
"total_object_size"] == len(entries) * OBJECT_SIZE
|
|
||||||
assert memory_table.summary["total_pinned_in_memory"] == 1
|
assert memory_table.summary["total_pinned_in_memory"] == 1
|
||||||
assert memory_table.summary["total_used_by_pending_task"] == 1
|
assert memory_table.summary["total_used_by_pending_task"] == 1
|
||||||
|
|
||||||
|
@ -202,14 +221,13 @@ def test_memory_table_sort_by_pid():
|
||||||
|
|
||||||
def test_memory_table_sort_by_reference_type():
|
def test_memory_table_sort_by_reference_type():
|
||||||
unsort = [
|
unsort = [
|
||||||
ReferenceType.USED_BY_PENDING_TASK, ReferenceType.LOCAL_REFERENCE,
|
ReferenceType.USED_BY_PENDING_TASK,
|
||||||
ReferenceType.LOCAL_REFERENCE, ReferenceType.PINNED_IN_MEMORY
|
ReferenceType.LOCAL_REFERENCE,
|
||||||
|
ReferenceType.LOCAL_REFERENCE,
|
||||||
|
ReferenceType.PINNED_IN_MEMORY,
|
||||||
]
|
]
|
||||||
entries = [
|
entries = [build_entry(reference_type=reference_type) for reference_type in unsort]
|
||||||
build_entry(reference_type=reference_type) for reference_type in unsort
|
memory_table = MemoryTable(entries, sort_by_type=SortingType.REFERENCE_TYPE)
|
||||||
]
|
|
||||||
memory_table = MemoryTable(
|
|
||||||
entries, sort_by_type=SortingType.REFERENCE_TYPE)
|
|
||||||
sort = sorted(unsort)
|
sort = sorted(unsort)
|
||||||
for reference_type, entry in zip(sort, memory_table.table):
|
for reference_type, entry in zip(sort, memory_table.table):
|
||||||
assert reference_type == entry.reference_type
|
assert reference_type == entry.reference_type
|
||||||
|
@ -231,7 +249,7 @@ def test_group_by():
|
||||||
build_entry(node_address=node_second, pid=2),
|
build_entry(node_address=node_second, pid=2),
|
||||||
build_entry(node_address=node_second, pid=1),
|
build_entry(node_address=node_second, pid=1),
|
||||||
build_entry(node_address=node_first, pid=2),
|
build_entry(node_address=node_first, pid=2),
|
||||||
build_entry(node_address=node_first, pid=1)
|
build_entry(node_address=node_first, pid=1),
|
||||||
]
|
]
|
||||||
memory_table = MemoryTable(entries)
|
memory_table = MemoryTable(entries)
|
||||||
|
|
||||||
|
@ -250,4 +268,5 @@ def test_group_by():
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
sys.exit(pytest.main(["-v", __file__]))
|
sys.exit(pytest.main(["-v", __file__]))
|
||||||
|
|
|
@ -18,8 +18,7 @@ import aiosignal # noqa: F401
|
||||||
from google.protobuf.json_format import MessageToDict
|
from google.protobuf.json_format import MessageToDict
|
||||||
from frozenlist import FrozenList # noqa: F401
|
from frozenlist import FrozenList # noqa: F401
|
||||||
|
|
||||||
from ray._private.utils import (binary_to_hex,
|
from ray._private.utils import binary_to_hex, check_dashboard_dependencies_installed
|
||||||
check_dashboard_dependencies_installed)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
create_task = asyncio.create_task
|
create_task = asyncio.create_task
|
||||||
|
@ -97,23 +96,26 @@ def get_all_modules(module_type):
|
||||||
"""
|
"""
|
||||||
logger.info(f"Get all modules by type: {module_type.__name__}")
|
logger.info(f"Get all modules by type: {module_type.__name__}")
|
||||||
import ray.dashboard.modules
|
import ray.dashboard.modules
|
||||||
should_only_load_minimal_modules = (
|
|
||||||
not check_dashboard_dependencies_installed())
|
should_only_load_minimal_modules = not check_dashboard_dependencies_installed()
|
||||||
|
|
||||||
for module_loader, name, ispkg in pkgutil.walk_packages(
|
for module_loader, name, ispkg in pkgutil.walk_packages(
|
||||||
ray.dashboard.modules.__path__,
|
ray.dashboard.modules.__path__, ray.dashboard.modules.__name__ + "."
|
||||||
ray.dashboard.modules.__name__ + "."):
|
):
|
||||||
try:
|
try:
|
||||||
importlib.import_module(name)
|
importlib.import_module(name)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
logger.info(f"Module {name} cannot be loaded because "
|
logger.info(
|
||||||
"we cannot import all dependencies. Download "
|
f"Module {name} cannot be loaded because "
|
||||||
"`pip install ray[default]` for the full "
|
"we cannot import all dependencies. Download "
|
||||||
f"dashboard functionality. Error: {e}")
|
"`pip install ray[default]` for the full "
|
||||||
|
f"dashboard functionality. Error: {e}"
|
||||||
|
)
|
||||||
if not should_only_load_minimal_modules:
|
if not should_only_load_minimal_modules:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Although `pip install ray[default] is downloaded, "
|
"Although `pip install ray[default] is downloaded, "
|
||||||
"module couldn't be imported`")
|
"module couldn't be imported`"
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
imported_modules = []
|
imported_modules = []
|
||||||
|
@ -202,7 +204,8 @@ def message_to_dict(message, decode_keys=None, **kwargs):
|
||||||
|
|
||||||
if decode_keys:
|
if decode_keys:
|
||||||
return _decode_keys(
|
return _decode_keys(
|
||||||
MessageToDict(message, use_integers_for_enums=False, **kwargs))
|
MessageToDict(message, use_integers_for_enums=False, **kwargs)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return MessageToDict(message, use_integers_for_enums=False, **kwargs)
|
return MessageToDict(message, use_integers_for_enums=False, **kwargs)
|
||||||
|
|
||||||
|
@ -251,8 +254,9 @@ class Change:
|
||||||
self.new = new
|
self.new = new
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"Change(owner: {type(self.owner)}), " \
|
return (
|
||||||
f"old: {self.old}, new: {self.new}"
|
f"Change(owner: {type(self.owner)}), " f"old: {self.old}, new: {self.new}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NotifyQueue:
|
class NotifyQueue:
|
||||||
|
@ -289,10 +293,7 @@ https://docs.python.org/3/library/json.html?highlight=json#json.JSONEncoder
|
||||||
| None | null |
|
| None | null |
|
||||||
+-------------------+---------------+
|
+-------------------+---------------+
|
||||||
"""
|
"""
|
||||||
_json_compatible_types = {
|
_json_compatible_types = {dict, list, tuple, str, int, float, bool, type(None), bytes}
|
||||||
dict, list, tuple, str, int, float, bool,
|
|
||||||
type(None), bytes
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def is_immutable(self):
|
def is_immutable(self):
|
||||||
|
@ -318,8 +319,7 @@ class Immutable(metaclass=ABCMeta):
|
||||||
|
|
||||||
|
|
||||||
class ImmutableList(Immutable, Sequence):
|
class ImmutableList(Immutable, Sequence):
|
||||||
"""Makes a :class:`list` immutable.
|
"""Makes a :class:`list` immutable."""
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("_list", "_proxy")
|
__slots__ = ("_list", "_proxy")
|
||||||
|
|
||||||
|
@ -332,7 +332,7 @@ class ImmutableList(Immutable, Sequence):
|
||||||
self._proxy = [None] * len(list_value)
|
self._proxy = [None] * len(list_value)
|
||||||
|
|
||||||
def __reduce_ex__(self, protocol):
|
def __reduce_ex__(self, protocol):
|
||||||
return type(self), (self._list, )
|
return type(self), (self._list,)
|
||||||
|
|
||||||
def mutable(self):
|
def mutable(self):
|
||||||
return self._list
|
return self._list
|
||||||
|
@ -366,8 +366,7 @@ class ImmutableList(Immutable, Sequence):
|
||||||
|
|
||||||
|
|
||||||
class ImmutableDict(Immutable, Mapping):
|
class ImmutableDict(Immutable, Mapping):
|
||||||
"""Makes a :class:`dict` immutable.
|
"""Makes a :class:`dict` immutable."""
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("_dict", "_proxy")
|
__slots__ = ("_dict", "_proxy")
|
||||||
|
|
||||||
|
@ -380,7 +379,7 @@ class ImmutableDict(Immutable, Mapping):
|
||||||
self._proxy = {}
|
self._proxy = {}
|
||||||
|
|
||||||
def __reduce_ex__(self, protocol):
|
def __reduce_ex__(self, protocol):
|
||||||
return type(self), (self._dict, )
|
return type(self), (self._dict,)
|
||||||
|
|
||||||
def mutable(self):
|
def mutable(self):
|
||||||
return self._dict
|
return self._dict
|
||||||
|
@ -443,21 +442,23 @@ class Dict(ImmutableDict, MutableMapping):
|
||||||
if len(self.signal) and old != value:
|
if len(self.signal) and old != value:
|
||||||
if old is None:
|
if old is None:
|
||||||
co = self.signal.send(
|
co = self.signal.send(
|
||||||
Change(owner=self, new=Dict.ChangeItem(key, value)))
|
Change(owner=self, new=Dict.ChangeItem(key, value))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
co = self.signal.send(
|
co = self.signal.send(
|
||||||
Change(
|
Change(
|
||||||
owner=self,
|
owner=self,
|
||||||
old=Dict.ChangeItem(key, old),
|
old=Dict.ChangeItem(key, old),
|
||||||
new=Dict.ChangeItem(key, value)))
|
new=Dict.ChangeItem(key, value),
|
||||||
|
)
|
||||||
|
)
|
||||||
NotifyQueue.put(co)
|
NotifyQueue.put(co)
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
old = self._dict.pop(key, None)
|
old = self._dict.pop(key, None)
|
||||||
self._proxy.pop(key, None)
|
self._proxy.pop(key, None)
|
||||||
if len(self.signal) and old is not None:
|
if len(self.signal) and old is not None:
|
||||||
co = self.signal.send(
|
co = self.signal.send(Change(owner=self, old=Dict.ChangeItem(key, old)))
|
||||||
Change(owner=self, old=Dict.ChangeItem(key, old)))
|
|
||||||
NotifyQueue.put(co)
|
NotifyQueue.put(co)
|
||||||
|
|
||||||
def reset(self, d):
|
def reset(self, d):
|
||||||
|
@ -482,12 +483,15 @@ def async_loop_forever(interval_seconds, cancellable=False):
|
||||||
await coro(*args, **kwargs)
|
await coro(*args, **kwargs)
|
||||||
except asyncio.CancelledError as ex:
|
except asyncio.CancelledError as ex:
|
||||||
if cancellable:
|
if cancellable:
|
||||||
logger.info(f"An async loop forever coroutine "
|
logger.info(
|
||||||
f"is cancelled {coro}.")
|
f"An async loop forever coroutine " f"is cancelled {coro}."
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
else:
|
else:
|
||||||
logger.exception(f"Can not cancel the async loop "
|
logger.exception(
|
||||||
f"forever coroutine {coro}.")
|
f"Can not cancel the async loop "
|
||||||
|
f"forever coroutine {coro}."
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Error looping coroutine {coro}.")
|
logger.exception(f"Error looping coroutine {coro}.")
|
||||||
await asyncio.sleep(interval_seconds)
|
await asyncio.sleep(interval_seconds)
|
||||||
|
@ -497,15 +501,18 @@ def async_loop_forever(interval_seconds, cancellable=False):
|
||||||
return _wrapper
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
async def get_aioredis_client(redis_address, redis_password,
|
async def get_aioredis_client(
|
||||||
retry_interval_seconds, retry_times):
|
redis_address, redis_password, retry_interval_seconds, retry_times
|
||||||
|
):
|
||||||
for x in range(retry_times):
|
for x in range(retry_times):
|
||||||
try:
|
try:
|
||||||
return await aioredis.create_redis_pool(
|
return await aioredis.create_redis_pool(
|
||||||
address=redis_address, password=redis_password)
|
address=redis_address, password=redis_password
|
||||||
|
)
|
||||||
except (socket.gaierror, ConnectionError) as ex:
|
except (socket.gaierror, ConnectionError) as ex:
|
||||||
logger.error("Connect to Redis failed: %s, retry...", ex)
|
logger.error("Connect to Redis failed: %s, retry...", ex)
|
||||||
await asyncio.sleep(retry_interval_seconds)
|
await asyncio.sleep(retry_interval_seconds)
|
||||||
# Raise exception from create_redis_pool
|
# Raise exception from create_redis_pool
|
||||||
return await aioredis.create_redis_pool(
|
return await aioredis.create_redis_pool(
|
||||||
address=redis_address, password=redis_password)
|
address=redis_address, password=redis_password
|
||||||
|
)
|
||||||
|
|
|
@ -2,6 +2,7 @@ from collections import Counter
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
""" This script is meant to be run from a pod in the same Kubernetes namespace
|
""" This script is meant to be run from a pod in the same Kubernetes namespace
|
||||||
as your Ray cluster.
|
as your Ray cluster.
|
||||||
"""
|
"""
|
||||||
|
@ -11,8 +12,9 @@ as your Ray cluster.
|
||||||
def gethostname(x):
|
def gethostname(x):
|
||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
|
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
return x + (platform.node(), )
|
return x + (platform.node(),)
|
||||||
|
|
||||||
|
|
||||||
def wait_for_nodes(expected):
|
def wait_for_nodes(expected):
|
||||||
|
@ -22,8 +24,11 @@ def wait_for_nodes(expected):
|
||||||
node_keys = [key for key in resources if "node" in key]
|
node_keys = [key for key in resources if "node" in key]
|
||||||
num_nodes = sum(resources[node_key] for node_key in node_keys)
|
num_nodes = sum(resources[node_key] for node_key in node_keys)
|
||||||
if num_nodes < expected:
|
if num_nodes < expected:
|
||||||
print("{} nodes have joined so far, waiting for {} more.".format(
|
print(
|
||||||
num_nodes, expected - num_nodes))
|
"{} nodes have joined so far, waiting for {} more.".format(
|
||||||
|
num_nodes, expected - num_nodes
|
||||||
|
)
|
||||||
|
)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
else:
|
else:
|
||||||
|
@ -36,9 +41,7 @@ def main():
|
||||||
# Check that objects can be transferred from each node to each other node.
|
# Check that objects can be transferred from each node to each other node.
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
print("Iteration {}".format(i))
|
print("Iteration {}".format(i))
|
||||||
results = [
|
results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
|
||||||
gethostname.remote(gethostname.remote(())) for _ in range(100)
|
|
||||||
]
|
|
||||||
print(Counter(ray.get(results)))
|
print(Counter(ray.get(results)))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ from collections import Counter
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
""" Run this script locally to execute a Ray program on your Ray cluster on
|
""" Run this script locally to execute a Ray program on your Ray cluster on
|
||||||
Kubernetes.
|
Kubernetes.
|
||||||
|
|
||||||
|
@ -18,8 +19,9 @@ LOCAL_PORT = 10001
|
||||||
def gethostname(x):
|
def gethostname(x):
|
||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
|
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
return x + (platform.node(), )
|
return x + (platform.node(),)
|
||||||
|
|
||||||
|
|
||||||
def wait_for_nodes(expected):
|
def wait_for_nodes(expected):
|
||||||
|
@ -29,8 +31,11 @@ def wait_for_nodes(expected):
|
||||||
node_keys = [key for key in resources if "node" in key]
|
node_keys = [key for key in resources if "node" in key]
|
||||||
num_nodes = sum(resources[node_key] for node_key in node_keys)
|
num_nodes = sum(resources[node_key] for node_key in node_keys)
|
||||||
if num_nodes < expected:
|
if num_nodes < expected:
|
||||||
print("{} nodes have joined so far, waiting for {} more.".format(
|
print(
|
||||||
num_nodes, expected - num_nodes))
|
"{} nodes have joined so far, waiting for {} more.".format(
|
||||||
|
num_nodes, expected - num_nodes
|
||||||
|
)
|
||||||
|
)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
else:
|
else:
|
||||||
|
@ -43,9 +48,7 @@ def main():
|
||||||
# Check that objects can be transferred from each node to each other node.
|
# Check that objects can be transferred from each node to each other node.
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
print("Iteration {}".format(i))
|
print("Iteration {}".format(i))
|
||||||
results = [
|
results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
|
||||||
gethostname.remote(gethostname.remote(())) for _ in range(100)
|
|
||||||
]
|
|
||||||
print(Counter(ray.get(results)))
|
print(Counter(ray.get(results)))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
|
@ -10,8 +10,9 @@ import ray
|
||||||
def gethostname(x):
|
def gethostname(x):
|
||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
|
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
return x + (platform.node(), )
|
return x + (platform.node(),)
|
||||||
|
|
||||||
|
|
||||||
def wait_for_nodes(expected):
|
def wait_for_nodes(expected):
|
||||||
|
@ -21,8 +22,11 @@ def wait_for_nodes(expected):
|
||||||
node_keys = [key for key in resources if "node" in key]
|
node_keys = [key for key in resources if "node" in key]
|
||||||
num_nodes = sum(resources[node_key] for node_key in node_keys)
|
num_nodes = sum(resources[node_key] for node_key in node_keys)
|
||||||
if num_nodes < expected:
|
if num_nodes < expected:
|
||||||
print("{} nodes have joined so far, waiting for {} more.".format(
|
print(
|
||||||
num_nodes, expected - num_nodes))
|
"{} nodes have joined so far, waiting for {} more.".format(
|
||||||
|
num_nodes, expected - num_nodes
|
||||||
|
)
|
||||||
|
)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
else:
|
else:
|
||||||
|
@ -35,9 +39,7 @@ def main():
|
||||||
# Check that objects can be transferred from each node to each other node.
|
# Check that objects can be transferred from each node to each other node.
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
print("Iteration {}".format(i))
|
print("Iteration {}".format(i))
|
||||||
results = [
|
results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
|
||||||
gethostname.remote(gethostname.remote(())) for _ in range(100)
|
|
||||||
]
|
|
||||||
print(Counter(ray.get(results)))
|
print(Counter(ray.get(results)))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
|
@ -25,24 +25,24 @@ if __name__ == "__main__":
|
||||||
"--exp-name",
|
"--exp-name",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The job name and path to logging file (exp_name.log).")
|
help="The job name and path to logging file (exp_name.log).",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-nodes",
|
"--num-nodes", "-n", type=int, default=1, help="Number of nodes to use."
|
||||||
"-n",
|
)
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of nodes to use.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--node",
|
"--node",
|
||||||
"-w",
|
"-w",
|
||||||
type=str,
|
type=str,
|
||||||
help="The specified nodes to use. Same format as the "
|
help="The specified nodes to use. Same format as the "
|
||||||
"return of 'sinfo'. Default: ''.")
|
"return of 'sinfo'. Default: ''.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-gpus",
|
"--num-gpus",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Number of GPUs to use in each node. (Default: 0)")
|
help="Number of GPUs to use in each node. (Default: 0)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--partition",
|
"--partition",
|
||||||
"-p",
|
"-p",
|
||||||
|
@ -51,14 +51,16 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load-env",
|
"--load-env",
|
||||||
type=str,
|
type=str,
|
||||||
help="The script to load your environment ('module load cuda/10.1')")
|
help="The script to load your environment ('module load cuda/10.1')",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--command",
|
"--command",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The command you wish to execute. For example: "
|
help="The command you wish to execute. For example: "
|
||||||
" --command 'python test.py'. "
|
" --command 'python test.py'. "
|
||||||
"Note that the command must be a string.")
|
"Note that the command must be a string.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.node:
|
if args.node:
|
||||||
|
@ -67,11 +69,13 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
node_info = ""
|
node_info = ""
|
||||||
|
|
||||||
job_name = "{}_{}".format(args.exp_name,
|
job_name = "{}_{}".format(
|
||||||
time.strftime("%m%d-%H%M", time.localtime()))
|
args.exp_name, time.strftime("%m%d-%H%M", time.localtime())
|
||||||
|
)
|
||||||
|
|
||||||
partition_option = "#SBATCH --partition={}".format(
|
partition_option = (
|
||||||
args.partition) if args.partition else ""
|
"#SBATCH --partition={}".format(args.partition) if args.partition else ""
|
||||||
|
)
|
||||||
|
|
||||||
# ===== Modified the template script =====
|
# ===== Modified the template script =====
|
||||||
with open(template_file, "r") as f:
|
with open(template_file, "r") as f:
|
||||||
|
@ -84,10 +88,10 @@ if __name__ == "__main__":
|
||||||
text = text.replace(LOAD_ENV, str(args.load_env))
|
text = text.replace(LOAD_ENV, str(args.load_env))
|
||||||
text = text.replace(GIVEN_NODE, node_info)
|
text = text.replace(GIVEN_NODE, node_info)
|
||||||
text = text.replace(
|
text = text.replace(
|
||||||
"# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO "
|
"# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO " "PRODUCTION!",
|
||||||
"PRODUCTION!",
|
|
||||||
"# THIS FILE IS MODIFIED AUTOMATICALLY FROM TEMPLATE AND SHOULD BE "
|
"# THIS FILE IS MODIFIED AUTOMATICALLY FROM TEMPLATE AND SHOULD BE "
|
||||||
"RUNNABLE!")
|
"RUNNABLE!",
|
||||||
|
)
|
||||||
|
|
||||||
# ===== Save the script =====
|
# ===== Save the script =====
|
||||||
script_file = "{}.sh".format(job_name)
|
script_file = "{}.sh".format(job_name)
|
||||||
|
@ -99,5 +103,7 @@ if __name__ == "__main__":
|
||||||
subprocess.Popen(["sbatch", script_file])
|
subprocess.Popen(["sbatch", script_file])
|
||||||
print(
|
print(
|
||||||
"Job submitted! Script file is at: <{}>. Log file is at: <{}>".format(
|
"Job submitted! Script file is at: <{}>. Log file is at: <{}>".format(
|
||||||
script_file, "{}.log".format(job_name)))
|
script_file, "{}.log".format(job_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
|
@ -89,7 +89,7 @@ myst_enable_extensions = [
|
||||||
]
|
]
|
||||||
|
|
||||||
external_toc_exclude_missing = False
|
external_toc_exclude_missing = False
|
||||||
external_toc_path = '_toc.yml'
|
external_toc_path = "_toc.yml"
|
||||||
|
|
||||||
# There's a flaky autodoc import for "TensorFlowVariables" that fails depending on the doc structure / order
|
# There's a flaky autodoc import for "TensorFlowVariables" that fails depending on the doc structure / order
|
||||||
# of imports.
|
# of imports.
|
||||||
|
@ -112,7 +112,8 @@ versionwarning_messages = {
|
||||||
"<b>Got questions?</b> Join "
|
"<b>Got questions?</b> Join "
|
||||||
f'<a href="{FORUM_LINK}">the Ray Community forum</a> '
|
f'<a href="{FORUM_LINK}">the Ray Community forum</a> '
|
||||||
"for Q&A on all things Ray, as well as to share and learn use cases "
|
"for Q&A on all things Ray, as well as to share and learn use cases "
|
||||||
"and best practices with the Ray community."),
|
"and best practices with the Ray community."
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
versionwarning_body_selector = "#main-content"
|
versionwarning_body_selector = "#main-content"
|
||||||
|
@ -189,11 +190,16 @@ exclude_patterns += sphinx_gallery_conf["examples_dirs"]
|
||||||
# If "DOC_LIB" is found, only build that top-level navigation item.
|
# If "DOC_LIB" is found, only build that top-level navigation item.
|
||||||
build_one_lib = os.getenv("DOC_LIB")
|
build_one_lib = os.getenv("DOC_LIB")
|
||||||
|
|
||||||
all_toc_libs = [
|
all_toc_libs = [f.path for f in os.scandir(".") if f.is_dir() and "ray-" in f.path]
|
||||||
f.path for f in os.scandir(".") if f.is_dir() and "ray-" in f.path
|
|
||||||
]
|
|
||||||
all_toc_libs += [
|
all_toc_libs += [
|
||||||
"cluster", "tune", "data", "raysgd", "train", "rllib", "serve", "workflows"
|
"cluster",
|
||||||
|
"tune",
|
||||||
|
"data",
|
||||||
|
"raysgd",
|
||||||
|
"train",
|
||||||
|
"rllib",
|
||||||
|
"serve",
|
||||||
|
"workflows",
|
||||||
]
|
]
|
||||||
if build_one_lib and build_one_lib in all_toc_libs:
|
if build_one_lib and build_one_lib in all_toc_libs:
|
||||||
all_toc_libs.remove(build_one_lib)
|
all_toc_libs.remove(build_one_lib)
|
||||||
|
@ -405,7 +411,8 @@ def setup(app):
|
||||||
# Custom JS
|
# Custom JS
|
||||||
app.add_js_file(
|
app.add_js_file(
|
||||||
"https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js",
|
"https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js",
|
||||||
defer="defer")
|
defer="defer",
|
||||||
|
)
|
||||||
app.add_js_file("js/docsearch.js", defer="defer")
|
app.add_js_file("js/docsearch.js", defer="defer")
|
||||||
# Custom Sphinx directives
|
# Custom Sphinx directives
|
||||||
app.add_directive("customgalleryitem", CustomGalleryItemDirective)
|
app.add_directive("customgalleryitem", CustomGalleryItemDirective)
|
||||||
|
|
|
@ -6,13 +6,17 @@ from docutils import nodes
|
||||||
import os
|
import os
|
||||||
import sphinx_gallery
|
import sphinx_gallery
|
||||||
import urllib
|
import urllib
|
||||||
|
|
||||||
# Note: the scipy import has to stay here, it's used implicitly down the line
|
# Note: the scipy import has to stay here, it's used implicitly down the line
|
||||||
import scipy.stats # noqa: F401
|
import scipy.stats # noqa: F401
|
||||||
import scipy.linalg # noqa: F401
|
import scipy.linalg # noqa: F401
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CustomGalleryItemDirective", "fix_xgb_lgbm_docs", "MOCK_MODULES",
|
"CustomGalleryItemDirective",
|
||||||
"CHILD_MOCK_MODULES", "update_context"
|
"fix_xgb_lgbm_docs",
|
||||||
|
"MOCK_MODULES",
|
||||||
|
"CHILD_MOCK_MODULES",
|
||||||
|
"update_context",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -60,7 +64,7 @@ class CustomGalleryItemDirective(Directive):
|
||||||
option_spec = {
|
option_spec = {
|
||||||
"tooltip": directives.unchanged,
|
"tooltip": directives.unchanged,
|
||||||
"figure": directives.unchanged,
|
"figure": directives.unchanged,
|
||||||
"description": directives.unchanged
|
"description": directives.unchanged,
|
||||||
}
|
}
|
||||||
|
|
||||||
has_content = False
|
has_content = False
|
||||||
|
@ -73,8 +77,9 @@ class CustomGalleryItemDirective(Directive):
|
||||||
if len(self.options["tooltip"]) > 195:
|
if len(self.options["tooltip"]) > 195:
|
||||||
tooltip = tooltip[:195] + "..."
|
tooltip = tooltip[:195] + "..."
|
||||||
else:
|
else:
|
||||||
raise ValueError("Need to provide :tooltip: under "
|
raise ValueError(
|
||||||
"`.. customgalleryitem::`.")
|
"Need to provide :tooltip: under " "`.. customgalleryitem::`."
|
||||||
|
)
|
||||||
|
|
||||||
# Generate `thumbnail` used in the gallery.
|
# Generate `thumbnail` used in the gallery.
|
||||||
if "figure" in self.options:
|
if "figure" in self.options:
|
||||||
|
@ -95,11 +100,13 @@ class CustomGalleryItemDirective(Directive):
|
||||||
if "description" in self.options:
|
if "description" in self.options:
|
||||||
description = self.options["description"]
|
description = self.options["description"]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Need to provide :description: under "
|
raise ValueError(
|
||||||
"`customgalleryitem::`.")
|
"Need to provide :description: under " "`customgalleryitem::`."
|
||||||
|
)
|
||||||
|
|
||||||
thumbnail_rst = GALLERY_TEMPLATE.format(
|
thumbnail_rst = GALLERY_TEMPLATE.format(
|
||||||
tooltip=tooltip, thumbnail=thumbnail, description=description)
|
tooltip=tooltip, thumbnail=thumbnail, description=description
|
||||||
|
)
|
||||||
thumbnail = StringList(thumbnail_rst.split("\n"))
|
thumbnail = StringList(thumbnail_rst.split("\n"))
|
||||||
thumb = nodes.paragraph()
|
thumb = nodes.paragraph()
|
||||||
self.state.nested_parse(thumbnail, self.content_offset, thumb)
|
self.state.nested_parse(thumbnail, self.content_offset, thumb)
|
||||||
|
@ -146,29 +153,30 @@ def fix_xgb_lgbm_docs(app, what, name, obj, options, lines):
|
||||||
|
|
||||||
|
|
||||||
# Taken from https://github.com/edx/edx-documentation
|
# Taken from https://github.com/edx/edx-documentation
|
||||||
FEEDBACK_FORM_FMT = "https://github.com/ray-project/ray/issues/new?" \
|
FEEDBACK_FORM_FMT = (
|
||||||
"title={title}&labels=docs&body={body}"
|
"https://github.com/ray-project/ray/issues/new?"
|
||||||
|
"title={title}&labels=docs&body={body}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def feedback_form_url(project, page):
|
def feedback_form_url(project, page):
|
||||||
"""Create a URL for feedback on a particular page in a project."""
|
"""Create a URL for feedback on a particular page in a project."""
|
||||||
return FEEDBACK_FORM_FMT.format(
|
return FEEDBACK_FORM_FMT.format(
|
||||||
title=urllib.parse.quote(
|
title=urllib.parse.quote("[docs] Issue on `{page}.rst`".format(page=page)),
|
||||||
"[docs] Issue on `{page}.rst`".format(page=page)),
|
|
||||||
body=urllib.parse.quote(
|
body=urllib.parse.quote(
|
||||||
"# Documentation Problem/Question/Comment\n"
|
"# Documentation Problem/Question/Comment\n"
|
||||||
"<!-- Describe your issue/question/comment below. -->\n"
|
"<!-- Describe your issue/question/comment below. -->\n"
|
||||||
"<!-- If there are typos or errors in the docs, feel free "
|
"<!-- If there are typos or errors in the docs, feel free "
|
||||||
"to create a pull-request. -->\n"
|
"to create a pull-request. -->\n"
|
||||||
"\n\n\n\n"
|
"\n\n\n\n"
|
||||||
"(Created directly from the docs)\n"),
|
"(Created directly from the docs)\n"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def update_context(app, pagename, templatename, context, doctree):
|
def update_context(app, pagename, templatename, context, doctree):
|
||||||
"""Update the page rendering context to include ``feedback_form_url``."""
|
"""Update the page rendering context to include ``feedback_form_url``."""
|
||||||
context["feedback_form_url"] = feedback_form_url(app.config.project,
|
context["feedback_form_url"] = feedback_form_url(app.config.project, pagename)
|
||||||
pagename)
|
|
||||||
|
|
||||||
|
|
||||||
MOCK_MODULES = [
|
MOCK_MODULES = [
|
||||||
|
@ -187,8 +195,7 @@ MOCK_MODULES = [
|
||||||
"horovod.ray.runner",
|
"horovod.ray.runner",
|
||||||
"horovod.ray.utils",
|
"horovod.ray.utils",
|
||||||
"hyperopt",
|
"hyperopt",
|
||||||
"hyperopt.hp"
|
"hyperopt.hp" "kubernetes",
|
||||||
"kubernetes",
|
|
||||||
"mlflow",
|
"mlflow",
|
||||||
"modin",
|
"modin",
|
||||||
"mxnet",
|
"mxnet",
|
||||||
|
|
|
@ -59,13 +59,16 @@ from ray.data.datasource.datasource import RandomIntRowDatasource
|
||||||
# Let’s see how we implement such pipeline using Ray Dataset:
|
# Let’s see how we implement such pipeline using Ray Dataset:
|
||||||
|
|
||||||
|
|
||||||
def create_shuffle_pipeline(training_data_dir: str, num_epochs: int,
|
def create_shuffle_pipeline(
|
||||||
num_shards: int) -> List[DatasetPipeline]:
|
training_data_dir: str, num_epochs: int, num_shards: int
|
||||||
|
) -> List[DatasetPipeline]:
|
||||||
|
|
||||||
return ray.data.read_parquet(training_data_dir) \
|
return (
|
||||||
.repeat(num_epochs) \
|
ray.data.read_parquet(training_data_dir)
|
||||||
.random_shuffle_each_window() \
|
.repeat(num_epochs)
|
||||||
|
.random_shuffle_each_window()
|
||||||
.split(num_shards, equal=True)
|
.split(num_shards, equal=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
@ -117,7 +120,8 @@ parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--large-scale-test",
|
"--large-scale-test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Run large scale test (500GiB of data).")
|
help="Run large scale test (500GiB of data).",
|
||||||
|
)
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
@ -142,13 +146,15 @@ if not args.large_scale_test:
|
||||||
ray.data.read_datasource(
|
ray.data.read_datasource(
|
||||||
RandomIntRowDatasource(),
|
RandomIntRowDatasource(),
|
||||||
n=size_bytes // 8 // NUM_COLUMNS,
|
n=size_bytes // 8 // NUM_COLUMNS,
|
||||||
num_columns=NUM_COLUMNS).write_parquet(tmpdir)
|
num_columns=NUM_COLUMNS,
|
||||||
|
).write_parquet(tmpdir)
|
||||||
return tmpdir
|
return tmpdir
|
||||||
|
|
||||||
example_files_dir = generate_example_files(SIZE_100MiB)
|
example_files_dir = generate_example_files(SIZE_100MiB)
|
||||||
|
|
||||||
splits = create_shuffle_pipeline(example_files_dir, NUM_EPOCHS,
|
splits = create_shuffle_pipeline(
|
||||||
NUM_TRAINING_WORKERS)
|
example_files_dir, NUM_EPOCHS, NUM_TRAINING_WORKERS
|
||||||
|
)
|
||||||
|
|
||||||
training_workers = [
|
training_workers = [
|
||||||
TrainingWorker.remote(rank, shard) for rank, shard in enumerate(splits)
|
TrainingWorker.remote(rank, shard) for rank, shard in enumerate(splits)
|
||||||
|
@ -198,18 +204,22 @@ if not args.large_scale_test:
|
||||||
# generated data.
|
# generated data.
|
||||||
|
|
||||||
|
|
||||||
def create_large_shuffle_pipeline(data_size_bytes: int, num_epochs: int,
|
def create_large_shuffle_pipeline(
|
||||||
num_columns: int,
|
data_size_bytes: int, num_epochs: int, num_columns: int, num_shards: int
|
||||||
num_shards: int) -> List[DatasetPipeline]:
|
) -> List[DatasetPipeline]:
|
||||||
# _spread_resource_prefix is used to ensure tasks are evenly spread to all
|
# _spread_resource_prefix is used to ensure tasks are evenly spread to all
|
||||||
# CPU nodes.
|
# CPU nodes.
|
||||||
return ray.data.read_datasource(
|
return (
|
||||||
RandomIntRowDatasource(), n=data_size_bytes // 8 // num_columns,
|
ray.data.read_datasource(
|
||||||
|
RandomIntRowDatasource(),
|
||||||
|
n=data_size_bytes // 8 // num_columns,
|
||||||
num_columns=num_columns,
|
num_columns=num_columns,
|
||||||
_spread_resource_prefix="node:") \
|
_spread_resource_prefix="node:",
|
||||||
.repeat(num_epochs) \
|
)
|
||||||
.random_shuffle_each_window(_spread_resource_prefix="node:") \
|
.repeat(num_epochs)
|
||||||
|
.random_shuffle_each_window(_spread_resource_prefix="node:")
|
||||||
.split(num_shards, equal=True)
|
.split(num_shards, equal=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
#################################################################################
|
#################################################################################
|
||||||
|
@ -229,19 +239,18 @@ if args.large_scale_test:
|
||||||
|
|
||||||
# waiting for cluster nodes to come up.
|
# waiting for cluster nodes to come up.
|
||||||
while len(ray.nodes()) < TOTAL_NUM_NODES:
|
while len(ray.nodes()) < TOTAL_NUM_NODES:
|
||||||
print(
|
print(f"waiting for nodes to start up: {len(ray.nodes())}/{TOTAL_NUM_NODES}")
|
||||||
f"waiting for nodes to start up: {len(ray.nodes())}/{TOTAL_NUM_NODES}"
|
|
||||||
)
|
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
splits = create_large_shuffle_pipeline(SIZE_500GiB, NUM_EPOCHS,
|
splits = create_large_shuffle_pipeline(
|
||||||
NUM_COLUMNS, NUM_TRAINING_WORKERS)
|
SIZE_500GiB, NUM_EPOCHS, NUM_COLUMNS, NUM_TRAINING_WORKERS
|
||||||
|
)
|
||||||
|
|
||||||
# Note we set num_gpus=1 for workers so that
|
# Note we set num_gpus=1 for workers so that
|
||||||
# the workers will only run on GPU nodes.
|
# the workers will only run on GPU nodes.
|
||||||
training_workers = [
|
training_workers = [
|
||||||
TrainingWorker.options(num_gpus=1) \
|
TrainingWorker.options(num_gpus=1).remote(rank, shard)
|
||||||
.remote(rank, shard) for rank, shard in enumerate(splits)
|
for rank, shard in enumerate(splits)
|
||||||
]
|
]
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
|
@ -4,21 +4,30 @@ import pyximport
|
||||||
|
|
||||||
pyximport.install(setup_args={"include_dirs": numpy.get_include()})
|
pyximport.install(setup_args={"include_dirs": numpy.get_include()})
|
||||||
|
|
||||||
from .cython_simple import simple_func, fib, fib_int, \
|
from .cython_simple import simple_func, fib, fib_int, fib_cpdef, fib_cdef, simple_class
|
||||||
fib_cpdef, fib_cdef, simple_class
|
|
||||||
from .masked_log import masked_log
|
from .masked_log import masked_log
|
||||||
|
|
||||||
from .cython_blas import \
|
from .cython_blas import (
|
||||||
compute_self_corr_for_voxel_sel, \
|
compute_self_corr_for_voxel_sel,
|
||||||
compute_kernel_matrix, \
|
compute_kernel_matrix,
|
||||||
compute_single_self_corr_syrk, \
|
compute_single_self_corr_syrk,
|
||||||
compute_single_self_corr_gemm, \
|
compute_single_self_corr_gemm,
|
||||||
compute_corr_vectors, \
|
compute_corr_vectors,
|
||||||
compute_single_matrix_multiplication
|
compute_single_matrix_multiplication,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"simple_func", "fib", "fib_int", "fib_cpdef", "fib_cdef", "simple_class",
|
"simple_func",
|
||||||
"masked_log", "compute_self_corr_for_voxel_sel", "compute_kernel_matrix",
|
"fib",
|
||||||
"compute_single_self_corr_syrk", "compute_single_self_corr_gemm",
|
"fib_int",
|
||||||
"compute_corr_vectors", "compute_single_matrix_multiplication"
|
"fib_cpdef",
|
||||||
|
"fib_cdef",
|
||||||
|
"simple_class",
|
||||||
|
"masked_log",
|
||||||
|
"compute_self_corr_for_voxel_sel",
|
||||||
|
"compute_kernel_matrix",
|
||||||
|
"compute_single_self_corr_syrk",
|
||||||
|
"compute_single_self_corr_gemm",
|
||||||
|
"compute_corr_vectors",
|
||||||
|
"compute_single_matrix_multiplication",
|
||||||
]
|
]
|
||||||
|
|
|
@ -94,11 +94,11 @@ def example8():
|
||||||
|
|
||||||
# See cython_blas.pyx for argument documentation
|
# See cython_blas.pyx for argument documentation
|
||||||
mat = np.array(
|
mat = np.array(
|
||||||
[[[2.0, 2.0], [2.0, 2.0]], [[2.0, 2.0], [2.0, 2.0]]], dtype=np.float32)
|
[[[2.0, 2.0], [2.0, 2.0]], [[2.0, 2.0], [2.0, 2.0]]], dtype=np.float32
|
||||||
|
)
|
||||||
result = np.zeros((2, 2), np.float32, order="C")
|
result = np.zeros((2, 2), np.float32, order="C")
|
||||||
|
|
||||||
run_func(cyth.compute_kernel_matrix, "L", "T", 2, 2, 1.0, mat, 0, 2, 1.0,
|
run_func(cyth.compute_kernel_matrix, "L", "T", 2, 2, 1.0, mat, 0, 2, 1.0, result, 2)
|
||||||
result, 2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -13,6 +13,7 @@ include_dirs = [numpy.get_include()]
|
||||||
# dependencies
|
# dependencies
|
||||||
try:
|
try:
|
||||||
import scipy # noqa
|
import scipy # noqa
|
||||||
|
|
||||||
modules.append("cython_blas.pyx")
|
modules.append("cython_blas.pyx")
|
||||||
install_requires.append("scipy")
|
install_requires.append("scipy")
|
||||||
except ImportError as e: # noqa
|
except ImportError as e: # noqa
|
||||||
|
@ -27,4 +28,5 @@ setup(
|
||||||
packages=[pkg_dir],
|
packages=[pkg_dir],
|
||||||
ext_modules=cythonize(modules),
|
ext_modules=cythonize(modules),
|
||||||
install_requires=install_requires,
|
install_requires=install_requires,
|
||||||
include_dirs=include_dirs)
|
include_dirs=include_dirs,
|
||||||
|
)
|
||||||
|
|
|
@ -56,4 +56,5 @@ class CythonTest(unittest.TestCase):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.exit(pytest.main(["-v", __file__]))
|
sys.exit(pytest.main(["-v", __file__]))
|
||||||
|
|
|
@ -65,31 +65,34 @@ from ray.util.dask import ray_dask_get
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--address", type=str, default="auto", help="The address to use for Ray.")
|
"--address", type=str, default="auto", help="The address to use for Ray."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--smoke-test",
|
"--smoke-test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Read a smaller dataset for quick testing purposes.")
|
help="Read a smaller dataset for quick testing purposes.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-actors",
|
"--num-actors", type=int, default=4, help="Sets number of actors for training."
|
||||||
type=int,
|
)
|
||||||
default=4,
|
|
||||||
help="Sets number of actors for training.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cpus-per-actor",
|
"--cpus-per-actor",
|
||||||
type=int,
|
type=int,
|
||||||
default=6,
|
default=6,
|
||||||
help="The number of CPUs per actor for training.")
|
help="The number of CPUs per actor for training.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-actors-inference",
|
"--num-actors-inference",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="Sets number of actors for inference.")
|
help="Sets number of actors for inference.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cpus-per-actor-inference",
|
"--cpus-per-actor-inference",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The number of CPUs per actor for inference.")
|
help="The number of CPUs per actor for inference.",
|
||||||
|
)
|
||||||
# Ignore -f from ipykernel_launcher
|
# Ignore -f from ipykernel_launcher
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
@ -125,12 +128,13 @@ if not ray.is_initialized():
|
||||||
LABEL_COLUMN = "label"
|
LABEL_COLUMN = "label"
|
||||||
if smoke_test:
|
if smoke_test:
|
||||||
# Test dataset with only 10,000 records.
|
# Test dataset with only 10,000 records.
|
||||||
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \
|
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" ".csv"
|
||||||
".csv"
|
|
||||||
else:
|
else:
|
||||||
# Full dataset. This may take a couple of minutes to load.
|
# Full dataset. This may take a couple of minutes to load.
|
||||||
FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \
|
FILE_URL = (
|
||||||
"/00280/HIGGS.csv.gz"
|
"https://archive.ics.uci.edu/ml/machine-learning-databases"
|
||||||
|
"/00280/HIGGS.csv.gz"
|
||||||
|
)
|
||||||
colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)]
|
colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)]
|
||||||
dask.config.set(scheduler=ray_dask_get)
|
dask.config.set(scheduler=ray_dask_get)
|
||||||
|
|
||||||
|
@ -192,7 +196,8 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
|
||||||
dtrain=train_set,
|
dtrain=train_set,
|
||||||
evals=[(test_set, "eval")],
|
evals=[(test_set, "eval")],
|
||||||
evals_result=evals_result,
|
evals_result=evals_result,
|
||||||
ray_params=ray_params)
|
ray_params=ray_params,
|
||||||
|
)
|
||||||
|
|
||||||
train_end_time = time.time()
|
train_end_time = time.time()
|
||||||
train_duration = train_end_time - train_start_time
|
train_duration = train_end_time - train_start_time
|
||||||
|
@ -200,8 +205,7 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
|
||||||
|
|
||||||
model_path = "model.xgb"
|
model_path = "model.xgb"
|
||||||
bst.save_model(model_path)
|
bst.save_model(model_path)
|
||||||
print("Final validation error: {:.4f}".format(
|
print("Final validation error: {:.4f}".format(evals_result["eval"]["error"][-1]))
|
||||||
evals_result["eval"]["error"][-1]))
|
|
||||||
|
|
||||||
return bst, evals_result
|
return bst, evals_result
|
||||||
|
|
||||||
|
@ -221,8 +225,12 @@ config = {
|
||||||
}
|
}
|
||||||
|
|
||||||
bst, evals_result = train_xgboost(
|
bst, evals_result = train_xgboost(
|
||||||
config, train_df, eval_df, LABEL_COLUMN,
|
config,
|
||||||
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors))
|
train_df,
|
||||||
|
eval_df,
|
||||||
|
LABEL_COLUMN,
|
||||||
|
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors),
|
||||||
|
)
|
||||||
print(f"Results: {evals_result}")
|
print(f"Results: {evals_result}")
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
@ -260,13 +268,12 @@ def tune_xgboost(train_df, test_df, target_column):
|
||||||
"eval_metric": ["logloss", "error"],
|
"eval_metric": ["logloss", "error"],
|
||||||
"eta": tune.loguniform(1e-4, 1e-1),
|
"eta": tune.loguniform(1e-4, 1e-1),
|
||||||
"subsample": tune.uniform(0.5, 1.0),
|
"subsample": tune.uniform(0.5, 1.0),
|
||||||
"max_depth": tune.randint(1, 9)
|
"max_depth": tune.randint(1, 9),
|
||||||
}
|
}
|
||||||
|
|
||||||
ray_params = RayParams(
|
ray_params = RayParams(
|
||||||
max_actor_restarts=1,
|
max_actor_restarts=1, cpus_per_actor=cpus_per_actor, num_actors=num_actors
|
||||||
cpus_per_actor=cpus_per_actor,
|
)
|
||||||
num_actors=num_actors)
|
|
||||||
|
|
||||||
tune_start_time = time.time()
|
tune_start_time = time.time()
|
||||||
|
|
||||||
|
@ -276,19 +283,21 @@ def tune_xgboost(train_df, test_df, target_column):
|
||||||
train_df=train_df,
|
train_df=train_df,
|
||||||
test_df=test_df,
|
test_df=test_df,
|
||||||
target_column=target_column,
|
target_column=target_column,
|
||||||
ray_params=ray_params),
|
ray_params=ray_params,
|
||||||
|
),
|
||||||
# Use the `get_tune_resources` helper function to set the resources.
|
# Use the `get_tune_resources` helper function to set the resources.
|
||||||
resources_per_trial=ray_params.get_tune_resources(),
|
resources_per_trial=ray_params.get_tune_resources(),
|
||||||
config=config,
|
config=config,
|
||||||
num_samples=10,
|
num_samples=10,
|
||||||
metric="eval-error",
|
metric="eval-error",
|
||||||
mode="min")
|
mode="min",
|
||||||
|
)
|
||||||
|
|
||||||
tune_end_time = time.time()
|
tune_end_time = time.time()
|
||||||
tune_duration = tune_end_time - tune_start_time
|
tune_duration = tune_end_time - tune_start_time
|
||||||
print(f"Total time taken: {tune_duration} seconds.")
|
print(f"Total time taken: {tune_duration} seconds.")
|
||||||
|
|
||||||
accuracy = 1. - analysis.best_result["eval-error"]
|
accuracy = 1.0 - analysis.best_result["eval-error"]
|
||||||
print(f"Best model parameters: {analysis.best_config}")
|
print(f"Best model parameters: {analysis.best_config}")
|
||||||
print(f"Best model total accuracy: {accuracy:.4f}")
|
print(f"Best model total accuracy: {accuracy:.4f}")
|
||||||
|
|
||||||
|
@ -315,7 +324,8 @@ results = predict(
|
||||||
bst,
|
bst,
|
||||||
inference_df,
|
inference_df,
|
||||||
ray_params=RayParams(
|
ray_params=RayParams(
|
||||||
cpus_per_actor=cpus_per_actor_inference,
|
cpus_per_actor=cpus_per_actor_inference, num_actors=num_actors_inference
|
||||||
num_actors=num_actors_inference))
|
),
|
||||||
|
)
|
||||||
|
|
||||||
print(results)
|
print(results)
|
||||||
|
|
|
@ -59,7 +59,8 @@ def make_and_upload_dataset(dir_path):
|
||||||
shift=0.0,
|
shift=0.0,
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
random_state=seed)
|
random_state=seed,
|
||||||
|
)
|
||||||
|
|
||||||
# turn into dataframe with column names
|
# turn into dataframe with column names
|
||||||
col_names = ["feature_%0d" % i for i in range(1, d + 1, 1)]
|
col_names = ["feature_%0d" % i for i in range(1, d + 1, 1)]
|
||||||
|
@ -91,10 +92,8 @@ def make_and_upload_dataset(dir_path):
|
||||||
path = os.path.join(data_path, f"data_{i:05d}.parquet.snappy")
|
path = os.path.join(data_path, f"data_{i:05d}.parquet.snappy")
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
tmp_df = create_data_chunk(
|
tmp_df = create_data_chunk(
|
||||||
n=PARQUET_FILE_CHUNK_SIZE,
|
n=PARQUET_FILE_CHUNK_SIZE, d=NUM_FEATURES, seed=i, include_label=True
|
||||||
d=NUM_FEATURES,
|
)
|
||||||
seed=i,
|
|
||||||
include_label=True)
|
|
||||||
tmp_df.to_parquet(path, compression="snappy", index=False)
|
tmp_df.to_parquet(path, compression="snappy", index=False)
|
||||||
print(f"Wrote {path} to disk...")
|
print(f"Wrote {path} to disk...")
|
||||||
# todo: at large enough scale we might want to upload the rest after
|
# todo: at large enough scale we might want to upload the rest after
|
||||||
|
@ -108,10 +107,8 @@ def make_and_upload_dataset(dir_path):
|
||||||
path = os.path.join(inference_path, f"data_{i:05d}.parquet.snappy")
|
path = os.path.join(inference_path, f"data_{i:05d}.parquet.snappy")
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
tmp_df = create_data_chunk(
|
tmp_df = create_data_chunk(
|
||||||
n=PARQUET_FILE_CHUNK_SIZE,
|
n=PARQUET_FILE_CHUNK_SIZE, d=NUM_FEATURES, seed=i, include_label=False
|
||||||
d=NUM_FEATURES,
|
)
|
||||||
seed=i,
|
|
||||||
include_label=False)
|
|
||||||
tmp_df.to_parquet(path, compression="snappy", index=False)
|
tmp_df.to_parquet(path, compression="snappy", index=False)
|
||||||
print(f"Wrote {path} to disk...")
|
print(f"Wrote {path} to disk...")
|
||||||
# todo: at large enough scale we might want to upload the rest after
|
# todo: at large enough scale we might want to upload the rest after
|
||||||
|
@ -124,8 +121,9 @@ def make_and_upload_dataset(dir_path):
|
||||||
|
|
||||||
def read_dataset(path: str) -> ray.data.Dataset:
|
def read_dataset(path: str) -> ray.data.Dataset:
|
||||||
print(f"reading data from {path}")
|
print(f"reading data from {path}")
|
||||||
return ray.data.read_parquet(path, _spread_resource_prefix="node:") \
|
return ray.data.read_parquet(path, _spread_resource_prefix="node:").random_shuffle(
|
||||||
.random_shuffle(_spread_resource_prefix="node:")
|
_spread_resource_prefix="node:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DataPreprocessor:
|
class DataPreprocessor:
|
||||||
|
@ -141,20 +139,20 @@ class DataPreprocessor:
|
||||||
# columns.
|
# columns.
|
||||||
self.standard_stats = None
|
self.standard_stats = None
|
||||||
|
|
||||||
def preprocess_train_data(self, ds: ray.data.Dataset
|
def preprocess_train_data(
|
||||||
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
|
self, ds: ray.data.Dataset
|
||||||
|
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
|
||||||
print("\n\nPreprocessing training dataset.\n")
|
print("\n\nPreprocessing training dataset.\n")
|
||||||
return self._preprocess(ds, False)
|
return self._preprocess(ds, False)
|
||||||
|
|
||||||
def preprocess_inference_data(self,
|
def preprocess_inference_data(self, df: ray.data.Dataset) -> ray.data.Dataset:
|
||||||
df: ray.data.Dataset) -> ray.data.Dataset:
|
|
||||||
print("\n\nPreprocessing inference dataset.\n")
|
print("\n\nPreprocessing inference dataset.\n")
|
||||||
return self._preprocess(df, True)[0]
|
return self._preprocess(df, True)[0]
|
||||||
|
|
||||||
def _preprocess(self, ds: ray.data.Dataset, inferencing: bool
|
def _preprocess(
|
||||||
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
|
self, ds: ray.data.Dataset, inferencing: bool
|
||||||
print(
|
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
|
||||||
"\nStep 1: Dropping nulls, creating new_col, updating feature_1\n")
|
print("\nStep 1: Dropping nulls, creating new_col, updating feature_1\n")
|
||||||
|
|
||||||
def batch_transformer(df: pd.DataFrame):
|
def batch_transformer(df: pd.DataFrame):
|
||||||
# Disable chained assignment warning.
|
# Disable chained assignment warning.
|
||||||
|
@ -165,25 +163,27 @@ class DataPreprocessor:
|
||||||
|
|
||||||
# Add new column.
|
# Add new column.
|
||||||
df["new_col"] = (
|
df["new_col"] = (
|
||||||
df["feature_1"] - 2 * df["feature_2"] + df["feature_3"]) / 3.
|
df["feature_1"] - 2 * df["feature_2"] + df["feature_3"]
|
||||||
|
) / 3.0
|
||||||
|
|
||||||
# Transform column.
|
# Transform column.
|
||||||
df["feature_1"] = 2. * df["feature_1"] + 0.1
|
df["feature_1"] = 2.0 * df["feature_1"] + 0.1
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
ds = ds.map_batches(batch_transformer, batch_format="pandas")
|
ds = ds.map_batches(batch_transformer, batch_format="pandas")
|
||||||
|
|
||||||
print("\nStep 2: Precalculating fruit-grouped mean for new column and "
|
print(
|
||||||
"for one-hot encoding (latter only uses fruit groups)\n")
|
"\nStep 2: Precalculating fruit-grouped mean for new column and "
|
||||||
|
"for one-hot encoding (latter only uses fruit groups)\n"
|
||||||
|
)
|
||||||
agg_ds = ds.groupby("fruit").mean("feature_1")
|
agg_ds = ds.groupby("fruit").mean("feature_1")
|
||||||
fruit_means = {
|
fruit_means = {r["fruit"]: r["mean(feature_1)"] for r in agg_ds.take_all()}
|
||||||
r["fruit"]: r["mean(feature_1)"]
|
|
||||||
for r in agg_ds.take_all()
|
|
||||||
}
|
|
||||||
|
|
||||||
print("\nStep 3: create mean_by_fruit as mean of feature_1 groupby "
|
print(
|
||||||
"fruit; one-hot encode fruit column\n")
|
"\nStep 3: create mean_by_fruit as mean of feature_1 groupby "
|
||||||
|
"fruit; one-hot encode fruit column\n"
|
||||||
|
)
|
||||||
|
|
||||||
if inferencing:
|
if inferencing:
|
||||||
assert self.fruits is not None
|
assert self.fruits is not None
|
||||||
|
@ -192,8 +192,7 @@ class DataPreprocessor:
|
||||||
self.fruits = list(fruit_means.keys())
|
self.fruits = list(fruit_means.keys())
|
||||||
|
|
||||||
fruit_one_hots = {
|
fruit_one_hots = {
|
||||||
fruit: collections.defaultdict(int, fruit=1)
|
fruit: collections.defaultdict(int, fruit=1) for fruit in self.fruits
|
||||||
for fruit in self.fruits
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def batch_transformer(df: pd.DataFrame):
|
def batch_transformer(df: pd.DataFrame):
|
||||||
|
@ -224,12 +223,12 @@ class DataPreprocessor:
|
||||||
# Split into 90% training set, 10% test set.
|
# Split into 90% training set, 10% test set.
|
||||||
train_ds, test_ds = ds.split_at_indices([split_index])
|
train_ds, test_ds = ds.split_at_indices([split_index])
|
||||||
|
|
||||||
print("\nStep 4b: Precalculate training dataset stats for "
|
print(
|
||||||
"standard scaling\n")
|
"\nStep 4b: Precalculate training dataset stats for "
|
||||||
|
"standard scaling\n"
|
||||||
|
)
|
||||||
# Calculate stats needed for standard scaling feature columns.
|
# Calculate stats needed for standard scaling feature columns.
|
||||||
feature_columns = [
|
feature_columns = [col for col in train_ds.schema().names if col != "label"]
|
||||||
col for col in train_ds.schema().names if col != "label"
|
|
||||||
]
|
|
||||||
standard_aggs = [
|
standard_aggs = [
|
||||||
agg(on=col) for col in feature_columns for agg in (Mean, Std)
|
agg(on=col) for col in feature_columns for agg in (Mean, Std)
|
||||||
]
|
]
|
||||||
|
@ -252,30 +251,29 @@ class DataPreprocessor:
|
||||||
|
|
||||||
if inferencing:
|
if inferencing:
|
||||||
# Apply standard scaling to inference dataset.
|
# Apply standard scaling to inference dataset.
|
||||||
inference_ds = ds.map_batches(
|
inference_ds = ds.map_batches(batch_standard_scaler, batch_format="pandas")
|
||||||
batch_standard_scaler, batch_format="pandas")
|
|
||||||
return inference_ds, None
|
return inference_ds, None
|
||||||
else:
|
else:
|
||||||
# Apply standard scaling to both training dataset and test dataset.
|
# Apply standard scaling to both training dataset and test dataset.
|
||||||
train_ds = train_ds.map_batches(
|
train_ds = train_ds.map_batches(
|
||||||
batch_standard_scaler, batch_format="pandas")
|
batch_standard_scaler, batch_format="pandas"
|
||||||
test_ds = test_ds.map_batches(
|
)
|
||||||
batch_standard_scaler, batch_format="pandas")
|
test_ds = test_ds.map_batches(batch_standard_scaler, batch_format="pandas")
|
||||||
return train_ds, test_ds
|
return train_ds, test_ds
|
||||||
|
|
||||||
|
|
||||||
def inference(dataset, model_cls: type, batch_size: int, result_path: str,
|
def inference(
|
||||||
use_gpu: bool):
|
dataset, model_cls: type, batch_size: int, result_path: str, use_gpu: bool
|
||||||
|
):
|
||||||
print("inferencing...")
|
print("inferencing...")
|
||||||
num_gpus = 1 if use_gpu else 0
|
num_gpus = 1 if use_gpu else 0
|
||||||
dataset \
|
dataset.map_batches(
|
||||||
.map_batches(
|
model_cls,
|
||||||
model_cls,
|
compute="actors",
|
||||||
compute="actors",
|
batch_size=batch_size,
|
||||||
batch_size=batch_size,
|
num_gpus=num_gpus,
|
||||||
num_gpus=num_gpus,
|
num_cpus=0,
|
||||||
num_cpus=0) \
|
).write_parquet(result_path)
|
||||||
.write_parquet(result_path)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -295,8 +293,7 @@ P1:
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self, n_layers, n_features, num_hidden, dropout_every,
|
def __init__(self, n_layers, n_features, num_hidden, dropout_every, drop_prob):
|
||||||
drop_prob):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
self.dropout_every = dropout_every
|
self.dropout_every = dropout_every
|
||||||
|
@ -406,8 +403,9 @@ def train_func(config):
|
||||||
print("Defining model, loss, and optimizer...")
|
print("Defining model, loss, and optimizer...")
|
||||||
|
|
||||||
# Setup device.
|
# Setup device.
|
||||||
device = torch.device(f"cuda:{train.local_rank()}"
|
device = torch.device(
|
||||||
if use_gpu and torch.cuda.is_available() else "cpu")
|
f"cuda:{train.local_rank()}" if use_gpu and torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
print(f"Device: {device}")
|
print(f"Device: {device}")
|
||||||
|
|
||||||
# Setup data.
|
# Setup data.
|
||||||
|
@ -415,7 +413,8 @@ def train_func(config):
|
||||||
train_dataset_epoch_iterator = train_dataset_pipeline.iter_epochs()
|
train_dataset_epoch_iterator = train_dataset_pipeline.iter_epochs()
|
||||||
test_dataset = train.get_dataset_shard("test_dataset")
|
test_dataset = train.get_dataset_shard("test_dataset")
|
||||||
test_torch_dataset = test_dataset.to_torch(
|
test_torch_dataset = test_dataset.to_torch(
|
||||||
label_column="label", batch_size=batch_size)
|
label_column="label", batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
net = Net(
|
net = Net(
|
||||||
n_layers=num_layers,
|
n_layers=num_layers,
|
||||||
|
@ -436,30 +435,37 @@ def train_func(config):
|
||||||
train_dataset = next(train_dataset_epoch_iterator)
|
train_dataset = next(train_dataset_epoch_iterator)
|
||||||
|
|
||||||
train_torch_dataset = train_dataset.to_torch(
|
train_torch_dataset = train_dataset.to_torch(
|
||||||
label_column="label", batch_size=batch_size)
|
label_column="label", batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
train_running_loss, train_num_correct, train_num_total = train_epoch(
|
train_running_loss, train_num_correct, train_num_total = train_epoch(
|
||||||
train_torch_dataset, net, device, criterion, optimizer)
|
train_torch_dataset, net, device, criterion, optimizer
|
||||||
|
)
|
||||||
train_acc = train_num_correct / train_num_total
|
train_acc = train_num_correct / train_num_total
|
||||||
print(f"epoch [{epoch + 1}]: training accuracy: "
|
print(
|
||||||
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}")
|
f"epoch [{epoch + 1}]: training accuracy: "
|
||||||
|
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
test_running_loss, test_num_correct, test_num_total = test_epoch(
|
test_running_loss, test_num_correct, test_num_total = test_epoch(
|
||||||
test_torch_dataset, net, device, criterion)
|
test_torch_dataset, net, device, criterion
|
||||||
|
)
|
||||||
test_acc = test_num_correct / test_num_total
|
test_acc = test_num_correct / test_num_total
|
||||||
print(f"epoch [{epoch + 1}]: testing accuracy: "
|
print(
|
||||||
f"{test_num_correct} / {test_num_total} = {test_acc:.4f}")
|
f"epoch [{epoch + 1}]: testing accuracy: "
|
||||||
|
f"{test_num_correct} / {test_num_total} = {test_acc:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
# Record and log stats.
|
# Record and log stats.
|
||||||
train.report(
|
train.report(
|
||||||
train_acc=train_acc,
|
train_acc=train_acc,
|
||||||
train_loss=train_running_loss,
|
train_loss=train_running_loss,
|
||||||
test_acc=test_acc,
|
test_acc=test_acc,
|
||||||
test_loss=test_running_loss)
|
test_loss=test_running_loss,
|
||||||
|
)
|
||||||
|
|
||||||
# Checkpoint model.
|
# Checkpoint model.
|
||||||
module = (net.module
|
module = net.module if isinstance(net, DistributedDataParallel) else net
|
||||||
if isinstance(net, DistributedDataParallel) else net)
|
|
||||||
train.save_checkpoint(model_state_dict=module.state_dict())
|
train.save_checkpoint(model_state_dict=module.state_dict())
|
||||||
|
|
||||||
if train.world_rank() == 0:
|
if train.world_rank() == 0:
|
||||||
|
@ -469,46 +475,44 @@ def train_func(config):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dir-path",
|
"--dir-path", default=".", type=str, help="Path to read and write data from"
|
||||||
default=".",
|
)
|
||||||
type=str,
|
|
||||||
help="Path to read and write data from")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-s3",
|
"--use-s3",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Use data from s3 for testing.")
|
help="Use data from s3 for testing.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--smoke-test",
|
"--smoke-test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Finish quickly for testing.")
|
help="Finish quickly for testing.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--address",
|
"--address",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
help="The address to use for Ray. "
|
help="The address to use for Ray. " "`auto` if running through `ray submit.",
|
||||||
"`auto` if running through `ray submit.")
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
default=1,
|
default=1,
|
||||||
type=int,
|
type=int,
|
||||||
help="The number of Ray workers to use for distributed training")
|
help="The number of Ray workers to use for distributed training",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--large-dataset",
|
"--large-dataset", action="store_true", default=False, help="Use 500GB dataset"
|
||||||
action="store_true",
|
)
|
||||||
default=False,
|
|
||||||
help="Use 500GB dataset")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-gpu",
|
"--use-gpu", action="store_true", default=False, help="Use GPU for training."
|
||||||
action="store_true",
|
)
|
||||||
default=False,
|
|
||||||
help="Use GPU for training.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mlflow-register-model",
|
"--mlflow-register-model",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to use mlflow model registry. If set, a local MLflow "
|
help="Whether to use mlflow model registry. If set, a local MLflow "
|
||||||
"tracking server is expected to have already been started.")
|
"tracking server is expected to have already been started.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
smoke_test = args.smoke_test
|
smoke_test = args.smoke_test
|
||||||
|
@ -553,8 +557,11 @@ if __name__ == "__main__":
|
||||||
if len(list(count)) == 0:
|
if len(list(count)) == 0:
|
||||||
print("please run `python make_and_upload_dataset.py` first")
|
print("please run `python make_and_upload_dataset.py` first")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
data_path = ("s3://cuj-big-data/big-data/"
|
data_path = (
|
||||||
if large_dataset else "s3://cuj-big-data/data/")
|
"s3://cuj-big-data/big-data/"
|
||||||
|
if large_dataset
|
||||||
|
else "s3://cuj-big-data/data/"
|
||||||
|
)
|
||||||
inference_path = "s3://cuj-big-data/inference/"
|
inference_path = "s3://cuj-big-data/inference/"
|
||||||
inference_output_path = "s3://cuj-big-data/output/"
|
inference_output_path = "s3://cuj-big-data/output/"
|
||||||
else:
|
else:
|
||||||
|
@ -562,20 +569,19 @@ if __name__ == "__main__":
|
||||||
inference_path = os.path.join(dir_path, "inference")
|
inference_path = os.path.join(dir_path, "inference")
|
||||||
inference_output_path = "/tmp"
|
inference_output_path = "/tmp"
|
||||||
|
|
||||||
if len(os.listdir(data_path)) <= 1 or len(
|
if len(os.listdir(data_path)) <= 1 or len(os.listdir(inference_path)) <= 1:
|
||||||
os.listdir(inference_path)) <= 1:
|
|
||||||
print("please run `python make_and_upload_dataset.py` first")
|
print("please run `python make_and_upload_dataset.py` first")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if smoke_test:
|
if smoke_test:
|
||||||
# Only read a single file.
|
# Only read a single file.
|
||||||
data_path = os.path.join(data_path, "data_00000.parquet.snappy")
|
data_path = os.path.join(data_path, "data_00000.parquet.snappy")
|
||||||
inference_path = os.path.join(inference_path,
|
inference_path = os.path.join(inference_path, "data_00000.parquet.snappy")
|
||||||
"data_00000.parquet.snappy")
|
|
||||||
|
|
||||||
preprocessor = DataPreprocessor()
|
preprocessor = DataPreprocessor()
|
||||||
train_dataset, test_dataset = preprocessor.preprocess_train_data(
|
train_dataset, test_dataset = preprocessor.preprocess_train_data(
|
||||||
read_dataset(data_path))
|
read_dataset(data_path)
|
||||||
|
)
|
||||||
|
|
||||||
num_columns = len(train_dataset.schema().names)
|
num_columns = len(train_dataset.schema().names)
|
||||||
# remove label column and internal Arrow column.
|
# remove label column and internal Arrow column.
|
||||||
|
@ -589,14 +595,12 @@ if __name__ == "__main__":
|
||||||
DROPOUT_PROB = 0.2
|
DROPOUT_PROB = 0.2
|
||||||
|
|
||||||
# Random global shuffle
|
# Random global shuffle
|
||||||
train_dataset_pipeline = train_dataset.repeat() \
|
train_dataset_pipeline = train_dataset.repeat().random_shuffle_each_window(
|
||||||
.random_shuffle_each_window(_spread_resource_prefix="node:")
|
_spread_resource_prefix="node:"
|
||||||
|
)
|
||||||
del train_dataset
|
del train_dataset
|
||||||
|
|
||||||
datasets = {
|
datasets = {"train_dataset": train_dataset_pipeline, "test_dataset": test_dataset}
|
||||||
"train_dataset": train_dataset_pipeline,
|
|
||||||
"test_dataset": test_dataset
|
|
||||||
}
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"use_gpu": use_gpu,
|
"use_gpu": use_gpu,
|
||||||
|
@ -606,7 +610,7 @@ if __name__ == "__main__":
|
||||||
"num_layers": NUM_LAYERS,
|
"num_layers": NUM_LAYERS,
|
||||||
"dropout_every": DROPOUT_EVERY,
|
"dropout_every": DROPOUT_EVERY,
|
||||||
"dropout_prob": DROPOUT_PROB,
|
"dropout_prob": DROPOUT_PROB,
|
||||||
"num_features": num_features
|
"num_features": num_features,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create 2 callbacks: one for Tensorboard Logging and one for MLflow
|
# Create 2 callbacks: one for Tensorboard Logging and one for MLflow
|
||||||
|
@ -619,7 +623,8 @@ if __name__ == "__main__":
|
||||||
callbacks = [
|
callbacks = [
|
||||||
TBXLoggerCallback(logdir=tbx_logdir),
|
TBXLoggerCallback(logdir=tbx_logdir),
|
||||||
MLflowLoggerCallback(
|
MLflowLoggerCallback(
|
||||||
experiment_name="cuj-big-data-training", save_artifact=True)
|
experiment_name="cuj-big-data-training", save_artifact=True
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Remove CPU resource so Datasets can be scheduled.
|
# Remove CPU resource so Datasets can be scheduled.
|
||||||
|
@ -629,19 +634,19 @@ if __name__ == "__main__":
|
||||||
backend="torch",
|
backend="torch",
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
use_gpu=use_gpu,
|
use_gpu=use_gpu,
|
||||||
resources_per_worker=resources_per_worker)
|
resources_per_worker=resources_per_worker,
|
||||||
|
)
|
||||||
trainer.start()
|
trainer.start()
|
||||||
results = trainer.run(
|
results = trainer.run(
|
||||||
train_func=train_func,
|
train_func=train_func, config=config, callbacks=callbacks, dataset=datasets
|
||||||
config=config,
|
)
|
||||||
callbacks=callbacks,
|
|
||||||
dataset=datasets)
|
|
||||||
model = results[0]
|
model = results[0]
|
||||||
trainer.shutdown()
|
trainer.shutdown()
|
||||||
|
|
||||||
if args.mlflow_register_model:
|
if args.mlflow_register_model:
|
||||||
mlflow.pytorch.log_model(
|
mlflow.pytorch.log_model(
|
||||||
model, artifact_path="models", registered_model_name="torch_model")
|
model, artifact_path="models", registered_model_name="torch_model"
|
||||||
|
)
|
||||||
|
|
||||||
# Get the latest model from mlflow model registry.
|
# Get the latest model from mlflow model registry.
|
||||||
client = mlflow.tracking.MlflowClient()
|
client = mlflow.tracking.MlflowClient()
|
||||||
|
@ -649,12 +654,14 @@ if __name__ == "__main__":
|
||||||
# Get the info for the latest model.
|
# Get the info for the latest model.
|
||||||
# By default, registered models are in stage "None".
|
# By default, registered models are in stage "None".
|
||||||
latest_model_info = client.get_latest_versions(
|
latest_model_info = client.get_latest_versions(
|
||||||
registered_model_name, stages=["None"])[0]
|
registered_model_name, stages=["None"]
|
||||||
|
)[0]
|
||||||
latest_version = latest_model_info.version
|
latest_version = latest_model_info.version
|
||||||
|
|
||||||
def load_model_func():
|
def load_model_func():
|
||||||
model_uri = f"models:/torch_model/{latest_version}"
|
model_uri = f"models:/torch_model/{latest_version}"
|
||||||
return mlflow.pytorch.load_model(model_uri)
|
return mlflow.pytorch.load_model(model_uri)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
@ -670,25 +677,30 @@ if __name__ == "__main__":
|
||||||
n_features=num_features,
|
n_features=num_features,
|
||||||
num_hidden=num_hidden,
|
num_hidden=num_hidden,
|
||||||
dropout_every=dropout_every,
|
dropout_every=dropout_every,
|
||||||
drop_prob=dropout_prob)
|
drop_prob=dropout_prob,
|
||||||
|
)
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
class BatchInferModel:
|
class BatchInferModel:
|
||||||
def __init__(self, load_model_func):
|
def __init__(self, load_model_func):
|
||||||
self.device = torch.device("cuda:0"
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
if torch.cuda.is_available() else "cpu")
|
|
||||||
self.model = load_model_func().to(self.device)
|
self.model = load_model_func().to(self.device)
|
||||||
|
|
||||||
def __call__(self, batch) -> "pd.DataFrame":
|
def __call__(self, batch) -> "pd.DataFrame":
|
||||||
tensor = torch.FloatTensor(batch.to_pandas().values).to(
|
tensor = torch.FloatTensor(batch.to_pandas().values).to(self.device)
|
||||||
self.device)
|
|
||||||
return pd.DataFrame(self.model(tensor).cpu().detach().numpy())
|
return pd.DataFrame(self.model(tensor).cpu().detach().numpy())
|
||||||
|
|
||||||
inference_dataset = preprocessor.preprocess_inference_data(
|
inference_dataset = preprocessor.preprocess_inference_data(
|
||||||
read_dataset(inference_path))
|
read_dataset(inference_path)
|
||||||
inference(inference_dataset, BatchInferModel(load_model_func), 100,
|
)
|
||||||
inference_output_path, use_gpu)
|
inference(
|
||||||
|
inference_dataset,
|
||||||
|
BatchInferModel(load_model_func),
|
||||||
|
100,
|
||||||
|
inference_output_path,
|
||||||
|
use_gpu,
|
||||||
|
)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
|
|
|
@ -14,20 +14,23 @@ class MyActor:
|
||||||
self.counter = Counter(
|
self.counter = Counter(
|
||||||
"num_requests",
|
"num_requests",
|
||||||
description="Number of requests processed by the actor.",
|
description="Number of requests processed by the actor.",
|
||||||
tag_keys=("actor_name", ))
|
tag_keys=("actor_name",),
|
||||||
|
)
|
||||||
self.counter.set_default_tags({"actor_name": name})
|
self.counter.set_default_tags({"actor_name": name})
|
||||||
|
|
||||||
self.gauge = Gauge(
|
self.gauge = Gauge(
|
||||||
"curr_count",
|
"curr_count",
|
||||||
description="Current count held by the actor. Goes up and down.",
|
description="Current count held by the actor. Goes up and down.",
|
||||||
tag_keys=("actor_name", ))
|
tag_keys=("actor_name",),
|
||||||
|
)
|
||||||
self.gauge.set_default_tags({"actor_name": name})
|
self.gauge.set_default_tags({"actor_name": name})
|
||||||
|
|
||||||
self.histogram = Histogram(
|
self.histogram = Histogram(
|
||||||
"request_latency",
|
"request_latency",
|
||||||
description="Latencies of requests in ms.",
|
description="Latencies of requests in ms.",
|
||||||
boundaries=[0.1, 1],
|
boundaries=[0.1, 1],
|
||||||
tag_keys=("actor_name", ))
|
tag_keys=("actor_name",),
|
||||||
|
)
|
||||||
self.histogram.set_default_tags({"actor_name": name})
|
self.histogram.set_default_tags({"actor_name": name})
|
||||||
|
|
||||||
def process_request(self, num):
|
def process_request(self, num):
|
||||||
|
|
|
@ -46,7 +46,8 @@ class LinearModel(object):
|
||||||
y_ = tf.placeholder(tf.float32, [None, shape[1]])
|
y_ = tf.placeholder(tf.float32, [None, shape[1]])
|
||||||
self.y_ = y_
|
self.y_ = y_
|
||||||
cross_entropy = tf.reduce_mean(
|
cross_entropy = tf.reduce_mean(
|
||||||
-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
|
-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])
|
||||||
|
)
|
||||||
self.cross_entropy = cross_entropy
|
self.cross_entropy = cross_entropy
|
||||||
self.cross_entropy_grads = tf.gradients(cross_entropy, [w, b])
|
self.cross_entropy_grads = tf.gradients(cross_entropy, [w, b])
|
||||||
self.sess = tf.Session()
|
self.sess = tf.Session()
|
||||||
|
@ -54,24 +55,20 @@ class LinearModel(object):
|
||||||
# Ray's TensorFlowVariables to automatically create methods to modify
|
# Ray's TensorFlowVariables to automatically create methods to modify
|
||||||
# the weights.
|
# the weights.
|
||||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||||
cross_entropy, self.sess)
|
cross_entropy, self.sess
|
||||||
|
)
|
||||||
|
|
||||||
def loss(self, xs, ys):
|
def loss(self, xs, ys):
|
||||||
"""Computes the loss of the network."""
|
"""Computes the loss of the network."""
|
||||||
return float(
|
return float(
|
||||||
self.sess.run(
|
self.sess.run(self.cross_entropy, feed_dict={self.x: xs, self.y_: ys})
|
||||||
self.cross_entropy, feed_dict={
|
)
|
||||||
self.x: xs,
|
|
||||||
self.y_: ys
|
|
||||||
}))
|
|
||||||
|
|
||||||
def grad(self, xs, ys):
|
def grad(self, xs, ys):
|
||||||
"""Computes the gradients of the network."""
|
"""Computes the gradients of the network."""
|
||||||
return self.sess.run(
|
return self.sess.run(
|
||||||
self.cross_entropy_grads, feed_dict={
|
self.cross_entropy_grads, feed_dict={self.x: xs, self.y_: ys}
|
||||||
self.x: xs,
|
)
|
||||||
self.y_: ys
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -143,4 +140,5 @@ if __name__ == "__main__":
|
||||||
# Use L-BFGS to minimize the loss function.
|
# Use L-BFGS to minimize the loss function.
|
||||||
print("Running L-BFGS.")
|
print("Running L-BFGS.")
|
||||||
result = scipy.optimize.fmin_l_bfgs_b(
|
result = scipy.optimize.fmin_l_bfgs_b(
|
||||||
full_loss, theta_init, maxiter=10, fprime=full_grad, disp=True)
|
full_loss, theta_init, maxiter=10, fprime=full_grad, disp=True
|
||||||
|
)
|
||||||
|
|
|
@ -26,8 +26,7 @@ class RayDistributedActor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Set the init_method and rank of the process for distributed training.
|
# Set the init_method and rank of the process for distributed training.
|
||||||
print("Ray worker at {url} rank {rank}".format(
|
print("Ray worker at {url} rank {rank}".format(url=url, rank=world_rank))
|
||||||
url=url, rank=world_rank))
|
|
||||||
self.url = url
|
self.url = url
|
||||||
self.world_rank = world_rank
|
self.world_rank = world_rank
|
||||||
args.distributed_rank = world_rank
|
args.distributed_rank = world_rank
|
||||||
|
@ -55,8 +54,10 @@ class RayDistributedActor:
|
||||||
n_cpus = int(ray.cluster_resources()["CPU"])
|
n_cpus = int(ray.cluster_resources()["CPU"])
|
||||||
if n_cpus > original_n_cpus:
|
if n_cpus > original_n_cpus:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"New CPUs find (original %d CPUs, now %d CPUs)" %
|
"New CPUs find (original %d CPUs, now %d CPUs)"
|
||||||
(original_n_cpus, n_cpus))
|
% (original_n_cpus, n_cpus)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
original_n_gpus = args.distributed_world_size
|
original_n_gpus = args.distributed_world_size
|
||||||
|
|
||||||
|
@ -65,8 +66,9 @@ class RayDistributedActor:
|
||||||
n_gpus = int(ray.cluster_resources().get("GPU", 0))
|
n_gpus = int(ray.cluster_resources().get("GPU", 0))
|
||||||
if n_gpus > original_n_gpus:
|
if n_gpus > original_n_gpus:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"New GPUs find (original %d GPUs, now %d GPUs)" %
|
"New GPUs find (original %d GPUs, now %d GPUs)"
|
||||||
(original_n_gpus, n_gpus))
|
% (original_n_gpus, n_gpus)
|
||||||
|
)
|
||||||
|
|
||||||
fairseq.checkpoint_utils.save_checkpoint = _new_save_checkpoint
|
fairseq.checkpoint_utils.save_checkpoint = _new_save_checkpoint
|
||||||
|
|
||||||
|
@ -103,8 +105,7 @@ def run_fault_tolerant_loop():
|
||||||
set_batch_size(args)
|
set_batch_size(args)
|
||||||
|
|
||||||
# Set up Ray distributed actors.
|
# Set up Ray distributed actors.
|
||||||
Actor = ray.remote(
|
Actor = ray.remote(num_cpus=1, num_gpus=int(not args.cpu))(RayDistributedActor)
|
||||||
num_cpus=1, num_gpus=int(not args.cpu))(RayDistributedActor)
|
|
||||||
workers = [Actor.remote() for i in range(args.distributed_world_size)]
|
workers = [Actor.remote() for i in range(args.distributed_world_size)]
|
||||||
|
|
||||||
# Get the IP address and a free port of actor 0, which is used for
|
# Get the IP address and a free port of actor 0, which is used for
|
||||||
|
@ -116,8 +117,7 @@ def run_fault_tolerant_loop():
|
||||||
# Start the remote processes, and check whether their are any process
|
# Start the remote processes, and check whether their are any process
|
||||||
# fails. If so, restart all the processes.
|
# fails. If so, restart all the processes.
|
||||||
unfinished = [
|
unfinished = [
|
||||||
worker.run.remote(address, i, args)
|
worker.run.remote(address, i, args) for i, worker in enumerate(workers)
|
||||||
for i, worker in enumerate(workers)
|
|
||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
while len(unfinished) > 0:
|
while len(unfinished) > 0:
|
||||||
|
@ -135,10 +135,8 @@ def add_ray_args(parser):
|
||||||
"""Add ray and fault-tolerance related parser arguments to the parser."""
|
"""Add ray and fault-tolerance related parser arguments to the parser."""
|
||||||
group = parser.add_argument_group("Ray related arguments")
|
group = parser.add_argument_group("Ray related arguments")
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--ray-address",
|
"--ray-address", default="auto", type=str, help="address for ray initialization"
|
||||||
default="auto",
|
)
|
||||||
type=str,
|
|
||||||
help="address for ray initialization")
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--fix-batch-size",
|
"--fix-batch-size",
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -147,7 +145,8 @@ def add_ray_args(parser):
|
||||||
help="fix the actual batch size (max_sentences * update_freq "
|
help="fix the actual batch size (max_sentences * update_freq "
|
||||||
"* n_GPUs) to be the fixed input values by adjusting update_freq "
|
"* n_GPUs) to be the fixed input values by adjusting update_freq "
|
||||||
"accroding to actual n_GPUs; the batch size is fixed to B_i for "
|
"accroding to actual n_GPUs; the batch size is fixed to B_i for "
|
||||||
"epoch i; all epochs >N are fixed to B_N")
|
"epoch i; all epochs >N are fixed to B_N",
|
||||||
|
)
|
||||||
return group
|
return group
|
||||||
|
|
||||||
|
|
||||||
|
@ -168,13 +167,13 @@ def set_batch_size(args):
|
||||||
"""Fixes the total batch_size to be agnostic to the GPU count."""
|
"""Fixes the total batch_size to be agnostic to the GPU count."""
|
||||||
if args.fix_batch_size is not None:
|
if args.fix_batch_size is not None:
|
||||||
args.update_freq = [
|
args.update_freq = [
|
||||||
math.ceil(batch_size /
|
math.ceil(batch_size / (args.max_sentences * args.distributed_world_size))
|
||||||
(args.max_sentences * args.distributed_world_size))
|
|
||||||
for batch_size in args.fix_batch_size
|
for batch_size in args.fix_batch_size
|
||||||
]
|
]
|
||||||
print("Training on %d GPUs, max_sentences=%d, update_freq=%s" %
|
print(
|
||||||
(args.distributed_world_size, args.max_sentences,
|
"Training on %d GPUs, max_sentences=%d, update_freq=%s"
|
||||||
repr(args.update_freq)))
|
% (args.distributed_world_size, args.max_sentences, repr(args.update_freq))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -62,31 +62,34 @@ import ray
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--address", type=str, default="auto", help="The address to use for Ray.")
|
"--address", type=str, default="auto", help="The address to use for Ray."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--smoke-test",
|
"--smoke-test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Read a smaller dataset for quick testing purposes.")
|
help="Read a smaller dataset for quick testing purposes.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-actors",
|
"--num-actors", type=int, default=4, help="Sets number of actors for training."
|
||||||
type=int,
|
)
|
||||||
default=4,
|
|
||||||
help="Sets number of actors for training.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cpus-per-actor",
|
"--cpus-per-actor",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="The number of CPUs per actor for training.")
|
help="The number of CPUs per actor for training.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-actors-inference",
|
"--num-actors-inference",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="Sets number of actors for inference.")
|
help="Sets number of actors for inference.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cpus-per-actor-inference",
|
"--cpus-per-actor-inference",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The number of CPUs per actor for inference.")
|
help="The number of CPUs per actor for inference.",
|
||||||
|
)
|
||||||
# Ignore -f from ipykernel_launcher
|
# Ignore -f from ipykernel_launcher
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
@ -119,12 +122,13 @@ if not ray.is_initialized():
|
||||||
LABEL_COLUMN = "label"
|
LABEL_COLUMN = "label"
|
||||||
if smoke_test:
|
if smoke_test:
|
||||||
# Test dataset with only 10,000 records.
|
# Test dataset with only 10,000 records.
|
||||||
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \
|
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" ".csv"
|
||||||
".csv"
|
|
||||||
else:
|
else:
|
||||||
# Full dataset. This may take a couple of minutes to load.
|
# Full dataset. This may take a couple of minutes to load.
|
||||||
FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \
|
FILE_URL = (
|
||||||
"/00280/HIGGS.csv.gz"
|
"https://archive.ics.uci.edu/ml/machine-learning-databases"
|
||||||
|
"/00280/HIGGS.csv.gz"
|
||||||
|
)
|
||||||
|
|
||||||
colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)]
|
colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)]
|
||||||
|
|
||||||
|
@ -182,7 +186,8 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
|
||||||
evals_result=evals_result,
|
evals_result=evals_result,
|
||||||
verbose_eval=False,
|
verbose_eval=False,
|
||||||
num_boost_round=100,
|
num_boost_round=100,
|
||||||
ray_params=ray_params)
|
ray_params=ray_params,
|
||||||
|
)
|
||||||
|
|
||||||
train_end_time = time.time()
|
train_end_time = time.time()
|
||||||
train_duration = train_end_time - train_start_time
|
train_duration = train_end_time - train_start_time
|
||||||
|
@ -190,8 +195,7 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
|
||||||
|
|
||||||
model_path = "model.xgb"
|
model_path = "model.xgb"
|
||||||
bst.save_model(model_path)
|
bst.save_model(model_path)
|
||||||
print("Final validation error: {:.4f}".format(
|
print("Final validation error: {:.4f}".format(evals_result["eval"]["error"][-1]))
|
||||||
evals_result["eval"]["error"][-1]))
|
|
||||||
|
|
||||||
return bst, evals_result
|
return bst, evals_result
|
||||||
|
|
||||||
|
@ -208,8 +212,12 @@ config = {
|
||||||
}
|
}
|
||||||
|
|
||||||
bst, evals_result = train_xgboost(
|
bst, evals_result = train_xgboost(
|
||||||
config, df_train, df_validation, LABEL_COLUMN,
|
config,
|
||||||
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors))
|
df_train,
|
||||||
|
df_validation,
|
||||||
|
LABEL_COLUMN,
|
||||||
|
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors),
|
||||||
|
)
|
||||||
print(f"Results: {evals_result}")
|
print(f"Results: {evals_result}")
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
@ -227,7 +235,8 @@ results = predict(
|
||||||
bst,
|
bst,
|
||||||
inference_df,
|
inference_df,
|
||||||
ray_params=RayParams(
|
ray_params=RayParams(
|
||||||
cpus_per_actor=cpus_per_actor_inference,
|
cpus_per_actor=cpus_per_actor_inference, num_actors=num_actors_inference
|
||||||
num_actors=num_actors_inference))
|
),
|
||||||
|
)
|
||||||
|
|
||||||
print(results)
|
print(results)
|
||||||
|
|
|
@ -12,10 +12,12 @@ class NewsServer(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.conn = sqlite3.connect("newsreader.db")
|
self.conn = sqlite3.connect("newsreader.db")
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
c.execute("""CREATE TABLE IF NOT EXISTS news
|
c.execute(
|
||||||
|
"""CREATE TABLE IF NOT EXISTS news
|
||||||
(title text, link text,
|
(title text, link text,
|
||||||
description text, published timestamp,
|
description text, published timestamp,
|
||||||
feed url, liked bool)""")
|
feed url, liked bool)"""
|
||||||
|
)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
def retrieve_feed(self, url):
|
def retrieve_feed(self, url):
|
||||||
|
@ -24,36 +26,41 @@ class NewsServer(object):
|
||||||
items = []
|
items = []
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
for item in feed.items:
|
for item in feed.items:
|
||||||
items.append({
|
items.append(
|
||||||
"title": item.title,
|
{
|
||||||
"link": item.link,
|
"title": item.title,
|
||||||
"description": item.description,
|
"link": item.link,
|
||||||
"description_text": item.description,
|
"description": item.description,
|
||||||
"pubDate": str(item.pub_date)
|
"description_text": item.description,
|
||||||
})
|
"pubDate": str(item.pub_date),
|
||||||
|
}
|
||||||
|
)
|
||||||
c.execute(
|
c.execute(
|
||||||
"""INSERT INTO news (title, link, description,
|
"""INSERT INTO news (title, link, description,
|
||||||
published, feed, liked) values
|
published, feed, liked) values
|
||||||
(?, ?, ?, ?, ?, ?)""",
|
(?, ?, ?, ?, ?, ?)""",
|
||||||
(item.title, item.link, item.description, item.pub_date,
|
(
|
||||||
feed.link, False))
|
item.title,
|
||||||
|
item.link,
|
||||||
|
item.description,
|
||||||
|
item.pub_date,
|
||||||
|
feed.link,
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"channel": {
|
"channel": {"title": feed.title, "link": feed.link, "url": feed.link},
|
||||||
"title": feed.title,
|
"items": items,
|
||||||
"link": feed.link,
|
|
||||||
"url": feed.link
|
|
||||||
},
|
|
||||||
"items": items
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def like_item(self, url, is_faved):
|
def like_item(self, url, is_faved):
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
if is_faved:
|
if is_faved:
|
||||||
c.execute("UPDATE news SET liked = 1 WHERE link = ?", (url, ))
|
c.execute("UPDATE news SET liked = 1 WHERE link = ?", (url,))
|
||||||
else:
|
else:
|
||||||
c.execute("UPDATE news SET liked = 0 WHERE link = ?", (url, ))
|
c.execute("UPDATE news SET liked = 0 WHERE link = ?", (url,))
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,9 +84,7 @@ def dispatcher():
|
||||||
result = ray.get(method.remote(*method_args))
|
result = ray.get(method.remote(*method_args))
|
||||||
return jsonify(result)
|
return jsonify(result)
|
||||||
else:
|
else:
|
||||||
return jsonify({
|
return jsonify({"error": "method_name '" + method_name + "' not found"})
|
||||||
"error": "method_name '" + method_name + "' not found"
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -44,16 +44,16 @@ num_evaluations = 10
|
||||||
# A function for generating random hyperparameters.
|
# A function for generating random hyperparameters.
|
||||||
def generate_hyperparameters():
|
def generate_hyperparameters():
|
||||||
return {
|
return {
|
||||||
"learning_rate": 10**np.random.uniform(-5, 1),
|
"learning_rate": 10 ** np.random.uniform(-5, 1),
|
||||||
"batch_size": np.random.randint(1, 100),
|
"batch_size": np.random.randint(1, 100),
|
||||||
"momentum": np.random.uniform(0, 1)
|
"momentum": np.random.uniform(0, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_data_loaders(batch_size):
|
def get_data_loaders(batch_size):
|
||||||
mnist_transforms = transforms.Compose(
|
mnist_transforms = transforms.Compose(
|
||||||
[transforms.ToTensor(),
|
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))])
|
)
|
||||||
|
|
||||||
# We add FileLock here because multiple workers will want to
|
# We add FileLock here because multiple workers will want to
|
||||||
# download data, and this may cause overwrites since
|
# download data, and this may cause overwrites since
|
||||||
|
@ -61,16 +61,16 @@ def get_data_loaders(batch_size):
|
||||||
with FileLock(os.path.expanduser("~/data.lock")):
|
with FileLock(os.path.expanduser("~/data.lock")):
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
datasets.MNIST(
|
datasets.MNIST(
|
||||||
"~/data",
|
"~/data", train=True, download=True, transform=mnist_transforms
|
||||||
train=True,
|
),
|
||||||
download=True,
|
|
||||||
transform=mnist_transforms),
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True)
|
shuffle=True,
|
||||||
|
)
|
||||||
test_loader = torch.utils.data.DataLoader(
|
test_loader = torch.utils.data.DataLoader(
|
||||||
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
|
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True)
|
shuffle=True,
|
||||||
|
)
|
||||||
return train_loader, test_loader
|
return train_loader, test_loader
|
||||||
|
|
||||||
|
|
||||||
|
@ -152,9 +152,8 @@ def evaluate_hyperparameters(config):
|
||||||
model = ConvNet()
|
model = ConvNet()
|
||||||
train_loader, test_loader = get_data_loaders(config["batch_size"])
|
train_loader, test_loader = get_data_loaders(config["batch_size"])
|
||||||
optimizer = optim.SGD(
|
optimizer = optim.SGD(
|
||||||
model.parameters(),
|
model.parameters(), lr=config["learning_rate"], momentum=config["momentum"]
|
||||||
lr=config["learning_rate"],
|
)
|
||||||
momentum=config["momentum"])
|
|
||||||
train(model, optimizer, train_loader)
|
train(model, optimizer, train_loader)
|
||||||
return test(model, test_loader)
|
return test(model, test_loader)
|
||||||
|
|
||||||
|
@ -202,22 +201,33 @@ while remaining_ids:
|
||||||
|
|
||||||
hyperparameters = hyperparameters_mapping[result_id]
|
hyperparameters = hyperparameters_mapping[result_id]
|
||||||
accuracy = ray.get(result_id)
|
accuracy = ray.get(result_id)
|
||||||
print("""We achieve accuracy {:.3}% with
|
print(
|
||||||
|
"""We achieve accuracy {:.3}% with
|
||||||
learning_rate: {:.2}
|
learning_rate: {:.2}
|
||||||
batch_size: {}
|
batch_size: {}
|
||||||
momentum: {:.2}
|
momentum: {:.2}
|
||||||
""".format(100 * accuracy, hyperparameters["learning_rate"],
|
""".format(
|
||||||
hyperparameters["batch_size"], hyperparameters["momentum"]))
|
100 * accuracy,
|
||||||
|
hyperparameters["learning_rate"],
|
||||||
|
hyperparameters["batch_size"],
|
||||||
|
hyperparameters["momentum"],
|
||||||
|
)
|
||||||
|
)
|
||||||
if accuracy > best_accuracy:
|
if accuracy > best_accuracy:
|
||||||
best_hyperparameters = hyperparameters
|
best_hyperparameters = hyperparameters
|
||||||
best_accuracy = accuracy
|
best_accuracy = accuracy
|
||||||
|
|
||||||
# Record the best performing set of hyperparameters.
|
# Record the best performing set of hyperparameters.
|
||||||
print("""Best accuracy over {} trials was {:.3} with
|
print(
|
||||||
|
"""Best accuracy over {} trials was {:.3} with
|
||||||
learning_rate: {:.2}
|
learning_rate: {:.2}
|
||||||
batch_size: {}
|
batch_size: {}
|
||||||
momentum: {:.2}
|
momentum: {:.2}
|
||||||
""".format(num_evaluations, 100 * best_accuracy,
|
""".format(
|
||||||
best_hyperparameters["learning_rate"],
|
num_evaluations,
|
||||||
best_hyperparameters["batch_size"],
|
100 * best_accuracy,
|
||||||
best_hyperparameters["momentum"]))
|
best_hyperparameters["learning_rate"],
|
||||||
|
best_hyperparameters["batch_size"],
|
||||||
|
best_hyperparameters["momentum"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -39,8 +39,8 @@ import ray
|
||||||
def get_data_loader():
|
def get_data_loader():
|
||||||
"""Safely downloads data. Returns training/validation set dataloader."""
|
"""Safely downloads data. Returns training/validation set dataloader."""
|
||||||
mnist_transforms = transforms.Compose(
|
mnist_transforms = transforms.Compose(
|
||||||
[transforms.ToTensor(),
|
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))])
|
)
|
||||||
|
|
||||||
# We add FileLock here because multiple workers will want to
|
# We add FileLock here because multiple workers will want to
|
||||||
# download data, and this may cause overwrites since
|
# download data, and this may cause overwrites since
|
||||||
|
@ -48,16 +48,16 @@ def get_data_loader():
|
||||||
with FileLock(os.path.expanduser("~/data.lock")):
|
with FileLock(os.path.expanduser("~/data.lock")):
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
datasets.MNIST(
|
datasets.MNIST(
|
||||||
"~/data",
|
"~/data", train=True, download=True, transform=mnist_transforms
|
||||||
train=True,
|
),
|
||||||
download=True,
|
|
||||||
transform=mnist_transforms),
|
|
||||||
batch_size=128,
|
batch_size=128,
|
||||||
shuffle=True)
|
shuffle=True,
|
||||||
|
)
|
||||||
test_loader = torch.utils.data.DataLoader(
|
test_loader = torch.utils.data.DataLoader(
|
||||||
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
|
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
|
||||||
batch_size=128,
|
batch_size=128,
|
||||||
shuffle=True)
|
shuffle=True,
|
||||||
|
)
|
||||||
return train_loader, test_loader
|
return train_loader, test_loader
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ def evaluate(model, test_loader):
|
||||||
_, predicted = torch.max(outputs.data, 1)
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
total += target.size(0)
|
total += target.size(0)
|
||||||
correct += (predicted == target).sum().item()
|
correct += (predicted == target).sum().item()
|
||||||
return 100. * correct / total
|
return 100.0 * correct / total
|
||||||
|
|
||||||
|
|
||||||
#######################################################################
|
#######################################################################
|
||||||
|
@ -144,8 +144,7 @@ class ParameterServer(object):
|
||||||
|
|
||||||
def apply_gradients(self, *gradients):
|
def apply_gradients(self, *gradients):
|
||||||
summed_gradients = [
|
summed_gradients = [
|
||||||
np.stack(gradient_zip).sum(axis=0)
|
np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*gradients)
|
||||||
for gradient_zip in zip(*gradients)
|
|
||||||
]
|
]
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.model.set_gradients(summed_gradients)
|
self.model.set_gradients(summed_gradients)
|
||||||
|
@ -215,9 +214,7 @@ test_loader = get_data_loader()[1]
|
||||||
print("Running synchronous parameter server training.")
|
print("Running synchronous parameter server training.")
|
||||||
current_weights = ps.get_weights.remote()
|
current_weights = ps.get_weights.remote()
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
gradients = [
|
gradients = [worker.compute_gradients.remote(current_weights) for worker in workers]
|
||||||
worker.compute_gradients.remote(current_weights) for worker in workers
|
|
||||||
]
|
|
||||||
# Calculate update after all gradients are available.
|
# Calculate update after all gradients are available.
|
||||||
current_weights = ps.apply_gradients.remote(*gradients)
|
current_weights = ps.apply_gradients.remote(*gradients)
|
||||||
|
|
||||||
|
|
|
@ -197,7 +197,7 @@ class Model(object):
|
||||||
"""Applies the gradients to the model parameters with RMSProp."""
|
"""Applies the gradients to the model parameters with RMSProp."""
|
||||||
for k, v in self.weights.items():
|
for k, v in self.weights.items():
|
||||||
g = grad_buffer[k]
|
g = grad_buffer[k]
|
||||||
rmsprop_cache[k] = (decay * rmsprop_cache[k] + (1 - decay) * g**2)
|
rmsprop_cache[k] = decay * rmsprop_cache[k] + (1 - decay) * g ** 2
|
||||||
self.weights[k] += lr * g / (np.sqrt(rmsprop_cache[k]) + 1e-5)
|
self.weights[k] += lr * g / (np.sqrt(rmsprop_cache[k]) + 1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@ -278,20 +278,24 @@ for i in range(1, 1 + iterations):
|
||||||
gradient_ids = []
|
gradient_ids = []
|
||||||
# Launch tasks to compute gradients from multiple rollouts in parallel.
|
# Launch tasks to compute gradients from multiple rollouts in parallel.
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
gradient_ids = [
|
gradient_ids = [actor.compute_gradient.remote(model_id) for actor in actors]
|
||||||
actor.compute_gradient.remote(model_id) for actor in actors
|
|
||||||
]
|
|
||||||
for batch in range(batch_size):
|
for batch in range(batch_size):
|
||||||
[grad_id], gradient_ids = ray.wait(gradient_ids)
|
[grad_id], gradient_ids = ray.wait(gradient_ids)
|
||||||
grad, reward_sum = ray.get(grad_id)
|
grad, reward_sum = ray.get(grad_id)
|
||||||
# Accumulate the gradient over batch.
|
# Accumulate the gradient over batch.
|
||||||
for k in model.weights:
|
for k in model.weights:
|
||||||
grad_buffer[k] += grad[k]
|
grad_buffer[k] += grad[k]
|
||||||
running_reward = (reward_sum if running_reward is None else
|
running_reward = (
|
||||||
running_reward * 0.99 + reward_sum * 0.01)
|
reward_sum
|
||||||
|
if running_reward is None
|
||||||
|
else running_reward * 0.99 + reward_sum * 0.01
|
||||||
|
)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print("Batch {} computed {} rollouts in {} seconds, "
|
print(
|
||||||
"running mean is {}".format(i, batch_size, end_time - start_time,
|
"Batch {} computed {} rollouts in {} seconds, "
|
||||||
running_reward))
|
"running mean is {}".format(
|
||||||
|
i, batch_size, end_time - start_time, running_reward
|
||||||
|
)
|
||||||
|
)
|
||||||
model.update(grad_buffer, rmsprop_cache, learning_rate, decay_rate)
|
model.update(grad_buffer, rmsprop_cache, learning_rate, decay_rate)
|
||||||
zero_grads(grad_buffer)
|
zero_grads(grad_buffer)
|
||||||
|
|
|
@ -23,6 +23,7 @@ from typing import Tuple
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
# For typing purposes
|
# For typing purposes
|
||||||
from ray.actor import ActorHandle
|
from ray.actor import ActorHandle
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
|
@ -8,12 +8,11 @@ import wikipedia
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-mappers", help="number of mapper actors used", default=3, type=int)
|
"--num-mappers", help="number of mapper actors used", default=3, type=int
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-reducers",
|
"--num-reducers", help="number of reducer actors used", default=4, type=int
|
||||||
help="number of reducer actors used",
|
)
|
||||||
default=4,
|
|
||||||
type=int)
|
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -36,8 +35,11 @@ class Mapper(object):
|
||||||
while self.num_articles_processed < article_index + 1:
|
while self.num_articles_processed < article_index + 1:
|
||||||
self.get_new_article()
|
self.get_new_article()
|
||||||
# Return the word counts from within a given character range.
|
# Return the word counts from within a given character range.
|
||||||
return [(k, v) for k, v in self.word_counts[article_index].items()
|
return [
|
||||||
if len(k) >= 1 and k[0] >= keys[0] and k[0] <= keys[1]]
|
(k, v)
|
||||||
|
for k, v in self.word_counts[article_index].items()
|
||||||
|
if len(k) >= 1 and k[0] >= keys[0] and k[0] <= keys[1]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -51,8 +53,7 @@ class Reducer(object):
|
||||||
# Get the word counts for this Reducer's keys from all of the Mappers
|
# Get the word counts for this Reducer's keys from all of the Mappers
|
||||||
# and aggregate the results.
|
# and aggregate the results.
|
||||||
count_ids = [
|
count_ids = [
|
||||||
mapper.get_range.remote(article_index, self.keys)
|
mapper.get_range.remote(article_index, self.keys) for mapper in self.mappers
|
||||||
for mapper in self.mappers
|
|
||||||
]
|
]
|
||||||
|
|
||||||
while len(count_ids) > 0:
|
while len(count_ids) > 0:
|
||||||
|
@ -87,9 +88,9 @@ if __name__ == "__main__":
|
||||||
streams.append(Stream([line.strip() for line in f.readlines()]))
|
streams.append(Stream([line.strip() for line in f.readlines()]))
|
||||||
|
|
||||||
# Partition the keys among the reducers.
|
# Partition the keys among the reducers.
|
||||||
chunks = np.array_split([chr(i)
|
chunks = np.array_split(
|
||||||
for i in range(ord("a"),
|
[chr(i) for i in range(ord("a"), ord("z") + 1)], args.num_reducers
|
||||||
ord("z") + 1)], args.num_reducers)
|
)
|
||||||
keys = [[chunk[0], chunk[-1]] for chunk in chunks]
|
keys = [[chunk[0], chunk[-1]] for chunk in chunks]
|
||||||
|
|
||||||
# Create a number of mappers.
|
# Create a number of mappers.
|
||||||
|
@ -103,14 +104,12 @@ if __name__ == "__main__":
|
||||||
while True:
|
while True:
|
||||||
print("article index = {}".format(article_index))
|
print("article index = {}".format(article_index))
|
||||||
wordcounts = {}
|
wordcounts = {}
|
||||||
counts = ray.get([
|
counts = ray.get(
|
||||||
reducer.next_reduce_result.remote(article_index)
|
[reducer.next_reduce_result.remote(article_index) for reducer in reducers]
|
||||||
for reducer in reducers
|
)
|
||||||
])
|
|
||||||
for count in counts:
|
for count in counts:
|
||||||
wordcounts.update(count)
|
wordcounts.update(count)
|
||||||
most_frequent_words = heapq.nlargest(
|
most_frequent_words = heapq.nlargest(10, wordcounts, key=wordcounts.get)
|
||||||
10, wordcounts, key=wordcounts.get)
|
|
||||||
for word in most_frequent_words:
|
for word in most_frequent_words:
|
||||||
print(" ", word, wordcounts[word])
|
print(" ", word, wordcounts[word])
|
||||||
article_index += 1
|
article_index += 1
|
||||||
|
|
|
@ -125,12 +125,12 @@ class MNISTDataInterface(object):
|
||||||
self.data_dir = data_dir
|
self.data_dir = data_dir
|
||||||
self.max_days = max_days
|
self.max_days = max_days
|
||||||
|
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose(
|
||||||
transforms.ToTensor(),
|
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
)
|
||||||
])
|
|
||||||
self.dataset = MNIST(
|
self.dataset = MNIST(
|
||||||
self.data_dir, train=True, download=True, transform=transform)
|
self.data_dir, train=True, download=True, transform=transform
|
||||||
|
)
|
||||||
|
|
||||||
def _get_day_slice(self, day=0):
|
def _get_day_slice(self, day=0):
|
||||||
if day < 0:
|
if day < 0:
|
||||||
|
@ -154,8 +154,7 @@ class MNISTDataInterface(object):
|
||||||
end = self._get_day_slice(day)
|
end = self._get_day_slice(day)
|
||||||
|
|
||||||
available_data = Subset(self.dataset, list(range(start, end)))
|
available_data = Subset(self.dataset, list(range(start, end)))
|
||||||
train_n = int(
|
train_n = int(0.8 * (end - start)) # 80% train data, 20% validation data
|
||||||
0.8 * (end - start)) # 80% train data, 20% validation data
|
|
||||||
|
|
||||||
return random_split(available_data, [train_n, end - start - train_n])
|
return random_split(available_data, [train_n, end - start - train_n])
|
||||||
|
|
||||||
|
@ -223,13 +222,15 @@ def test(model, data_loader, device=None):
|
||||||
# will take care of creating the model and optimizer and repeatedly
|
# will take care of creating the model and optimizer and repeatedly
|
||||||
# call the ``train`` function to train the model. Also, this function
|
# call the ``train`` function to train the model. Also, this function
|
||||||
# will report the training progress back to Tune.
|
# will report the training progress back to Tune.
|
||||||
def train_mnist(config,
|
def train_mnist(
|
||||||
start_model=None,
|
config,
|
||||||
checkpoint_dir=None,
|
start_model=None,
|
||||||
num_epochs=10,
|
checkpoint_dir=None,
|
||||||
use_gpus=False,
|
num_epochs=10,
|
||||||
data_fn=None,
|
use_gpus=False,
|
||||||
day=0):
|
data_fn=None,
|
||||||
|
day=0,
|
||||||
|
):
|
||||||
# Create model
|
# Create model
|
||||||
use_cuda = use_gpus and torch.cuda.is_available()
|
use_cuda = use_gpus and torch.cuda.is_available()
|
||||||
device = torch.device("cuda" if use_cuda else "cpu")
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||||||
|
@ -237,7 +238,8 @@ def train_mnist(config,
|
||||||
|
|
||||||
# Create optimizer
|
# Create optimizer
|
||||||
optimizer = optim.SGD(
|
optimizer = optim.SGD(
|
||||||
model.parameters(), lr=config["lr"], momentum=config["momentum"])
|
model.parameters(), lr=config["lr"], momentum=config["momentum"]
|
||||||
|
)
|
||||||
|
|
||||||
# Load checkpoint, or load start model if no checkpoint has been
|
# Load checkpoint, or load start model if no checkpoint has been
|
||||||
# passed and a start model is specified
|
# passed and a start model is specified
|
||||||
|
@ -248,8 +250,7 @@ def train_mnist(config,
|
||||||
load_dir = start_model
|
load_dir = start_model
|
||||||
|
|
||||||
if load_dir:
|
if load_dir:
|
||||||
model_state, optimizer_state = torch.load(
|
model_state, optimizer_state = torch.load(os.path.join(load_dir, "checkpoint"))
|
||||||
os.path.join(load_dir, "checkpoint"))
|
|
||||||
model.load_state_dict(model_state)
|
model.load_state_dict(model_state)
|
||||||
optimizer.load_state_dict(optimizer_state)
|
optimizer.load_state_dict(optimizer_state)
|
||||||
|
|
||||||
|
@ -257,18 +258,22 @@ def train_mnist(config,
|
||||||
train_dataset, validation_dataset = data_fn(day=day)
|
train_dataset, validation_dataset = data_fn(day=day)
|
||||||
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=config["batch_size"], shuffle=True)
|
train_dataset, batch_size=config["batch_size"], shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
validation_loader = torch.utils.data.DataLoader(
|
validation_loader = torch.utils.data.DataLoader(
|
||||||
validation_dataset, batch_size=config["batch_size"], shuffle=True)
|
validation_dataset, batch_size=config["batch_size"], shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(num_epochs):
|
for i in range(num_epochs):
|
||||||
train(model, optimizer, train_loader, device)
|
train(model, optimizer, train_loader, device)
|
||||||
acc = test(model, validation_loader, device)
|
acc = test(model, validation_loader, device)
|
||||||
if i == num_epochs - 1:
|
if i == num_epochs - 1:
|
||||||
with tune.checkpoint_dir(step=i) as checkpoint_dir:
|
with tune.checkpoint_dir(step=i) as checkpoint_dir:
|
||||||
torch.save((model.state_dict(), optimizer.state_dict()),
|
torch.save(
|
||||||
os.path.join(checkpoint_dir, "checkpoint"))
|
(model.state_dict(), optimizer.state_dict()),
|
||||||
|
os.path.join(checkpoint_dir, "checkpoint"),
|
||||||
|
)
|
||||||
tune.report(mean_accuracy=acc, done=True)
|
tune.report(mean_accuracy=acc, done=True)
|
||||||
else:
|
else:
|
||||||
tune.report(mean_accuracy=acc)
|
tune.report(mean_accuracy=acc)
|
||||||
|
@ -286,7 +291,7 @@ def train_mnist(config,
|
||||||
# until the given day. Our search space can thus also contain parameters
|
# until the given day. Our search space can thus also contain parameters
|
||||||
# that affect the model complexity (such as the layer size), since it
|
# that affect the model complexity (such as the layer size), since it
|
||||||
# does not have to be compatible to an existing model.
|
# does not have to be compatible to an existing model.
|
||||||
def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
|
def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0):
|
||||||
data_interface = MNISTDataInterface("~/data", max_days=10)
|
data_interface = MNISTDataInterface("~/data", max_days=10)
|
||||||
num_examples = data_interface._get_day_slice(day)
|
num_examples = data_interface._get_day_slice(day)
|
||||||
|
|
||||||
|
@ -302,11 +307,13 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
|
||||||
mode="max",
|
mode="max",
|
||||||
max_t=num_epochs,
|
max_t=num_epochs,
|
||||||
grace_period=1,
|
grace_period=1,
|
||||||
reduction_factor=2)
|
reduction_factor=2,
|
||||||
|
)
|
||||||
|
|
||||||
reporter = CLIReporter(
|
reporter = CLIReporter(
|
||||||
parameter_columns=["layer_size", "lr", "momentum", "batch_size"],
|
parameter_columns=["layer_size", "lr", "momentum", "batch_size"],
|
||||||
metric_columns=["mean_accuracy", "training_iteration"])
|
metric_columns=["mean_accuracy", "training_iteration"],
|
||||||
|
)
|
||||||
|
|
||||||
analysis = tune.run(
|
analysis = tune.run(
|
||||||
partial(
|
partial(
|
||||||
|
@ -315,17 +322,16 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
|
||||||
data_fn=data_interface.get_data,
|
data_fn=data_interface.get_data,
|
||||||
num_epochs=num_epochs,
|
num_epochs=num_epochs,
|
||||||
use_gpus=True if gpus_per_trial > 0 else False,
|
use_gpus=True if gpus_per_trial > 0 else False,
|
||||||
day=day),
|
day=day,
|
||||||
resources_per_trial={
|
),
|
||||||
"cpu": 1,
|
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
|
||||||
"gpu": gpus_per_trial
|
|
||||||
},
|
|
||||||
config=config,
|
config=config,
|
||||||
num_samples=num_samples,
|
num_samples=num_samples,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
progress_reporter=reporter,
|
progress_reporter=reporter,
|
||||||
verbose=0,
|
verbose=0,
|
||||||
name="tune_serve_mnist_fromscratch")
|
name="tune_serve_mnist_fromscratch",
|
||||||
|
)
|
||||||
|
|
||||||
best_trial = analysis.get_best_trial("mean_accuracy", "max", "last")
|
best_trial = analysis.get_best_trial("mean_accuracy", "max", "last")
|
||||||
best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"]
|
best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"]
|
||||||
|
@ -344,33 +350,35 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
|
||||||
# layer size parameter. Since we continue to train an existing model,
|
# layer size parameter. Since we continue to train an existing model,
|
||||||
# we cannot change the layer size mid training, so we just continue
|
# we cannot change the layer size mid training, so we just continue
|
||||||
# to use the existing one.
|
# to use the existing one.
|
||||||
def tune_from_existing(start_model,
|
def tune_from_existing(
|
||||||
start_config,
|
start_model, start_config, num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0
|
||||||
num_samples=10,
|
):
|
||||||
num_epochs=10,
|
|
||||||
gpus_per_trial=0.,
|
|
||||||
day=0):
|
|
||||||
data_interface = MNISTDataInterface("/tmp/mnist_data", max_days=10)
|
data_interface = MNISTDataInterface("/tmp/mnist_data", max_days=10)
|
||||||
num_examples = data_interface._get_day_slice(day) - \
|
num_examples = data_interface._get_day_slice(day) - data_interface._get_day_slice(
|
||||||
data_interface._get_day_slice(day - 1)
|
day - 1
|
||||||
|
)
|
||||||
|
|
||||||
config = start_config.copy()
|
config = start_config.copy()
|
||||||
config.update({
|
config.update(
|
||||||
"batch_size": tune.choice([16, 32, 64]),
|
{
|
||||||
"lr": tune.loguniform(1e-4, 1e-1),
|
"batch_size": tune.choice([16, 32, 64]),
|
||||||
"momentum": tune.uniform(0.1, 0.9),
|
"lr": tune.loguniform(1e-4, 1e-1),
|
||||||
})
|
"momentum": tune.uniform(0.1, 0.9),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
scheduler = ASHAScheduler(
|
scheduler = ASHAScheduler(
|
||||||
metric="mean_accuracy",
|
metric="mean_accuracy",
|
||||||
mode="max",
|
mode="max",
|
||||||
max_t=num_epochs,
|
max_t=num_epochs,
|
||||||
grace_period=1,
|
grace_period=1,
|
||||||
reduction_factor=2)
|
reduction_factor=2,
|
||||||
|
)
|
||||||
|
|
||||||
reporter = CLIReporter(
|
reporter = CLIReporter(
|
||||||
parameter_columns=["lr", "momentum", "batch_size"],
|
parameter_columns=["lr", "momentum", "batch_size"],
|
||||||
metric_columns=["mean_accuracy", "training_iteration"])
|
metric_columns=["mean_accuracy", "training_iteration"],
|
||||||
|
)
|
||||||
|
|
||||||
analysis = tune.run(
|
analysis = tune.run(
|
||||||
partial(
|
partial(
|
||||||
|
@ -379,17 +387,16 @@ def tune_from_existing(start_model,
|
||||||
data_fn=data_interface.get_incremental_data,
|
data_fn=data_interface.get_incremental_data,
|
||||||
num_epochs=num_epochs,
|
num_epochs=num_epochs,
|
||||||
use_gpus=True if gpus_per_trial > 0 else False,
|
use_gpus=True if gpus_per_trial > 0 else False,
|
||||||
day=day),
|
day=day,
|
||||||
resources_per_trial={
|
),
|
||||||
"cpu": 1,
|
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
|
||||||
"gpu": gpus_per_trial
|
|
||||||
},
|
|
||||||
config=config,
|
config=config,
|
||||||
num_samples=num_samples,
|
num_samples=num_samples,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
progress_reporter=reporter,
|
progress_reporter=reporter,
|
||||||
verbose=0,
|
verbose=0,
|
||||||
name="tune_serve_mnist_fromsexisting")
|
name="tune_serve_mnist_fromsexisting",
|
||||||
|
)
|
||||||
|
|
||||||
best_trial = analysis.get_best_trial("mean_accuracy", "max", "last")
|
best_trial = analysis.get_best_trial("mean_accuracy", "max", "last")
|
||||||
best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"]
|
best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"]
|
||||||
|
@ -423,8 +430,8 @@ class MNISTDeployment:
|
||||||
model = ConvNet(layer_size=self.config["layer_size"]).to(self.device)
|
model = ConvNet(layer_size=self.config["layer_size"]).to(self.device)
|
||||||
|
|
||||||
model_state, optimizer_state = torch.load(
|
model_state, optimizer_state = torch.load(
|
||||||
os.path.join(self.checkpoint_dir, "checkpoint"),
|
os.path.join(self.checkpoint_dir, "checkpoint"), map_location=self.device
|
||||||
map_location=self.device)
|
)
|
||||||
model.load_state_dict(model_state)
|
model.load_state_dict(model_state)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -442,12 +449,12 @@ class MNISTDeployment:
|
||||||
# active model. We call this directory ``model_dir``. Every time we
|
# active model. We call this directory ``model_dir``. Every time we
|
||||||
# would like to update our model, we copy the checkpoint of the new
|
# would like to update our model, we copy the checkpoint of the new
|
||||||
# model to this directory. We then update the deployment to the new version.
|
# model to this directory. We then update the deployment to the new version.
|
||||||
def serve_new_model(model_dir, checkpoint, config, metrics, day,
|
def serve_new_model(model_dir, checkpoint, config, metrics, day, use_gpu=False):
|
||||||
use_gpu=False):
|
|
||||||
print("Serving checkpoint: {}".format(checkpoint))
|
print("Serving checkpoint: {}".format(checkpoint))
|
||||||
|
|
||||||
checkpoint_path = _move_checkpoint_to_model_dir(model_dir, checkpoint,
|
checkpoint_path = _move_checkpoint_to_model_dir(
|
||||||
config, metrics)
|
model_dir, checkpoint, config, metrics
|
||||||
|
)
|
||||||
|
|
||||||
serve.start(detached=True)
|
serve.start(detached=True)
|
||||||
MNISTDeployment.deploy(checkpoint_path, config, metrics, use_gpu)
|
MNISTDeployment.deploy(checkpoint_path, config, metrics, use_gpu)
|
||||||
|
@ -482,8 +489,7 @@ def get_current_model(model_dir):
|
||||||
checkpoint_path = os.path.join(model_dir, "checkpoint")
|
checkpoint_path = os.path.join(model_dir, "checkpoint")
|
||||||
meta_path = os.path.join(model_dir, "meta.json")
|
meta_path = os.path.join(model_dir, "meta.json")
|
||||||
|
|
||||||
if not os.path.exists(checkpoint_path) or \
|
if not os.path.exists(checkpoint_path) or not os.path.exists(meta_path):
|
||||||
not os.path.exists(meta_path):
|
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
with open(meta_path, "rt") as fp:
|
with open(meta_path, "rt") as fp:
|
||||||
|
@ -559,28 +565,33 @@ if __name__ == "__main__":
|
||||||
"--from_scratch",
|
"--from_scratch",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Train and select best model from scratch",
|
help="Train and select best model from scratch",
|
||||||
default=False)
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--from_existing",
|
"--from_existing",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Train and select best model from existing model",
|
help="Train and select best model from existing model",
|
||||||
default=False)
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--day",
|
"--day",
|
||||||
help="Indicate the day to simulate the amount of data available to us",
|
help="Indicate the day to simulate the amount of data available to us",
|
||||||
type=int,
|
type=int,
|
||||||
default=0)
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--query", help="Query endpoint with example", type=int, default=-1)
|
"--query", help="Query endpoint with example", type=int, default=-1
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--smoke-test",
|
"--smoke-test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Finish quickly for testing",
|
help="Finish quickly for testing",
|
||||||
default=False)
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -600,20 +611,23 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Query our model
|
# Query our model
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"http://localhost:8000/mnist",
|
"http://localhost:8000/mnist", json={"images": [data[0].numpy().tolist()]}
|
||||||
json={"images": [data[0].numpy().tolist()]})
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pred = response.json()["result"][0]
|
pred = response.json()["result"][0]
|
||||||
except: # noqa: E722
|
except: # noqa: E722
|
||||||
pred = -1
|
pred = -1
|
||||||
|
|
||||||
print("Querying model with example #{}. "
|
print(
|
||||||
"Label = {}, Response = {}, Correct = {}".format(
|
"Querying model with example #{}. "
|
||||||
args.query, label, pred, label == pred))
|
"Label = {}, Response = {}, Correct = {}".format(
|
||||||
|
args.query, label, pred, label == pred
|
||||||
|
)
|
||||||
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
gpus_per_trial = 0.5 if not args.smoke_test else 0.
|
gpus_per_trial = 0.5 if not args.smoke_test else 0.0
|
||||||
serve_gpu = True if gpus_per_trial > 0 else False
|
serve_gpu = True if gpus_per_trial > 0 else False
|
||||||
num_samples = 8 if not args.smoke_test else 1
|
num_samples = 8 if not args.smoke_test else 1
|
||||||
num_epochs = 10 if not args.smoke_test else 1
|
num_epochs = 10 if not args.smoke_test else 1
|
||||||
|
@ -621,23 +635,22 @@ if __name__ == "__main__":
|
||||||
if args.from_scratch: # train everyday from scratch
|
if args.from_scratch: # train everyday from scratch
|
||||||
print("Start training job from scratch on day {}.".format(args.day))
|
print("Start training job from scratch on day {}.".format(args.day))
|
||||||
acc, config, best_checkpoint, num_examples = tune_from_scratch(
|
acc, config, best_checkpoint, num_examples = tune_from_scratch(
|
||||||
num_samples, num_epochs, gpus_per_trial, day=args.day)
|
num_samples, num_epochs, gpus_per_trial, day=args.day
|
||||||
print("Trained day {} from scratch on {} samples. "
|
)
|
||||||
"Best accuracy: {:.4f}. Best config: {}".format(
|
print(
|
||||||
args.day, num_examples, acc, config))
|
"Trained day {} from scratch on {} samples. "
|
||||||
|
"Best accuracy: {:.4f}. Best config: {}".format(
|
||||||
|
args.day, num_examples, acc, config
|
||||||
|
)
|
||||||
|
)
|
||||||
serve_new_model(
|
serve_new_model(
|
||||||
model_dir,
|
model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu
|
||||||
best_checkpoint,
|
)
|
||||||
config,
|
|
||||||
acc,
|
|
||||||
args.day,
|
|
||||||
use_gpu=serve_gpu)
|
|
||||||
|
|
||||||
if args.from_existing:
|
if args.from_existing:
|
||||||
old_checkpoint, old_config, old_acc = get_current_model(model_dir)
|
old_checkpoint, old_config, old_acc = get_current_model(model_dir)
|
||||||
if not old_checkpoint or not old_config or not old_acc:
|
if not old_checkpoint or not old_config or not old_acc:
|
||||||
print("No existing model found. Train one with --from_scratch "
|
print("No existing model found. Train one with --from_scratch " "first.")
|
||||||
"first.")
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
acc, config, best_checkpoint, num_examples = tune_from_existing(
|
acc, config, best_checkpoint, num_examples = tune_from_existing(
|
||||||
old_checkpoint,
|
old_checkpoint,
|
||||||
|
@ -645,17 +658,17 @@ if __name__ == "__main__":
|
||||||
num_samples,
|
num_samples,
|
||||||
num_epochs,
|
num_epochs,
|
||||||
gpus_per_trial,
|
gpus_per_trial,
|
||||||
day=args.day)
|
day=args.day,
|
||||||
print("Trained day {} from existing on {} samples. "
|
)
|
||||||
"Best accuracy: {:.4f}. Best config: {}".format(
|
print(
|
||||||
args.day, num_examples, acc, config))
|
"Trained day {} from existing on {} samples. "
|
||||||
|
"Best accuracy: {:.4f}. Best config: {}".format(
|
||||||
|
args.day, num_examples, acc, config
|
||||||
|
)
|
||||||
|
)
|
||||||
serve_new_model(
|
serve_new_model(
|
||||||
model_dir,
|
model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu
|
||||||
best_checkpoint,
|
)
|
||||||
config,
|
|
||||||
acc,
|
|
||||||
args.day,
|
|
||||||
use_gpu=serve_gpu)
|
|
||||||
|
|
||||||
#######################################################################
|
#######################################################################
|
||||||
# That's it! We now have an end-to-end workflow to train and update a
|
# That's it! We now have an end-to-end workflow to train and update a
|
||||||
|
|
|
@ -38,6 +38,7 @@ To start out, change the import statement to get tune-scikit-learn’s grid sear
|
||||||
"""
|
"""
|
||||||
# Keep this here for https://github.com/ray-project/ray/issues/11547
|
# Keep this here for https://github.com/ray-project/ray/issues/11547
|
||||||
from sklearn.model_selection import GridSearchCV
|
from sklearn.model_selection import GridSearchCV
|
||||||
|
|
||||||
# Replace above line with:
|
# Replace above line with:
|
||||||
from ray.tune.sklearn import TuneGridSearchCV
|
from ray.tune.sklearn import TuneGridSearchCV
|
||||||
|
|
||||||
|
@ -60,7 +61,8 @@ X, y = make_classification(
|
||||||
n_informative=50,
|
n_informative=50,
|
||||||
n_redundant=0,
|
n_redundant=0,
|
||||||
n_classes=10,
|
n_classes=10,
|
||||||
class_sep=2.5)
|
class_sep=2.5,
|
||||||
|
)
|
||||||
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=1000)
|
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=1000)
|
||||||
|
|
||||||
# Example parameters to tune from SGDClassifier
|
# Example parameters to tune from SGDClassifier
|
||||||
|
@ -70,9 +72,11 @@ parameter_grid = {"alpha": [1e-4, 1e-1, 1], "epsilon": [0.01, 0.1]}
|
||||||
# As you can see, the setup here is exactly how you would do it for Scikit-Learn. Now, let's try fitting a model.
|
# As you can see, the setup here is exactly how you would do it for Scikit-Learn. Now, let's try fitting a model.
|
||||||
|
|
||||||
tune_search = TuneGridSearchCV(
|
tune_search = TuneGridSearchCV(
|
||||||
SGDClassifier(), parameter_grid, early_stopping=True, max_iters=10)
|
SGDClassifier(), parameter_grid, early_stopping=True, max_iters=10
|
||||||
|
)
|
||||||
|
|
||||||
import time # Just to compare fit times
|
import time # Just to compare fit times
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
tune_search.fit(x_train, y_train)
|
tune_search.fit(x_train, y_train)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
@ -93,6 +97,7 @@ print("Tune GridSearch Fit Time:", end - start)
|
||||||
# Try running this compared to the GridSearchCV equivalent, and see the speedup for yourself!
|
# Try running this compared to the GridSearchCV equivalent, and see the speedup for yourself!
|
||||||
|
|
||||||
from sklearn.model_selection import GridSearchCV
|
from sklearn.model_selection import GridSearchCV
|
||||||
|
|
||||||
# n_jobs=-1 enables use of all cores like Tune does
|
# n_jobs=-1 enables use of all cores like Tune does
|
||||||
sklearn_search = GridSearchCV(SGDClassifier(), parameter_grid, n_jobs=-1)
|
sklearn_search = GridSearchCV(SGDClassifier(), parameter_grid, n_jobs=-1)
|
||||||
|
|
||||||
|
@ -120,7 +125,7 @@ import numpy as np
|
||||||
digits = datasets.load_digits()
|
digits = datasets.load_digits()
|
||||||
x = digits.data
|
x = digits.data
|
||||||
y = digits.target
|
y = digits.target
|
||||||
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.2)
|
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
|
||||||
|
|
||||||
clf = SGDClassifier()
|
clf = SGDClassifier()
|
||||||
parameter_grid = {"alpha": (1e-4, 1), "epsilon": (0.01, 0.1)}
|
parameter_grid = {"alpha": (1e-4, 1), "epsilon": (0.01, 0.1)}
|
||||||
|
|
|
@ -8,8 +8,9 @@ import ray
|
||||||
def get_host_name(x):
|
def get_host_name(x):
|
||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
|
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
return x + (platform.node(), )
|
return x + (platform.node(),)
|
||||||
|
|
||||||
|
|
||||||
def wait_for_nodes(expected):
|
def wait_for_nodes(expected):
|
||||||
|
@ -17,8 +18,11 @@ def wait_for_nodes(expected):
|
||||||
while True:
|
while True:
|
||||||
num_nodes = len(ray.nodes())
|
num_nodes = len(ray.nodes())
|
||||||
if num_nodes < expected:
|
if num_nodes < expected:
|
||||||
print("{} nodes have joined so far, waiting for {} more.".format(
|
print(
|
||||||
num_nodes, expected - num_nodes))
|
"{} nodes have joined so far, waiting for {} more.".format(
|
||||||
|
num_nodes, expected - num_nodes
|
||||||
|
)
|
||||||
|
)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
else:
|
else:
|
||||||
|
@ -31,9 +35,7 @@ def main():
|
||||||
# Check that objects can be transferred from each node to each other node.
|
# Check that objects can be transferred from each node to each other node.
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
print("Iteration {}".format(i))
|
print("Iteration {}".format(i))
|
||||||
results = [
|
results = [get_host_name.remote(get_host_name.remote(())) for _ in range(100)]
|
||||||
get_host_name.remote(get_host_name.remote(())) for _ in range(100)
|
|
||||||
]
|
|
||||||
print(Counter(ray.get(results)))
|
print(Counter(ray.get(results)))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,9 @@ def setup_logging() -> None:
|
||||||
setup_component_logger(
|
setup_component_logger(
|
||||||
logging_level=ray_constants.LOGGER_LEVEL, # info
|
logging_level=ray_constants.LOGGER_LEVEL, # info
|
||||||
logging_format=ray_constants.LOGGER_FORMAT,
|
logging_format=ray_constants.LOGGER_FORMAT,
|
||||||
log_dir=os.path.join(ray._private.utils.get_ray_temp_dir(),
|
log_dir=os.path.join(
|
||||||
ray.node.SESSION_LATEST, "logs"),
|
ray._private.utils.get_ray_temp_dir(), ray.node.SESSION_LATEST, "logs"
|
||||||
|
),
|
||||||
filename=ray_constants.MONITOR_LOG_FILE_NAME, # monitor.log
|
filename=ray_constants.MONITOR_LOG_FILE_NAME, # monitor.log
|
||||||
max_bytes=ray_constants.LOGGING_ROTATE_BYTES,
|
max_bytes=ray_constants.LOGGING_ROTATE_BYTES,
|
||||||
backup_count=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
|
backup_count=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
|
||||||
|
@ -47,11 +48,11 @@ if __name__ == "__main__":
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The password to use for Redis")
|
help="The password to use for Redis",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
cluster_name = yaml.safe_load(
|
cluster_name = yaml.safe_load(open(AUTOSCALING_CONFIG_PATH).read())["cluster_name"]
|
||||||
open(AUTOSCALING_CONFIG_PATH).read())["cluster_name"]
|
|
||||||
head_ip = get_node_ip_address()
|
head_ip = get_node_ip_address()
|
||||||
Monitor(
|
Monitor(
|
||||||
address=f"{head_ip}:6379",
|
address=f"{head_ip}:6379",
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue