mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
This commit is contained in:
parent
95877be8ee
commit
7f1bacc7dc
1637 changed files with 75167 additions and 59155 deletions
|
@ -39,14 +39,16 @@ def perform_auth():
|
|||
resp = requests.get(
|
||||
"https://vop4ss7n22.execute-api.us-west-2.amazonaws.com/endpoint/",
|
||||
auth=auth,
|
||||
params={"job_id": os.environ["BUILDKITE_JOB_ID"]})
|
||||
params={"job_id": os.environ["BUILDKITE_JOB_ID"]},
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
def handle_docker_login(resp):
|
||||
pwd = resp.json()["docker_password"]
|
||||
subprocess.call(
|
||||
["docker", "login", "--username", "raytravisbot", "--password", pwd])
|
||||
["docker", "login", "--username", "raytravisbot", "--password", pwd]
|
||||
)
|
||||
|
||||
|
||||
def gather_paths(dir_path) -> List[str]:
|
||||
|
@ -86,7 +88,7 @@ def upload_paths(paths, resp, destination):
|
|||
"branch_wheels": f"{branch}/{sha}/{fn}",
|
||||
"jars": f"jars/latest/{current_os}/{fn}",
|
||||
"branch_jars": f"jars/{branch}/{sha}/{current_os}/{fn}",
|
||||
"logs": f"bazel_events/{branch}/{sha}/{bk_job_id}/{fn}"
|
||||
"logs": f"bazel_events/{branch}/{sha}/{bk_job_id}/{fn}",
|
||||
}[destination]
|
||||
of["file"] = open(path, "rb")
|
||||
r = requests.post(c["url"], files=of)
|
||||
|
@ -95,14 +97,19 @@ def upload_paths(paths, resp, destination):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Helper script to upload files to S3 bucket")
|
||||
description="Helper script to upload files to S3 bucket"
|
||||
)
|
||||
parser.add_argument("--path", type=str, required=False)
|
||||
parser.add_argument("--destination", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.destination in {
|
||||
"branch_jars", "branch_wheels", "jars", "logs", "wheels",
|
||||
"docker_login"
|
||||
"branch_jars",
|
||||
"branch_wheels",
|
||||
"jars",
|
||||
"logs",
|
||||
"wheels",
|
||||
"docker_login",
|
||||
}
|
||||
assert "BUILDKITE_JOB_ID" in os.environ
|
||||
assert "BUILDKITE_COMMIT" in os.environ
|
||||
|
|
|
@ -51,8 +51,10 @@ del monitor_actor
|
|||
test_utils.wait_for_condition(no_resource_leaks)
|
||||
|
||||
rate = MAX_ACTORS_IN_CLUSTER / (end_time - start_time)
|
||||
print(f"Success! Started {MAX_ACTORS_IN_CLUSTER} actors in "
|
||||
f"{end_time - start_time}s. ({rate} actors/s)")
|
||||
print(
|
||||
f"Success! Started {MAX_ACTORS_IN_CLUSTER} actors in "
|
||||
f"{end_time - start_time}s. ({rate} actors/s)"
|
||||
)
|
||||
|
||||
if "TEST_OUTPUT_JSON" in os.environ:
|
||||
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
||||
|
@ -62,6 +64,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
|
|||
"time": end_time - start_time,
|
||||
"success": "1",
|
||||
"_peak_memory": round(used_gb, 2),
|
||||
"_peak_process_memory": usage
|
||||
"_peak_process_memory": usage,
|
||||
}
|
||||
json.dump(results, out_file)
|
||||
|
|
|
@ -77,8 +77,10 @@ del monitor_actor
|
|||
test_utils.wait_for_condition(no_resource_leaks)
|
||||
|
||||
rate = MAX_PLACEMENT_GROUPS / (end_time - start_time)
|
||||
print(f"Success! Started {MAX_PLACEMENT_GROUPS} pgs in "
|
||||
f"{end_time - start_time}s. ({rate} pgs/s)")
|
||||
print(
|
||||
f"Success! Started {MAX_PLACEMENT_GROUPS} pgs in "
|
||||
f"{end_time - start_time}s. ({rate} pgs/s)"
|
||||
)
|
||||
|
||||
if "TEST_OUTPUT_JSON" in os.environ:
|
||||
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
||||
|
@ -88,6 +90,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
|
|||
"time": end_time - start_time,
|
||||
"success": "1",
|
||||
"_peak_memory": round(used_gb, 2),
|
||||
"_peak_process_memory": usage
|
||||
"_peak_process_memory": usage,
|
||||
}
|
||||
json.dump(results, out_file)
|
||||
|
|
|
@ -16,9 +16,7 @@ def test_max_running_tasks(num_tasks):
|
|||
def task():
|
||||
time.sleep(sleep_time)
|
||||
|
||||
refs = [
|
||||
task.remote() for _ in tqdm.trange(num_tasks, desc="Launching tasks")
|
||||
]
|
||||
refs = [task.remote() for _ in tqdm.trange(num_tasks, desc="Launching tasks")]
|
||||
|
||||
max_cpus = ray.cluster_resources()["CPU"]
|
||||
min_cpus_available = max_cpus
|
||||
|
@ -48,8 +46,7 @@ def no_resource_leaks():
|
|||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--num-tasks", required=True, type=int, help="Number of tasks to launch.")
|
||||
@click.option("--num-tasks", required=True, type=int, help="Number of tasks to launch.")
|
||||
def test(num_tasks):
|
||||
ray.init(address="auto")
|
||||
|
||||
|
@ -66,8 +63,10 @@ def test(num_tasks):
|
|||
test_utils.wait_for_condition(no_resource_leaks)
|
||||
|
||||
rate = num_tasks / (end_time - start_time - sleep_time)
|
||||
print(f"Success! Started {num_tasks} tasks in {end_time - start_time}s. "
|
||||
f"({rate} tasks/s)")
|
||||
print(
|
||||
f"Success! Started {num_tasks} tasks in {end_time - start_time}s. "
|
||||
f"({rate} tasks/s)"
|
||||
)
|
||||
|
||||
if "TEST_OUTPUT_JSON" in os.environ:
|
||||
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
||||
|
@ -77,7 +76,7 @@ def test(num_tasks):
|
|||
"time": end_time - start_time,
|
||||
"success": "1",
|
||||
"_peak_memory": round(used_gb, 2),
|
||||
"_peak_process_memory": usage
|
||||
"_peak_process_memory": usage,
|
||||
}
|
||||
json.dump(results, out_file)
|
||||
|
||||
|
|
|
@ -25,10 +25,12 @@ class SimpleActor:
|
|||
|
||||
|
||||
def start_tasks(num_task, num_cpu_per_task, task_duration):
|
||||
ray.get([
|
||||
simple_task.options(num_cpus=num_cpu_per_task).remote(task_duration)
|
||||
for _ in range(num_task)
|
||||
])
|
||||
ray.get(
|
||||
[
|
||||
simple_task.options(num_cpus=num_cpu_per_task).remote(task_duration)
|
||||
for _ in range(num_task)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def measure(f):
|
||||
|
@ -40,13 +42,16 @@ def measure(f):
|
|||
|
||||
def start_actor(num_actors, num_actors_per_nodes, job):
|
||||
resources = {"node": floor(1.0 / num_actors_per_nodes)}
|
||||
submission_cost, actors = measure(lambda: [
|
||||
SimpleActor.options(resources=resources, num_cpus=0).remote(job)
|
||||
for _ in range(num_actors)])
|
||||
ready_cost, _ = measure(
|
||||
lambda: ray.get([actor.ready.remote() for actor in actors]))
|
||||
submission_cost, actors = measure(
|
||||
lambda: [
|
||||
SimpleActor.options(resources=resources, num_cpus=0).remote(job)
|
||||
for _ in range(num_actors)
|
||||
]
|
||||
)
|
||||
ready_cost, _ = measure(lambda: ray.get([actor.ready.remote() for actor in actors]))
|
||||
actor_job_cost, _ = measure(
|
||||
lambda: ray.get([actor.do_job.remote() for actor in actors]))
|
||||
lambda: ray.get([actor.do_job.remote() for actor in actors])
|
||||
)
|
||||
return (submission_cost, ready_cost, actor_job_cost)
|
||||
|
||||
|
||||
|
@ -54,33 +59,32 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser(prog="Test Scheduling")
|
||||
# Task workloads
|
||||
parser.add_argument(
|
||||
"--total-num-task",
|
||||
type=int,
|
||||
help="Total number of tasks.",
|
||||
required=False)
|
||||
"--total-num-task", type=int, help="Total number of tasks.", required=False
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-cpu-per-task",
|
||||
type=int,
|
||||
help="Resources needed for tasks.",
|
||||
required=False)
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-duration-s",
|
||||
type=int,
|
||||
help="How long does each task execute.",
|
||||
required=False,
|
||||
default=1)
|
||||
default=1,
|
||||
)
|
||||
|
||||
# Actor workloads
|
||||
parser.add_argument(
|
||||
"--total-num-actors",
|
||||
type=int,
|
||||
help="Total number of actors.",
|
||||
required=True)
|
||||
"--total-num-actors", type=int, help="Total number of actors.", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-actors-per-nodes",
|
||||
type=int,
|
||||
help="How many actors to allocate for each nodes.",
|
||||
required=True)
|
||||
required=True,
|
||||
)
|
||||
|
||||
ray.init(address="auto")
|
||||
|
||||
|
@ -92,13 +96,14 @@ if __name__ == "__main__":
|
|||
job = None
|
||||
if args.total_num_task is not None:
|
||||
if args.num_cpu_per_task is None:
|
||||
args.num_cpu_per_task = floor(
|
||||
1.0 * total_cpus / args.total_num_task)
|
||||
args.num_cpu_per_task = floor(1.0 * total_cpus / args.total_num_task)
|
||||
job = lambda: start_tasks( # noqa: E731
|
||||
args.total_num_task, args.num_cpu_per_task, args.task_duration_s)
|
||||
args.total_num_task, args.num_cpu_per_task, args.task_duration_s
|
||||
)
|
||||
|
||||
submission_cost, ready_cost, actor_job_cost = start_actor(
|
||||
args.total_num_actors, args.num_actors_per_nodes, job)
|
||||
args.total_num_actors, args.num_actors_per_nodes, job
|
||||
)
|
||||
|
||||
output = os.environ.get("TEST_OUTPUT_JSON")
|
||||
|
||||
|
@ -118,6 +123,7 @@ if __name__ == "__main__":
|
|||
|
||||
if output is not None:
|
||||
from pathlib import Path
|
||||
|
||||
p = Path(output)
|
||||
p.write_text(json.dumps(result))
|
||||
|
||||
|
|
|
@ -12,8 +12,7 @@ def num_alive_nodes():
|
|||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--num-nodes", required=True, type=int, help="The target number of nodes")
|
||||
@click.option("--num-nodes", required=True, type=int, help="The target number of nodes")
|
||||
def wait_cluster(num_nodes: int):
|
||||
ray.init(address="auto")
|
||||
while num_alive_nodes() != num_nodes:
|
||||
|
|
|
@ -9,7 +9,7 @@ from time import perf_counter
|
|||
from tqdm import tqdm
|
||||
|
||||
NUM_NODES = 50
|
||||
OBJECT_SIZE = 2**30
|
||||
OBJECT_SIZE = 2 ** 30
|
||||
|
||||
|
||||
def num_alive_nodes():
|
||||
|
@ -60,6 +60,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
|
|||
"broadcast_time": end - start,
|
||||
"object_size": OBJECT_SIZE,
|
||||
"num_nodes": NUM_NODES,
|
||||
"success": "1"
|
||||
"success": "1",
|
||||
}
|
||||
json.dump(results, out_file)
|
||||
|
|
|
@ -13,7 +13,7 @@ MAX_ARGS = 10000
|
|||
MAX_RETURNS = 3000
|
||||
MAX_RAY_GET_ARGS = 10000
|
||||
MAX_QUEUED_TASKS = 1_000_000
|
||||
MAX_RAY_GET_SIZE = 100 * 2**30
|
||||
MAX_RAY_GET_SIZE = 100 * 2 ** 30
|
||||
|
||||
|
||||
def assert_no_leaks():
|
||||
|
@ -189,8 +189,7 @@ print(f"Many args time: {args_time} ({MAX_ARGS} args)")
|
|||
print(f"Many returns time: {returns_time} ({MAX_RETURNS} returns)")
|
||||
print(f"Ray.get time: {get_time} ({MAX_RAY_GET_ARGS} args)")
|
||||
print(f"Queued task time: {queued_time} ({MAX_QUEUED_TASKS} tasks)")
|
||||
print(f"Ray.get large object time: {large_object_time} "
|
||||
f"({MAX_RAY_GET_SIZE} bytes)")
|
||||
print(f"Ray.get large object time: {large_object_time} " f"({MAX_RAY_GET_SIZE} bytes)")
|
||||
|
||||
if "TEST_OUTPUT_JSON" in os.environ:
|
||||
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
|
||||
|
@ -205,6 +204,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
|
|||
"num_queued": MAX_QUEUED_TASKS,
|
||||
"large_object_time": large_object_time,
|
||||
"large_object_size": MAX_RAY_GET_SIZE,
|
||||
"success": "1"
|
||||
"success": "1",
|
||||
}
|
||||
json.dump(results, out_file)
|
||||
|
|
|
@ -66,7 +66,7 @@ def get_remote_url(remote):
|
|||
|
||||
def replace_suffix(base, old_suffix, new_suffix=""):
|
||||
if base.endswith(old_suffix):
|
||||
base = base[:len(base) - len(old_suffix)] + new_suffix
|
||||
base = base[: len(base) - len(old_suffix)] + new_suffix
|
||||
return base
|
||||
|
||||
|
||||
|
@ -199,12 +199,21 @@ def monitor():
|
|||
expected_line = "{}\t{}".format(expected_sha, ref)
|
||||
|
||||
if should_keep_alive(git("show", "-s", "--format=%B", "HEAD^-")):
|
||||
logger.info("Not monitoring %s on %s due to keep-alive on: %s", ref,
|
||||
remote, expected_line)
|
||||
logger.info(
|
||||
"Not monitoring %s on %s due to keep-alive on: %s",
|
||||
ref,
|
||||
remote,
|
||||
expected_line,
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("Monitoring %s (%s) for changes in %s: %s", remote,
|
||||
get_remote_url(remote), ref, expected_line)
|
||||
logger.info(
|
||||
"Monitoring %s (%s) for changes in %s: %s",
|
||||
remote,
|
||||
get_remote_url(remote),
|
||||
ref,
|
||||
expected_line,
|
||||
)
|
||||
|
||||
for to_wait in yield_poll_schedule():
|
||||
time.sleep(to_wait)
|
||||
|
@ -217,12 +226,21 @@ def monitor():
|
|||
status = ex.returncode
|
||||
|
||||
if status == 2:
|
||||
logger.info("Terminating job as %s has been deleted on %s: %s",
|
||||
ref, remote, expected_line)
|
||||
logger.info(
|
||||
"Terminating job as %s has been deleted on %s: %s",
|
||||
ref,
|
||||
remote,
|
||||
expected_line,
|
||||
)
|
||||
break
|
||||
elif status != 0:
|
||||
logger.error("Error %d: unable to check %s on %s: %s", status, ref,
|
||||
remote, expected_line)
|
||||
logger.error(
|
||||
"Error %d: unable to check %s on %s: %s",
|
||||
status,
|
||||
ref,
|
||||
remote,
|
||||
expected_line,
|
||||
)
|
||||
else:
|
||||
prev = expected_line
|
||||
expected_line = detect_spurious_commit(line, expected_line, remote)
|
||||
|
@ -230,14 +248,24 @@ def monitor():
|
|||
logger.info(
|
||||
"Terminating job as %s has been updated on %s\n"
|
||||
" from:\t%s\n"
|
||||
" to: \t%s", ref, remote, expected_line, line)
|
||||
" to: \t%s",
|
||||
ref,
|
||||
remote,
|
||||
expected_line,
|
||||
line,
|
||||
)
|
||||
time.sleep(1) # wait for CI to flush output
|
||||
break
|
||||
if expected_line != prev:
|
||||
logger.info(
|
||||
"%s appeared to spuriously change on %s\n"
|
||||
" from:\t%s\n"
|
||||
" to: \t%s", ref, remote, prev, expected_line)
|
||||
" to: \t%s",
|
||||
ref,
|
||||
remote,
|
||||
prev,
|
||||
expected_line,
|
||||
)
|
||||
|
||||
return terminate_my_process_group()
|
||||
|
||||
|
@ -259,9 +287,8 @@ def main(program, *args):
|
|||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(
|
||||
format="%(levelname)s: %(message)s",
|
||||
stream=sys.stderr,
|
||||
level=logging.DEBUG)
|
||||
format="%(levelname)s: %(message)s", stream=sys.stderr, level=logging.DEBUG
|
||||
)
|
||||
try:
|
||||
raise SystemExit(main(*sys.argv) or 0)
|
||||
except KeyboardInterrupt:
|
||||
|
|
271
ci/repro-ci.py
271
ci/repro-ci.py
|
@ -53,7 +53,10 @@ def maybe_fetch_buildkite_token():
|
|||
"secretsmanager", region_name="us-west-2"
|
||||
).get_secret_value(
|
||||
SecretId="arn:aws:secretsmanager:us-west-2:029272617770:secret:"
|
||||
"buildkite/ro-token")["SecretString"]
|
||||
"buildkite/ro-token"
|
||||
)[
|
||||
"SecretString"
|
||||
]
|
||||
|
||||
|
||||
def escape(v: Any):
|
||||
|
@ -85,26 +88,26 @@ def env_str(env: Dict[str, Any]):
|
|||
|
||||
def script_str(v: Any):
|
||||
if isinstance(v, bool):
|
||||
return f"\"{int(v)}\""
|
||||
return f'"{int(v)}"'
|
||||
elif isinstance(v, Number):
|
||||
return f"\"{v}\""
|
||||
return f'"{v}"'
|
||||
elif isinstance(v, list):
|
||||
return "(" + " ".join(f"\"{shlex.quote(w)}\"" for w in v) + ")"
|
||||
return "(" + " ".join(f'"{shlex.quote(w)}"' for w in v) + ")"
|
||||
else:
|
||||
return f"\"{shlex.quote(v)}\""
|
||||
return f'"{shlex.quote(v)}"'
|
||||
|
||||
|
||||
class ReproSession:
|
||||
plugin_default_env = {
|
||||
"docker": {
|
||||
"BUILDKITE_PLUGIN_DOCKER_MOUNT_BUILDKITE_AGENT": False
|
||||
}
|
||||
"docker": {"BUILDKITE_PLUGIN_DOCKER_MOUNT_BUILDKITE_AGENT": False}
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
buildkite_token: str,
|
||||
instance_name: Optional[str] = None,
|
||||
logger: Optional[logging.Logger] = None):
|
||||
def __init__(
|
||||
self,
|
||||
buildkite_token: str,
|
||||
instance_name: Optional[str] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
self.logger = logger or logging.getLogger(self.__class__.__name__)
|
||||
|
||||
self.bk = Buildkite()
|
||||
|
@ -139,12 +142,15 @@ class ReproSession:
|
|||
# https://buildkite.com/ray-project/ray-builders-pr/
|
||||
# builds/19635#55a0d71a-831e-4f68-b668-2b10c6f65ee6
|
||||
pattern = re.compile(
|
||||
"https://buildkite.com/([^/]+)/([^/]+)/builds/([0-9]+)#(.+)")
|
||||
"https://buildkite.com/([^/]+)/([^/]+)/builds/([0-9]+)#(.+)"
|
||||
)
|
||||
org, pipeline, build_id, job_id = pattern.match(session_url).groups()
|
||||
|
||||
self.logger.debug(f"Parsed session URL: {session_url}. "
|
||||
f"Got org='{org}', pipeline='{pipeline}', "
|
||||
f"build_id='{build_id}', job_id='{job_id}'.")
|
||||
self.logger.debug(
|
||||
f"Parsed session URL: {session_url}. "
|
||||
f"Got org='{org}', pipeline='{pipeline}', "
|
||||
f"build_id='{build_id}', job_id='{job_id}'."
|
||||
)
|
||||
|
||||
self.org = org
|
||||
self.pipeline = pipeline
|
||||
|
@ -155,7 +161,8 @@ class ReproSession:
|
|||
assert self.bk
|
||||
|
||||
self.env = self.bk.jobs().get_job_environment_variables(
|
||||
self.org, self.pipeline, self.build_id, self.job_id)["env"]
|
||||
self.org, self.pipeline, self.build_id, self.job_id
|
||||
)["env"]
|
||||
|
||||
if overwrite:
|
||||
self.env.update(overwrite)
|
||||
|
@ -166,33 +173,30 @@ class ReproSession:
|
|||
assert self.env
|
||||
|
||||
if not self.aws_instance_name:
|
||||
self.aws_instance_name = (
|
||||
f"repro_ci_{self.build_id}_{self.job_id[:8]}")
|
||||
self.aws_instance_name = f"repro_ci_{self.build_id}_{self.job_id[:8]}"
|
||||
self.logger.info(
|
||||
f"No instance name provided, using {self.aws_instance_name}")
|
||||
f"No instance name provided, using {self.aws_instance_name}"
|
||||
)
|
||||
|
||||
instance_type = self.env["BUILDKITE_AGENT_META_DATA_AWS_INSTANCE_TYPE"]
|
||||
instance_ami = self.env["BUILDKITE_AGENT_META_DATA_AWS_AMI_ID"]
|
||||
instance_sg = "sg-0ccfca2ef191c04ae"
|
||||
instance_block_device_mappings = [{
|
||||
"DeviceName": "/dev/xvda",
|
||||
"Ebs": {
|
||||
"VolumeSize": 500
|
||||
}
|
||||
}]
|
||||
instance_block_device_mappings = [
|
||||
{"DeviceName": "/dev/xvda", "Ebs": {"VolumeSize": 500}}
|
||||
]
|
||||
|
||||
# Check if instance exists:
|
||||
running_instances = self.ec2_resource.instances.filter(Filters=[{
|
||||
"Name": "tag:repro_name",
|
||||
"Values": [self.aws_instance_name]
|
||||
}, {
|
||||
"Name": "instance-state-name",
|
||||
"Values": ["running"]
|
||||
}])
|
||||
running_instances = self.ec2_resource.instances.filter(
|
||||
Filters=[
|
||||
{"Name": "tag:repro_name", "Values": [self.aws_instance_name]},
|
||||
{"Name": "instance-state-name", "Values": ["running"]},
|
||||
]
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Check if instance with name {self.aws_instance_name} "
|
||||
f"already exists...")
|
||||
f"already exists..."
|
||||
)
|
||||
|
||||
for instance in running_instances:
|
||||
self.aws_instance_id = instance.id
|
||||
|
@ -201,8 +205,8 @@ class ReproSession:
|
|||
return
|
||||
|
||||
self.logger.info(
|
||||
f"Instance with name {self.aws_instance_name} not found, "
|
||||
f"creating...")
|
||||
f"Instance with name {self.aws_instance_name} not found, " f"creating..."
|
||||
)
|
||||
|
||||
# Else, not running, yet, start.
|
||||
instance = self.ec2_resource.create_instances(
|
||||
|
@ -211,20 +215,18 @@ class ReproSession:
|
|||
InstanceType=instance_type,
|
||||
KeyName=self.ssh_key_name,
|
||||
SecurityGroupIds=[instance_sg],
|
||||
TagSpecifications=[{
|
||||
"ResourceType": "instance",
|
||||
"Tags": [{
|
||||
"Key": "repro_name",
|
||||
"Value": self.aws_instance_name
|
||||
}]
|
||||
}],
|
||||
TagSpecifications=[
|
||||
{
|
||||
"ResourceType": "instance",
|
||||
"Tags": [{"Key": "repro_name", "Value": self.aws_instance_name}],
|
||||
}
|
||||
],
|
||||
MinCount=1,
|
||||
MaxCount=1,
|
||||
)[0]
|
||||
|
||||
self.aws_instance_id = instance.id
|
||||
self.logger.info(
|
||||
f"Created new instance with ID {self.aws_instance_id}")
|
||||
self.logger.info(f"Created new instance with ID {self.aws_instance_id}")
|
||||
|
||||
def aws_wait_for_instance(self):
|
||||
assert self.aws_instance_id
|
||||
|
@ -234,28 +236,32 @@ class ReproSession:
|
|||
repro_instance_state = None
|
||||
while repro_instance_state != "running":
|
||||
detail = self.ec2_client.describe_instances(
|
||||
InstanceIds=[self.aws_instance_id], )
|
||||
repro_instance_state = \
|
||||
detail["Reservations"][0]["Instances"][0]["State"]["Name"]
|
||||
InstanceIds=[self.aws_instance_id],
|
||||
)
|
||||
repro_instance_state = detail["Reservations"][0]["Instances"][0]["State"][
|
||||
"Name"
|
||||
]
|
||||
|
||||
if repro_instance_state != "running":
|
||||
time.sleep(2)
|
||||
|
||||
self.aws_instance_ip = detail["Reservations"][0]["Instances"][0][
|
||||
"PublicIpAddress"]
|
||||
"PublicIpAddress"
|
||||
]
|
||||
|
||||
def aws_stop_instance(self):
|
||||
assert self.aws_instance_id
|
||||
|
||||
self.ec2_client.terminate_instances(
|
||||
InstanceIds=[self.aws_instance_id], )
|
||||
InstanceIds=[self.aws_instance_id],
|
||||
)
|
||||
|
||||
def print_stop_command(self):
|
||||
click.secho("To stop this instance in the future, run this: ")
|
||||
click.secho(
|
||||
f"aws ec2 terminate-instances "
|
||||
f"--instance-ids={self.aws_instance_id}",
|
||||
bold=True)
|
||||
f"aws ec2 terminate-instances " f"--instance-ids={self.aws_instance_id}",
|
||||
bold=True,
|
||||
)
|
||||
|
||||
def create_new_ssh_client(self):
|
||||
assert self.aws_instance_ip
|
||||
|
@ -264,7 +270,8 @@ class ReproSession:
|
|||
self.ssh.close()
|
||||
|
||||
self.logger.info(
|
||||
"Creating SSH client and waiting for SSH to become available...")
|
||||
"Creating SSH client and waiting for SSH to become available..."
|
||||
)
|
||||
|
||||
ssh = paramiko.client.SSHClient()
|
||||
ssh.load_system_host_keys()
|
||||
|
@ -275,7 +282,8 @@ class ReproSession:
|
|||
ssh.connect(
|
||||
self.aws_instance_ip,
|
||||
username=self.ssh_user,
|
||||
key_filename=os.path.expanduser(self.ssh_key_file))
|
||||
key_filename=os.path.expanduser(self.ssh_key_file),
|
||||
)
|
||||
break
|
||||
except paramiko.ssh_exception.NoValidConnectionsError:
|
||||
self.logger.info("SSH not ready, yet, sleeping 5 seconds")
|
||||
|
@ -291,8 +299,7 @@ class ReproSession:
|
|||
result = {}
|
||||
|
||||
def exec():
|
||||
stdin, stdout, stderr = self.ssh.exec_command(
|
||||
command, get_pty=True)
|
||||
stdin, stdout, stderr = self.ssh.exec_command(command, get_pty=True)
|
||||
|
||||
output = ""
|
||||
for line in stdout.readlines():
|
||||
|
@ -321,12 +328,13 @@ class ReproSession:
|
|||
return result.get("output", "")
|
||||
|
||||
def execute_ssh_command(
|
||||
self,
|
||||
command: str,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
as_script: bool = False,
|
||||
quiet: bool = False,
|
||||
command_wrapper: Optional[Callable[[str], str]] = None) -> str:
|
||||
self,
|
||||
command: str,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
as_script: bool = False,
|
||||
quiet: bool = False,
|
||||
command_wrapper: Optional[Callable[[str], str]] = None,
|
||||
) -> str:
|
||||
assert self.ssh
|
||||
|
||||
if not command_wrapper:
|
||||
|
@ -360,23 +368,25 @@ class ReproSession:
|
|||
|
||||
return output
|
||||
|
||||
def execute_ssh_commands(self,
|
||||
commands: List[str],
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
quiet: bool = False):
|
||||
def execute_ssh_commands(
|
||||
self,
|
||||
commands: List[str],
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
quiet: bool = False,
|
||||
):
|
||||
for command in commands:
|
||||
self.execute_ssh_command(command, env=env, quiet=quiet)
|
||||
|
||||
def execute_docker_command(self,
|
||||
command: str,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
quiet: bool = False):
|
||||
def execute_docker_command(
|
||||
self, command: str, env: Optional[Dict[str, str]] = None, quiet: bool = False
|
||||
):
|
||||
def command_wrapper(s):
|
||||
escaped = s.replace("'", "'\"'\"'")
|
||||
return f"docker exec -it ray_container /bin/bash -ci '{escaped}'"
|
||||
|
||||
self.execute_ssh_command(
|
||||
command, env=env, quiet=quiet, command_wrapper=command_wrapper)
|
||||
command, env=env, quiet=quiet, command_wrapper=command_wrapper
|
||||
)
|
||||
|
||||
def prepare_instance(self):
|
||||
self.create_new_ssh_client()
|
||||
|
@ -387,8 +397,9 @@ class ReproSession:
|
|||
|
||||
self.logger.info("Preparing instance (installing docker etc.)")
|
||||
commands = [
|
||||
"sudo yum install -y docker", "sudo service docker start",
|
||||
f"sudo usermod -aG docker {self.ssh_user}"
|
||||
"sudo yum install -y docker",
|
||||
"sudo service docker start",
|
||||
f"sudo usermod -aG docker {self.ssh_user}",
|
||||
]
|
||||
self.execute_ssh_commands(commands, quiet=True)
|
||||
self.create_new_ssh_client()
|
||||
|
@ -398,13 +409,18 @@ class ReproSession:
|
|||
def docker_login(self):
|
||||
self.logger.info("Logging into docker...")
|
||||
credentials = boto3.client(
|
||||
"ecr", region_name="us-west-2").get_authorization_token()
|
||||
token = base64.b64decode(credentials["authorizationData"][0][
|
||||
"authorizationToken"]).decode("utf-8").replace("AWS:", "")
|
||||
"ecr", region_name="us-west-2"
|
||||
).get_authorization_token()
|
||||
token = (
|
||||
base64.b64decode(credentials["authorizationData"][0]["authorizationToken"])
|
||||
.decode("utf-8")
|
||||
.replace("AWS:", "")
|
||||
)
|
||||
endpoint = credentials["authorizationData"][0]["proxyEndpoint"]
|
||||
|
||||
self.execute_ssh_command(
|
||||
f"docker login -u AWS -p {token} {endpoint}", quiet=True)
|
||||
f"docker login -u AWS -p {token} {endpoint}", quiet=True
|
||||
)
|
||||
|
||||
def fetch_buildkite_plugins(self):
|
||||
assert self.env
|
||||
|
@ -415,8 +431,9 @@ class ReproSession:
|
|||
for collection in plugins:
|
||||
for plugin, options in collection.items():
|
||||
plugin_url, plugin_version = plugin.split("#")
|
||||
if not plugin_url.startswith(
|
||||
"http://") or not plugin_url.startswith("https://"):
|
||||
if not plugin_url.startswith("http://") or not plugin_url.startswith(
|
||||
"https://"
|
||||
):
|
||||
plugin_url = f"https://{plugin_url}"
|
||||
|
||||
plugin_name = plugin_url.split("/")[-1].rstrip(".git")
|
||||
|
@ -432,7 +449,7 @@ class ReproSession:
|
|||
"version": plugin_version,
|
||||
"dir": plugin_dir,
|
||||
"env": plugin_env,
|
||||
"details": {}
|
||||
"details": {},
|
||||
}
|
||||
|
||||
def get_plugin_env(self, plugin_short: str, options: Dict[str, Any]):
|
||||
|
@ -457,30 +474,33 @@ class ReproSession:
|
|||
self.execute_ssh_command(
|
||||
f"[ ! -e {plugin_dir} ] && git clone --depth 1 "
|
||||
f"--branch {plugin_version} {plugin_url} {plugin_dir}",
|
||||
quiet=True)
|
||||
quiet=True,
|
||||
)
|
||||
|
||||
def load_plugin_details(self, plugin: str):
|
||||
assert plugin in self.plugins
|
||||
|
||||
plugin_dir = self.plugins[plugin]["dir"]
|
||||
|
||||
yaml_str = self.execute_ssh_command(
|
||||
f"cat {plugin_dir}/plugin.yml", quiet=True)
|
||||
yaml_str = self.execute_ssh_command(f"cat {plugin_dir}/plugin.yml", quiet=True)
|
||||
|
||||
details = yaml.safe_load(yaml_str)
|
||||
self.plugins[plugin]["details"] = details
|
||||
return details
|
||||
|
||||
def execute_plugin_hook(self,
|
||||
plugin: str,
|
||||
hook: str,
|
||||
env: Optional[Dict[str, Any]] = None,
|
||||
script_command: Optional[str] = None):
|
||||
def execute_plugin_hook(
|
||||
self,
|
||||
plugin: str,
|
||||
hook: str,
|
||||
env: Optional[Dict[str, Any]] = None,
|
||||
script_command: Optional[str] = None,
|
||||
):
|
||||
assert plugin in self.plugins
|
||||
|
||||
self.logger.info(
|
||||
f"Executing Buildkite hook for plugin {plugin}: {hook}. "
|
||||
f"This pulls a Docker image and could take a while.")
|
||||
f"This pulls a Docker image and could take a while."
|
||||
)
|
||||
|
||||
plugin_dir = self.plugins[plugin]["dir"]
|
||||
plugin_env = self.plugins[plugin]["env"].copy()
|
||||
|
@ -500,21 +520,23 @@ class ReproSession:
|
|||
|
||||
def print_buildkite_command(self, skipped: bool = False):
|
||||
print("-" * 80)
|
||||
print("These are the commands you need to execute to fully reproduce "
|
||||
"the run")
|
||||
print(
|
||||
"These are the commands you need to execute to fully reproduce " "the run"
|
||||
)
|
||||
print("-" * 80)
|
||||
print(self.env["BUILDKITE_COMMAND"])
|
||||
print("-" * 80)
|
||||
|
||||
if skipped and self.skipped_commands:
|
||||
print("Some of the commands above have already been run. "
|
||||
"Remaining commands:")
|
||||
print(
|
||||
"Some of the commands above have already been run. "
|
||||
"Remaining commands:"
|
||||
)
|
||||
print("-" * 80)
|
||||
print("\n".join(self.skipped_commands))
|
||||
print("-" * 80)
|
||||
|
||||
def run_buildkite_command(self,
|
||||
command_filter: Optional[List[str]] = None):
|
||||
def run_buildkite_command(self, command_filter: Optional[List[str]] = None):
|
||||
commands = self.env["BUILDKITE_COMMAND"].split("\n")
|
||||
regexes = [re.compile(cf) for cf in command_filter or []]
|
||||
|
||||
|
@ -537,15 +559,18 @@ class ReproSession:
|
|||
f"grep -q 'source ~/.env' $HOME/.bashrc "
|
||||
f"|| echo 'source ~/.env' >> $HOME/.bashrc; "
|
||||
f"echo 'export {escaped}' > $HOME/.env",
|
||||
quiet=True)
|
||||
quiet=True,
|
||||
)
|
||||
|
||||
def attach_to_container(self):
|
||||
self.logger.info("Attaching to AWS instance...")
|
||||
ssh_command = (f"ssh -ti {self.ssh_key_file} "
|
||||
f"-o StrictHostKeyChecking=no "
|
||||
f"-o ServerAliveInterval=30 "
|
||||
f"{self.ssh_user}@{self.aws_instance_ip} "
|
||||
f"'docker exec -it ray_container bash -l'")
|
||||
ssh_command = (
|
||||
f"ssh -ti {self.ssh_key_file} "
|
||||
f"-o StrictHostKeyChecking=no "
|
||||
f"-o ServerAliveInterval=30 "
|
||||
f"{self.ssh_user}@{self.aws_instance_ip} "
|
||||
f"'docker exec -it ray_container bash -l'"
|
||||
)
|
||||
|
||||
subprocess.run(ssh_command, shell=True)
|
||||
|
||||
|
@ -555,29 +580,32 @@ class ReproSession:
|
|||
@click.option("-n", "--instance-name", default=None)
|
||||
@click.option("-c", "--commands", is_flag=True, default=False)
|
||||
@click.option("-f", "--filters", multiple=True, default=[])
|
||||
def main(session_url: Optional[str],
|
||||
instance_name: Optional[str] = None,
|
||||
commands: bool = False,
|
||||
filters: Optional[List[str]] = None):
|
||||
def main(
|
||||
session_url: Optional[str],
|
||||
instance_name: Optional[str] = None,
|
||||
commands: bool = False,
|
||||
filters: Optional[List[str]] = None,
|
||||
):
|
||||
random.seed(1235)
|
||||
|
||||
logger = logging.getLogger("main")
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter("[%(levelname)s %(asctime)s] "
|
||||
"%(filename)s: %(lineno)d "
|
||||
"%(message)s"))
|
||||
logging.Formatter(
|
||||
"[%(levelname)s %(asctime)s] " "%(filename)s: %(lineno)d " "%(message)s"
|
||||
)
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
|
||||
maybe_fetch_buildkite_token()
|
||||
repro = ReproSession(
|
||||
os.environ["BUILDKITE_TOKEN"],
|
||||
instance_name=instance_name,
|
||||
logger=logger)
|
||||
os.environ["BUILDKITE_TOKEN"], instance_name=instance_name, logger=logger
|
||||
)
|
||||
|
||||
session_url = session_url or click.prompt(
|
||||
"Please copy and paste the Buildkite job build URI here")
|
||||
"Please copy and paste the Buildkite job build URI here"
|
||||
)
|
||||
|
||||
repro.set_session(session_url)
|
||||
|
||||
|
@ -610,13 +638,16 @@ def main(session_url: Optional[str],
|
|||
"BUILDKITE_PLUGIN_DOCKER_TTY": "0",
|
||||
"BUILDKITE_PLUGIN_DOCKER_MOUNT_CHECKOUT": "0",
|
||||
},
|
||||
script_command=("sed -E 's/"
|
||||
"docker run/"
|
||||
"docker run "
|
||||
"--cap-add=SYS_PTRACE "
|
||||
"--name ray_container "
|
||||
"-d/g' | "
|
||||
"bash -l"))
|
||||
script_command=(
|
||||
"sed -E 's/"
|
||||
"docker run/"
|
||||
"docker run "
|
||||
"--cap-add=SYS_PTRACE "
|
||||
"--name ray_container "
|
||||
"-d/g' | "
|
||||
"bash -l"
|
||||
),
|
||||
)
|
||||
|
||||
repro.create_new_ssh_client()
|
||||
|
||||
|
|
|
@ -54,7 +54,8 @@ def get_target_expansion_query(targets, tests_only, exclude_manual):
|
|||
|
||||
if exclude_manual:
|
||||
query = '{} except tests(attr("tags", "manual", set({})))'.format(
|
||||
query, included_targets)
|
||||
query, included_targets
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
|
@ -82,17 +83,16 @@ def get_targets_for_shard(targets, index, count):
|
|||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Expand and shard Bazel targets.")
|
||||
parser = argparse.ArgumentParser(description="Expand and shard Bazel targets.")
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--tests_only", action="store_true")
|
||||
parser.add_argument("--exclude_manual", action="store_true")
|
||||
parser.add_argument(
|
||||
"--index", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB", 1))
|
||||
"--index", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB", 1)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--count",
|
||||
type=int,
|
||||
default=os.getenv("BUILDKITE_PARALLEL_JOB_COUNT", 1))
|
||||
"--count", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB_COUNT", 1)
|
||||
)
|
||||
parser.add_argument("targets", nargs="+")
|
||||
args, extra_args = parser.parse_known_args()
|
||||
args.targets = list(args.targets) + list(extra_args)
|
||||
|
@ -100,11 +100,11 @@ def main():
|
|||
if args.index >= args.count:
|
||||
parser.error("--index must be between 0 and {}".format(args.count - 1))
|
||||
|
||||
query = get_target_expansion_query(args.targets, args.tests_only,
|
||||
args.exclude_manual)
|
||||
query = get_target_expansion_query(
|
||||
args.targets, args.tests_only, args.exclude_manual
|
||||
)
|
||||
expanded_targets = run_bazel_query(query, args.debug)
|
||||
my_targets = get_targets_for_shard(expanded_targets, args.index,
|
||||
args.count)
|
||||
my_targets = get_targets_for_shard(expanded_targets, args.index, args.count)
|
||||
print(" ".join(my_targets))
|
||||
|
||||
return 0
|
||||
|
|
|
@ -14,10 +14,10 @@ from collections import defaultdict, OrderedDict
|
|||
|
||||
def textproto_format(space, key, value, json_encoder):
|
||||
"""Rewrites a key-value pair from textproto as JSON."""
|
||||
if value.startswith(b"\""):
|
||||
if value.startswith(b'"'):
|
||||
evaluated = ast.literal_eval(value.decode("utf-8"))
|
||||
value = json_encoder.encode(evaluated).encode("utf-8")
|
||||
return b"%s[\"%s\", %s]" % (space, key, value)
|
||||
return b'%s["%s", %s]' % (space, key, value)
|
||||
|
||||
|
||||
def textproto_split(input_lines, json_encoder):
|
||||
|
@ -50,19 +50,21 @@ def textproto_split(input_lines, json_encoder):
|
|||
pieces = re.split(b"(\\r|\\n)", full_line, 1)
|
||||
pieces[1:] = [b"".join(pieces[1:])]
|
||||
[line, tail] = pieces
|
||||
next_line = pat_open.sub(b"\\1[\"\\2\",\\3[", line)
|
||||
outputs.append(b"" if not prev_comma else b"]"
|
||||
if next_line.endswith(b"}") else b",")
|
||||
next_line = pat_open.sub(b'\\1["\\2",\\3[', line)
|
||||
outputs.append(
|
||||
b"" if not prev_comma else b"]" if next_line.endswith(b"}") else b","
|
||||
)
|
||||
next_line = pat_close.sub(b"]", next_line)
|
||||
next_line = pat_line.sub(
|
||||
lambda m: textproto_format(*(m.groups() + (json_encoder, ))),
|
||||
next_line)
|
||||
lambda m: textproto_format(*(m.groups() + (json_encoder,))), next_line
|
||||
)
|
||||
outputs.append(prev_tail + next_line)
|
||||
if line == b"}":
|
||||
yield b"".join(outputs)
|
||||
del outputs[:]
|
||||
prev_comma = line != b"}" and (next_line.endswith(b"]")
|
||||
or next_line.endswith(b"\""))
|
||||
prev_comma = line != b"}" and (
|
||||
next_line.endswith(b"]") or next_line.endswith(b'"')
|
||||
)
|
||||
prev_tail = tail
|
||||
if len(outputs) > 0:
|
||||
yield b"".join(outputs)
|
||||
|
@ -80,13 +82,14 @@ class Bazel(object):
|
|||
def __init__(self, program=None):
|
||||
if program is None:
|
||||
program = os.getenv("BAZEL_EXECUTABLE", "bazel")
|
||||
self.argv = (program, )
|
||||
self.extra_args = ("--show_progress=no", )
|
||||
self.argv = (program,)
|
||||
self.extra_args = ("--show_progress=no",)
|
||||
|
||||
def _call(self, command, *args):
|
||||
return subprocess.check_output(
|
||||
self.argv + (command, ) + args[:1] + self.extra_args + args[1:],
|
||||
stdin=subprocess.PIPE)
|
||||
self.argv + (command,) + args[:1] + self.extra_args + args[1:],
|
||||
stdin=subprocess.PIPE,
|
||||
)
|
||||
|
||||
def info(self, *args):
|
||||
result = OrderedDict()
|
||||
|
@ -248,8 +251,7 @@ def shellcheck(bazel_aquery, *shellcheck_argv):
|
|||
def main(program, command, *command_args):
|
||||
result = 0
|
||||
if command == textproto2json.__name__:
|
||||
result = textproto2json(sys.stdin.buffer, sys.stdout.buffer,
|
||||
*command_args)
|
||||
result = textproto2json(sys.stdin.buffer, sys.stdout.buffer, *command_args)
|
||||
elif command == shellcheck.__name__:
|
||||
result = shellcheck(*command_args)
|
||||
elif command == preclean.__name__:
|
||||
|
|
|
@ -20,21 +20,16 @@ DOCKER_CLIENT = None
|
|||
PYTHON_WHL_VERSION = "cp3"
|
||||
|
||||
DOCKER_HUB_DESCRIPTION = {
|
||||
"base-deps": ("Internal Image, refer to "
|
||||
"https://hub.docker.com/r/rayproject/ray"),
|
||||
"ray-deps": ("Internal Image, refer to "
|
||||
"https://hub.docker.com/r/rayproject/ray"),
|
||||
"base-deps": (
|
||||
"Internal Image, refer to " "https://hub.docker.com/r/rayproject/ray"
|
||||
),
|
||||
"ray-deps": ("Internal Image, refer to " "https://hub.docker.com/r/rayproject/ray"),
|
||||
"ray": "Official Docker Images for Ray, the distributed computing API.",
|
||||
"ray-ml": "Developer ready Docker Image for Ray.",
|
||||
"ray-worker-container": "Internal Image for CI test",
|
||||
}
|
||||
|
||||
PY_MATRIX = {
|
||||
"py36": "3.6.12",
|
||||
"py37": "3.7.7",
|
||||
"py38": "3.8.5",
|
||||
"py39": "3.9.5"
|
||||
}
|
||||
PY_MATRIX = {"py36": "3.6.12", "py37": "3.7.7", "py38": "3.8.5", "py39": "3.9.5"}
|
||||
|
||||
BASE_IMAGES = {
|
||||
"cu112": "nvidia/cuda:11.2.0-cudnn8-devel-ubuntu18.04",
|
||||
|
@ -50,7 +45,7 @@ CUDA_FULL = {
|
|||
"cu111": "CUDA 11.1",
|
||||
"cu110": "CUDA 11.0",
|
||||
"cu102": "CUDA 10.2",
|
||||
"cu101": "CUDA 10.1"
|
||||
"cu101": "CUDA 10.1",
|
||||
}
|
||||
|
||||
# The CUDA version to use for the ML Docker image.
|
||||
|
@ -62,8 +57,7 @@ IMAGE_NAMES = list(DOCKER_HUB_DESCRIPTION.keys())
|
|||
|
||||
|
||||
def _get_branch():
|
||||
branch = (os.environ.get("TRAVIS_BRANCH")
|
||||
or os.environ.get("BUILDKITE_BRANCH"))
|
||||
branch = os.environ.get("TRAVIS_BRANCH") or os.environ.get("BUILDKITE_BRANCH")
|
||||
if not branch:
|
||||
print("Branch not found!")
|
||||
print(os.environ)
|
||||
|
@ -94,8 +88,7 @@ def _get_root_dir():
|
|||
|
||||
|
||||
def _get_commit_sha():
|
||||
sha = (os.environ.get("TRAVIS_COMMIT")
|
||||
or os.environ.get("BUILDKITE_COMMIT") or "")
|
||||
sha = os.environ.get("TRAVIS_COMMIT") or os.environ.get("BUILDKITE_COMMIT") or ""
|
||||
if len(sha) < 6:
|
||||
print("INVALID SHA FOUND")
|
||||
return "ERROR"
|
||||
|
@ -105,8 +98,9 @@ def _get_commit_sha():
|
|||
def _configure_human_version():
|
||||
global _get_branch
|
||||
global _get_commit_sha
|
||||
fake_branch_name = input("Provide a 'branch name'. For releases, it "
|
||||
"should be `releases/x.x.x`")
|
||||
fake_branch_name = input(
|
||||
"Provide a 'branch name'. For releases, it " "should be `releases/x.x.x`"
|
||||
)
|
||||
_get_branch = lambda: fake_branch_name # noqa: E731
|
||||
fake_sha = input("Provide a SHA (used for tag value)")
|
||||
_get_commit_sha = lambda: fake_sha # noqa: E731
|
||||
|
@ -115,38 +109,44 @@ def _configure_human_version():
|
|||
def _get_wheel_name(minor_version_number):
|
||||
if minor_version_number:
|
||||
matches = [
|
||||
file for file in glob.glob(
|
||||
file
|
||||
for file in glob.glob(
|
||||
f"{_get_root_dir()}/.whl/ray-*{PYTHON_WHL_VERSION}"
|
||||
f"{minor_version_number}*-manylinux*")
|
||||
f"{minor_version_number}*-manylinux*"
|
||||
)
|
||||
if "+" not in file # Exclude dbg, asan builds
|
||||
]
|
||||
assert len(matches) == 1, (
|
||||
f"Found ({len(matches)}) matches for 'ray-*{PYTHON_WHL_VERSION}"
|
||||
f"{minor_version_number}*-manylinux*' instead of 1.\n"
|
||||
f"wheel matches: {matches}")
|
||||
f"wheel matches: {matches}"
|
||||
)
|
||||
return os.path.basename(matches[0])
|
||||
else:
|
||||
matches = glob.glob(
|
||||
f"{_get_root_dir()}/.whl/*{PYTHON_WHL_VERSION}*-manylinux*")
|
||||
matches = glob.glob(f"{_get_root_dir()}/.whl/*{PYTHON_WHL_VERSION}*-manylinux*")
|
||||
return [os.path.basename(i) for i in matches]
|
||||
|
||||
|
||||
def _check_if_docker_files_modified():
|
||||
stdout = subprocess.check_output([
|
||||
sys.executable, f"{_get_curr_dir()}/determine_tests_to_run.py",
|
||||
"--output=json"
|
||||
])
|
||||
stdout = subprocess.check_output(
|
||||
[
|
||||
sys.executable,
|
||||
f"{_get_curr_dir()}/determine_tests_to_run.py",
|
||||
"--output=json",
|
||||
]
|
||||
)
|
||||
affected_env_var_list = json.loads(stdout)
|
||||
affected = ("RAY_CI_DOCKER_AFFECTED" in affected_env_var_list or
|
||||
"RAY_CI_PYTHON_DEPENDENCIES_AFFECTED" in affected_env_var_list)
|
||||
affected = (
|
||||
"RAY_CI_DOCKER_AFFECTED" in affected_env_var_list
|
||||
or "RAY_CI_PYTHON_DEPENDENCIES_AFFECTED" in affected_env_var_list
|
||||
)
|
||||
print(f"Docker affected: {affected}")
|
||||
return affected
|
||||
|
||||
|
||||
def _build_docker_image(image_name: str,
|
||||
py_version: str,
|
||||
image_type: str,
|
||||
no_cache=True):
|
||||
def _build_docker_image(
|
||||
image_name: str, py_version: str, image_type: str, no_cache=True
|
||||
):
|
||||
"""Builds Docker image with the provided info.
|
||||
|
||||
image_name (str): The name of the image to build. Must be one of
|
||||
|
@ -161,23 +161,27 @@ def _build_docker_image(image_name: str,
|
|||
if image_name not in IMAGE_NAMES:
|
||||
raise ValueError(
|
||||
f"The provided image name {image_name} is not "
|
||||
f"recognized. Image names must be one of {IMAGE_NAMES}")
|
||||
f"recognized. Image names must be one of {IMAGE_NAMES}"
|
||||
)
|
||||
|
||||
if py_version not in PY_MATRIX.keys():
|
||||
raise ValueError(f"The provided python version {py_version} is not "
|
||||
f"recognized. Python version must be one of"
|
||||
f" {PY_MATRIX.keys()}")
|
||||
raise ValueError(
|
||||
f"The provided python version {py_version} is not "
|
||||
f"recognized. Python version must be one of"
|
||||
f" {PY_MATRIX.keys()}"
|
||||
)
|
||||
|
||||
if image_type not in BASE_IMAGES.keys():
|
||||
raise ValueError(f"The provided CUDA version {image_type} is not "
|
||||
f"recognized. CUDA version must be one of"
|
||||
f" {image_type.keys()}")
|
||||
raise ValueError(
|
||||
f"The provided CUDA version {image_type} is not "
|
||||
f"recognized. CUDA version must be one of"
|
||||
f" {image_type.keys()}"
|
||||
)
|
||||
|
||||
# TODO(https://github.com/ray-project/ray/issues/16599):
|
||||
# remove below after supporting ray-ml images with Python 3.9
|
||||
if image_name == "ray-ml" and py_version == "py39":
|
||||
print(f"{image_name} image is currently unsupported with "
|
||||
"Python 3.9")
|
||||
print(f"{image_name} image is currently unsupported with " "Python 3.9")
|
||||
return
|
||||
|
||||
build_args = {}
|
||||
|
@ -212,7 +216,7 @@ def _build_docker_image(image_name: str,
|
|||
labels = {
|
||||
"image-name": image_name,
|
||||
"python-version": PY_MATRIX[py_version],
|
||||
"ray-commit": _get_commit_sha()
|
||||
"ray-commit": _get_commit_sha(),
|
||||
}
|
||||
if image_type in CUDA_FULL:
|
||||
labels["cuda-version"] = CUDA_FULL[image_type]
|
||||
|
@ -222,7 +226,8 @@ def _build_docker_image(image_name: str,
|
|||
tag=tagged_name,
|
||||
nocache=no_cache,
|
||||
labels=labels,
|
||||
buildargs=build_args)
|
||||
buildargs=build_args,
|
||||
)
|
||||
|
||||
cmd_output = []
|
||||
try:
|
||||
|
@ -230,12 +235,15 @@ def _build_docker_image(image_name: str,
|
|||
current_iter = start
|
||||
for line in output:
|
||||
cmd_output.append(line.decode("utf-8"))
|
||||
if datetime.datetime.now(
|
||||
) - current_iter >= datetime.timedelta(minutes=5):
|
||||
if datetime.datetime.now() - current_iter >= datetime.timedelta(
|
||||
minutes=5
|
||||
):
|
||||
current_iter = datetime.datetime.now()
|
||||
elapsed = datetime.datetime.now() - start
|
||||
print(f"Still building {tagged_name} after "
|
||||
f"{elapsed.seconds} seconds")
|
||||
print(
|
||||
f"Still building {tagged_name} after "
|
||||
f"{elapsed.seconds} seconds"
|
||||
)
|
||||
if elapsed >= datetime.timedelta(minutes=15):
|
||||
print("Additional build output:")
|
||||
print(*cmd_output, sep="\n")
|
||||
|
@ -259,8 +267,10 @@ def _build_docker_image(image_name: str,
|
|||
|
||||
def copy_wheels(human_build):
|
||||
if human_build:
|
||||
print("Please download images using:\n"
|
||||
"`pip download --python-version <py_version> ray==<ray_version>")
|
||||
print(
|
||||
"Please download images using:\n"
|
||||
"`pip download --python-version <py_version> ray==<ray_version>"
|
||||
)
|
||||
root_dir = _get_root_dir()
|
||||
wheels = _get_wheel_name(None)
|
||||
for wheel in wheels:
|
||||
|
@ -268,7 +278,8 @@ def copy_wheels(human_build):
|
|||
ray_dst = os.path.join(root_dir, "docker/ray/.whl/")
|
||||
ray_dep_dst = os.path.join(root_dir, "docker/ray-deps/.whl/")
|
||||
ray_worker_container_dst = os.path.join(
|
||||
root_dir, "docker/ray-worker-container/.whl/")
|
||||
root_dir, "docker/ray-worker-container/.whl/"
|
||||
)
|
||||
os.makedirs(ray_dst, exist_ok=True)
|
||||
shutil.copy(source, ray_dst)
|
||||
os.makedirs(ray_dep_dst, exist_ok=True)
|
||||
|
@ -282,8 +293,7 @@ def check_staleness(repository, tag):
|
|||
|
||||
age = DOCKER_CLIENT.api.inspect_image(f"{repository}:{tag}")["Created"]
|
||||
short_date = datetime.datetime.strptime(age.split("T")[0], "%Y-%m-%d")
|
||||
is_stale = (
|
||||
datetime.datetime.now() - short_date) > datetime.timedelta(days=14)
|
||||
is_stale = (datetime.datetime.now() - short_date) > datetime.timedelta(days=14)
|
||||
return is_stale
|
||||
|
||||
|
||||
|
@ -292,28 +302,23 @@ def build_for_all_versions(image_name, py_versions, image_types, **kwargs):
|
|||
for py_version in py_versions:
|
||||
for image_type in image_types:
|
||||
_build_docker_image(
|
||||
image_name,
|
||||
py_version=py_version,
|
||||
image_type=image_type,
|
||||
**kwargs)
|
||||
image_name, py_version=py_version, image_type=image_type, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def build_base_images(py_versions, image_types):
|
||||
build_for_all_versions(
|
||||
"base-deps", py_versions, image_types, no_cache=False)
|
||||
build_for_all_versions(
|
||||
"ray-deps", py_versions, image_types, no_cache=False)
|
||||
build_for_all_versions("base-deps", py_versions, image_types, no_cache=False)
|
||||
build_for_all_versions("ray-deps", py_versions, image_types, no_cache=False)
|
||||
|
||||
|
||||
def build_or_pull_base_images(py_versions: List[str],
|
||||
image_types: List[str],
|
||||
rebuild_base_images: bool = True) -> bool:
|
||||
def build_or_pull_base_images(
|
||||
py_versions: List[str], image_types: List[str], rebuild_base_images: bool = True
|
||||
) -> bool:
|
||||
"""Returns images to tag and build."""
|
||||
repositories = ["rayproject/base-deps", "rayproject/ray-deps"]
|
||||
tags = [
|
||||
f"nightly-{py_version}-{image_type}"
|
||||
for py_version, image_type in itertools.product(
|
||||
py_versions, image_types)
|
||||
for py_version, image_type in itertools.product(py_versions, image_types)
|
||||
]
|
||||
|
||||
try:
|
||||
|
@ -339,12 +344,15 @@ def build_or_pull_base_images(py_versions: List[str],
|
|||
def prep_ray_ml():
|
||||
root_dir = _get_root_dir()
|
||||
requirement_files = glob.glob(
|
||||
f"{_get_root_dir()}/python/**/requirements*.txt", recursive=True)
|
||||
f"{_get_root_dir()}/python/**/requirements*.txt", recursive=True
|
||||
)
|
||||
for fl in requirement_files:
|
||||
shutil.copy(fl, os.path.join(root_dir, "docker/ray-ml/"))
|
||||
# Install atari roms script
|
||||
shutil.copy(f"{_get_root_dir()}/rllib/utils/install_atari_roms.sh",
|
||||
os.path.join(root_dir, "docker/ray-ml/"))
|
||||
shutil.copy(
|
||||
f"{_get_root_dir()}/rllib/utils/install_atari_roms.sh",
|
||||
os.path.join(root_dir, "docker/ray-ml/"),
|
||||
)
|
||||
|
||||
|
||||
def _get_docker_creds() -> Tuple[str, str]:
|
||||
|
@ -377,10 +385,13 @@ def _tag_and_push(full_image_name, old_tag, new_tag, merge_build=False):
|
|||
DOCKER_CLIENT.api.tag(
|
||||
image=f"{full_image_name}:{old_tag}",
|
||||
repository=full_image_name,
|
||||
tag=new_tag)
|
||||
tag=new_tag,
|
||||
)
|
||||
if not merge_build:
|
||||
print("This is a PR Build! On a merge build, we would normally push"
|
||||
f"to: {full_image_name}:{new_tag}")
|
||||
print(
|
||||
"This is a PR Build! On a merge build, we would normally push"
|
||||
f"to: {full_image_name}:{new_tag}"
|
||||
)
|
||||
else:
|
||||
_docker_push(full_image_name, new_tag)
|
||||
|
||||
|
@ -395,16 +406,17 @@ def _create_new_tags(all_tags, old_str, new_str):
|
|||
|
||||
# For non-release builds, push "nightly" & "sha"
|
||||
# For release builds, push "nightly" & "latest" & "x.x.x"
|
||||
def push_and_tag_images(py_versions: List[str],
|
||||
image_types: List[str],
|
||||
push_base_images: bool,
|
||||
merge_build: bool = False):
|
||||
def push_and_tag_images(
|
||||
py_versions: List[str],
|
||||
image_types: List[str],
|
||||
push_base_images: bool,
|
||||
merge_build: bool = False,
|
||||
):
|
||||
|
||||
date_tag = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
sha_tag = _get_commit_sha()
|
||||
if _release_build():
|
||||
release_name = re.search("[0-9]+\.[0-9]+\.[0-9].*",
|
||||
_get_branch()).group(0)
|
||||
release_name = re.search("[0-9]+\.[0-9]+\.[0-9].*", _get_branch()).group(0)
|
||||
date_tag = release_name
|
||||
sha_tag = release_name
|
||||
|
||||
|
@ -423,16 +435,19 @@ def push_and_tag_images(py_versions: List[str],
|
|||
for py_name in py_versions:
|
||||
for image_type in image_types:
|
||||
if image_name == "ray-ml" and image_type != ML_CUDA_VERSION:
|
||||
print("ML Docker image is not built for the following "
|
||||
f"device type: {image_type}")
|
||||
print(
|
||||
"ML Docker image is not built for the following "
|
||||
f"device type: {image_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# TODO(https://github.com/ray-project/ray/issues/16599):
|
||||
# remove below after supporting ray-ml images with Python 3.9
|
||||
if image_name in ["ray-ml"
|
||||
] and PY_MATRIX[py_name].startswith("3.9"):
|
||||
print(f"{image_name} image is currently "
|
||||
f"unsupported with Python 3.9")
|
||||
if image_name in ["ray-ml"] and PY_MATRIX[py_name].startswith("3.9"):
|
||||
print(
|
||||
f"{image_name} image is currently "
|
||||
f"unsupported with Python 3.9"
|
||||
)
|
||||
continue
|
||||
|
||||
tag = f"nightly-{py_name}-{image_type}"
|
||||
|
@ -445,20 +460,19 @@ def push_and_tag_images(py_versions: List[str],
|
|||
for old_tag in tag_mapping.keys():
|
||||
if "cpu" in old_tag:
|
||||
new_tags = _create_new_tags(
|
||||
tag_mapping[old_tag], old_str="-cpu", new_str="")
|
||||
tag_mapping[old_tag], old_str="-cpu", new_str=""
|
||||
)
|
||||
tag_mapping[old_tag].extend(new_tags)
|
||||
elif ML_CUDA_VERSION in old_tag:
|
||||
new_tags = _create_new_tags(
|
||||
tag_mapping[old_tag],
|
||||
old_str=f"-{ML_CUDA_VERSION}",
|
||||
new_str="-gpu")
|
||||
tag_mapping[old_tag], old_str=f"-{ML_CUDA_VERSION}", new_str="-gpu"
|
||||
)
|
||||
tag_mapping[old_tag].extend(new_tags)
|
||||
|
||||
if image_name == "ray-ml":
|
||||
new_tags = _create_new_tags(
|
||||
tag_mapping[old_tag],
|
||||
old_str=f"-{ML_CUDA_VERSION}",
|
||||
new_str="")
|
||||
tag_mapping[old_tag], old_str=f"-{ML_CUDA_VERSION}", new_str=""
|
||||
)
|
||||
tag_mapping[old_tag].extend(new_tags)
|
||||
|
||||
# No Python version specified should refer to DEFAULT_PYTHON_VERSION
|
||||
|
@ -467,7 +481,8 @@ def push_and_tag_images(py_versions: List[str],
|
|||
new_tags = _create_new_tags(
|
||||
tag_mapping[old_tag],
|
||||
old_str=f"-{DEFAULT_PYTHON_VERSION}",
|
||||
new_str="")
|
||||
new_str="",
|
||||
)
|
||||
tag_mapping[old_tag].extend(new_tags)
|
||||
|
||||
# For all tags, create Date/Sha tags
|
||||
|
@ -475,7 +490,8 @@ def push_and_tag_images(py_versions: List[str],
|
|||
new_tags = _create_new_tags(
|
||||
tag_mapping[old_tag],
|
||||
old_str="nightly",
|
||||
new_str=date_tag if "-deps" in image_name else sha_tag)
|
||||
new_str=date_tag if "-deps" in image_name else sha_tag,
|
||||
)
|
||||
tag_mapping[old_tag].extend(new_tags)
|
||||
|
||||
# Sanity checking.
|
||||
|
@ -511,7 +527,8 @@ def push_and_tag_images(py_versions: List[str],
|
|||
full_image_name,
|
||||
old_tag=old_tag,
|
||||
new_tag=new_tag,
|
||||
merge_build=merge_build)
|
||||
merge_build=merge_build,
|
||||
)
|
||||
|
||||
|
||||
# Push infra here:
|
||||
|
@ -527,9 +544,9 @@ def push_readmes(merge_build: bool):
|
|||
"DOCKER_PASS": password,
|
||||
"PUSHRM_FILE": f"/myvol/docker/{image}/README.md",
|
||||
"PUSHRM_DEBUG": 1,
|
||||
"PUSHRM_SHORT": tag_line
|
||||
"PUSHRM_SHORT": tag_line,
|
||||
}
|
||||
cmd_string = (f"rayproject/{image}")
|
||||
cmd_string = f"rayproject/{image}"
|
||||
|
||||
print(
|
||||
DOCKER_CLIENT.containers.run(
|
||||
|
@ -546,7 +563,9 @@ def push_readmes(merge_build: bool):
|
|||
detach=False,
|
||||
stderr=True,
|
||||
stdout=True,
|
||||
tty=False))
|
||||
tty=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Build base-deps/ray-deps only on file change, 2 weeks, per release
|
||||
|
@ -566,63 +585,73 @@ if __name__ == "__main__":
|
|||
choices=list(PY_MATRIX.keys()),
|
||||
default="py37",
|
||||
nargs="*",
|
||||
help="Which python versions to build. "
|
||||
"Must be in (py36, py37, py38, py39)")
|
||||
help="Which python versions to build. " "Must be in (py36, py37, py38, py39)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device-types",
|
||||
choices=list(BASE_IMAGES.keys()),
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="Which device types (CPU/CUDA versions) to build images for. "
|
||||
"If not specified, images will be built for all device types.")
|
||||
"If not specified, images will be built for all device types.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--build-type",
|
||||
choices=BUILD_TYPES,
|
||||
required=True,
|
||||
help="Whether to bypass checking if docker is affected")
|
||||
help="Whether to bypass checking if docker is affected",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--build-base",
|
||||
dest="base",
|
||||
action="store_true",
|
||||
help="Whether to build base-deps & ray-deps")
|
||||
help="Whether to build base-deps & ray-deps",
|
||||
)
|
||||
parser.add_argument("--no-build-base", dest="base", action="store_false")
|
||||
parser.set_defaults(base=True)
|
||||
parser.add_argument(
|
||||
"--only-build-worker-container",
|
||||
dest="only_build_worker_container",
|
||||
action="store_true",
|
||||
help="Whether only to build ray-worker-container")
|
||||
help="Whether only to build ray-worker-container",
|
||||
)
|
||||
parser.set_defaults(only_build_worker_container=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
py_versions = args.py_versions
|
||||
py_versions = py_versions if isinstance(py_versions,
|
||||
list) else [py_versions]
|
||||
py_versions = py_versions if isinstance(py_versions, list) else [py_versions]
|
||||
|
||||
image_types = args.device_types if args.device_types else list(
|
||||
BASE_IMAGES.keys())
|
||||
image_types = args.device_types if args.device_types else list(BASE_IMAGES.keys())
|
||||
|
||||
assert set(list(CUDA_FULL.keys()) + ["cpu"]) == set(BASE_IMAGES.keys())
|
||||
|
||||
# Make sure the python images and cuda versions we build here are
|
||||
# consistent with the ones used with fix-latest-docker.sh script.
|
||||
py_version_file = os.path.join(_get_root_dir(), "docker/retag-lambda",
|
||||
"python_versions.txt")
|
||||
py_version_file = os.path.join(
|
||||
_get_root_dir(), "docker/retag-lambda", "python_versions.txt"
|
||||
)
|
||||
with open(py_version_file) as f:
|
||||
py_file_versions = f.read().splitlines()
|
||||
assert set(PY_MATRIX.keys()) == set(py_file_versions), \
|
||||
(PY_MATRIX.keys(), py_file_versions)
|
||||
assert set(PY_MATRIX.keys()) == set(py_file_versions), (
|
||||
PY_MATRIX.keys(),
|
||||
py_file_versions,
|
||||
)
|
||||
|
||||
cuda_version_file = os.path.join(_get_root_dir(), "docker/retag-lambda",
|
||||
"cuda_versions.txt")
|
||||
cuda_version_file = os.path.join(
|
||||
_get_root_dir(), "docker/retag-lambda", "cuda_versions.txt"
|
||||
)
|
||||
|
||||
with open(cuda_version_file) as f:
|
||||
cuda_file_versions = f.read().splitlines()
|
||||
assert set(BASE_IMAGES.keys()) == set(cuda_file_versions + ["cpu"]),\
|
||||
(BASE_IMAGES.keys(), cuda_file_versions + ["cpu"])
|
||||
assert set(BASE_IMAGES.keys()) == set(cuda_file_versions + ["cpu"]), (
|
||||
BASE_IMAGES.keys(),
|
||||
cuda_file_versions + ["cpu"],
|
||||
)
|
||||
|
||||
print("Building the following python versions: ",
|
||||
[PY_MATRIX[py_version] for py_version in py_versions])
|
||||
print(
|
||||
"Building the following python versions: ",
|
||||
[PY_MATRIX[py_version] for py_version in py_versions],
|
||||
)
|
||||
print("Building images for the following devices: ", image_types)
|
||||
print("Building base images: ", args.base)
|
||||
|
||||
|
@ -639,9 +668,11 @@ if __name__ == "__main__":
|
|||
if build_type == HUMAN:
|
||||
# If manually triggered, request user for branch and SHA value to use.
|
||||
_configure_human_version()
|
||||
if (build_type in {HUMAN, MERGE, BUILDKITE, LOCAL}
|
||||
or _check_if_docker_files_modified()
|
||||
or args.only_build_worker_container):
|
||||
if (
|
||||
build_type in {HUMAN, MERGE, BUILDKITE, LOCAL}
|
||||
or _check_if_docker_files_modified()
|
||||
or args.only_build_worker_container
|
||||
):
|
||||
DOCKER_CLIENT = docker.from_env()
|
||||
is_merge = build_type == MERGE
|
||||
# Buildkite is authenticated in the background.
|
||||
|
@ -652,11 +683,11 @@ if __name__ == "__main__":
|
|||
DOCKER_CLIENT.api.login(username=username, password=password)
|
||||
copy_wheels(build_type == HUMAN)
|
||||
is_base_images_built = build_or_pull_base_images(
|
||||
py_versions, image_types, args.base)
|
||||
py_versions, image_types, args.base
|
||||
)
|
||||
|
||||
if args.only_build_worker_container:
|
||||
build_for_all_versions("ray-worker-container", py_versions,
|
||||
image_types)
|
||||
build_for_all_versions("ray-worker-container", py_versions, image_types)
|
||||
# TODO Currently don't push ray_worker_container
|
||||
else:
|
||||
# Build Ray Docker images.
|
||||
|
@ -668,15 +699,19 @@ if __name__ == "__main__":
|
|||
prep_ray_ml()
|
||||
# Only build ML Docker for the ML_CUDA_VERSION
|
||||
build_for_all_versions(
|
||||
"ray-ml", py_versions, image_types=[ML_CUDA_VERSION])
|
||||
"ray-ml", py_versions, image_types=[ML_CUDA_VERSION]
|
||||
)
|
||||
|
||||
if build_type in {MERGE, PR}:
|
||||
valid_branch = _valid_branch()
|
||||
if (not valid_branch) and is_merge:
|
||||
print(f"Invalid Branch found: {_get_branch()}")
|
||||
push_and_tag_images(py_versions, image_types,
|
||||
is_base_images_built, valid_branch
|
||||
and is_merge)
|
||||
push_and_tag_images(
|
||||
py_versions,
|
||||
image_types,
|
||||
is_base_images_built,
|
||||
valid_branch and is_merge,
|
||||
)
|
||||
|
||||
# TODO(ilr) Re-Enable Push READMEs by using a normal password
|
||||
# (not auth token :/)
|
||||
|
|
|
@ -20,7 +20,8 @@ def build_multinode_image(source_image: str, target_image: str):
|
|||
f.write("RUN sudo apt install -y openssh-server\n")
|
||||
|
||||
subprocess.check_output(
|
||||
f"docker build -t {target_image} .", shell=True, cwd=tempdir)
|
||||
f"docker build -t {target_image} .", shell=True, cwd=tempdir
|
||||
)
|
||||
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
|
|
|
@ -25,9 +25,7 @@ def perform_check(raw_xml_string: str):
|
|||
missing_owners = []
|
||||
for rule in tree.findall("rule"):
|
||||
test_name = rule.attrib["name"]
|
||||
tags = [
|
||||
child.attrib["value"] for child in rule.find("list").getchildren()
|
||||
]
|
||||
tags = [child.attrib["value"] for child in rule.find("list").getchildren()]
|
||||
team_owner = [t for t in tags if t.startswith("team")]
|
||||
if len(team_owner) == 0:
|
||||
missing_owners.append(test_name)
|
||||
|
@ -36,7 +34,8 @@ def perform_check(raw_xml_string: str):
|
|||
if len(missing_owners):
|
||||
raise Exception(
|
||||
f"Cannot find owner for tests {missing_owners}, please add "
|
||||
"`team:*` to the tags.")
|
||||
"`team:*` to the tags."
|
||||
)
|
||||
|
||||
print(owners)
|
||||
|
||||
|
|
|
@ -19,11 +19,7 @@ exit_with_error = False
|
|||
|
||||
|
||||
def check_import(file):
|
||||
check_to_lines = {
|
||||
"import ray": -1,
|
||||
"import psutil": -1,
|
||||
"import setproctitle": -1
|
||||
}
|
||||
check_to_lines = {"import ray": -1, "import psutil": -1, "import setproctitle": -1}
|
||||
|
||||
with io.open(file, "r", encoding="utf-8") as f:
|
||||
for i, line in enumerate(f):
|
||||
|
@ -37,8 +33,10 @@ def check_import(file):
|
|||
# It will not match the following
|
||||
# - submodule import: `import ray.constants as ray_constants`
|
||||
# - submodule import: `from ray import xyz`
|
||||
if re.search(r"^\s*" + check + r"(\s*|\s+# noqa F401.*)$",
|
||||
line) and check_to_lines[check] == -1:
|
||||
if (
|
||||
re.search(r"^\s*" + check + r"(\s*|\s+# noqa F401.*)$", line)
|
||||
and check_to_lines[check] == -1
|
||||
):
|
||||
check_to_lines[check] = i
|
||||
|
||||
for import_lib in ["import psutil", "import setproctitle"]:
|
||||
|
@ -48,8 +46,8 @@ def check_import(file):
|
|||
if import_ray_line == -1 or import_ray_line > import_psutil_line:
|
||||
print(
|
||||
"{}:{}".format(str(file), import_psutil_line + 1),
|
||||
"{} without explicitly import ray before it.".format(
|
||||
import_lib))
|
||||
"{} without explicitly import ray before it.".format(import_lib),
|
||||
)
|
||||
global exit_with_error
|
||||
exit_with_error = True
|
||||
|
||||
|
@ -59,8 +57,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("path", help="File path to check. e.g. '.' or './src'")
|
||||
# TODO(simon): For the future, consider adding a feature to explicitly
|
||||
# white-list the path instead of skipping them.
|
||||
parser.add_argument(
|
||||
"-s", "--skip", action="append", help="Skip certian directory")
|
||||
parser.add_argument("-s", "--skip", action="append", help="Skip certian directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
file_path = Path(args.path)
|
||||
|
|
|
@ -18,7 +18,7 @@ DEFAULT_BLACKLIST = [
|
|||
"gpustat",
|
||||
"opencensus",
|
||||
"prometheus_client",
|
||||
"smart_open"
|
||||
"smart_open",
|
||||
]
|
||||
|
||||
|
||||
|
@ -28,19 +28,20 @@ def assert_packages_not_installed(blacklist: List[str]):
|
|||
except ImportError: # pip < 10.0
|
||||
from pip.operations import freeze
|
||||
|
||||
installed_packages = [
|
||||
p.split("==")[0].split(" @ ")[0] for p in freeze.freeze()
|
||||
]
|
||||
installed_packages = [p.split("==")[0].split(" @ ")[0] for p in freeze.freeze()]
|
||||
|
||||
assert not any(p in installed_packages for p in blacklist), \
|
||||
f"Found blacklisted packages in installed python packages: " \
|
||||
f"{[p for p in blacklist if p in installed_packages]}. " \
|
||||
f"Minimal dependency tests could be tainted by this. " \
|
||||
f"Check the install logs and primary dependencies if any of these " \
|
||||
assert not any(p in installed_packages for p in blacklist), (
|
||||
f"Found blacklisted packages in installed python packages: "
|
||||
f"{[p for p in blacklist if p in installed_packages]}. "
|
||||
f"Minimal dependency tests could be tainted by this. "
|
||||
f"Check the install logs and primary dependencies if any of these "
|
||||
f"packages were installed as part of another install step."
|
||||
)
|
||||
|
||||
print(f"Confirmed that blacklisted packages are not installed in "
|
||||
f"current Python environment: {blacklist}")
|
||||
print(
|
||||
f"Confirmed that blacklisted packages are not installed in "
|
||||
f"current Python environment: {blacklist}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -54,7 +54,8 @@ def run_tidy(task_queue, lock, timeout):
|
|||
command = task_queue.get()
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
|
||||
if timeout is not None:
|
||||
watchdog = threading.Timer(timeout, proc.kill)
|
||||
|
@ -70,22 +71,21 @@ def run_tidy(task_queue, lock, timeout):
|
|||
sys.stderr.flush()
|
||||
except Exception as e:
|
||||
with lock:
|
||||
sys.stderr.write("Failed: " + str(e) + ": ".join(command) +
|
||||
"\n")
|
||||
sys.stderr.write("Failed: " + str(e) + ": ".join(command) + "\n")
|
||||
finally:
|
||||
with lock:
|
||||
if timeout is not None and watchdog is not None:
|
||||
if not watchdog.is_alive():
|
||||
sys.stderr.write("Terminated by timeout: " +
|
||||
" ".join(command) + "\n")
|
||||
sys.stderr.write(
|
||||
"Terminated by timeout: " + " ".join(command) + "\n"
|
||||
)
|
||||
watchdog.cancel()
|
||||
task_queue.task_done()
|
||||
|
||||
|
||||
def start_workers(max_tasks, tidy_caller, task_queue, lock, timeout):
|
||||
for _ in range(max_tasks):
|
||||
t = threading.Thread(
|
||||
target=tidy_caller, args=(task_queue, lock, timeout))
|
||||
t = threading.Thread(target=tidy_caller, args=(task_queue, lock, timeout))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
|
@ -119,84 +119,87 @@ def main():
|
|||
parser = argparse.ArgumentParser(
|
||||
description="Run clang-tidy against changed files, and "
|
||||
"output diagnostics only for modified "
|
||||
"lines.")
|
||||
"lines."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-clang-tidy-binary",
|
||||
metavar="PATH",
|
||||
default="clang-tidy",
|
||||
help="path to clang-tidy binary")
|
||||
help="path to clang-tidy binary",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
metavar="NUM",
|
||||
default=0,
|
||||
help="strip the smallest prefix containing P slashes")
|
||||
help="strip the smallest prefix containing P slashes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-regex",
|
||||
metavar="PATTERN",
|
||||
default=None,
|
||||
help="custom pattern selecting file paths to check "
|
||||
"(case sensitive, overrides -iregex)")
|
||||
"(case sensitive, overrides -iregex)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-iregex",
|
||||
metavar="PATTERN",
|
||||
default=r".*\.(cpp|cc|c\+\+|cxx|c|cl|h|hpp|m|mm|inc)",
|
||||
help="custom pattern selecting file paths to check "
|
||||
"(case insensitive, overridden by -regex)")
|
||||
"(case insensitive, overridden by -regex)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of tidy instances to be run in parallel.")
|
||||
help="number of tidy instances to be run in parallel.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-timeout",
|
||||
type=int,
|
||||
default=None,
|
||||
help="timeout per each file in seconds.")
|
||||
"-timeout", type=int, default=None, help="timeout per each file in seconds."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-fix",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="apply suggested fixes")
|
||||
"-fix", action="store_true", default=False, help="apply suggested fixes"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-checks",
|
||||
help="checks filter, when not specified, use clang-tidy "
|
||||
"default",
|
||||
default="")
|
||||
help="checks filter, when not specified, use clang-tidy " "default",
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-path",
|
||||
dest="build_path",
|
||||
help="Path used to read a compile command database.")
|
||||
"-path", dest="build_path", help="Path used to read a compile command database."
|
||||
)
|
||||
if yaml:
|
||||
parser.add_argument(
|
||||
"-export-fixes",
|
||||
metavar="FILE",
|
||||
dest="export_fixes",
|
||||
help="Create a yaml file to store suggested fixes in, "
|
||||
"which can be applied with clang-apply-replacements.")
|
||||
"which can be applied with clang-apply-replacements.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-extra-arg",
|
||||
dest="extra_arg",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Additional argument to append to the compiler "
|
||||
"command line.")
|
||||
help="Additional argument to append to the compiler " "command line.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-extra-arg-before",
|
||||
dest="extra_arg_before",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Additional argument to prepend to the compiler "
|
||||
"command line.")
|
||||
help="Additional argument to prepend to the compiler " "command line.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-quiet",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run clang-tidy in quiet mode")
|
||||
help="Run clang-tidy in quiet mode",
|
||||
)
|
||||
clang_tidy_args = []
|
||||
argv = sys.argv[1:]
|
||||
if "--" in argv:
|
||||
clang_tidy_args.extend(argv[argv.index("--"):])
|
||||
argv = argv[:argv.index("--")]
|
||||
clang_tidy_args.extend(argv[argv.index("--") :])
|
||||
argv = argv[: argv.index("--")]
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
|
@ -204,7 +207,7 @@ def main():
|
|||
filename = None
|
||||
lines_by_file = {}
|
||||
for line in sys.stdin:
|
||||
match = re.search('^\+\+\+\ \"?(.*?/){%s}([^ \t\n\"]*)' % args.p, line)
|
||||
match = re.search('^\+\+\+\ "?(.*?/){%s}([^ \t\n"]*)' % args.p, line)
|
||||
if match:
|
||||
filename = match.group(2)
|
||||
if filename is None:
|
||||
|
@ -226,8 +229,7 @@ def main():
|
|||
if line_count == 0:
|
||||
continue
|
||||
end_line = start_line + line_count - 1
|
||||
lines_by_file.setdefault(filename,
|
||||
[]).append([start_line, end_line])
|
||||
lines_by_file.setdefault(filename, []).append([start_line, end_line])
|
||||
|
||||
if not any(lines_by_file):
|
||||
print("No relevant changes found.")
|
||||
|
@ -267,11 +269,8 @@ def main():
|
|||
|
||||
for name in lines_by_file:
|
||||
line_filter_json = json.dumps(
|
||||
[{
|
||||
"name": name,
|
||||
"lines": lines_by_file[name]
|
||||
}],
|
||||
separators=(",", ":"))
|
||||
[{"name": name, "lines": lines_by_file[name]}], separators=(",", ":")
|
||||
)
|
||||
|
||||
# Run clang-tidy on files containing changes.
|
||||
command = [args.clang_tidy_binary]
|
||||
|
|
|
@ -38,8 +38,10 @@ def is_pull_request():
|
|||
for key in ["GITHUB_EVENT_NAME", "TRAVIS_EVENT_TYPE"]:
|
||||
event_type = os.getenv(key, event_type)
|
||||
|
||||
if (os.environ.get("BUILDKITE")
|
||||
and os.environ.get("BUILDKITE_PULL_REQUEST") != "false"):
|
||||
if (
|
||||
os.environ.get("BUILDKITE")
|
||||
and os.environ.get("BUILDKITE_PULL_REQUEST") != "false"
|
||||
):
|
||||
event_type = "pull_request"
|
||||
|
||||
return event_type == "pull_request"
|
||||
|
@ -67,8 +69,7 @@ def get_commit_range():
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--output", type=str, help="json or envvars", default="envvars")
|
||||
parser.add_argument("--output", type=str, help="json or envvars", default="envvars")
|
||||
args = parser.parse_args()
|
||||
|
||||
RAY_CI_TUNE_AFFECTED = 0
|
||||
|
@ -103,8 +104,7 @@ if __name__ == "__main__":
|
|||
try:
|
||||
graph = pda.build_dep_graph()
|
||||
rllib_tests = pda.list_rllib_tests()
|
||||
print(
|
||||
"Total # of RLlib tests: ", len(rllib_tests), file=sys.stderr)
|
||||
print("Total # of RLlib tests: ", len(rllib_tests), file=sys.stderr)
|
||||
|
||||
impacted = {}
|
||||
for test in rllib_tests:
|
||||
|
@ -120,9 +120,7 @@ if __name__ == "__main__":
|
|||
print(e, file=sys.stderr)
|
||||
# End of dry run.
|
||||
|
||||
skip_prefix_list = [
|
||||
"doc/", "examples/", "dev/", "kubernetes/", "site/"
|
||||
]
|
||||
skip_prefix_list = ["doc/", "examples/", "dev/", "kubernetes/", "site/"]
|
||||
|
||||
for changed_file in files:
|
||||
if changed_file.startswith("python/ray/tune"):
|
||||
|
@ -181,7 +179,8 @@ if __name__ == "__main__":
|
|||
# Java also depends on Python CLI to manage processes.
|
||||
RAY_CI_JAVA_AFFECTED = 1
|
||||
if changed_file.startswith("python/setup.py") or re.match(
|
||||
".*requirements.*\.txt", changed_file):
|
||||
".*requirements.*\.txt", changed_file
|
||||
):
|
||||
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED = 1
|
||||
elif changed_file.startswith("java/"):
|
||||
RAY_CI_JAVA_AFFECTED = 1
|
||||
|
@ -190,12 +189,9 @@ if __name__ == "__main__":
|
|||
elif changed_file.startswith("docker/"):
|
||||
RAY_CI_DOCKER_AFFECTED = 1
|
||||
RAY_CI_LINUX_WHEELS_AFFECTED = 1
|
||||
elif changed_file.startswith("doc/") and changed_file.endswith(
|
||||
".py"):
|
||||
elif changed_file.startswith("doc/") and changed_file.endswith(".py"):
|
||||
RAY_CI_DOC_AFFECTED = 1
|
||||
elif any(
|
||||
changed_file.startswith(prefix)
|
||||
for prefix in skip_prefix_list):
|
||||
elif any(changed_file.startswith(prefix) for prefix in skip_prefix_list):
|
||||
# nothing is run but linting in these cases
|
||||
pass
|
||||
elif changed_file.endswith("build-docker-images.py"):
|
||||
|
@ -246,26 +242,28 @@ if __name__ == "__main__":
|
|||
RAY_CI_DASHBOARD_AFFECTED = 1
|
||||
|
||||
# Log the modified environment variables visible in console.
|
||||
output_string = " ".join([
|
||||
"RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED),
|
||||
"RAY_CI_SGD_AFFECTED={}".format(RAY_CI_SGD_AFFECTED),
|
||||
"RAY_CI_TRAIN_AFFECTED={}".format(RAY_CI_TRAIN_AFFECTED),
|
||||
"RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED),
|
||||
"RAY_CI_RLLIB_DIRECTLY_AFFECTED={}".format(
|
||||
RAY_CI_RLLIB_DIRECTLY_AFFECTED),
|
||||
"RAY_CI_SERVE_AFFECTED={}".format(RAY_CI_SERVE_AFFECTED),
|
||||
"RAY_CI_DASHBOARD_AFFECTED={}".format(RAY_CI_DASHBOARD_AFFECTED),
|
||||
"RAY_CI_DOC_AFFECTED={}".format(RAY_CI_DOC_AFFECTED),
|
||||
"RAY_CI_CORE_CPP_AFFECTED={}".format(RAY_CI_CORE_CPP_AFFECTED),
|
||||
"RAY_CI_CPP_AFFECTED={}".format(RAY_CI_CPP_AFFECTED),
|
||||
"RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED),
|
||||
"RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED),
|
||||
"RAY_CI_LINUX_WHEELS_AFFECTED={}".format(RAY_CI_LINUX_WHEELS_AFFECTED),
|
||||
"RAY_CI_MACOS_WHEELS_AFFECTED={}".format(RAY_CI_MACOS_WHEELS_AFFECTED),
|
||||
"RAY_CI_DOCKER_AFFECTED={}".format(RAY_CI_DOCKER_AFFECTED),
|
||||
"RAY_CI_PYTHON_DEPENDENCIES_AFFECTED={}".format(
|
||||
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED),
|
||||
])
|
||||
output_string = " ".join(
|
||||
[
|
||||
"RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED),
|
||||
"RAY_CI_SGD_AFFECTED={}".format(RAY_CI_SGD_AFFECTED),
|
||||
"RAY_CI_TRAIN_AFFECTED={}".format(RAY_CI_TRAIN_AFFECTED),
|
||||
"RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED),
|
||||
"RAY_CI_RLLIB_DIRECTLY_AFFECTED={}".format(RAY_CI_RLLIB_DIRECTLY_AFFECTED),
|
||||
"RAY_CI_SERVE_AFFECTED={}".format(RAY_CI_SERVE_AFFECTED),
|
||||
"RAY_CI_DASHBOARD_AFFECTED={}".format(RAY_CI_DASHBOARD_AFFECTED),
|
||||
"RAY_CI_DOC_AFFECTED={}".format(RAY_CI_DOC_AFFECTED),
|
||||
"RAY_CI_CORE_CPP_AFFECTED={}".format(RAY_CI_CORE_CPP_AFFECTED),
|
||||
"RAY_CI_CPP_AFFECTED={}".format(RAY_CI_CPP_AFFECTED),
|
||||
"RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED),
|
||||
"RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED),
|
||||
"RAY_CI_LINUX_WHEELS_AFFECTED={}".format(RAY_CI_LINUX_WHEELS_AFFECTED),
|
||||
"RAY_CI_MACOS_WHEELS_AFFECTED={}".format(RAY_CI_MACOS_WHEELS_AFFECTED),
|
||||
"RAY_CI_DOCKER_AFFECTED={}".format(RAY_CI_DOCKER_AFFECTED),
|
||||
"RAY_CI_PYTHON_DEPENDENCIES_AFFECTED={}".format(
|
||||
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Debug purpose
|
||||
print(output_string, file=sys.stderr)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# Cause the script to exit if a single command fails
|
||||
set -euo pipefail
|
||||
|
||||
BLACK_IS_ENABLED=false
|
||||
BLACK_IS_ENABLED=true
|
||||
|
||||
FLAKE8_VERSION_REQUIRED="3.9.1"
|
||||
BLACK_VERSION_REQUIRED="21.12b0"
|
||||
|
|
|
@ -17,19 +17,19 @@ import json
|
|||
|
||||
def gha_get_self_url():
|
||||
import requests
|
||||
|
||||
# stringed together api call to get the current check's html url.
|
||||
sha = os.environ["GITHUB_SHA"]
|
||||
repo = os.environ["GITHUB_REPOSITORY"]
|
||||
resp = requests.get(
|
||||
"https://api.github.com/repos/{}/commits/{}/check-suites".format(
|
||||
repo, sha))
|
||||
"https://api.github.com/repos/{}/commits/{}/check-suites".format(repo, sha)
|
||||
)
|
||||
data = resp.json()
|
||||
for check in data["check_suites"]:
|
||||
slug = check["app"]["slug"]
|
||||
if slug == "github-actions":
|
||||
run_url = check["check_runs_url"]
|
||||
html_url = (
|
||||
requests.get(run_url).json()["check_runs"][0]["html_url"])
|
||||
html_url = requests.get(run_url).json()["check_runs"][0]["html_url"]
|
||||
return html_url
|
||||
|
||||
# Return a fallback url
|
||||
|
@ -47,10 +47,12 @@ def get_build_env():
|
|||
if os.environ.get("BUILDKITE"):
|
||||
return {
|
||||
"TRAVIS_COMMIT": os.environ["BUILDKITE_COMMIT"],
|
||||
"TRAVIS_JOB_WEB_URL": (os.environ["BUILDKITE_BUILD_URL"] + "#" +
|
||||
os.environ["BUILDKITE_BUILD_ID"]),
|
||||
"TRAVIS_OS_NAME": # The map is used to stay consistent with Travis
|
||||
{
|
||||
"TRAVIS_JOB_WEB_URL": (
|
||||
os.environ["BUILDKITE_BUILD_URL"]
|
||||
+ "#"
|
||||
+ os.environ["BUILDKITE_BUILD_ID"]
|
||||
),
|
||||
"TRAVIS_OS_NAME": { # The map is used to stay consistent with Travis
|
||||
"linux": "linux",
|
||||
"darwin": "osx",
|
||||
"win32": "windows",
|
||||
|
@ -70,13 +72,10 @@ def get_build_config():
|
|||
return {"config": {"env": "Windows CI"}}
|
||||
|
||||
if os.environ.get("BUILDKITE"):
|
||||
return {
|
||||
"config": {
|
||||
"env": "Buildkite " + os.environ["BUILDKITE_LABEL"]
|
||||
}
|
||||
}
|
||||
return {"config": {"env": "Buildkite " + os.environ["BUILDKITE_LABEL"]}}
|
||||
|
||||
import requests
|
||||
|
||||
url = "https://api.travis-ci.com/job/{job_id}?include=job.config"
|
||||
url = url.format(job_id=os.environ["TRAVIS_JOB_ID"])
|
||||
resp = requests.get(url, headers={"Travis-API-Version": "3"})
|
||||
|
@ -87,9 +86,4 @@ if __name__ == "__main__":
|
|||
build_env = get_build_env()
|
||||
build_config = get_build_config()
|
||||
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"build_env": build_env,
|
||||
"build_config": build_config
|
||||
}, indent=2))
|
||||
print(json.dumps({"build_env": build_env, "build_config": build_config}, indent=2))
|
||||
|
|
|
@ -43,7 +43,8 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
|
|||
test: only return information about a specific test.
|
||||
"""
|
||||
tests_res = _run_shell(
|
||||
["bazel", "query", "tests(//python/ray/rllib:*)", "--output", "label"])
|
||||
["bazel", "query", "tests(//python/ray/rllib:*)", "--output", "label"]
|
||||
)
|
||||
|
||||
all_tests = []
|
||||
|
||||
|
@ -53,15 +54,18 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
|
|||
if test and t != test:
|
||||
continue
|
||||
|
||||
src_out = _run_shell([
|
||||
"bazel", "query", "kind(\"source file\", deps({}))".format(t),
|
||||
"--output", "label"
|
||||
])
|
||||
src_out = _run_shell(
|
||||
[
|
||||
"bazel",
|
||||
"query",
|
||||
'kind("source file", deps({}))'.format(t),
|
||||
"--output",
|
||||
"label",
|
||||
]
|
||||
)
|
||||
|
||||
srcs = [f.strip() for f in src_out.splitlines()]
|
||||
srcs = [
|
||||
f for f in srcs if f.startswith("//python") and f.endswith(".py")
|
||||
]
|
||||
srcs = [f for f in srcs if f.startswith("//python") and f.endswith(".py")]
|
||||
if srcs:
|
||||
all_tests.append((t, srcs))
|
||||
|
||||
|
@ -73,8 +77,7 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
|
|||
|
||||
|
||||
def _new_dep(graph: DepGraph, src_module: str, dep: str):
|
||||
"""Create a new dependency between src_module and dep.
|
||||
"""
|
||||
"""Create a new dependency between src_module and dep."""
|
||||
if dep not in graph.ids:
|
||||
graph.ids[dep] = len(graph.ids)
|
||||
|
||||
|
@ -87,8 +90,7 @@ def _new_dep(graph: DepGraph, src_module: str, dep: str):
|
|||
|
||||
|
||||
def _new_import(graph: DepGraph, src_module: str, dep_module: str):
|
||||
"""Process a new import statement in src_module.
|
||||
"""
|
||||
"""Process a new import statement in src_module."""
|
||||
# We don't care about system imports.
|
||||
if not dep_module.startswith("ray"):
|
||||
return
|
||||
|
@ -97,8 +99,7 @@ def _new_import(graph: DepGraph, src_module: str, dep_module: str):
|
|||
|
||||
|
||||
def _is_path_module(module: str, name: str, _base_dir: str) -> bool:
|
||||
"""Figure out if base.sub is a python module or not.
|
||||
"""
|
||||
"""Figure out if base.sub is a python module or not."""
|
||||
# Special handling for _raylet, which is a C++ lib.
|
||||
if module == "ray._raylet":
|
||||
return False
|
||||
|
@ -110,10 +111,10 @@ def _is_path_module(module: str, name: str, _base_dir: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _new_from_import(graph: DepGraph, src_module: str, dep_module: str,
|
||||
dep_name: str, _base_dir: str):
|
||||
"""Process a new "from ... import ..." statement in src_module.
|
||||
"""
|
||||
def _new_from_import(
|
||||
graph: DepGraph, src_module: str, dep_module: str, dep_name: str, _base_dir: str
|
||||
):
|
||||
"""Process a new "from ... import ..." statement in src_module."""
|
||||
# We don't care about imports outside of ray package.
|
||||
if not dep_module or not dep_module.startswith("ray"):
|
||||
return
|
||||
|
@ -126,10 +127,7 @@ def _new_from_import(graph: DepGraph, src_module: str, dep_module: str,
|
|||
_new_dep(graph, src_module, dep_module)
|
||||
|
||||
|
||||
def _process_file(graph: DepGraph,
|
||||
src_path: str,
|
||||
src_module: str,
|
||||
_base_dir=""):
|
||||
def _process_file(graph: DepGraph, src_path: str, src_module: str, _base_dir=""):
|
||||
"""Create dependencies from src_module to all the valid imports in src_path.
|
||||
|
||||
Args:
|
||||
|
@ -147,13 +145,13 @@ def _process_file(graph: DepGraph,
|
|||
_new_import(graph, src_module, alias.name)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
_new_from_import(graph, src_module, node.module,
|
||||
alias.name, _base_dir)
|
||||
_new_from_import(
|
||||
graph, src_module, node.module, alias.name, _base_dir
|
||||
)
|
||||
|
||||
|
||||
def build_dep_graph() -> DepGraph:
|
||||
"""Build index from py files to their immediate dependees.
|
||||
"""
|
||||
"""Build index from py files to their immediate dependees."""
|
||||
graph = DepGraph()
|
||||
|
||||
# Assuming we run from root /ray directory.
|
||||
|
@ -197,8 +195,7 @@ def _full_module_path(module, f) -> str:
|
|||
|
||||
|
||||
def _should_skip(d: str) -> bool:
|
||||
"""Skip directories that should not contain py sources.
|
||||
"""
|
||||
"""Skip directories that should not contain py sources."""
|
||||
if d.startswith("python/.eggs/"):
|
||||
return True
|
||||
if d.startswith("python/."):
|
||||
|
@ -224,14 +221,14 @@ def _bazel_path_to_module_path(d: str) -> str:
|
|||
|
||||
|
||||
def _file_path_to_module_path(f: str) -> str:
|
||||
"""Return the corresponding module path for a .py file.
|
||||
"""
|
||||
"""Return the corresponding module path for a .py file."""
|
||||
dir, fn = os.path.split(f)
|
||||
return _full_module_path(_bazel_path_to_module_path(dir), fn)
|
||||
|
||||
|
||||
def _depends(graph: DepGraph, visited: Dict[int, bool], tid: int,
|
||||
qid: int) -> List[int]:
|
||||
def _depends(
|
||||
graph: DepGraph, visited: Dict[int, bool], tid: int, qid: int
|
||||
) -> List[int]:
|
||||
"""Whether there is a dependency path from module tid to module qid.
|
||||
|
||||
Given graph, and without going through visited.
|
||||
|
@ -253,8 +250,9 @@ def _depends(graph: DepGraph, visited: Dict[int, bool], tid: int,
|
|||
return []
|
||||
|
||||
|
||||
def test_depends_on_file(graph: DepGraph, test: Tuple[str, Tuple[str]],
|
||||
path: str) -> List[int]:
|
||||
def test_depends_on_file(
|
||||
graph: DepGraph, test: Tuple[str, Tuple[str]], path: str
|
||||
) -> List[int]:
|
||||
"""Give dependency graph, check if a test depends on a specific .py file.
|
||||
|
||||
Args:
|
||||
|
@ -307,8 +305,7 @@ def _find_circular_dep_impl(graph: DepGraph, id: str, branch: str) -> bool:
|
|||
|
||||
|
||||
def find_circular_dep(graph: DepGraph) -> Dict[str, List[int]]:
|
||||
"""Find circular dependencies among a dependency graph.
|
||||
"""
|
||||
"""Find circular dependencies among a dependency graph."""
|
||||
known = {}
|
||||
circles = {}
|
||||
for m, id in graph.ids.items():
|
||||
|
@ -334,25 +331,29 @@ if __name__ == "__main__":
|
|||
"--mode",
|
||||
type=str,
|
||||
default="test-dep",
|
||||
help=("test-dep: find dependencies for a specified test. "
|
||||
"circular-dep: find circular dependencies in "
|
||||
"the specific codebase."))
|
||||
help=(
|
||||
"test-dep: find dependencies for a specified test. "
|
||||
"circular-dep: find circular dependencies in "
|
||||
"the specific codebase."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file",
|
||||
type=str,
|
||||
help="Path of a .py source file relative to --base_dir.")
|
||||
"--file", type=str, help="Path of a .py source file relative to --base_dir."
|
||||
)
|
||||
parser.add_argument("--test", type=str, help="Specific test to check.")
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
help="Load only a few tests for testing.")
|
||||
"--smoke-test", action="store_true", help="Load only a few tests for testing."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("building dep graph ...")
|
||||
graph = build_dep_graph()
|
||||
print("done. total {} files, {} of which have dependencies.".format(
|
||||
len(graph.ids), len(graph.edges)))
|
||||
print(
|
||||
"done. total {} files, {} of which have dependencies.".format(
|
||||
len(graph.ids), len(graph.edges)
|
||||
)
|
||||
)
|
||||
|
||||
if args.mode == "circular-dep":
|
||||
circles = find_circular_dep(graph)
|
||||
|
|
|
@ -12,30 +12,33 @@ class TestPyDepAnalysis(unittest.TestCase):
|
|||
f.close()
|
||||
|
||||
def test_full_module_path(self):
|
||||
self.assertEqual(
|
||||
pda._full_module_path("aa.bb.cc", "__init__.py"), "aa.bb.cc")
|
||||
self.assertEqual(
|
||||
pda._full_module_path("aa.bb.cc", "dd.py"), "aa.bb.cc.dd")
|
||||
self.assertEqual(pda._full_module_path("aa.bb.cc", "__init__.py"), "aa.bb.cc")
|
||||
self.assertEqual(pda._full_module_path("aa.bb.cc", "dd.py"), "aa.bb.cc.dd")
|
||||
self.assertEqual(pda._full_module_path("", "dd.py"), "dd")
|
||||
|
||||
def test_bazel_path_to_module_path(self):
|
||||
self.assertEqual(
|
||||
pda._bazel_path_to_module_path("//python/ray/rllib:xxx/yyy/dd"),
|
||||
"ray.rllib.xxx.yyy.dd")
|
||||
"ray.rllib.xxx.yyy.dd",
|
||||
)
|
||||
self.assertEqual(
|
||||
pda._bazel_path_to_module_path("python:ray/rllib/xxx/yyy/dd"),
|
||||
"ray.rllib.xxx.yyy.dd")
|
||||
"ray.rllib.xxx.yyy.dd",
|
||||
)
|
||||
self.assertEqual(
|
||||
pda._bazel_path_to_module_path("python/ray/rllib:xxx/yyy/dd"),
|
||||
"ray.rllib.xxx.yyy.dd")
|
||||
"ray.rllib.xxx.yyy.dd",
|
||||
)
|
||||
|
||||
def test_file_path_to_module_path(self):
|
||||
self.assertEqual(
|
||||
pda._file_path_to_module_path("python/ray/rllib/env/env.py"),
|
||||
"ray.rllib.env.env")
|
||||
"ray.rllib.env.env",
|
||||
)
|
||||
self.assertEqual(
|
||||
pda._file_path_to_module_path("python/ray/rllib/env/__init__.py"),
|
||||
"ray.rllib.env")
|
||||
"ray.rllib.env",
|
||||
)
|
||||
|
||||
def test_import_line_continuation(self):
|
||||
graph = pda.DepGraph()
|
||||
|
@ -44,11 +47,13 @@ class TestPyDepAnalysis(unittest.TestCase):
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = os.path.join(tmpdir, "continuation1.py")
|
||||
self.create_tmp_file(
|
||||
src_path, """
|
||||
src_path,
|
||||
"""
|
||||
import ray.rllib.env.\\
|
||||
mock_env
|
||||
b = 2
|
||||
""")
|
||||
""",
|
||||
)
|
||||
pda._process_file(graph, src_path, "ray")
|
||||
|
||||
self.assertEqual(len(graph.ids), 2)
|
||||
|
@ -64,11 +69,13 @@ b = 2
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = os.path.join(tmpdir, "continuation1.py")
|
||||
self.create_tmp_file(
|
||||
src_path, """
|
||||
src_path,
|
||||
"""
|
||||
from ray.rllib.env import (ClassName,
|
||||
module1, module2)
|
||||
b = 2
|
||||
""")
|
||||
""",
|
||||
)
|
||||
pda._process_file(graph, src_path, "ray")
|
||||
|
||||
self.assertEqual(len(graph.ids), 2)
|
||||
|
@ -84,11 +91,13 @@ b = 2
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = "multi_line_comment_3.py"
|
||||
self.create_tmp_file(
|
||||
os.path.join(tmpdir, src_path), """
|
||||
os.path.join(tmpdir, src_path),
|
||||
"""
|
||||
from ray.rllib.env import mock_env
|
||||
a = 1
|
||||
b = 2
|
||||
""")
|
||||
""",
|
||||
)
|
||||
# Touch ray/rllib/env/mock_env.py in tmpdir,
|
||||
# so that it looks like a module.
|
||||
module_dir = os.path.join(tmpdir, "python", "ray", "rllib", "env")
|
||||
|
@ -112,11 +121,13 @@ b = 2
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = "multi_line_comment_3.py"
|
||||
self.create_tmp_file(
|
||||
os.path.join(tmpdir, src_path), """
|
||||
os.path.join(tmpdir, src_path),
|
||||
"""
|
||||
from ray.rllib.env import MockEnv
|
||||
a = 1
|
||||
b = 2
|
||||
""")
|
||||
""",
|
||||
)
|
||||
# Touch ray/rllib/env.py in tmpdir,
|
||||
# MockEnv is a class on env module.
|
||||
module_dir = os.path.join(tmpdir, "python", "ray", "rllib")
|
||||
|
@ -138,4 +149,5 @@ b = 2
|
|||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -22,8 +22,11 @@ import ray.ray_constants as ray_constants
|
|||
import ray._private.services
|
||||
import ray._private.utils
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
|
||||
from ray._private.gcs_utils import GcsClient, \
|
||||
get_gcs_address_from_redis, use_gcs_for_bootstrap
|
||||
from ray._private.gcs_utils import (
|
||||
GcsClient,
|
||||
get_gcs_address_from_redis,
|
||||
use_gcs_for_bootstrap,
|
||||
)
|
||||
from ray.core.generated import agent_manager_pb2
|
||||
from ray.core.generated import agent_manager_pb2_grpc
|
||||
from ray._private.ray_logging import setup_component_logger
|
||||
|
@ -42,23 +45,25 @@ aiogrpc.init_grpc_aio()
|
|||
|
||||
|
||||
class DashboardAgent(object):
|
||||
def __init__(self,
|
||||
node_ip_address,
|
||||
redis_address,
|
||||
dashboard_agent_port,
|
||||
gcs_address,
|
||||
minimal,
|
||||
redis_password=None,
|
||||
temp_dir=None,
|
||||
session_dir=None,
|
||||
runtime_env_dir=None,
|
||||
log_dir=None,
|
||||
metrics_export_port=None,
|
||||
node_manager_port=None,
|
||||
listen_port=0,
|
||||
object_store_name=None,
|
||||
raylet_name=None,
|
||||
logging_params=None):
|
||||
def __init__(
|
||||
self,
|
||||
node_ip_address,
|
||||
redis_address,
|
||||
dashboard_agent_port,
|
||||
gcs_address,
|
||||
minimal,
|
||||
redis_password=None,
|
||||
temp_dir=None,
|
||||
session_dir=None,
|
||||
runtime_env_dir=None,
|
||||
log_dir=None,
|
||||
metrics_export_port=None,
|
||||
node_manager_port=None,
|
||||
listen_port=0,
|
||||
object_store_name=None,
|
||||
raylet_name=None,
|
||||
logging_params=None,
|
||||
):
|
||||
"""Initialize the DashboardAgent object."""
|
||||
# Public attributes are accessible for all agent modules.
|
||||
self.ip = node_ip_address
|
||||
|
@ -92,15 +97,16 @@ class DashboardAgent(object):
|
|||
self.ppid = int(os.environ["RAY_RAYLET_PID"])
|
||||
assert self.ppid > 0
|
||||
logger.info("Parent pid is %s", self.ppid)
|
||||
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
|
||||
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
|
||||
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
|
||||
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
|
||||
self.server, f"{grpc_ip}:{self.dashboard_agent_port}")
|
||||
logger.info("Dashboard agent grpc address: %s:%s", grpc_ip,
|
||||
self.grpc_port)
|
||||
options = (("grpc.enable_http_proxy", 0), )
|
||||
self.server, f"{grpc_ip}:{self.dashboard_agent_port}"
|
||||
)
|
||||
logger.info("Dashboard agent grpc address: %s:%s", grpc_ip, self.grpc_port)
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel(
|
||||
f"{self.ip}:{self.node_manager_port}", options, asynchronous=True)
|
||||
f"{self.ip}:{self.node_manager_port}", options, asynchronous=True
|
||||
)
|
||||
|
||||
# If the agent is started as non-minimal version, http server should
|
||||
# be configured to communicate with the dashboard in a head node.
|
||||
|
@ -108,6 +114,7 @@ class DashboardAgent(object):
|
|||
|
||||
async def _configure_http_server(self, modules):
|
||||
from ray.dashboard.http_server_agent import HttpServerAgent
|
||||
|
||||
http_server = HttpServerAgent(self.ip, self.listen_port)
|
||||
await http_server.start(modules)
|
||||
return http_server
|
||||
|
@ -116,10 +123,12 @@ class DashboardAgent(object):
|
|||
"""Load dashboard agent modules."""
|
||||
modules = []
|
||||
agent_cls_list = dashboard_utils.get_all_modules(
|
||||
dashboard_utils.DashboardAgentModule)
|
||||
dashboard_utils.DashboardAgentModule
|
||||
)
|
||||
for cls in agent_cls_list:
|
||||
logger.info("Loading %s: %s",
|
||||
dashboard_utils.DashboardAgentModule.__name__, cls)
|
||||
logger.info(
|
||||
"Loading %s: %s", dashboard_utils.DashboardAgentModule.__name__, cls
|
||||
)
|
||||
c = cls(self)
|
||||
modules.append(c)
|
||||
logger.info("Loaded %d modules.", len(modules))
|
||||
|
@ -137,13 +146,12 @@ class DashboardAgent(object):
|
|||
curr_proc = psutil.Process()
|
||||
while True:
|
||||
parent = curr_proc.parent()
|
||||
if (parent is None or parent.pid == 1
|
||||
or self.ppid != parent.pid):
|
||||
if parent is None or parent.pid == 1 or self.ppid != parent.pid:
|
||||
logger.error("Raylet is dead, exiting.")
|
||||
sys.exit(0)
|
||||
await asyncio.sleep(
|
||||
dashboard_consts.
|
||||
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS)
|
||||
dashboard_consts.DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Failed to check parent PID, exiting.")
|
||||
sys.exit(1)
|
||||
|
@ -154,15 +162,17 @@ class DashboardAgent(object):
|
|||
if not use_gcs_for_bootstrap():
|
||||
# Create an aioredis client for all modules.
|
||||
try:
|
||||
self.aioredis_client = \
|
||||
await dashboard_utils.get_aioredis_client(
|
||||
self.redis_address, self.redis_password,
|
||||
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
|
||||
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES)
|
||||
self.aioredis_client = await dashboard_utils.get_aioredis_client(
|
||||
self.redis_address,
|
||||
self.redis_password,
|
||||
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
|
||||
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES,
|
||||
)
|
||||
except (socket.gaierror, ConnectionRefusedError):
|
||||
logger.error(
|
||||
"Dashboard agent exiting: "
|
||||
"Failed to connect to redis at %s", self.redis_address)
|
||||
"Dashboard agent exiting: " "Failed to connect to redis at %s",
|
||||
self.redis_address,
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
# Start a grpc asyncio server.
|
||||
|
@ -170,7 +180,8 @@ class DashboardAgent(object):
|
|||
|
||||
if not use_gcs_for_bootstrap():
|
||||
gcs_address = await self.aioredis_client.get(
|
||||
dashboard_consts.GCS_SERVER_ADDRESS)
|
||||
dashboard_consts.GCS_SERVER_ADDRESS
|
||||
)
|
||||
self.gcs_client = GcsClient(address=gcs_address.decode())
|
||||
else:
|
||||
self.gcs_client = GcsClient(address=self.gcs_address)
|
||||
|
@ -192,17 +203,21 @@ class DashboardAgent(object):
|
|||
internal_kv._internal_kv_put(
|
||||
f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{self.node_id}",
|
||||
json.dumps([http_port, self.grpc_port]),
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||
)
|
||||
|
||||
# Register agent to agent manager.
|
||||
raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub(
|
||||
self.aiogrpc_raylet_channel)
|
||||
self.aiogrpc_raylet_channel
|
||||
)
|
||||
|
||||
await raylet_stub.RegisterAgent(
|
||||
agent_manager_pb2.RegisterAgentRequest(
|
||||
agent_pid=os.getpid(),
|
||||
agent_port=self.grpc_port,
|
||||
agent_ip_address=self.ip))
|
||||
agent_ip_address=self.ip,
|
||||
)
|
||||
)
|
||||
|
||||
tasks = [m.run(self.server) for m in modules]
|
||||
if sys.platform not in ["win32", "cygwin"]:
|
||||
|
@ -221,123 +236,139 @@ if __name__ == "__main__":
|
|||
"--node-ip-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="the IP address of this node.")
|
||||
help="the IP address of this node.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gcs-address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="The address (ip:port) of GCS.")
|
||||
"--gcs-address", required=False, type=str, help="The address (ip:port) of GCS."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The address to use for Redis.")
|
||||
"--redis-address", required=True, type=str, help="The address to use for Redis."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metrics-export-port",
|
||||
required=True,
|
||||
type=int,
|
||||
help="The port to expose metrics through Prometheus.")
|
||||
help="The port to expose metrics through Prometheus.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dashboard-agent-port",
|
||||
required=True,
|
||||
type=int,
|
||||
help="The port on which the dashboard agent will receive GRPCs.")
|
||||
help="The port on which the dashboard agent will receive GRPCs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--node-manager-port",
|
||||
required=True,
|
||||
type=int,
|
||||
help="The port to use for starting the node manager")
|
||||
help="The port to use for starting the node manager",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--object-store-name",
|
||||
required=True,
|
||||
type=str,
|
||||
default=None,
|
||||
help="The socket name of the plasma store")
|
||||
help="The socket name of the plasma store",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--listen-port",
|
||||
required=False,
|
||||
type=int,
|
||||
default=0,
|
||||
help="Port for HTTP server to listen on")
|
||||
help="Port for HTTP server to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raylet-name",
|
||||
required=True,
|
||||
type=str,
|
||||
default=None,
|
||||
help="The socket path of the raylet process")
|
||||
help="The socket path of the raylet process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--redis-password",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="The password to use for Redis")
|
||||
help="The password to use for Redis",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging-level",
|
||||
required=False,
|
||||
type=lambda s: logging.getLevelName(s.upper()),
|
||||
default=ray_constants.LOGGER_LEVEL,
|
||||
choices=ray_constants.LOGGER_LEVEL_CHOICES,
|
||||
help=ray_constants.LOGGER_LEVEL_HELP)
|
||||
help=ray_constants.LOGGER_LEVEL_HELP,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging-format",
|
||||
required=False,
|
||||
type=str,
|
||||
default=ray_constants.LOGGER_FORMAT,
|
||||
help=ray_constants.LOGGER_FORMAT_HELP)
|
||||
help=ray_constants.LOGGER_FORMAT_HELP,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging-filename",
|
||||
required=False,
|
||||
type=str,
|
||||
default=dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME,
|
||||
help="Specify the name of log file, "
|
||||
"log to stdout if set empty, default is \"{}\".".format(
|
||||
dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME))
|
||||
'log to stdout if set empty, default is "{}".'.format(
|
||||
dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging-rotate-bytes",
|
||||
required=False,
|
||||
type=int,
|
||||
default=ray_constants.LOGGING_ROTATE_BYTES,
|
||||
help="Specify the max bytes for rotating "
|
||||
"log file, default is {} bytes.".format(
|
||||
ray_constants.LOGGING_ROTATE_BYTES))
|
||||
"log file, default is {} bytes.".format(ray_constants.LOGGING_ROTATE_BYTES),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging-rotate-backup-count",
|
||||
required=False,
|
||||
type=int,
|
||||
default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
|
||||
help="Specify the backup count of rotated log file, default is {}.".
|
||||
format(ray_constants.LOGGING_ROTATE_BACKUP_COUNT))
|
||||
help="Specify the backup count of rotated log file, default is {}.".format(
|
||||
ray_constants.LOGGING_ROTATE_BACKUP_COUNT
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir",
|
||||
required=True,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify the path of log directory.")
|
||||
help="Specify the path of log directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp-dir",
|
||||
required=True,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify the path of the temporary directory use by Ray process.")
|
||||
help="Specify the path of the temporary directory use by Ray process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--session-dir",
|
||||
required=True,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify the path of this session.")
|
||||
help="Specify the path of this session.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--runtime-env-dir",
|
||||
required=True,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify the path of the resource directory used by runtime_env.")
|
||||
help="Specify the path of the resource directory used by runtime_env.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--minimal",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Minimal agent only contains a subset of features that don't "
|
||||
"require additional dependencies installed when ray is installed "
|
||||
"by `pip install ray[default]`."))
|
||||
"by `pip install ray[default]`."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
|
@ -347,7 +378,8 @@ if __name__ == "__main__":
|
|||
log_dir=args.log_dir,
|
||||
filename=args.logging_filename,
|
||||
max_bytes=args.logging_rotate_bytes,
|
||||
backup_count=args.logging_rotate_backup_count)
|
||||
backup_count=args.logging_rotate_backup_count,
|
||||
)
|
||||
setup_component_logger(**logging_params)
|
||||
|
||||
agent = DashboardAgent(
|
||||
|
@ -366,7 +398,8 @@ if __name__ == "__main__":
|
|||
listen_port=args.listen_port,
|
||||
object_store_name=args.object_store_name,
|
||||
raylet_name=args.raylet_name,
|
||||
logging_params=logging_params)
|
||||
logging_params=logging_params,
|
||||
)
|
||||
if os.environ.get("_RAY_AGENT_FAILING"):
|
||||
raise Exception("Failure injection failure.")
|
||||
|
||||
|
@ -390,15 +423,19 @@ if __name__ == "__main__":
|
|||
gcs_publisher = GcsPublisher(args.gcs_address)
|
||||
else:
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
args.redis_address, password=args.redis_password)
|
||||
args.redis_address, password=args.redis_password
|
||||
)
|
||||
gcs_publisher = GcsPublisher(
|
||||
address=get_gcs_address_from_redis(redis_client))
|
||||
address=get_gcs_address_from_redis(redis_client)
|
||||
)
|
||||
else:
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
args.redis_address, password=args.redis_password)
|
||||
args.redis_address, password=args.redis_password
|
||||
)
|
||||
|
||||
traceback_str = ray._private.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
traceback.format_exc()
|
||||
)
|
||||
message = (
|
||||
f"(ip={node_ip}) "
|
||||
f"The agent on node {platform.uname()[1]} failed to "
|
||||
|
@ -409,12 +446,14 @@ if __name__ == "__main__":
|
|||
"\n 2. Metrics on this node won't be reported."
|
||||
"\n 3. runtime_env APIs won't work."
|
||||
"\nCheck out the `dashboard_agent.log` to see the "
|
||||
"detailed failure messages.")
|
||||
"detailed failure messages."
|
||||
)
|
||||
ray._private.utils.publish_error_to_driver(
|
||||
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
|
||||
message,
|
||||
redis_client=redis_client,
|
||||
gcs_publisher=gcs_publisher)
|
||||
gcs_publisher=gcs_publisher,
|
||||
)
|
||||
logger.error(message)
|
||||
logger.exception(e)
|
||||
exit(1)
|
||||
|
|
|
@ -12,12 +12,13 @@ DASHBOARD_RPC_ADDRESS = "dashboard_rpc"
|
|||
GCS_SERVER_ADDRESS = "GcsServerAddress"
|
||||
# GCS check alive
|
||||
GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR = env_integer(
|
||||
"GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR", 10)
|
||||
GCS_CHECK_ALIVE_INTERVAL_SECONDS = env_integer(
|
||||
"GCS_CHECK_ALIVE_INTERVAL_SECONDS", 5)
|
||||
"GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR", 10
|
||||
)
|
||||
GCS_CHECK_ALIVE_INTERVAL_SECONDS = env_integer("GCS_CHECK_ALIVE_INTERVAL_SECONDS", 5)
|
||||
GCS_CHECK_ALIVE_RPC_TIMEOUT = env_integer("GCS_CHECK_ALIVE_RPC_TIMEOUT", 10)
|
||||
GCS_RETRY_CONNECT_INTERVAL_SECONDS = env_integer(
|
||||
"GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2)
|
||||
"GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2
|
||||
)
|
||||
# aiohttp_cache
|
||||
AIOHTTP_CACHE_TTL_SECONDS = 2
|
||||
AIOHTTP_CACHE_MAX_SIZE = 128
|
||||
|
|
|
@ -38,17 +38,21 @@ class FrontendNotFoundError(OSError):
|
|||
|
||||
def setup_static_dir():
|
||||
build_dir = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "client", "build")
|
||||
os.path.dirname(os.path.abspath(__file__)), "client", "build"
|
||||
)
|
||||
module_name = os.path.basename(os.path.dirname(__file__))
|
||||
if not os.path.isdir(build_dir):
|
||||
raise FrontendNotFoundError(
|
||||
errno.ENOENT, "Dashboard build directory not found. If installing "
|
||||
errno.ENOENT,
|
||||
"Dashboard build directory not found. If installing "
|
||||
"from source, please follow the additional steps "
|
||||
"required to build the dashboard"
|
||||
f"(cd python/ray/{module_name}/client "
|
||||
"&& npm install "
|
||||
"&& npm ci "
|
||||
"&& npm run build)", build_dir)
|
||||
"&& npm run build)",
|
||||
build_dir,
|
||||
)
|
||||
|
||||
static_dir = os.path.join(build_dir, "static")
|
||||
routes.static("/static", static_dir, follow_symlinks=True)
|
||||
|
@ -72,14 +76,16 @@ class Dashboard:
|
|||
log_dir(str): Log directory of dashboard.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
host,
|
||||
port,
|
||||
port_retries,
|
||||
gcs_address,
|
||||
redis_address,
|
||||
redis_password=None,
|
||||
log_dir=None):
|
||||
def __init__(
|
||||
self,
|
||||
host,
|
||||
port,
|
||||
port_retries,
|
||||
gcs_address,
|
||||
redis_address,
|
||||
redis_password=None,
|
||||
log_dir=None,
|
||||
):
|
||||
self.dashboard_head = dashboard_head.DashboardHead(
|
||||
http_host=host,
|
||||
http_port=port,
|
||||
|
@ -87,7 +93,8 @@ class Dashboard:
|
|||
gcs_address=gcs_address,
|
||||
redis_address=redis_address,
|
||||
redis_password=redis_password,
|
||||
log_dir=log_dir)
|
||||
log_dir=log_dir,
|
||||
)
|
||||
|
||||
# Setup Dashboard Routes
|
||||
try:
|
||||
|
@ -107,15 +114,17 @@ class Dashboard:
|
|||
async def get_index(self, req) -> aiohttp.web.FileResponse:
|
||||
return aiohttp.web.FileResponse(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"client/build/index.html"))
|
||||
os.path.dirname(os.path.abspath(__file__)), "client/build/index.html"
|
||||
)
|
||||
)
|
||||
|
||||
@routes.get("/favicon.ico")
|
||||
async def get_favicon(self, req) -> aiohttp.web.FileResponse:
|
||||
return aiohttp.web.FileResponse(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"client/build/favicon.ico"))
|
||||
os.path.dirname(os.path.abspath(__file__)), "client/build/favicon.ico"
|
||||
)
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
await self.dashboard_head.run()
|
||||
|
@ -124,92 +133,96 @@ class Dashboard:
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Ray dashboard.")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The host to use for the HTTP server.")
|
||||
"--host", required=True, type=str, help="The host to use for the HTTP server."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
required=True,
|
||||
type=int,
|
||||
help="The port to use for the HTTP server.")
|
||||
"--port", required=True, type=int, help="The port to use for the HTTP server."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port-retries",
|
||||
required=False,
|
||||
type=int,
|
||||
default=0,
|
||||
help="The retry times to select a valid port.")
|
||||
help="The retry times to select a valid port.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gcs-address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="The address (ip:port) of GCS.")
|
||||
"--gcs-address", required=False, type=str, help="The address (ip:port) of GCS."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The address to use for Redis.")
|
||||
"--redis-address", required=True, type=str, help="The address to use for Redis."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--redis-password",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="The password to use for Redis")
|
||||
help="The password to use for Redis",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging-level",
|
||||
required=False,
|
||||
type=lambda s: logging.getLevelName(s.upper()),
|
||||
default=ray_constants.LOGGER_LEVEL,
|
||||
choices=ray_constants.LOGGER_LEVEL_CHOICES,
|
||||
help=ray_constants.LOGGER_LEVEL_HELP)
|
||||
help=ray_constants.LOGGER_LEVEL_HELP,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging-format",
|
||||
required=False,
|
||||
type=str,
|
||||
default=ray_constants.LOGGER_FORMAT,
|
||||
help=ray_constants.LOGGER_FORMAT_HELP)
|
||||
help=ray_constants.LOGGER_FORMAT_HELP,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging-filename",
|
||||
required=False,
|
||||
type=str,
|
||||
default=dashboard_consts.DASHBOARD_LOG_FILENAME,
|
||||
help="Specify the name of log file, "
|
||||
"log to stdout if set empty, default is \"{}\"".format(
|
||||
dashboard_consts.DASHBOARD_LOG_FILENAME))
|
||||
'log to stdout if set empty, default is "{}"'.format(
|
||||
dashboard_consts.DASHBOARD_LOG_FILENAME
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging-rotate-bytes",
|
||||
required=False,
|
||||
type=int,
|
||||
default=ray_constants.LOGGING_ROTATE_BYTES,
|
||||
help="Specify the max bytes for rotating "
|
||||
"log file, default is {} bytes.".format(
|
||||
ray_constants.LOGGING_ROTATE_BYTES))
|
||||
"log file, default is {} bytes.".format(ray_constants.LOGGING_ROTATE_BYTES),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging-rotate-backup-count",
|
||||
required=False,
|
||||
type=int,
|
||||
default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
|
||||
help="Specify the backup count of rotated log file, default is {}.".
|
||||
format(ray_constants.LOGGING_ROTATE_BACKUP_COUNT))
|
||||
help="Specify the backup count of rotated log file, default is {}.".format(
|
||||
ray_constants.LOGGING_ROTATE_BACKUP_COUNT
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-dir",
|
||||
required=True,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify the path of log directory.")
|
||||
help="Specify the path of log directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp-dir",
|
||||
required=True,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify the path of the temporary directory use by Ray process.")
|
||||
help="Specify the path of the temporary directory use by Ray process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--minimal",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Minimal dashboard only contains a subset of features that don't "
|
||||
"require additional dependencies installed when ray is installed "
|
||||
"by `pip install ray[default]`."))
|
||||
"by `pip install ray[default]`."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -226,7 +239,8 @@ if __name__ == "__main__":
|
|||
log_dir=args.log_dir,
|
||||
filename=args.logging_filename,
|
||||
max_bytes=args.logging_rotate_bytes,
|
||||
backup_count=args.logging_rotate_backup_count)
|
||||
backup_count=args.logging_rotate_backup_count,
|
||||
)
|
||||
|
||||
dashboard = Dashboard(
|
||||
args.host,
|
||||
|
@ -235,25 +249,27 @@ if __name__ == "__main__":
|
|||
args.gcs_address,
|
||||
args.redis_address,
|
||||
redis_password=args.redis_password,
|
||||
log_dir=args.log_dir)
|
||||
log_dir=args.log_dir,
|
||||
)
|
||||
# TODO(fyrestone): Avoid using ray.state in dashboard, it's not
|
||||
# asynchronous and will lead to low performance. ray disconnect()
|
||||
# will be hang when the ray.state is connected and the GCS is exit.
|
||||
# Please refer to: https://github.com/ray-project/ray/issues/16328
|
||||
service_discovery = PrometheusServiceDiscoveryWriter(
|
||||
args.redis_address, args.redis_password, args.gcs_address,
|
||||
args.temp_dir)
|
||||
args.redis_address, args.redis_password, args.gcs_address, args.temp_dir
|
||||
)
|
||||
# Need daemon True to avoid dashboard hangs at exit.
|
||||
service_discovery.daemon = True
|
||||
service_discovery.start()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(dashboard.run())
|
||||
except Exception as e:
|
||||
traceback_str = ray._private.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
message = f"The dashboard on node {platform.uname()[1]} " \
|
||||
f"failed with the following " \
|
||||
f"error:\n{traceback_str}"
|
||||
traceback_str = ray._private.utils.format_error_message(traceback.format_exc())
|
||||
message = (
|
||||
f"The dashboard on node {platform.uname()[1]} "
|
||||
f"failed with the following "
|
||||
f"error:\n{traceback_str}"
|
||||
)
|
||||
if isinstance(e, FrontendNotFoundError):
|
||||
logger.warning(message)
|
||||
else:
|
||||
|
@ -268,17 +284,21 @@ if __name__ == "__main__":
|
|||
gcs_publisher = GcsPublisher(args.gcs_address)
|
||||
else:
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
args.redis_address, password=args.redis_password)
|
||||
args.redis_address, password=args.redis_password
|
||||
)
|
||||
gcs_publisher = GcsPublisher(
|
||||
address=gcs_utils.get_gcs_address_from_redis(redis_client))
|
||||
address=gcs_utils.get_gcs_address_from_redis(redis_client)
|
||||
)
|
||||
redis_client = None
|
||||
else:
|
||||
redis_client = ray._private.services.create_redis_client(
|
||||
args.redis_address, password=args.redis_password)
|
||||
args.redis_address, password=args.redis_password
|
||||
)
|
||||
|
||||
ray._private.utils.publish_error_to_driver(
|
||||
redis_client,
|
||||
ray_constants.DASHBOARD_DIED_ERROR,
|
||||
message,
|
||||
redis_client=redis_client,
|
||||
gcs_publisher=gcs_publisher)
|
||||
gcs_publisher=gcs_publisher,
|
||||
)
|
||||
|
|
|
@ -2,9 +2,9 @@ import asyncio
|
|||
import logging
|
||||
import ray.dashboard.consts as dashboard_consts
|
||||
import ray.dashboard.memory_utils as memory_utils
|
||||
|
||||
# TODO(fyrestone): Not import from dashboard module.
|
||||
from ray.dashboard.modules.actor.actor_utils import \
|
||||
actor_classname_from_task_spec
|
||||
from ray.dashboard.modules.actor.actor_utils import actor_classname_from_task_spec
|
||||
from ray.dashboard.utils import Dict, Signal, async_loop_forever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -132,9 +132,9 @@ class DataOrganizer:
|
|||
worker["errorCount"] = len(node_errs.get(str(pid), []))
|
||||
worker["coreWorkerStats"] = pid_to_worker_stats.get(pid, [])
|
||||
worker["language"] = pid_to_language.get(
|
||||
pid, dashboard_consts.DEFAULT_LANGUAGE)
|
||||
worker["jobId"] = pid_to_job_id.get(
|
||||
pid, dashboard_consts.DEFAULT_JOB_ID)
|
||||
pid, dashboard_consts.DEFAULT_LANGUAGE
|
||||
)
|
||||
worker["jobId"] = pid_to_job_id.get(pid, dashboard_consts.DEFAULT_JOB_ID)
|
||||
|
||||
await GlobalSignals.worker_info_fetched.send(node_id, worker)
|
||||
|
||||
|
@ -143,8 +143,7 @@ class DataOrganizer:
|
|||
|
||||
@classmethod
|
||||
async def get_node_info(cls, node_id):
|
||||
node_physical_stats = dict(
|
||||
DataSource.node_physical_stats.get(node_id, {}))
|
||||
node_physical_stats = dict(DataSource.node_physical_stats.get(node_id, {}))
|
||||
node_stats = dict(DataSource.node_stats.get(node_id, {}))
|
||||
node = DataSource.nodes.get(node_id, {})
|
||||
node_ip = DataSource.node_id_to_ip.get(node_id)
|
||||
|
@ -162,8 +161,8 @@ class DataOrganizer:
|
|||
|
||||
view_data = node_stats.get("viewData", [])
|
||||
ray_stats = cls._extract_view_data(
|
||||
view_data,
|
||||
{"object_store_used_memory", "object_store_available_memory"})
|
||||
view_data, {"object_store_used_memory", "object_store_available_memory"}
|
||||
)
|
||||
|
||||
node_info = node_physical_stats
|
||||
# Merge node stats to node physical stats under raylet
|
||||
|
@ -184,8 +183,7 @@ class DataOrganizer:
|
|||
|
||||
@classmethod
|
||||
async def get_node_summary(cls, node_id):
|
||||
node_physical_stats = dict(
|
||||
DataSource.node_physical_stats.get(node_id, {}))
|
||||
node_physical_stats = dict(DataSource.node_physical_stats.get(node_id, {}))
|
||||
node_stats = dict(DataSource.node_stats.get(node_id, {}))
|
||||
node = DataSource.nodes.get(node_id, {})
|
||||
|
||||
|
@ -193,8 +191,8 @@ class DataOrganizer:
|
|||
node_stats.pop("workersStats", None)
|
||||
view_data = node_stats.get("viewData", [])
|
||||
ray_stats = cls._extract_view_data(
|
||||
view_data,
|
||||
{"object_store_used_memory", "object_store_available_memory"})
|
||||
view_data, {"object_store_used_memory", "object_store_available_memory"}
|
||||
)
|
||||
node_stats.pop("viewData", None)
|
||||
|
||||
node_summary = node_physical_stats
|
||||
|
@ -244,8 +242,9 @@ class DataOrganizer:
|
|||
actor = dict(actor)
|
||||
worker_id = actor["address"]["workerId"]
|
||||
core_worker_stats = DataSource.core_worker_stats.get(worker_id, {})
|
||||
actor_constructor = core_worker_stats.get("actorTitle",
|
||||
"Unknown actor constructor")
|
||||
actor_constructor = core_worker_stats.get(
|
||||
"actorTitle", "Unknown actor constructor"
|
||||
)
|
||||
actor["actorConstructor"] = actor_constructor
|
||||
actor.update(core_worker_stats)
|
||||
|
||||
|
@ -275,8 +274,12 @@ class DataOrganizer:
|
|||
@classmethod
|
||||
async def get_actor_creation_tasks(cls):
|
||||
infeasible_tasks = sum(
|
||||
(list(node_stats.get("infeasibleTasks", []))
|
||||
for node_stats in DataSource.node_stats.values()), [])
|
||||
(
|
||||
list(node_stats.get("infeasibleTasks", []))
|
||||
for node_stats in DataSource.node_stats.values()
|
||||
),
|
||||
[],
|
||||
)
|
||||
new_infeasible_tasks = []
|
||||
for task in infeasible_tasks:
|
||||
task = dict(task)
|
||||
|
@ -285,8 +288,12 @@ class DataOrganizer:
|
|||
new_infeasible_tasks.append(task)
|
||||
|
||||
resource_pending_tasks = sum(
|
||||
(list(data.get("readyTasks", []))
|
||||
for data in DataSource.node_stats.values()), [])
|
||||
(
|
||||
list(data.get("readyTasks", []))
|
||||
for data in DataSource.node_stats.values()
|
||||
),
|
||||
[],
|
||||
)
|
||||
new_resource_pending_tasks = []
|
||||
for task in resource_pending_tasks:
|
||||
task = dict(task)
|
||||
|
@ -301,14 +308,17 @@ class DataOrganizer:
|
|||
return results
|
||||
|
||||
@classmethod
|
||||
async def get_memory_table(cls,
|
||||
sort_by=memory_utils.SortingType.OBJECT_SIZE,
|
||||
group_by=memory_utils.GroupByType.STACK_TRACE):
|
||||
async def get_memory_table(
|
||||
cls,
|
||||
sort_by=memory_utils.SortingType.OBJECT_SIZE,
|
||||
group_by=memory_utils.GroupByType.STACK_TRACE,
|
||||
):
|
||||
all_worker_stats = []
|
||||
for node_stats in DataSource.node_stats.values():
|
||||
all_worker_stats.extend(node_stats.get("coreWorkersStats", []))
|
||||
memory_information = memory_utils.construct_memory_table(
|
||||
all_worker_stats, group_by=group_by, sort_by=sort_by)
|
||||
all_worker_stats, group_by=group_by, sort_by=sort_by
|
||||
)
|
||||
return memory_information
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -10,6 +10,7 @@ from queue import Queue
|
|||
|
||||
from distutils.version import LooseVersion
|
||||
import grpc
|
||||
|
||||
try:
|
||||
from grpc import aio as aiogrpc
|
||||
except ImportError:
|
||||
|
@ -23,8 +24,11 @@ import ray.dashboard.consts as dashboard_consts
|
|||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||
from ray import ray_constants
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, \
|
||||
GcsAioErrorSubscriber, GcsAioLogSubscriber
|
||||
from ray._private.gcs_pubsub import (
|
||||
gcs_pubsub_enabled,
|
||||
GcsAioErrorSubscriber,
|
||||
GcsAioLogSubscriber,
|
||||
)
|
||||
from ray.core.generated import gcs_service_pb2
|
||||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
from ray.dashboard.datacenter import DataOrganizer
|
||||
|
@ -42,33 +46,33 @@ aiogrpc.init_grpc_aio()
|
|||
GRPC_CHANNEL_OPTIONS = (
|
||||
("grpc.enable_http_proxy", 0),
|
||||
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
|
||||
("grpc.max_receive_message_length",
|
||||
ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
|
||||
("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
|
||||
)
|
||||
|
||||
|
||||
async def get_gcs_address_with_retry(redis_client) -> str:
|
||||
while True:
|
||||
try:
|
||||
gcs_address = (await redis_client.get(
|
||||
dashboard_consts.GCS_SERVER_ADDRESS)).decode()
|
||||
gcs_address = (
|
||||
await redis_client.get(dashboard_consts.GCS_SERVER_ADDRESS)
|
||||
).decode()
|
||||
if not gcs_address:
|
||||
raise Exception("GCS address not found.")
|
||||
logger.info("Connect to GCS at %s", gcs_address)
|
||||
return gcs_address
|
||||
except Exception as ex:
|
||||
logger.error("Connect to GCS failed: %s, retry...", ex)
|
||||
await asyncio.sleep(
|
||||
dashboard_consts.GCS_RETRY_CONNECT_INTERVAL_SECONDS)
|
||||
await asyncio.sleep(dashboard_consts.GCS_RETRY_CONNECT_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
class GCSHealthCheckThread(threading.Thread):
|
||||
def __init__(self, gcs_address: str):
|
||||
self.grpc_gcs_channel = ray._private.utils.init_grpc_channel(
|
||||
gcs_address, options=GRPC_CHANNEL_OPTIONS)
|
||||
self.gcs_heartbeat_info_stub = (
|
||||
gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
|
||||
self.grpc_gcs_channel))
|
||||
gcs_address, options=GRPC_CHANNEL_OPTIONS
|
||||
)
|
||||
self.gcs_heartbeat_info_stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
|
||||
self.grpc_gcs_channel
|
||||
)
|
||||
self.work_queue = Queue()
|
||||
|
||||
super().__init__(daemon=True)
|
||||
|
@ -83,10 +87,10 @@ class GCSHealthCheckThread(threading.Thread):
|
|||
request = gcs_service_pb2.CheckAliveRequest()
|
||||
try:
|
||||
reply = self.gcs_heartbeat_info_stub.CheckAlive(
|
||||
request, timeout=dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT)
|
||||
request, timeout=dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT
|
||||
)
|
||||
if reply.status.code != 0:
|
||||
logger.exception(
|
||||
f"Failed to CheckAlive: {reply.status.message}")
|
||||
logger.exception(f"Failed to CheckAlive: {reply.status.message}")
|
||||
return False
|
||||
except grpc.RpcError: # Deadline Exceeded
|
||||
logger.exception("Got RpcError when checking GCS is alive")
|
||||
|
@ -95,9 +99,9 @@ class GCSHealthCheckThread(threading.Thread):
|
|||
|
||||
async def check_once(self) -> bool:
|
||||
"""Ask the thread to perform a healthcheck."""
|
||||
assert threading.current_thread != self, (
|
||||
"caller shouldn't be from the same thread as GCSHealthCheckThread."
|
||||
)
|
||||
assert (
|
||||
threading.current_thread != self
|
||||
), "caller shouldn't be from the same thread as GCSHealthCheckThread."
|
||||
|
||||
future = Future()
|
||||
self.work_queue.put(future)
|
||||
|
@ -105,8 +109,16 @@ class GCSHealthCheckThread(threading.Thread):
|
|||
|
||||
|
||||
class DashboardHead:
|
||||
def __init__(self, http_host, http_port, http_port_retries, gcs_address,
|
||||
redis_address, redis_password, log_dir):
|
||||
def __init__(
|
||||
self,
|
||||
http_host,
|
||||
http_port,
|
||||
http_port_retries,
|
||||
gcs_address,
|
||||
redis_address,
|
||||
redis_password,
|
||||
log_dir,
|
||||
):
|
||||
self.health_check_thread: GCSHealthCheckThread = None
|
||||
self._gcs_rpc_error_counter = 0
|
||||
# Public attributes are accessible for all head modules.
|
||||
|
@ -134,12 +146,12 @@ class DashboardHead:
|
|||
else:
|
||||
ip, port = gcs_address.split(":")
|
||||
|
||||
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
|
||||
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
|
||||
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
|
||||
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
|
||||
self.server, f"{grpc_ip}:0")
|
||||
logger.info("Dashboard head grpc address: %s:%s", grpc_ip,
|
||||
self.grpc_port)
|
||||
self.server, f"{grpc_ip}:0"
|
||||
)
|
||||
logger.info("Dashboard head grpc address: %s:%s", grpc_ip, self.grpc_port)
|
||||
|
||||
@async_loop_forever(dashboard_consts.GCS_CHECK_ALIVE_INTERVAL_SECONDS)
|
||||
async def _gcs_check_alive(self):
|
||||
|
@ -149,7 +161,8 @@ class DashboardHead:
|
|||
# Otherwise, the dashboard will always think that gcs is alive.
|
||||
try:
|
||||
is_alive = await asyncio.wait_for(
|
||||
check_future, dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT + 1)
|
||||
check_future, dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT + 1
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Failed to check gcs health, client timed out.")
|
||||
is_alive = False
|
||||
|
@ -158,13 +171,16 @@ class DashboardHead:
|
|||
self._gcs_rpc_error_counter = 0
|
||||
else:
|
||||
self._gcs_rpc_error_counter += 1
|
||||
if self._gcs_rpc_error_counter > \
|
||||
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR:
|
||||
if (
|
||||
self._gcs_rpc_error_counter
|
||||
> dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR
|
||||
):
|
||||
logger.error(
|
||||
"Dashboard exiting because it received too many GCS RPC "
|
||||
"errors count: %s, threshold is %s.",
|
||||
self._gcs_rpc_error_counter,
|
||||
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR)
|
||||
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR,
|
||||
)
|
||||
# TODO(fyrestone): Do not use ray.state in
|
||||
# PrometheusServiceDiscoveryWriter.
|
||||
# Currently, we use os._exit() here to avoid hanging at the ray
|
||||
|
@ -176,10 +192,12 @@ class DashboardHead:
|
|||
"""Load dashboard head modules."""
|
||||
modules = []
|
||||
head_cls_list = dashboard_utils.get_all_modules(
|
||||
dashboard_utils.DashboardHeadModule)
|
||||
dashboard_utils.DashboardHeadModule
|
||||
)
|
||||
for cls in head_cls_list:
|
||||
logger.info("Loading %s: %s",
|
||||
dashboard_utils.DashboardHeadModule.__name__, cls)
|
||||
logger.info(
|
||||
"Loading %s: %s", dashboard_utils.DashboardHeadModule.__name__, cls
|
||||
)
|
||||
c = cls(self)
|
||||
dashboard_optional_utils.ClassMethodRouteTable.bind(c)
|
||||
modules.append(c)
|
||||
|
@ -192,15 +210,17 @@ class DashboardHead:
|
|||
return self.gcs_address
|
||||
else:
|
||||
try:
|
||||
self.aioredis_client = \
|
||||
await dashboard_utils.get_aioredis_client(
|
||||
self.redis_address, self.redis_password,
|
||||
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
|
||||
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES)
|
||||
self.aioredis_client = await dashboard_utils.get_aioredis_client(
|
||||
self.redis_address,
|
||||
self.redis_password,
|
||||
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
|
||||
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES,
|
||||
)
|
||||
except (socket.gaierror, ConnectionError):
|
||||
logger.error(
|
||||
"Dashboard head exiting: "
|
||||
"Failed to connect to redis at %s", self.redis_address)
|
||||
"Dashboard head exiting: " "Failed to connect to redis at %s",
|
||||
self.redis_address,
|
||||
)
|
||||
sys.exit(-1)
|
||||
return await get_gcs_address_with_retry(self.aioredis_client)
|
||||
|
||||
|
@ -209,22 +229,20 @@ class DashboardHead:
|
|||
# Create a http session for all modules.
|
||||
# aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
|
||||
if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
|
||||
self.http_session = aiohttp.ClientSession(
|
||||
loop=asyncio.get_event_loop())
|
||||
self.http_session = aiohttp.ClientSession(loop=asyncio.get_event_loop())
|
||||
else:
|
||||
self.http_session = aiohttp.ClientSession()
|
||||
|
||||
gcs_address = await self.get_gcs_address()
|
||||
|
||||
# Dashboard will handle connection failure automatically
|
||||
self.gcs_client = GcsClient(
|
||||
address=gcs_address, nums_reconnect_retry=0)
|
||||
self.gcs_client = GcsClient(address=gcs_address, nums_reconnect_retry=0)
|
||||
internal_kv._initialize_internal_kv(self.gcs_client)
|
||||
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
|
||||
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True)
|
||||
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True
|
||||
)
|
||||
if gcs_pubsub_enabled():
|
||||
self.gcs_error_subscriber = GcsAioErrorSubscriber(
|
||||
address=gcs_address)
|
||||
self.gcs_error_subscriber = GcsAioErrorSubscriber(address=gcs_address)
|
||||
self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address)
|
||||
await self.gcs_error_subscriber.subscribe()
|
||||
await self.gcs_log_subscriber.subscribe()
|
||||
|
@ -248,7 +266,7 @@ class DashboardHead:
|
|||
|
||||
# Http server should be initialized after all modules loaded.
|
||||
# working_dir uploads for job submission can be up to 100MiB.
|
||||
app = aiohttp.web.Application(client_max_size=100 * 1024**2)
|
||||
app = aiohttp.web.Application(client_max_size=100 * 1024 ** 2)
|
||||
app.add_routes(routes=routes.bound_routes())
|
||||
|
||||
runner = aiohttp.web.AppRunner(app)
|
||||
|
@ -256,8 +274,7 @@ class DashboardHead:
|
|||
last_ex = None
|
||||
for i in range(1 + self.http_port_retries):
|
||||
try:
|
||||
site = aiohttp.web.TCPSite(runner, self.http_host,
|
||||
self.http_port)
|
||||
site = aiohttp.web.TCPSite(runner, self.http_host, self.http_port)
|
||||
await site.start()
|
||||
break
|
||||
except OSError as e:
|
||||
|
@ -265,11 +282,14 @@ class DashboardHead:
|
|||
self.http_port += 1
|
||||
logger.warning("Try to use port %s: %s", self.http_port, e)
|
||||
else:
|
||||
raise Exception(f"Failed to find a valid port for dashboard after "
|
||||
f"{self.http_port_retries} retries: {last_ex}")
|
||||
raise Exception(
|
||||
f"Failed to find a valid port for dashboard after "
|
||||
f"{self.http_port_retries} retries: {last_ex}"
|
||||
)
|
||||
http_host, http_port, *_ = site._server.sockets[0].getsockname()
|
||||
http_host = self.ip if ipaddress.ip_address(
|
||||
http_host).is_unspecified else http_host
|
||||
http_host = (
|
||||
self.ip if ipaddress.ip_address(http_host).is_unspecified else http_host
|
||||
)
|
||||
logger.info("Dashboard head http address: %s:%s", http_host, http_port)
|
||||
|
||||
# TODO: Use async version if performance is an issue
|
||||
|
@ -277,16 +297,16 @@ class DashboardHead:
|
|||
internal_kv._internal_kv_put(
|
||||
ray_constants.DASHBOARD_ADDRESS,
|
||||
f"{http_host}:{http_port}",
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||
)
|
||||
internal_kv._internal_kv_put(
|
||||
dashboard_consts.DASHBOARD_RPC_ADDRESS,
|
||||
f"{self.ip}:{self.grpc_port}",
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||
)
|
||||
|
||||
# Dump registered http routes.
|
||||
dump_routes = [
|
||||
r for r in app.router.routes() if r.method != hdrs.METH_HEAD
|
||||
]
|
||||
dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD]
|
||||
for r in dump_routes:
|
||||
logger.info(r)
|
||||
logger.info("Registered %s routes.", len(dump_routes))
|
||||
|
@ -299,6 +319,5 @@ class DashboardHead:
|
|||
DataOrganizer.purge(),
|
||||
DataOrganizer.organize(),
|
||||
]
|
||||
await asyncio.gather(*concurrent_tasks,
|
||||
*(m.run(self.server) for m in modules))
|
||||
await asyncio.gather(*concurrent_tasks, *(m.run(self.server) for m in modules))
|
||||
await self.server.wait_for_termination()
|
||||
|
|
|
@ -23,8 +23,7 @@ class HttpServerAgent:
|
|||
# Create a http session for all modules.
|
||||
# aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
|
||||
if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
|
||||
self.http_session = aiohttp.ClientSession(
|
||||
loop=asyncio.get_event_loop())
|
||||
self.http_session = aiohttp.ClientSession(loop=asyncio.get_event_loop())
|
||||
else:
|
||||
self.http_session = aiohttp.ClientSession()
|
||||
|
||||
|
@ -47,25 +46,26 @@ class HttpServerAgent:
|
|||
allow_methods="*",
|
||||
allow_headers=("Content-Type", "X-Header"),
|
||||
)
|
||||
})
|
||||
},
|
||||
)
|
||||
for route in list(app.router.routes()):
|
||||
cors.add(route)
|
||||
|
||||
self.runner = aiohttp.web.AppRunner(app)
|
||||
await self.runner.setup()
|
||||
site = aiohttp.web.TCPSite(
|
||||
self.runner, "127.0.0.1"
|
||||
if self.ip == "127.0.0.1" else "0.0.0.0", self.listen_port)
|
||||
self.runner,
|
||||
"127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0",
|
||||
self.listen_port,
|
||||
)
|
||||
await site.start()
|
||||
self.http_host, self.http_port, *_ = (
|
||||
site._server.sockets[0].getsockname())
|
||||
logger.info("Dashboard agent http address: %s:%s", self.http_host,
|
||||
self.http_port)
|
||||
self.http_host, self.http_port, *_ = site._server.sockets[0].getsockname()
|
||||
logger.info(
|
||||
"Dashboard agent http address: %s:%s", self.http_host, self.http_port
|
||||
)
|
||||
|
||||
# Dump registered http routes.
|
||||
dump_routes = [
|
||||
r for r in app.router.routes() if r.method != hdrs.METH_HEAD
|
||||
]
|
||||
dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD]
|
||||
for r in dump_routes:
|
||||
logger.info(r)
|
||||
logger.info("Registered %s routes.", len(dump_routes))
|
||||
|
|
|
@ -31,7 +31,7 @@ def cpu_percent():
|
|||
delta in total host cpu usage, averaged over host's cpus.
|
||||
|
||||
Since deltas are not initially available, return 0.0 on first call.
|
||||
""" # noqa
|
||||
""" # noqa
|
||||
global last_system_usage
|
||||
global last_cpu_usage
|
||||
try:
|
||||
|
@ -43,12 +43,10 @@ def cpu_percent():
|
|||
else:
|
||||
cpu_delta = cpu_usage - last_cpu_usage
|
||||
# "System time passed." (Typically close to clock time.)
|
||||
system_delta = (
|
||||
(system_usage - last_system_usage) / _host_num_cpus())
|
||||
system_delta = (system_usage - last_system_usage) / _host_num_cpus()
|
||||
|
||||
quotient = cpu_delta / system_delta
|
||||
cpu_percent = round(
|
||||
quotient * 100 / ray._private.utils.get_k8s_cpus(), 1)
|
||||
cpu_percent = round(quotient * 100 / ray._private.utils.get_k8s_cpus(), 1)
|
||||
last_system_usage = system_usage
|
||||
last_cpu_usage = cpu_usage
|
||||
# Computed percentage might be slightly above 100%.
|
||||
|
@ -73,14 +71,14 @@ def _system_usage():
|
|||
|
||||
See also the /proc/stat entry here:
|
||||
https://man7.org/linux/man-pages/man5/proc.5.html
|
||||
""" # noqa
|
||||
""" # noqa
|
||||
cpu_summary_str = open(PROC_STAT_PATH).read().split("\n")[0]
|
||||
parts = cpu_summary_str.split()
|
||||
assert parts[0] == "cpu"
|
||||
usage_data = parts[1:8]
|
||||
total_clock_ticks = sum(int(entry) for entry in usage_data)
|
||||
# 100 clock ticks per second, 10^9 ns per second
|
||||
usage_ns = total_clock_ticks * 10**7
|
||||
usage_ns = total_clock_ticks * 10 ** 7
|
||||
return usage_ns
|
||||
|
||||
|
||||
|
@ -91,7 +89,8 @@ def _host_num_cpus():
|
|||
proc_stat_lines = open(PROC_STAT_PATH).read().split("\n")
|
||||
split_proc_stat_lines = [line.split() for line in proc_stat_lines]
|
||||
cpu_lines = [
|
||||
split_line for split_line in split_proc_stat_lines
|
||||
split_line
|
||||
for split_line in split_proc_stat_lines
|
||||
if len(split_line) > 0 and "cpu" in split_line[0]
|
||||
]
|
||||
# Number of lines starting with a word including 'cpu', subtracting
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import List
|
|||
|
||||
import ray
|
||||
|
||||
from ray._raylet import (TaskID, ActorID, JobID)
|
||||
from ray._raylet import TaskID, ActorID, JobID
|
||||
from ray.internal.internal_api import node_stats
|
||||
import logging
|
||||
|
||||
|
@ -69,8 +69,10 @@ def get_sorting_type(sort_by: str):
|
|||
elif sort_by == "REFERENCE_TYPE":
|
||||
return SortingType.REFERENCE_TYPE
|
||||
else:
|
||||
raise Exception("The sort-by input provided is not one of\
|
||||
PID, OBJECT_SIZE, or REFERENCE_TYPE.")
|
||||
raise Exception(
|
||||
"The sort-by input provided is not one of\
|
||||
PID, OBJECT_SIZE, or REFERENCE_TYPE."
|
||||
)
|
||||
|
||||
|
||||
def get_group_by_type(group_by: str):
|
||||
|
@ -81,13 +83,16 @@ def get_group_by_type(group_by: str):
|
|||
elif group_by == "STACK_TRACE":
|
||||
return GroupByType.STACK_TRACE
|
||||
else:
|
||||
raise Exception("The group-by input provided is not one of\
|
||||
NODE_ADDRESS or STACK_TRACE.")
|
||||
raise Exception(
|
||||
"The group-by input provided is not one of\
|
||||
NODE_ADDRESS or STACK_TRACE."
|
||||
)
|
||||
|
||||
|
||||
class MemoryTableEntry:
|
||||
def __init__(self, *, object_ref: dict, node_address: str, is_driver: bool,
|
||||
pid: int):
|
||||
def __init__(
|
||||
self, *, object_ref: dict, node_address: str, is_driver: bool, pid: int
|
||||
):
|
||||
# worker info
|
||||
self.is_driver = is_driver
|
||||
self.pid = pid
|
||||
|
@ -97,13 +102,13 @@ class MemoryTableEntry:
|
|||
self.object_size = int(object_ref.get("objectSize", -1))
|
||||
self.call_site = object_ref.get("callSite", "<Unknown>")
|
||||
self.object_ref = ray.ObjectRef(
|
||||
decode_object_ref_if_needed(object_ref["objectId"]))
|
||||
decode_object_ref_if_needed(object_ref["objectId"])
|
||||
)
|
||||
|
||||
# reference info
|
||||
self.local_ref_count = int(object_ref.get("localRefCount", 0))
|
||||
self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False))
|
||||
self.submitted_task_ref_count = int(
|
||||
object_ref.get("submittedTaskRefCount", 0))
|
||||
self.submitted_task_ref_count = int(object_ref.get("submittedTaskRefCount", 0))
|
||||
self.contained_in_owned = [
|
||||
ray.ObjectRef(decode_object_ref_if_needed(object_ref))
|
||||
for object_ref in object_ref.get("containedInOwned", [])
|
||||
|
@ -113,9 +118,12 @@ class MemoryTableEntry:
|
|||
def is_valid(self) -> bool:
|
||||
# If the entry doesn't have a reference type or some invalid state,
|
||||
# (e.g., no object ref presented), it is considered invalid.
|
||||
if (not self.pinned_in_memory and self.local_ref_count == 0
|
||||
and self.submitted_task_ref_count == 0
|
||||
and len(self.contained_in_owned) == 0):
|
||||
if (
|
||||
not self.pinned_in_memory
|
||||
and self.local_ref_count == 0
|
||||
and self.submitted_task_ref_count == 0
|
||||
and len(self.contained_in_owned) == 0
|
||||
):
|
||||
return False
|
||||
elif self.object_ref.is_nil():
|
||||
return False
|
||||
|
@ -153,10 +161,10 @@ class MemoryTableEntry:
|
|||
# are not all 'f', that means it is an actor creation
|
||||
# task, which is an actor handle.
|
||||
random_bits = object_ref_hex[:TASKID_RANDOM_BITS_SIZE]
|
||||
actor_random_bits = object_ref_hex[TASKID_RANDOM_BITS_SIZE:
|
||||
TASKID_RANDOM_BITS_SIZE +
|
||||
ACTORID_RANDOM_BITS_SIZE]
|
||||
if (random_bits == "f" * 16 and not actor_random_bits == "f" * 24):
|
||||
actor_random_bits = object_ref_hex[
|
||||
TASKID_RANDOM_BITS_SIZE : TASKID_RANDOM_BITS_SIZE + ACTORID_RANDOM_BITS_SIZE
|
||||
]
|
||||
if random_bits == "f" * 16 and not actor_random_bits == "f" * 24:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
@ -175,7 +183,7 @@ class MemoryTableEntry:
|
|||
"contained_in_owned": [
|
||||
object_ref.hex() for object_ref in self.contained_in_owned
|
||||
],
|
||||
"type": "Driver" if self.is_driver else "Worker"
|
||||
"type": "Driver" if self.is_driver else "Worker",
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
|
@ -186,10 +194,12 @@ class MemoryTableEntry:
|
|||
|
||||
|
||||
class MemoryTable:
|
||||
def __init__(self,
|
||||
entries: List[MemoryTableEntry],
|
||||
group_by_type: GroupByType = GroupByType.NODE_ADDRESS,
|
||||
sort_by_type: SortingType = SortingType.PID):
|
||||
def __init__(
|
||||
self,
|
||||
entries: List[MemoryTableEntry],
|
||||
group_by_type: GroupByType = GroupByType.NODE_ADDRESS,
|
||||
sort_by_type: SortingType = SortingType.PID,
|
||||
):
|
||||
self.table = entries
|
||||
# Group is a list of memory tables grouped by a group key.
|
||||
self.group = {}
|
||||
|
@ -247,7 +257,7 @@ class MemoryTable:
|
|||
"total_pinned_in_memory": total_pinned_in_memory,
|
||||
"total_used_by_pending_task": total_used_by_pending_task,
|
||||
"total_captured_in_objects": total_captured_in_objects,
|
||||
"total_actor_handles": total_actor_handles
|
||||
"total_actor_handles": total_actor_handles,
|
||||
}
|
||||
return self
|
||||
|
||||
|
@ -278,7 +288,8 @@ class MemoryTable:
|
|||
# Build a group table.
|
||||
for group_key, entries in group.items():
|
||||
self.group[group_key] = MemoryTable(
|
||||
entries, group_by_type=None, sort_by_type=None)
|
||||
entries, group_by_type=None, sort_by_type=None
|
||||
)
|
||||
for group_key, group_memory_table in self.group.items():
|
||||
group_memory_table.summarize()
|
||||
return self
|
||||
|
@ -289,10 +300,10 @@ class MemoryTable:
|
|||
"group": {
|
||||
group_key: {
|
||||
"entries": group_memory_table.get_entries(),
|
||||
"summary": group_memory_table.summary
|
||||
"summary": group_memory_table.summary,
|
||||
}
|
||||
for group_key, group_memory_table in self.group.items()
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def get_entries(self) -> List[dict]:
|
||||
|
@ -305,9 +316,11 @@ class MemoryTable:
|
|||
return self.__repr__()
|
||||
|
||||
|
||||
def construct_memory_table(workers_stats: List,
|
||||
group_by: GroupByType = GroupByType.NODE_ADDRESS,
|
||||
sort_by=SortingType.OBJECT_SIZE) -> MemoryTable:
|
||||
def construct_memory_table(
|
||||
workers_stats: List,
|
||||
group_by: GroupByType = GroupByType.NODE_ADDRESS,
|
||||
sort_by=SortingType.OBJECT_SIZE,
|
||||
) -> MemoryTable:
|
||||
memory_table_entries = []
|
||||
for core_worker_stats in workers_stats:
|
||||
pid = core_worker_stats["pid"]
|
||||
|
@ -320,11 +333,13 @@ def construct_memory_table(workers_stats: List,
|
|||
object_ref=object_ref,
|
||||
node_address=node_address,
|
||||
is_driver=is_driver,
|
||||
pid=pid)
|
||||
pid=pid,
|
||||
)
|
||||
if memory_table_entry.is_valid():
|
||||
memory_table_entries.append(memory_table_entry)
|
||||
memory_table = MemoryTable(
|
||||
memory_table_entries, group_by_type=group_by, sort_by_type=sort_by)
|
||||
memory_table_entries, group_by_type=group_by, sort_by_type=sort_by
|
||||
)
|
||||
return memory_table
|
||||
|
||||
|
||||
|
@ -337,7 +352,7 @@ def track_reference_size(group):
|
|||
"PINNED_IN_MEMORY": "total_pinned_in_memory",
|
||||
"USED_BY_PENDING_TASK": "total_used_by_pending_task",
|
||||
"CAPTURED_IN_OBJECT": "total_captured_in_objects",
|
||||
"ACTOR_HANDLE": "total_actor_handles"
|
||||
"ACTOR_HANDLE": "total_actor_handles",
|
||||
}
|
||||
for entry in group["entries"]:
|
||||
size = entry["object_size"]
|
||||
|
@ -348,51 +363,64 @@ def track_reference_size(group):
|
|||
return d
|
||||
|
||||
|
||||
def memory_summary(state,
|
||||
group_by="NODE_ADDRESS",
|
||||
sort_by="OBJECT_SIZE",
|
||||
line_wrap=True,
|
||||
unit="B",
|
||||
num_entries=None) -> str:
|
||||
from ray.dashboard.modules.node.node_head\
|
||||
import node_stats_to_dict
|
||||
def memory_summary(
|
||||
state,
|
||||
group_by="NODE_ADDRESS",
|
||||
sort_by="OBJECT_SIZE",
|
||||
line_wrap=True,
|
||||
unit="B",
|
||||
num_entries=None,
|
||||
) -> str:
|
||||
from ray.dashboard.modules.node.node_head import node_stats_to_dict
|
||||
|
||||
# Get terminal size
|
||||
import shutil
|
||||
|
||||
size = shutil.get_terminal_size((80, 20)).columns
|
||||
line_wrap_threshold = 137
|
||||
|
||||
# Unit conversions
|
||||
units = {"B": 10**0, "KB": 10**3, "MB": 10**6, "GB": 10**9}
|
||||
units = {"B": 10 ** 0, "KB": 10 ** 3, "MB": 10 ** 6, "GB": 10 ** 9}
|
||||
|
||||
# Fetch core memory worker stats, store as a dictionary
|
||||
core_worker_stats = []
|
||||
for raylet in state.node_table():
|
||||
stats = node_stats_to_dict(
|
||||
node_stats(raylet["NodeManagerAddress"],
|
||||
raylet["NodeManagerPort"]))
|
||||
node_stats(raylet["NodeManagerAddress"], raylet["NodeManagerPort"])
|
||||
)
|
||||
core_worker_stats.extend(stats["coreWorkersStats"])
|
||||
assert type(stats) is dict and "coreWorkersStats" in stats
|
||||
|
||||
# Build memory table with "group_by" and "sort_by" parameters
|
||||
group_by, sort_by = get_group_by_type(group_by), get_sorting_type(sort_by)
|
||||
memory_table = construct_memory_table(core_worker_stats, group_by,
|
||||
sort_by).as_dict()
|
||||
memory_table = construct_memory_table(
|
||||
core_worker_stats, group_by, sort_by
|
||||
).as_dict()
|
||||
assert "summary" in memory_table and "group" in memory_table
|
||||
|
||||
# Build memory summary
|
||||
mem = ""
|
||||
group_by, sort_by = group_by.name.lower().replace(
|
||||
"_", " "), sort_by.name.lower().replace("_", " ")
|
||||
"_", " "
|
||||
), sort_by.name.lower().replace("_", " ")
|
||||
summary_labels = [
|
||||
"Mem Used by Objects", "Local References", "Pinned", "Pending Tasks",
|
||||
"Captured in Objects", "Actor Handles"
|
||||
"Mem Used by Objects",
|
||||
"Local References",
|
||||
"Pinned",
|
||||
"Pending Tasks",
|
||||
"Captured in Objects",
|
||||
"Actor Handles",
|
||||
]
|
||||
summary_string = "{:<19} {:<16} {:<12} {:<13} {:<19} {:<13}\n"
|
||||
|
||||
object_ref_labels = [
|
||||
"IP Address", "PID", "Type", "Call Site", "Size", "Reference Type",
|
||||
"Object Ref"
|
||||
"IP Address",
|
||||
"PID",
|
||||
"Type",
|
||||
"Call Site",
|
||||
"Size",
|
||||
"Reference Type",
|
||||
"Object Ref",
|
||||
]
|
||||
object_ref_string = "{:<13} | {:<8} | {:<7} | {:<9} \
|
||||
| {:<8} | {:<14} | {:<10}\n"
|
||||
|
@ -416,22 +444,21 @@ entries per group...\n\n\n"
|
|||
else:
|
||||
summary[k] = str(v) + f", ({ref_size[k] / units[unit]} {unit})"
|
||||
mem += f"--- Summary for {group_by}: {key} ---\n"
|
||||
mem += summary_string\
|
||||
.format(*summary_labels)
|
||||
mem += summary_string\
|
||||
.format(*summary.values()) + "\n"
|
||||
mem += summary_string.format(*summary_labels)
|
||||
mem += summary_string.format(*summary.values()) + "\n"
|
||||
|
||||
# Memory table per group
|
||||
mem += f"--- Object references for {group_by}: {key} ---\n"
|
||||
mem += object_ref_string\
|
||||
.format(*object_ref_labels)
|
||||
mem += object_ref_string.format(*object_ref_labels)
|
||||
n = 1 # Counter for num entries per group
|
||||
for entry in group["entries"]:
|
||||
if num_entries is not None and n > num_entries:
|
||||
break
|
||||
entry["object_size"] = str(
|
||||
entry["object_size"] /
|
||||
units[unit]) + f" {unit}" if entry["object_size"] > -1 else "?"
|
||||
entry["object_size"] = (
|
||||
str(entry["object_size"] / units[unit]) + f" {unit}"
|
||||
if entry["object_size"] > -1
|
||||
else "?"
|
||||
)
|
||||
num_lines = 1
|
||||
if size > line_wrap_threshold and line_wrap:
|
||||
call_site_length = 22
|
||||
|
@ -439,30 +466,36 @@ entries per group...\n\n\n"
|
|||
entry["call_site"] = ["disabled"]
|
||||
else:
|
||||
entry["call_site"] = [
|
||||
entry["call_site"][i:i + call_site_length] for i in
|
||||
range(0, len(entry["call_site"]), call_site_length)
|
||||
entry["call_site"][i : i + call_site_length]
|
||||
for i in range(0, len(entry["call_site"]), call_site_length)
|
||||
]
|
||||
num_lines = len(entry["call_site"])
|
||||
else:
|
||||
mem += "\n"
|
||||
object_ref_values = [
|
||||
entry["node_ip_address"], entry["pid"], entry["type"],
|
||||
entry["call_site"], entry["object_size"],
|
||||
entry["reference_type"], entry["object_ref"]
|
||||
entry["node_ip_address"],
|
||||
entry["pid"],
|
||||
entry["type"],
|
||||
entry["call_site"],
|
||||
entry["object_size"],
|
||||
entry["reference_type"],
|
||||
entry["object_ref"],
|
||||
]
|
||||
for i in range(len(object_ref_values)):
|
||||
if not isinstance(object_ref_values[i], list):
|
||||
object_ref_values[i] = [object_ref_values[i]]
|
||||
object_ref_values[i].extend(
|
||||
["" for x in range(num_lines - len(object_ref_values[i]))])
|
||||
["" for x in range(num_lines - len(object_ref_values[i]))]
|
||||
)
|
||||
for i in range(num_lines):
|
||||
row = [elem[i] for elem in object_ref_values]
|
||||
mem += object_ref_string\
|
||||
.format(*row)
|
||||
mem += object_ref_string.format(*row)
|
||||
mem += "\n"
|
||||
n += 1
|
||||
|
||||
mem += "To record callsite information for each ObjectRef created, set " \
|
||||
"env variable RAY_record_ref_creation_sites=1\n\n"
|
||||
mem += (
|
||||
"To record callsite information for each ObjectRef created, set "
|
||||
"env variable RAY_record_ref_creation_sites=1\n\n"
|
||||
)
|
||||
|
||||
return mem
|
||||
|
|
|
@ -4,6 +4,7 @@ import aiohttp.web
|
|||
import ray._private.utils
|
||||
from ray.dashboard.modules.actor import actor_utils
|
||||
from aioredis.pubsub import Receiver
|
||||
|
||||
try:
|
||||
from grpc import aio as aiogrpc
|
||||
except ImportError:
|
||||
|
@ -15,8 +16,7 @@ import ray.dashboard.utils as dashboard_utils
|
|||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||
from ray.dashboard.optional_utils import rest_response
|
||||
from ray.dashboard.modules.actor import actor_consts
|
||||
from ray.dashboard.modules.actor.actor_utils import \
|
||||
actor_classname_from_task_spec
|
||||
from ray.dashboard.modules.actor.actor_utils import actor_classname_from_task_spec
|
||||
from ray.core.generated import node_manager_pb2_grpc
|
||||
from ray.core.generated import gcs_service_pb2
|
||||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
|
@ -30,12 +30,22 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
|
|||
|
||||
def actor_table_data_to_dict(message):
|
||||
orig_message = dashboard_utils.message_to_dict(
|
||||
message, {
|
||||
"actorId", "parentId", "jobId", "workerId", "rayletId",
|
||||
"actorCreationDummyObjectId", "callerId", "taskId", "parentTaskId",
|
||||
"sourceActorId", "placementGroupId"
|
||||
message,
|
||||
{
|
||||
"actorId",
|
||||
"parentId",
|
||||
"jobId",
|
||||
"workerId",
|
||||
"rayletId",
|
||||
"actorCreationDummyObjectId",
|
||||
"callerId",
|
||||
"taskId",
|
||||
"parentTaskId",
|
||||
"sourceActorId",
|
||||
"placementGroupId",
|
||||
},
|
||||
including_default_value_fields=True)
|
||||
including_default_value_fields=True,
|
||||
)
|
||||
# The complete schema for actor table is here:
|
||||
# src/ray/protobuf/gcs.proto
|
||||
# It is super big and for dashboard, we don't need that much information.
|
||||
|
@ -58,8 +68,7 @@ def actor_table_data_to_dict(message):
|
|||
light_message["actorClass"] = actor_class
|
||||
if "functionDescriptor" in light_message["taskSpec"]:
|
||||
light_message["taskSpec"] = {
|
||||
"functionDescriptor": light_message["taskSpec"][
|
||||
"functionDescriptor"]
|
||||
"functionDescriptor": light_message["taskSpec"]["functionDescriptor"]
|
||||
}
|
||||
else:
|
||||
light_message.pop("taskSpec")
|
||||
|
@ -81,11 +90,13 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
|||
if change.new:
|
||||
# TODO(fyrestone): Handle exceptions.
|
||||
node_id, node_info = change.new
|
||||
address = "{}:{}".format(node_info["nodeManagerAddress"],
|
||||
int(node_info["nodeManagerPort"]))
|
||||
options = (("grpc.enable_http_proxy", 0), )
|
||||
address = "{}:{}".format(
|
||||
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
|
||||
)
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
channel = ray._private.utils.init_grpc_channel(
|
||||
address, options, asynchronous=True)
|
||||
address, options, asynchronous=True
|
||||
)
|
||||
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
|
||||
self._stubs[node_id] = stub
|
||||
|
||||
|
@ -96,7 +107,8 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
|||
logger.info("Getting all actor info from GCS.")
|
||||
request = gcs_service_pb2.GetAllActorInfoRequest()
|
||||
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
|
||||
request, timeout=5)
|
||||
request, timeout=5
|
||||
)
|
||||
if reply.status.code == 0:
|
||||
actors = {}
|
||||
for message in reply.actor_table_data:
|
||||
|
@ -110,24 +122,25 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
|||
for actor_id, actor_table_data in actors.items():
|
||||
job_id = actor_table_data["jobId"]
|
||||
node_id = actor_table_data["address"]["rayletId"]
|
||||
job_actors.setdefault(job_id,
|
||||
{})[actor_id] = actor_table_data
|
||||
job_actors.setdefault(job_id, {})[actor_id] = actor_table_data
|
||||
# Update only when node_id is not Nil.
|
||||
if node_id != actor_consts.NIL_NODE_ID:
|
||||
node_actors.setdefault(
|
||||
node_id, {})[actor_id] = actor_table_data
|
||||
node_actors.setdefault(node_id, {})[
|
||||
actor_id
|
||||
] = actor_table_data
|
||||
DataSource.job_actors.reset(job_actors)
|
||||
DataSource.node_actors.reset(node_actors)
|
||||
logger.info("Received %d actor info from GCS.",
|
||||
len(actors))
|
||||
logger.info("Received %d actor info from GCS.", len(actors))
|
||||
break
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to GetAllActorInfo: {reply.status.message}")
|
||||
f"Failed to GetAllActorInfo: {reply.status.message}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error Getting all actor info from GCS.")
|
||||
await asyncio.sleep(
|
||||
actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS)
|
||||
actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS
|
||||
)
|
||||
|
||||
state_keys = ("state", "address", "numRestarts", "timestamp", "pid")
|
||||
|
||||
|
@ -167,8 +180,7 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
|||
if actor_id is not None:
|
||||
# Convert to lower case hex ID.
|
||||
actor_id = actor_id.hex()
|
||||
process_actor_data_from_pubsub(actor_id,
|
||||
actor_table_data)
|
||||
process_actor_data_from_pubsub(actor_id, actor_table_data)
|
||||
except Exception:
|
||||
logger.exception("Error processing actor info from GCS.")
|
||||
|
||||
|
@ -183,12 +195,15 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
|||
async for sender, msg in receiver.iter():
|
||||
try:
|
||||
actor_id, actor_table_data = msg
|
||||
actor_id = actor_id.decode("UTF-8")[len(
|
||||
gcs_utils.TablePrefix_ACTOR_string + ":"):]
|
||||
actor_id = actor_id.decode("UTF-8")[
|
||||
len(gcs_utils.TablePrefix_ACTOR_string + ":") :
|
||||
]
|
||||
pubsub_message = gcs_utils.PubSubMessage.FromString(
|
||||
actor_table_data)
|
||||
actor_table_data
|
||||
)
|
||||
actor_table_data = gcs_utils.ActorTableData.FromString(
|
||||
pubsub_message.data)
|
||||
pubsub_message.data
|
||||
)
|
||||
process_actor_data_from_pubsub(actor_id, actor_table_data)
|
||||
except Exception:
|
||||
logger.exception("Error processing actor info from Redis.")
|
||||
|
@ -203,17 +218,15 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
|||
actors.update(actor_creation_tasks)
|
||||
actor_groups = actor_utils.construct_actor_groups(actors)
|
||||
return rest_response(
|
||||
success=True,
|
||||
message="Fetched actor groups.",
|
||||
actor_groups=actor_groups)
|
||||
success=True, message="Fetched actor groups.", actor_groups=actor_groups
|
||||
)
|
||||
|
||||
@routes.get("/logical/actors")
|
||||
@dashboard_optional_utils.aiohttp_cache
|
||||
async def get_all_actors(self, req) -> aiohttp.web.Response:
|
||||
return rest_response(
|
||||
success=True,
|
||||
message="All actors fetched.",
|
||||
actors=DataSource.actors)
|
||||
success=True, message="All actors fetched.", actors=DataSource.actors
|
||||
)
|
||||
|
||||
@routes.get("/logical/kill_actor")
|
||||
async def kill_actor(self, req) -> aiohttp.web.Response:
|
||||
|
@ -224,15 +237,17 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
|||
except KeyError:
|
||||
return rest_response(success=False, message="Bad Request")
|
||||
try:
|
||||
options = (("grpc.enable_http_proxy", 0), )
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
channel = ray._private.utils.init_grpc_channel(
|
||||
f"{ip_address}:{port}", options=options, asynchronous=True)
|
||||
f"{ip_address}:{port}", options=options, asynchronous=True
|
||||
)
|
||||
stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel)
|
||||
|
||||
await stub.KillActor(
|
||||
core_worker_pb2.KillActorRequest(
|
||||
intended_actor_id=ray._private.utils.hex_to_binary(
|
||||
actor_id)))
|
||||
intended_actor_id=ray._private.utils.hex_to_binary(actor_id)
|
||||
)
|
||||
)
|
||||
|
||||
except aiogrpc.AioRpcError:
|
||||
# This always throws an exception because the worker
|
||||
|
@ -240,13 +255,13 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
|||
# before this handler, however it deletes the actor correctly.
|
||||
pass
|
||||
|
||||
return rest_response(
|
||||
success=True, message=f"Killed actor with id {actor_id}")
|
||||
return rest_response(success=True, message=f"Killed actor with id {actor_id}")
|
||||
|
||||
async def run(self, server):
|
||||
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
|
||||
self._gcs_actor_info_stub = \
|
||||
gcs_service_pb2_grpc.ActorInfoGcsServiceStub(gcs_channel)
|
||||
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
|
||||
gcs_channel
|
||||
)
|
||||
|
||||
await asyncio.gather(self._update_actors())
|
||||
|
||||
|
|
|
@ -7,27 +7,29 @@ PYCLASSNAME_RE = re.compile(r"(.+?)\(")
|
|||
|
||||
def construct_actor_groups(actors):
|
||||
"""actors is a dict from actor id to an actor or an
|
||||
actor creation task The shared fields currently are
|
||||
"actorClass", "actorId", and "state" """
|
||||
actor creation task The shared fields currently are
|
||||
"actorClass", "actorId", and "state" """
|
||||
actor_groups = _group_actors_by_python_class(actors)
|
||||
stats_by_group = {
|
||||
name: _get_actor_group_stats(group)
|
||||
for name, group in actor_groups.items()
|
||||
name: _get_actor_group_stats(group) for name, group in actor_groups.items()
|
||||
}
|
||||
|
||||
summarized_actor_groups = {}
|
||||
for name, group in actor_groups.items():
|
||||
summarized_actor_groups[name] = {
|
||||
"entries": group,
|
||||
"summary": stats_by_group[name]
|
||||
"summary": stats_by_group[name],
|
||||
}
|
||||
return summarized_actor_groups
|
||||
|
||||
|
||||
def actor_classname_from_task_spec(task_spec):
|
||||
return task_spec.get("functionDescriptor", {})\
|
||||
.get("pythonFunctionDescriptor", {})\
|
||||
.get("className", "Unknown actor class").split(".")[-1]
|
||||
return (
|
||||
task_spec.get("functionDescriptor", {})
|
||||
.get("pythonFunctionDescriptor", {})
|
||||
.get("className", "Unknown actor class")
|
||||
.split(".")[-1]
|
||||
)
|
||||
|
||||
|
||||
def _group_actors_by_python_class(actors):
|
||||
|
|
|
@ -50,8 +50,7 @@ def test_actor_groups(ray_start_with_dashboard):
|
|||
response = requests.get(webui_url + "/logical/actor_groups")
|
||||
response.raise_for_status()
|
||||
actor_groups_resp = response.json()
|
||||
assert actor_groups_resp["result"] is True, actor_groups_resp[
|
||||
"msg"]
|
||||
assert actor_groups_resp["result"] is True, actor_groups_resp["msg"]
|
||||
actor_groups = actor_groups_resp["data"]["actorGroups"]
|
||||
assert "Foo" in actor_groups
|
||||
summary = actor_groups["Foo"]["summary"]
|
||||
|
@ -78,9 +77,13 @@ def test_actor_groups(ray_start_with_dashboard):
|
|||
last_ex = ex
|
||||
finally:
|
||||
if time.time() > start_time + timeout_seconds:
|
||||
ex_stack = traceback.format_exception(
|
||||
type(last_ex), last_ex,
|
||||
last_ex.__traceback__) if last_ex else []
|
||||
ex_stack = (
|
||||
traceback.format_exception(
|
||||
type(last_ex), last_ex, last_ex.__traceback__
|
||||
)
|
||||
if last_ex
|
||||
else []
|
||||
)
|
||||
ex_stack = "".join(ex_stack)
|
||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||
|
||||
|
@ -135,9 +138,13 @@ def test_actors(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
last_ex = ex
|
||||
finally:
|
||||
if time.time() > start_time + timeout_seconds:
|
||||
ex_stack = traceback.format_exception(
|
||||
type(last_ex), last_ex,
|
||||
last_ex.__traceback__) if last_ex else []
|
||||
ex_stack = (
|
||||
traceback.format_exception(
|
||||
type(last_ex), last_ex, last_ex.__traceback__
|
||||
)
|
||||
if last_ex
|
||||
else []
|
||||
)
|
||||
ex_stack = "".join(ex_stack)
|
||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||
|
||||
|
@ -183,8 +190,9 @@ def test_kill_actor(ray_start_with_dashboard):
|
|||
params={
|
||||
"actorId": actor["actorId"],
|
||||
"ipAddress": actor["ipAddress"],
|
||||
"port": actor["port"]
|
||||
})
|
||||
"port": actor["port"],
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp_json = resp.json()
|
||||
assert resp_json["result"] is True, "msg" in resp_json
|
||||
|
@ -199,19 +207,17 @@ def test_kill_actor(ray_start_with_dashboard):
|
|||
break
|
||||
except (KeyError, AssertionError) as e:
|
||||
last_exc = e
|
||||
time.sleep(.1)
|
||||
time.sleep(0.1)
|
||||
assert last_exc is None
|
||||
|
||||
|
||||
def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||
timeout = 5
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
address_info = ray_start_with_dashboard
|
||||
|
||||
if gcs_pubsub.gcs_pubsub_enabled():
|
||||
sub = gcs_pubsub.GcsActorSubscriber(
|
||||
address=address_info["gcs_address"])
|
||||
sub = gcs_pubsub.GcsActorSubscriber(address=address_info["gcs_address"])
|
||||
sub.subscribe()
|
||||
else:
|
||||
address = address_info["redis_address"]
|
||||
|
@ -221,7 +227,8 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
client = redis.StrictRedis(
|
||||
host=address[0],
|
||||
port=int(address[1]),
|
||||
password=ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||
password=ray_constants.REDIS_DEFAULT_PASSWORD,
|
||||
)
|
||||
|
||||
sub = client.pubsub(ignore_subscribe_messages=True)
|
||||
sub.psubscribe(gcs_utils.RAY_ACTOR_PUBSUB_PATTERN)
|
||||
|
@ -245,8 +252,7 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
time.sleep(0.01)
|
||||
continue
|
||||
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
|
||||
actor_data = gcs_utils.ActorTableData.FromString(
|
||||
pubsub_msg.data)
|
||||
actor_data = gcs_utils.ActorTableData.FromString(pubsub_msg.data)
|
||||
if actor_data is None:
|
||||
continue
|
||||
msgs.append(actor_data)
|
||||
|
@ -266,12 +272,22 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
|
||||
def actor_table_data_to_dict(message):
|
||||
return dashboard_utils.message_to_dict(
|
||||
message, {
|
||||
"actorId", "parentId", "jobId", "workerId", "rayletId",
|
||||
"actorCreationDummyObjectId", "callerId", "taskId",
|
||||
"parentTaskId", "sourceActorId", "placementGroupId"
|
||||
message,
|
||||
{
|
||||
"actorId",
|
||||
"parentId",
|
||||
"jobId",
|
||||
"workerId",
|
||||
"rayletId",
|
||||
"actorCreationDummyObjectId",
|
||||
"callerId",
|
||||
"taskId",
|
||||
"parentTaskId",
|
||||
"sourceActorId",
|
||||
"placementGroupId",
|
||||
},
|
||||
including_default_value_fields=False)
|
||||
including_default_value_fields=False,
|
||||
)
|
||||
|
||||
non_state_keys = ("actorId", "jobId", "taskSpec")
|
||||
|
||||
|
@ -287,23 +303,31 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
# be published.
|
||||
elif actor_data_dict["state"] in ("ALIVE", "DEAD"):
|
||||
assert actor_data_dict.keys() >= {
|
||||
"state", "address", "timestamp", "pid", "rayNamespace"
|
||||
"state",
|
||||
"address",
|
||||
"timestamp",
|
||||
"pid",
|
||||
"rayNamespace",
|
||||
}
|
||||
elif actor_data_dict["state"] == "PENDING_CREATION":
|
||||
assert actor_data_dict.keys() == {
|
||||
"state", "address", "actorId", "actorCreationDummyObjectId",
|
||||
"jobId", "ownerAddress", "taskSpec", "className",
|
||||
"serializedRuntimeEnv", "rayNamespace"
|
||||
"state",
|
||||
"address",
|
||||
"actorId",
|
||||
"actorCreationDummyObjectId",
|
||||
"jobId",
|
||||
"ownerAddress",
|
||||
"taskSpec",
|
||||
"className",
|
||||
"serializedRuntimeEnv",
|
||||
"rayNamespace",
|
||||
}
|
||||
else:
|
||||
raise Exception("Unknown state: {}".format(
|
||||
actor_data_dict["state"]))
|
||||
raise Exception("Unknown state: {}".format(actor_data_dict["state"]))
|
||||
|
||||
|
||||
def test_nil_node(enable_test_module, disable_aiohttp_cache,
|
||||
ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
def test_nil_node(enable_test_module, disable_aiohttp_cache, ray_start_with_dashboard):
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
assert wait_until_server_available(webui_url)
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
@ -334,9 +358,13 @@ def test_nil_node(enable_test_module, disable_aiohttp_cache,
|
|||
last_ex = ex
|
||||
finally:
|
||||
if time.time() > start_time + timeout_seconds:
|
||||
ex_stack = traceback.format_exception(
|
||||
type(last_ex), last_ex,
|
||||
last_ex.__traceback__) if last_ex else []
|
||||
ex_stack = (
|
||||
traceback.format_exception(
|
||||
type(last_ex), last_ex, last_ex.__traceback__
|
||||
)
|
||||
if last_ex
|
||||
else []
|
||||
)
|
||||
ex_stack = "".join(ex_stack)
|
||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||
|
||||
|
|
|
@ -24,13 +24,11 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
|
|||
os.makedirs(self._event_dir, exist_ok=True)
|
||||
self._monitor: Union[asyncio.Task, None] = None
|
||||
self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None
|
||||
self._cached_events = asyncio.Queue(
|
||||
event_consts.EVENT_AGENT_CACHE_SIZE)
|
||||
logger.info("Event agent cache buffer size: %s",
|
||||
self._cached_events.maxsize)
|
||||
self._cached_events = asyncio.Queue(event_consts.EVENT_AGENT_CACHE_SIZE)
|
||||
logger.info("Event agent cache buffer size: %s", self._cached_events.maxsize)
|
||||
|
||||
async def _connect_to_dashboard(self):
|
||||
""" Connect to the dashboard. If the dashboard is not started, then
|
||||
"""Connect to the dashboard. If the dashboard is not started, then
|
||||
this method will never returns.
|
||||
|
||||
Returns:
|
||||
|
@ -41,23 +39,24 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
|
|||
# TODO: Use async version if performance is an issue
|
||||
dashboard_rpc_address = internal_kv._internal_kv_get(
|
||||
dashboard_consts.DASHBOARD_RPC_ADDRESS,
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||
)
|
||||
if dashboard_rpc_address:
|
||||
logger.info("Report events to %s", dashboard_rpc_address)
|
||||
options = (("grpc.enable_http_proxy", 0), )
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
channel = utils.init_grpc_channel(
|
||||
dashboard_rpc_address,
|
||||
options=options,
|
||||
asynchronous=True)
|
||||
dashboard_rpc_address, options=options, asynchronous=True
|
||||
)
|
||||
return event_pb2_grpc.ReportEventServiceStub(channel)
|
||||
except Exception:
|
||||
logger.exception("Connect to dashboard failed.")
|
||||
await asyncio.sleep(
|
||||
event_consts.RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS)
|
||||
event_consts.RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS
|
||||
)
|
||||
|
||||
@async_loop_forever(event_consts.EVENT_AGENT_REPORT_INTERVAL_SECONDS)
|
||||
async def report_events(self):
|
||||
""" Report events from cached events queue. Reconnect to dashboard if
|
||||
"""Report events from cached events queue. Reconnect to dashboard if
|
||||
report failed. Log error after retry EVENT_AGENT_RETRY_TIMES.
|
||||
|
||||
This method will never returns.
|
||||
|
@ -70,14 +69,15 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
|
|||
await self._stub.ReportEvents(request)
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Report event failed, reconnect to the "
|
||||
"dashboard.")
|
||||
logger.exception("Report event failed, reconnect to the " "dashboard.")
|
||||
self._stub = await self._connect_to_dashboard()
|
||||
else:
|
||||
data_str = str(data)
|
||||
limit = event_consts.LOG_ERROR_EVENT_STRING_LENGTH_LIMIT
|
||||
logger.error("Report event failed: %s",
|
||||
data_str[:limit] + (data_str[limit:] and "..."))
|
||||
logger.error(
|
||||
"Report event failed: %s",
|
||||
data_str[:limit] + (data_str[limit:] and "..."),
|
||||
)
|
||||
|
||||
async def run(self, server):
|
||||
# Connect to dashboard.
|
||||
|
@ -86,7 +86,8 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
|
|||
self._monitor = monitor_events(
|
||||
self._event_dir,
|
||||
lambda data: create_task(self._cached_events.put(data)),
|
||||
source_types=event_consts.EVENT_AGENT_MONITOR_SOURCE_TYPES)
|
||||
source_types=event_consts.EVENT_AGENT_MONITOR_SOURCE_TYPES,
|
||||
)
|
||||
# Start reporting events.
|
||||
await self.report_events()
|
||||
|
||||
|
|
|
@ -4,22 +4,20 @@ from ray.core.generated import event_pb2
|
|||
LOG_ERROR_EVENT_STRING_LENGTH_LIMIT = 1000
|
||||
RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS = 2
|
||||
# Monitor events
|
||||
SCAN_EVENT_DIR_INTERVAL_SECONDS = env_integer(
|
||||
"SCAN_EVENT_DIR_INTERVAL_SECONDS", 2)
|
||||
SCAN_EVENT_DIR_INTERVAL_SECONDS = env_integer("SCAN_EVENT_DIR_INTERVAL_SECONDS", 2)
|
||||
SCAN_EVENT_START_OFFSET_SECONDS = -30 * 60
|
||||
CONCURRENT_READ_LIMIT = 50
|
||||
EVENT_READ_LINE_COUNT_LIMIT = 200
|
||||
EVENT_READ_LINE_LENGTH_LIMIT = env_integer("EVENT_READ_LINE_LENGTH_LIMIT",
|
||||
2 * 1024 * 1024) # 2MB
|
||||
EVENT_READ_LINE_LENGTH_LIMIT = env_integer(
|
||||
"EVENT_READ_LINE_LENGTH_LIMIT", 2 * 1024 * 1024
|
||||
) # 2MB
|
||||
# Report events
|
||||
EVENT_AGENT_REPORT_INTERVAL_SECONDS = 0.1
|
||||
EVENT_AGENT_RETRY_TIMES = 10
|
||||
EVENT_AGENT_CACHE_SIZE = 10240
|
||||
# Event sources
|
||||
EVENT_HEAD_MONITOR_SOURCE_TYPES = [
|
||||
event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)
|
||||
]
|
||||
EVENT_HEAD_MONITOR_SOURCE_TYPES = [event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)]
|
||||
EVENT_AGENT_MONITOR_SOURCE_TYPES = list(
|
||||
set(event_pb2.Event.SourceType.keys()) -
|
||||
set(EVENT_HEAD_MONITOR_SOURCE_TYPES))
|
||||
set(event_pb2.Event.SourceType.keys()) - set(EVENT_HEAD_MONITOR_SOURCE_TYPES)
|
||||
)
|
||||
EVENT_SOURCE_ALL = event_pb2.Event.SourceType.keys()
|
||||
|
|
|
@ -24,8 +24,9 @@ JobEvents = OrderedDict
|
|||
dashboard_utils._json_compatible_types.add(JobEvents)
|
||||
|
||||
|
||||
class EventHead(dashboard_utils.DashboardHeadModule,
|
||||
event_pb2_grpc.ReportEventServiceServicer):
|
||||
class EventHead(
|
||||
dashboard_utils.DashboardHeadModule, event_pb2_grpc.ReportEventServiceServicer
|
||||
):
|
||||
def __init__(self, dashboard_head):
|
||||
super().__init__(dashboard_head)
|
||||
self._event_dir = os.path.join(self._dashboard_head.log_dir, "events")
|
||||
|
@ -70,21 +71,24 @@ class EventHead(dashboard_utils.DashboardHeadModule,
|
|||
for job_id, job_events in DataSource.events.items()
|
||||
}
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True, message="All events fetched.", events=all_events)
|
||||
success=True, message="All events fetched.", events=all_events
|
||||
)
|
||||
|
||||
job_events = DataSource.events.get(job_id, {})
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True,
|
||||
message="Job events fetched.",
|
||||
job_id=job_id,
|
||||
events=list(job_events.values()))
|
||||
events=list(job_events.values()),
|
||||
)
|
||||
|
||||
async def run(self, server):
|
||||
event_pb2_grpc.add_ReportEventServiceServicer_to_server(self, server)
|
||||
self._monitor = monitor_events(
|
||||
self._event_dir,
|
||||
lambda data: self._update_events(parse_event_strings(data)),
|
||||
source_types=event_consts.EVENT_HEAD_MONITOR_SOURCE_TYPES)
|
||||
source_types=event_consts.EVENT_HEAD_MONITOR_SOURCE_TYPES,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_minimal_module():
|
||||
|
|
|
@ -19,8 +19,7 @@ def _get_source_files(event_dir, source_types=None, event_file_filter=None):
|
|||
source_files = {}
|
||||
all_source_types = set(event_consts.EVENT_SOURCE_ALL)
|
||||
for source_type in source_types or event_consts.EVENT_SOURCE_ALL:
|
||||
assert source_type in all_source_types, \
|
||||
f"Invalid source type: {source_type}"
|
||||
assert source_type in all_source_types, f"Invalid source type: {source_type}"
|
||||
files = []
|
||||
for n in event_log_names:
|
||||
if fnmatch.fnmatch(n, f"*{source_type}*"):
|
||||
|
@ -35,9 +34,9 @@ def _get_source_files(event_dir, source_types=None, event_file_filter=None):
|
|||
|
||||
def _restore_newline(event_dict):
|
||||
try:
|
||||
event_dict["message"] = event_dict["message"]\
|
||||
.replace("\\n", "\n")\
|
||||
.replace("\\r", "\n")
|
||||
event_dict["message"] = (
|
||||
event_dict["message"].replace("\\n", "\n").replace("\\r", "\n")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Restore newline for event failed: %s", event_dict)
|
||||
return event_dict
|
||||
|
@ -61,13 +60,13 @@ def parse_event_strings(event_string_list):
|
|||
|
||||
|
||||
ReadFileResult = collections.namedtuple(
|
||||
"ReadFileResult", ["fid", "size", "mtime", "position", "lines"])
|
||||
"ReadFileResult", ["fid", "size", "mtime", "position", "lines"]
|
||||
)
|
||||
|
||||
|
||||
def _read_file(file,
|
||||
pos,
|
||||
n_lines=event_consts.EVENT_READ_LINE_COUNT_LIMIT,
|
||||
closefd=True):
|
||||
def _read_file(
|
||||
file, pos, n_lines=event_consts.EVENT_READ_LINE_COUNT_LIMIT, closefd=True
|
||||
):
|
||||
with open(file, "rb", closefd=closefd) as f:
|
||||
# The ino may be 0 on Windows.
|
||||
stat = os.stat(f.fileno())
|
||||
|
@ -82,24 +81,25 @@ def _read_file(file,
|
|||
if sep - start <= event_consts.EVENT_READ_LINE_LENGTH_LIMIT:
|
||||
lines.append(mm[start:sep].decode("utf-8"))
|
||||
else:
|
||||
truncated_size = min(
|
||||
100, event_consts.EVENT_READ_LINE_LENGTH_LIMIT)
|
||||
truncated_size = min(100, event_consts.EVENT_READ_LINE_LENGTH_LIMIT)
|
||||
logger.warning(
|
||||
"Ignored long string: %s...(%s chars)",
|
||||
mm[start:start + truncated_size].decode("utf-8"),
|
||||
sep - start)
|
||||
mm[start : start + truncated_size].decode("utf-8"),
|
||||
sep - start,
|
||||
)
|
||||
start = sep + 1
|
||||
return ReadFileResult(fid, stat.st_size, stat.st_mtime, start, lines)
|
||||
|
||||
|
||||
def monitor_events(
|
||||
event_dir,
|
||||
callback,
|
||||
scan_interval_seconds=event_consts.SCAN_EVENT_DIR_INTERVAL_SECONDS,
|
||||
start_mtime=time.time() + event_consts.SCAN_EVENT_START_OFFSET_SECONDS,
|
||||
monitor_files=None,
|
||||
source_types=None):
|
||||
""" Monitor events in directory. New events will be read and passed to the
|
||||
event_dir,
|
||||
callback,
|
||||
scan_interval_seconds=event_consts.SCAN_EVENT_DIR_INTERVAL_SECONDS,
|
||||
start_mtime=time.time() + event_consts.SCAN_EVENT_START_OFFSET_SECONDS,
|
||||
monitor_files=None,
|
||||
source_types=None,
|
||||
):
|
||||
"""Monitor events in directory. New events will be read and passed to the
|
||||
callback.
|
||||
|
||||
Args:
|
||||
|
@ -121,20 +121,22 @@ def monitor_events(
|
|||
monitor_files = {}
|
||||
|
||||
logger.info(
|
||||
"Monitor events logs modified after %s on %s, "
|
||||
"the source types are %s.", start_mtime, event_dir, "all"
|
||||
if source_types is None else source_types)
|
||||
"Monitor events logs modified after %s on %s, " "the source types are %s.",
|
||||
start_mtime,
|
||||
event_dir,
|
||||
"all" if source_types is None else source_types,
|
||||
)
|
||||
|
||||
MonitorFile = collections.namedtuple("MonitorFile",
|
||||
["size", "mtime", "position"])
|
||||
MonitorFile = collections.namedtuple("MonitorFile", ["size", "mtime", "position"])
|
||||
|
||||
def _source_file_filter(source_file):
|
||||
stat = os.stat(source_file)
|
||||
return stat.st_mtime > start_mtime
|
||||
|
||||
def _read_monitor_file(file, pos):
|
||||
assert isinstance(file, str), \
|
||||
f"File should be a str, but a {type(file)}({file}) found"
|
||||
assert isinstance(
|
||||
file, str
|
||||
), f"File should be a str, but a {type(file)}({file}) found"
|
||||
fd = os.open(file, os.O_RDONLY)
|
||||
try:
|
||||
stat = os.stat(fd)
|
||||
|
@ -145,12 +147,14 @@ def monitor_events(
|
|||
fid = stat.st_ino or file
|
||||
monitor_file = monitor_files.get(fid)
|
||||
if monitor_file:
|
||||
if (monitor_file.position == monitor_file.size
|
||||
and monitor_file.size == stat.st_size
|
||||
and monitor_file.mtime == stat.st_mtime):
|
||||
if (
|
||||
monitor_file.position == monitor_file.size
|
||||
and monitor_file.size == stat.st_size
|
||||
and monitor_file.mtime == stat.st_mtime
|
||||
):
|
||||
logger.debug(
|
||||
"Skip reading the file because "
|
||||
"there is no change: %s", file)
|
||||
"Skip reading the file because " "there is no change: %s", file
|
||||
)
|
||||
return []
|
||||
position = monitor_file.position
|
||||
else:
|
||||
|
@ -169,22 +173,23 @@ def monitor_events(
|
|||
@async_loop_forever(scan_interval_seconds, cancellable=True)
|
||||
async def _scan_event_log_files():
|
||||
# Scan event files.
|
||||
source_files = await loop.run_in_executor(None, _get_source_files,
|
||||
event_dir, source_types,
|
||||
_source_file_filter)
|
||||
source_files = await loop.run_in_executor(
|
||||
None, _get_source_files, event_dir, source_types, _source_file_filter
|
||||
)
|
||||
|
||||
# Limit concurrent read to avoid fd exhaustion.
|
||||
semaphore = asyncio.Semaphore(event_consts.CONCURRENT_READ_LIMIT)
|
||||
|
||||
async def _concurrent_coro(filename):
|
||||
async with semaphore:
|
||||
return await loop.run_in_executor(None, _read_monitor_file,
|
||||
filename, 0)
|
||||
return await loop.run_in_executor(None, _read_monitor_file, filename, 0)
|
||||
|
||||
# Read files.
|
||||
await asyncio.gather(*[
|
||||
_concurrent_coro(filename)
|
||||
for filename in list(itertools.chain(*source_files.values()))
|
||||
])
|
||||
await asyncio.gather(
|
||||
*[
|
||||
_concurrent_coro(filename)
|
||||
for filename in list(itertools.chain(*source_files.values()))
|
||||
]
|
||||
)
|
||||
|
||||
return create_task(_scan_event_log_files())
|
||||
|
|
|
@ -23,7 +23,8 @@ from ray._private.test_utils import (
|
|||
wait_for_condition,
|
||||
)
|
||||
from ray.dashboard.modules.event.event_utils import (
|
||||
monitor_events, )
|
||||
monitor_events,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -32,7 +33,8 @@ def _get_event(msg="empty message", job_id=None, source_type=None):
|
|||
return {
|
||||
"event_id": binary_to_hex(np.random.bytes(18)),
|
||||
"source_type": random.choice(event_pb2.Event.SourceType.keys())
|
||||
if source_type is None else source_type,
|
||||
if source_type is None
|
||||
else source_type,
|
||||
"host_name": "po-dev.inc.alipay.net",
|
||||
"pid": random.randint(1, 65536),
|
||||
"label": "",
|
||||
|
@ -41,16 +43,18 @@ def _get_event(msg="empty message", job_id=None, source_type=None):
|
|||
"severity": "INFO",
|
||||
"custom_fields": {
|
||||
"job_id": ray.JobID.from_int(random.randint(1, 100)).hex()
|
||||
if job_id is None else job_id,
|
||||
if job_id is None
|
||||
else job_id,
|
||||
"node_id": "",
|
||||
"task_id": "",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _test_logger(name, log_file, max_bytes, backup_count):
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
log_file, maxBytes=max_bytes, backupCount=backup_count)
|
||||
log_file, maxBytes=max_bytes, backupCount=backup_count
|
||||
)
|
||||
formatter = logging.Formatter("%(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
|
@ -63,15 +67,14 @@ def _test_logger(name, log_file, max_bytes, backup_count):
|
|||
|
||||
|
||||
def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
||||
session_dir = ray_start_with_dashboard["session_dir"]
|
||||
event_dir = os.path.join(session_dir, "logs", "events")
|
||||
job_id = ray.JobID.from_int(100).hex()
|
||||
|
||||
source_type_gcs = event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)
|
||||
source_type_raylet = event_pb2.Event.SourceType.Name(
|
||||
event_pb2.Event.RAYLET)
|
||||
source_type_raylet = event_pb2.Event.SourceType.Name(event_pb2.Event.RAYLET)
|
||||
test_count = 20
|
||||
|
||||
for source_type in [source_type_gcs, source_type_raylet]:
|
||||
|
@ -80,10 +83,10 @@ def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
__name__ + str(random.random()),
|
||||
test_log_file,
|
||||
max_bytes=2000,
|
||||
backup_count=1000)
|
||||
backup_count=1000,
|
||||
)
|
||||
for i in range(test_count):
|
||||
sample_event = _get_event(
|
||||
str(i), job_id=job_id, source_type=source_type)
|
||||
sample_event = _get_event(str(i), job_id=job_id, source_type=source_type)
|
||||
test_logger.info("%s", json.dumps(sample_event))
|
||||
|
||||
def _check_events():
|
||||
|
@ -112,10 +115,11 @@ def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
wait_for_condition(_check_events, timeout=15)
|
||||
|
||||
|
||||
def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache,
|
||||
ray_start_with_dashboard):
|
||||
def test_event_message_limit(
|
||||
small_event_line_limit, disable_aiohttp_cache, ray_start_with_dashboard
|
||||
):
|
||||
event_read_line_length_limit = small_event_line_limit
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
||||
session_dir = ray_start_with_dashboard["session_dir"]
|
||||
event_dir = os.path.join(session_dir, "logs", "events")
|
||||
|
@ -148,8 +152,8 @@ def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache,
|
|||
except Exception:
|
||||
pass
|
||||
os.rename(
|
||||
os.path.join(event_dir, "tmp.log"),
|
||||
os.path.join(event_dir, "event_GCS.log"))
|
||||
os.path.join(event_dir, "tmp.log"), os.path.join(event_dir, "event_GCS.log")
|
||||
)
|
||||
|
||||
def _check_events():
|
||||
try:
|
||||
|
@ -157,14 +161,14 @@ def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache,
|
|||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
all_events = result["data"]["events"]
|
||||
assert len(all_events[job_id]
|
||||
) >= event_consts.EVENT_READ_LINE_COUNT_LIMIT + 10
|
||||
assert (
|
||||
len(all_events[job_id]) >= event_consts.EVENT_READ_LINE_COUNT_LIMIT + 10
|
||||
)
|
||||
messages = [e["message"] for e in all_events[job_id]]
|
||||
for i in range(10):
|
||||
assert str(i) * message_len in messages
|
||||
assert "2" * (message_len + 1) not in messages
|
||||
assert str(event_consts.EVENT_READ_LINE_COUNT_LIMIT -
|
||||
1) in messages
|
||||
assert str(event_consts.EVENT_READ_LINE_COUNT_LIMIT - 1) in messages
|
||||
return True
|
||||
except Exception as ex:
|
||||
logger.exception(ex)
|
||||
|
@ -179,15 +183,12 @@ async def test_monitor_events():
|
|||
common = event_pb2.Event.SourceType.Name(event_pb2.Event.COMMON)
|
||||
common_log = os.path.join(temp_dir, f"event_{common}.log")
|
||||
test_logger = _test_logger(
|
||||
__name__ + str(random.random()),
|
||||
common_log,
|
||||
max_bytes=10,
|
||||
backup_count=10)
|
||||
__name__ + str(random.random()), common_log, max_bytes=10, backup_count=10
|
||||
)
|
||||
test_events1 = []
|
||||
monitor_task = monitor_events(
|
||||
temp_dir,
|
||||
lambda x: test_events1.extend(x),
|
||||
scan_interval_seconds=0.01)
|
||||
temp_dir, lambda x: test_events1.extend(x), scan_interval_seconds=0.01
|
||||
)
|
||||
assert not monitor_task.done()
|
||||
count = 10
|
||||
|
||||
|
@ -206,7 +207,8 @@ async def test_monitor_events():
|
|||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
f"Timeout, read events: {sorted_events}, "
|
||||
f"expect events: {expect_events}")
|
||||
f"expect events: {expect_events}"
|
||||
)
|
||||
if len(sorted_events) == len(expect_events):
|
||||
if sorted_events == expect_events:
|
||||
break
|
||||
|
@ -214,40 +216,37 @@ async def test_monitor_events():
|
|||
|
||||
await asyncio.gather(
|
||||
_writer(count, read_events=test_events1),
|
||||
_check_events(
|
||||
[str(i) for i in range(count)], read_events=test_events1))
|
||||
_check_events([str(i) for i in range(count)], read_events=test_events1),
|
||||
)
|
||||
|
||||
monitor_task.cancel()
|
||||
test_events2 = []
|
||||
monitor_task = monitor_events(
|
||||
temp_dir,
|
||||
lambda x: test_events2.extend(x),
|
||||
scan_interval_seconds=0.1)
|
||||
temp_dir, lambda x: test_events2.extend(x), scan_interval_seconds=0.1
|
||||
)
|
||||
|
||||
await _check_events(
|
||||
[str(i) for i in range(count)], read_events=test_events2)
|
||||
await _check_events([str(i) for i in range(count)], read_events=test_events2)
|
||||
|
||||
await _writer(count, count * 2, read_events=test_events2)
|
||||
await _check_events(
|
||||
[str(i) for i in range(count * 2)], read_events=test_events2)
|
||||
[str(i) for i in range(count * 2)], read_events=test_events2
|
||||
)
|
||||
|
||||
log_file_count = len(os.listdir(temp_dir))
|
||||
|
||||
test_logger = _test_logger(
|
||||
__name__ + str(random.random()),
|
||||
common_log,
|
||||
max_bytes=1000,
|
||||
backup_count=10)
|
||||
__name__ + str(random.random()), common_log, max_bytes=1000, backup_count=10
|
||||
)
|
||||
assert len(os.listdir(temp_dir)) == log_file_count
|
||||
|
||||
await _writer(
|
||||
count * 2, count * 3, spin=False, read_events=test_events2)
|
||||
await _writer(count * 2, count * 3, spin=False, read_events=test_events2)
|
||||
await _check_events(
|
||||
[str(i) for i in range(count * 3)], read_events=test_events2)
|
||||
await _writer(
|
||||
count * 3, count * 4, spin=False, read_events=test_events2)
|
||||
[str(i) for i in range(count * 3)], read_events=test_events2
|
||||
)
|
||||
await _writer(count * 3, count * 4, spin=False, read_events=test_events2)
|
||||
await _check_events(
|
||||
[str(i) for i in range(count * 4)], read_events=test_events2)
|
||||
[str(i) for i in range(count * 4)], read_events=test_events2
|
||||
)
|
||||
|
||||
# Test cancel monitor task.
|
||||
monitor_task.cancel()
|
||||
|
@ -255,8 +254,7 @@ async def test_monitor_events():
|
|||
await monitor_task
|
||||
assert monitor_task.done()
|
||||
|
||||
assert len(
|
||||
os.listdir(temp_dir)) > 1, "Event log should have rollovers."
|
||||
assert len(os.listdir(temp_dir)) > 1, "Event log should have rollovers."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -7,21 +7,21 @@ import yaml
|
|||
|
||||
import click
|
||||
|
||||
from ray.autoscaler._private.cli_logger import (add_click_logging_options,
|
||||
cli_logger, cf)
|
||||
from ray.autoscaler._private.cli_logger import add_click_logging_options, cli_logger, cf
|
||||
from ray.dashboard.modules.job.common import JobStatus
|
||||
from ray.dashboard.modules.job.sdk import JobSubmissionClient
|
||||
|
||||
|
||||
def _get_sdk_client(address: Optional[str],
|
||||
create_cluster_if_needed: bool = False
|
||||
) -> JobSubmissionClient:
|
||||
def _get_sdk_client(
|
||||
address: Optional[str], create_cluster_if_needed: bool = False
|
||||
) -> JobSubmissionClient:
|
||||
|
||||
if address is None:
|
||||
if "RAY_ADDRESS" not in os.environ:
|
||||
raise ValueError(
|
||||
"Address must be specified using either the --address flag "
|
||||
"or RAY_ADDRESS environment variable.")
|
||||
"or RAY_ADDRESS environment variable."
|
||||
)
|
||||
address = os.environ["RAY_ADDRESS"]
|
||||
|
||||
cli_logger.labeled_value("Job submission server address", address)
|
||||
|
@ -73,55 +73,67 @@ def job_cli_group():
|
|||
pass
|
||||
|
||||
|
||||
@job_cli_group.command(
|
||||
"submit", help="Submit a job to be executed on the cluster.")
|
||||
@job_cli_group.command("submit", help="Submit a job to be executed on the cluster.")
|
||||
@click.option(
|
||||
"--address",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=("Address of the Ray cluster to connect to. Can also be specified "
|
||||
"using the RAY_ADDRESS environment variable."))
|
||||
help=(
|
||||
"Address of the Ray cluster to connect to. Can also be specified "
|
||||
"using the RAY_ADDRESS environment variable."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--job-id",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=("Job ID to specify for the job. "
|
||||
"If not provided, one will be generated."))
|
||||
help=("Job ID to specify for the job. " "If not provided, one will be generated."),
|
||||
)
|
||||
@click.option(
|
||||
"--runtime-env",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to a local YAML file containing a runtime_env definition.")
|
||||
help="Path to a local YAML file containing a runtime_env definition.",
|
||||
)
|
||||
@click.option(
|
||||
"--runtime-env-json",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="JSON-serialized runtime_env dictionary.")
|
||||
help="JSON-serialized runtime_env dictionary.",
|
||||
)
|
||||
@click.option(
|
||||
"--working-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=("Directory containing files that your job will run in. Can be a "
|
||||
"local directory or a remote URI to a .zip file (S3, GS, HTTP). "
|
||||
"If specified, this overrides the option in --runtime-env."),
|
||||
help=(
|
||||
"Directory containing files that your job will run in. Can be a "
|
||||
"local directory or a remote URI to a .zip file (S3, GS, HTTP). "
|
||||
"If specified, this overrides the option in --runtime-env."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--no-wait",
|
||||
is_flag=True,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If set, will not stream logs and wait for the job to exit.")
|
||||
help="If set, will not stream logs and wait for the job to exit.",
|
||||
)
|
||||
@add_click_logging_options
|
||||
@click.argument("entrypoint", nargs=-1, required=True, type=click.UNPROCESSED)
|
||||
def job_submit(address: Optional[str], job_id: Optional[str],
|
||||
runtime_env: Optional[str], runtime_env_json: Optional[str],
|
||||
working_dir: Optional[str], entrypoint: Tuple[str],
|
||||
no_wait: bool):
|
||||
def job_submit(
|
||||
address: Optional[str],
|
||||
job_id: Optional[str],
|
||||
runtime_env: Optional[str],
|
||||
runtime_env_json: Optional[str],
|
||||
working_dir: Optional[str],
|
||||
entrypoint: Tuple[str],
|
||||
no_wait: bool,
|
||||
):
|
||||
"""Submits a job to be run on the cluster.
|
||||
|
||||
Example:
|
||||
|
@ -132,8 +144,9 @@ def job_submit(address: Optional[str], job_id: Optional[str],
|
|||
final_runtime_env = {}
|
||||
if runtime_env is not None:
|
||||
if runtime_env_json is not None:
|
||||
raise ValueError("Only one of --runtime_env and "
|
||||
"--runtime-env-json can be provided.")
|
||||
raise ValueError(
|
||||
"Only one of --runtime_env and " "--runtime-env-json can be provided."
|
||||
)
|
||||
with open(runtime_env, "r") as f:
|
||||
final_runtime_env = yaml.safe_load(f)
|
||||
|
||||
|
@ -143,14 +156,14 @@ def job_submit(address: Optional[str], job_id: Optional[str],
|
|||
if working_dir is not None:
|
||||
if "working_dir" in final_runtime_env:
|
||||
cli_logger.warning(
|
||||
"Overriding runtime_env working_dir with --working-dir option")
|
||||
"Overriding runtime_env working_dir with --working-dir option"
|
||||
)
|
||||
|
||||
final_runtime_env["working_dir"] = working_dir
|
||||
|
||||
job_id = client.submit_job(
|
||||
entrypoint=" ".join(entrypoint),
|
||||
job_id=job_id,
|
||||
runtime_env=final_runtime_env)
|
||||
entrypoint=" ".join(entrypoint), job_id=job_id, runtime_env=final_runtime_env
|
||||
)
|
||||
|
||||
_log_big_success_msg(f"Job '{job_id}' submitted successfully")
|
||||
|
||||
|
@ -172,15 +185,16 @@ def job_submit(address: Optional[str], job_id: Optional[str],
|
|||
# sdk version 0 does not have log streaming
|
||||
if not no_wait:
|
||||
if int(sdk_version) > 0:
|
||||
cli_logger.print("Tailing logs until the job exits "
|
||||
"(disable with --no-wait):")
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
_tail_logs(client, job_id))
|
||||
cli_logger.print(
|
||||
"Tailing logs until the job exits " "(disable with --no-wait):"
|
||||
)
|
||||
asyncio.get_event_loop().run_until_complete(_tail_logs(client, job_id))
|
||||
else:
|
||||
cli_logger.warning(
|
||||
"Tailing logs is not enabled for job sdk client version "
|
||||
f"{sdk_version}. Please upgrade your ray to latest version "
|
||||
"for this feature.")
|
||||
"for this feature."
|
||||
)
|
||||
|
||||
|
||||
@job_cli_group.command("status", help="Get the status of a running job.")
|
||||
|
@ -189,8 +203,11 @@ def job_submit(address: Optional[str], job_id: Optional[str],
|
|||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=("Address of the Ray cluster to connect to. Can also be specified "
|
||||
"using the RAY_ADDRESS environment variable."))
|
||||
help=(
|
||||
"Address of the Ray cluster to connect to. Can also be specified "
|
||||
"using the RAY_ADDRESS environment variable."
|
||||
),
|
||||
)
|
||||
@click.argument("job-id", type=str)
|
||||
@add_click_logging_options
|
||||
def job_status(address: Optional[str], job_id: str):
|
||||
|
@ -209,14 +226,18 @@ def job_status(address: Optional[str], job_id: str):
|
|||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=("Address of the Ray cluster to connect to. Can also be specified "
|
||||
"using the RAY_ADDRESS environment variable."))
|
||||
help=(
|
||||
"Address of the Ray cluster to connect to. Can also be specified "
|
||||
"using the RAY_ADDRESS environment variable."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--no-wait",
|
||||
is_flag=True,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If set, will not wait for the job to exit.")
|
||||
help="If set, will not wait for the job to exit.",
|
||||
)
|
||||
@click.argument("job-id", type=str)
|
||||
@add_click_logging_options
|
||||
def job_stop(address: Optional[str], no_wait: bool, job_id: str):
|
||||
|
@ -232,14 +253,13 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
|
|||
if no_wait:
|
||||
return
|
||||
else:
|
||||
cli_logger.print(f"Waiting for job '{job_id}' to exit "
|
||||
f"(disable with --no-wait):")
|
||||
cli_logger.print(
|
||||
f"Waiting for job '{job_id}' to exit " f"(disable with --no-wait):"
|
||||
)
|
||||
|
||||
while True:
|
||||
status = client.get_job_status(job_id)
|
||||
if status.status in {
|
||||
JobStatus.STOPPED, JobStatus.SUCCEEDED, JobStatus.FAILED
|
||||
}:
|
||||
if status.status in {JobStatus.STOPPED, JobStatus.SUCCEEDED, JobStatus.FAILED}:
|
||||
_log_job_status(client, job_id)
|
||||
break
|
||||
else:
|
||||
|
@ -253,8 +273,11 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
|
|||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=("Address of the Ray cluster to connect to. Can also be specified "
|
||||
"using the RAY_ADDRESS environment variable."))
|
||||
help=(
|
||||
"Address of the Ray cluster to connect to. Can also be specified "
|
||||
"using the RAY_ADDRESS environment variable."
|
||||
),
|
||||
)
|
||||
@click.argument("job-id", type=str)
|
||||
@click.option(
|
||||
"-f",
|
||||
|
@ -262,7 +285,8 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
|
|||
is_flag=True,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If set, follow the logs (like `tail -f`).")
|
||||
help="If set, follow the logs (like `tail -f`).",
|
||||
)
|
||||
@add_click_logging_options
|
||||
def job_logs(address: Optional[str], job_id: str, follow: bool):
|
||||
"""Gets the logs of a job.
|
||||
|
@ -275,12 +299,12 @@ def job_logs(address: Optional[str], job_id: str, follow: bool):
|
|||
# sdk version 0 did not have log streaming
|
||||
if follow:
|
||||
if int(sdk_version) > 0:
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
_tail_logs(client, job_id))
|
||||
asyncio.get_event_loop().run_until_complete(_tail_logs(client, job_id))
|
||||
else:
|
||||
cli_logger.warning(
|
||||
"Tailing logs is not enabled for job sdk client version "
|
||||
f"{sdk_version}. Please upgrade your ray to latest version "
|
||||
"for this feature.")
|
||||
"for this feature."
|
||||
)
|
||||
else:
|
||||
print(client.get_job_logs(job_id), end="")
|
||||
|
|
|
@ -39,8 +39,10 @@ class JobStatusInfo:
|
|||
def __post_init__(self):
|
||||
if self.message is None:
|
||||
if self.status == JobStatus.PENDING:
|
||||
self.message = ("Job has not started yet, likely waiting "
|
||||
"for the runtime_env to be set up.")
|
||||
self.message = (
|
||||
"Job has not started yet, likely waiting "
|
||||
"for the runtime_env to be set up."
|
||||
)
|
||||
elif self.status == JobStatus.RUNNING:
|
||||
self.message = "Job is currently running."
|
||||
elif self.status == JobStatus.STOPPED:
|
||||
|
@ -55,6 +57,7 @@ class JobStatusStorageClient:
|
|||
"""
|
||||
Handles formatting of status storage key given job id.
|
||||
"""
|
||||
|
||||
JOB_STATUS_KEY = "_ray_internal_job_status_{job_id}"
|
||||
|
||||
def __init__(self):
|
||||
|
@ -69,12 +72,14 @@ class JobStatusStorageClient:
|
|||
_internal_kv_put(
|
||||
self.JOB_STATUS_KEY.format(job_id=job_id),
|
||||
pickle.dumps(status),
|
||||
namespace=ray_constants.KV_NAMESPACE_JOB)
|
||||
namespace=ray_constants.KV_NAMESPACE_JOB,
|
||||
)
|
||||
|
||||
def get_status(self, job_id: str) -> Optional[JobStatusInfo]:
|
||||
pickled_status = _internal_kv_get(
|
||||
self.JOB_STATUS_KEY.format(job_id=job_id),
|
||||
namespace=ray_constants.KV_NAMESPACE_JOB)
|
||||
namespace=ray_constants.KV_NAMESPACE_JOB,
|
||||
)
|
||||
if pickled_status is None:
|
||||
return None
|
||||
else:
|
||||
|
@ -87,18 +92,16 @@ def uri_to_http_components(package_uri: str) -> Tuple[str, str]:
|
|||
# We need to strip the gcs:// prefix and .zip suffix to make it
|
||||
# possible to pass the package_uri over HTTP.
|
||||
protocol, package_name = parse_uri(package_uri)
|
||||
return protocol.value, package_name[:-len(".zip")]
|
||||
return protocol.value, package_name[: -len(".zip")]
|
||||
|
||||
|
||||
def http_uri_components_to_uri(protocol: str, package_name: str) -> str:
|
||||
if package_name.endswith(".zip"):
|
||||
raise ValueError(
|
||||
f"package_name ({package_name}) should not end in .zip")
|
||||
raise ValueError(f"package_name ({package_name}) should not end in .zip")
|
||||
return f"{protocol}://{package_name}.zip"
|
||||
|
||||
|
||||
def validate_request_type(json_data: Dict[str, Any],
|
||||
request_type: dataclass) -> Any:
|
||||
def validate_request_type(json_data: Dict[str, Any], request_type: dataclass) -> Any:
|
||||
return request_type(**json_data)
|
||||
|
||||
|
||||
|
@ -124,8 +127,7 @@ class JobSubmitRequest:
|
|||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.entrypoint, str):
|
||||
raise TypeError(
|
||||
f"entrypoint must be a string, got {type(self.entrypoint)}")
|
||||
raise TypeError(f"entrypoint must be a string, got {type(self.entrypoint)}")
|
||||
|
||||
if self.job_id is not None and not isinstance(self.job_id, str):
|
||||
raise TypeError(
|
||||
|
@ -141,21 +143,21 @@ class JobSubmitRequest:
|
|||
for k in self.runtime_env.keys():
|
||||
if not isinstance(k, str):
|
||||
raise TypeError(
|
||||
f"runtime_env keys must be strings, got {type(k)}")
|
||||
f"runtime_env keys must be strings, got {type(k)}"
|
||||
)
|
||||
|
||||
if self.metadata is not None:
|
||||
if not isinstance(self.metadata, dict):
|
||||
raise TypeError(
|
||||
f"metadata must be a dict, got {type(self.metadata)}")
|
||||
raise TypeError(f"metadata must be a dict, got {type(self.metadata)}")
|
||||
else:
|
||||
for k in self.metadata.keys():
|
||||
if not isinstance(k, str):
|
||||
raise TypeError(
|
||||
f"metadata keys must be strings, got {type(k)}")
|
||||
raise TypeError(f"metadata keys must be strings, got {type(k)}")
|
||||
for v in self.metadata.values():
|
||||
if not isinstance(v, str):
|
||||
raise TypeError(
|
||||
f"metadata values must be strings, got {type(v)}")
|
||||
f"metadata values must be strings, got {type(v)}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -12,8 +12,7 @@ import ray
|
|||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||
from ray._private.gcs_utils import use_gcs_for_bootstrap
|
||||
from ray._private.runtime_env.packaging import (package_exists,
|
||||
upload_package_to_gcs)
|
||||
from ray._private.runtime_env.packaging import package_exists, upload_package_to_gcs
|
||||
from ray.dashboard.modules.job.common import (
|
||||
CURRENT_VERSION,
|
||||
http_uri_components_to_uri,
|
||||
|
@ -45,19 +44,20 @@ def _init_ray_and_catch_exceptions(f: Callable) -> Callable:
|
|||
if use_gcs_for_bootstrap():
|
||||
address = self._dashboard_head.gcs_address
|
||||
redis_pw = None
|
||||
logger.info(
|
||||
f"Connecting to ray with address={address}")
|
||||
logger.info(f"Connecting to ray with address={address}")
|
||||
else:
|
||||
ip, port = self._dashboard_head.redis_address
|
||||
redis_pw = self._dashboard_head.redis_password
|
||||
address = f"{ip}:{port}"
|
||||
logger.info(
|
||||
f"Connecting to ray with address={address}, "
|
||||
f"redis_pw={redis_pw}")
|
||||
f"redis_pw={redis_pw}"
|
||||
)
|
||||
ray.init(
|
||||
address=address,
|
||||
namespace=RAY_INTERNAL_JOBS_NAMESPACE,
|
||||
_redis_password=redis_pw)
|
||||
_redis_password=redis_pw,
|
||||
)
|
||||
except Exception as e:
|
||||
ray.shutdown()
|
||||
raise e from None
|
||||
|
@ -67,7 +67,8 @@ def _init_ray_and_catch_exceptions(f: Callable) -> Callable:
|
|||
logger.exception(f"Unexpected error in handler: {e}")
|
||||
return Response(
|
||||
text=traceback.format_exc(),
|
||||
status=aiohttp.web.HTTPInternalServerError.status_code)
|
||||
status=aiohttp.web.HTTPInternalServerError.status_code,
|
||||
)
|
||||
|
||||
return check
|
||||
|
||||
|
@ -77,8 +78,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
super().__init__(dashboard_head)
|
||||
self._job_manager = None
|
||||
|
||||
async def _parse_and_validate_request(self, req: Request,
|
||||
request_type: dataclass) -> Any:
|
||||
async def _parse_and_validate_request(
|
||||
self, req: Request, request_type: dataclass
|
||||
) -> Any:
|
||||
"""Parse request and cast to request type. If parsing failed, return a
|
||||
Response object with status 400 and stacktrace instead.
|
||||
"""
|
||||
|
@ -88,7 +90,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
logger.info(f"Got invalid request type: {e}")
|
||||
return Response(
|
||||
text=traceback.format_exc(),
|
||||
status=aiohttp.web.HTTPBadRequest.status_code)
|
||||
status=aiohttp.web.HTTPBadRequest.status_code,
|
||||
)
|
||||
|
||||
def job_exists(self, job_id: str) -> bool:
|
||||
status = self._job_manager.get_job_status(job_id)
|
||||
|
@ -101,7 +104,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
resp = VersionResponse(
|
||||
version=CURRENT_VERSION,
|
||||
ray_version=ray.__version__,
|
||||
ray_commit=ray.__commit__)
|
||||
ray_commit=ray.__commit__,
|
||||
)
|
||||
return Response(
|
||||
text=json.dumps(dataclasses.asdict(resp)),
|
||||
content_type="application/json",
|
||||
|
@ -113,12 +117,14 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
async def get_package(self, req: Request) -> Response:
|
||||
package_uri = http_uri_components_to_uri(
|
||||
protocol=req.match_info["protocol"],
|
||||
package_name=req.match_info["package_name"])
|
||||
package_name=req.match_info["package_name"],
|
||||
)
|
||||
|
||||
if not package_exists(package_uri):
|
||||
return Response(
|
||||
text=f"Package {package_uri} does not exist",
|
||||
status=aiohttp.web.HTTPNotFound.status_code)
|
||||
status=aiohttp.web.HTTPNotFound.status_code,
|
||||
)
|
||||
|
||||
return Response()
|
||||
|
||||
|
@ -127,14 +133,16 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
async def upload_package(self, req: Request):
|
||||
package_uri = http_uri_components_to_uri(
|
||||
protocol=req.match_info["protocol"],
|
||||
package_name=req.match_info["package_name"])
|
||||
package_name=req.match_info["package_name"],
|
||||
)
|
||||
logger.info(f"Uploading package {package_uri} to the GCS.")
|
||||
try:
|
||||
upload_package_to_gcs(package_uri, await req.read())
|
||||
except Exception:
|
||||
return Response(
|
||||
text=traceback.format_exc(),
|
||||
status=aiohttp.web.HTTPInternalServerError.status_code)
|
||||
status=aiohttp.web.HTTPInternalServerError.status_code,
|
||||
)
|
||||
|
||||
return Response(status=aiohttp.web.HTTPOk.status_code)
|
||||
|
||||
|
@ -153,17 +161,20 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
entrypoint=submit_request.entrypoint,
|
||||
job_id=submit_request.job_id,
|
||||
runtime_env=submit_request.runtime_env,
|
||||
metadata=submit_request.metadata)
|
||||
metadata=submit_request.metadata,
|
||||
)
|
||||
|
||||
resp = JobSubmitResponse(job_id=job_id)
|
||||
except (TypeError, ValueError):
|
||||
return Response(
|
||||
text=traceback.format_exc(),
|
||||
status=aiohttp.web.HTTPBadRequest.status_code)
|
||||
status=aiohttp.web.HTTPBadRequest.status_code,
|
||||
)
|
||||
except Exception:
|
||||
return Response(
|
||||
text=traceback.format_exc(),
|
||||
status=aiohttp.web.HTTPInternalServerError.status_code)
|
||||
status=aiohttp.web.HTTPInternalServerError.status_code,
|
||||
)
|
||||
|
||||
return Response(
|
||||
text=json.dumps(dataclasses.asdict(resp)),
|
||||
|
@ -178,7 +189,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
if not self.job_exists(job_id):
|
||||
return Response(
|
||||
text=f"Job {job_id} does not exist",
|
||||
status=aiohttp.web.HTTPNotFound.status_code)
|
||||
status=aiohttp.web.HTTPNotFound.status_code,
|
||||
)
|
||||
|
||||
try:
|
||||
stopped = self._job_manager.stop_job(job_id)
|
||||
|
@ -186,11 +198,12 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
except Exception:
|
||||
return Response(
|
||||
text=traceback.format_exc(),
|
||||
status=aiohttp.web.HTTPInternalServerError.status_code)
|
||||
status=aiohttp.web.HTTPInternalServerError.status_code,
|
||||
)
|
||||
|
||||
return Response(
|
||||
text=json.dumps(dataclasses.asdict(resp)),
|
||||
content_type="application/json")
|
||||
text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
|
||||
)
|
||||
|
||||
@routes.get("/api/jobs/{job_id}")
|
||||
@_init_ray_and_catch_exceptions
|
||||
|
@ -199,13 +212,14 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
if not self.job_exists(job_id):
|
||||
return Response(
|
||||
text=f"Job {job_id} does not exist",
|
||||
status=aiohttp.web.HTTPNotFound.status_code)
|
||||
status=aiohttp.web.HTTPNotFound.status_code,
|
||||
)
|
||||
|
||||
status: JobStatusInfo = self._job_manager.get_job_status(job_id)
|
||||
resp = JobStatusResponse(status=status.status, message=status.message)
|
||||
return Response(
|
||||
text=json.dumps(dataclasses.asdict(resp)),
|
||||
content_type="application/json")
|
||||
text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
|
||||
)
|
||||
|
||||
@routes.get("/api/jobs/{job_id}/logs")
|
||||
@_init_ray_and_catch_exceptions
|
||||
|
@ -214,12 +228,13 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
if not self.job_exists(job_id):
|
||||
return Response(
|
||||
text=f"Job {job_id} does not exist",
|
||||
status=aiohttp.web.HTTPNotFound.status_code)
|
||||
status=aiohttp.web.HTTPNotFound.status_code,
|
||||
)
|
||||
|
||||
resp = JobLogsResponse(logs=self._job_manager.get_job_logs(job_id))
|
||||
return Response(
|
||||
text=json.dumps(dataclasses.asdict(resp)),
|
||||
content_type="application/json")
|
||||
text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
|
||||
)
|
||||
|
||||
@routes.get("/api/jobs/{job_id}/logs/tail")
|
||||
@_init_ray_and_catch_exceptions
|
||||
|
@ -228,7 +243,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
if not self.job_exists(job_id):
|
||||
return Response(
|
||||
text=f"Job {job_id} does not exist",
|
||||
status=aiohttp.web.HTTPNotFound.status_code)
|
||||
status=aiohttp.web.HTTPNotFound.status_code,
|
||||
)
|
||||
|
||||
ws = aiohttp.web.WebSocketResponse()
|
||||
await ws.prepare(req)
|
||||
|
|
|
@ -15,8 +15,12 @@ from ray.exceptions import RuntimeEnvSetupError
|
|||
import ray.ray_constants as ray_constants
|
||||
from ray.actor import ActorHandle
|
||||
from ray.dashboard.modules.job.common import (
|
||||
JobStatus, JobStatusInfo, JobStatusStorageClient, JOB_ID_METADATA_KEY,
|
||||
JOB_NAME_METADATA_KEY)
|
||||
JobStatus,
|
||||
JobStatusInfo,
|
||||
JobStatusStorageClient,
|
||||
JOB_ID_METADATA_KEY,
|
||||
JOB_NAME_METADATA_KEY,
|
||||
)
|
||||
from ray.dashboard.modules.job.utils import file_tail_iterator
|
||||
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
|
||||
|
||||
|
@ -36,8 +40,8 @@ def generate_job_id() -> str:
|
|||
"""
|
||||
rand = random.SystemRandom()
|
||||
possible_characters = list(
|
||||
set(string.ascii_letters + string.digits) -
|
||||
{"I", "l", "o", "O", "0"} # No confusing characters
|
||||
set(string.ascii_letters + string.digits)
|
||||
- {"I", "l", "o", "O", "0"} # No confusing characters
|
||||
)
|
||||
id_part = "".join(rand.choices(possible_characters, k=16))
|
||||
return f"raysubmit_{id_part}"
|
||||
|
@ -47,6 +51,7 @@ class JobLogStorageClient:
|
|||
"""
|
||||
Disk storage for stdout / stderr of driver script logs.
|
||||
"""
|
||||
|
||||
JOB_LOGS_PATH = "job-driver-{job_id}.log"
|
||||
# Number of last N lines to put in job message upon failure.
|
||||
NUM_LOG_LINES_ON_ERROR = 10
|
||||
|
@ -61,9 +66,9 @@ class JobLogStorageClient:
|
|||
def tail_logs(self, job_id: str) -> Iterator[str]:
|
||||
return file_tail_iterator(self.get_log_file_path(job_id))
|
||||
|
||||
def get_last_n_log_lines(self,
|
||||
job_id: str,
|
||||
num_log_lines=NUM_LOG_LINES_ON_ERROR) -> str:
|
||||
def get_last_n_log_lines(
|
||||
self, job_id: str, num_log_lines=NUM_LOG_LINES_ON_ERROR
|
||||
) -> str:
|
||||
log_tail_iter = self.tail_logs(job_id)
|
||||
log_tail_deque = deque(maxlen=num_log_lines)
|
||||
for line in log_tail_iter:
|
||||
|
@ -80,7 +85,8 @@ class JobLogStorageClient:
|
|||
"""
|
||||
return os.path.join(
|
||||
ray.worker._global_node.get_logs_dir_path(),
|
||||
self.JOB_LOGS_PATH.format(job_id=job_id))
|
||||
self.JOB_LOGS_PATH.format(job_id=job_id),
|
||||
)
|
||||
|
||||
|
||||
class JobSupervisor:
|
||||
|
@ -95,8 +101,7 @@ class JobSupervisor:
|
|||
|
||||
SUBPROCESS_POLL_PERIOD_S = 0.1
|
||||
|
||||
def __init__(self, job_id: str, entrypoint: str,
|
||||
user_metadata: Dict[str, str]):
|
||||
def __init__(self, job_id: str, entrypoint: str, user_metadata: Dict[str, str]):
|
||||
self._job_id = job_id
|
||||
self._status_client = JobStatusStorageClient()
|
||||
self._log_client = JobLogStorageClient()
|
||||
|
@ -104,10 +109,7 @@ class JobSupervisor:
|
|||
self._entrypoint = entrypoint
|
||||
|
||||
# Default metadata if not passed by the user.
|
||||
self._metadata = {
|
||||
JOB_ID_METADATA_KEY: job_id,
|
||||
JOB_NAME_METADATA_KEY: job_id
|
||||
}
|
||||
self._metadata = {JOB_ID_METADATA_KEY: job_id, JOB_NAME_METADATA_KEY: job_id}
|
||||
self._metadata.update(user_metadata)
|
||||
|
||||
# fire and forget call from outer job manager to this actor
|
||||
|
@ -142,7 +144,8 @@ class JobSupervisor:
|
|||
shell=True,
|
||||
start_new_session=True,
|
||||
stdout=logs_file,
|
||||
stderr=subprocess.STDOUT)
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
parent_pid = os.getpid()
|
||||
# Create new pgid with new subprocess to execute driver command
|
||||
child_pid = child_process.pid
|
||||
|
@ -177,9 +180,10 @@ class JobSupervisor:
|
|||
return 1
|
||||
|
||||
async def run(
|
||||
self,
|
||||
# Signal actor used in testing to capture PENDING -> RUNNING cases
|
||||
_start_signal_actor: Optional[ActorHandle] = None):
|
||||
self,
|
||||
# Signal actor used in testing to capture PENDING -> RUNNING cases
|
||||
_start_signal_actor: Optional[ActorHandle] = None,
|
||||
):
|
||||
"""
|
||||
Stop and start both happen asynchrously, coordinated by asyncio event
|
||||
and coroutine, respectively.
|
||||
|
@ -190,26 +194,26 @@ class JobSupervisor:
|
|||
3) Handle concurrent events of driver execution and
|
||||
"""
|
||||
cur_status = self._get_status()
|
||||
assert cur_status.status == JobStatus.PENDING, (
|
||||
"Run should only be called once.")
|
||||
assert cur_status.status == JobStatus.PENDING, "Run should only be called once."
|
||||
|
||||
if _start_signal_actor:
|
||||
# Block in PENDING state until start signal received.
|
||||
await _start_signal_actor.wait.remote()
|
||||
|
||||
self._status_client.put_status(self._job_id,
|
||||
JobStatusInfo(JobStatus.RUNNING))
|
||||
self._status_client.put_status(self._job_id, JobStatusInfo(JobStatus.RUNNING))
|
||||
|
||||
try:
|
||||
# Set JobConfig for the child process (runtime_env, metadata).
|
||||
os.environ[RAY_JOB_CONFIG_JSON_ENV_VAR] = json.dumps({
|
||||
"runtime_env": self._runtime_env,
|
||||
"metadata": self._metadata,
|
||||
})
|
||||
os.environ[RAY_JOB_CONFIG_JSON_ENV_VAR] = json.dumps(
|
||||
{
|
||||
"runtime_env": self._runtime_env,
|
||||
"metadata": self._metadata,
|
||||
}
|
||||
)
|
||||
# Set RAY_ADDRESS to local Ray address, if it is not set.
|
||||
os.environ[
|
||||
ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE] = \
|
||||
ray._private.services.get_ray_address_from_environment()
|
||||
ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE
|
||||
] = ray._private.services.get_ray_address_from_environment()
|
||||
# Set PYTHONUNBUFFERED=1 to stream logs during the job instead of
|
||||
# only streaming them upon completion of the job.
|
||||
os.environ["PYTHONUNBUFFERED"] = "1"
|
||||
|
@ -218,8 +222,8 @@ class JobSupervisor:
|
|||
|
||||
polling_task = create_task(self._polling(child_process))
|
||||
finished, _ = await asyncio.wait(
|
||||
[polling_task, self._stop_event.wait()],
|
||||
return_when=FIRST_COMPLETED)
|
||||
[polling_task, self._stop_event.wait()], return_when=FIRST_COMPLETED
|
||||
)
|
||||
|
||||
if self._stop_event.is_set():
|
||||
polling_task.cancel()
|
||||
|
@ -229,29 +233,29 @@ class JobSupervisor:
|
|||
else:
|
||||
# Child process finished execution and no stop event is set
|
||||
# at the same time
|
||||
assert len(
|
||||
finished) == 1, "Should have only one coroutine done"
|
||||
assert len(finished) == 1, "Should have only one coroutine done"
|
||||
[child_process_task] = finished
|
||||
return_code = child_process_task.result()
|
||||
if return_code == 0:
|
||||
self._status_client.put_status(self._job_id,
|
||||
JobStatus.SUCCEEDED)
|
||||
self._status_client.put_status(self._job_id, JobStatus.SUCCEEDED)
|
||||
else:
|
||||
log_tail = self._log_client.get_last_n_log_lines(
|
||||
self._job_id)
|
||||
log_tail = self._log_client.get_last_n_log_lines(self._job_id)
|
||||
if log_tail is not None and log_tail != "":
|
||||
message = ("Job failed due to an application error, "
|
||||
"last available logs:\n" + log_tail)
|
||||
message = (
|
||||
"Job failed due to an application error, "
|
||||
"last available logs:\n" + log_tail
|
||||
)
|
||||
else:
|
||||
message = None
|
||||
self._status_client.put_status(
|
||||
self._job_id,
|
||||
JobStatusInfo(
|
||||
status=JobStatus.FAILED, message=message))
|
||||
JobStatusInfo(status=JobStatus.FAILED, message=message),
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Got unexpected exception while trying to execute driver "
|
||||
f"command. {traceback.format_exc()}")
|
||||
f"command. {traceback.format_exc()}"
|
||||
)
|
||||
finally:
|
||||
# clean up actor after tasks are finished
|
||||
ray.actor.exit_actor()
|
||||
|
@ -260,8 +264,7 @@ class JobSupervisor:
|
|||
return self._status_client.get_status(self._job_id)
|
||||
|
||||
def stop(self):
|
||||
"""Set step_event and let run() handle the rest in its asyncio.wait().
|
||||
"""
|
||||
"""Set step_event and let run() handle the rest in its asyncio.wait()."""
|
||||
self._stop_event.set()
|
||||
|
||||
|
||||
|
@ -271,6 +274,7 @@ class JobManager:
|
|||
It does not provide persistence, all info will be lost if the cluster
|
||||
goes down.
|
||||
"""
|
||||
|
||||
JOB_ACTOR_NAME = "_ray_internal_job_actor_{job_id}"
|
||||
# Time that we will sleep while tailing logs if no new log line is
|
||||
# available.
|
||||
|
@ -300,11 +304,9 @@ class JobManager:
|
|||
if key.startswith("node:"):
|
||||
return key
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot find the node dictionary for current node.")
|
||||
raise ValueError("Cannot find the node dictionary for current node.")
|
||||
|
||||
def _handle_supervisor_startup(self, job_id: str,
|
||||
result: Optional[Exception]):
|
||||
def _handle_supervisor_startup(self, job_id: str, result: Optional[Exception]):
|
||||
"""Handle the result of starting a job supervisor actor.
|
||||
|
||||
If started successfully, result should be None. Otherwise it should be
|
||||
|
@ -321,26 +323,30 @@ class JobManager:
|
|||
job_id,
|
||||
JobStatusInfo(
|
||||
status=JobStatus.FAILED,
|
||||
message=(f"runtime_env setup failed: {result}")))
|
||||
message=(f"runtime_env setup failed: {result}"),
|
||||
),
|
||||
)
|
||||
elif isinstance(result, Exception):
|
||||
logger.error(
|
||||
f"Failed to start supervisor for job {job_id}: {result}.")
|
||||
logger.error(f"Failed to start supervisor for job {job_id}: {result}.")
|
||||
self._status_client.put_status(
|
||||
job_id,
|
||||
JobStatusInfo(
|
||||
status=JobStatus.FAILED,
|
||||
message=f"Error occurred while starting the job: {result}")
|
||||
message=f"Error occurred while starting the job: {result}",
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert False, "This should not be reached."
|
||||
|
||||
def submit_job(self,
|
||||
*,
|
||||
entrypoint: str,
|
||||
job_id: Optional[str] = None,
|
||||
runtime_env: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
_start_signal_actor: Optional[ActorHandle] = None) -> str:
|
||||
def submit_job(
|
||||
self,
|
||||
*,
|
||||
entrypoint: str,
|
||||
job_id: Optional[str] = None,
|
||||
runtime_env: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
_start_signal_actor: Optional[ActorHandle] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Job execution happens asynchronously.
|
||||
|
||||
|
@ -390,8 +396,8 @@ class JobManager:
|
|||
resources={
|
||||
self._get_current_node_resource_key(): 0.001,
|
||||
},
|
||||
runtime_env=runtime_env).remote(job_id, entrypoint, metadata
|
||||
or {})
|
||||
runtime_env=runtime_env,
|
||||
).remote(job_id, entrypoint, metadata or {})
|
||||
actor.run.remote(_start_signal_actor=_start_signal_actor)
|
||||
|
||||
def callback(result: Optional[Exception]):
|
||||
|
@ -441,7 +447,8 @@ class JobManager:
|
|||
# updating GCS with latest status.
|
||||
last_status = self._status_client.get_status(job_id)
|
||||
if last_status and last_status.status in {
|
||||
JobStatus.PENDING, JobStatus.RUNNING
|
||||
JobStatus.PENDING,
|
||||
JobStatus.RUNNING,
|
||||
}:
|
||||
self._status_client.put_status(job_id, JobStatus.FAILED)
|
||||
|
||||
|
|
|
@ -13,10 +13,19 @@ except ImportError:
|
|||
requests = None
|
||||
|
||||
from ray._private.runtime_env.packaging import (
|
||||
create_package, get_uri_for_directory, parse_uri)
|
||||
create_package,
|
||||
get_uri_for_directory,
|
||||
parse_uri,
|
||||
)
|
||||
from ray.dashboard.modules.job.common import (
|
||||
JobSubmitRequest, JobSubmitResponse, JobStopResponse, JobStatusInfo,
|
||||
JobStatusResponse, JobLogsResponse, uri_to_http_components)
|
||||
JobSubmitRequest,
|
||||
JobSubmitResponse,
|
||||
JobStopResponse,
|
||||
JobStatusInfo,
|
||||
JobStatusResponse,
|
||||
JobLogsResponse,
|
||||
uri_to_http_components,
|
||||
)
|
||||
|
||||
from ray.client_builder import _split_address
|
||||
|
||||
|
@ -33,51 +42,49 @@ class ClusterInfo:
|
|||
|
||||
|
||||
def get_job_submission_client_cluster_info(
|
||||
address: str,
|
||||
# For backwards compatibility
|
||||
*,
|
||||
# only used in importlib case in parse_cluster_info, but needed
|
||||
# in function signature.
|
||||
create_cluster_if_needed: Optional[bool] = False,
|
||||
cookies: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, Any]] = None) -> ClusterInfo:
|
||||
address: str,
|
||||
# For backwards compatibility
|
||||
*,
|
||||
# only used in importlib case in parse_cluster_info, but needed
|
||||
# in function signature.
|
||||
create_cluster_if_needed: Optional[bool] = False,
|
||||
cookies: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, Any]] = None,
|
||||
) -> ClusterInfo:
|
||||
"""Get address, cookies, and metadata used for JobSubmissionClient.
|
||||
|
||||
Args:
|
||||
address (str): Address without the module prefix that is passed
|
||||
to JobSubmissionClient.
|
||||
create_cluster_if_needed (bool): Indicates whether the cluster
|
||||
of the address returned needs to be running. Ray doesn't
|
||||
start a cluster before interacting with jobs, but other
|
||||
implementations may do so.
|
||||
Args:
|
||||
address (str): Address without the module prefix that is passed
|
||||
to JobSubmissionClient.
|
||||
create_cluster_if_needed (bool): Indicates whether the cluster
|
||||
of the address returned needs to be running. Ray doesn't
|
||||
start a cluster before interacting with jobs, but other
|
||||
implementations may do so.
|
||||
|
||||
Returns:
|
||||
ClusterInfo object consisting of address, cookies, and metadata
|
||||
for JobSubmissionClient to use.
|
||||
"""
|
||||
Returns:
|
||||
ClusterInfo object consisting of address, cookies, and metadata
|
||||
for JobSubmissionClient to use.
|
||||
"""
|
||||
return ClusterInfo(
|
||||
address="http://" + address,
|
||||
cookies=cookies,
|
||||
metadata=metadata,
|
||||
headers=headers)
|
||||
address="http://" + address, cookies=cookies, metadata=metadata, headers=headers
|
||||
)
|
||||
|
||||
|
||||
def parse_cluster_info(
|
||||
address: str,
|
||||
create_cluster_if_needed: bool = False,
|
||||
cookies: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, Any]] = None) -> ClusterInfo:
|
||||
address: str,
|
||||
create_cluster_if_needed: bool = False,
|
||||
cookies: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, Any]] = None,
|
||||
) -> ClusterInfo:
|
||||
module_string, inner_address = _split_address(address.rstrip("/"))
|
||||
|
||||
# If user passes in a raw HTTP(S) address, just pass it through.
|
||||
if module_string == "http" or module_string == "https":
|
||||
return ClusterInfo(
|
||||
address=address,
|
||||
cookies=cookies,
|
||||
metadata=metadata,
|
||||
headers=headers)
|
||||
address=address, cookies=cookies, metadata=metadata, headers=headers
|
||||
)
|
||||
# If user passes in a Ray address, convert it to HTTP.
|
||||
elif module_string == "ray":
|
||||
return get_job_submission_client_cluster_info(
|
||||
|
@ -85,7 +92,8 @@ def parse_cluster_info(
|
|||
create_cluster_if_needed=create_cluster_if_needed,
|
||||
cookies=cookies,
|
||||
metadata=metadata,
|
||||
headers=headers)
|
||||
headers=headers,
|
||||
)
|
||||
# Try to dynamically import the function to get cluster info.
|
||||
else:
|
||||
try:
|
||||
|
@ -93,33 +101,40 @@ def parse_cluster_info(
|
|||
except Exception:
|
||||
raise RuntimeError(
|
||||
f"Module: {module_string} does not exist.\n"
|
||||
f"This module was parsed from Address: {address}") from None
|
||||
f"This module was parsed from Address: {address}"
|
||||
) from None
|
||||
assert "get_job_submission_client_cluster_info" in dir(module), (
|
||||
f"Module: {module_string} does "
|
||||
"not have `get_job_submission_client_cluster_info`.")
|
||||
"not have `get_job_submission_client_cluster_info`."
|
||||
)
|
||||
|
||||
return module.get_job_submission_client_cluster_info(
|
||||
inner_address,
|
||||
create_cluster_if_needed=create_cluster_if_needed,
|
||||
cookies=cookies,
|
||||
metadata=metadata,
|
||||
headers=headers)
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
class JobSubmissionClient:
|
||||
def __init__(self,
|
||||
address: str,
|
||||
create_cluster_if_needed=False,
|
||||
cookies: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, Any]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
address: str,
|
||||
create_cluster_if_needed=False,
|
||||
cookies: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if requests is None:
|
||||
raise RuntimeError(
|
||||
"The Ray jobs CLI & SDK require the ray[default] "
|
||||
"installation: `pip install 'ray[default']``")
|
||||
"installation: `pip install 'ray[default']``"
|
||||
)
|
||||
|
||||
cluster_info = parse_cluster_info(address, create_cluster_if_needed,
|
||||
cookies, metadata, headers)
|
||||
cluster_info = parse_cluster_info(
|
||||
address, create_cluster_if_needed, cookies, metadata, headers
|
||||
)
|
||||
self._address = cluster_info.address
|
||||
self._cookies = cluster_info.cookies
|
||||
self._default_metadata = cluster_info.metadata or {}
|
||||
|
@ -136,38 +151,43 @@ class JobSubmissionClient:
|
|||
raise RuntimeError(
|
||||
"Jobs API not supported on the Ray cluster. "
|
||||
"Please ensure the cluster is running "
|
||||
"Ray 1.9 or higher.")
|
||||
"Ray 1.9 or higher."
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
# TODO(edoakes): check the version if/when we break compatibility.
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to Ray at address: {self._address}.")
|
||||
f"Failed to connect to Ray at address: {self._address}."
|
||||
)
|
||||
|
||||
def _raise_error(self, r: "requests.Response"):
|
||||
raise RuntimeError(
|
||||
f"Request failed with status code {r.status_code}: {r.text}.")
|
||||
f"Request failed with status code {r.status_code}: {r.text}."
|
||||
)
|
||||
|
||||
def _do_request(self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
*,
|
||||
data: Optional[bytes] = None,
|
||||
json_data: Optional[dict] = None) -> Optional[object]:
|
||||
def _do_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
*,
|
||||
data: Optional[bytes] = None,
|
||||
json_data: Optional[dict] = None,
|
||||
) -> Optional[object]:
|
||||
url = self._address + endpoint
|
||||
logger.debug(
|
||||
f"Sending request to {url} with json data: {json_data or {}}.")
|
||||
logger.debug(f"Sending request to {url} with json data: {json_data or {}}.")
|
||||
return requests.request(
|
||||
method,
|
||||
url,
|
||||
cookies=self._cookies,
|
||||
data=data,
|
||||
json=json_data,
|
||||
headers=self._headers)
|
||||
headers=self._headers,
|
||||
)
|
||||
|
||||
def _package_exists(
|
||||
self,
|
||||
package_uri: str,
|
||||
self,
|
||||
package_uri: str,
|
||||
) -> bool:
|
||||
protocol, package_name = uri_to_http_components(package_uri)
|
||||
r = self._do_request("GET", f"/api/packages/{protocol}/{package_name}")
|
||||
|
@ -181,11 +201,13 @@ class JobSubmissionClient:
|
|||
else:
|
||||
self._raise_error(r)
|
||||
|
||||
def _upload_package(self,
|
||||
package_uri: str,
|
||||
package_path: str,
|
||||
include_parent_dir: Optional[bool] = False,
|
||||
excludes: Optional[List[str]] = None) -> bool:
|
||||
def _upload_package(
|
||||
self,
|
||||
package_uri: str,
|
||||
package_path: str,
|
||||
include_parent_dir: Optional[bool] = False,
|
||||
excludes: Optional[List[str]] = None,
|
||||
) -> bool:
|
||||
logger.info(f"Uploading package {package_uri}.")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
protocol, package_name = uri_to_http_components(package_uri)
|
||||
|
@ -194,26 +216,27 @@ class JobSubmissionClient:
|
|||
package_path,
|
||||
package_file,
|
||||
include_parent_dir=include_parent_dir,
|
||||
excludes=excludes)
|
||||
excludes=excludes,
|
||||
)
|
||||
try:
|
||||
r = self._do_request(
|
||||
"PUT",
|
||||
f"/api/packages/{protocol}/{package_name}",
|
||||
data=package_file.read_bytes())
|
||||
data=package_file.read_bytes(),
|
||||
)
|
||||
if r.status_code != 200:
|
||||
self._raise_error(r)
|
||||
finally:
|
||||
package_file.unlink()
|
||||
|
||||
def _upload_package_if_needed(self,
|
||||
package_path: str,
|
||||
excludes: Optional[List[str]] = None) -> str:
|
||||
def _upload_package_if_needed(
|
||||
self, package_path: str, excludes: Optional[List[str]] = None
|
||||
) -> str:
|
||||
package_uri = get_uri_for_directory(package_path, excludes=excludes)
|
||||
if not self._package_exists(package_uri):
|
||||
self._upload_package(package_uri, package_path, excludes=excludes)
|
||||
else:
|
||||
logger.info(
|
||||
f"Package {package_uri} already exists, skipping upload.")
|
||||
logger.info(f"Package {package_uri} already exists, skipping upload.")
|
||||
|
||||
return package_uri
|
||||
|
||||
|
@ -230,7 +253,8 @@ class JobSubmissionClient:
|
|||
if not is_uri:
|
||||
logger.debug("working_dir is not a URI, attempting to upload.")
|
||||
package_uri = self._upload_package_if_needed(
|
||||
working_dir, excludes=runtime_env.get("excludes", None))
|
||||
working_dir, excludes=runtime_env.get("excludes", None)
|
||||
)
|
||||
runtime_env["working_dir"] = package_uri
|
||||
|
||||
def get_version(self) -> str:
|
||||
|
@ -241,12 +265,12 @@ class JobSubmissionClient:
|
|||
self._raise_error(r)
|
||||
|
||||
def submit_job(
|
||||
self,
|
||||
*,
|
||||
entrypoint: str,
|
||||
job_id: Optional[str] = None,
|
||||
runtime_env: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
self,
|
||||
*,
|
||||
entrypoint: str,
|
||||
job_id: Optional[str] = None,
|
||||
runtime_env: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
runtime_env = runtime_env or {}
|
||||
metadata = metadata or {}
|
||||
|
@ -257,11 +281,11 @@ class JobSubmissionClient:
|
|||
entrypoint=entrypoint,
|
||||
job_id=job_id,
|
||||
runtime_env=runtime_env,
|
||||
metadata=metadata)
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
logger.debug(f"Submitting job with job_id={job_id}.")
|
||||
r = self._do_request(
|
||||
"POST", "/api/jobs/", json_data=dataclasses.asdict(req))
|
||||
r = self._do_request("POST", "/api/jobs/", json_data=dataclasses.asdict(req))
|
||||
|
||||
if r.status_code == 200:
|
||||
return JobSubmitResponse(**r.json()).job_id
|
||||
|
@ -269,8 +293,8 @@ class JobSubmissionClient:
|
|||
self._raise_error(r)
|
||||
|
||||
def stop_job(
|
||||
self,
|
||||
job_id: str,
|
||||
self,
|
||||
job_id: str,
|
||||
) -> bool:
|
||||
logger.debug(f"Stopping job with job_id={job_id}.")
|
||||
r = self._do_request("POST", f"/api/jobs/{job_id}/stop")
|
||||
|
@ -281,15 +305,14 @@ class JobSubmissionClient:
|
|||
self._raise_error(r)
|
||||
|
||||
def get_job_status(
|
||||
self,
|
||||
job_id: str,
|
||||
self,
|
||||
job_id: str,
|
||||
) -> JobStatusInfo:
|
||||
r = self._do_request("GET", f"/api/jobs/{job_id}")
|
||||
|
||||
if r.status_code == 200:
|
||||
response = JobStatusResponse(**r.json())
|
||||
return JobStatusInfo(
|
||||
status=response.status, message=response.message)
|
||||
return JobStatusInfo(status=response.status, message=response.message)
|
||||
else:
|
||||
self._raise_error(r)
|
||||
|
||||
|
@ -304,7 +327,8 @@ class JobSubmissionClient:
|
|||
async def tail_job_logs(self, job_id: str) -> Iterator[str]:
|
||||
async with aiohttp.ClientSession(cookies=self._cookies) as session:
|
||||
ws = await session.ws_connect(
|
||||
f"{self._address}/api/jobs/{job_id}/logs/tail")
|
||||
f"{self._address}/api/jobs/{job_id}/logs/tail"
|
||||
)
|
||||
|
||||
while True:
|
||||
msg = await ws.receive()
|
||||
|
|
|
@ -7,13 +7,13 @@ we ended up using job submission API call's runtime_env instead of scripts
|
|||
def run():
|
||||
import ray
|
||||
import os
|
||||
|
||||
ray.init(
|
||||
address=os.environ["RAY_ADDRESS"],
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "SHOULD_BE_OVERRIDEN"
|
||||
}
|
||||
})
|
||||
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "SHOULD_BE_OVERRIDEN"}
|
||||
},
|
||||
)
|
||||
|
||||
@ray.remote
|
||||
def foo():
|
||||
|
|
|
@ -21,15 +21,14 @@ def conda_env(env_name):
|
|||
# Clean up created conda env upon test exit to prevent leaking
|
||||
del os.environ["JOB_COMPATIBILITY_TEST_TEMP_ENV"]
|
||||
subprocess.run(
|
||||
f"conda env remove -y --name {env_name}",
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE)
|
||||
f"conda env remove -y --name {env_name}", shell=True, stdout=subprocess.PIPE
|
||||
)
|
||||
|
||||
|
||||
def _compatibility_script_path(file_name: str) -> str:
|
||||
return os.path.join(
|
||||
os.path.dirname(__file__), "backwards_compatibility_scripts",
|
||||
file_name)
|
||||
os.path.dirname(__file__), "backwards_compatibility_scripts", file_name
|
||||
)
|
||||
|
||||
|
||||
class TestBackwardsCompatibility:
|
||||
|
@ -48,8 +47,7 @@ class TestBackwardsCompatibility:
|
|||
shell_cmd = f"{_compatibility_script_path('test_backwards_compatibility.sh')}" # noqa: E501
|
||||
|
||||
try:
|
||||
subprocess.check_output(
|
||||
shell_cmd, shell=True, stderr=subprocess.STDOUT)
|
||||
subprocess.check_output(shell_cmd, shell=True, stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(str(e))
|
||||
logger.error(e.stdout.decode())
|
||||
|
|
|
@ -34,8 +34,7 @@ def mock_sdk_client():
|
|||
|
||||
if "RAY_ADDRESS" in os.environ:
|
||||
del os.environ["RAY_ADDRESS"]
|
||||
with mock.patch("ray.dashboard.modules.job.cli.JobSubmissionClient"
|
||||
) as mock_client:
|
||||
with mock.patch("ray.dashboard.modules.job.cli.JobSubmissionClient") as mock_client:
|
||||
# In python 3.6 it will fail with error
|
||||
# 'async for' requires an object with __aiter__ method, got MagicMock"
|
||||
mock_client().tail_job_logs.return_value = AsyncIterator(range(10))
|
||||
|
@ -52,9 +51,7 @@ def runtime_env_formats():
|
|||
"working_dir": "s3://bogus.zip",
|
||||
"conda": "conda_env",
|
||||
"pip": ["pip-install-test"],
|
||||
"env_vars": {
|
||||
"hi": "hi2"
|
||||
}
|
||||
"env_vars": {"hi": "hi2"},
|
||||
}
|
||||
|
||||
yaml_file = path / "env.yaml"
|
||||
|
@ -86,14 +83,13 @@ class TestSubmit:
|
|||
|
||||
# Test passing address via command line.
|
||||
result = runner.invoke(
|
||||
job_cli_group,
|
||||
["submit", "--address=arg_addr", "--", "echo hello"])
|
||||
job_cli_group, ["submit", "--address=arg_addr", "--", "echo hello"]
|
||||
)
|
||||
assert mock_sdk_client.called_with("arg_addr")
|
||||
assert result.exit_code == 0
|
||||
# Test passing address via env var.
|
||||
with set_env_var("RAY_ADDRESS", "env_addr"):
|
||||
result = runner.invoke(job_cli_group,
|
||||
["submit", "--", "echo hello"])
|
||||
result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
|
||||
assert result.exit_code == 0
|
||||
assert mock_sdk_client.called_with("env_addr")
|
||||
# Test passing no address.
|
||||
|
@ -106,24 +102,22 @@ class TestSubmit:
|
|||
mock_client_instance = mock_sdk_client.return_value
|
||||
|
||||
with set_env_var("RAY_ADDRESS", "env_addr"):
|
||||
result = runner.invoke(job_cli_group,
|
||||
["submit", "--", "echo hello"])
|
||||
result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
|
||||
assert result.exit_code == 0
|
||||
assert mock_client_instance.called_with(runtime_env={})
|
||||
|
||||
result = runner.invoke(
|
||||
job_cli_group,
|
||||
["submit", "--", "--working-dir", "blah", "--", "echo hello"])
|
||||
["submit", "--", "--working-dir", "blah", "--", "echo hello"],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert mock_client_instance.called_with(
|
||||
runtime_env={"working_dir": "blah"})
|
||||
assert mock_client_instance.called_with(runtime_env={"working_dir": "blah"})
|
||||
|
||||
result = runner.invoke(
|
||||
job_cli_group,
|
||||
["submit", "--", "--working-dir='.'", "--", "echo hello"])
|
||||
job_cli_group, ["submit", "--", "--working-dir='.'", "--", "echo hello"]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert mock_client_instance.called_with(
|
||||
runtime_env={"working_dir": "."})
|
||||
assert mock_client_instance.called_with(runtime_env={"working_dir": "."})
|
||||
|
||||
def test_runtime_env(self, mock_sdk_client, runtime_env_formats):
|
||||
runner = CliRunner()
|
||||
|
@ -133,39 +127,64 @@ class TestSubmit:
|
|||
with set_env_var("RAY_ADDRESS", "env_addr"):
|
||||
# Test passing via file.
|
||||
result = runner.invoke(
|
||||
job_cli_group,
|
||||
["submit", "--runtime-env", env_yaml, "--", "echo hello"])
|
||||
job_cli_group, ["submit", "--runtime-env", env_yaml, "--", "echo hello"]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert mock_client_instance.called_with(runtime_env=env_dict)
|
||||
|
||||
# Test passing via json.
|
||||
result = runner.invoke(
|
||||
job_cli_group,
|
||||
["submit", "--runtime-env-json", env_json, "--", "echo hello"])
|
||||
["submit", "--runtime-env-json", env_json, "--", "echo hello"],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert mock_client_instance.called_with(runtime_env=env_dict)
|
||||
|
||||
# Test passing both throws an error.
|
||||
result = runner.invoke(job_cli_group, [
|
||||
"submit", "--runtime-env", env_yaml, "--runtime-env-json",
|
||||
env_json, "--", "echo hello"
|
||||
])
|
||||
result = runner.invoke(
|
||||
job_cli_group,
|
||||
[
|
||||
"submit",
|
||||
"--runtime-env",
|
||||
env_yaml,
|
||||
"--runtime-env-json",
|
||||
env_json,
|
||||
"--",
|
||||
"echo hello",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 1
|
||||
assert "Only one of" in str(result.exception)
|
||||
|
||||
# Test overriding working_dir.
|
||||
env_dict.update(working_dir=".")
|
||||
result = runner.invoke(job_cli_group, [
|
||||
"submit", "--runtime-env", env_yaml, "--working-dir", ".",
|
||||
"--", "echo hello"
|
||||
])
|
||||
result = runner.invoke(
|
||||
job_cli_group,
|
||||
[
|
||||
"submit",
|
||||
"--runtime-env",
|
||||
env_yaml,
|
||||
"--working-dir",
|
||||
".",
|
||||
"--",
|
||||
"echo hello",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert mock_client_instance.called_with(runtime_env=env_dict)
|
||||
|
||||
result = runner.invoke(job_cli_group, [
|
||||
"submit", "--runtime-env-json", env_json, "--working-dir", ".",
|
||||
"--", "echo hello"
|
||||
])
|
||||
result = runner.invoke(
|
||||
job_cli_group,
|
||||
[
|
||||
"submit",
|
||||
"--runtime-env-json",
|
||||
env_json,
|
||||
"--working-dir",
|
||||
".",
|
||||
"--",
|
||||
"echo hello",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert mock_client_instance.called_with(runtime_env=env_dict)
|
||||
|
||||
|
@ -174,18 +193,18 @@ class TestSubmit:
|
|||
mock_client_instance = mock_sdk_client.return_value
|
||||
|
||||
with set_env_var("RAY_ADDRESS", "env_addr"):
|
||||
result = runner.invoke(job_cli_group,
|
||||
["submit", "--", "echo hello"])
|
||||
result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
|
||||
assert result.exit_code == 0
|
||||
assert mock_client_instance.called_with(job_id=None)
|
||||
|
||||
result = runner.invoke(
|
||||
job_cli_group,
|
||||
["submit", "--", "--job-id=my_job_id", "echo hello"])
|
||||
job_cli_group, ["submit", "--", "--job-id=my_job_id", "echo hello"]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert mock_client_instance.called_with(job_id="my_job_id")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -60,25 +60,27 @@ class TestRayAddress:
|
|||
def test_empty_ray_address(self, ray_start_stop):
|
||||
with set_env_var("RAY_ADDRESS", None):
|
||||
completed_process = subprocess.run(
|
||||
["ray", "job", "submit", "--", "echo hello"],
|
||||
stderr=subprocess.PIPE)
|
||||
["ray", "job", "submit", "--", "echo hello"], stderr=subprocess.PIPE
|
||||
)
|
||||
stderr = completed_process.stderr.decode("utf-8")
|
||||
# Current dashboard module that raises no exception from requests..
|
||||
assert ("Address must be specified using either the "
|
||||
"--address flag or RAY_ADDRESS environment") in stderr
|
||||
assert (
|
||||
"Address must be specified using either the "
|
||||
"--address flag or RAY_ADDRESS environment"
|
||||
) in stderr
|
||||
|
||||
def test_ray_client_address(self, ray_start_stop):
|
||||
completed_process = subprocess.run(
|
||||
["ray", "job", "submit", "--", "echo hello"],
|
||||
stdout=subprocess.PIPE)
|
||||
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
|
||||
)
|
||||
stdout = completed_process.stdout.decode("utf-8")
|
||||
assert "hello" in stdout
|
||||
assert "succeeded" in stdout
|
||||
|
||||
def test_valid_http_ray_address(self, ray_start_stop):
|
||||
completed_process = subprocess.run(
|
||||
["ray", "job", "submit", "--", "echo hello"],
|
||||
stdout=subprocess.PIPE)
|
||||
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
|
||||
)
|
||||
stdout = completed_process.stdout.decode("utf-8")
|
||||
assert "hello" in stdout
|
||||
assert "succeeded" in stdout
|
||||
|
@ -87,8 +89,8 @@ class TestRayAddress:
|
|||
with set_env_var("RAY_ADDRESS", "http://127.0.0.1:8265"):
|
||||
with ray_cluster_manager():
|
||||
completed_process = subprocess.run(
|
||||
["ray", "job", "submit", "--", "echo hello"],
|
||||
stdout=subprocess.PIPE)
|
||||
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
|
||||
)
|
||||
stdout = completed_process.stdout.decode("utf-8")
|
||||
assert "hello" in stdout
|
||||
assert "succeeded" in stdout
|
||||
|
@ -97,8 +99,8 @@ class TestRayAddress:
|
|||
with set_env_var("RAY_ADDRESS", "127.0.0.1:8265"):
|
||||
with ray_cluster_manager():
|
||||
completed_process = subprocess.run(
|
||||
["ray", "job", "submit", "--", "echo hello"],
|
||||
stdout=subprocess.PIPE)
|
||||
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
|
||||
)
|
||||
stdout = completed_process.stdout.decode("utf-8")
|
||||
assert "hello" in stdout
|
||||
assert "succeeded" in stdout
|
||||
|
@ -109,7 +111,8 @@ class TestJobSubmit:
|
|||
"""Should tail logs and wait for process to exit."""
|
||||
cmd = "sleep 1 && echo hello && sleep 1 && echo hello"
|
||||
completed_process = subprocess.run(
|
||||
["ray", "job", "submit", "--", cmd], stdout=subprocess.PIPE)
|
||||
["ray", "job", "submit", "--", cmd], stdout=subprocess.PIPE
|
||||
)
|
||||
stdout = completed_process.stdout.decode("utf-8")
|
||||
assert "hello\nhello" in stdout
|
||||
assert "succeeded" in stdout
|
||||
|
@ -118,8 +121,8 @@ class TestJobSubmit:
|
|||
"""Should exit immediately w/o printing logs."""
|
||||
cmd = "echo hello && sleep 1000"
|
||||
completed_process = subprocess.run(
|
||||
["ray", "job", "submit", "--no-wait", "--", cmd],
|
||||
stdout=subprocess.PIPE)
|
||||
["ray", "job", "submit", "--no-wait", "--", cmd], stdout=subprocess.PIPE
|
||||
)
|
||||
stdout = completed_process.stdout.decode("utf-8")
|
||||
assert "hello" not in stdout
|
||||
assert "Tailing logs until the job exits" not in stdout
|
||||
|
@ -130,13 +133,13 @@ class TestJobStop:
|
|||
"""Should wait until the job is stopped."""
|
||||
cmd = "sleep 1000"
|
||||
job_id = "test_basic_stop"
|
||||
completed_process = subprocess.run([
|
||||
"ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--",
|
||||
cmd
|
||||
])
|
||||
completed_process = subprocess.run(
|
||||
["ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--", cmd]
|
||||
)
|
||||
|
||||
completed_process = subprocess.run(
|
||||
["ray", "job", "stop", job_id], stdout=subprocess.PIPE)
|
||||
["ray", "job", "stop", job_id], stdout=subprocess.PIPE
|
||||
)
|
||||
stdout = completed_process.stdout.decode("utf-8")
|
||||
assert "Waiting for job" in stdout
|
||||
assert f"Job '{job_id}' was stopped" in stdout
|
||||
|
@ -145,14 +148,13 @@ class TestJobStop:
|
|||
"""Should not wait until the job is stopped."""
|
||||
cmd = "echo hello && sleep 1000"
|
||||
job_id = "test_stop_no_wait"
|
||||
completed_process = subprocess.run([
|
||||
"ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--",
|
||||
cmd
|
||||
])
|
||||
completed_process = subprocess.run(
|
||||
["ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--", cmd]
|
||||
)
|
||||
|
||||
completed_process = subprocess.run(
|
||||
["ray", "job", "stop", "--no-wait", job_id],
|
||||
stdout=subprocess.PIPE)
|
||||
["ray", "job", "stop", "--no-wait", job_id], stdout=subprocess.PIPE
|
||||
)
|
||||
stdout = completed_process.stdout.decode("utf-8")
|
||||
assert "Waiting for job" not in stdout
|
||||
assert f"Job '{job_id}' was stopped" not in stdout
|
||||
|
|
|
@ -24,82 +24,61 @@ class TestJobSubmitRequestValidation:
|
|||
assert r.entrypoint == "abc"
|
||||
assert r.job_id is None
|
||||
|
||||
r = validate_request_type({
|
||||
"entrypoint": "abc",
|
||||
"job_id": "123"
|
||||
}, JobSubmitRequest)
|
||||
r = validate_request_type(
|
||||
{"entrypoint": "abc", "job_id": "123"}, JobSubmitRequest
|
||||
)
|
||||
assert r.entrypoint == "abc"
|
||||
assert r.job_id == "123"
|
||||
|
||||
with pytest.raises(TypeError, match="must be a string"):
|
||||
validate_request_type({
|
||||
"entrypoint": 123,
|
||||
"job_id": 1
|
||||
}, JobSubmitRequest)
|
||||
validate_request_type({"entrypoint": 123, "job_id": 1}, JobSubmitRequest)
|
||||
|
||||
def test_validate_runtime_env(self):
|
||||
r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest)
|
||||
assert r.entrypoint == "abc"
|
||||
assert r.runtime_env is None
|
||||
|
||||
r = validate_request_type({
|
||||
"entrypoint": "abc",
|
||||
"runtime_env": {
|
||||
"hi": "hi2"
|
||||
}
|
||||
}, JobSubmitRequest)
|
||||
r = validate_request_type(
|
||||
{"entrypoint": "abc", "runtime_env": {"hi": "hi2"}}, JobSubmitRequest
|
||||
)
|
||||
assert r.entrypoint == "abc"
|
||||
assert r.runtime_env == {"hi": "hi2"}
|
||||
|
||||
with pytest.raises(TypeError, match="must be a dict"):
|
||||
validate_request_type({
|
||||
"entrypoint": "abc",
|
||||
"runtime_env": 123
|
||||
}, JobSubmitRequest)
|
||||
validate_request_type(
|
||||
{"entrypoint": "abc", "runtime_env": 123}, JobSubmitRequest
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="keys must be strings"):
|
||||
validate_request_type({
|
||||
"entrypoint": "abc",
|
||||
"runtime_env": {
|
||||
1: "hi"
|
||||
}
|
||||
}, JobSubmitRequest)
|
||||
validate_request_type(
|
||||
{"entrypoint": "abc", "runtime_env": {1: "hi"}}, JobSubmitRequest
|
||||
)
|
||||
|
||||
def test_validate_metadata(self):
|
||||
r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest)
|
||||
assert r.entrypoint == "abc"
|
||||
assert r.metadata is None
|
||||
|
||||
r = validate_request_type({
|
||||
"entrypoint": "abc",
|
||||
"metadata": {
|
||||
"hi": "hi2"
|
||||
}
|
||||
}, JobSubmitRequest)
|
||||
r = validate_request_type(
|
||||
{"entrypoint": "abc", "metadata": {"hi": "hi2"}}, JobSubmitRequest
|
||||
)
|
||||
assert r.entrypoint == "abc"
|
||||
assert r.metadata == {"hi": "hi2"}
|
||||
|
||||
with pytest.raises(TypeError, match="must be a dict"):
|
||||
validate_request_type({
|
||||
"entrypoint": "abc",
|
||||
"metadata": 123
|
||||
}, JobSubmitRequest)
|
||||
validate_request_type(
|
||||
{"entrypoint": "abc", "metadata": 123}, JobSubmitRequest
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="keys must be strings"):
|
||||
validate_request_type({
|
||||
"entrypoint": "abc",
|
||||
"metadata": {
|
||||
1: "hi"
|
||||
}
|
||||
}, JobSubmitRequest)
|
||||
validate_request_type(
|
||||
{"entrypoint": "abc", "metadata": {1: "hi"}}, JobSubmitRequest
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="values must be strings"):
|
||||
validate_request_type({
|
||||
"entrypoint": "abc",
|
||||
"metadata": {
|
||||
"hi": 1
|
||||
}
|
||||
}, JobSubmitRequest)
|
||||
validate_request_type(
|
||||
{"entrypoint": "abc", "metadata": {"hi": 1}}, JobSubmitRequest
|
||||
)
|
||||
|
||||
|
||||
def test_uri_to_http_and_back():
|
||||
|
@ -127,4 +106,5 @@ def test_uri_to_http_and_back():
|
|||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -8,11 +8,17 @@ import pytest
|
|||
import ray
|
||||
from ray.dashboard.tests.conftest import * # noqa
|
||||
from ray.tests.conftest import _ray_start
|
||||
from ray._private.test_utils import (format_web_url, wait_for_condition,
|
||||
wait_until_server_available)
|
||||
from ray._private.test_utils import (
|
||||
format_web_url,
|
||||
wait_for_condition,
|
||||
wait_until_server_available,
|
||||
)
|
||||
from ray.dashboard.modules.job.common import CURRENT_VERSION, JobStatus
|
||||
from ray.dashboard.modules.job.sdk import (ClusterInfo, JobSubmissionClient,
|
||||
parse_cluster_info)
|
||||
from ray.dashboard.modules.job.sdk import (
|
||||
ClusterInfo,
|
||||
JobSubmissionClient,
|
||||
parse_cluster_info,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -50,8 +56,8 @@ def _check_job_stopped(client: JobSubmissionClient, job_id: str) -> bool:
|
|||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
params=["no_working_dir", "local_working_dir", "s3_working_dir"])
|
||||
scope="module", params=["no_working_dir", "local_working_dir", "s3_working_dir"]
|
||||
)
|
||||
def working_dir_option(request):
|
||||
if request.param == "no_working_dir":
|
||||
yield {
|
||||
|
@ -81,9 +87,7 @@ def working_dir_option(request):
|
|||
f.write("from test_module.test import run_test\n")
|
||||
|
||||
yield {
|
||||
"runtime_env": {
|
||||
"working_dir": tmp_dir
|
||||
},
|
||||
"runtime_env": {"working_dir": tmp_dir},
|
||||
"entrypoint": "python test.py",
|
||||
"expected_logs": "Hello from test_module!\n",
|
||||
}
|
||||
|
@ -104,7 +108,8 @@ def test_submit_job(job_sdk_client, working_dir_option):
|
|||
|
||||
job_id = client.submit_job(
|
||||
entrypoint=working_dir_option["entrypoint"],
|
||||
runtime_env=working_dir_option["runtime_env"])
|
||||
runtime_env=working_dir_option["runtime_env"],
|
||||
)
|
||||
|
||||
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
|
||||
|
||||
|
@ -133,7 +138,8 @@ def test_http_bad_request(job_sdk_client):
|
|||
def test_invalid_runtime_env(job_sdk_client):
|
||||
client = job_sdk_client
|
||||
job_id = client.submit_job(
|
||||
entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"})
|
||||
entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"}
|
||||
)
|
||||
|
||||
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
|
||||
status = client.get_job_status(job_id)
|
||||
|
@ -143,8 +149,8 @@ def test_invalid_runtime_env(job_sdk_client):
|
|||
def test_runtime_env_setup_failure(job_sdk_client):
|
||||
client = job_sdk_client
|
||||
job_id = client.submit_job(
|
||||
entrypoint="echo hello",
|
||||
runtime_env={"working_dir": "s3://does_not_exist.zip"})
|
||||
entrypoint="echo hello", runtime_env={"working_dir": "s3://does_not_exist.zip"}
|
||||
)
|
||||
|
||||
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
|
||||
status = client.get_job_status(job_id)
|
||||
|
@ -168,8 +174,8 @@ raise RuntimeError('Intentionally failed.')
|
|||
file.write(driver_script)
|
||||
|
||||
job_id = client.submit_job(
|
||||
entrypoint="python test_script.py",
|
||||
runtime_env={"working_dir": tmp_dir})
|
||||
entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
|
||||
)
|
||||
|
||||
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
|
||||
logs = client.get_job_logs(job_id)
|
||||
|
@ -196,8 +202,8 @@ raise RuntimeError('Intentionally failed.')
|
|||
file.write(driver_script)
|
||||
|
||||
job_id = client.submit_job(
|
||||
entrypoint="python test_script.py",
|
||||
runtime_env={"working_dir": tmp_dir})
|
||||
entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
|
||||
)
|
||||
assert client.stop_job(job_id) is True
|
||||
wait_for_condition(_check_job_stopped, client=client, job_id=job_id)
|
||||
|
||||
|
@ -206,28 +212,31 @@ def test_job_metadata(job_sdk_client):
|
|||
client = job_sdk_client
|
||||
|
||||
print_metadata_cmd = (
|
||||
"python -c\""
|
||||
'python -c"'
|
||||
"import ray;"
|
||||
"ray.init();"
|
||||
"job_config=ray.worker.global_worker.core_worker.get_job_config();"
|
||||
"print(dict(sorted(job_config.metadata.items())))"
|
||||
"\"")
|
||||
'"'
|
||||
)
|
||||
|
||||
job_id = client.submit_job(
|
||||
entrypoint=print_metadata_cmd,
|
||||
metadata={
|
||||
"key1": "val1",
|
||||
"key2": "val2"
|
||||
})
|
||||
entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
|
||||
)
|
||||
|
||||
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
|
||||
|
||||
assert str({
|
||||
"job_name": job_id,
|
||||
"job_submission_id": job_id,
|
||||
"key1": "val1",
|
||||
"key2": "val2"
|
||||
}) in client.get_job_logs(job_id)
|
||||
assert (
|
||||
str(
|
||||
{
|
||||
"job_name": job_id,
|
||||
"job_submission_id": job_id,
|
||||
"key1": "val1",
|
||||
"key2": "val2",
|
||||
}
|
||||
)
|
||||
in client.get_job_logs(job_id)
|
||||
)
|
||||
|
||||
|
||||
def test_pass_job_id(job_sdk_client):
|
||||
|
@ -261,19 +270,19 @@ def test_submit_optional_args(job_sdk_client):
|
|||
json_data={"entrypoint": "ls"},
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
_check_job_succeeded, client=client, job_id=r.json()["job_id"])
|
||||
wait_for_condition(_check_job_succeeded, client=client, job_id=r.json()["job_id"])
|
||||
|
||||
|
||||
def test_missing_resources(job_sdk_client):
|
||||
"""Check that 404s are raised for resources that don't exist."""
|
||||
client = job_sdk_client
|
||||
|
||||
conditions = [("GET",
|
||||
"/api/jobs/fake_job_id"), ("GET",
|
||||
"/api/jobs/fake_job_id/logs"),
|
||||
("POST", "/api/jobs/fake_job_id/stop"),
|
||||
("GET", "/api/packages/fake_package_uri")]
|
||||
conditions = [
|
||||
("GET", "/api/jobs/fake_job_id"),
|
||||
("GET", "/api/jobs/fake_job_id/logs"),
|
||||
("POST", "/api/jobs/fake_job_id/stop"),
|
||||
("GET", "/api/packages/fake_package_uri"),
|
||||
]
|
||||
|
||||
for method, route in conditions:
|
||||
assert client._do_request(method, route).status_code == 404
|
||||
|
@ -287,7 +296,7 @@ def test_version_endpoint(job_sdk_client):
|
|||
assert r.json() == {
|
||||
"version": CURRENT_VERSION,
|
||||
"ray_version": ray.__version__,
|
||||
"ray_commit": ray.__commit__
|
||||
"ray_commit": ray.__commit__,
|
||||
}
|
||||
|
||||
|
||||
|
@ -306,26 +315,31 @@ def test_request_headers(job_sdk_client):
|
|||
cookies=None,
|
||||
data=None,
|
||||
json={"entrypoint": "ls"},
|
||||
headers={
|
||||
"Connection": "keep-alive",
|
||||
"Authorization": "TOK:<MY_TOKEN>"
|
||||
})
|
||||
headers={"Connection": "keep-alive", "Authorization": "TOK:<MY_TOKEN>"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("address", [
|
||||
"http://127.0.0.1", "https://127.0.0.1", "ray://127.0.0.1",
|
||||
"fake_module://127.0.0.1"
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"address",
|
||||
[
|
||||
"http://127.0.0.1",
|
||||
"https://127.0.0.1",
|
||||
"ray://127.0.0.1",
|
||||
"fake_module://127.0.0.1",
|
||||
],
|
||||
)
|
||||
def test_parse_cluster_info(address: str):
|
||||
if address.startswith("ray"):
|
||||
assert parse_cluster_info(address, False) == ClusterInfo(
|
||||
address="http" + address[address.index("://"):],
|
||||
address="http" + address[address.index("://") :],
|
||||
cookies=None,
|
||||
metadata=None,
|
||||
headers=None)
|
||||
headers=None,
|
||||
)
|
||||
elif address.startswith("http") or address.startswith("https"):
|
||||
assert parse_cluster_info(address, False) == ClusterInfo(
|
||||
address=address, cookies=None, metadata=None, headers=None)
|
||||
address=address, cookies=None, metadata=None, headers=None
|
||||
)
|
||||
else:
|
||||
with pytest.raises(RuntimeError):
|
||||
parse_cluster_info(address, False)
|
||||
|
@ -347,8 +361,8 @@ for i in range(100):
|
|||
f.write(driver_script)
|
||||
|
||||
job_id = client.submit_job(
|
||||
entrypoint="python test_script.py",
|
||||
runtime_env={"working_dir": tmp_dir})
|
||||
entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
|
||||
)
|
||||
|
||||
i = 0
|
||||
async for lines in client.tail_job_logs(job_id):
|
||||
|
|
|
@ -8,8 +8,11 @@ import signal
|
|||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.dashboard.modules.job.common import (JobStatus, JOB_ID_METADATA_KEY,
|
||||
JOB_NAME_METADATA_KEY)
|
||||
from ray.dashboard.modules.job.common import (
|
||||
JobStatus,
|
||||
JOB_ID_METADATA_KEY,
|
||||
JOB_NAME_METADATA_KEY,
|
||||
)
|
||||
from ray.dashboard.modules.job.job_manager import generate_job_id, JobManager
|
||||
from ray._private.test_utils import SignalActor, wait_for_condition
|
||||
|
||||
|
@ -28,7 +31,8 @@ def job_manager(shared_ray_instance):
|
|||
|
||||
def _driver_script_path(file_name: str) -> str:
|
||||
return os.path.join(
|
||||
os.path.dirname(__file__), "subprocess_driver_scripts", file_name)
|
||||
os.path.dirname(__file__), "subprocess_driver_scripts", file_name
|
||||
)
|
||||
|
||||
|
||||
def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
|
||||
|
@ -36,12 +40,15 @@ def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
|
|||
pid_file = os.path.join(tmp_dir, "pid")
|
||||
|
||||
# Write subprocess pid to pid_file and block until tmp_file is present.
|
||||
wait_for_file_cmd = (f"echo $$ > {pid_file} && "
|
||||
f"until [ -f {tmp_file} ]; "
|
||||
"do echo 'Waiting...' && sleep 1; "
|
||||
"done")
|
||||
wait_for_file_cmd = (
|
||||
f"echo $$ > {pid_file} && "
|
||||
f"until [ -f {tmp_file} ]; "
|
||||
"do echo 'Waiting...' && sleep 1; "
|
||||
"done"
|
||||
)
|
||||
job_id = job_manager.submit_job(
|
||||
entrypoint=wait_for_file_cmd, _start_signal_actor=start_signal_actor)
|
||||
entrypoint=wait_for_file_cmd, _start_signal_actor=start_signal_actor
|
||||
)
|
||||
|
||||
status = job_manager.get_job_status(job_id)
|
||||
if start_signal_actor:
|
||||
|
@ -50,11 +57,9 @@ def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
|
|||
logs = job_manager.get_job_logs(job_id)
|
||||
assert logs == ""
|
||||
else:
|
||||
wait_for_condition(
|
||||
check_job_running, job_manager=job_manager, job_id=job_id)
|
||||
wait_for_condition(check_job_running, job_manager=job_manager, job_id=job_id)
|
||||
|
||||
wait_for_condition(
|
||||
lambda: "Waiting..." in job_manager.get_job_logs(job_id))
|
||||
wait_for_condition(lambda: "Waiting..." in job_manager.get_job_logs(job_id))
|
||||
|
||||
return pid_file, tmp_file, job_id
|
||||
|
||||
|
@ -63,25 +68,19 @@ def check_job_succeeded(job_manager, job_id):
|
|||
status = job_manager.get_job_status(job_id)
|
||||
if status.status == JobStatus.FAILED:
|
||||
raise RuntimeError(f"Job failed! {status.message}")
|
||||
assert status.status in {
|
||||
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.SUCCEEDED
|
||||
}
|
||||
assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.SUCCEEDED}
|
||||
return status.status == JobStatus.SUCCEEDED
|
||||
|
||||
|
||||
def check_job_failed(job_manager, job_id):
|
||||
status = job_manager.get_job_status(job_id)
|
||||
assert status.status in {
|
||||
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.FAILED
|
||||
}
|
||||
assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.FAILED}
|
||||
return status.status == JobStatus.FAILED
|
||||
|
||||
|
||||
def check_job_stopped(job_manager, job_id):
|
||||
status = job_manager.get_job_status(job_id)
|
||||
assert status.status in {
|
||||
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.STOPPED
|
||||
}
|
||||
assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.STOPPED}
|
||||
return status.status == JobStatus.STOPPED
|
||||
|
||||
|
||||
|
@ -111,12 +110,10 @@ def test_generate_job_id():
|
|||
def test_pass_job_id(job_manager):
|
||||
job_id = "my_custom_id"
|
||||
|
||||
returned_id = job_manager.submit_job(
|
||||
entrypoint="echo hello", job_id=job_id)
|
||||
returned_id = job_manager.submit_job(entrypoint="echo hello", job_id=job_id)
|
||||
assert returned_id == job_id
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
|
||||
# Check that the same job_id is rejected.
|
||||
with pytest.raises(RuntimeError):
|
||||
|
@ -127,23 +124,20 @@ class TestShellScriptExecution:
|
|||
def test_submit_basic_echo(self, job_manager):
|
||||
job_id = job_manager.submit_job(entrypoint="echo hello")
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert job_manager.get_job_logs(job_id) == "hello\n"
|
||||
|
||||
def test_submit_stderr(self, job_manager):
|
||||
job_id = job_manager.submit_job(entrypoint="echo error 1>&2")
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert job_manager.get_job_logs(job_id) == "error\n"
|
||||
|
||||
def test_submit_ls_grep(self, job_manager):
|
||||
grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py"
|
||||
job_id = job_manager.submit_job(entrypoint=grep_cmd)
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n"
|
||||
|
||||
def test_subprocess_exception(self, job_manager):
|
||||
|
@ -161,8 +155,7 @@ class TestShellScriptExecution:
|
|||
status = job_manager.get_job_status(job_id)
|
||||
if status.status != JobStatus.FAILED:
|
||||
return False
|
||||
if ("Exception: Script failed with exception !" not in
|
||||
status.message):
|
||||
if "Exception: Script failed with exception !" not in status.message:
|
||||
return False
|
||||
|
||||
return job_manager._get_actor_for_job(job_id) is None
|
||||
|
@ -172,14 +165,13 @@ class TestShellScriptExecution:
|
|||
def test_submit_with_s3_runtime_env(self, job_manager):
|
||||
job_id = job_manager.submit_job(
|
||||
entrypoint="python script.py",
|
||||
runtime_env={
|
||||
"working_dir": "s3://runtime-env-test/script_runtime_env.zip"
|
||||
})
|
||||
runtime_env={"working_dir": "s3://runtime-env-test/script_runtime_env.zip"},
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert job_manager.get_job_logs(
|
||||
job_id) == "Executing main() from script.py !!\n"
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert (
|
||||
job_manager.get_job_logs(job_id) == "Executing main() from script.py !!\n"
|
||||
)
|
||||
|
||||
|
||||
class TestRuntimeEnv:
|
||||
|
@ -193,14 +185,10 @@ class TestRuntimeEnv:
|
|||
"""
|
||||
job_id = job_manager.submit_job(
|
||||
entrypoint="echo $TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR",
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"
|
||||
}
|
||||
})
|
||||
runtime_env={"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"}},
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert job_manager.get_job_logs(job_id) == "233\n"
|
||||
|
||||
def test_multiple_runtime_envs(self, job_manager):
|
||||
|
@ -208,28 +196,32 @@ class TestRuntimeEnv:
|
|||
job_id_1 = job_manager.submit_job(
|
||||
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"
|
||||
}
|
||||
})
|
||||
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
|
||||
},
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id_1)
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id_1
|
||||
)
|
||||
logs = job_manager.get_job_logs(job_id_1)
|
||||
assert "{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_1_VAR'}}" in logs # noqa: E501
|
||||
assert (
|
||||
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_1_VAR'}}" in logs
|
||||
) # noqa: E501
|
||||
|
||||
job_id_2 = job_manager.submit_job(
|
||||
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_2_VAR"
|
||||
}
|
||||
})
|
||||
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_2_VAR"}
|
||||
},
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id_2)
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id_2
|
||||
)
|
||||
logs = job_manager.get_job_logs(job_id_2)
|
||||
assert "{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_2_VAR'}}" in logs # noqa: E501
|
||||
assert (
|
||||
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_2_VAR'}}" in logs
|
||||
) # noqa: E501
|
||||
|
||||
def test_env_var_and_driver_job_config_warning(self, job_manager):
|
||||
"""Ensure we got error message from worker.py and job logs
|
||||
|
@ -238,17 +230,15 @@ class TestRuntimeEnv:
|
|||
job_id = job_manager.submit_job(
|
||||
entrypoint=f"python {_driver_script_path('override_env_var.py')}",
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"
|
||||
}
|
||||
})
|
||||
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
|
||||
},
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
logs = job_manager.get_job_logs(job_id)
|
||||
assert logs.startswith(
|
||||
"Both RAY_JOB_CONFIG_JSON_ENV_VAR and ray.init(runtime_env) "
|
||||
"are provided")
|
||||
"Both RAY_JOB_CONFIG_JSON_ENV_VAR and ray.init(runtime_env) " "are provided"
|
||||
)
|
||||
assert "JOB_1_VAR" in logs
|
||||
|
||||
def test_failed_runtime_env_validation(self, job_manager):
|
||||
|
@ -257,7 +247,8 @@ class TestRuntimeEnv:
|
|||
"""
|
||||
run_cmd = f"python {_driver_script_path('override_env_var.py')}"
|
||||
job_id = job_manager.submit_job(
|
||||
entrypoint=run_cmd, runtime_env={"working_dir": "path_not_exist"})
|
||||
entrypoint=run_cmd, runtime_env={"working_dir": "path_not_exist"}
|
||||
)
|
||||
|
||||
status = job_manager.get_job_status(job_id)
|
||||
assert status.status == JobStatus.FAILED
|
||||
|
@ -269,11 +260,10 @@ class TestRuntimeEnv:
|
|||
"""
|
||||
run_cmd = f"python {_driver_script_path('override_env_var.py')}"
|
||||
job_id = job_manager.submit_job(
|
||||
entrypoint=run_cmd,
|
||||
runtime_env={"working_dir": "s3://does_not_exist.zip"})
|
||||
entrypoint=run_cmd, runtime_env={"working_dir": "s3://does_not_exist.zip"}
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
|
||||
status = job_manager.get_job_status(job_id)
|
||||
assert "runtime_env setup failed" in status.message
|
||||
|
@ -283,69 +273,67 @@ class TestRuntimeEnv:
|
|||
return str(dict(sorted(d.items())))
|
||||
|
||||
print_metadata_cmd = (
|
||||
"python -c\""
|
||||
'python -c"'
|
||||
"import ray;"
|
||||
"ray.init();"
|
||||
"job_config=ray.worker.global_worker.core_worker.get_job_config();"
|
||||
"print(dict(sorted(job_config.metadata.items())))"
|
||||
"\"")
|
||||
'"'
|
||||
)
|
||||
|
||||
# Check that we default to only the job ID and job name.
|
||||
job_id = job_manager.submit_job(entrypoint=print_metadata_cmd)
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert dict_to_str({
|
||||
JOB_NAME_METADATA_KEY: job_id,
|
||||
JOB_ID_METADATA_KEY: job_id
|
||||
}) in job_manager.get_job_logs(job_id)
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert dict_to_str(
|
||||
{JOB_NAME_METADATA_KEY: job_id, JOB_ID_METADATA_KEY: job_id}
|
||||
) in job_manager.get_job_logs(job_id)
|
||||
|
||||
# Check that we can pass custom metadata.
|
||||
job_id = job_manager.submit_job(
|
||||
entrypoint=print_metadata_cmd,
|
||||
metadata={
|
||||
"key1": "val1",
|
||||
"key2": "val2"
|
||||
})
|
||||
entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert dict_to_str({
|
||||
JOB_NAME_METADATA_KEY: job_id,
|
||||
JOB_ID_METADATA_KEY: job_id,
|
||||
"key1": "val1",
|
||||
"key2": "val2"
|
||||
}) in job_manager.get_job_logs(job_id)
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert (
|
||||
dict_to_str(
|
||||
{
|
||||
JOB_NAME_METADATA_KEY: job_id,
|
||||
JOB_ID_METADATA_KEY: job_id,
|
||||
"key1": "val1",
|
||||
"key2": "val2",
|
||||
}
|
||||
)
|
||||
in job_manager.get_job_logs(job_id)
|
||||
)
|
||||
|
||||
# Check that we can override job name.
|
||||
job_id = job_manager.submit_job(
|
||||
entrypoint=print_metadata_cmd,
|
||||
metadata={JOB_NAME_METADATA_KEY: "custom_name"})
|
||||
metadata={JOB_NAME_METADATA_KEY: "custom_name"},
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert dict_to_str({
|
||||
JOB_NAME_METADATA_KEY: "custom_name",
|
||||
JOB_ID_METADATA_KEY: job_id
|
||||
}) in job_manager.get_job_logs(job_id)
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
assert dict_to_str(
|
||||
{JOB_NAME_METADATA_KEY: "custom_name", JOB_ID_METADATA_KEY: job_id}
|
||||
) in job_manager.get_job_logs(job_id)
|
||||
|
||||
|
||||
class TestAsyncAPI:
|
||||
def test_status_and_logs_while_blocking(self, job_manager):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pid_file, tmp_file, job_id = _run_hanging_command(
|
||||
job_manager, tmp_dir)
|
||||
pid_file, tmp_file, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||
with open(pid_file, "r") as file:
|
||||
pid = int(file.read())
|
||||
assert psutil.pid_exists(pid), (
|
||||
"driver subprocess should be running")
|
||||
assert psutil.pid_exists(pid), "driver subprocess should be running"
|
||||
|
||||
# Signal the job to exit by writing to the file.
|
||||
with open(tmp_file, "w") as f:
|
||||
print("hello", file=f)
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
# Ensure driver subprocess gets cleaned up after job reached
|
||||
# termination state
|
||||
wait_for_condition(check_subprocess_cleaned, pid=pid)
|
||||
|
@ -356,7 +344,8 @@ class TestAsyncAPI:
|
|||
|
||||
assert job_manager.stop_job(job_id) is True
|
||||
wait_for_condition(
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id)
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
# Assert re-stopping a stopped job also returns False
|
||||
wait_for_condition(lambda: job_manager.stop_job(job_id) is False)
|
||||
# Assert stopping non-existent job returns False
|
||||
|
@ -375,13 +364,11 @@ class TestAsyncAPI:
|
|||
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||
with open(pid_file, "r") as file:
|
||||
pid = int(file.read())
|
||||
assert psutil.pid_exists(pid), (
|
||||
"driver subprocess should be running")
|
||||
assert psutil.pid_exists(pid), "driver subprocess should be running"
|
||||
|
||||
actor = job_manager._get_actor_for_job(job_id)
|
||||
ray.kill(actor, no_restart=True)
|
||||
wait_for_condition(
|
||||
check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
|
||||
# Ensure driver subprocess gets cleaned up after job reached
|
||||
# termination state
|
||||
|
@ -398,16 +385,18 @@ class TestAsyncAPI:
|
|||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pid_file, _, job_id = _run_hanging_command(
|
||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor)
|
||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor
|
||||
)
|
||||
assert not os.path.exists(pid_file), (
|
||||
"driver subprocess should NOT be running while job is "
|
||||
"still PENDING.")
|
||||
"driver subprocess should NOT be running while job is " "still PENDING."
|
||||
)
|
||||
|
||||
assert job_manager.stop_job(job_id) is True
|
||||
# Send run signal to unblock run function
|
||||
ray.get(start_signal_actor.send.remote())
|
||||
wait_for_condition(
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id)
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
def test_kill_job_actor_in_pending(self, job_manager):
|
||||
"""
|
||||
|
@ -420,16 +409,16 @@ class TestAsyncAPI:
|
|||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pid_file, _, job_id = _run_hanging_command(
|
||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor)
|
||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor
|
||||
)
|
||||
|
||||
assert not os.path.exists(pid_file), (
|
||||
"driver subprocess should NOT be running while job is "
|
||||
"still PENDING.")
|
||||
"driver subprocess should NOT be running while job is " "still PENDING."
|
||||
)
|
||||
|
||||
actor = job_manager._get_actor_for_job(job_id)
|
||||
ray.kill(actor, no_restart=True)
|
||||
wait_for_condition(
|
||||
check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
|
||||
def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager):
|
||||
"""
|
||||
|
@ -442,12 +431,12 @@ class TestAsyncAPI:
|
|||
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||
with open(pid_file, "r") as file:
|
||||
pid = int(file.read())
|
||||
assert psutil.pid_exists(pid), (
|
||||
"driver subprocess should be running")
|
||||
assert psutil.pid_exists(pid), "driver subprocess should be running"
|
||||
|
||||
assert job_manager.stop_job(job_id) is True
|
||||
wait_for_condition(
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id)
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
# Ensure driver subprocess gets cleaned up after job reached
|
||||
# termination state
|
||||
|
@ -455,11 +444,9 @@ class TestAsyncAPI:
|
|||
|
||||
|
||||
class TestTailLogs:
|
||||
async def _tail_and_assert_logs(self,
|
||||
job_id,
|
||||
job_manager,
|
||||
expected_log="",
|
||||
num_iteration=5):
|
||||
async def _tail_and_assert_logs(
|
||||
self, job_id, job_manager, expected_log="", num_iteration=5
|
||||
):
|
||||
i = 0
|
||||
async for lines in job_manager.tail_job_logs(job_id):
|
||||
assert all(s == expected_log for s in lines.strip().split("\n"))
|
||||
|
@ -470,8 +457,7 @@ class TestTailLogs:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_job(self, job_manager):
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Job 'unknown' does not exist."):
|
||||
with pytest.raises(RuntimeError, match="Job 'unknown' does not exist."):
|
||||
async for _ in job_manager.tail_job_logs("unknown"):
|
||||
pass
|
||||
|
||||
|
@ -482,33 +468,31 @@ class TestTailLogs:
|
|||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
_, tmp_file, job_id = _run_hanging_command(
|
||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor)
|
||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor
|
||||
)
|
||||
|
||||
# TODO(edoakes): check we get no logs before actor starts (not sure
|
||||
# how to timeout the iterator call).
|
||||
assert job_manager.get_job_status(
|
||||
job_id).status == JobStatus.PENDING
|
||||
assert job_manager.get_job_status(job_id).status == JobStatus.PENDING
|
||||
|
||||
# Signal job to start.
|
||||
ray.get(start_signal_actor.send.remote())
|
||||
|
||||
await self._tail_and_assert_logs(
|
||||
job_id,
|
||||
job_manager,
|
||||
expected_log="Waiting...",
|
||||
num_iteration=5)
|
||||
job_id, job_manager, expected_log="Waiting...", num_iteration=5
|
||||
)
|
||||
|
||||
# Signal the job to exit by writing to the file.
|
||||
with open(tmp_file, "w") as f:
|
||||
print("hello", file=f)
|
||||
|
||||
async for lines in job_manager.tail_job_logs(job_id):
|
||||
assert all(
|
||||
s == "Waiting..." for s in lines.strip().split("\n"))
|
||||
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
|
||||
print(lines, end="")
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_job(self, job_manager):
|
||||
|
@ -517,22 +501,18 @@ class TestTailLogs:
|
|||
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||
|
||||
await self._tail_and_assert_logs(
|
||||
job_id,
|
||||
job_manager,
|
||||
expected_log="Waiting...",
|
||||
num_iteration=5)
|
||||
job_id, job_manager, expected_log="Waiting...", num_iteration=5
|
||||
)
|
||||
|
||||
# Kill the job unexpectedly.
|
||||
with open(pid_file, "r") as f:
|
||||
os.kill(int(f.read()), signal.SIGKILL)
|
||||
|
||||
async for lines in job_manager.tail_job_logs(job_id):
|
||||
assert all(
|
||||
s == "Waiting..." for s in lines.strip().split("\n"))
|
||||
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
|
||||
print(lines, end="")
|
||||
|
||||
wait_for_condition(
|
||||
check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stopped_job(self, job_manager):
|
||||
|
@ -541,21 +521,19 @@ class TestTailLogs:
|
|||
_, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||
|
||||
await self._tail_and_assert_logs(
|
||||
job_id,
|
||||
job_manager,
|
||||
expected_log="Waiting...",
|
||||
num_iteration=5)
|
||||
job_id, job_manager, expected_log="Waiting...", num_iteration=5
|
||||
)
|
||||
|
||||
# Stop the job via the API.
|
||||
job_manager.stop_job(job_id)
|
||||
|
||||
async for lines in job_manager.tail_job_logs(job_id):
|
||||
assert all(
|
||||
s == "Waiting..." for s in lines.strip().split("\n"))
|
||||
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
|
||||
print(lines, end="")
|
||||
|
||||
wait_for_condition(
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id)
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
|
||||
def test_logs_streaming(job_manager):
|
||||
|
@ -568,7 +546,7 @@ while True:
|
|||
time.sleep(1)
|
||||
"""
|
||||
|
||||
stream_logs_cmd = f"python -c \"{stream_logs_script}\""
|
||||
stream_logs_cmd = f'python -c "{stream_logs_script}"'
|
||||
|
||||
job_id = job_manager.submit_job(entrypoint=stream_logs_cmd)
|
||||
wait_for_condition(lambda: "STREAMED" in job_manager.get_job_logs(job_id))
|
||||
|
|
|
@ -12,7 +12,7 @@ def tmp():
|
|||
yield f.name
|
||||
|
||||
|
||||
class TestIterLine():
|
||||
class TestIterLine:
|
||||
def test_invalid_type(self):
|
||||
with pytest.raises(TypeError, match="path must be a string"):
|
||||
next(file_tail_iterator(1))
|
||||
|
|
|
@ -19,10 +19,8 @@ class LogHead(dashboard_utils.DashboardHeadModule):
|
|||
self._proxy_session = aiohttp.ClientSession(auto_decompress=False)
|
||||
log_utils.register_mimetypes()
|
||||
routes.static("/logs", self._dashboard_head.log_dir, show_index=True)
|
||||
GlobalSignals.node_info_fetched.append(
|
||||
self.insert_log_url_to_node_info)
|
||||
GlobalSignals.node_summary_fetched.append(
|
||||
self.insert_log_url_to_node_info)
|
||||
GlobalSignals.node_info_fetched.append(self.insert_log_url_to_node_info)
|
||||
GlobalSignals.node_summary_fetched.append(self.insert_log_url_to_node_info)
|
||||
|
||||
async def insert_log_url_to_node_info(self, node_info):
|
||||
node_id = node_info.get("raylet", {}).get("nodeId")
|
||||
|
@ -33,7 +31,8 @@ class LogHead(dashboard_utils.DashboardHeadModule):
|
|||
return
|
||||
agent_http_port, _ = agent_port
|
||||
log_url = self.LOG_URL_TEMPLATE.format(
|
||||
ip=node_info.get("ip"), port=agent_http_port)
|
||||
ip=node_info.get("ip"), port=agent_http_port
|
||||
)
|
||||
node_info["logUrl"] = log_url
|
||||
|
||||
@routes.get("/log_index")
|
||||
|
@ -43,15 +42,16 @@ class LogHead(dashboard_utils.DashboardHeadModule):
|
|||
for node_id, ports in DataSource.agents.items():
|
||||
ip = DataSource.node_id_to_ip[node_id]
|
||||
agent_ips.append(ip)
|
||||
url_list.append(
|
||||
self.LOG_URL_TEMPLATE.format(ip=ip, port=str(ports[0])))
|
||||
url_list.append(self.LOG_URL_TEMPLATE.format(ip=ip, port=str(ports[0])))
|
||||
if self._dashboard_head.ip not in agent_ips:
|
||||
url_list.append(
|
||||
self.LOG_URL_TEMPLATE.format(
|
||||
ip=self._dashboard_head.ip,
|
||||
port=self._dashboard_head.http_port))
|
||||
ip=self._dashboard_head.ip, port=self._dashboard_head.http_port
|
||||
)
|
||||
)
|
||||
return aiohttp.web.Response(
|
||||
text=self._directory_as_html(url_list), content_type="text/html")
|
||||
text=self._directory_as_html(url_list), content_type="text/html"
|
||||
)
|
||||
|
||||
@routes.get("/log_proxy")
|
||||
async def get_log_from_proxy(self, req) -> aiohttp.web.StreamResponse:
|
||||
|
@ -60,9 +60,11 @@ class LogHead(dashboard_utils.DashboardHeadModule):
|
|||
raise Exception("url is None.")
|
||||
body = await req.read()
|
||||
async with self._proxy_session.request(
|
||||
req.method, url, data=body, headers=req.headers) as r:
|
||||
req.method, url, data=body, headers=req.headers
|
||||
) as r:
|
||||
sr = aiohttp.web.StreamResponse(
|
||||
status=r.status, reason=r.reason, headers=req.headers)
|
||||
status=r.status, reason=r.reason, headers=req.headers
|
||||
)
|
||||
sr.content_length = r.content_length
|
||||
sr.content_type = r.content_type
|
||||
sr.charset = r.charset
|
||||
|
|
|
@ -40,8 +40,7 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
|
||||
test_log_text = "test_log_text"
|
||||
ray.get(write_log.remote(test_log_text))
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
node_id = ray_start_with_dashboard["node_id"]
|
||||
|
@ -82,8 +81,8 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
|
||||
# Test range request.
|
||||
response = requests.get(
|
||||
webui_url + "/logs/dashboard.log",
|
||||
headers={"Range": "bytes=44-52"})
|
||||
webui_url + "/logs/dashboard.log", headers={"Range": "bytes=44-52"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
assert response.text == "Dashboard"
|
||||
|
||||
|
@ -100,16 +99,19 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
last_ex = ex
|
||||
finally:
|
||||
if time.time() > start_time + timeout_seconds:
|
||||
ex_stack = traceback.format_exception(
|
||||
type(last_ex), last_ex,
|
||||
last_ex.__traceback__) if last_ex else []
|
||||
ex_stack = (
|
||||
traceback.format_exception(
|
||||
type(last_ex), last_ex, last_ex.__traceback__
|
||||
)
|
||||
if last_ex
|
||||
else []
|
||||
)
|
||||
ex_stack = "".join(ex_stack)
|
||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||
|
||||
|
||||
def test_log_proxy(ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
|
@ -122,21 +124,27 @@ def test_log_proxy(ray_start_with_dashboard):
|
|||
# Test range request.
|
||||
response = requests.get(
|
||||
f"{webui_url}/log_proxy?url={webui_url}/logs/dashboard.log",
|
||||
headers={"Range": "bytes=44-52"})
|
||||
headers={"Range": "bytes=44-52"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
assert response.text == "Dashboard"
|
||||
# Test 404.
|
||||
response = requests.get(f"{webui_url}/log_proxy?"
|
||||
f"url={webui_url}/logs/not_exist_file.log")
|
||||
response = requests.get(
|
||||
f"{webui_url}/log_proxy?" f"url={webui_url}/logs/not_exist_file.log"
|
||||
)
|
||||
assert response.status_code == 404
|
||||
break
|
||||
except Exception as ex:
|
||||
last_ex = ex
|
||||
finally:
|
||||
if time.time() > start_time + timeout_seconds:
|
||||
ex_stack = traceback.format_exception(
|
||||
type(last_ex), last_ex,
|
||||
last_ex.__traceback__) if last_ex else []
|
||||
ex_stack = (
|
||||
traceback.format_exception(
|
||||
type(last_ex), last_ex, last_ex.__traceback__
|
||||
)
|
||||
if last_ex
|
||||
else []
|
||||
)
|
||||
ex_stack = "".join(ex_stack)
|
||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||
|
||||
|
|
|
@ -9,8 +9,10 @@ import ray._private.utils
|
|||
import ray._private.gcs_utils as gcs_utils
|
||||
from ray import ray_constants
|
||||
from ray.dashboard.modules.node import node_consts
|
||||
from ray.dashboard.modules.node.node_consts import (MAX_LOGS_TO_CACHE,
|
||||
LOG_PRUNE_THREASHOLD)
|
||||
from ray.dashboard.modules.node.node_consts import (
|
||||
MAX_LOGS_TO_CACHE,
|
||||
LOG_PRUNE_THREASHOLD,
|
||||
)
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||
import ray.dashboard.consts as dashboard_consts
|
||||
|
@ -28,13 +30,21 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
|
|||
|
||||
def gcs_node_info_to_dict(message):
|
||||
return dashboard_utils.message_to_dict(
|
||||
message, {"nodeId"}, including_default_value_fields=True)
|
||||
message, {"nodeId"}, including_default_value_fields=True
|
||||
)
|
||||
|
||||
|
||||
def node_stats_to_dict(message):
|
||||
decode_keys = {
|
||||
"actorId", "jobId", "taskId", "parentTaskId", "sourceActorId",
|
||||
"callerId", "rayletId", "workerId", "placementGroupId"
|
||||
"actorId",
|
||||
"jobId",
|
||||
"taskId",
|
||||
"parentTaskId",
|
||||
"sourceActorId",
|
||||
"callerId",
|
||||
"rayletId",
|
||||
"workerId",
|
||||
"placementGroupId",
|
||||
}
|
||||
core_workers_stats = message.core_workers_stats
|
||||
message.ClearField("core_workers_stats")
|
||||
|
@ -42,7 +52,8 @@ def node_stats_to_dict(message):
|
|||
result = dashboard_utils.message_to_dict(message, decode_keys)
|
||||
result["coreWorkersStats"] = [
|
||||
dashboard_utils.message_to_dict(
|
||||
m, decode_keys, including_default_value_fields=True)
|
||||
m, decode_keys, including_default_value_fields=True
|
||||
)
|
||||
for m in core_workers_stats
|
||||
]
|
||||
return result
|
||||
|
@ -66,11 +77,13 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
if change.new:
|
||||
# TODO(fyrestone): Handle exceptions.
|
||||
node_id, node_info = change.new
|
||||
address = "{}:{}".format(node_info["nodeManagerAddress"],
|
||||
int(node_info["nodeManagerPort"]))
|
||||
options = (("grpc.enable_http_proxy", 0), )
|
||||
address = "{}:{}".format(
|
||||
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
|
||||
)
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
channel = ray._private.utils.init_grpc_channel(
|
||||
address, options, asynchronous=True)
|
||||
address, options, asynchronous=True
|
||||
)
|
||||
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
|
||||
self._stubs[node_id] = stub
|
||||
|
||||
|
@ -81,8 +94,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
A dict of information about the nodes in the cluster.
|
||||
"""
|
||||
request = gcs_service_pb2.GetAllNodeInfoRequest()
|
||||
reply = await self._gcs_node_info_stub.GetAllNodeInfo(
|
||||
request, timeout=2)
|
||||
reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=2)
|
||||
if reply.status.code == 0:
|
||||
result = {}
|
||||
for node_info in reply.node_info_list:
|
||||
|
@ -116,11 +128,11 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
|
||||
agents = dict(DataSource.agents)
|
||||
for node_id in alive_node_ids:
|
||||
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" \
|
||||
f"{node_id}"
|
||||
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" f"{node_id}"
|
||||
# TODO: Use async version if performance is an issue
|
||||
agent_port = ray.experimental.internal_kv._internal_kv_get(
|
||||
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
||||
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
|
||||
)
|
||||
if agent_port:
|
||||
agents[node_id] = json.loads(agent_port)
|
||||
for node_id in agents.keys() - set(alive_node_ids):
|
||||
|
@ -142,9 +154,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
if view == "summary":
|
||||
all_node_summary = await DataOrganizer.get_all_node_summary()
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True,
|
||||
message="Node summary fetched.",
|
||||
summary=all_node_summary)
|
||||
success=True, message="Node summary fetched.", summary=all_node_summary
|
||||
)
|
||||
elif view == "details":
|
||||
all_node_details = await DataOrganizer.get_all_node_details()
|
||||
return dashboard_optional_utils.rest_response(
|
||||
|
@ -160,10 +171,12 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
return dashboard_optional_utils.rest_response(
|
||||
success=True,
|
||||
message="Node hostname list fetched.",
|
||||
host_name_list=list(alive_hostnames))
|
||||
host_name_list=list(alive_hostnames),
|
||||
)
|
||||
else:
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=False, message=f"Unknown view {view}")
|
||||
success=False, message=f"Unknown view {view}"
|
||||
)
|
||||
|
||||
@routes.get("/nodes/{node_id}")
|
||||
@dashboard_optional_utils.aiohttp_cache
|
||||
|
@ -171,7 +184,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
node_id = req.match_info.get("node_id")
|
||||
node_info = await DataOrganizer.get_node_info(node_id)
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True, message="Node details fetched.", detail=node_info)
|
||||
success=True, message="Node details fetched.", detail=node_info
|
||||
)
|
||||
|
||||
@routes.get("/memory/memory_table")
|
||||
async def get_memory_table(self, req) -> aiohttp.web.Response:
|
||||
|
@ -187,7 +201,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
return dashboard_optional_utils.rest_response(
|
||||
success=True,
|
||||
message="Fetched memory table",
|
||||
memory_table=memory_table.as_dict())
|
||||
memory_table=memory_table.as_dict(),
|
||||
)
|
||||
|
||||
@routes.get("/memory/set_fetch")
|
||||
async def set_fetch_memory_info(self, req) -> aiohttp.web.Response:
|
||||
|
@ -198,11 +213,11 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
self._collect_memory_info = False
|
||||
else:
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=False,
|
||||
message=f"Unknown argument to set_fetch {should_fetch}")
|
||||
success=False, message=f"Unknown argument to set_fetch {should_fetch}"
|
||||
)
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True,
|
||||
message=f"Successfully set fetching to {should_fetch}")
|
||||
success=True, message=f"Successfully set fetching to {should_fetch}"
|
||||
)
|
||||
|
||||
@routes.get("/node_logs")
|
||||
async def get_logs(self, req) -> aiohttp.web.Response:
|
||||
|
@ -212,7 +227,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
if pid:
|
||||
node_logs = {str(pid): node_logs.get(pid, [])}
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True, message="Fetched logs.", logs=node_logs)
|
||||
success=True, message="Fetched logs.", logs=node_logs
|
||||
)
|
||||
|
||||
@routes.get("/node_errors")
|
||||
async def get_errors(self, req) -> aiohttp.web.Response:
|
||||
|
@ -222,7 +238,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
if pid:
|
||||
node_errors = {str(pid): node_errors.get(pid, [])}
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True, message="Fetched errors.", errors=node_errors)
|
||||
success=True, message="Fetched errors.", errors=node_errors
|
||||
)
|
||||
|
||||
@async_loop_forever(node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS)
|
||||
async def _update_node_stats(self):
|
||||
|
@ -234,8 +251,10 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
try:
|
||||
reply = await stub.GetNodeStats(
|
||||
node_manager_pb2.GetNodeStatsRequest(
|
||||
include_memory_info=self._collect_memory_info),
|
||||
timeout=2)
|
||||
include_memory_info=self._collect_memory_info
|
||||
),
|
||||
timeout=2,
|
||||
)
|
||||
reply_dict = node_stats_to_dict(reply)
|
||||
DataSource.node_stats[node_id] = reply_dict
|
||||
except Exception:
|
||||
|
@ -263,8 +282,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
if self._dashboard_head.gcs_log_subscriber:
|
||||
while True:
|
||||
try:
|
||||
log_batch = await \
|
||||
self._dashboard_head.gcs_log_subscriber.poll()
|
||||
log_batch = await self._dashboard_head.gcs_log_subscriber.poll()
|
||||
if log_batch is None:
|
||||
continue
|
||||
process_log_batch(log_batch)
|
||||
|
@ -296,11 +314,13 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
ip = match.group(2)
|
||||
errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
|
||||
pid_errors = list(errs_for_ip.get(pid, []))
|
||||
pid_errors.append({
|
||||
"message": message,
|
||||
"timestamp": error_data.timestamp,
|
||||
"type": error_data.type
|
||||
})
|
||||
pid_errors.append(
|
||||
{
|
||||
"message": message,
|
||||
"timestamp": error_data.timestamp,
|
||||
"type": error_data.type,
|
||||
}
|
||||
)
|
||||
errs_for_ip[pid] = pid_errors
|
||||
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
|
||||
logger.info(f"Received error entry for {ip} {pid}")
|
||||
|
@ -308,8 +328,10 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
if self._dashboard_head.gcs_error_subscriber:
|
||||
while True:
|
||||
try:
|
||||
_, error_data = await \
|
||||
self._dashboard_head.gcs_error_subscriber.poll()
|
||||
(
|
||||
_,
|
||||
error_data,
|
||||
) = await self._dashboard_head.gcs_error_subscriber.poll()
|
||||
if error_data is None:
|
||||
continue
|
||||
process_error(error_data)
|
||||
|
@ -328,20 +350,23 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
try:
|
||||
_, data = msg
|
||||
pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
|
||||
error_data = gcs_utils.ErrorTableData.FromString(
|
||||
pubsub_msg.data)
|
||||
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
|
||||
process_error(error_data)
|
||||
except Exception:
|
||||
logger.exception("Error receiving error info from Redis.")
|
||||
|
||||
async def run(self, server):
|
||||
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
|
||||
self._gcs_node_info_stub = \
|
||||
gcs_service_pb2_grpc.NodeInfoGcsServiceStub(gcs_channel)
|
||||
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
|
||||
gcs_channel
|
||||
)
|
||||
|
||||
await asyncio.gather(self._update_nodes(), self._update_node_stats(),
|
||||
self._update_log_info(),
|
||||
self._update_error_info())
|
||||
await asyncio.gather(
|
||||
self._update_nodes(),
|
||||
self._update_node_stats(),
|
||||
self._update_log_info(),
|
||||
self._update_error_info(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_minimal_module():
|
||||
|
|
|
@ -10,19 +10,23 @@ import ray
|
|||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.dashboard.modules.node.node_consts import (LOG_PRUNE_THREASHOLD,
|
||||
MAX_LOGS_TO_CACHE)
|
||||
from ray.dashboard.modules.node.node_consts import (
|
||||
LOG_PRUNE_THREASHOLD,
|
||||
MAX_LOGS_TO_CACHE,
|
||||
)
|
||||
from ray.dashboard.tests.conftest import * # noqa
|
||||
from ray._private.test_utils import (
|
||||
format_web_url, wait_until_server_available, wait_for_condition,
|
||||
wait_until_succeeded_without_exception)
|
||||
format_web_url,
|
||||
wait_until_server_available,
|
||||
wait_for_condition,
|
||||
wait_until_succeeded_without_exception,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_nodes_update(enable_test_module, ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
|
@ -44,8 +48,7 @@ def test_nodes_update(enable_test_module, ray_start_with_dashboard):
|
|||
assert len(dump_data["agents"]) == 1
|
||||
assert len(dump_data["nodeIdToIp"]) == 1
|
||||
assert len(dump_data["nodeIdToHostname"]) == 1
|
||||
assert dump_data["nodes"].keys() == dump_data[
|
||||
"nodeIdToHostname"].keys()
|
||||
assert dump_data["nodes"].keys() == dump_data["nodeIdToHostname"].keys()
|
||||
|
||||
response = requests.get(webui_url + "/test/notified_agents")
|
||||
response.raise_for_status()
|
||||
|
@ -77,8 +80,7 @@ def test_node_info(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
actor_pids = [actor.getpid.remote() for actor in actors]
|
||||
actor_pids = set(ray.get(actor_pids))
|
||||
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
node_id = ray_start_with_dashboard["node_id"]
|
||||
|
@ -134,15 +136,19 @@ def test_node_info(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
last_ex = ex
|
||||
finally:
|
||||
if time.time() > start_time + timeout_seconds:
|
||||
ex_stack = traceback.format_exception(
|
||||
type(last_ex), last_ex,
|
||||
last_ex.__traceback__) if last_ex else []
|
||||
ex_stack = (
|
||||
traceback.format_exception(
|
||||
type(last_ex), last_ex, last_ex.__traceback__
|
||||
)
|
||||
if last_ex
|
||||
else []
|
||||
)
|
||||
ex_stack = "".join(ex_stack)
|
||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||
|
||||
|
||||
def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
|
||||
@ray.remote
|
||||
class ActorWithObjs:
|
||||
|
@ -156,8 +162,7 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
actors = [ActorWithObjs.remote() for _ in range(2)] # noqa
|
||||
results = ray.get([actor.get_obj.remote() for actor in actors]) # noqa
|
||||
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
||||
resp = requests.get(
|
||||
webui_url + "/memory/set_fetch", params={"shouldFetch": "true"})
|
||||
resp = requests.get(webui_url + "/memory/set_fetch", params={"shouldFetch": "true"})
|
||||
resp.raise_for_status()
|
||||
|
||||
def check_mem_table():
|
||||
|
@ -172,11 +177,12 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
assert summary["totalLocalRefCount"] == 3
|
||||
|
||||
assert wait_until_succeeded_without_exception(
|
||||
check_mem_table, (AssertionError, ), timeout_ms=10000)
|
||||
check_mem_table, (AssertionError,), timeout_ms=10000
|
||||
)
|
||||
|
||||
|
||||
def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
|
||||
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
|
||||
|
||||
|
@ -220,21 +226,25 @@ def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard):
|
|||
last_ex = ex
|
||||
finally:
|
||||
if time.time() > start_time + timeout_seconds:
|
||||
ex_stack = traceback.format_exception(
|
||||
type(last_ex), last_ex,
|
||||
last_ex.__traceback__) if last_ex else []
|
||||
ex_stack = (
|
||||
traceback.format_exception(
|
||||
type(last_ex), last_ex, last_ex.__traceback__
|
||||
)
|
||||
if last_ex
|
||||
else []
|
||||
)
|
||||
ex_stack = "".join(ex_stack)
|
||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster_head", [{
|
||||
"include_dashboard": True
|
||||
}], indirect=True)
|
||||
def test_multi_nodes_info(enable_test_module, disable_aiohttp_cache,
|
||||
ray_start_cluster_head):
|
||||
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
|
||||
)
|
||||
def test_multi_nodes_info(
|
||||
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
|
||||
):
|
||||
cluster: Cluster = ray_start_cluster_head
|
||||
assert (wait_until_server_available(cluster.webui_url) is True)
|
||||
assert wait_until_server_available(cluster.webui_url) is True
|
||||
webui_url = cluster.webui_url
|
||||
webui_url = format_web_url(webui_url)
|
||||
cluster.add_node()
|
||||
|
@ -269,13 +279,13 @@ def test_multi_nodes_info(enable_test_module, disable_aiohttp_cache,
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster_head", [{
|
||||
"include_dashboard": True
|
||||
}], indirect=True)
|
||||
def test_multi_node_churn(enable_test_module, disable_aiohttp_cache,
|
||||
ray_start_cluster_head):
|
||||
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
|
||||
)
|
||||
def test_multi_node_churn(
|
||||
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
|
||||
):
|
||||
cluster: Cluster = ray_start_cluster_head
|
||||
assert (wait_until_server_available(cluster.webui_url) is True)
|
||||
assert wait_until_server_available(cluster.webui_url) is True
|
||||
webui_url = format_web_url(cluster.webui_url)
|
||||
|
||||
def cluster_chaos_monkey():
|
||||
|
@ -315,13 +325,11 @@ def test_multi_node_churn(enable_test_module, disable_aiohttp_cache,
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster_head", [{
|
||||
"include_dashboard": True
|
||||
}], indirect=True)
|
||||
def test_logs(enable_test_module, disable_aiohttp_cache,
|
||||
ray_start_cluster_head):
|
||||
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
|
||||
)
|
||||
def test_logs(enable_test_module, disable_aiohttp_cache, ray_start_cluster_head):
|
||||
cluster = ray_start_cluster_head
|
||||
assert (wait_until_server_available(cluster.webui_url) is True)
|
||||
assert wait_until_server_available(cluster.webui_url) is True
|
||||
webui_url = cluster.webui_url
|
||||
webui_url = format_web_url(webui_url)
|
||||
nodes = ray.nodes()
|
||||
|
@ -348,21 +356,18 @@ def test_logs(enable_test_module, disable_aiohttp_cache,
|
|||
|
||||
def check_logs():
|
||||
node_logs_response = requests.get(
|
||||
f"{webui_url}/node_logs", params={"ip": node_ip})
|
||||
f"{webui_url}/node_logs", params={"ip": node_ip}
|
||||
)
|
||||
node_logs_response.raise_for_status()
|
||||
node_logs = node_logs_response.json()
|
||||
assert node_logs["result"]
|
||||
assert type(node_logs["data"]["logs"]) is dict
|
||||
assert all(
|
||||
pid in node_logs["data"]["logs"] for pid in (la_pid, la2_pid))
|
||||
assert all(pid in node_logs["data"]["logs"] for pid in (la_pid, la2_pid))
|
||||
assert len(node_logs["data"]["logs"][la2_pid]) == 1
|
||||
|
||||
actor_one_logs_response = requests.get(
|
||||
f"{webui_url}/node_logs",
|
||||
params={
|
||||
"ip": node_ip,
|
||||
"pid": str(la_pid)
|
||||
})
|
||||
f"{webui_url}/node_logs", params={"ip": node_ip, "pid": str(la_pid)}
|
||||
)
|
||||
actor_one_logs_response.raise_for_status()
|
||||
actor_one_logs = actor_one_logs_response.json()
|
||||
assert actor_one_logs["result"]
|
||||
|
@ -370,19 +375,19 @@ def test_logs(enable_test_module, disable_aiohttp_cache,
|
|||
assert len(actor_one_logs["data"]["logs"][la_pid]) == 4
|
||||
|
||||
assert wait_until_succeeded_without_exception(
|
||||
check_logs, (AssertionError, ), timeout_ms=1000)
|
||||
check_logs, (AssertionError,), timeout_ms=1000
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster_head", [{
|
||||
"include_dashboard": True
|
||||
}], indirect=True)
|
||||
def test_logs_clean_up(enable_test_module, disable_aiohttp_cache,
|
||||
ray_start_cluster_head):
|
||||
"""Check if logs from the dead pids are GC'ed.
|
||||
"""
|
||||
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
|
||||
)
|
||||
def test_logs_clean_up(
|
||||
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
|
||||
):
|
||||
"""Check if logs from the dead pids are GC'ed."""
|
||||
cluster = ray_start_cluster_head
|
||||
assert (wait_until_server_available(cluster.webui_url) is True)
|
||||
assert wait_until_server_available(cluster.webui_url) is True
|
||||
webui_url = cluster.webui_url
|
||||
webui_url = format_web_url(webui_url)
|
||||
nodes = ray.nodes()
|
||||
|
@ -406,38 +411,41 @@ def test_logs_clean_up(enable_test_module, disable_aiohttp_cache,
|
|||
|
||||
def check_logs():
|
||||
node_logs_response = requests.get(
|
||||
f"{webui_url}/node_logs", params={"ip": node_ip})
|
||||
f"{webui_url}/node_logs", params={"ip": node_ip}
|
||||
)
|
||||
node_logs_response.raise_for_status()
|
||||
node_logs = node_logs_response.json()
|
||||
assert node_logs["result"]
|
||||
assert la_pid in node_logs["data"]["logs"]
|
||||
|
||||
assert wait_until_succeeded_without_exception(
|
||||
check_logs, (AssertionError, ), timeout_ms=1000)
|
||||
check_logs, (AssertionError,), timeout_ms=1000
|
||||
)
|
||||
ray.kill(la)
|
||||
|
||||
def check_logs_not_exist():
|
||||
node_logs_response = requests.get(
|
||||
f"{webui_url}/node_logs", params={"ip": node_ip})
|
||||
f"{webui_url}/node_logs", params={"ip": node_ip}
|
||||
)
|
||||
node_logs_response.raise_for_status()
|
||||
node_logs = node_logs_response.json()
|
||||
assert node_logs["result"]
|
||||
assert la_pid not in node_logs["data"]["logs"]
|
||||
|
||||
assert wait_until_succeeded_without_exception(
|
||||
check_logs_not_exist, (AssertionError, ), timeout_ms=10000)
|
||||
check_logs_not_exist, (AssertionError,), timeout_ms=10000
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster_head", [{
|
||||
"include_dashboard": True
|
||||
}], indirect=True)
|
||||
def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
|
||||
ray_start_cluster_head):
|
||||
"""Test that each Ray worker cannot cache more than 1000 logs at a time.
|
||||
"""
|
||||
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
|
||||
)
|
||||
def test_logs_max_count(
|
||||
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
|
||||
):
|
||||
"""Test that each Ray worker cannot cache more than 1000 logs at a time."""
|
||||
cluster = ray_start_cluster_head
|
||||
assert (wait_until_server_available(cluster.webui_url) is True)
|
||||
assert wait_until_server_available(cluster.webui_url) is True
|
||||
webui_url = cluster.webui_url
|
||||
webui_url = format_web_url(webui_url)
|
||||
nodes = ray.nodes()
|
||||
|
@ -461,7 +469,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
|
|||
|
||||
def check_logs():
|
||||
node_logs_response = requests.get(
|
||||
f"{webui_url}/node_logs", params={"ip": node_ip})
|
||||
f"{webui_url}/node_logs", params={"ip": node_ip}
|
||||
)
|
||||
node_logs_response.raise_for_status()
|
||||
node_logs = node_logs_response.json()
|
||||
assert node_logs["result"]
|
||||
|
@ -472,11 +481,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
|
|||
assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD
|
||||
|
||||
actor_one_logs_response = requests.get(
|
||||
f"{webui_url}/node_logs",
|
||||
params={
|
||||
"ip": node_ip,
|
||||
"pid": str(la_pid)
|
||||
})
|
||||
f"{webui_url}/node_logs", params={"ip": node_ip, "pid": str(la_pid)}
|
||||
)
|
||||
actor_one_logs_response.raise_for_status()
|
||||
actor_one_logs = actor_one_logs_response.json()
|
||||
assert actor_one_logs["result"]
|
||||
|
@ -486,7 +492,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
|
|||
assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD
|
||||
|
||||
assert wait_until_succeeded_without_exception(
|
||||
check_logs, (AssertionError, ), timeout_ms=10000)
|
||||
check_logs, (AssertionError,), timeout_ms=10000
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -39,10 +39,12 @@ try:
|
|||
except (ModuleNotFoundError, ImportError):
|
||||
gpustat = None
|
||||
if log_once("gpustat_import_warning"):
|
||||
warnings.warn("`gpustat` package is not installed. GPU monitoring is "
|
||||
"not available. To have full functionality of the "
|
||||
"dashboard please install `pip install ray["
|
||||
"default]`.)")
|
||||
warnings.warn(
|
||||
"`gpustat` package is not installed. GPU monitoring is "
|
||||
"not available. To have full functionality of the "
|
||||
"dashboard please install `pip install ray["
|
||||
"default]`.)"
|
||||
)
|
||||
|
||||
|
||||
def recursive_asdict(o):
|
||||
|
@ -68,68 +70,81 @@ def jsonify_asdict(o) -> str:
|
|||
|
||||
# A list of gauges to record and export metrics.
|
||||
METRICS_GAUGES = {
|
||||
"node_cpu_utilization": Gauge("node_cpu_utilization",
|
||||
"Total CPU usage on a ray node",
|
||||
"percentage", ["ip"]),
|
||||
"node_cpu_count": Gauge("node_cpu_count",
|
||||
"Total CPUs available on a ray node", "cores",
|
||||
["ip"]),
|
||||
"node_mem_used": Gauge("node_mem_used", "Memory usage on a ray node",
|
||||
"bytes", ["ip"]),
|
||||
"node_mem_available": Gauge("node_mem_available",
|
||||
"Memory available on a ray node", "bytes",
|
||||
["ip"]),
|
||||
"node_mem_total": Gauge("node_mem_total", "Total memory on a ray node",
|
||||
"bytes", ["ip"]),
|
||||
"node_gpus_available": Gauge("node_gpus_available",
|
||||
"Total GPUs available on a ray node",
|
||||
"percentage", ["ip"]),
|
||||
"node_gpus_utilization": Gauge("node_gpus_utilization",
|
||||
"Total GPUs usage on a ray node",
|
||||
"percentage", ["ip"]),
|
||||
"node_gram_used": Gauge("node_gram_used",
|
||||
"Total GPU RAM usage on a ray node", "bytes",
|
||||
["ip"]),
|
||||
"node_gram_available": Gauge("node_gram_available",
|
||||
"Total GPU RAM available on a ray node",
|
||||
"bytes", ["ip"]),
|
||||
"node_disk_usage": Gauge("node_disk_usage",
|
||||
"Total disk usage (bytes) on a ray node", "bytes",
|
||||
["ip"]),
|
||||
"node_disk_free": Gauge("node_disk_free",
|
||||
"Total disk free (bytes) on a ray node", "bytes",
|
||||
["ip"]),
|
||||
"node_cpu_utilization": Gauge(
|
||||
"node_cpu_utilization", "Total CPU usage on a ray node", "percentage", ["ip"]
|
||||
),
|
||||
"node_cpu_count": Gauge(
|
||||
"node_cpu_count", "Total CPUs available on a ray node", "cores", ["ip"]
|
||||
),
|
||||
"node_mem_used": Gauge(
|
||||
"node_mem_used", "Memory usage on a ray node", "bytes", ["ip"]
|
||||
),
|
||||
"node_mem_available": Gauge(
|
||||
"node_mem_available", "Memory available on a ray node", "bytes", ["ip"]
|
||||
),
|
||||
"node_mem_total": Gauge(
|
||||
"node_mem_total", "Total memory on a ray node", "bytes", ["ip"]
|
||||
),
|
||||
"node_gpus_available": Gauge(
|
||||
"node_gpus_available",
|
||||
"Total GPUs available on a ray node",
|
||||
"percentage",
|
||||
["ip"],
|
||||
),
|
||||
"node_gpus_utilization": Gauge(
|
||||
"node_gpus_utilization", "Total GPUs usage on a ray node", "percentage", ["ip"]
|
||||
),
|
||||
"node_gram_used": Gauge(
|
||||
"node_gram_used", "Total GPU RAM usage on a ray node", "bytes", ["ip"]
|
||||
),
|
||||
"node_gram_available": Gauge(
|
||||
"node_gram_available", "Total GPU RAM available on a ray node", "bytes", ["ip"]
|
||||
),
|
||||
"node_disk_usage": Gauge(
|
||||
"node_disk_usage", "Total disk usage (bytes) on a ray node", "bytes", ["ip"]
|
||||
),
|
||||
"node_disk_free": Gauge(
|
||||
"node_disk_free", "Total disk free (bytes) on a ray node", "bytes", ["ip"]
|
||||
),
|
||||
"node_disk_utilization_percentage": Gauge(
|
||||
"node_disk_utilization_percentage",
|
||||
"Total disk utilization (percentage) on a ray node", "percentage",
|
||||
["ip"]),
|
||||
"node_network_sent": Gauge("node_network_sent", "Total network sent",
|
||||
"bytes", ["ip"]),
|
||||
"node_network_received": Gauge("node_network_received",
|
||||
"Total network received", "bytes", ["ip"]),
|
||||
"Total disk utilization (percentage) on a ray node",
|
||||
"percentage",
|
||||
["ip"],
|
||||
),
|
||||
"node_network_sent": Gauge(
|
||||
"node_network_sent", "Total network sent", "bytes", ["ip"]
|
||||
),
|
||||
"node_network_received": Gauge(
|
||||
"node_network_received", "Total network received", "bytes", ["ip"]
|
||||
),
|
||||
"node_network_send_speed": Gauge(
|
||||
"node_network_send_speed", "Network send speed", "bytes/sec", ["ip"]),
|
||||
"node_network_receive_speed": Gauge("node_network_receive_speed",
|
||||
"Network receive speed", "bytes/sec",
|
||||
["ip"]),
|
||||
"raylet_cpu": Gauge("raylet_cpu", "CPU usage of the raylet on a node.",
|
||||
"percentage", ["ip", "pid"]),
|
||||
"raylet_mem": Gauge("raylet_mem", "Memory usage of the raylet on a node",
|
||||
"mb", ["ip", "pid"]),
|
||||
"cluster_active_nodes": Gauge("cluster_active_nodes",
|
||||
"Active nodes on the cluster", "count",
|
||||
["node_type"]),
|
||||
"cluster_failed_nodes": Gauge("cluster_failed_nodes",
|
||||
"Failed nodes on the cluster", "count",
|
||||
["node_type"]),
|
||||
"cluster_pending_nodes": Gauge("cluster_pending_nodes",
|
||||
"Pending nodes on the cluster", "count",
|
||||
["node_type"]),
|
||||
"node_network_send_speed", "Network send speed", "bytes/sec", ["ip"]
|
||||
),
|
||||
"node_network_receive_speed": Gauge(
|
||||
"node_network_receive_speed", "Network receive speed", "bytes/sec", ["ip"]
|
||||
),
|
||||
"raylet_cpu": Gauge(
|
||||
"raylet_cpu", "CPU usage of the raylet on a node.", "percentage", ["ip", "pid"]
|
||||
),
|
||||
"raylet_mem": Gauge(
|
||||
"raylet_mem", "Memory usage of the raylet on a node", "mb", ["ip", "pid"]
|
||||
),
|
||||
"cluster_active_nodes": Gauge(
|
||||
"cluster_active_nodes", "Active nodes on the cluster", "count", ["node_type"]
|
||||
),
|
||||
"cluster_failed_nodes": Gauge(
|
||||
"cluster_failed_nodes", "Failed nodes on the cluster", "count", ["node_type"]
|
||||
),
|
||||
"cluster_pending_nodes": Gauge(
|
||||
"cluster_pending_nodes", "Pending nodes on the cluster", "count", ["node_type"]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
||||
reporter_pb2_grpc.ReporterServiceServicer):
|
||||
class ReporterAgent(
|
||||
dashboard_utils.DashboardAgentModule, reporter_pb2_grpc.ReporterServiceServicer
|
||||
):
|
||||
"""A monitor process for monitoring Ray nodes.
|
||||
|
||||
Attributes:
|
||||
|
@ -145,37 +160,39 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
cpu_count = ray._private.utils.get_num_cpus()
|
||||
self._cpu_counts = (cpu_count, cpu_count)
|
||||
else:
|
||||
self._cpu_counts = (psutil.cpu_count(),
|
||||
psutil.cpu_count(logical=False))
|
||||
self._cpu_counts = (psutil.cpu_count(), psutil.cpu_count(logical=False))
|
||||
|
||||
self._ip = dashboard_agent.ip
|
||||
if not use_gcs_for_bootstrap():
|
||||
self._redis_address, _ = dashboard_agent.redis_address
|
||||
self._is_head_node = (self._ip == self._redis_address)
|
||||
self._is_head_node = self._ip == self._redis_address
|
||||
else:
|
||||
self._is_head_node = (
|
||||
self._ip == dashboard_agent.gcs_address.split(":")[0])
|
||||
self._is_head_node = self._ip == dashboard_agent.gcs_address.split(":")[0]
|
||||
self._hostname = socket.gethostname()
|
||||
self._workers = set()
|
||||
self._network_stats_hist = [(0, (0.0, 0.0))] # time, (sent, recv)
|
||||
self._metrics_agent = MetricsAgent(
|
||||
"127.0.0.1" if self._ip == "127.0.0.1" else "",
|
||||
dashboard_agent.metrics_export_port)
|
||||
self._key = f"{reporter_consts.REPORTER_PREFIX}" \
|
||||
f"{self._dashboard_agent.node_id}"
|
||||
dashboard_agent.metrics_export_port,
|
||||
)
|
||||
self._key = (
|
||||
f"{reporter_consts.REPORTER_PREFIX}" f"{self._dashboard_agent.node_id}"
|
||||
)
|
||||
|
||||
async def GetProfilingStats(self, request, context):
|
||||
pid = request.pid
|
||||
duration = request.duration
|
||||
profiling_file_path = os.path.join(
|
||||
ray._private.utils.get_ray_temp_dir(), f"{pid}_profiling.txt")
|
||||
ray._private.utils.get_ray_temp_dir(), f"{pid}_profiling.txt"
|
||||
)
|
||||
sudo = "sudo" if ray._private.utils.get_user() != "root" else ""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
f"{sudo} $(which py-spy) record "
|
||||
f"-o {profiling_file_path} -p {pid} -d {duration} -f speedscope",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
shell=True)
|
||||
shell=True,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
if process.returncode != 0:
|
||||
profiling_stats = ""
|
||||
|
@ -183,14 +200,14 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
with open(profiling_file_path, "r") as f:
|
||||
profiling_stats = f.read()
|
||||
return reporter_pb2.GetProfilingStatsReply(
|
||||
profiling_stats=profiling_stats, std_out=stdout, std_err=stderr)
|
||||
profiling_stats=profiling_stats, std_out=stdout, std_err=stderr
|
||||
)
|
||||
|
||||
async def ReportOCMetrics(self, request, context):
|
||||
# This function receives a GRPC containing OpenCensus (OC) metrics
|
||||
# from a Ray process, then exposes those metrics to Prometheus.
|
||||
try:
|
||||
self._metrics_agent.record_metric_points_from_protobuf(
|
||||
request.metrics)
|
||||
self._metrics_agent.record_metric_points_from_protobuf(request.metrics)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
return reporter_pb2.ReportOCMetricsReply()
|
||||
|
@ -227,10 +244,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
for gpu in gpus:
|
||||
# Note the keys in this dict have periods which throws
|
||||
# off javascript so we change .s to _s
|
||||
gpu_data = {
|
||||
"_".join(key.split(".")): val
|
||||
for key, val in gpu.entry.items()
|
||||
}
|
||||
gpu_data = {"_".join(key.split(".")): val for key, val in gpu.entry.items()}
|
||||
gpu_utilizations.append(gpu_data)
|
||||
return gpu_utilizations
|
||||
|
||||
|
@ -245,8 +259,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
@staticmethod
|
||||
def _get_network_stats():
|
||||
ifaces = [
|
||||
v for k, v in psutil.net_io_counters(pernic=True).items()
|
||||
if k[0] == "e"
|
||||
v for k, v in psutil.net_io_counters(pernic=True).items() if k[0] == "e"
|
||||
]
|
||||
|
||||
sent = sum((iface.bytes_sent for iface in ifaces))
|
||||
|
@ -266,8 +279,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
if IN_KUBERNETES_POD:
|
||||
# If in a K8s pod, disable disk display by passing in dummy values.
|
||||
return {
|
||||
"/": psutil._common.sdiskusage(
|
||||
total=1, used=0, free=1, percent=0.0)
|
||||
"/": psutil._common.sdiskusage(total=1, used=0, free=1, percent=0.0)
|
||||
}
|
||||
root = os.environ["USERPROFILE"] if sys.platform == "win32" else os.sep
|
||||
tmp = ray._private.utils.get_user_temp_dir()
|
||||
|
@ -286,14 +298,18 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
self._workers.update(workers)
|
||||
self._workers.discard(psutil.Process())
|
||||
return [
|
||||
w.as_dict(attrs=[
|
||||
"pid",
|
||||
"create_time",
|
||||
"cpu_percent",
|
||||
"cpu_times",
|
||||
"cmdline",
|
||||
"memory_info",
|
||||
]) for w in self._workers if w.status() != psutil.STATUS_ZOMBIE
|
||||
w.as_dict(
|
||||
attrs=[
|
||||
"pid",
|
||||
"create_time",
|
||||
"cpu_percent",
|
||||
"cpu_times",
|
||||
"cmdline",
|
||||
"memory_info",
|
||||
]
|
||||
)
|
||||
for w in self._workers
|
||||
if w.status() != psutil.STATUS_ZOMBIE
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
|
@ -318,14 +334,16 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
if raylet_proc is None:
|
||||
return {}
|
||||
else:
|
||||
return raylet_proc.as_dict(attrs=[
|
||||
"pid",
|
||||
"create_time",
|
||||
"cpu_percent",
|
||||
"cpu_times",
|
||||
"cmdline",
|
||||
"memory_info",
|
||||
])
|
||||
return raylet_proc.as_dict(
|
||||
attrs=[
|
||||
"pid",
|
||||
"create_time",
|
||||
"cpu_percent",
|
||||
"cpu_times",
|
||||
"cmdline",
|
||||
"memory_info",
|
||||
]
|
||||
)
|
||||
|
||||
def _get_load_avg(self):
|
||||
if sys.platform == "win32":
|
||||
|
@ -345,8 +363,10 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
then, prev_network_stats = self._network_stats_hist[0]
|
||||
prev_send, prev_recv = prev_network_stats
|
||||
now_send, now_recv = network_stats
|
||||
network_speed_stats = ((now_send - prev_send) / (now - then),
|
||||
(now_recv - prev_recv) / (now - then))
|
||||
network_speed_stats = (
|
||||
(now_send - prev_send) / (now - then),
|
||||
(now_recv - prev_recv) / (now - then),
|
||||
)
|
||||
return {
|
||||
"now": now,
|
||||
"hostname": self._hostname,
|
||||
|
@ -379,7 +399,9 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
Record(
|
||||
gauge=METRICS_GAUGES["cluster_active_nodes"],
|
||||
value=active_node_count,
|
||||
tags={"node_type": node_type}))
|
||||
tags={"node_type": node_type},
|
||||
)
|
||||
)
|
||||
|
||||
failed_nodes = cluster_stats["autoscaler_report"]["failed_nodes"]
|
||||
failed_nodes_dict = {}
|
||||
|
@ -394,7 +416,9 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
Record(
|
||||
gauge=METRICS_GAUGES["cluster_failed_nodes"],
|
||||
value=failed_node_count,
|
||||
tags={"node_type": node_type}))
|
||||
tags={"node_type": node_type},
|
||||
)
|
||||
)
|
||||
|
||||
pending_nodes = cluster_stats["autoscaler_report"]["pending_nodes"]
|
||||
pending_nodes_dict = {}
|
||||
|
@ -409,35 +433,36 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
Record(
|
||||
gauge=METRICS_GAUGES["cluster_pending_nodes"],
|
||||
value=pending_node_count,
|
||||
tags={"node_type": node_type}))
|
||||
tags={"node_type": node_type},
|
||||
)
|
||||
)
|
||||
|
||||
# -- CPU per node --
|
||||
cpu_usage = float(stats["cpu"])
|
||||
cpu_record = Record(
|
||||
gauge=METRICS_GAUGES["node_cpu_utilization"],
|
||||
value=cpu_usage,
|
||||
tags={"ip": ip})
|
||||
tags={"ip": ip},
|
||||
)
|
||||
|
||||
cpu_count, _ = stats["cpus"]
|
||||
cpu_count_record = Record(
|
||||
gauge=METRICS_GAUGES["node_cpu_count"],
|
||||
value=cpu_count,
|
||||
tags={"ip": ip})
|
||||
gauge=METRICS_GAUGES["node_cpu_count"], value=cpu_count, tags={"ip": ip}
|
||||
)
|
||||
|
||||
# -- Mem per node --
|
||||
mem_total, mem_available, _, mem_used = stats["mem"]
|
||||
mem_used_record = Record(
|
||||
gauge=METRICS_GAUGES["node_mem_used"],
|
||||
value=mem_used,
|
||||
tags={"ip": ip})
|
||||
gauge=METRICS_GAUGES["node_mem_used"], value=mem_used, tags={"ip": ip}
|
||||
)
|
||||
mem_available_record = Record(
|
||||
gauge=METRICS_GAUGES["node_mem_available"],
|
||||
value=mem_available,
|
||||
tags={"ip": ip})
|
||||
tags={"ip": ip},
|
||||
)
|
||||
mem_total_record = Record(
|
||||
gauge=METRICS_GAUGES["node_mem_total"],
|
||||
value=mem_total,
|
||||
tags={"ip": ip})
|
||||
gauge=METRICS_GAUGES["node_mem_total"], value=mem_total, tags={"ip": ip}
|
||||
)
|
||||
|
||||
# -- GPU per node --
|
||||
gpus = stats["gpus"]
|
||||
|
@ -455,23 +480,29 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
gpus_available_record = Record(
|
||||
gauge=METRICS_GAUGES["node_gpus_available"],
|
||||
value=gpus_available,
|
||||
tags={"ip": ip})
|
||||
tags={"ip": ip},
|
||||
)
|
||||
gpus_utilization_record = Record(
|
||||
gauge=METRICS_GAUGES["node_gpus_utilization"],
|
||||
value=gpus_utilization,
|
||||
tags={"ip": ip})
|
||||
tags={"ip": ip},
|
||||
)
|
||||
gram_used_record = Record(
|
||||
gauge=METRICS_GAUGES["node_gram_used"],
|
||||
value=gram_used,
|
||||
tags={"ip": ip})
|
||||
gauge=METRICS_GAUGES["node_gram_used"], value=gram_used, tags={"ip": ip}
|
||||
)
|
||||
gram_available_record = Record(
|
||||
gauge=METRICS_GAUGES["node_gram_available"],
|
||||
value=gram_available,
|
||||
tags={"ip": ip})
|
||||
records_reported.extend([
|
||||
gpus_available_record, gpus_utilization_record,
|
||||
gram_used_record, gram_available_record
|
||||
])
|
||||
tags={"ip": ip},
|
||||
)
|
||||
records_reported.extend(
|
||||
[
|
||||
gpus_available_record,
|
||||
gpus_utilization_record,
|
||||
gram_used_record,
|
||||
gram_available_record,
|
||||
]
|
||||
)
|
||||
|
||||
# -- Disk per node --
|
||||
used, free = 0, 0
|
||||
|
@ -480,39 +511,42 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
free += entry.free
|
||||
disk_utilization = float(used / (used + free)) * 100
|
||||
disk_usage_record = Record(
|
||||
gauge=METRICS_GAUGES["node_disk_usage"],
|
||||
value=used,
|
||||
tags={"ip": ip})
|
||||
gauge=METRICS_GAUGES["node_disk_usage"], value=used, tags={"ip": ip}
|
||||
)
|
||||
disk_free_record = Record(
|
||||
gauge=METRICS_GAUGES["node_disk_free"],
|
||||
value=free,
|
||||
tags={"ip": ip})
|
||||
gauge=METRICS_GAUGES["node_disk_free"], value=free, tags={"ip": ip}
|
||||
)
|
||||
disk_utilization_percentage_record = Record(
|
||||
gauge=METRICS_GAUGES["node_disk_utilization_percentage"],
|
||||
value=disk_utilization,
|
||||
tags={"ip": ip})
|
||||
tags={"ip": ip},
|
||||
)
|
||||
|
||||
# -- Network speed (send/receive) stats per node --
|
||||
network_stats = stats["network"]
|
||||
network_sent_record = Record(
|
||||
gauge=METRICS_GAUGES["node_network_sent"],
|
||||
value=network_stats[0],
|
||||
tags={"ip": ip})
|
||||
tags={"ip": ip},
|
||||
)
|
||||
network_received_record = Record(
|
||||
gauge=METRICS_GAUGES["node_network_received"],
|
||||
value=network_stats[1],
|
||||
tags={"ip": ip})
|
||||
tags={"ip": ip},
|
||||
)
|
||||
|
||||
# -- Network speed (send/receive) per node --
|
||||
network_speed_stats = stats["network_speed"]
|
||||
network_send_speed_record = Record(
|
||||
gauge=METRICS_GAUGES["node_network_send_speed"],
|
||||
value=network_speed_stats[0],
|
||||
tags={"ip": ip})
|
||||
tags={"ip": ip},
|
||||
)
|
||||
network_receive_speed_record = Record(
|
||||
gauge=METRICS_GAUGES["node_network_receive_speed"],
|
||||
value=network_speed_stats[1],
|
||||
tags={"ip": ip})
|
||||
tags={"ip": ip},
|
||||
)
|
||||
|
||||
raylet_stats = stats["raylet"]
|
||||
if raylet_stats:
|
||||
|
@ -522,29 +556,34 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
raylet_cpu_record = Record(
|
||||
gauge=METRICS_GAUGES["raylet_cpu"],
|
||||
value=raylet_cpu_usage,
|
||||
tags={
|
||||
"ip": ip,
|
||||
"pid": raylet_pid
|
||||
})
|
||||
tags={"ip": ip, "pid": raylet_pid},
|
||||
)
|
||||
|
||||
# -- raylet mem --
|
||||
raylet_mem_usage = float(raylet_stats["memory_info"].rss) / 1e6
|
||||
raylet_mem_record = Record(
|
||||
gauge=METRICS_GAUGES["raylet_mem"],
|
||||
value=raylet_mem_usage,
|
||||
tags={
|
||||
"ip": ip,
|
||||
"pid": raylet_pid
|
||||
})
|
||||
tags={"ip": ip, "pid": raylet_pid},
|
||||
)
|
||||
records_reported.extend([raylet_cpu_record, raylet_mem_record])
|
||||
|
||||
records_reported.extend([
|
||||
cpu_record, cpu_count_record, mem_used_record,
|
||||
mem_available_record, mem_total_record, disk_usage_record,
|
||||
disk_free_record, disk_utilization_percentage_record,
|
||||
network_sent_record, network_received_record,
|
||||
network_send_speed_record, network_receive_speed_record
|
||||
])
|
||||
records_reported.extend(
|
||||
[
|
||||
cpu_record,
|
||||
cpu_count_record,
|
||||
mem_used_record,
|
||||
mem_available_record,
|
||||
mem_total_record,
|
||||
disk_usage_record,
|
||||
disk_free_record,
|
||||
disk_utilization_percentage_record,
|
||||
network_sent_record,
|
||||
network_received_record,
|
||||
network_send_speed_record,
|
||||
network_receive_speed_record,
|
||||
]
|
||||
)
|
||||
return records_reported
|
||||
|
||||
async def _perform_iteration(self, publish):
|
||||
|
@ -552,9 +591,13 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
while True:
|
||||
try:
|
||||
formatted_status_string = internal_kv._internal_kv_get(
|
||||
DEBUG_AUTOSCALING_STATUS)
|
||||
cluster_stats = json.loads(formatted_status_string.decode(
|
||||
)) if formatted_status_string else {}
|
||||
DEBUG_AUTOSCALING_STATUS
|
||||
)
|
||||
cluster_stats = (
|
||||
json.loads(formatted_status_string.decode())
|
||||
if formatted_status_string
|
||||
else {}
|
||||
)
|
||||
|
||||
stats = self._get_all_stats()
|
||||
records_reported = self._record_stats(stats, cluster_stats)
|
||||
|
@ -563,8 +606,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
|
||||
except Exception:
|
||||
logger.exception("Error publishing node physical stats.")
|
||||
await asyncio.sleep(
|
||||
reporter_consts.REPORTER_UPDATE_INTERVAL_MS / 1000)
|
||||
await asyncio.sleep(reporter_consts.REPORTER_UPDATE_INTERVAL_MS / 1000)
|
||||
|
||||
async def run(self, server):
|
||||
reporter_pb2_grpc.add_ReporterServiceServicer_to_server(self, server)
|
||||
|
@ -573,17 +615,20 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
|
|||
if gcs_addr is None:
|
||||
aioredis_client = await aioredis.create_redis_pool(
|
||||
address=self._dashboard_agent.redis_address,
|
||||
password=self._dashboard_agent.redis_password)
|
||||
password=self._dashboard_agent.redis_password,
|
||||
)
|
||||
gcs_addr = await aioredis_client.get("GcsServerAddress")
|
||||
gcs_addr = gcs_addr.decode()
|
||||
publisher = GcsAioPublisher(address=gcs_addr)
|
||||
|
||||
async def publish(key: str, data: str):
|
||||
await publisher.publish_resource_usage(key, data)
|
||||
|
||||
else:
|
||||
aioredis_client = await aioredis.create_redis_pool(
|
||||
address=self._dashboard_agent.redis_address,
|
||||
password=self._dashboard_agent.redis_password)
|
||||
password=self._dashboard_agent.redis_password,
|
||||
)
|
||||
|
||||
async def publish(key: str, data: str):
|
||||
await aioredis_client.publish(key, data)
|
||||
|
|
|
@ -3,4 +3,5 @@ import ray.ray_constants as ray_constants
|
|||
REPORTER_PREFIX = "RAY_REPORTER:"
|
||||
# The reporter will report its statistics this often (milliseconds).
|
||||
REPORTER_UPDATE_INTERVAL_MS = ray_constants.env_integer(
|
||||
"REPORTER_UPDATE_INTERVAL_MS", 2500)
|
||||
"REPORTER_UPDATE_INTERVAL_MS", 2500
|
||||
)
|
||||
|
|
|
@ -9,13 +9,14 @@ import ray
|
|||
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, \
|
||||
GcsAioResourceUsageSubscriber
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioResourceUsageSubscriber
|
||||
import ray._private.services
|
||||
import ray._private.utils
|
||||
from ray.ray_constants import (DEBUG_AUTOSCALING_STATUS,
|
||||
DEBUG_AUTOSCALING_STATUS_LEGACY,
|
||||
DEBUG_AUTOSCALING_ERROR)
|
||||
from ray.ray_constants import (
|
||||
DEBUG_AUTOSCALING_STATUS,
|
||||
DEBUG_AUTOSCALING_STATUS_LEGACY,
|
||||
DEBUG_AUTOSCALING_ERROR,
|
||||
)
|
||||
from ray.core.generated import reporter_pb2
|
||||
from ray.core.generated import reporter_pb2_grpc
|
||||
import ray.experimental.internal_kv as internal_kv
|
||||
|
@ -40,9 +41,10 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
|||
if change.new:
|
||||
node_id, ports = change.new
|
||||
ip = DataSource.node_id_to_ip[node_id]
|
||||
options = (("grpc.enable_http_proxy", 0), )
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
channel = ray._private.utils.init_grpc_channel(
|
||||
f"{ip}:{ports[1]}", options=options, asynchronous=True)
|
||||
f"{ip}:{ports[1]}", options=options, asynchronous=True
|
||||
)
|
||||
stub = reporter_pb2_grpc.ReporterServiceStub(channel)
|
||||
self._stubs[ip] = stub
|
||||
|
||||
|
@ -53,13 +55,16 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
|||
duration = int(req.query["duration"])
|
||||
reporter_stub = self._stubs[ip]
|
||||
reply = await reporter_stub.GetProfilingStats(
|
||||
reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration))
|
||||
profiling_info = (json.loads(reply.profiling_stats)
|
||||
if reply.profiling_stats else reply.std_out)
|
||||
reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration)
|
||||
)
|
||||
profiling_info = (
|
||||
json.loads(reply.profiling_stats)
|
||||
if reply.profiling_stats
|
||||
else reply.std_out
|
||||
)
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True,
|
||||
message="Profiling success.",
|
||||
profiling_info=profiling_info)
|
||||
success=True, message="Profiling success.", profiling_info=profiling_info
|
||||
)
|
||||
|
||||
@routes.get("/api/ray_config")
|
||||
async def get_ray_config(self, req) -> aiohttp.web.Response:
|
||||
|
@ -75,12 +80,12 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
|||
)
|
||||
except FileNotFoundError:
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=False,
|
||||
message="Invalid config, could not load YAML.")
|
||||
success=False, message="Invalid config, could not load YAML."
|
||||
)
|
||||
|
||||
payload = {
|
||||
"min_workers": cfg.get("min_workers", "unspecified"),
|
||||
"max_workers": cfg.get("max_workers", "unspecified")
|
||||
"max_workers": cfg.get("max_workers", "unspecified"),
|
||||
}
|
||||
|
||||
try:
|
||||
|
@ -115,18 +120,18 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
|||
"""
|
||||
|
||||
assert ray.experimental.internal_kv._internal_kv_initialized()
|
||||
legacy_status = internal_kv._internal_kv_get(
|
||||
DEBUG_AUTOSCALING_STATUS_LEGACY)
|
||||
formatted_status_string = internal_kv._internal_kv_get(
|
||||
DEBUG_AUTOSCALING_STATUS)
|
||||
formatted_status = json.loads(formatted_status_string.decode()
|
||||
) if formatted_status_string else {}
|
||||
legacy_status = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS_LEGACY)
|
||||
formatted_status_string = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS)
|
||||
formatted_status = (
|
||||
json.loads(formatted_status_string.decode())
|
||||
if formatted_status_string
|
||||
else {}
|
||||
)
|
||||
error = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_ERROR)
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True,
|
||||
message="Got cluster status.",
|
||||
autoscaling_status=legacy_status.decode()
|
||||
if legacy_status else None,
|
||||
autoscaling_status=legacy_status.decode() if legacy_status else None,
|
||||
autoscaling_error=error.decode() if error else None,
|
||||
cluster_status=formatted_status if formatted_status else None,
|
||||
)
|
||||
|
@ -148,8 +153,9 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
|||
node_id = key.split(":")[-1]
|
||||
DataSource.node_physical_stats[node_id] = data
|
||||
except Exception:
|
||||
logger.exception("Error receiving node physical stats "
|
||||
"from reporter agent.")
|
||||
logger.exception(
|
||||
"Error receiving node physical stats " "from reporter agent."
|
||||
)
|
||||
else:
|
||||
receiver = Receiver()
|
||||
aioredis_client = self._dashboard_head.aioredis_client
|
||||
|
@ -165,8 +171,9 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
|||
node_id = key.split(":")[-1]
|
||||
DataSource.node_physical_stats[node_id] = data
|
||||
except Exception:
|
||||
logger.exception("Error receiving node physical stats "
|
||||
"from reporter agent.")
|
||||
logger.exception(
|
||||
"Error receiving node physical stats " "from reporter agent."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_minimal_module():
|
||||
|
|
|
@ -10,9 +10,13 @@ from ray import ray_constants
|
|||
from ray.dashboard.tests.conftest import * # noqa
|
||||
from ray.dashboard.utils import Bunch
|
||||
from ray.dashboard.modules.reporter.reporter_agent import ReporterAgent
|
||||
from ray._private.test_utils import (format_web_url, RayTestTimeoutException,
|
||||
wait_until_server_available,
|
||||
wait_for_condition, fetch_prometheus)
|
||||
from ray._private.test_utils import (
|
||||
format_web_url,
|
||||
RayTestTimeoutException,
|
||||
wait_until_server_available,
|
||||
wait_for_condition,
|
||||
fetch_prometheus,
|
||||
)
|
||||
|
||||
try:
|
||||
import prometheus_client
|
||||
|
@ -34,7 +38,7 @@ def test_profiling(shutdown_only):
|
|||
actor_pid = ray.get(c.getpid.remote())
|
||||
|
||||
webui_url = addresses["webui_url"]
|
||||
assert (wait_until_server_available(webui_url) is True)
|
||||
assert wait_until_server_available(webui_url) is True
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
start_time = time.time()
|
||||
|
@ -44,14 +48,16 @@ def test_profiling(shutdown_only):
|
|||
if time.time() - start_time > 15:
|
||||
raise RayTestTimeoutException(
|
||||
"Timed out while collecting profiling stats, "
|
||||
f"launch_profiling: {launch_profiling}")
|
||||
f"launch_profiling: {launch_profiling}"
|
||||
)
|
||||
launch_profiling = requests.get(
|
||||
webui_url + "/api/launch_profiling",
|
||||
params={
|
||||
"ip": ray.nodes()[0]["NodeManagerAddress"],
|
||||
"pid": actor_pid,
|
||||
"duration": 5
|
||||
}).json()
|
||||
"duration": 5,
|
||||
},
|
||||
).json()
|
||||
if launch_profiling["result"]:
|
||||
profiling_info = launch_profiling["data"]["profilingInfo"]
|
||||
break
|
||||
|
@ -72,13 +78,12 @@ def test_node_physical_stats(enable_test_module, shutdown_only):
|
|||
actor_pids = set(actor_pids)
|
||||
|
||||
webui_url = addresses["webui_url"]
|
||||
assert (wait_until_server_available(webui_url) is True)
|
||||
assert wait_until_server_available(webui_url) is True
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
def _check_workers():
|
||||
try:
|
||||
resp = requests.get(webui_url +
|
||||
"/test/dump?key=node_physical_stats")
|
||||
resp = requests.get(webui_url + "/test/dump?key=node_physical_stats")
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is True
|
||||
|
@ -101,8 +106,7 @@ def test_node_physical_stats(enable_test_module, shutdown_only):
|
|||
wait_for_condition(_check_workers, timeout=10)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
prometheus_client is None, reason="prometheus_client not installed")
|
||||
@pytest.mark.skipif(prometheus_client is None, reason="prometheus_client not installed")
|
||||
def test_prometheus_physical_stats_record(enable_test_module, shutdown_only):
|
||||
addresses = ray.init(include_dashboard=True, num_cpus=1)
|
||||
metrics_export_port = addresses["metrics_export_port"]
|
||||
|
@ -110,29 +114,31 @@ def test_prometheus_physical_stats_record(enable_test_module, shutdown_only):
|
|||
prom_addresses = [f"{addr}:{metrics_export_port}"]
|
||||
|
||||
def test_case_stats_exist():
|
||||
components_dict, metric_names, metric_samples = fetch_prometheus(
|
||||
prom_addresses)
|
||||
return all([
|
||||
"ray_node_cpu_utilization" in metric_names,
|
||||
"ray_node_cpu_count" in metric_names,
|
||||
"ray_node_mem_used" in metric_names,
|
||||
"ray_node_mem_available" in metric_names,
|
||||
"ray_node_mem_total" in metric_names,
|
||||
"ray_raylet_cpu" in metric_names, "ray_raylet_mem" in metric_names,
|
||||
"ray_node_disk_usage" in metric_names,
|
||||
"ray_node_disk_free" in metric_names,
|
||||
"ray_node_disk_utilization_percentage" in metric_names,
|
||||
"ray_node_network_sent" in metric_names,
|
||||
"ray_node_network_received" in metric_names,
|
||||
"ray_node_network_send_speed" in metric_names,
|
||||
"ray_node_network_receive_speed" in metric_names
|
||||
])
|
||||
components_dict, metric_names, metric_samples = fetch_prometheus(prom_addresses)
|
||||
return all(
|
||||
[
|
||||
"ray_node_cpu_utilization" in metric_names,
|
||||
"ray_node_cpu_count" in metric_names,
|
||||
"ray_node_mem_used" in metric_names,
|
||||
"ray_node_mem_available" in metric_names,
|
||||
"ray_node_mem_total" in metric_names,
|
||||
"ray_raylet_cpu" in metric_names,
|
||||
"ray_raylet_mem" in metric_names,
|
||||
"ray_node_disk_usage" in metric_names,
|
||||
"ray_node_disk_free" in metric_names,
|
||||
"ray_node_disk_utilization_percentage" in metric_names,
|
||||
"ray_node_network_sent" in metric_names,
|
||||
"ray_node_network_received" in metric_names,
|
||||
"ray_node_network_send_speed" in metric_names,
|
||||
"ray_node_network_receive_speed" in metric_names,
|
||||
]
|
||||
)
|
||||
|
||||
def test_case_ip_correct():
|
||||
components_dict, metric_names, metric_samples = fetch_prometheus(
|
||||
prom_addresses)
|
||||
components_dict, metric_names, metric_samples = fetch_prometheus(prom_addresses)
|
||||
raylet_proc = ray.worker._global_node.all_processes[
|
||||
ray_constants.PROCESS_TYPE_RAYLET][0]
|
||||
ray_constants.PROCESS_TYPE_RAYLET
|
||||
][0]
|
||||
raylet_pid = None
|
||||
# Find the raylet pid recorded in the tag.
|
||||
for sample in metric_samples:
|
||||
|
@ -159,24 +165,25 @@ def test_report_stats():
|
|||
"cpu": 57.4,
|
||||
"cpus": (8, 4),
|
||||
"mem": (17179869184, 5723353088, 66.7, 9234341888),
|
||||
"workers": [{
|
||||
"memory_info": Bunch(
|
||||
rss=55934976, vms=7026937856, pfaults=15354, pageins=0),
|
||||
"cpu_percent": 0.0,
|
||||
"cmdline": [
|
||||
"ray::IDLE", "", "", "", "", "", "", "", "", "", "", ""
|
||||
],
|
||||
"create_time": 1614826391.338613,
|
||||
"pid": 7174,
|
||||
"cpu_times": Bunch(
|
||||
user=0.607899328,
|
||||
system=0.274044032,
|
||||
children_user=0.0,
|
||||
children_system=0.0)
|
||||
}],
|
||||
"workers": [
|
||||
{
|
||||
"memory_info": Bunch(
|
||||
rss=55934976, vms=7026937856, pfaults=15354, pageins=0
|
||||
),
|
||||
"cpu_percent": 0.0,
|
||||
"cmdline": ["ray::IDLE", "", "", "", "", "", "", "", "", "", "", ""],
|
||||
"create_time": 1614826391.338613,
|
||||
"pid": 7174,
|
||||
"cpu_times": Bunch(
|
||||
user=0.607899328,
|
||||
system=0.274044032,
|
||||
children_user=0.0,
|
||||
children_system=0.0,
|
||||
),
|
||||
}
|
||||
],
|
||||
"raylet": {
|
||||
"memory_info": Bunch(
|
||||
rss=18354176, vms=6921486336, pfaults=6206, pageins=3),
|
||||
"memory_info": Bunch(rss=18354176, vms=6921486336, pfaults=6206, pageins=3),
|
||||
"cpu_percent": 0.0,
|
||||
"cmdline": ["fake raylet cmdline"],
|
||||
"create_time": 1614826390.274854,
|
||||
|
@ -185,22 +192,18 @@ def test_report_stats():
|
|||
user=0.03683138,
|
||||
system=0.035913716,
|
||||
children_user=0.0,
|
||||
children_system=0.0)
|
||||
children_system=0.0,
|
||||
),
|
||||
},
|
||||
"bootTime": 1612934656.0,
|
||||
"loadAvg": ((4.4521484375, 3.61083984375, 3.5400390625), (0.56, 0.45,
|
||||
0.44)),
|
||||
"loadAvg": ((4.4521484375, 3.61083984375, 3.5400390625), (0.56, 0.45, 0.44)),
|
||||
"disk": {
|
||||
"/": Bunch(
|
||||
total=250790436864,
|
||||
used=11316781056,
|
||||
free=22748921856,
|
||||
percent=33.2),
|
||||
total=250790436864, used=11316781056, free=22748921856, percent=33.2
|
||||
),
|
||||
"/tmp": Bunch(
|
||||
total=250790436864,
|
||||
used=209532035072,
|
||||
free=22748921856,
|
||||
percent=90.2)
|
||||
total=250790436864, used=209532035072, free=22748921856, percent=90.2
|
||||
),
|
||||
},
|
||||
"gpus": [],
|
||||
"network": (13621160960, 11914936320),
|
||||
|
@ -209,13 +212,10 @@ def test_report_stats():
|
|||
|
||||
cluster_stats = {
|
||||
"autoscaler_report": {
|
||||
"active_nodes": {
|
||||
"head_node": 1,
|
||||
"worker-node-0": 2
|
||||
},
|
||||
"active_nodes": {"head_node": 1, "worker-node-0": 2},
|
||||
"failed_nodes": [],
|
||||
"pending_launches": {},
|
||||
"pending_nodes": []
|
||||
"pending_nodes": [],
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -226,11 +226,9 @@ def test_report_stats():
|
|||
records = ReporterAgent._record_stats(obj, test_stats, cluster_stats)
|
||||
assert len(records) == 14
|
||||
# Test stats with gpus
|
||||
test_stats["gpus"] = [{
|
||||
"utilization_gpu": 1,
|
||||
"memory_used": 100,
|
||||
"memory_total": 1000
|
||||
}]
|
||||
test_stats["gpus"] = [
|
||||
{"utilization_gpu": 1, "memory_used": 100, "memory_total": 1000}
|
||||
]
|
||||
records = ReporterAgent._record_stats(obj, test_stats, cluster_stats)
|
||||
assert len(records) == 18
|
||||
# Test stats without autoscaler report
|
||||
|
|
|
@ -12,10 +12,11 @@ from ray.core.generated import runtime_env_agent_pb2
|
|||
from ray.core.generated import runtime_env_agent_pb2_grpc
|
||||
from ray.core.generated import agent_manager_pb2
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.dashboard.modules.runtime_env.runtime_env_consts \
|
||||
as runtime_env_consts
|
||||
from ray.experimental.internal_kv import _internal_kv_initialized, \
|
||||
_initialize_internal_kv
|
||||
import ray.dashboard.modules.runtime_env.runtime_env_consts as runtime_env_consts
|
||||
from ray.experimental.internal_kv import (
|
||||
_internal_kv_initialized,
|
||||
_initialize_internal_kv,
|
||||
)
|
||||
from ray._private.ray_logging import setup_component_logger
|
||||
from ray._private.runtime_env.pip import PipManager
|
||||
from ray._private.runtime_env.conda import CondaManager
|
||||
|
@ -42,8 +43,10 @@ class CreatedEnvResult:
|
|||
result: str
|
||||
|
||||
|
||||
class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
||||
runtime_env_agent_pb2_grpc.RuntimeEnvServiceServicer):
|
||||
class RuntimeEnvAgent(
|
||||
dashboard_utils.DashboardAgentModule,
|
||||
runtime_env_agent_pb2_grpc.RuntimeEnvServiceServicer,
|
||||
):
|
||||
"""An RPC server to create and delete runtime envs.
|
||||
|
||||
Attributes:
|
||||
|
@ -86,32 +89,33 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
|||
return self._per_job_logger_cache[job_id]
|
||||
|
||||
async def CreateRuntimeEnv(self, request, context):
|
||||
async def _setup_runtime_env(serialized_runtime_env,
|
||||
serialized_allocated_resource_instances):
|
||||
async def _setup_runtime_env(
|
||||
serialized_runtime_env, serialized_allocated_resource_instances
|
||||
):
|
||||
# This function will be ran inside a thread
|
||||
def run_setup_with_logger():
|
||||
runtime_env = RuntimeEnv(
|
||||
serialized_runtime_env=serialized_runtime_env)
|
||||
runtime_env = RuntimeEnv(serialized_runtime_env=serialized_runtime_env)
|
||||
allocated_resource: dict = json.loads(
|
||||
serialized_allocated_resource_instances or "{}")
|
||||
serialized_allocated_resource_instances or "{}"
|
||||
)
|
||||
|
||||
# Use a separate logger for each job.
|
||||
per_job_logger = self.get_or_create_logger(request.job_id)
|
||||
# TODO(chenk008): Add log about allocated_resource to
|
||||
# avoid lint error. That will be moved to cgroup plugin.
|
||||
per_job_logger.debug(f"Worker has resource :"
|
||||
f"{allocated_resource}")
|
||||
per_job_logger.debug(f"Worker has resource :" f"{allocated_resource}")
|
||||
context = RuntimeEnvContext(env_vars=runtime_env.env_vars())
|
||||
self._pip_manager.setup(
|
||||
runtime_env, context, logger=per_job_logger)
|
||||
self._conda_manager.setup(
|
||||
runtime_env, context, logger=per_job_logger)
|
||||
self._pip_manager.setup(runtime_env, context, logger=per_job_logger)
|
||||
self._conda_manager.setup(runtime_env, context, logger=per_job_logger)
|
||||
self._py_modules_manager.setup(
|
||||
runtime_env, context, logger=per_job_logger)
|
||||
runtime_env, context, logger=per_job_logger
|
||||
)
|
||||
self._working_dir_manager.setup(
|
||||
runtime_env, context, logger=per_job_logger)
|
||||
runtime_env, context, logger=per_job_logger
|
||||
)
|
||||
self._container_manager.setup(
|
||||
runtime_env, context, logger=per_job_logger)
|
||||
runtime_env, context, logger=per_job_logger
|
||||
)
|
||||
|
||||
# Add the mapping of URIs -> the serialized environment to be
|
||||
# used for cache invalidation.
|
||||
|
@ -133,14 +137,15 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
|||
|
||||
# Run setup function from all the plugins
|
||||
for plugin_class_path, config in runtime_env.plugins():
|
||||
logger.debug(
|
||||
f"Setting up runtime env plugin {plugin_class_path}")
|
||||
logger.debug(f"Setting up runtime env plugin {plugin_class_path}")
|
||||
plugin_class = import_attr(plugin_class_path)
|
||||
# TODO(simon): implement uri support
|
||||
plugin_class.create("uri not implemented",
|
||||
json.loads(config), context)
|
||||
plugin_class.modify_context("uri not implemented",
|
||||
json.loads(config), context)
|
||||
plugin_class.create(
|
||||
"uri not implemented", json.loads(config), context
|
||||
)
|
||||
plugin_class.modify_context(
|
||||
"uri not implemented", json.loads(config), context
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
@ -159,18 +164,24 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
|||
result = self._env_cache[serialized_env]
|
||||
if result.success:
|
||||
context = result.result
|
||||
logger.info("Runtime env already created successfully. "
|
||||
f"Env: {serialized_env}, context: {context}")
|
||||
logger.info(
|
||||
"Runtime env already created successfully. "
|
||||
f"Env: {serialized_env}, context: {context}"
|
||||
)
|
||||
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
|
||||
serialized_runtime_env_context=context)
|
||||
serialized_runtime_env_context=context,
|
||||
)
|
||||
else:
|
||||
error_message = result.result
|
||||
logger.info("Runtime env already failed. "
|
||||
f"Env: {serialized_env}, err: {error_message}")
|
||||
logger.info(
|
||||
"Runtime env already failed. "
|
||||
f"Env: {serialized_env}, err: {error_message}"
|
||||
)
|
||||
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
||||
error_message=error_message)
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
if SLEEP_FOR_TESTING_S:
|
||||
logger.info(f"Sleeping for {SLEEP_FOR_TESTING_S}s.")
|
||||
|
@ -182,8 +193,8 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
|||
for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES):
|
||||
try:
|
||||
runtime_env_context = await _setup_runtime_env(
|
||||
serialized_env,
|
||||
request.serialized_allocated_resource_instances)
|
||||
serialized_env, request.serialized_allocated_resource_instances
|
||||
)
|
||||
break
|
||||
except Exception as ex:
|
||||
logger.exception("Runtime env creation failed.")
|
||||
|
@ -195,22 +206,25 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
|||
logger.error(
|
||||
"Runtime env creation failed for %d times, "
|
||||
"don't retry any more.",
|
||||
runtime_env_consts.RUNTIME_ENV_RETRY_TIMES)
|
||||
self._env_cache[serialized_env] = CreatedEnvResult(
|
||||
False, error_message)
|
||||
runtime_env_consts.RUNTIME_ENV_RETRY_TIMES,
|
||||
)
|
||||
self._env_cache[serialized_env] = CreatedEnvResult(False, error_message)
|
||||
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
||||
error_message=error_message)
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
serialized_context = runtime_env_context.serialize()
|
||||
self._env_cache[serialized_env] = CreatedEnvResult(
|
||||
True, serialized_context)
|
||||
self._env_cache[serialized_env] = CreatedEnvResult(True, serialized_context)
|
||||
logger.info(
|
||||
"Successfully created runtime env: %s, the context: %s",
|
||||
serialized_env, serialized_context)
|
||||
serialized_env,
|
||||
serialized_context,
|
||||
)
|
||||
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
|
||||
serialized_runtime_env_context=serialized_context)
|
||||
serialized_runtime_env_context=serialized_context,
|
||||
)
|
||||
|
||||
async def DeleteURIs(self, request, context):
|
||||
logger.info(f"Got request to delete URIs: {request.uris}.")
|
||||
|
@ -239,20 +253,21 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
|
|||
else:
|
||||
raise ValueError(
|
||||
"RuntimeEnvAgent received DeleteURI request "
|
||||
f"for unsupported plugin {plugin}. URI: {uri}")
|
||||
f"for unsupported plugin {plugin}. URI: {uri}"
|
||||
)
|
||||
|
||||
if failed_uris:
|
||||
return runtime_env_agent_pb2.DeleteURIsReply(
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
||||
error_message="Local files for URI(s) "
|
||||
f"{failed_uris} not found.")
|
||||
error_message="Local files for URI(s) " f"{failed_uris} not found.",
|
||||
)
|
||||
else:
|
||||
return runtime_env_agent_pb2.DeleteURIsReply(
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_OK)
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_OK
|
||||
)
|
||||
|
||||
async def run(self, server):
|
||||
runtime_env_agent_pb2_grpc.add_RuntimeEnvServiceServicer_to_server(
|
||||
self, server)
|
||||
runtime_env_agent_pb2_grpc.add_RuntimeEnvServiceServicer_to_server(self, server)
|
||||
|
||||
@staticmethod
|
||||
def is_minimal_module():
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import ray.ray_constants as ray_constants
|
||||
|
||||
RUNTIME_ENV_RETRY_TIMES = ray_constants.env_integer("RUNTIME_ENV_RETRY_TIMES",
|
||||
3)
|
||||
RUNTIME_ENV_RETRY_TIMES = ray_constants.env_integer("RUNTIME_ENV_RETRY_TIMES", 3)
|
||||
|
||||
RUNTIME_ENV_RETRY_INTERVAL_MS = ray_constants.env_integer(
|
||||
"RUNTIME_ENV_RETRY_INTERVAL_MS", 1000)
|
||||
"RUNTIME_ENV_RETRY_INTERVAL_MS", 1000
|
||||
)
|
||||
|
|
|
@ -8,13 +8,19 @@ from ray import ray_constants
|
|||
from ray.core.generated import gcs_service_pb2
|
||||
from ray.core.generated import gcs_pb2
|
||||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
from ray.experimental.internal_kv import (_internal_kv_initialized,
|
||||
_internal_kv_get, _internal_kv_list)
|
||||
from ray.experimental.internal_kv import (
|
||||
_internal_kv_initialized,
|
||||
_internal_kv_get,
|
||||
_internal_kv_list,
|
||||
)
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||
from ray._private.runtime_env.validation import ParsedRuntimeEnv
|
||||
from ray.dashboard.modules.job.common import (
|
||||
JobStatusInfo, JobStatusStorageClient, JOB_ID_METADATA_KEY)
|
||||
JobStatusInfo,
|
||||
JobStatusStorageClient,
|
||||
JOB_ID_METADATA_KEY,
|
||||
)
|
||||
|
||||
import json
|
||||
import aiohttp.web
|
||||
|
@ -32,7 +38,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
|||
self._job_status_client = JobStatusStorageClient()
|
||||
# For offloading CPU intensive work.
|
||||
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=2, thread_name_prefix="api_head")
|
||||
max_workers=2, thread_name_prefix="api_head"
|
||||
)
|
||||
|
||||
@routes.get("/api/actors/kill")
|
||||
async def kill_actor_gcs(self, req) -> aiohttp.web.Response:
|
||||
|
@ -41,7 +48,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
|||
no_restart = req.query.get("no_restart", False) in ("true", "True")
|
||||
if not actor_id:
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=False, message="actor_id is required.")
|
||||
success=False, message="actor_id is required."
|
||||
)
|
||||
|
||||
request = gcs_service_pb2.KillActorViaGcsRequest()
|
||||
request.actor_id = bytes.fromhex(actor_id)
|
||||
|
@ -49,31 +57,36 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
|||
request.no_restart = no_restart
|
||||
await self._gcs_actor_info_stub.KillActorViaGcs(request, timeout=5)
|
||||
|
||||
message = (f"Force killed actor with id {actor_id}" if force_kill else
|
||||
f"Requested actor with id {actor_id} to terminate. " +
|
||||
"It will exit once running tasks complete")
|
||||
message = (
|
||||
f"Force killed actor with id {actor_id}"
|
||||
if force_kill
|
||||
else f"Requested actor with id {actor_id} to terminate. "
|
||||
+ "It will exit once running tasks complete"
|
||||
)
|
||||
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True, message=message)
|
||||
return dashboard_optional_utils.rest_response(success=True, message=message)
|
||||
|
||||
@routes.get("/api/snapshot")
|
||||
async def snapshot(self, req):
|
||||
job_data, actor_data, serve_data, session_name = await asyncio.gather(
|
||||
self.get_job_info(), self.get_actor_info(), self.get_serve_info(),
|
||||
self.get_session_name())
|
||||
self.get_job_info(),
|
||||
self.get_actor_info(),
|
||||
self.get_serve_info(),
|
||||
self.get_session_name(),
|
||||
)
|
||||
snapshot = {
|
||||
"jobs": job_data,
|
||||
"actors": actor_data,
|
||||
"deployments": serve_data,
|
||||
"session_name": session_name,
|
||||
"ray_version": ray.__version__,
|
||||
"ray_commit": ray.__commit__
|
||||
"ray_commit": ray.__commit__,
|
||||
}
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True, message="hello", snapshot=snapshot)
|
||||
success=True, message="hello", snapshot=snapshot
|
||||
)
|
||||
|
||||
def _get_job_status(self,
|
||||
metadata: Dict[str, str]) -> Optional[JobStatusInfo]:
|
||||
def _get_job_status(self, metadata: Dict[str, str]) -> Optional[JobStatusInfo]:
|
||||
# If a job submission ID has been added to a job, the status is
|
||||
# guaranteed to be returned.
|
||||
job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
|
||||
|
@ -91,8 +104,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
|||
"namespace": job_table_entry.config.ray_namespace,
|
||||
"metadata": metadata,
|
||||
"runtime_env": ParsedRuntimeEnv.deserialize(
|
||||
job_table_entry.config.runtime_env_info.
|
||||
serialized_runtime_env),
|
||||
job_table_entry.config.runtime_env_info.serialized_runtime_env
|
||||
),
|
||||
}
|
||||
status = self._get_job_status(metadata)
|
||||
entry = {
|
||||
|
@ -111,8 +124,7 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
|||
# TODO (Alex): GCS still needs to return actors from dead jobs.
|
||||
request = gcs_service_pb2.GetAllActorInfoRequest()
|
||||
request.show_dead_jobs = True
|
||||
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
|
||||
request, timeout=5)
|
||||
reply = await self._gcs_actor_info_stub.GetAllActorInfo(request, timeout=5)
|
||||
actors = {}
|
||||
for actor_table_entry in reply.actor_table_data:
|
||||
actor_id = actor_table_entry.actor_id.hex()
|
||||
|
@ -120,37 +132,33 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
|||
entry = {
|
||||
"job_id": actor_table_entry.job_id.hex(),
|
||||
"state": gcs_pb2.ActorTableData.ActorState.Name(
|
||||
actor_table_entry.state),
|
||||
actor_table_entry.state
|
||||
),
|
||||
"name": actor_table_entry.name,
|
||||
"namespace": actor_table_entry.ray_namespace,
|
||||
"runtime_env": runtime_env,
|
||||
"start_time": actor_table_entry.start_time,
|
||||
"end_time": actor_table_entry.end_time,
|
||||
"is_detached": actor_table_entry.is_detached,
|
||||
"resources": dict(
|
||||
actor_table_entry.task_spec.required_resources),
|
||||
"resources": dict(actor_table_entry.task_spec.required_resources),
|
||||
"actor_class": actor_table_entry.class_name,
|
||||
"current_worker_id": actor_table_entry.address.worker_id.hex(),
|
||||
"current_raylet_id": actor_table_entry.address.raylet_id.hex(),
|
||||
"ip_address": actor_table_entry.address.ip_address,
|
||||
"port": actor_table_entry.address.port,
|
||||
"metadata": dict()
|
||||
"metadata": dict(),
|
||||
}
|
||||
actors[actor_id] = entry
|
||||
|
||||
deployments = await self.get_serve_info()
|
||||
for _, deployment_info in deployments.items():
|
||||
for replica_actor_id, actor_info in deployment_info[
|
||||
"actors"].items():
|
||||
for replica_actor_id, actor_info in deployment_info["actors"].items():
|
||||
if replica_actor_id in actors:
|
||||
serve_metadata = dict()
|
||||
serve_metadata["replica_tag"] = actor_info[
|
||||
"replica_tag"]
|
||||
serve_metadata["deployment_name"] = deployment_info[
|
||||
"name"]
|
||||
serve_metadata["replica_tag"] = actor_info["replica_tag"]
|
||||
serve_metadata["deployment_name"] = deployment_info["name"]
|
||||
serve_metadata["version"] = actor_info["version"]
|
||||
actors[replica_actor_id]["metadata"][
|
||||
"serve"] = serve_metadata
|
||||
actors[replica_actor_id]["metadata"]["serve"] = serve_metadata
|
||||
return actors
|
||||
|
||||
async def get_serve_info(self) -> Dict[str, Any]:
|
||||
|
@ -168,22 +176,21 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
|||
# TODO: Convert to async GRPC, if CPU usage is not a concern.
|
||||
def get_deployments():
|
||||
serve_keys = _internal_kv_list(
|
||||
SERVE_CONTROLLER_NAME,
|
||||
namespace=ray_constants.KV_NAMESPACE_SERVE)
|
||||
SERVE_CONTROLLER_NAME, namespace=ray_constants.KV_NAMESPACE_SERVE
|
||||
)
|
||||
serve_snapshot_keys = filter(
|
||||
lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys)
|
||||
lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys
|
||||
)
|
||||
|
||||
deployments_per_controller: List[Dict[str, Any]] = []
|
||||
for key in serve_snapshot_keys:
|
||||
val_bytes = _internal_kv_get(
|
||||
key, namespace=ray_constants.KV_NAMESPACE_SERVE
|
||||
) or "{}".encode("utf-8")
|
||||
deployments_per_controller.append(
|
||||
json.loads(val_bytes.decode("utf-8")))
|
||||
deployments_per_controller.append(json.loads(val_bytes.decode("utf-8")))
|
||||
# Merge the deployments dicts of all controllers.
|
||||
deployments: Dict[str, Any] = {
|
||||
k: v
|
||||
for d in deployments_per_controller for k, v in d.items()
|
||||
k: v for d in deployments_per_controller for k, v in d.items()
|
||||
}
|
||||
# Replace the keys (deployment names) with their hashes to prevent
|
||||
# collisions caused by the automatic conversion to camelcase by the
|
||||
|
@ -194,24 +201,27 @@ class APIHead(dashboard_utils.DashboardHeadModule):
|
|||
}
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
executor=self._thread_pool, func=get_deployments)
|
||||
executor=self._thread_pool, func=get_deployments
|
||||
)
|
||||
|
||||
async def get_session_name(self):
|
||||
# TODO(yic): Convert to async GRPC.
|
||||
def get_session():
|
||||
return ray.experimental.internal_kv._internal_kv_get(
|
||||
"session_name",
|
||||
namespace=ray_constants.KV_NAMESPACE_SESSION).decode()
|
||||
"session_name", namespace=ray_constants.KV_NAMESPACE_SESSION
|
||||
).decode()
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
executor=self._thread_pool, func=get_session)
|
||||
executor=self._thread_pool, func=get_session
|
||||
)
|
||||
|
||||
async def run(self, server):
|
||||
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
|
||||
self._dashboard_head.aiogrpc_gcs_channel)
|
||||
self._gcs_actor_info_stub = \
|
||||
gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
|
||||
self._dashboard_head.aiogrpc_gcs_channel)
|
||||
self._dashboard_head.aiogrpc_gcs_channel
|
||||
)
|
||||
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
|
||||
self._dashboard_head.aiogrpc_gcs_channel
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_minimal_module():
|
||||
|
|
|
@ -36,15 +36,14 @@ def _actor_killed_loop(worker_pid: str, timeout_secs=3) -> bool:
|
|||
return dead
|
||||
|
||||
|
||||
def _kill_actor_using_dashboard_gcs(webui_url: str,
|
||||
actor_id: str,
|
||||
force_kill=False):
|
||||
def _kill_actor_using_dashboard_gcs(webui_url: str, actor_id: str, force_kill=False):
|
||||
resp = requests.get(
|
||||
webui_url + KILL_ACTOR_ENDPOINT,
|
||||
params={
|
||||
"actor_id": actor_id,
|
||||
"force_kill": force_kill,
|
||||
})
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp_json = resp.json()
|
||||
assert resp_json["result"] is True, "msg" in resp_json
|
||||
|
|
|
@ -8,8 +8,11 @@ import pprint
|
|||
import pytest
|
||||
import requests
|
||||
|
||||
from ray._private.test_utils import (format_web_url, wait_for_condition,
|
||||
wait_until_server_available)
|
||||
from ray._private.test_utils import (
|
||||
format_web_url,
|
||||
wait_for_condition,
|
||||
wait_until_server_available,
|
||||
)
|
||||
from ray.dashboard import dashboard
|
||||
from ray.dashboard.tests.conftest import * # noqa
|
||||
from ray.dashboard.modules.job.sdk import JobSubmissionClient
|
||||
|
@ -22,25 +25,23 @@ def _get_snapshot(address: str):
|
|||
response.raise_for_status()
|
||||
data = response.json()
|
||||
schema_path = os.path.join(
|
||||
os.path.dirname(dashboard.__file__),
|
||||
"modules/snapshot/snapshot_schema.json")
|
||||
os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
|
||||
)
|
||||
pprint.pprint(data)
|
||||
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
|
||||
return data
|
||||
|
||||
|
||||
def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
|
||||
enable_test_module):
|
||||
def test_successful_job_status(
|
||||
ray_start_with_dashboard, disable_aiohttp_cache, enable_test_module
|
||||
):
|
||||
address = ray_start_with_dashboard["webui_url"]
|
||||
assert wait_until_server_available(address)
|
||||
address = format_web_url(address)
|
||||
|
||||
entrypoint_cmd = ("python -c\""
|
||||
"import ray;"
|
||||
"ray.init();"
|
||||
"import time;"
|
||||
"time.sleep(5);"
|
||||
"\"")
|
||||
entrypoint_cmd = (
|
||||
'python -c"' "import ray;" "ray.init();" "import time;" "time.sleep(5);" '"'
|
||||
)
|
||||
|
||||
client = JobSubmissionClient(address)
|
||||
job_id = client.submit_job(entrypoint=entrypoint_cmd)
|
||||
|
@ -49,11 +50,8 @@ def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
|
|||
data = _get_snapshot(address)
|
||||
for job_entry in data["data"]["snapshot"]["jobs"].values():
|
||||
if job_entry["status"] is not None:
|
||||
assert job_entry["config"]["metadata"][
|
||||
"jobSubmissionId"] == job_id
|
||||
assert job_entry["status"] in {
|
||||
"PENDING", "RUNNING", "SUCCEEDED"
|
||||
}
|
||||
assert job_entry["config"]["metadata"]["jobSubmissionId"] == job_id
|
||||
assert job_entry["status"] in {"PENDING", "RUNNING", "SUCCEEDED"}
|
||||
assert job_entry["statusMessage"] is not None
|
||||
return job_entry["status"] == "SUCCEEDED"
|
||||
|
||||
|
@ -62,20 +60,23 @@ def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
|
|||
wait_for_condition(wait_for_job_to_succeed, timeout=30)
|
||||
|
||||
|
||||
def test_failed_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
|
||||
enable_test_module):
|
||||
def test_failed_job_status(
|
||||
ray_start_with_dashboard, disable_aiohttp_cache, enable_test_module
|
||||
):
|
||||
address = ray_start_with_dashboard["webui_url"]
|
||||
assert wait_until_server_available(address)
|
||||
address = format_web_url(address)
|
||||
|
||||
entrypoint_cmd = ("python -c\""
|
||||
"import ray;"
|
||||
"ray.init();"
|
||||
"import time;"
|
||||
"time.sleep(5);"
|
||||
"import sys;"
|
||||
"sys.exit(1);"
|
||||
"\"")
|
||||
entrypoint_cmd = (
|
||||
'python -c"'
|
||||
"import ray;"
|
||||
"ray.init();"
|
||||
"import time;"
|
||||
"time.sleep(5);"
|
||||
"import sys;"
|
||||
"sys.exit(1);"
|
||||
'"'
|
||||
)
|
||||
client = JobSubmissionClient(address)
|
||||
job_id = client.submit_job(entrypoint=entrypoint_cmd)
|
||||
|
||||
|
@ -83,8 +84,7 @@ def test_failed_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
|
|||
data = _get_snapshot(address)
|
||||
for job_entry in data["data"]["snapshot"]["jobs"].values():
|
||||
if job_entry["status"] is not None:
|
||||
assert job_entry["config"]["metadata"][
|
||||
"jobSubmissionId"] == job_id
|
||||
assert job_entry["config"]["metadata"]["jobSubmissionId"] == job_id
|
||||
assert job_entry["status"] in {"PENDING", "RUNNING", "FAILED"}
|
||||
assert job_entry["statusMessage"] is not None
|
||||
return job_entry["status"] == "FAILED"
|
||||
|
|
|
@ -34,11 +34,14 @@ ray.get(a.ping.remote())
|
|||
"""
|
||||
address = ray_start_with_dashboard["address"]
|
||||
detached_driver = driver_template.format(
|
||||
address=address, lifetime="'detached'", name="'abc'")
|
||||
address=address, lifetime="'detached'", name="'abc'"
|
||||
)
|
||||
named_driver = driver_template.format(
|
||||
address=address, lifetime="None", name="'xyz'")
|
||||
address=address, lifetime="None", name="'xyz'"
|
||||
)
|
||||
unnamed_driver = driver_template.format(
|
||||
address=address, lifetime="None", name="None")
|
||||
address=address, lifetime="None", name="None"
|
||||
)
|
||||
|
||||
run_string_as_driver(detached_driver)
|
||||
run_string_as_driver(named_driver)
|
||||
|
@ -50,8 +53,8 @@ ray.get(a.ping.remote())
|
|||
response.raise_for_status()
|
||||
data = response.json()
|
||||
schema_path = os.path.join(
|
||||
os.path.dirname(dashboard.__file__),
|
||||
"modules/snapshot/snapshot_schema.json")
|
||||
os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
|
||||
)
|
||||
pprint.pprint(data)
|
||||
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
|
||||
|
||||
|
@ -72,10 +75,7 @@ ray.get(a.ping.remote())
|
|||
assert data["data"]["snapshot"]["rayVersion"] == ray.__version__
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_with_dashboard", [{
|
||||
"num_cpus": 4
|
||||
}], indirect=True)
|
||||
@pytest.mark.parametrize("ray_start_with_dashboard", [{"num_cpus": 4}], indirect=True)
|
||||
def test_serve_snapshot(ray_start_with_dashboard):
|
||||
"""Test detached and nondetached Serve instances running concurrently."""
|
||||
|
||||
|
@ -115,8 +115,7 @@ my_func_deleted.delete()
|
|||
|
||||
my_func_nondetached.deploy()
|
||||
|
||||
assert requests.get(
|
||||
"http://127.0.0.1:8123/my_func_nondetached").text == "hello"
|
||||
assert requests.get("http://127.0.0.1:8123/my_func_nondetached").text == "hello"
|
||||
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
@ -124,15 +123,16 @@ my_func_deleted.delete()
|
|||
response.raise_for_status()
|
||||
data = response.json()
|
||||
schema_path = os.path.join(
|
||||
os.path.dirname(dashboard.__file__),
|
||||
"modules/snapshot/snapshot_schema.json")
|
||||
os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
|
||||
)
|
||||
pprint.pprint(data)
|
||||
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
|
||||
|
||||
assert len(data["data"]["snapshot"]["deployments"]) == 3
|
||||
|
||||
entry = data["data"]["snapshot"]["deployments"][hashlib.sha1(
|
||||
"my_func".encode()).hexdigest()]
|
||||
entry = data["data"]["snapshot"]["deployments"][
|
||||
hashlib.sha1("my_func".encode()).hexdigest()
|
||||
]
|
||||
assert entry["name"] == "my_func"
|
||||
assert entry["version"] is None
|
||||
assert entry["namespace"] == "serve"
|
||||
|
@ -145,14 +145,14 @@ my_func_deleted.delete()
|
|||
|
||||
assert len(entry["actors"]) == 1
|
||||
actor_id = next(iter(entry["actors"]))
|
||||
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"][
|
||||
"serve"]
|
||||
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"]["serve"]
|
||||
assert metadata["deploymentName"] == "my_func"
|
||||
assert metadata["version"] is None
|
||||
assert len(metadata["replicaTag"]) > 0
|
||||
|
||||
entry_deleted = data["data"]["snapshot"]["deployments"][hashlib.sha1(
|
||||
"my_func_deleted".encode()).hexdigest()]
|
||||
entry_deleted = data["data"]["snapshot"]["deployments"][
|
||||
hashlib.sha1("my_func_deleted".encode()).hexdigest()
|
||||
]
|
||||
assert entry_deleted["name"] == "my_func_deleted"
|
||||
assert entry_deleted["version"] == "v1"
|
||||
assert entry_deleted["namespace"] == "serve"
|
||||
|
@ -163,8 +163,9 @@ my_func_deleted.delete()
|
|||
assert entry_deleted["startTime"] > 0
|
||||
assert entry_deleted["endTime"] > entry_deleted["startTime"]
|
||||
|
||||
entry_nondetached = data["data"]["snapshot"]["deployments"][hashlib.sha1(
|
||||
"my_func_nondetached".encode()).hexdigest()]
|
||||
entry_nondetached = data["data"]["snapshot"]["deployments"][
|
||||
hashlib.sha1("my_func_nondetached".encode()).hexdigest()
|
||||
]
|
||||
assert entry_nondetached["name"] == "my_func_nondetached"
|
||||
assert entry_nondetached["version"] == "v1"
|
||||
assert entry_nondetached["namespace"] == "default_test_namespace"
|
||||
|
@ -177,8 +178,7 @@ my_func_deleted.delete()
|
|||
|
||||
assert len(entry_nondetached["actors"]) == 1
|
||||
actor_id = next(iter(entry_nondetached["actors"]))
|
||||
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"][
|
||||
"serve"]
|
||||
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"]["serve"]
|
||||
assert metadata["deploymentName"] == "my_func_nondetached"
|
||||
assert metadata["version"] == "v1"
|
||||
assert len(metadata["replicaTag"]) > 0
|
||||
|
|
|
@ -13,7 +13,8 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
|
|||
|
||||
|
||||
@dashboard_utils.dashboard_module(
|
||||
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False))
|
||||
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False)
|
||||
)
|
||||
class TestAgent(dashboard_utils.DashboardAgentModule):
|
||||
def __init__(self, dashboard_agent):
|
||||
super().__init__(dashboard_agent)
|
||||
|
@ -25,8 +26,7 @@ class TestAgent(dashboard_utils.DashboardAgentModule):
|
|||
@routes.get("/test/http_get_from_agent")
|
||||
async def get_url(self, req) -> aiohttp.web.Response:
|
||||
url = req.query.get("url")
|
||||
result = await test_utils.http_get(self._dashboard_agent.http_session,
|
||||
url)
|
||||
result = await test_utils.http_get(self._dashboard_agent.http_session, url)
|
||||
return aiohttp.web.json_response(result)
|
||||
|
||||
@routes.head("/test/route_head")
|
||||
|
|
|
@ -15,7 +15,8 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
|
|||
|
||||
|
||||
@dashboard_utils.dashboard_module(
|
||||
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False))
|
||||
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False)
|
||||
)
|
||||
class TestHead(dashboard_utils.DashboardHeadModule):
|
||||
def __init__(self, dashboard_head):
|
||||
super().__init__(dashboard_head)
|
||||
|
@ -62,26 +63,28 @@ class TestHead(dashboard_utils.DashboardHeadModule):
|
|||
return dashboard_optional_utils.rest_response(
|
||||
success=True,
|
||||
message="Fetch all data from datacenter success.",
|
||||
**all_data)
|
||||
**all_data,
|
||||
)
|
||||
else:
|
||||
data = dict(DataSource.__dict__.get(key))
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True,
|
||||
message=f"Fetch {key} from datacenter success.",
|
||||
**{key: data})
|
||||
**{key: data},
|
||||
)
|
||||
|
||||
@routes.get("/test/notified_agents")
|
||||
async def get_notified_agents(self, req) -> aiohttp.web.Response:
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True,
|
||||
message="Fetch notified agents success.",
|
||||
**self._notified_agents)
|
||||
**self._notified_agents,
|
||||
)
|
||||
|
||||
@routes.get("/test/http_get")
|
||||
async def get_url(self, req) -> aiohttp.web.Response:
|
||||
url = req.query.get("url")
|
||||
result = await test_utils.http_get(self._dashboard_head.http_session,
|
||||
url)
|
||||
result = await test_utils.http_get(self._dashboard_head.http_session, url)
|
||||
return aiohttp.web.json_response(result)
|
||||
|
||||
@routes.get("/test/aiohttp_cache/{sub_path}")
|
||||
|
@ -89,14 +92,16 @@ class TestHead(dashboard_utils.DashboardHeadModule):
|
|||
async def test_aiohttp_cache(self, req) -> aiohttp.web.Response:
|
||||
value = req.query["value"]
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True, message="OK", value=value, timestamp=time.time())
|
||||
success=True, message="OK", value=value, timestamp=time.time()
|
||||
)
|
||||
|
||||
@routes.get("/test/aiohttp_cache_lru/{sub_path}")
|
||||
@dashboard_optional_utils.aiohttp_cache(ttl_seconds=60, maxsize=5)
|
||||
async def test_aiohttp_cache_lru(self, req) -> aiohttp.web.Response:
|
||||
value = req.query.get("value")
|
||||
return dashboard_optional_utils.rest_response(
|
||||
success=True, message="OK", value=value, timestamp=time.time())
|
||||
success=True, message="OK", value=value, timestamp=time.time()
|
||||
)
|
||||
|
||||
@routes.get("/test/file")
|
||||
async def test_file(self, req) -> aiohttp.web.FileResponse:
|
||||
|
|
|
@ -4,8 +4,7 @@ import copy
|
|||
import os
|
||||
import aiohttp.web
|
||||
|
||||
import ray.dashboard.modules.tune.tune_consts \
|
||||
as tune_consts
|
||||
import ray.dashboard.modules.tune.tune_consts as tune_consts
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||
from ray.dashboard.utils import async_loop_forever
|
||||
|
@ -45,19 +44,17 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
|||
@routes.get("/tune/info")
|
||||
async def tune_info(self, req) -> aiohttp.web.Response:
|
||||
stats = self.get_stats()
|
||||
return rest_response(
|
||||
success=True, message="Fetched tune info", result=stats)
|
||||
return rest_response(success=True, message="Fetched tune info", result=stats)
|
||||
|
||||
@routes.get("/tune/availability")
|
||||
async def get_availability(self, req) -> aiohttp.web.Response:
|
||||
availability = {
|
||||
"available": ExperimentAnalysis is not None,
|
||||
"trials_available": self._trials_available
|
||||
"trials_available": self._trials_available,
|
||||
}
|
||||
return rest_response(
|
||||
success=True,
|
||||
message="Fetched tune availability",
|
||||
result=availability)
|
||||
success=True, message="Fetched tune availability", result=availability
|
||||
)
|
||||
|
||||
@routes.get("/tune/set_experiment")
|
||||
async def set_tune_experiment(self, req) -> aiohttp.web.Response:
|
||||
|
@ -66,25 +63,25 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
|||
if err:
|
||||
return rest_response(success=False, error=err)
|
||||
return rest_response(
|
||||
success=True, message="Successfully set experiment", **experiment)
|
||||
success=True, message="Successfully set experiment", **experiment
|
||||
)
|
||||
|
||||
@routes.get("/tune/enable_tensorboard")
|
||||
async def enable_tensorboard(self, req) -> aiohttp.web.Response:
|
||||
self._enable_tensorboard()
|
||||
if not self._tensor_board_dir:
|
||||
return rest_response(
|
||||
success=False, message="Error enabling tensorboard")
|
||||
return rest_response(success=False, message="Error enabling tensorboard")
|
||||
return rest_response(success=True, message="Enabled tensorboard")
|
||||
|
||||
def get_stats(self):
|
||||
tensor_board_info = {
|
||||
"tensorboard_current": self._logdir == self._tensor_board_dir,
|
||||
"tensorboard_enabled": self._tensor_board_dir != ""
|
||||
"tensorboard_enabled": self._tensor_board_dir != "",
|
||||
}
|
||||
return {
|
||||
"trial_records": copy.deepcopy(self._trial_records),
|
||||
"errors": copy.deepcopy(self._errors),
|
||||
"tensorboard": tensor_board_info
|
||||
"tensorboard": tensor_board_info,
|
||||
}
|
||||
|
||||
def set_experiment(self, experiment):
|
||||
|
@ -104,7 +101,8 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
|||
def collect_errors(self, df):
|
||||
sub_dirs = os.listdir(self._logdir)
|
||||
trial_names = filter(
|
||||
lambda d: os.path.isdir(os.path.join(self._logdir, d)), sub_dirs)
|
||||
lambda d: os.path.isdir(os.path.join(self._logdir, d)), sub_dirs
|
||||
)
|
||||
for trial in trial_names:
|
||||
error_path = os.path.join(self._logdir, trial, "error.txt")
|
||||
if os.path.isfile(error_path):
|
||||
|
@ -114,7 +112,7 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
|||
self._errors[str(trial)] = {
|
||||
"text": text,
|
||||
"job_id": os.path.basename(self._logdir),
|
||||
"trial_id": "No Trial ID"
|
||||
"trial_id": "No Trial ID",
|
||||
}
|
||||
other_data = df[df["logdir"].str.contains(trial)]
|
||||
if len(other_data) > 0:
|
||||
|
@ -175,12 +173,25 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
|||
|
||||
# list of static attributes for trial
|
||||
default_names = {
|
||||
"logdir", "time_this_iter_s", "done", "episodes_total",
|
||||
"training_iteration", "timestamp", "timesteps_total",
|
||||
"experiment_id", "date", "timestamp", "time_total_s", "pid",
|
||||
"hostname", "node_ip", "time_since_restore",
|
||||
"timesteps_since_restore", "iterations_since_restore",
|
||||
"experiment_tag", "trial_id"
|
||||
"logdir",
|
||||
"time_this_iter_s",
|
||||
"done",
|
||||
"episodes_total",
|
||||
"training_iteration",
|
||||
"timestamp",
|
||||
"timesteps_total",
|
||||
"experiment_id",
|
||||
"date",
|
||||
"timestamp",
|
||||
"time_total_s",
|
||||
"pid",
|
||||
"hostname",
|
||||
"node_ip",
|
||||
"time_since_restore",
|
||||
"timesteps_since_restore",
|
||||
"iterations_since_restore",
|
||||
"experiment_tag",
|
||||
"trial_id",
|
||||
}
|
||||
|
||||
# filter attributes into floats, metrics, and config variables
|
||||
|
@ -196,7 +207,8 @@ class TuneController(dashboard_utils.DashboardHeadModule):
|
|||
for trial, details in trial_details.items():
|
||||
ts = os.path.getctime(details["logdir"])
|
||||
formatted_time = datetime.datetime.fromtimestamp(ts).strftime(
|
||||
"%Y-%m-%d %H:%M:%S")
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
details["start_time"] = formatted_time
|
||||
details["params"] = {}
|
||||
details["metrics"] = {}
|
||||
|
|
|
@ -25,7 +25,7 @@ except AttributeError:
|
|||
# All third-party dependencies that are not included in the minimal Ray
|
||||
# installation must be included in this file. This allows us to determine if
|
||||
# the agent has the necessary dependencies to be started.
|
||||
from ray.dashboard.optional_deps import (aiohttp, hdrs, PathLike, RouteDef)
|
||||
from ray.dashboard.optional_deps import aiohttp, hdrs, PathLike, RouteDef
|
||||
from ray.dashboard.utils import to_google_style, CustomEncoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -68,12 +68,15 @@ class ClassMethodRouteTable:
|
|||
def _wrapper(handler):
|
||||
if path in cls._bind_map[method]:
|
||||
bind_info = cls._bind_map[method][path]
|
||||
raise Exception(f"Duplicated route path: {path}, "
|
||||
f"previous one registered at "
|
||||
f"{bind_info.filename}:{bind_info.lineno}")
|
||||
raise Exception(
|
||||
f"Duplicated route path: {path}, "
|
||||
f"previous one registered at "
|
||||
f"{bind_info.filename}:{bind_info.lineno}"
|
||||
)
|
||||
|
||||
bind_info = cls._BindInfo(handler.__code__.co_filename,
|
||||
handler.__code__.co_firstlineno, None)
|
||||
bind_info = cls._BindInfo(
|
||||
handler.__code__.co_filename, handler.__code__.co_firstlineno, None
|
||||
)
|
||||
|
||||
@functools.wraps(handler)
|
||||
async def _handler_route(*args) -> aiohttp.web.Response:
|
||||
|
@ -86,8 +89,7 @@ class ClassMethodRouteTable:
|
|||
return await handler(bind_info.instance, req)
|
||||
except Exception:
|
||||
logger.exception("Handle %s %s failed.", method, path)
|
||||
return rest_response(
|
||||
success=False, message=traceback.format_exc())
|
||||
return rest_response(success=False, message=traceback.format_exc())
|
||||
|
||||
cls._bind_map[method][path] = bind_info
|
||||
_handler_route.__route_method__ = method
|
||||
|
@ -132,18 +134,19 @@ class ClassMethodRouteTable:
|
|||
def bind(cls, instance):
|
||||
def predicate(o):
|
||||
if inspect.ismethod(o):
|
||||
return hasattr(o, "__route_method__") and hasattr(
|
||||
o, "__route_path__")
|
||||
return hasattr(o, "__route_method__") and hasattr(o, "__route_path__")
|
||||
return False
|
||||
|
||||
handler_routes = inspect.getmembers(instance, predicate)
|
||||
for _, h in handler_routes:
|
||||
cls._bind_map[h.__func__.__route_method__][
|
||||
h.__func__.__route_path__].instance = instance
|
||||
h.__func__.__route_path__
|
||||
].instance = instance
|
||||
|
||||
|
||||
def rest_response(success, message, convert_google_style=True,
|
||||
**kwargs) -> aiohttp.web.Response:
|
||||
def rest_response(
|
||||
success, message, convert_google_style=True, **kwargs
|
||||
) -> aiohttp.web.Response:
|
||||
# In the dev context we allow a dev server running on a
|
||||
# different port to consume the API, meaning we need to allow
|
||||
# cross-origin access
|
||||
|
@ -155,24 +158,24 @@ def rest_response(success, message, convert_google_style=True,
|
|||
{
|
||||
"result": success,
|
||||
"msg": message,
|
||||
"data": to_google_style(kwargs) if convert_google_style else kwargs
|
||||
"data": to_google_style(kwargs) if convert_google_style else kwargs,
|
||||
},
|
||||
dumps=functools.partial(json.dumps, cls=CustomEncoder),
|
||||
headers=headers)
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
# The cache value type used by aiohttp_cache.
|
||||
_AiohttpCacheValue = namedtuple("AiohttpCacheValue",
|
||||
["data", "expiration", "task"])
|
||||
_AiohttpCacheValue = namedtuple("AiohttpCacheValue", ["data", "expiration", "task"])
|
||||
# The methods with no request body used by aiohttp_cache.
|
||||
_AIOHTTP_CACHE_NOBODY_METHODS = {hdrs.METH_GET, hdrs.METH_DELETE}
|
||||
|
||||
|
||||
def aiohttp_cache(
|
||||
ttl_seconds=dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS,
|
||||
maxsize=dashboard_consts.AIOHTTP_CACHE_MAX_SIZE,
|
||||
enable=not env_bool(
|
||||
dashboard_consts.AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY, False)):
|
||||
ttl_seconds=dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS,
|
||||
maxsize=dashboard_consts.AIOHTTP_CACHE_MAX_SIZE,
|
||||
enable=not env_bool(dashboard_consts.AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY, False),
|
||||
):
|
||||
assert maxsize > 0
|
||||
cache = collections.OrderedDict()
|
||||
|
||||
|
@ -195,8 +198,7 @@ def aiohttp_cache(
|
|||
value = cache.get(key)
|
||||
if value is not None:
|
||||
cache.move_to_end(key)
|
||||
if (not value.task.done()
|
||||
or value.expiration >= time.time()):
|
||||
if not value.task.done() or value.expiration >= time.time():
|
||||
# Update task not done or the data is not expired.
|
||||
return aiohttp.web.Response(**value.data)
|
||||
|
||||
|
@ -205,15 +207,16 @@ def aiohttp_cache(
|
|||
response = task.result()
|
||||
except Exception:
|
||||
response = rest_response(
|
||||
success=False, message=traceback.format_exc())
|
||||
success=False, message=traceback.format_exc()
|
||||
)
|
||||
data = {
|
||||
"status": response.status,
|
||||
"headers": dict(response.headers),
|
||||
"body": response.body,
|
||||
}
|
||||
cache[key] = _AiohttpCacheValue(data,
|
||||
time.time() + ttl_seconds,
|
||||
task)
|
||||
cache[key] = _AiohttpCacheValue(
|
||||
data, time.time() + ttl_seconds, task
|
||||
)
|
||||
cache.move_to_end(key)
|
||||
if len(cache) > maxsize:
|
||||
cache.popitem(last=False)
|
||||
|
|
|
@ -19,11 +19,14 @@ import requests
|
|||
|
||||
from ray import ray_constants
|
||||
from ray._private.test_utils import (
|
||||
format_web_url, wait_for_condition, wait_until_server_available,
|
||||
run_string_as_driver, wait_until_succeeded_without_exception)
|
||||
format_web_url,
|
||||
wait_for_condition,
|
||||
wait_until_server_available,
|
||||
run_string_as_driver,
|
||||
wait_until_succeeded_without_exception,
|
||||
)
|
||||
from ray._private.gcs_pubsub import gcs_pubsub_enabled
|
||||
from ray.ray_constants import (DEBUG_AUTOSCALING_STATUS_LEGACY,
|
||||
DEBUG_AUTOSCALING_ERROR)
|
||||
from ray.ray_constants import DEBUG_AUTOSCALING_STATUS_LEGACY, DEBUG_AUTOSCALING_ERROR
|
||||
from ray.dashboard import dashboard
|
||||
import ray.dashboard.consts as dashboard_consts
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
|
@ -43,7 +46,8 @@ def make_gcs_client(address_info):
|
|||
client = redis.StrictRedis(
|
||||
host=address[0],
|
||||
port=int(address[1]),
|
||||
password=ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||
password=ray_constants.REDIS_DEFAULT_PASSWORD,
|
||||
)
|
||||
gcs_client = ray._private.gcs_utils.GcsClient.create_from_redis(client)
|
||||
else:
|
||||
address = address_info["gcs_address"]
|
||||
|
@ -73,17 +77,14 @@ cleanup_test_files()
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_with_dashboard", [{
|
||||
"_system_config": {
|
||||
"agent_register_timeout_ms": 5000
|
||||
}
|
||||
}],
|
||||
indirect=True)
|
||||
"ray_start_with_dashboard",
|
||||
[{"_system_config": {"agent_register_timeout_ms": 5000}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_basic(ray_start_with_dashboard):
|
||||
"""Dashboard test that starts a Ray cluster with a dashboard server running,
|
||||
then hits the dashboard API and asserts that it receives sensible data."""
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
address_info = ray_start_with_dashboard
|
||||
node_id = address_info["node_id"]
|
||||
gcs_client = make_gcs_client(address_info)
|
||||
|
@ -92,11 +93,12 @@ def test_basic(ray_start_with_dashboard):
|
|||
all_processes = ray.worker._global_node.all_processes
|
||||
assert ray_constants.PROCESS_TYPE_DASHBOARD in all_processes
|
||||
assert ray_constants.PROCESS_TYPE_REPORTER not in all_processes
|
||||
dashboard_proc_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][
|
||||
0]
|
||||
dashboard_proc_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
|
||||
dashboard_proc = psutil.Process(dashboard_proc_info.process.pid)
|
||||
assert dashboard_proc.status() in [
|
||||
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING, psutil.STATUS_DISK_SLEEP
|
||||
psutil.STATUS_RUNNING,
|
||||
psutil.STATUS_SLEEPING,
|
||||
psutil.STATUS_DISK_SLEEP,
|
||||
]
|
||||
raylet_proc_info = all_processes[ray_constants.PROCESS_TYPE_RAYLET][0]
|
||||
raylet_proc = psutil.Process(raylet_proc_info.process.pid)
|
||||
|
@ -140,9 +142,7 @@ def test_basic(ray_start_with_dashboard):
|
|||
|
||||
logger.info("Test agent register is OK.")
|
||||
wait_for_condition(lambda: _search_agent(raylet_proc.children()))
|
||||
assert dashboard_proc.status() in [
|
||||
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING
|
||||
]
|
||||
assert dashboard_proc.status() in [psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING]
|
||||
agent_proc = _search_agent(raylet_proc.children())
|
||||
agent_pid = agent_proc.pid
|
||||
|
||||
|
@ -161,40 +161,39 @@ def test_basic(ray_start_with_dashboard):
|
|||
# Check kv keys are set.
|
||||
logger.info("Check kv keys are set.")
|
||||
dashboard_address = ray.experimental.internal_kv._internal_kv_get(
|
||||
ray_constants.DASHBOARD_ADDRESS,
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
||||
ray_constants.DASHBOARD_ADDRESS, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
|
||||
)
|
||||
assert dashboard_address is not None
|
||||
dashboard_rpc_address = ray.experimental.internal_kv._internal_kv_get(
|
||||
dashboard_consts.DASHBOARD_RPC_ADDRESS,
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||
)
|
||||
assert dashboard_rpc_address is not None
|
||||
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}"
|
||||
agent_ports = ray.experimental.internal_kv._internal_kv_get(
|
||||
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
||||
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
|
||||
)
|
||||
assert agent_ports is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_with_dashboard", [{
|
||||
"dashboard_host": "127.0.0.1"
|
||||
}, {
|
||||
"dashboard_host": "0.0.0.0"
|
||||
}, {
|
||||
"dashboard_host": "::"
|
||||
}],
|
||||
indirect=True)
|
||||
"ray_start_with_dashboard",
|
||||
[
|
||||
{"dashboard_host": "127.0.0.1"},
|
||||
{"dashboard_host": "0.0.0.0"},
|
||||
{"dashboard_host": "::"},
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_dashboard_address(ray_start_with_dashboard):
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
webui_ip = webui_url.split(":")[0]
|
||||
assert not ipaddress.ip_address(webui_ip).is_unspecified
|
||||
assert webui_ip in [
|
||||
"127.0.0.1", ray_start_with_dashboard["node_ip_address"]
|
||||
]
|
||||
assert webui_ip in ["127.0.0.1", ray_start_with_dashboard["node_ip_address"]]
|
||||
|
||||
|
||||
def test_http_get(enable_test_module, ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
|
@ -205,8 +204,7 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
|
|||
while True:
|
||||
time.sleep(3)
|
||||
try:
|
||||
response = requests.get(webui_url + "/test/http_get?url=" +
|
||||
target_url)
|
||||
response = requests.get(webui_url + "/test/http_get?url=" + target_url)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
dump_info = response.json()
|
||||
|
@ -221,8 +219,8 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
|
|||
http_port, grpc_port = ports
|
||||
|
||||
response = requests.get(
|
||||
f"http://{ip}:{http_port}"
|
||||
f"/test/http_get_from_agent?url={target_url}")
|
||||
f"http://{ip}:{http_port}" f"/test/http_get_from_agent?url={target_url}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
dump_info = response.json()
|
||||
|
@ -239,10 +237,10 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
|
|||
|
||||
|
||||
def test_class_method_route_table(enable_test_module):
|
||||
head_cls_list = dashboard_utils.get_all_modules(
|
||||
dashboard_utils.DashboardHeadModule)
|
||||
head_cls_list = dashboard_utils.get_all_modules(dashboard_utils.DashboardHeadModule)
|
||||
agent_cls_list = dashboard_utils.get_all_modules(
|
||||
dashboard_utils.DashboardAgentModule)
|
||||
dashboard_utils.DashboardAgentModule
|
||||
)
|
||||
test_head_cls = None
|
||||
for cls in head_cls_list:
|
||||
if cls.__name__ == "TestHead":
|
||||
|
@ -274,28 +272,23 @@ def test_class_method_route_table(enable_test_module):
|
|||
assert any(_has_route(r, "POST", "/test/route_post") for r in all_routes)
|
||||
assert any(_has_route(r, "PUT", "/test/route_put") for r in all_routes)
|
||||
assert any(_has_route(r, "PATCH", "/test/route_patch") for r in all_routes)
|
||||
assert any(
|
||||
_has_route(r, "DELETE", "/test/route_delete") for r in all_routes)
|
||||
assert any(_has_route(r, "DELETE", "/test/route_delete") for r in all_routes)
|
||||
assert any(_has_route(r, "*", "/test/route_view") for r in all_routes)
|
||||
|
||||
# Test bind()
|
||||
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes(
|
||||
)
|
||||
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
|
||||
assert len(bound_routes) == 0
|
||||
dashboard_optional_utils.ClassMethodRouteTable.bind(
|
||||
test_agent_cls.__new__(test_agent_cls))
|
||||
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes(
|
||||
test_agent_cls.__new__(test_agent_cls)
|
||||
)
|
||||
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
|
||||
assert any(_has_route(r, "POST", "/test/route_post") for r in bound_routes)
|
||||
assert all(
|
||||
not _has_route(r, "PUT", "/test/route_put") for r in bound_routes)
|
||||
assert all(not _has_route(r, "PUT", "/test/route_put") for r in bound_routes)
|
||||
|
||||
# Static def should be in bound routes.
|
||||
routes.static("/test/route_static", "/path")
|
||||
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes(
|
||||
)
|
||||
assert any(
|
||||
_has_static(r, "/path", "/test/route_static") for r in bound_routes)
|
||||
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
|
||||
assert any(_has_static(r, "/path", "/test/route_static") for r in bound_routes)
|
||||
|
||||
# Test duplicated routes should raise exception.
|
||||
try:
|
||||
|
@ -358,10 +351,10 @@ def test_async_loop_forever():
|
|||
|
||||
|
||||
def test_dashboard_module_decorator(enable_test_module):
|
||||
head_cls_list = dashboard_utils.get_all_modules(
|
||||
dashboard_utils.DashboardHeadModule)
|
||||
head_cls_list = dashboard_utils.get_all_modules(dashboard_utils.DashboardHeadModule)
|
||||
agent_cls_list = dashboard_utils.get_all_modules(
|
||||
dashboard_utils.DashboardAgentModule)
|
||||
dashboard_utils.DashboardAgentModule
|
||||
)
|
||||
|
||||
assert any(cls.__name__ == "TestHead" for cls in head_cls_list)
|
||||
assert any(cls.__name__ == "TestAgent" for cls in agent_cls_list)
|
||||
|
@ -385,8 +378,7 @@ print("success")
|
|||
|
||||
|
||||
def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
|
@ -397,8 +389,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
|||
time.sleep(1)
|
||||
try:
|
||||
for x in range(10):
|
||||
response = requests.get(webui_url +
|
||||
"/test/aiohttp_cache/t1?value=1")
|
||||
response = requests.get(webui_url + "/test/aiohttp_cache/t1?value=1")
|
||||
response.raise_for_status()
|
||||
timestamp = response.json()["data"]["timestamp"]
|
||||
value1_timestamps.append(timestamp)
|
||||
|
@ -412,8 +403,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
|||
|
||||
sub_path_timestamps = []
|
||||
for x in range(10):
|
||||
response = requests.get(webui_url +
|
||||
f"/test/aiohttp_cache/tt{x}?value=1")
|
||||
response = requests.get(webui_url + f"/test/aiohttp_cache/tt{x}?value=1")
|
||||
response.raise_for_status()
|
||||
timestamp = response.json()["data"]["timestamp"]
|
||||
sub_path_timestamps.append(timestamp)
|
||||
|
@ -421,8 +411,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
|||
|
||||
volatile_value_timestamps = []
|
||||
for x in range(10):
|
||||
response = requests.get(webui_url +
|
||||
f"/test/aiohttp_cache/tt?value={x}")
|
||||
response = requests.get(webui_url + f"/test/aiohttp_cache/tt?value={x}")
|
||||
response.raise_for_status()
|
||||
timestamp = response.json()["data"]["timestamp"]
|
||||
volatile_value_timestamps.append(timestamp)
|
||||
|
@ -436,8 +425,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
|||
|
||||
volatile_value_timestamps = []
|
||||
for x in range(10):
|
||||
response = requests.get(webui_url +
|
||||
f"/test/aiohttp_cache_lru/tt{x % 4}")
|
||||
response = requests.get(webui_url + f"/test/aiohttp_cache_lru/tt{x % 4}")
|
||||
response.raise_for_status()
|
||||
timestamp = response.json()["data"]["timestamp"]
|
||||
volatile_value_timestamps.append(timestamp)
|
||||
|
@ -446,8 +434,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
|||
volatile_value_timestamps = []
|
||||
data = collections.defaultdict(set)
|
||||
for x in [0, 1, 2, 3, 4, 5, 2, 1, 0, 3]:
|
||||
response = requests.get(webui_url +
|
||||
f"/test/aiohttp_cache_lru/t1?value={x}")
|
||||
response = requests.get(webui_url + f"/test/aiohttp_cache_lru/t1?value={x}")
|
||||
response.raise_for_status()
|
||||
timestamp = response.json()["data"]["timestamp"]
|
||||
data[x].add(timestamp)
|
||||
|
@ -458,8 +445,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
|
|||
|
||||
|
||||
def test_get_cluster_status(ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
address_info = ray_start_with_dashboard
|
||||
webui_url = address_info["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
@ -478,14 +464,15 @@ def test_get_cluster_status(ray_start_with_dashboard):
|
|||
assert "loadMetricsReport" in response.json()["data"]["clusterStatus"]
|
||||
|
||||
assert wait_until_succeeded_without_exception(
|
||||
get_cluster_status, (requests.RequestException, ))
|
||||
get_cluster_status, (requests.RequestException,)
|
||||
)
|
||||
|
||||
gcs_client = make_gcs_client(address_info)
|
||||
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
|
||||
ray.experimental.internal_kv._internal_kv_put(
|
||||
DEBUG_AUTOSCALING_STATUS_LEGACY, "hello")
|
||||
ray.experimental.internal_kv._internal_kv_put(DEBUG_AUTOSCALING_ERROR,
|
||||
"world")
|
||||
DEBUG_AUTOSCALING_STATUS_LEGACY, "hello"
|
||||
)
|
||||
ray.experimental.internal_kv._internal_kv_put(DEBUG_AUTOSCALING_ERROR, "world")
|
||||
|
||||
response = requests.get(f"{webui_url}/api/cluster_status")
|
||||
response.raise_for_status()
|
||||
|
@ -508,20 +495,19 @@ def test_immutable_types():
|
|||
assert immutable_dict == dashboard_utils.ImmutableDict(d)
|
||||
assert immutable_dict == d
|
||||
assert dashboard_utils.ImmutableDict(immutable_dict) == immutable_dict
|
||||
assert dashboard_utils.ImmutableList(
|
||||
immutable_dict["list"]) == immutable_dict["list"]
|
||||
assert (
|
||||
dashboard_utils.ImmutableList(immutable_dict["list"]) == immutable_dict["list"]
|
||||
)
|
||||
assert "512" in d
|
||||
assert "512" in d["list"][0]
|
||||
assert "512" in d["dict"]
|
||||
|
||||
# Test type conversion
|
||||
assert type(dict(immutable_dict)["list"]) == dashboard_utils.ImmutableList
|
||||
assert type(list(
|
||||
immutable_dict["list"])[0]) == dashboard_utils.ImmutableDict
|
||||
assert type(list(immutable_dict["list"])[0]) == dashboard_utils.ImmutableDict
|
||||
|
||||
# Test json dumps / loads
|
||||
json_str = json.dumps(
|
||||
immutable_dict, cls=dashboard_optional_utils.CustomEncoder)
|
||||
json_str = json.dumps(immutable_dict, cls=dashboard_optional_utils.CustomEncoder)
|
||||
deserialized_immutable_dict = json.loads(json_str)
|
||||
assert type(deserialized_immutable_dict) == dict
|
||||
assert type(deserialized_immutable_dict["list"]) == list
|
||||
|
@ -577,7 +563,7 @@ def test_immutable_types():
|
|||
|
||||
def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
|
||||
address_info = ray.init(num_cpus=1, include_dashboard=True)
|
||||
assert (wait_until_server_available(address_info["webui_url"]) is True)
|
||||
assert wait_until_server_available(address_info["webui_url"]) is True
|
||||
|
||||
webui_url = address_info["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
@ -588,11 +574,8 @@ def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
|
|||
time.sleep(1)
|
||||
try:
|
||||
response = requests.get(
|
||||
webui_url + "/test/dump",
|
||||
proxies={
|
||||
"http": None,
|
||||
"https": None
|
||||
})
|
||||
webui_url + "/test/dump", proxies={"http": None, "https": None}
|
||||
)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
response.json()
|
||||
|
@ -609,8 +592,7 @@ def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
|
|||
|
||||
|
||||
def test_dashboard_port_conflict(ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
address_info = ray_start_with_dashboard
|
||||
gcs_client = make_gcs_client(address_info)
|
||||
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
|
||||
|
@ -618,11 +600,15 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
|
|||
temp_dir = "/tmp/ray"
|
||||
log_dir = "/tmp/ray/session_latest/logs"
|
||||
dashboard_cmd = [
|
||||
sys.executable, dashboard.__file__, f"--host={host}", f"--port={port}",
|
||||
f"--temp-dir={temp_dir}", f"--log-dir={log_dir}",
|
||||
sys.executable,
|
||||
dashboard.__file__,
|
||||
f"--host={host}",
|
||||
f"--port={port}",
|
||||
f"--temp-dir={temp_dir}",
|
||||
f"--log-dir={log_dir}",
|
||||
f"--redis-address={address_info['redis_address']}",
|
||||
f"--redis-password={ray_constants.REDIS_DEFAULT_PASSWORD}",
|
||||
f"--gcs-address={address_info['gcs_address']}"
|
||||
f"--gcs-address={address_info['gcs_address']}",
|
||||
]
|
||||
logger.info("The dashboard should be exit: %s", dashboard_cmd)
|
||||
p = subprocess.Popen(dashboard_cmd)
|
||||
|
@ -638,7 +624,8 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
|
|||
try:
|
||||
dashboard_url = ray.experimental.internal_kv._internal_kv_get(
|
||||
ray_constants.DASHBOARD_ADDRESS,
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
|
||||
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
|
||||
)
|
||||
if dashboard_url:
|
||||
new_port = int(dashboard_url.split(b":")[-1])
|
||||
assert new_port > int(port)
|
||||
|
@ -651,8 +638,7 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
|
|||
|
||||
|
||||
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
|
||||
|
||||
all_processes = ray.worker._global_node.all_processes
|
||||
dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
|
||||
|
@ -661,7 +647,9 @@ def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
|
|||
gcs_server_proc = psutil.Process(gcs_server_info.process.pid)
|
||||
|
||||
assert dashboard_proc.status() in [
|
||||
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING, psutil.STATUS_DISK_SLEEP
|
||||
psutil.STATUS_RUNNING,
|
||||
psutil.STATUS_SLEEPING,
|
||||
psutil.STATUS_DISK_SLEEP,
|
||||
]
|
||||
|
||||
gcs_server_proc.kill()
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
import ray
|
||||
from ray.dashboard.memory_utils import (
|
||||
ReferenceType, decode_object_ref_if_needed, MemoryTableEntry, MemoryTable,
|
||||
SortingType)
|
||||
ReferenceType,
|
||||
decode_object_ref_if_needed,
|
||||
MemoryTableEntry,
|
||||
MemoryTable,
|
||||
SortingType,
|
||||
)
|
||||
|
||||
"""Memory Table Unit Test"""
|
||||
|
||||
NODE_ADDRESS = "127.0.0.1"
|
||||
|
@ -14,15 +19,17 @@ DECODED_ID = decode_object_ref_if_needed(OBJECT_ID)
|
|||
OBJECT_SIZE = 100
|
||||
|
||||
|
||||
def build_memory_entry(*,
|
||||
local_ref_count,
|
||||
pinned_in_memory,
|
||||
submitted_task_reference_count,
|
||||
contained_in_owned,
|
||||
object_size,
|
||||
pid,
|
||||
object_id=OBJECT_ID,
|
||||
node_address=NODE_ADDRESS):
|
||||
def build_memory_entry(
|
||||
*,
|
||||
local_ref_count,
|
||||
pinned_in_memory,
|
||||
submitted_task_reference_count,
|
||||
contained_in_owned,
|
||||
object_size,
|
||||
pid,
|
||||
object_id=OBJECT_ID,
|
||||
node_address=NODE_ADDRESS
|
||||
):
|
||||
object_ref = {
|
||||
"objectId": object_id,
|
||||
"callSite": "(task call) /Users:458",
|
||||
|
@ -30,18 +37,16 @@ def build_memory_entry(*,
|
|||
"localRefCount": local_ref_count,
|
||||
"pinnedInMemory": pinned_in_memory,
|
||||
"submittedTaskRefCount": submitted_task_reference_count,
|
||||
"containedInOwned": contained_in_owned
|
||||
"containedInOwned": contained_in_owned,
|
||||
}
|
||||
return MemoryTableEntry(
|
||||
object_ref=object_ref,
|
||||
node_address=node_address,
|
||||
is_driver=IS_DRIVER,
|
||||
pid=pid)
|
||||
object_ref=object_ref, node_address=node_address, is_driver=IS_DRIVER, pid=pid
|
||||
)
|
||||
|
||||
|
||||
def build_local_reference_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS):
|
||||
def build_local_reference_entry(
|
||||
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
|
||||
):
|
||||
return build_memory_entry(
|
||||
local_ref_count=1,
|
||||
pinned_in_memory=False,
|
||||
|
@ -49,12 +54,13 @@ def build_local_reference_entry(object_size=OBJECT_SIZE,
|
|||
contained_in_owned=[],
|
||||
object_size=object_size,
|
||||
pid=pid,
|
||||
node_address=node_address)
|
||||
node_address=node_address,
|
||||
)
|
||||
|
||||
|
||||
def build_used_by_pending_task_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS):
|
||||
def build_used_by_pending_task_entry(
|
||||
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
|
||||
):
|
||||
return build_memory_entry(
|
||||
local_ref_count=0,
|
||||
pinned_in_memory=False,
|
||||
|
@ -62,12 +68,13 @@ def build_used_by_pending_task_entry(object_size=OBJECT_SIZE,
|
|||
contained_in_owned=[],
|
||||
object_size=object_size,
|
||||
pid=pid,
|
||||
node_address=node_address)
|
||||
node_address=node_address,
|
||||
)
|
||||
|
||||
|
||||
def build_captured_in_object_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS):
|
||||
def build_captured_in_object_entry(
|
||||
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
|
||||
):
|
||||
return build_memory_entry(
|
||||
local_ref_count=0,
|
||||
pinned_in_memory=False,
|
||||
|
@ -75,12 +82,13 @@ def build_captured_in_object_entry(object_size=OBJECT_SIZE,
|
|||
contained_in_owned=[OBJECT_ID],
|
||||
object_size=object_size,
|
||||
pid=pid,
|
||||
node_address=node_address)
|
||||
node_address=node_address,
|
||||
)
|
||||
|
||||
|
||||
def build_actor_handle_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS):
|
||||
def build_actor_handle_entry(
|
||||
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
|
||||
):
|
||||
return build_memory_entry(
|
||||
local_ref_count=1,
|
||||
pinned_in_memory=False,
|
||||
|
@ -89,12 +97,13 @@ def build_actor_handle_entry(object_size=OBJECT_SIZE,
|
|||
object_size=object_size,
|
||||
pid=pid,
|
||||
node_address=node_address,
|
||||
object_id=ACTOR_ID)
|
||||
object_id=ACTOR_ID,
|
||||
)
|
||||
|
||||
|
||||
def build_pinned_in_memory_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS):
|
||||
def build_pinned_in_memory_entry(
|
||||
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
|
||||
):
|
||||
return build_memory_entry(
|
||||
local_ref_count=0,
|
||||
pinned_in_memory=True,
|
||||
|
@ -102,28 +111,36 @@ def build_pinned_in_memory_entry(object_size=OBJECT_SIZE,
|
|||
contained_in_owned=[],
|
||||
object_size=object_size,
|
||||
pid=pid,
|
||||
node_address=node_address)
|
||||
node_address=node_address,
|
||||
)
|
||||
|
||||
|
||||
def build_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS,
|
||||
reference_type=ReferenceType.PINNED_IN_MEMORY):
|
||||
def build_entry(
|
||||
object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS,
|
||||
reference_type=ReferenceType.PINNED_IN_MEMORY,
|
||||
):
|
||||
if reference_type == ReferenceType.USED_BY_PENDING_TASK:
|
||||
return build_used_by_pending_task_entry(
|
||||
pid=pid, object_size=object_size, node_address=node_address)
|
||||
pid=pid, object_size=object_size, node_address=node_address
|
||||
)
|
||||
elif reference_type == ReferenceType.LOCAL_REFERENCE:
|
||||
return build_local_reference_entry(
|
||||
pid=pid, object_size=object_size, node_address=node_address)
|
||||
pid=pid, object_size=object_size, node_address=node_address
|
||||
)
|
||||
elif reference_type == ReferenceType.PINNED_IN_MEMORY:
|
||||
return build_pinned_in_memory_entry(
|
||||
pid=pid, object_size=object_size, node_address=node_address)
|
||||
pid=pid, object_size=object_size, node_address=node_address
|
||||
)
|
||||
elif reference_type == ReferenceType.ACTOR_HANDLE:
|
||||
return build_actor_handle_entry(
|
||||
pid=pid, object_size=object_size, node_address=node_address)
|
||||
pid=pid, object_size=object_size, node_address=node_address
|
||||
)
|
||||
elif reference_type == ReferenceType.CAPTURED_IN_OBJECT:
|
||||
return build_captured_in_object_entry(
|
||||
pid=pid, object_size=object_size, node_address=node_address)
|
||||
pid=pid, object_size=object_size, node_address=node_address
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_memory_entry():
|
||||
|
@ -133,7 +150,8 @@ def test_invalid_memory_entry():
|
|||
submitted_task_reference_count=0,
|
||||
contained_in_owned=[],
|
||||
object_size=OBJECT_SIZE,
|
||||
pid=PID)
|
||||
pid=PID,
|
||||
)
|
||||
assert memory_entry.is_valid() is False
|
||||
memory_entry = build_memory_entry(
|
||||
local_ref_count=0,
|
||||
|
@ -141,7 +159,8 @@ def test_invalid_memory_entry():
|
|||
submitted_task_reference_count=0,
|
||||
contained_in_owned=[],
|
||||
object_size=-1,
|
||||
pid=PID)
|
||||
pid=PID,
|
||||
)
|
||||
assert memory_entry.is_valid() is False
|
||||
|
||||
|
||||
|
@ -149,7 +168,8 @@ def test_valid_reference_memory_entry():
|
|||
memory_entry = build_local_reference_entry()
|
||||
assert memory_entry.reference_type == ReferenceType.LOCAL_REFERENCE
|
||||
assert memory_entry.object_ref == ray.ObjectRef(
|
||||
decode_object_ref_if_needed(OBJECT_ID))
|
||||
decode_object_ref_if_needed(OBJECT_ID)
|
||||
)
|
||||
assert memory_entry.is_valid() is True
|
||||
|
||||
|
||||
|
@ -178,15 +198,14 @@ def test_memory_table_summary():
|
|||
build_captured_in_object_entry(),
|
||||
build_actor_handle_entry(),
|
||||
build_local_reference_entry(),
|
||||
build_local_reference_entry()
|
||||
build_local_reference_entry(),
|
||||
]
|
||||
memory_table = MemoryTable(entries)
|
||||
assert len(memory_table.group) == 1
|
||||
assert memory_table.summary["total_actor_handles"] == 1
|
||||
assert memory_table.summary["total_captured_in_objects"] == 1
|
||||
assert memory_table.summary["total_local_ref_count"] == 2
|
||||
assert memory_table.summary[
|
||||
"total_object_size"] == len(entries) * OBJECT_SIZE
|
||||
assert memory_table.summary["total_object_size"] == len(entries) * OBJECT_SIZE
|
||||
assert memory_table.summary["total_pinned_in_memory"] == 1
|
||||
assert memory_table.summary["total_used_by_pending_task"] == 1
|
||||
|
||||
|
@ -202,14 +221,13 @@ def test_memory_table_sort_by_pid():
|
|||
|
||||
def test_memory_table_sort_by_reference_type():
|
||||
unsort = [
|
||||
ReferenceType.USED_BY_PENDING_TASK, ReferenceType.LOCAL_REFERENCE,
|
||||
ReferenceType.LOCAL_REFERENCE, ReferenceType.PINNED_IN_MEMORY
|
||||
ReferenceType.USED_BY_PENDING_TASK,
|
||||
ReferenceType.LOCAL_REFERENCE,
|
||||
ReferenceType.LOCAL_REFERENCE,
|
||||
ReferenceType.PINNED_IN_MEMORY,
|
||||
]
|
||||
entries = [
|
||||
build_entry(reference_type=reference_type) for reference_type in unsort
|
||||
]
|
||||
memory_table = MemoryTable(
|
||||
entries, sort_by_type=SortingType.REFERENCE_TYPE)
|
||||
entries = [build_entry(reference_type=reference_type) for reference_type in unsort]
|
||||
memory_table = MemoryTable(entries, sort_by_type=SortingType.REFERENCE_TYPE)
|
||||
sort = sorted(unsort)
|
||||
for reference_type, entry in zip(sort, memory_table.table):
|
||||
assert reference_type == entry.reference_type
|
||||
|
@ -231,7 +249,7 @@ def test_group_by():
|
|||
build_entry(node_address=node_second, pid=2),
|
||||
build_entry(node_address=node_second, pid=1),
|
||||
build_entry(node_address=node_first, pid=2),
|
||||
build_entry(node_address=node_first, pid=1)
|
||||
build_entry(node_address=node_first, pid=1),
|
||||
]
|
||||
memory_table = MemoryTable(entries)
|
||||
|
||||
|
@ -250,4 +268,5 @@ def test_group_by():
|
|||
if __name__ == "__main__":
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -18,8 +18,7 @@ import aiosignal # noqa: F401
|
|||
from google.protobuf.json_format import MessageToDict
|
||||
from frozenlist import FrozenList # noqa: F401
|
||||
|
||||
from ray._private.utils import (binary_to_hex,
|
||||
check_dashboard_dependencies_installed)
|
||||
from ray._private.utils import binary_to_hex, check_dashboard_dependencies_installed
|
||||
|
||||
try:
|
||||
create_task = asyncio.create_task
|
||||
|
@ -97,23 +96,26 @@ def get_all_modules(module_type):
|
|||
"""
|
||||
logger.info(f"Get all modules by type: {module_type.__name__}")
|
||||
import ray.dashboard.modules
|
||||
should_only_load_minimal_modules = (
|
||||
not check_dashboard_dependencies_installed())
|
||||
|
||||
should_only_load_minimal_modules = not check_dashboard_dependencies_installed()
|
||||
|
||||
for module_loader, name, ispkg in pkgutil.walk_packages(
|
||||
ray.dashboard.modules.__path__,
|
||||
ray.dashboard.modules.__name__ + "."):
|
||||
ray.dashboard.modules.__path__, ray.dashboard.modules.__name__ + "."
|
||||
):
|
||||
try:
|
||||
importlib.import_module(name)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info(f"Module {name} cannot be loaded because "
|
||||
"we cannot import all dependencies. Download "
|
||||
"`pip install ray[default]` for the full "
|
||||
f"dashboard functionality. Error: {e}")
|
||||
logger.info(
|
||||
f"Module {name} cannot be loaded because "
|
||||
"we cannot import all dependencies. Download "
|
||||
"`pip install ray[default]` for the full "
|
||||
f"dashboard functionality. Error: {e}"
|
||||
)
|
||||
if not should_only_load_minimal_modules:
|
||||
logger.info(
|
||||
"Although `pip install ray[default] is downloaded, "
|
||||
"module couldn't be imported`")
|
||||
"module couldn't be imported`"
|
||||
)
|
||||
raise e
|
||||
|
||||
imported_modules = []
|
||||
|
@ -202,7 +204,8 @@ def message_to_dict(message, decode_keys=None, **kwargs):
|
|||
|
||||
if decode_keys:
|
||||
return _decode_keys(
|
||||
MessageToDict(message, use_integers_for_enums=False, **kwargs))
|
||||
MessageToDict(message, use_integers_for_enums=False, **kwargs)
|
||||
)
|
||||
else:
|
||||
return MessageToDict(message, use_integers_for_enums=False, **kwargs)
|
||||
|
||||
|
@ -251,8 +254,9 @@ class Change:
|
|||
self.new = new
|
||||
|
||||
def __str__(self):
|
||||
return f"Change(owner: {type(self.owner)}), " \
|
||||
f"old: {self.old}, new: {self.new}"
|
||||
return (
|
||||
f"Change(owner: {type(self.owner)}), " f"old: {self.old}, new: {self.new}"
|
||||
)
|
||||
|
||||
|
||||
class NotifyQueue:
|
||||
|
@ -289,10 +293,7 @@ https://docs.python.org/3/library/json.html?highlight=json#json.JSONEncoder
|
|||
| None | null |
|
||||
+-------------------+---------------+
|
||||
"""
|
||||
_json_compatible_types = {
|
||||
dict, list, tuple, str, int, float, bool,
|
||||
type(None), bytes
|
||||
}
|
||||
_json_compatible_types = {dict, list, tuple, str, int, float, bool, type(None), bytes}
|
||||
|
||||
|
||||
def is_immutable(self):
|
||||
|
@ -318,8 +319,7 @@ class Immutable(metaclass=ABCMeta):
|
|||
|
||||
|
||||
class ImmutableList(Immutable, Sequence):
|
||||
"""Makes a :class:`list` immutable.
|
||||
"""
|
||||
"""Makes a :class:`list` immutable."""
|
||||
|
||||
__slots__ = ("_list", "_proxy")
|
||||
|
||||
|
@ -332,7 +332,7 @@ class ImmutableList(Immutable, Sequence):
|
|||
self._proxy = [None] * len(list_value)
|
||||
|
||||
def __reduce_ex__(self, protocol):
|
||||
return type(self), (self._list, )
|
||||
return type(self), (self._list,)
|
||||
|
||||
def mutable(self):
|
||||
return self._list
|
||||
|
@ -366,8 +366,7 @@ class ImmutableList(Immutable, Sequence):
|
|||
|
||||
|
||||
class ImmutableDict(Immutable, Mapping):
|
||||
"""Makes a :class:`dict` immutable.
|
||||
"""
|
||||
"""Makes a :class:`dict` immutable."""
|
||||
|
||||
__slots__ = ("_dict", "_proxy")
|
||||
|
||||
|
@ -380,7 +379,7 @@ class ImmutableDict(Immutable, Mapping):
|
|||
self._proxy = {}
|
||||
|
||||
def __reduce_ex__(self, protocol):
|
||||
return type(self), (self._dict, )
|
||||
return type(self), (self._dict,)
|
||||
|
||||
def mutable(self):
|
||||
return self._dict
|
||||
|
@ -443,21 +442,23 @@ class Dict(ImmutableDict, MutableMapping):
|
|||
if len(self.signal) and old != value:
|
||||
if old is None:
|
||||
co = self.signal.send(
|
||||
Change(owner=self, new=Dict.ChangeItem(key, value)))
|
||||
Change(owner=self, new=Dict.ChangeItem(key, value))
|
||||
)
|
||||
else:
|
||||
co = self.signal.send(
|
||||
Change(
|
||||
owner=self,
|
||||
old=Dict.ChangeItem(key, old),
|
||||
new=Dict.ChangeItem(key, value)))
|
||||
new=Dict.ChangeItem(key, value),
|
||||
)
|
||||
)
|
||||
NotifyQueue.put(co)
|
||||
|
||||
def __delitem__(self, key):
|
||||
old = self._dict.pop(key, None)
|
||||
self._proxy.pop(key, None)
|
||||
if len(self.signal) and old is not None:
|
||||
co = self.signal.send(
|
||||
Change(owner=self, old=Dict.ChangeItem(key, old)))
|
||||
co = self.signal.send(Change(owner=self, old=Dict.ChangeItem(key, old)))
|
||||
NotifyQueue.put(co)
|
||||
|
||||
def reset(self, d):
|
||||
|
@ -482,12 +483,15 @@ def async_loop_forever(interval_seconds, cancellable=False):
|
|||
await coro(*args, **kwargs)
|
||||
except asyncio.CancelledError as ex:
|
||||
if cancellable:
|
||||
logger.info(f"An async loop forever coroutine "
|
||||
f"is cancelled {coro}.")
|
||||
logger.info(
|
||||
f"An async loop forever coroutine " f"is cancelled {coro}."
|
||||
)
|
||||
raise ex
|
||||
else:
|
||||
logger.exception(f"Can not cancel the async loop "
|
||||
f"forever coroutine {coro}.")
|
||||
logger.exception(
|
||||
f"Can not cancel the async loop "
|
||||
f"forever coroutine {coro}."
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Error looping coroutine {coro}.")
|
||||
await asyncio.sleep(interval_seconds)
|
||||
|
@ -497,15 +501,18 @@ def async_loop_forever(interval_seconds, cancellable=False):
|
|||
return _wrapper
|
||||
|
||||
|
||||
async def get_aioredis_client(redis_address, redis_password,
|
||||
retry_interval_seconds, retry_times):
|
||||
async def get_aioredis_client(
|
||||
redis_address, redis_password, retry_interval_seconds, retry_times
|
||||
):
|
||||
for x in range(retry_times):
|
||||
try:
|
||||
return await aioredis.create_redis_pool(
|
||||
address=redis_address, password=redis_password)
|
||||
address=redis_address, password=redis_password
|
||||
)
|
||||
except (socket.gaierror, ConnectionError) as ex:
|
||||
logger.error("Connect to Redis failed: %s, retry...", ex)
|
||||
await asyncio.sleep(retry_interval_seconds)
|
||||
# Raise exception from create_redis_pool
|
||||
return await aioredis.create_redis_pool(
|
||||
address=redis_address, password=redis_password)
|
||||
address=redis_address, password=redis_password
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@ from collections import Counter
|
|||
import sys
|
||||
import time
|
||||
import ray
|
||||
|
||||
""" This script is meant to be run from a pod in the same Kubernetes namespace
|
||||
as your Ray cluster.
|
||||
"""
|
||||
|
@ -11,8 +12,9 @@ as your Ray cluster.
|
|||
def gethostname(x):
|
||||
import platform
|
||||
import time
|
||||
|
||||
time.sleep(0.01)
|
||||
return x + (platform.node(), )
|
||||
return x + (platform.node(),)
|
||||
|
||||
|
||||
def wait_for_nodes(expected):
|
||||
|
@ -22,8 +24,11 @@ def wait_for_nodes(expected):
|
|||
node_keys = [key for key in resources if "node" in key]
|
||||
num_nodes = sum(resources[node_key] for node_key in node_keys)
|
||||
if num_nodes < expected:
|
||||
print("{} nodes have joined so far, waiting for {} more.".format(
|
||||
num_nodes, expected - num_nodes))
|
||||
print(
|
||||
"{} nodes have joined so far, waiting for {} more.".format(
|
||||
num_nodes, expected - num_nodes
|
||||
)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
time.sleep(1)
|
||||
else:
|
||||
|
@ -36,9 +41,7 @@ def main():
|
|||
# Check that objects can be transferred from each node to each other node.
|
||||
for i in range(10):
|
||||
print("Iteration {}".format(i))
|
||||
results = [
|
||||
gethostname.remote(gethostname.remote(())) for _ in range(100)
|
||||
]
|
||||
results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
|
||||
print(Counter(ray.get(results)))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ from collections import Counter
|
|||
import sys
|
||||
import time
|
||||
import ray
|
||||
|
||||
""" Run this script locally to execute a Ray program on your Ray cluster on
|
||||
Kubernetes.
|
||||
|
||||
|
@ -18,8 +19,9 @@ LOCAL_PORT = 10001
|
|||
def gethostname(x):
|
||||
import platform
|
||||
import time
|
||||
|
||||
time.sleep(0.01)
|
||||
return x + (platform.node(), )
|
||||
return x + (platform.node(),)
|
||||
|
||||
|
||||
def wait_for_nodes(expected):
|
||||
|
@ -29,8 +31,11 @@ def wait_for_nodes(expected):
|
|||
node_keys = [key for key in resources if "node" in key]
|
||||
num_nodes = sum(resources[node_key] for node_key in node_keys)
|
||||
if num_nodes < expected:
|
||||
print("{} nodes have joined so far, waiting for {} more.".format(
|
||||
num_nodes, expected - num_nodes))
|
||||
print(
|
||||
"{} nodes have joined so far, waiting for {} more.".format(
|
||||
num_nodes, expected - num_nodes
|
||||
)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
time.sleep(1)
|
||||
else:
|
||||
|
@ -43,9 +48,7 @@ def main():
|
|||
# Check that objects can be transferred from each node to each other node.
|
||||
for i in range(10):
|
||||
print("Iteration {}".format(i))
|
||||
results = [
|
||||
gethostname.remote(gethostname.remote(())) for _ in range(100)
|
||||
]
|
||||
results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
|
||||
print(Counter(ray.get(results)))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
|
|
@ -10,8 +10,9 @@ import ray
|
|||
def gethostname(x):
|
||||
import platform
|
||||
import time
|
||||
|
||||
time.sleep(0.01)
|
||||
return x + (platform.node(), )
|
||||
return x + (platform.node(),)
|
||||
|
||||
|
||||
def wait_for_nodes(expected):
|
||||
|
@ -21,8 +22,11 @@ def wait_for_nodes(expected):
|
|||
node_keys = [key for key in resources if "node" in key]
|
||||
num_nodes = sum(resources[node_key] for node_key in node_keys)
|
||||
if num_nodes < expected:
|
||||
print("{} nodes have joined so far, waiting for {} more.".format(
|
||||
num_nodes, expected - num_nodes))
|
||||
print(
|
||||
"{} nodes have joined so far, waiting for {} more.".format(
|
||||
num_nodes, expected - num_nodes
|
||||
)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
time.sleep(1)
|
||||
else:
|
||||
|
@ -35,9 +39,7 @@ def main():
|
|||
# Check that objects can be transferred from each node to each other node.
|
||||
for i in range(10):
|
||||
print("Iteration {}".format(i))
|
||||
results = [
|
||||
gethostname.remote(gethostname.remote(())) for _ in range(100)
|
||||
]
|
||||
results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
|
||||
print(Counter(ray.get(results)))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
|
|
@ -25,24 +25,24 @@ if __name__ == "__main__":
|
|||
"--exp-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The job name and path to logging file (exp_name.log).")
|
||||
help="The job name and path to logging file (exp_name.log).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-nodes",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of nodes to use.")
|
||||
"--num-nodes", "-n", type=int, default=1, help="Number of nodes to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--node",
|
||||
"-w",
|
||||
type=str,
|
||||
help="The specified nodes to use. Same format as the "
|
||||
"return of 'sinfo'. Default: ''.")
|
||||
"return of 'sinfo'. Default: ''.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-gpus",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of GPUs to use in each node. (Default: 0)")
|
||||
help="Number of GPUs to use in each node. (Default: 0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
"-p",
|
||||
|
@ -51,14 +51,16 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"--load-env",
|
||||
type=str,
|
||||
help="The script to load your environment ('module load cuda/10.1')")
|
||||
help="The script to load your environment ('module load cuda/10.1')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--command",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The command you wish to execute. For example: "
|
||||
" --command 'python test.py'. "
|
||||
"Note that the command must be a string.")
|
||||
"Note that the command must be a string.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.node:
|
||||
|
@ -67,11 +69,13 @@ if __name__ == "__main__":
|
|||
else:
|
||||
node_info = ""
|
||||
|
||||
job_name = "{}_{}".format(args.exp_name,
|
||||
time.strftime("%m%d-%H%M", time.localtime()))
|
||||
job_name = "{}_{}".format(
|
||||
args.exp_name, time.strftime("%m%d-%H%M", time.localtime())
|
||||
)
|
||||
|
||||
partition_option = "#SBATCH --partition={}".format(
|
||||
args.partition) if args.partition else ""
|
||||
partition_option = (
|
||||
"#SBATCH --partition={}".format(args.partition) if args.partition else ""
|
||||
)
|
||||
|
||||
# ===== Modified the template script =====
|
||||
with open(template_file, "r") as f:
|
||||
|
@ -84,10 +88,10 @@ if __name__ == "__main__":
|
|||
text = text.replace(LOAD_ENV, str(args.load_env))
|
||||
text = text.replace(GIVEN_NODE, node_info)
|
||||
text = text.replace(
|
||||
"# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO "
|
||||
"PRODUCTION!",
|
||||
"# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO " "PRODUCTION!",
|
||||
"# THIS FILE IS MODIFIED AUTOMATICALLY FROM TEMPLATE AND SHOULD BE "
|
||||
"RUNNABLE!")
|
||||
"RUNNABLE!",
|
||||
)
|
||||
|
||||
# ===== Save the script =====
|
||||
script_file = "{}.sh".format(job_name)
|
||||
|
@ -99,5 +103,7 @@ if __name__ == "__main__":
|
|||
subprocess.Popen(["sbatch", script_file])
|
||||
print(
|
||||
"Job submitted! Script file is at: <{}>. Log file is at: <{}>".format(
|
||||
script_file, "{}.log".format(job_name)))
|
||||
script_file, "{}.log".format(job_name)
|
||||
)
|
||||
)
|
||||
sys.exit(0)
|
||||
|
|
|
@ -89,7 +89,7 @@ myst_enable_extensions = [
|
|||
]
|
||||
|
||||
external_toc_exclude_missing = False
|
||||
external_toc_path = '_toc.yml'
|
||||
external_toc_path = "_toc.yml"
|
||||
|
||||
# There's a flaky autodoc import for "TensorFlowVariables" that fails depending on the doc structure / order
|
||||
# of imports.
|
||||
|
@ -112,7 +112,8 @@ versionwarning_messages = {
|
|||
"<b>Got questions?</b> Join "
|
||||
f'<a href="{FORUM_LINK}">the Ray Community forum</a> '
|
||||
"for Q&A on all things Ray, as well as to share and learn use cases "
|
||||
"and best practices with the Ray community."),
|
||||
"and best practices with the Ray community."
|
||||
),
|
||||
}
|
||||
|
||||
versionwarning_body_selector = "#main-content"
|
||||
|
@ -189,11 +190,16 @@ exclude_patterns += sphinx_gallery_conf["examples_dirs"]
|
|||
# If "DOC_LIB" is found, only build that top-level navigation item.
|
||||
build_one_lib = os.getenv("DOC_LIB")
|
||||
|
||||
all_toc_libs = [
|
||||
f.path for f in os.scandir(".") if f.is_dir() and "ray-" in f.path
|
||||
]
|
||||
all_toc_libs = [f.path for f in os.scandir(".") if f.is_dir() and "ray-" in f.path]
|
||||
all_toc_libs += [
|
||||
"cluster", "tune", "data", "raysgd", "train", "rllib", "serve", "workflows"
|
||||
"cluster",
|
||||
"tune",
|
||||
"data",
|
||||
"raysgd",
|
||||
"train",
|
||||
"rllib",
|
||||
"serve",
|
||||
"workflows",
|
||||
]
|
||||
if build_one_lib and build_one_lib in all_toc_libs:
|
||||
all_toc_libs.remove(build_one_lib)
|
||||
|
@ -405,7 +411,8 @@ def setup(app):
|
|||
# Custom JS
|
||||
app.add_js_file(
|
||||
"https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js",
|
||||
defer="defer")
|
||||
defer="defer",
|
||||
)
|
||||
app.add_js_file("js/docsearch.js", defer="defer")
|
||||
# Custom Sphinx directives
|
||||
app.add_directive("customgalleryitem", CustomGalleryItemDirective)
|
||||
|
|
|
@ -6,13 +6,17 @@ from docutils import nodes
|
|||
import os
|
||||
import sphinx_gallery
|
||||
import urllib
|
||||
|
||||
# Note: the scipy import has to stay here, it's used implicitly down the line
|
||||
import scipy.stats # noqa: F401
|
||||
import scipy.linalg # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"CustomGalleryItemDirective", "fix_xgb_lgbm_docs", "MOCK_MODULES",
|
||||
"CHILD_MOCK_MODULES", "update_context"
|
||||
"CustomGalleryItemDirective",
|
||||
"fix_xgb_lgbm_docs",
|
||||
"MOCK_MODULES",
|
||||
"CHILD_MOCK_MODULES",
|
||||
"update_context",
|
||||
]
|
||||
|
||||
try:
|
||||
|
@ -60,7 +64,7 @@ class CustomGalleryItemDirective(Directive):
|
|||
option_spec = {
|
||||
"tooltip": directives.unchanged,
|
||||
"figure": directives.unchanged,
|
||||
"description": directives.unchanged
|
||||
"description": directives.unchanged,
|
||||
}
|
||||
|
||||
has_content = False
|
||||
|
@ -73,8 +77,9 @@ class CustomGalleryItemDirective(Directive):
|
|||
if len(self.options["tooltip"]) > 195:
|
||||
tooltip = tooltip[:195] + "..."
|
||||
else:
|
||||
raise ValueError("Need to provide :tooltip: under "
|
||||
"`.. customgalleryitem::`.")
|
||||
raise ValueError(
|
||||
"Need to provide :tooltip: under " "`.. customgalleryitem::`."
|
||||
)
|
||||
|
||||
# Generate `thumbnail` used in the gallery.
|
||||
if "figure" in self.options:
|
||||
|
@ -95,11 +100,13 @@ class CustomGalleryItemDirective(Directive):
|
|||
if "description" in self.options:
|
||||
description = self.options["description"]
|
||||
else:
|
||||
raise ValueError("Need to provide :description: under "
|
||||
"`customgalleryitem::`.")
|
||||
raise ValueError(
|
||||
"Need to provide :description: under " "`customgalleryitem::`."
|
||||
)
|
||||
|
||||
thumbnail_rst = GALLERY_TEMPLATE.format(
|
||||
tooltip=tooltip, thumbnail=thumbnail, description=description)
|
||||
tooltip=tooltip, thumbnail=thumbnail, description=description
|
||||
)
|
||||
thumbnail = StringList(thumbnail_rst.split("\n"))
|
||||
thumb = nodes.paragraph()
|
||||
self.state.nested_parse(thumbnail, self.content_offset, thumb)
|
||||
|
@ -146,29 +153,30 @@ def fix_xgb_lgbm_docs(app, what, name, obj, options, lines):
|
|||
|
||||
|
||||
# Taken from https://github.com/edx/edx-documentation
|
||||
FEEDBACK_FORM_FMT = "https://github.com/ray-project/ray/issues/new?" \
|
||||
"title={title}&labels=docs&body={body}"
|
||||
FEEDBACK_FORM_FMT = (
|
||||
"https://github.com/ray-project/ray/issues/new?"
|
||||
"title={title}&labels=docs&body={body}"
|
||||
)
|
||||
|
||||
|
||||
def feedback_form_url(project, page):
|
||||
"""Create a URL for feedback on a particular page in a project."""
|
||||
return FEEDBACK_FORM_FMT.format(
|
||||
title=urllib.parse.quote(
|
||||
"[docs] Issue on `{page}.rst`".format(page=page)),
|
||||
title=urllib.parse.quote("[docs] Issue on `{page}.rst`".format(page=page)),
|
||||
body=urllib.parse.quote(
|
||||
"# Documentation Problem/Question/Comment\n"
|
||||
"<!-- Describe your issue/question/comment below. -->\n"
|
||||
"<!-- If there are typos or errors in the docs, feel free "
|
||||
"to create a pull-request. -->\n"
|
||||
"\n\n\n\n"
|
||||
"(Created directly from the docs)\n"),
|
||||
"(Created directly from the docs)\n"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def update_context(app, pagename, templatename, context, doctree):
|
||||
"""Update the page rendering context to include ``feedback_form_url``."""
|
||||
context["feedback_form_url"] = feedback_form_url(app.config.project,
|
||||
pagename)
|
||||
context["feedback_form_url"] = feedback_form_url(app.config.project, pagename)
|
||||
|
||||
|
||||
MOCK_MODULES = [
|
||||
|
@ -187,8 +195,7 @@ MOCK_MODULES = [
|
|||
"horovod.ray.runner",
|
||||
"horovod.ray.utils",
|
||||
"hyperopt",
|
||||
"hyperopt.hp"
|
||||
"kubernetes",
|
||||
"hyperopt.hp" "kubernetes",
|
||||
"mlflow",
|
||||
"modin",
|
||||
"mxnet",
|
||||
|
|
|
@ -59,13 +59,16 @@ from ray.data.datasource.datasource import RandomIntRowDatasource
|
|||
# Let’s see how we implement such pipeline using Ray Dataset:
|
||||
|
||||
|
||||
def create_shuffle_pipeline(training_data_dir: str, num_epochs: int,
|
||||
num_shards: int) -> List[DatasetPipeline]:
|
||||
def create_shuffle_pipeline(
|
||||
training_data_dir: str, num_epochs: int, num_shards: int
|
||||
) -> List[DatasetPipeline]:
|
||||
|
||||
return ray.data.read_parquet(training_data_dir) \
|
||||
.repeat(num_epochs) \
|
||||
.random_shuffle_each_window() \
|
||||
return (
|
||||
ray.data.read_parquet(training_data_dir)
|
||||
.repeat(num_epochs)
|
||||
.random_shuffle_each_window()
|
||||
.split(num_shards, equal=True)
|
||||
)
|
||||
|
||||
|
||||
############################################################################
|
||||
|
@ -117,7 +120,8 @@ parser = argparse.ArgumentParser()
|
|||
parser.add_argument(
|
||||
"--large-scale-test",
|
||||
action="store_true",
|
||||
help="Run large scale test (500GiB of data).")
|
||||
help="Run large scale test (500GiB of data).",
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
@ -142,13 +146,15 @@ if not args.large_scale_test:
|
|||
ray.data.read_datasource(
|
||||
RandomIntRowDatasource(),
|
||||
n=size_bytes // 8 // NUM_COLUMNS,
|
||||
num_columns=NUM_COLUMNS).write_parquet(tmpdir)
|
||||
num_columns=NUM_COLUMNS,
|
||||
).write_parquet(tmpdir)
|
||||
return tmpdir
|
||||
|
||||
example_files_dir = generate_example_files(SIZE_100MiB)
|
||||
|
||||
splits = create_shuffle_pipeline(example_files_dir, NUM_EPOCHS,
|
||||
NUM_TRAINING_WORKERS)
|
||||
splits = create_shuffle_pipeline(
|
||||
example_files_dir, NUM_EPOCHS, NUM_TRAINING_WORKERS
|
||||
)
|
||||
|
||||
training_workers = [
|
||||
TrainingWorker.remote(rank, shard) for rank, shard in enumerate(splits)
|
||||
|
@ -198,18 +204,22 @@ if not args.large_scale_test:
|
|||
# generated data.
|
||||
|
||||
|
||||
def create_large_shuffle_pipeline(data_size_bytes: int, num_epochs: int,
|
||||
num_columns: int,
|
||||
num_shards: int) -> List[DatasetPipeline]:
|
||||
def create_large_shuffle_pipeline(
|
||||
data_size_bytes: int, num_epochs: int, num_columns: int, num_shards: int
|
||||
) -> List[DatasetPipeline]:
|
||||
# _spread_resource_prefix is used to ensure tasks are evenly spread to all
|
||||
# CPU nodes.
|
||||
return ray.data.read_datasource(
|
||||
RandomIntRowDatasource(), n=data_size_bytes // 8 // num_columns,
|
||||
return (
|
||||
ray.data.read_datasource(
|
||||
RandomIntRowDatasource(),
|
||||
n=data_size_bytes // 8 // num_columns,
|
||||
num_columns=num_columns,
|
||||
_spread_resource_prefix="node:") \
|
||||
.repeat(num_epochs) \
|
||||
.random_shuffle_each_window(_spread_resource_prefix="node:") \
|
||||
_spread_resource_prefix="node:",
|
||||
)
|
||||
.repeat(num_epochs)
|
||||
.random_shuffle_each_window(_spread_resource_prefix="node:")
|
||||
.split(num_shards, equal=True)
|
||||
)
|
||||
|
||||
|
||||
#################################################################################
|
||||
|
@ -229,19 +239,18 @@ if args.large_scale_test:
|
|||
|
||||
# waiting for cluster nodes to come up.
|
||||
while len(ray.nodes()) < TOTAL_NUM_NODES:
|
||||
print(
|
||||
f"waiting for nodes to start up: {len(ray.nodes())}/{TOTAL_NUM_NODES}"
|
||||
)
|
||||
print(f"waiting for nodes to start up: {len(ray.nodes())}/{TOTAL_NUM_NODES}")
|
||||
time.sleep(5)
|
||||
|
||||
splits = create_large_shuffle_pipeline(SIZE_500GiB, NUM_EPOCHS,
|
||||
NUM_COLUMNS, NUM_TRAINING_WORKERS)
|
||||
splits = create_large_shuffle_pipeline(
|
||||
SIZE_500GiB, NUM_EPOCHS, NUM_COLUMNS, NUM_TRAINING_WORKERS
|
||||
)
|
||||
|
||||
# Note we set num_gpus=1 for workers so that
|
||||
# the workers will only run on GPU nodes.
|
||||
training_workers = [
|
||||
TrainingWorker.options(num_gpus=1) \
|
||||
.remote(rank, shard) for rank, shard in enumerate(splits)
|
||||
TrainingWorker.options(num_gpus=1).remote(rank, shard)
|
||||
for rank, shard in enumerate(splits)
|
||||
]
|
||||
|
||||
start = time.time()
|
||||
|
|
|
@ -4,21 +4,30 @@ import pyximport
|
|||
|
||||
pyximport.install(setup_args={"include_dirs": numpy.get_include()})
|
||||
|
||||
from .cython_simple import simple_func, fib, fib_int, \
|
||||
fib_cpdef, fib_cdef, simple_class
|
||||
from .cython_simple import simple_func, fib, fib_int, fib_cpdef, fib_cdef, simple_class
|
||||
from .masked_log import masked_log
|
||||
|
||||
from .cython_blas import \
|
||||
compute_self_corr_for_voxel_sel, \
|
||||
compute_kernel_matrix, \
|
||||
compute_single_self_corr_syrk, \
|
||||
compute_single_self_corr_gemm, \
|
||||
compute_corr_vectors, \
|
||||
compute_single_matrix_multiplication
|
||||
from .cython_blas import (
|
||||
compute_self_corr_for_voxel_sel,
|
||||
compute_kernel_matrix,
|
||||
compute_single_self_corr_syrk,
|
||||
compute_single_self_corr_gemm,
|
||||
compute_corr_vectors,
|
||||
compute_single_matrix_multiplication,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"simple_func", "fib", "fib_int", "fib_cpdef", "fib_cdef", "simple_class",
|
||||
"masked_log", "compute_self_corr_for_voxel_sel", "compute_kernel_matrix",
|
||||
"compute_single_self_corr_syrk", "compute_single_self_corr_gemm",
|
||||
"compute_corr_vectors", "compute_single_matrix_multiplication"
|
||||
"simple_func",
|
||||
"fib",
|
||||
"fib_int",
|
||||
"fib_cpdef",
|
||||
"fib_cdef",
|
||||
"simple_class",
|
||||
"masked_log",
|
||||
"compute_self_corr_for_voxel_sel",
|
||||
"compute_kernel_matrix",
|
||||
"compute_single_self_corr_syrk",
|
||||
"compute_single_self_corr_gemm",
|
||||
"compute_corr_vectors",
|
||||
"compute_single_matrix_multiplication",
|
||||
]
|
||||
|
|
|
@ -94,11 +94,11 @@ def example8():
|
|||
|
||||
# See cython_blas.pyx for argument documentation
|
||||
mat = np.array(
|
||||
[[[2.0, 2.0], [2.0, 2.0]], [[2.0, 2.0], [2.0, 2.0]]], dtype=np.float32)
|
||||
[[[2.0, 2.0], [2.0, 2.0]], [[2.0, 2.0], [2.0, 2.0]]], dtype=np.float32
|
||||
)
|
||||
result = np.zeros((2, 2), np.float32, order="C")
|
||||
|
||||
run_func(cyth.compute_kernel_matrix, "L", "T", 2, 2, 1.0, mat, 0, 2, 1.0,
|
||||
result, 2)
|
||||
run_func(cyth.compute_kernel_matrix, "L", "T", 2, 2, 1.0, mat, 0, 2, 1.0, result, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -13,6 +13,7 @@ include_dirs = [numpy.get_include()]
|
|||
# dependencies
|
||||
try:
|
||||
import scipy # noqa
|
||||
|
||||
modules.append("cython_blas.pyx")
|
||||
install_requires.append("scipy")
|
||||
except ImportError as e: # noqa
|
||||
|
@ -27,4 +28,5 @@ setup(
|
|||
packages=[pkg_dir],
|
||||
ext_modules=cythonize(modules),
|
||||
install_requires=install_requires,
|
||||
include_dirs=include_dirs)
|
||||
include_dirs=include_dirs,
|
||||
)
|
||||
|
|
|
@ -56,4 +56,5 @@ class CythonTest(unittest.TestCase):
|
|||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -65,31 +65,34 @@ from ray.util.dask import ray_dask_get
|
|||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--address", type=str, default="auto", help="The address to use for Ray.")
|
||||
"--address", type=str, default="auto", help="The address to use for Ray."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
help="Read a smaller dataset for quick testing purposes.")
|
||||
help="Read a smaller dataset for quick testing purposes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-actors",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Sets number of actors for training.")
|
||||
"--num-actors", type=int, default=4, help="Sets number of actors for training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-actor",
|
||||
type=int,
|
||||
default=6,
|
||||
help="The number of CPUs per actor for training.")
|
||||
help="The number of CPUs per actor for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-actors-inference",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Sets number of actors for inference.")
|
||||
help="Sets number of actors for inference.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-actor-inference",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of CPUs per actor for inference.")
|
||||
help="The number of CPUs per actor for inference.",
|
||||
)
|
||||
# Ignore -f from ipykernel_launcher
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
@ -125,12 +128,13 @@ if not ray.is_initialized():
|
|||
LABEL_COLUMN = "label"
|
||||
if smoke_test:
|
||||
# Test dataset with only 10,000 records.
|
||||
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \
|
||||
".csv"
|
||||
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" ".csv"
|
||||
else:
|
||||
# Full dataset. This may take a couple of minutes to load.
|
||||
FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \
|
||||
"/00280/HIGGS.csv.gz"
|
||||
FILE_URL = (
|
||||
"https://archive.ics.uci.edu/ml/machine-learning-databases"
|
||||
"/00280/HIGGS.csv.gz"
|
||||
)
|
||||
colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)]
|
||||
dask.config.set(scheduler=ray_dask_get)
|
||||
|
||||
|
@ -192,7 +196,8 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
|
|||
dtrain=train_set,
|
||||
evals=[(test_set, "eval")],
|
||||
evals_result=evals_result,
|
||||
ray_params=ray_params)
|
||||
ray_params=ray_params,
|
||||
)
|
||||
|
||||
train_end_time = time.time()
|
||||
train_duration = train_end_time - train_start_time
|
||||
|
@ -200,8 +205,7 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
|
|||
|
||||
model_path = "model.xgb"
|
||||
bst.save_model(model_path)
|
||||
print("Final validation error: {:.4f}".format(
|
||||
evals_result["eval"]["error"][-1]))
|
||||
print("Final validation error: {:.4f}".format(evals_result["eval"]["error"][-1]))
|
||||
|
||||
return bst, evals_result
|
||||
|
||||
|
@ -221,8 +225,12 @@ config = {
|
|||
}
|
||||
|
||||
bst, evals_result = train_xgboost(
|
||||
config, train_df, eval_df, LABEL_COLUMN,
|
||||
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors))
|
||||
config,
|
||||
train_df,
|
||||
eval_df,
|
||||
LABEL_COLUMN,
|
||||
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors),
|
||||
)
|
||||
print(f"Results: {evals_result}")
|
||||
|
||||
###############################################################################
|
||||
|
@ -260,13 +268,12 @@ def tune_xgboost(train_df, test_df, target_column):
|
|||
"eval_metric": ["logloss", "error"],
|
||||
"eta": tune.loguniform(1e-4, 1e-1),
|
||||
"subsample": tune.uniform(0.5, 1.0),
|
||||
"max_depth": tune.randint(1, 9)
|
||||
"max_depth": tune.randint(1, 9),
|
||||
}
|
||||
|
||||
ray_params = RayParams(
|
||||
max_actor_restarts=1,
|
||||
cpus_per_actor=cpus_per_actor,
|
||||
num_actors=num_actors)
|
||||
max_actor_restarts=1, cpus_per_actor=cpus_per_actor, num_actors=num_actors
|
||||
)
|
||||
|
||||
tune_start_time = time.time()
|
||||
|
||||
|
@ -276,19 +283,21 @@ def tune_xgboost(train_df, test_df, target_column):
|
|||
train_df=train_df,
|
||||
test_df=test_df,
|
||||
target_column=target_column,
|
||||
ray_params=ray_params),
|
||||
ray_params=ray_params,
|
||||
),
|
||||
# Use the `get_tune_resources` helper function to set the resources.
|
||||
resources_per_trial=ray_params.get_tune_resources(),
|
||||
config=config,
|
||||
num_samples=10,
|
||||
metric="eval-error",
|
||||
mode="min")
|
||||
mode="min",
|
||||
)
|
||||
|
||||
tune_end_time = time.time()
|
||||
tune_duration = tune_end_time - tune_start_time
|
||||
print(f"Total time taken: {tune_duration} seconds.")
|
||||
|
||||
accuracy = 1. - analysis.best_result["eval-error"]
|
||||
accuracy = 1.0 - analysis.best_result["eval-error"]
|
||||
print(f"Best model parameters: {analysis.best_config}")
|
||||
print(f"Best model total accuracy: {accuracy:.4f}")
|
||||
|
||||
|
@ -315,7 +324,8 @@ results = predict(
|
|||
bst,
|
||||
inference_df,
|
||||
ray_params=RayParams(
|
||||
cpus_per_actor=cpus_per_actor_inference,
|
||||
num_actors=num_actors_inference))
|
||||
cpus_per_actor=cpus_per_actor_inference, num_actors=num_actors_inference
|
||||
),
|
||||
)
|
||||
|
||||
print(results)
|
||||
|
|
|
@ -59,7 +59,8 @@ def make_and_upload_dataset(dir_path):
|
|||
shift=0.0,
|
||||
scale=1.0,
|
||||
shuffle=False,
|
||||
random_state=seed)
|
||||
random_state=seed,
|
||||
)
|
||||
|
||||
# turn into dataframe with column names
|
||||
col_names = ["feature_%0d" % i for i in range(1, d + 1, 1)]
|
||||
|
@ -91,10 +92,8 @@ def make_and_upload_dataset(dir_path):
|
|||
path = os.path.join(data_path, f"data_{i:05d}.parquet.snappy")
|
||||
if not os.path.exists(path):
|
||||
tmp_df = create_data_chunk(
|
||||
n=PARQUET_FILE_CHUNK_SIZE,
|
||||
d=NUM_FEATURES,
|
||||
seed=i,
|
||||
include_label=True)
|
||||
n=PARQUET_FILE_CHUNK_SIZE, d=NUM_FEATURES, seed=i, include_label=True
|
||||
)
|
||||
tmp_df.to_parquet(path, compression="snappy", index=False)
|
||||
print(f"Wrote {path} to disk...")
|
||||
# todo: at large enough scale we might want to upload the rest after
|
||||
|
@ -108,10 +107,8 @@ def make_and_upload_dataset(dir_path):
|
|||
path = os.path.join(inference_path, f"data_{i:05d}.parquet.snappy")
|
||||
if not os.path.exists(path):
|
||||
tmp_df = create_data_chunk(
|
||||
n=PARQUET_FILE_CHUNK_SIZE,
|
||||
d=NUM_FEATURES,
|
||||
seed=i,
|
||||
include_label=False)
|
||||
n=PARQUET_FILE_CHUNK_SIZE, d=NUM_FEATURES, seed=i, include_label=False
|
||||
)
|
||||
tmp_df.to_parquet(path, compression="snappy", index=False)
|
||||
print(f"Wrote {path} to disk...")
|
||||
# todo: at large enough scale we might want to upload the rest after
|
||||
|
@ -124,8 +121,9 @@ def make_and_upload_dataset(dir_path):
|
|||
|
||||
def read_dataset(path: str) -> ray.data.Dataset:
|
||||
print(f"reading data from {path}")
|
||||
return ray.data.read_parquet(path, _spread_resource_prefix="node:") \
|
||||
.random_shuffle(_spread_resource_prefix="node:")
|
||||
return ray.data.read_parquet(path, _spread_resource_prefix="node:").random_shuffle(
|
||||
_spread_resource_prefix="node:"
|
||||
)
|
||||
|
||||
|
||||
class DataPreprocessor:
|
||||
|
@ -141,20 +139,20 @@ class DataPreprocessor:
|
|||
# columns.
|
||||
self.standard_stats = None
|
||||
|
||||
def preprocess_train_data(self, ds: ray.data.Dataset
|
||||
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
|
||||
def preprocess_train_data(
|
||||
self, ds: ray.data.Dataset
|
||||
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
|
||||
print("\n\nPreprocessing training dataset.\n")
|
||||
return self._preprocess(ds, False)
|
||||
|
||||
def preprocess_inference_data(self,
|
||||
df: ray.data.Dataset) -> ray.data.Dataset:
|
||||
def preprocess_inference_data(self, df: ray.data.Dataset) -> ray.data.Dataset:
|
||||
print("\n\nPreprocessing inference dataset.\n")
|
||||
return self._preprocess(df, True)[0]
|
||||
|
||||
def _preprocess(self, ds: ray.data.Dataset, inferencing: bool
|
||||
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
|
||||
print(
|
||||
"\nStep 1: Dropping nulls, creating new_col, updating feature_1\n")
|
||||
def _preprocess(
|
||||
self, ds: ray.data.Dataset, inferencing: bool
|
||||
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
|
||||
print("\nStep 1: Dropping nulls, creating new_col, updating feature_1\n")
|
||||
|
||||
def batch_transformer(df: pd.DataFrame):
|
||||
# Disable chained assignment warning.
|
||||
|
@ -165,25 +163,27 @@ class DataPreprocessor:
|
|||
|
||||
# Add new column.
|
||||
df["new_col"] = (
|
||||
df["feature_1"] - 2 * df["feature_2"] + df["feature_3"]) / 3.
|
||||
df["feature_1"] - 2 * df["feature_2"] + df["feature_3"]
|
||||
) / 3.0
|
||||
|
||||
# Transform column.
|
||||
df["feature_1"] = 2. * df["feature_1"] + 0.1
|
||||
df["feature_1"] = 2.0 * df["feature_1"] + 0.1
|
||||
|
||||
return df
|
||||
|
||||
ds = ds.map_batches(batch_transformer, batch_format="pandas")
|
||||
|
||||
print("\nStep 2: Precalculating fruit-grouped mean for new column and "
|
||||
"for one-hot encoding (latter only uses fruit groups)\n")
|
||||
print(
|
||||
"\nStep 2: Precalculating fruit-grouped mean for new column and "
|
||||
"for one-hot encoding (latter only uses fruit groups)\n"
|
||||
)
|
||||
agg_ds = ds.groupby("fruit").mean("feature_1")
|
||||
fruit_means = {
|
||||
r["fruit"]: r["mean(feature_1)"]
|
||||
for r in agg_ds.take_all()
|
||||
}
|
||||
fruit_means = {r["fruit"]: r["mean(feature_1)"] for r in agg_ds.take_all()}
|
||||
|
||||
print("\nStep 3: create mean_by_fruit as mean of feature_1 groupby "
|
||||
"fruit; one-hot encode fruit column\n")
|
||||
print(
|
||||
"\nStep 3: create mean_by_fruit as mean of feature_1 groupby "
|
||||
"fruit; one-hot encode fruit column\n"
|
||||
)
|
||||
|
||||
if inferencing:
|
||||
assert self.fruits is not None
|
||||
|
@ -192,8 +192,7 @@ class DataPreprocessor:
|
|||
self.fruits = list(fruit_means.keys())
|
||||
|
||||
fruit_one_hots = {
|
||||
fruit: collections.defaultdict(int, fruit=1)
|
||||
for fruit in self.fruits
|
||||
fruit: collections.defaultdict(int, fruit=1) for fruit in self.fruits
|
||||
}
|
||||
|
||||
def batch_transformer(df: pd.DataFrame):
|
||||
|
@ -224,12 +223,12 @@ class DataPreprocessor:
|
|||
# Split into 90% training set, 10% test set.
|
||||
train_ds, test_ds = ds.split_at_indices([split_index])
|
||||
|
||||
print("\nStep 4b: Precalculate training dataset stats for "
|
||||
"standard scaling\n")
|
||||
print(
|
||||
"\nStep 4b: Precalculate training dataset stats for "
|
||||
"standard scaling\n"
|
||||
)
|
||||
# Calculate stats needed for standard scaling feature columns.
|
||||
feature_columns = [
|
||||
col for col in train_ds.schema().names if col != "label"
|
||||
]
|
||||
feature_columns = [col for col in train_ds.schema().names if col != "label"]
|
||||
standard_aggs = [
|
||||
agg(on=col) for col in feature_columns for agg in (Mean, Std)
|
||||
]
|
||||
|
@ -252,30 +251,29 @@ class DataPreprocessor:
|
|||
|
||||
if inferencing:
|
||||
# Apply standard scaling to inference dataset.
|
||||
inference_ds = ds.map_batches(
|
||||
batch_standard_scaler, batch_format="pandas")
|
||||
inference_ds = ds.map_batches(batch_standard_scaler, batch_format="pandas")
|
||||
return inference_ds, None
|
||||
else:
|
||||
# Apply standard scaling to both training dataset and test dataset.
|
||||
train_ds = train_ds.map_batches(
|
||||
batch_standard_scaler, batch_format="pandas")
|
||||
test_ds = test_ds.map_batches(
|
||||
batch_standard_scaler, batch_format="pandas")
|
||||
batch_standard_scaler, batch_format="pandas"
|
||||
)
|
||||
test_ds = test_ds.map_batches(batch_standard_scaler, batch_format="pandas")
|
||||
return train_ds, test_ds
|
||||
|
||||
|
||||
def inference(dataset, model_cls: type, batch_size: int, result_path: str,
|
||||
use_gpu: bool):
|
||||
def inference(
|
||||
dataset, model_cls: type, batch_size: int, result_path: str, use_gpu: bool
|
||||
):
|
||||
print("inferencing...")
|
||||
num_gpus = 1 if use_gpu else 0
|
||||
dataset \
|
||||
.map_batches(
|
||||
model_cls,
|
||||
compute="actors",
|
||||
batch_size=batch_size,
|
||||
num_gpus=num_gpus,
|
||||
num_cpus=0) \
|
||||
.write_parquet(result_path)
|
||||
dataset.map_batches(
|
||||
model_cls,
|
||||
compute="actors",
|
||||
batch_size=batch_size,
|
||||
num_gpus=num_gpus,
|
||||
num_cpus=0,
|
||||
).write_parquet(result_path)
|
||||
|
||||
|
||||
"""
|
||||
|
@ -295,8 +293,7 @@ P1:
|
|||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, n_layers, n_features, num_hidden, dropout_every,
|
||||
drop_prob):
|
||||
def __init__(self, n_layers, n_features, num_hidden, dropout_every, drop_prob):
|
||||
super().__init__()
|
||||
self.n_layers = n_layers
|
||||
self.dropout_every = dropout_every
|
||||
|
@ -406,8 +403,9 @@ def train_func(config):
|
|||
print("Defining model, loss, and optimizer...")
|
||||
|
||||
# Setup device.
|
||||
device = torch.device(f"cuda:{train.local_rank()}"
|
||||
if use_gpu and torch.cuda.is_available() else "cpu")
|
||||
device = torch.device(
|
||||
f"cuda:{train.local_rank()}" if use_gpu and torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
print(f"Device: {device}")
|
||||
|
||||
# Setup data.
|
||||
|
@ -415,7 +413,8 @@ def train_func(config):
|
|||
train_dataset_epoch_iterator = train_dataset_pipeline.iter_epochs()
|
||||
test_dataset = train.get_dataset_shard("test_dataset")
|
||||
test_torch_dataset = test_dataset.to_torch(
|
||||
label_column="label", batch_size=batch_size)
|
||||
label_column="label", batch_size=batch_size
|
||||
)
|
||||
|
||||
net = Net(
|
||||
n_layers=num_layers,
|
||||
|
@ -436,30 +435,37 @@ def train_func(config):
|
|||
train_dataset = next(train_dataset_epoch_iterator)
|
||||
|
||||
train_torch_dataset = train_dataset.to_torch(
|
||||
label_column="label", batch_size=batch_size)
|
||||
label_column="label", batch_size=batch_size
|
||||
)
|
||||
|
||||
train_running_loss, train_num_correct, train_num_total = train_epoch(
|
||||
train_torch_dataset, net, device, criterion, optimizer)
|
||||
train_torch_dataset, net, device, criterion, optimizer
|
||||
)
|
||||
train_acc = train_num_correct / train_num_total
|
||||
print(f"epoch [{epoch + 1}]: training accuracy: "
|
||||
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}")
|
||||
print(
|
||||
f"epoch [{epoch + 1}]: training accuracy: "
|
||||
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}"
|
||||
)
|
||||
|
||||
test_running_loss, test_num_correct, test_num_total = test_epoch(
|
||||
test_torch_dataset, net, device, criterion)
|
||||
test_torch_dataset, net, device, criterion
|
||||
)
|
||||
test_acc = test_num_correct / test_num_total
|
||||
print(f"epoch [{epoch + 1}]: testing accuracy: "
|
||||
f"{test_num_correct} / {test_num_total} = {test_acc:.4f}")
|
||||
print(
|
||||
f"epoch [{epoch + 1}]: testing accuracy: "
|
||||
f"{test_num_correct} / {test_num_total} = {test_acc:.4f}"
|
||||
)
|
||||
|
||||
# Record and log stats.
|
||||
train.report(
|
||||
train_acc=train_acc,
|
||||
train_loss=train_running_loss,
|
||||
test_acc=test_acc,
|
||||
test_loss=test_running_loss)
|
||||
test_loss=test_running_loss,
|
||||
)
|
||||
|
||||
# Checkpoint model.
|
||||
module = (net.module
|
||||
if isinstance(net, DistributedDataParallel) else net)
|
||||
module = net.module if isinstance(net, DistributedDataParallel) else net
|
||||
train.save_checkpoint(model_state_dict=module.state_dict())
|
||||
|
||||
if train.world_rank() == 0:
|
||||
|
@ -469,46 +475,44 @@ def train_func(config):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dir-path",
|
||||
default=".",
|
||||
type=str,
|
||||
help="Path to read and write data from")
|
||||
"--dir-path", default=".", type=str, help="Path to read and write data from"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-s3",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use data from s3 for testing.")
|
||||
help="Use data from s3 for testing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for testing.")
|
||||
help="Finish quickly for testing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="The address to use for Ray. "
|
||||
"`auto` if running through `ray submit.")
|
||||
help="The address to use for Ray. " "`auto` if running through `ray submit.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
default=1,
|
||||
type=int,
|
||||
help="The number of Ray workers to use for distributed training")
|
||||
help="The number of Ray workers to use for distributed training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--large-dataset",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use 500GB dataset")
|
||||
"--large-dataset", action="store_true", default=False, help="Use 500GB dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use GPU for training.")
|
||||
"--use-gpu", action="store_true", default=False, help="Use GPU for training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlflow-register-model",
|
||||
action="store_true",
|
||||
help="Whether to use mlflow model registry. If set, a local MLflow "
|
||||
"tracking server is expected to have already been started.")
|
||||
"tracking server is expected to have already been started.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
smoke_test = args.smoke_test
|
||||
|
@ -553,8 +557,11 @@ if __name__ == "__main__":
|
|||
if len(list(count)) == 0:
|
||||
print("please run `python make_and_upload_dataset.py` first")
|
||||
sys.exit(1)
|
||||
data_path = ("s3://cuj-big-data/big-data/"
|
||||
if large_dataset else "s3://cuj-big-data/data/")
|
||||
data_path = (
|
||||
"s3://cuj-big-data/big-data/"
|
||||
if large_dataset
|
||||
else "s3://cuj-big-data/data/"
|
||||
)
|
||||
inference_path = "s3://cuj-big-data/inference/"
|
||||
inference_output_path = "s3://cuj-big-data/output/"
|
||||
else:
|
||||
|
@ -562,20 +569,19 @@ if __name__ == "__main__":
|
|||
inference_path = os.path.join(dir_path, "inference")
|
||||
inference_output_path = "/tmp"
|
||||
|
||||
if len(os.listdir(data_path)) <= 1 or len(
|
||||
os.listdir(inference_path)) <= 1:
|
||||
if len(os.listdir(data_path)) <= 1 or len(os.listdir(inference_path)) <= 1:
|
||||
print("please run `python make_and_upload_dataset.py` first")
|
||||
sys.exit(1)
|
||||
|
||||
if smoke_test:
|
||||
# Only read a single file.
|
||||
data_path = os.path.join(data_path, "data_00000.parquet.snappy")
|
||||
inference_path = os.path.join(inference_path,
|
||||
"data_00000.parquet.snappy")
|
||||
inference_path = os.path.join(inference_path, "data_00000.parquet.snappy")
|
||||
|
||||
preprocessor = DataPreprocessor()
|
||||
train_dataset, test_dataset = preprocessor.preprocess_train_data(
|
||||
read_dataset(data_path))
|
||||
read_dataset(data_path)
|
||||
)
|
||||
|
||||
num_columns = len(train_dataset.schema().names)
|
||||
# remove label column and internal Arrow column.
|
||||
|
@ -589,14 +595,12 @@ if __name__ == "__main__":
|
|||
DROPOUT_PROB = 0.2
|
||||
|
||||
# Random global shuffle
|
||||
train_dataset_pipeline = train_dataset.repeat() \
|
||||
.random_shuffle_each_window(_spread_resource_prefix="node:")
|
||||
train_dataset_pipeline = train_dataset.repeat().random_shuffle_each_window(
|
||||
_spread_resource_prefix="node:"
|
||||
)
|
||||
del train_dataset
|
||||
|
||||
datasets = {
|
||||
"train_dataset": train_dataset_pipeline,
|
||||
"test_dataset": test_dataset
|
||||
}
|
||||
datasets = {"train_dataset": train_dataset_pipeline, "test_dataset": test_dataset}
|
||||
|
||||
config = {
|
||||
"use_gpu": use_gpu,
|
||||
|
@ -606,7 +610,7 @@ if __name__ == "__main__":
|
|||
"num_layers": NUM_LAYERS,
|
||||
"dropout_every": DROPOUT_EVERY,
|
||||
"dropout_prob": DROPOUT_PROB,
|
||||
"num_features": num_features
|
||||
"num_features": num_features,
|
||||
}
|
||||
|
||||
# Create 2 callbacks: one for Tensorboard Logging and one for MLflow
|
||||
|
@ -619,7 +623,8 @@ if __name__ == "__main__":
|
|||
callbacks = [
|
||||
TBXLoggerCallback(logdir=tbx_logdir),
|
||||
MLflowLoggerCallback(
|
||||
experiment_name="cuj-big-data-training", save_artifact=True)
|
||||
experiment_name="cuj-big-data-training", save_artifact=True
|
||||
),
|
||||
]
|
||||
|
||||
# Remove CPU resource so Datasets can be scheduled.
|
||||
|
@ -629,19 +634,19 @@ if __name__ == "__main__":
|
|||
backend="torch",
|
||||
num_workers=num_workers,
|
||||
use_gpu=use_gpu,
|
||||
resources_per_worker=resources_per_worker)
|
||||
resources_per_worker=resources_per_worker,
|
||||
)
|
||||
trainer.start()
|
||||
results = trainer.run(
|
||||
train_func=train_func,
|
||||
config=config,
|
||||
callbacks=callbacks,
|
||||
dataset=datasets)
|
||||
train_func=train_func, config=config, callbacks=callbacks, dataset=datasets
|
||||
)
|
||||
model = results[0]
|
||||
trainer.shutdown()
|
||||
|
||||
if args.mlflow_register_model:
|
||||
mlflow.pytorch.log_model(
|
||||
model, artifact_path="models", registered_model_name="torch_model")
|
||||
model, artifact_path="models", registered_model_name="torch_model"
|
||||
)
|
||||
|
||||
# Get the latest model from mlflow model registry.
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
|
@ -649,12 +654,14 @@ if __name__ == "__main__":
|
|||
# Get the info for the latest model.
|
||||
# By default, registered models are in stage "None".
|
||||
latest_model_info = client.get_latest_versions(
|
||||
registered_model_name, stages=["None"])[0]
|
||||
registered_model_name, stages=["None"]
|
||||
)[0]
|
||||
latest_version = latest_model_info.version
|
||||
|
||||
def load_model_func():
|
||||
model_uri = f"models:/torch_model/{latest_version}"
|
||||
return mlflow.pytorch.load_model(model_uri)
|
||||
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
|
@ -670,25 +677,30 @@ if __name__ == "__main__":
|
|||
n_features=num_features,
|
||||
num_hidden=num_hidden,
|
||||
dropout_every=dropout_every,
|
||||
drop_prob=dropout_prob)
|
||||
drop_prob=dropout_prob,
|
||||
)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
class BatchInferModel:
|
||||
def __init__(self, load_model_func):
|
||||
self.device = torch.device("cuda:0"
|
||||
if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self.model = load_model_func().to(self.device)
|
||||
|
||||
def __call__(self, batch) -> "pd.DataFrame":
|
||||
tensor = torch.FloatTensor(batch.to_pandas().values).to(
|
||||
self.device)
|
||||
tensor = torch.FloatTensor(batch.to_pandas().values).to(self.device)
|
||||
return pd.DataFrame(self.model(tensor).cpu().detach().numpy())
|
||||
|
||||
inference_dataset = preprocessor.preprocess_inference_data(
|
||||
read_dataset(inference_path))
|
||||
inference(inference_dataset, BatchInferModel(load_model_func), 100,
|
||||
inference_output_path, use_gpu)
|
||||
read_dataset(inference_path)
|
||||
)
|
||||
inference(
|
||||
inference_dataset,
|
||||
BatchInferModel(load_model_func),
|
||||
100,
|
||||
inference_output_path,
|
||||
use_gpu,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
|
|
|
@ -14,20 +14,23 @@ class MyActor:
|
|||
self.counter = Counter(
|
||||
"num_requests",
|
||||
description="Number of requests processed by the actor.",
|
||||
tag_keys=("actor_name", ))
|
||||
tag_keys=("actor_name",),
|
||||
)
|
||||
self.counter.set_default_tags({"actor_name": name})
|
||||
|
||||
self.gauge = Gauge(
|
||||
"curr_count",
|
||||
description="Current count held by the actor. Goes up and down.",
|
||||
tag_keys=("actor_name", ))
|
||||
tag_keys=("actor_name",),
|
||||
)
|
||||
self.gauge.set_default_tags({"actor_name": name})
|
||||
|
||||
self.histogram = Histogram(
|
||||
"request_latency",
|
||||
description="Latencies of requests in ms.",
|
||||
boundaries=[0.1, 1],
|
||||
tag_keys=("actor_name", ))
|
||||
tag_keys=("actor_name",),
|
||||
)
|
||||
self.histogram.set_default_tags({"actor_name": name})
|
||||
|
||||
def process_request(self, num):
|
||||
|
|
|
@ -46,7 +46,8 @@ class LinearModel(object):
|
|||
y_ = tf.placeholder(tf.float32, [None, shape[1]])
|
||||
self.y_ = y_
|
||||
cross_entropy = tf.reduce_mean(
|
||||
-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
|
||||
-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])
|
||||
)
|
||||
self.cross_entropy = cross_entropy
|
||||
self.cross_entropy_grads = tf.gradients(cross_entropy, [w, b])
|
||||
self.sess = tf.Session()
|
||||
|
@ -54,24 +55,20 @@ class LinearModel(object):
|
|||
# Ray's TensorFlowVariables to automatically create methods to modify
|
||||
# the weights.
|
||||
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
cross_entropy, self.sess)
|
||||
cross_entropy, self.sess
|
||||
)
|
||||
|
||||
def loss(self, xs, ys):
|
||||
"""Computes the loss of the network."""
|
||||
return float(
|
||||
self.sess.run(
|
||||
self.cross_entropy, feed_dict={
|
||||
self.x: xs,
|
||||
self.y_: ys
|
||||
}))
|
||||
self.sess.run(self.cross_entropy, feed_dict={self.x: xs, self.y_: ys})
|
||||
)
|
||||
|
||||
def grad(self, xs, ys):
|
||||
"""Computes the gradients of the network."""
|
||||
return self.sess.run(
|
||||
self.cross_entropy_grads, feed_dict={
|
||||
self.x: xs,
|
||||
self.y_: ys
|
||||
})
|
||||
self.cross_entropy_grads, feed_dict={self.x: xs, self.y_: ys}
|
||||
)
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
@ -143,4 +140,5 @@ if __name__ == "__main__":
|
|||
# Use L-BFGS to minimize the loss function.
|
||||
print("Running L-BFGS.")
|
||||
result = scipy.optimize.fmin_l_bfgs_b(
|
||||
full_loss, theta_init, maxiter=10, fprime=full_grad, disp=True)
|
||||
full_loss, theta_init, maxiter=10, fprime=full_grad, disp=True
|
||||
)
|
||||
|
|
|
@ -26,8 +26,7 @@ class RayDistributedActor:
|
|||
"""
|
||||
|
||||
# Set the init_method and rank of the process for distributed training.
|
||||
print("Ray worker at {url} rank {rank}".format(
|
||||
url=url, rank=world_rank))
|
||||
print("Ray worker at {url} rank {rank}".format(url=url, rank=world_rank))
|
||||
self.url = url
|
||||
self.world_rank = world_rank
|
||||
args.distributed_rank = world_rank
|
||||
|
@ -55,8 +54,10 @@ class RayDistributedActor:
|
|||
n_cpus = int(ray.cluster_resources()["CPU"])
|
||||
if n_cpus > original_n_cpus:
|
||||
raise Exception(
|
||||
"New CPUs find (original %d CPUs, now %d CPUs)" %
|
||||
(original_n_cpus, n_cpus))
|
||||
"New CPUs find (original %d CPUs, now %d CPUs)"
|
||||
% (original_n_cpus, n_cpus)
|
||||
)
|
||||
|
||||
else:
|
||||
original_n_gpus = args.distributed_world_size
|
||||
|
||||
|
@ -65,8 +66,9 @@ class RayDistributedActor:
|
|||
n_gpus = int(ray.cluster_resources().get("GPU", 0))
|
||||
if n_gpus > original_n_gpus:
|
||||
raise Exception(
|
||||
"New GPUs find (original %d GPUs, now %d GPUs)" %
|
||||
(original_n_gpus, n_gpus))
|
||||
"New GPUs find (original %d GPUs, now %d GPUs)"
|
||||
% (original_n_gpus, n_gpus)
|
||||
)
|
||||
|
||||
fairseq.checkpoint_utils.save_checkpoint = _new_save_checkpoint
|
||||
|
||||
|
@ -103,8 +105,7 @@ def run_fault_tolerant_loop():
|
|||
set_batch_size(args)
|
||||
|
||||
# Set up Ray distributed actors.
|
||||
Actor = ray.remote(
|
||||
num_cpus=1, num_gpus=int(not args.cpu))(RayDistributedActor)
|
||||
Actor = ray.remote(num_cpus=1, num_gpus=int(not args.cpu))(RayDistributedActor)
|
||||
workers = [Actor.remote() for i in range(args.distributed_world_size)]
|
||||
|
||||
# Get the IP address and a free port of actor 0, which is used for
|
||||
|
@ -116,8 +117,7 @@ def run_fault_tolerant_loop():
|
|||
# Start the remote processes, and check whether their are any process
|
||||
# fails. If so, restart all the processes.
|
||||
unfinished = [
|
||||
worker.run.remote(address, i, args)
|
||||
for i, worker in enumerate(workers)
|
||||
worker.run.remote(address, i, args) for i, worker in enumerate(workers)
|
||||
]
|
||||
try:
|
||||
while len(unfinished) > 0:
|
||||
|
@ -135,10 +135,8 @@ def add_ray_args(parser):
|
|||
"""Add ray and fault-tolerance related parser arguments to the parser."""
|
||||
group = parser.add_argument_group("Ray related arguments")
|
||||
group.add_argument(
|
||||
"--ray-address",
|
||||
default="auto",
|
||||
type=str,
|
||||
help="address for ray initialization")
|
||||
"--ray-address", default="auto", type=str, help="address for ray initialization"
|
||||
)
|
||||
group.add_argument(
|
||||
"--fix-batch-size",
|
||||
default=None,
|
||||
|
@ -147,7 +145,8 @@ def add_ray_args(parser):
|
|||
help="fix the actual batch size (max_sentences * update_freq "
|
||||
"* n_GPUs) to be the fixed input values by adjusting update_freq "
|
||||
"accroding to actual n_GPUs; the batch size is fixed to B_i for "
|
||||
"epoch i; all epochs >N are fixed to B_N")
|
||||
"epoch i; all epochs >N are fixed to B_N",
|
||||
)
|
||||
return group
|
||||
|
||||
|
||||
|
@ -168,13 +167,13 @@ def set_batch_size(args):
|
|||
"""Fixes the total batch_size to be agnostic to the GPU count."""
|
||||
if args.fix_batch_size is not None:
|
||||
args.update_freq = [
|
||||
math.ceil(batch_size /
|
||||
(args.max_sentences * args.distributed_world_size))
|
||||
math.ceil(batch_size / (args.max_sentences * args.distributed_world_size))
|
||||
for batch_size in args.fix_batch_size
|
||||
]
|
||||
print("Training on %d GPUs, max_sentences=%d, update_freq=%s" %
|
||||
(args.distributed_world_size, args.max_sentences,
|
||||
repr(args.update_freq)))
|
||||
print(
|
||||
"Training on %d GPUs, max_sentences=%d, update_freq=%s"
|
||||
% (args.distributed_world_size, args.max_sentences, repr(args.update_freq))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -62,31 +62,34 @@ import ray
|
|||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--address", type=str, default="auto", help="The address to use for Ray.")
|
||||
"--address", type=str, default="auto", help="The address to use for Ray."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
help="Read a smaller dataset for quick testing purposes.")
|
||||
help="Read a smaller dataset for quick testing purposes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-actors",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Sets number of actors for training.")
|
||||
"--num-actors", type=int, default=4, help="Sets number of actors for training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-actor",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The number of CPUs per actor for training.")
|
||||
help="The number of CPUs per actor for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-actors-inference",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Sets number of actors for inference.")
|
||||
help="Sets number of actors for inference.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-actor-inference",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of CPUs per actor for inference.")
|
||||
help="The number of CPUs per actor for inference.",
|
||||
)
|
||||
# Ignore -f from ipykernel_launcher
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
@ -119,12 +122,13 @@ if not ray.is_initialized():
|
|||
LABEL_COLUMN = "label"
|
||||
if smoke_test:
|
||||
# Test dataset with only 10,000 records.
|
||||
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \
|
||||
".csv"
|
||||
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" ".csv"
|
||||
else:
|
||||
# Full dataset. This may take a couple of minutes to load.
|
||||
FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \
|
||||
"/00280/HIGGS.csv.gz"
|
||||
FILE_URL = (
|
||||
"https://archive.ics.uci.edu/ml/machine-learning-databases"
|
||||
"/00280/HIGGS.csv.gz"
|
||||
)
|
||||
|
||||
colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)]
|
||||
|
||||
|
@ -182,7 +186,8 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
|
|||
evals_result=evals_result,
|
||||
verbose_eval=False,
|
||||
num_boost_round=100,
|
||||
ray_params=ray_params)
|
||||
ray_params=ray_params,
|
||||
)
|
||||
|
||||
train_end_time = time.time()
|
||||
train_duration = train_end_time - train_start_time
|
||||
|
@ -190,8 +195,7 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
|
|||
|
||||
model_path = "model.xgb"
|
||||
bst.save_model(model_path)
|
||||
print("Final validation error: {:.4f}".format(
|
||||
evals_result["eval"]["error"][-1]))
|
||||
print("Final validation error: {:.4f}".format(evals_result["eval"]["error"][-1]))
|
||||
|
||||
return bst, evals_result
|
||||
|
||||
|
@ -208,8 +212,12 @@ config = {
|
|||
}
|
||||
|
||||
bst, evals_result = train_xgboost(
|
||||
config, df_train, df_validation, LABEL_COLUMN,
|
||||
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors))
|
||||
config,
|
||||
df_train,
|
||||
df_validation,
|
||||
LABEL_COLUMN,
|
||||
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors),
|
||||
)
|
||||
print(f"Results: {evals_result}")
|
||||
|
||||
###############################################################################
|
||||
|
@ -227,7 +235,8 @@ results = predict(
|
|||
bst,
|
||||
inference_df,
|
||||
ray_params=RayParams(
|
||||
cpus_per_actor=cpus_per_actor_inference,
|
||||
num_actors=num_actors_inference))
|
||||
cpus_per_actor=cpus_per_actor_inference, num_actors=num_actors_inference
|
||||
),
|
||||
)
|
||||
|
||||
print(results)
|
||||
|
|
|
@ -12,10 +12,12 @@ class NewsServer(object):
|
|||
def __init__(self):
|
||||
self.conn = sqlite3.connect("newsreader.db")
|
||||
c = self.conn.cursor()
|
||||
c.execute("""CREATE TABLE IF NOT EXISTS news
|
||||
c.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS news
|
||||
(title text, link text,
|
||||
description text, published timestamp,
|
||||
feed url, liked bool)""")
|
||||
feed url, liked bool)"""
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def retrieve_feed(self, url):
|
||||
|
@ -24,36 +26,41 @@ class NewsServer(object):
|
|||
items = []
|
||||
c = self.conn.cursor()
|
||||
for item in feed.items:
|
||||
items.append({
|
||||
"title": item.title,
|
||||
"link": item.link,
|
||||
"description": item.description,
|
||||
"description_text": item.description,
|
||||
"pubDate": str(item.pub_date)
|
||||
})
|
||||
items.append(
|
||||
{
|
||||
"title": item.title,
|
||||
"link": item.link,
|
||||
"description": item.description,
|
||||
"description_text": item.description,
|
||||
"pubDate": str(item.pub_date),
|
||||
}
|
||||
)
|
||||
c.execute(
|
||||
"""INSERT INTO news (title, link, description,
|
||||
published, feed, liked) values
|
||||
(?, ?, ?, ?, ?, ?)""",
|
||||
(item.title, item.link, item.description, item.pub_date,
|
||||
feed.link, False))
|
||||
(
|
||||
item.title,
|
||||
item.link,
|
||||
item.description,
|
||||
item.pub_date,
|
||||
feed.link,
|
||||
False,
|
||||
),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
return {
|
||||
"channel": {
|
||||
"title": feed.title,
|
||||
"link": feed.link,
|
||||
"url": feed.link
|
||||
},
|
||||
"items": items
|
||||
"channel": {"title": feed.title, "link": feed.link, "url": feed.link},
|
||||
"items": items,
|
||||
}
|
||||
|
||||
def like_item(self, url, is_faved):
|
||||
c = self.conn.cursor()
|
||||
if is_faved:
|
||||
c.execute("UPDATE news SET liked = 1 WHERE link = ?", (url, ))
|
||||
c.execute("UPDATE news SET liked = 1 WHERE link = ?", (url,))
|
||||
else:
|
||||
c.execute("UPDATE news SET liked = 0 WHERE link = ?", (url, ))
|
||||
c.execute("UPDATE news SET liked = 0 WHERE link = ?", (url,))
|
||||
self.conn.commit()
|
||||
|
||||
|
||||
|
@ -77,9 +84,7 @@ def dispatcher():
|
|||
result = ray.get(method.remote(*method_args))
|
||||
return jsonify(result)
|
||||
else:
|
||||
return jsonify({
|
||||
"error": "method_name '" + method_name + "' not found"
|
||||
})
|
||||
return jsonify({"error": "method_name '" + method_name + "' not found"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -44,16 +44,16 @@ num_evaluations = 10
|
|||
# A function for generating random hyperparameters.
|
||||
def generate_hyperparameters():
|
||||
return {
|
||||
"learning_rate": 10**np.random.uniform(-5, 1),
|
||||
"learning_rate": 10 ** np.random.uniform(-5, 1),
|
||||
"batch_size": np.random.randint(1, 100),
|
||||
"momentum": np.random.uniform(0, 1)
|
||||
"momentum": np.random.uniform(0, 1),
|
||||
}
|
||||
|
||||
|
||||
def get_data_loaders(batch_size):
|
||||
mnist_transforms = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))])
|
||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||
)
|
||||
|
||||
# We add FileLock here because multiple workers will want to
|
||||
# download data, and this may cause overwrites since
|
||||
|
@ -61,16 +61,16 @@ def get_data_loaders(batch_size):
|
|||
with FileLock(os.path.expanduser("~/data.lock")):
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
"~/data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=mnist_transforms),
|
||||
"~/data", train=True, download=True, transform=mnist_transforms
|
||||
),
|
||||
batch_size=batch_size,
|
||||
shuffle=True)
|
||||
shuffle=True,
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
|
||||
batch_size=batch_size,
|
||||
shuffle=True)
|
||||
shuffle=True,
|
||||
)
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
|
@ -152,9 +152,8 @@ def evaluate_hyperparameters(config):
|
|||
model = ConvNet()
|
||||
train_loader, test_loader = get_data_loaders(config["batch_size"])
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config["learning_rate"],
|
||||
momentum=config["momentum"])
|
||||
model.parameters(), lr=config["learning_rate"], momentum=config["momentum"]
|
||||
)
|
||||
train(model, optimizer, train_loader)
|
||||
return test(model, test_loader)
|
||||
|
||||
|
@ -202,22 +201,33 @@ while remaining_ids:
|
|||
|
||||
hyperparameters = hyperparameters_mapping[result_id]
|
||||
accuracy = ray.get(result_id)
|
||||
print("""We achieve accuracy {:.3}% with
|
||||
print(
|
||||
"""We achieve accuracy {:.3}% with
|
||||
learning_rate: {:.2}
|
||||
batch_size: {}
|
||||
momentum: {:.2}
|
||||
""".format(100 * accuracy, hyperparameters["learning_rate"],
|
||||
hyperparameters["batch_size"], hyperparameters["momentum"]))
|
||||
""".format(
|
||||
100 * accuracy,
|
||||
hyperparameters["learning_rate"],
|
||||
hyperparameters["batch_size"],
|
||||
hyperparameters["momentum"],
|
||||
)
|
||||
)
|
||||
if accuracy > best_accuracy:
|
||||
best_hyperparameters = hyperparameters
|
||||
best_accuracy = accuracy
|
||||
|
||||
# Record the best performing set of hyperparameters.
|
||||
print("""Best accuracy over {} trials was {:.3} with
|
||||
print(
|
||||
"""Best accuracy over {} trials was {:.3} with
|
||||
learning_rate: {:.2}
|
||||
batch_size: {}
|
||||
momentum: {:.2}
|
||||
""".format(num_evaluations, 100 * best_accuracy,
|
||||
best_hyperparameters["learning_rate"],
|
||||
best_hyperparameters["batch_size"],
|
||||
best_hyperparameters["momentum"]))
|
||||
""".format(
|
||||
num_evaluations,
|
||||
100 * best_accuracy,
|
||||
best_hyperparameters["learning_rate"],
|
||||
best_hyperparameters["batch_size"],
|
||||
best_hyperparameters["momentum"],
|
||||
)
|
||||
)
|
||||
|
|
|
@ -39,8 +39,8 @@ import ray
|
|||
def get_data_loader():
|
||||
"""Safely downloads data. Returns training/validation set dataloader."""
|
||||
mnist_transforms = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))])
|
||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||
)
|
||||
|
||||
# We add FileLock here because multiple workers will want to
|
||||
# download data, and this may cause overwrites since
|
||||
|
@ -48,16 +48,16 @@ def get_data_loader():
|
|||
with FileLock(os.path.expanduser("~/data.lock")):
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
"~/data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=mnist_transforms),
|
||||
"~/data", train=True, download=True, transform=mnist_transforms
|
||||
),
|
||||
batch_size=128,
|
||||
shuffle=True)
|
||||
shuffle=True,
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
|
||||
batch_size=128,
|
||||
shuffle=True)
|
||||
shuffle=True,
|
||||
)
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
|
@ -75,7 +75,7 @@ def evaluate(model, test_loader):
|
|||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += target.size(0)
|
||||
correct += (predicted == target).sum().item()
|
||||
return 100. * correct / total
|
||||
return 100.0 * correct / total
|
||||
|
||||
|
||||
#######################################################################
|
||||
|
@ -144,8 +144,7 @@ class ParameterServer(object):
|
|||
|
||||
def apply_gradients(self, *gradients):
|
||||
summed_gradients = [
|
||||
np.stack(gradient_zip).sum(axis=0)
|
||||
for gradient_zip in zip(*gradients)
|
||||
np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*gradients)
|
||||
]
|
||||
self.optimizer.zero_grad()
|
||||
self.model.set_gradients(summed_gradients)
|
||||
|
@ -215,9 +214,7 @@ test_loader = get_data_loader()[1]
|
|||
print("Running synchronous parameter server training.")
|
||||
current_weights = ps.get_weights.remote()
|
||||
for i in range(iterations):
|
||||
gradients = [
|
||||
worker.compute_gradients.remote(current_weights) for worker in workers
|
||||
]
|
||||
gradients = [worker.compute_gradients.remote(current_weights) for worker in workers]
|
||||
# Calculate update after all gradients are available.
|
||||
current_weights = ps.apply_gradients.remote(*gradients)
|
||||
|
||||
|
|
|
@ -197,7 +197,7 @@ class Model(object):
|
|||
"""Applies the gradients to the model parameters with RMSProp."""
|
||||
for k, v in self.weights.items():
|
||||
g = grad_buffer[k]
|
||||
rmsprop_cache[k] = (decay * rmsprop_cache[k] + (1 - decay) * g**2)
|
||||
rmsprop_cache[k] = decay * rmsprop_cache[k] + (1 - decay) * g ** 2
|
||||
self.weights[k] += lr * g / (np.sqrt(rmsprop_cache[k]) + 1e-5)
|
||||
|
||||
|
||||
|
@ -278,20 +278,24 @@ for i in range(1, 1 + iterations):
|
|||
gradient_ids = []
|
||||
# Launch tasks to compute gradients from multiple rollouts in parallel.
|
||||
start_time = time.time()
|
||||
gradient_ids = [
|
||||
actor.compute_gradient.remote(model_id) for actor in actors
|
||||
]
|
||||
gradient_ids = [actor.compute_gradient.remote(model_id) for actor in actors]
|
||||
for batch in range(batch_size):
|
||||
[grad_id], gradient_ids = ray.wait(gradient_ids)
|
||||
grad, reward_sum = ray.get(grad_id)
|
||||
# Accumulate the gradient over batch.
|
||||
for k in model.weights:
|
||||
grad_buffer[k] += grad[k]
|
||||
running_reward = (reward_sum if running_reward is None else
|
||||
running_reward * 0.99 + reward_sum * 0.01)
|
||||
running_reward = (
|
||||
reward_sum
|
||||
if running_reward is None
|
||||
else running_reward * 0.99 + reward_sum * 0.01
|
||||
)
|
||||
end_time = time.time()
|
||||
print("Batch {} computed {} rollouts in {} seconds, "
|
||||
"running mean is {}".format(i, batch_size, end_time - start_time,
|
||||
running_reward))
|
||||
print(
|
||||
"Batch {} computed {} rollouts in {} seconds, "
|
||||
"running mean is {}".format(
|
||||
i, batch_size, end_time - start_time, running_reward
|
||||
)
|
||||
)
|
||||
model.update(grad_buffer, rmsprop_cache, learning_rate, decay_rate)
|
||||
zero_grads(grad_buffer)
|
||||
|
|
|
@ -23,6 +23,7 @@ from typing import Tuple
|
|||
from time import sleep
|
||||
|
||||
import ray
|
||||
|
||||
# For typing purposes
|
||||
from ray.actor import ActorHandle
|
||||
from tqdm import tqdm
|
||||
|
|
|
@ -8,12 +8,11 @@ import wikipedia
|
|||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-mappers", help="number of mapper actors used", default=3, type=int)
|
||||
"--num-mappers", help="number of mapper actors used", default=3, type=int
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-reducers",
|
||||
help="number of reducer actors used",
|
||||
default=4,
|
||||
type=int)
|
||||
"--num-reducers", help="number of reducer actors used", default=4, type=int
|
||||
)
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
@ -36,8 +35,11 @@ class Mapper(object):
|
|||
while self.num_articles_processed < article_index + 1:
|
||||
self.get_new_article()
|
||||
# Return the word counts from within a given character range.
|
||||
return [(k, v) for k, v in self.word_counts[article_index].items()
|
||||
if len(k) >= 1 and k[0] >= keys[0] and k[0] <= keys[1]]
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in self.word_counts[article_index].items()
|
||||
if len(k) >= 1 and k[0] >= keys[0] and k[0] <= keys[1]
|
||||
]
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
@ -51,8 +53,7 @@ class Reducer(object):
|
|||
# Get the word counts for this Reducer's keys from all of the Mappers
|
||||
# and aggregate the results.
|
||||
count_ids = [
|
||||
mapper.get_range.remote(article_index, self.keys)
|
||||
for mapper in self.mappers
|
||||
mapper.get_range.remote(article_index, self.keys) for mapper in self.mappers
|
||||
]
|
||||
|
||||
while len(count_ids) > 0:
|
||||
|
@ -87,9 +88,9 @@ if __name__ == "__main__":
|
|||
streams.append(Stream([line.strip() for line in f.readlines()]))
|
||||
|
||||
# Partition the keys among the reducers.
|
||||
chunks = np.array_split([chr(i)
|
||||
for i in range(ord("a"),
|
||||
ord("z") + 1)], args.num_reducers)
|
||||
chunks = np.array_split(
|
||||
[chr(i) for i in range(ord("a"), ord("z") + 1)], args.num_reducers
|
||||
)
|
||||
keys = [[chunk[0], chunk[-1]] for chunk in chunks]
|
||||
|
||||
# Create a number of mappers.
|
||||
|
@ -103,14 +104,12 @@ if __name__ == "__main__":
|
|||
while True:
|
||||
print("article index = {}".format(article_index))
|
||||
wordcounts = {}
|
||||
counts = ray.get([
|
||||
reducer.next_reduce_result.remote(article_index)
|
||||
for reducer in reducers
|
||||
])
|
||||
counts = ray.get(
|
||||
[reducer.next_reduce_result.remote(article_index) for reducer in reducers]
|
||||
)
|
||||
for count in counts:
|
||||
wordcounts.update(count)
|
||||
most_frequent_words = heapq.nlargest(
|
||||
10, wordcounts, key=wordcounts.get)
|
||||
most_frequent_words = heapq.nlargest(10, wordcounts, key=wordcounts.get)
|
||||
for word in most_frequent_words:
|
||||
print(" ", word, wordcounts[word])
|
||||
article_index += 1
|
||||
|
|
|
@ -125,12 +125,12 @@ class MNISTDataInterface(object):
|
|||
self.data_dir = data_dir
|
||||
self.max_days = max_days
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
])
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||
)
|
||||
self.dataset = MNIST(
|
||||
self.data_dir, train=True, download=True, transform=transform)
|
||||
self.data_dir, train=True, download=True, transform=transform
|
||||
)
|
||||
|
||||
def _get_day_slice(self, day=0):
|
||||
if day < 0:
|
||||
|
@ -154,8 +154,7 @@ class MNISTDataInterface(object):
|
|||
end = self._get_day_slice(day)
|
||||
|
||||
available_data = Subset(self.dataset, list(range(start, end)))
|
||||
train_n = int(
|
||||
0.8 * (end - start)) # 80% train data, 20% validation data
|
||||
train_n = int(0.8 * (end - start)) # 80% train data, 20% validation data
|
||||
|
||||
return random_split(available_data, [train_n, end - start - train_n])
|
||||
|
||||
|
@ -223,13 +222,15 @@ def test(model, data_loader, device=None):
|
|||
# will take care of creating the model and optimizer and repeatedly
|
||||
# call the ``train`` function to train the model. Also, this function
|
||||
# will report the training progress back to Tune.
|
||||
def train_mnist(config,
|
||||
start_model=None,
|
||||
checkpoint_dir=None,
|
||||
num_epochs=10,
|
||||
use_gpus=False,
|
||||
data_fn=None,
|
||||
day=0):
|
||||
def train_mnist(
|
||||
config,
|
||||
start_model=None,
|
||||
checkpoint_dir=None,
|
||||
num_epochs=10,
|
||||
use_gpus=False,
|
||||
data_fn=None,
|
||||
day=0,
|
||||
):
|
||||
# Create model
|
||||
use_cuda = use_gpus and torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
@ -237,7 +238,8 @@ def train_mnist(config,
|
|||
|
||||
# Create optimizer
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=config["lr"], momentum=config["momentum"])
|
||||
model.parameters(), lr=config["lr"], momentum=config["momentum"]
|
||||
)
|
||||
|
||||
# Load checkpoint, or load start model if no checkpoint has been
|
||||
# passed and a start model is specified
|
||||
|
@ -248,8 +250,7 @@ def train_mnist(config,
|
|||
load_dir = start_model
|
||||
|
||||
if load_dir:
|
||||
model_state, optimizer_state = torch.load(
|
||||
os.path.join(load_dir, "checkpoint"))
|
||||
model_state, optimizer_state = torch.load(os.path.join(load_dir, "checkpoint"))
|
||||
model.load_state_dict(model_state)
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
|
@ -257,18 +258,22 @@ def train_mnist(config,
|
|||
train_dataset, validation_dataset = data_fn(day=day)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=config["batch_size"], shuffle=True)
|
||||
train_dataset, batch_size=config["batch_size"], shuffle=True
|
||||
)
|
||||
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
validation_dataset, batch_size=config["batch_size"], shuffle=True)
|
||||
validation_dataset, batch_size=config["batch_size"], shuffle=True
|
||||
)
|
||||
|
||||
for i in range(num_epochs):
|
||||
train(model, optimizer, train_loader, device)
|
||||
acc = test(model, validation_loader, device)
|
||||
if i == num_epochs - 1:
|
||||
with tune.checkpoint_dir(step=i) as checkpoint_dir:
|
||||
torch.save((model.state_dict(), optimizer.state_dict()),
|
||||
os.path.join(checkpoint_dir, "checkpoint"))
|
||||
torch.save(
|
||||
(model.state_dict(), optimizer.state_dict()),
|
||||
os.path.join(checkpoint_dir, "checkpoint"),
|
||||
)
|
||||
tune.report(mean_accuracy=acc, done=True)
|
||||
else:
|
||||
tune.report(mean_accuracy=acc)
|
||||
|
@ -286,7 +291,7 @@ def train_mnist(config,
|
|||
# until the given day. Our search space can thus also contain parameters
|
||||
# that affect the model complexity (such as the layer size), since it
|
||||
# does not have to be compatible to an existing model.
|
||||
def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
|
||||
def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0):
|
||||
data_interface = MNISTDataInterface("~/data", max_days=10)
|
||||
num_examples = data_interface._get_day_slice(day)
|
||||
|
||||
|
@ -302,11 +307,13 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
|
|||
mode="max",
|
||||
max_t=num_epochs,
|
||||
grace_period=1,
|
||||
reduction_factor=2)
|
||||
reduction_factor=2,
|
||||
)
|
||||
|
||||
reporter = CLIReporter(
|
||||
parameter_columns=["layer_size", "lr", "momentum", "batch_size"],
|
||||
metric_columns=["mean_accuracy", "training_iteration"])
|
||||
metric_columns=["mean_accuracy", "training_iteration"],
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
partial(
|
||||
|
@ -315,17 +322,16 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
|
|||
data_fn=data_interface.get_data,
|
||||
num_epochs=num_epochs,
|
||||
use_gpus=True if gpus_per_trial > 0 else False,
|
||||
day=day),
|
||||
resources_per_trial={
|
||||
"cpu": 1,
|
||||
"gpu": gpus_per_trial
|
||||
},
|
||||
day=day,
|
||||
),
|
||||
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
|
||||
config=config,
|
||||
num_samples=num_samples,
|
||||
scheduler=scheduler,
|
||||
progress_reporter=reporter,
|
||||
verbose=0,
|
||||
name="tune_serve_mnist_fromscratch")
|
||||
name="tune_serve_mnist_fromscratch",
|
||||
)
|
||||
|
||||
best_trial = analysis.get_best_trial("mean_accuracy", "max", "last")
|
||||
best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"]
|
||||
|
@ -344,33 +350,35 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
|
|||
# layer size parameter. Since we continue to train an existing model,
|
||||
# we cannot change the layer size mid training, so we just continue
|
||||
# to use the existing one.
|
||||
def tune_from_existing(start_model,
|
||||
start_config,
|
||||
num_samples=10,
|
||||
num_epochs=10,
|
||||
gpus_per_trial=0.,
|
||||
day=0):
|
||||
def tune_from_existing(
|
||||
start_model, start_config, num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0
|
||||
):
|
||||
data_interface = MNISTDataInterface("/tmp/mnist_data", max_days=10)
|
||||
num_examples = data_interface._get_day_slice(day) - \
|
||||
data_interface._get_day_slice(day - 1)
|
||||
num_examples = data_interface._get_day_slice(day) - data_interface._get_day_slice(
|
||||
day - 1
|
||||
)
|
||||
|
||||
config = start_config.copy()
|
||||
config.update({
|
||||
"batch_size": tune.choice([16, 32, 64]),
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"momentum": tune.uniform(0.1, 0.9),
|
||||
})
|
||||
config.update(
|
||||
{
|
||||
"batch_size": tune.choice([16, 32, 64]),
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"momentum": tune.uniform(0.1, 0.9),
|
||||
}
|
||||
)
|
||||
|
||||
scheduler = ASHAScheduler(
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
max_t=num_epochs,
|
||||
grace_period=1,
|
||||
reduction_factor=2)
|
||||
reduction_factor=2,
|
||||
)
|
||||
|
||||
reporter = CLIReporter(
|
||||
parameter_columns=["lr", "momentum", "batch_size"],
|
||||
metric_columns=["mean_accuracy", "training_iteration"])
|
||||
metric_columns=["mean_accuracy", "training_iteration"],
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
partial(
|
||||
|
@ -379,17 +387,16 @@ def tune_from_existing(start_model,
|
|||
data_fn=data_interface.get_incremental_data,
|
||||
num_epochs=num_epochs,
|
||||
use_gpus=True if gpus_per_trial > 0 else False,
|
||||
day=day),
|
||||
resources_per_trial={
|
||||
"cpu": 1,
|
||||
"gpu": gpus_per_trial
|
||||
},
|
||||
day=day,
|
||||
),
|
||||
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
|
||||
config=config,
|
||||
num_samples=num_samples,
|
||||
scheduler=scheduler,
|
||||
progress_reporter=reporter,
|
||||
verbose=0,
|
||||
name="tune_serve_mnist_fromsexisting")
|
||||
name="tune_serve_mnist_fromsexisting",
|
||||
)
|
||||
|
||||
best_trial = analysis.get_best_trial("mean_accuracy", "max", "last")
|
||||
best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"]
|
||||
|
@ -423,8 +430,8 @@ class MNISTDeployment:
|
|||
model = ConvNet(layer_size=self.config["layer_size"]).to(self.device)
|
||||
|
||||
model_state, optimizer_state = torch.load(
|
||||
os.path.join(self.checkpoint_dir, "checkpoint"),
|
||||
map_location=self.device)
|
||||
os.path.join(self.checkpoint_dir, "checkpoint"), map_location=self.device
|
||||
)
|
||||
model.load_state_dict(model_state)
|
||||
|
||||
self.model = model
|
||||
|
@ -442,12 +449,12 @@ class MNISTDeployment:
|
|||
# active model. We call this directory ``model_dir``. Every time we
|
||||
# would like to update our model, we copy the checkpoint of the new
|
||||
# model to this directory. We then update the deployment to the new version.
|
||||
def serve_new_model(model_dir, checkpoint, config, metrics, day,
|
||||
use_gpu=False):
|
||||
def serve_new_model(model_dir, checkpoint, config, metrics, day, use_gpu=False):
|
||||
print("Serving checkpoint: {}".format(checkpoint))
|
||||
|
||||
checkpoint_path = _move_checkpoint_to_model_dir(model_dir, checkpoint,
|
||||
config, metrics)
|
||||
checkpoint_path = _move_checkpoint_to_model_dir(
|
||||
model_dir, checkpoint, config, metrics
|
||||
)
|
||||
|
||||
serve.start(detached=True)
|
||||
MNISTDeployment.deploy(checkpoint_path, config, metrics, use_gpu)
|
||||
|
@ -482,8 +489,7 @@ def get_current_model(model_dir):
|
|||
checkpoint_path = os.path.join(model_dir, "checkpoint")
|
||||
meta_path = os.path.join(model_dir, "meta.json")
|
||||
|
||||
if not os.path.exists(checkpoint_path) or \
|
||||
not os.path.exists(meta_path):
|
||||
if not os.path.exists(checkpoint_path) or not os.path.exists(meta_path):
|
||||
return None, None, None
|
||||
|
||||
with open(meta_path, "rt") as fp:
|
||||
|
@ -559,28 +565,33 @@ if __name__ == "__main__":
|
|||
"--from_scratch",
|
||||
action="store_true",
|
||||
help="Train and select best model from scratch",
|
||||
default=False)
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--from_existing",
|
||||
action="store_true",
|
||||
help="Train and select best model from existing model",
|
||||
default=False)
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--day",
|
||||
help="Indicate the day to simulate the amount of data available to us",
|
||||
type=int,
|
||||
default=0)
|
||||
default=0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query", help="Query endpoint with example", type=int, default=-1)
|
||||
"--query", help="Query endpoint with example", type=int, default=-1
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
help="Finish quickly for testing",
|
||||
default=False)
|
||||
default=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -600,20 +611,23 @@ if __name__ == "__main__":
|
|||
|
||||
# Query our model
|
||||
response = requests.post(
|
||||
"http://localhost:8000/mnist",
|
||||
json={"images": [data[0].numpy().tolist()]})
|
||||
"http://localhost:8000/mnist", json={"images": [data[0].numpy().tolist()]}
|
||||
)
|
||||
|
||||
try:
|
||||
pred = response.json()["result"][0]
|
||||
except: # noqa: E722
|
||||
pred = -1
|
||||
|
||||
print("Querying model with example #{}. "
|
||||
"Label = {}, Response = {}, Correct = {}".format(
|
||||
args.query, label, pred, label == pred))
|
||||
print(
|
||||
"Querying model with example #{}. "
|
||||
"Label = {}, Response = {}, Correct = {}".format(
|
||||
args.query, label, pred, label == pred
|
||||
)
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
gpus_per_trial = 0.5 if not args.smoke_test else 0.
|
||||
gpus_per_trial = 0.5 if not args.smoke_test else 0.0
|
||||
serve_gpu = True if gpus_per_trial > 0 else False
|
||||
num_samples = 8 if not args.smoke_test else 1
|
||||
num_epochs = 10 if not args.smoke_test else 1
|
||||
|
@ -621,23 +635,22 @@ if __name__ == "__main__":
|
|||
if args.from_scratch: # train everyday from scratch
|
||||
print("Start training job from scratch on day {}.".format(args.day))
|
||||
acc, config, best_checkpoint, num_examples = tune_from_scratch(
|
||||
num_samples, num_epochs, gpus_per_trial, day=args.day)
|
||||
print("Trained day {} from scratch on {} samples. "
|
||||
"Best accuracy: {:.4f}. Best config: {}".format(
|
||||
args.day, num_examples, acc, config))
|
||||
num_samples, num_epochs, gpus_per_trial, day=args.day
|
||||
)
|
||||
print(
|
||||
"Trained day {} from scratch on {} samples. "
|
||||
"Best accuracy: {:.4f}. Best config: {}".format(
|
||||
args.day, num_examples, acc, config
|
||||
)
|
||||
)
|
||||
serve_new_model(
|
||||
model_dir,
|
||||
best_checkpoint,
|
||||
config,
|
||||
acc,
|
||||
args.day,
|
||||
use_gpu=serve_gpu)
|
||||
model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu
|
||||
)
|
||||
|
||||
if args.from_existing:
|
||||
old_checkpoint, old_config, old_acc = get_current_model(model_dir)
|
||||
if not old_checkpoint or not old_config or not old_acc:
|
||||
print("No existing model found. Train one with --from_scratch "
|
||||
"first.")
|
||||
print("No existing model found. Train one with --from_scratch " "first.")
|
||||
sys.exit(1)
|
||||
acc, config, best_checkpoint, num_examples = tune_from_existing(
|
||||
old_checkpoint,
|
||||
|
@ -645,17 +658,17 @@ if __name__ == "__main__":
|
|||
num_samples,
|
||||
num_epochs,
|
||||
gpus_per_trial,
|
||||
day=args.day)
|
||||
print("Trained day {} from existing on {} samples. "
|
||||
"Best accuracy: {:.4f}. Best config: {}".format(
|
||||
args.day, num_examples, acc, config))
|
||||
day=args.day,
|
||||
)
|
||||
print(
|
||||
"Trained day {} from existing on {} samples. "
|
||||
"Best accuracy: {:.4f}. Best config: {}".format(
|
||||
args.day, num_examples, acc, config
|
||||
)
|
||||
)
|
||||
serve_new_model(
|
||||
model_dir,
|
||||
best_checkpoint,
|
||||
config,
|
||||
acc,
|
||||
args.day,
|
||||
use_gpu=serve_gpu)
|
||||
model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu
|
||||
)
|
||||
|
||||
#######################################################################
|
||||
# That's it! We now have an end-to-end workflow to train and update a
|
||||
|
|
|
@ -38,6 +38,7 @@ To start out, change the import statement to get tune-scikit-learn’s grid sear
|
|||
"""
|
||||
# Keep this here for https://github.com/ray-project/ray/issues/11547
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
|
||||
# Replace above line with:
|
||||
from ray.tune.sklearn import TuneGridSearchCV
|
||||
|
||||
|
@ -60,7 +61,8 @@ X, y = make_classification(
|
|||
n_informative=50,
|
||||
n_redundant=0,
|
||||
n_classes=10,
|
||||
class_sep=2.5)
|
||||
class_sep=2.5,
|
||||
)
|
||||
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=1000)
|
||||
|
||||
# Example parameters to tune from SGDClassifier
|
||||
|
@ -70,9 +72,11 @@ parameter_grid = {"alpha": [1e-4, 1e-1, 1], "epsilon": [0.01, 0.1]}
|
|||
# As you can see, the setup here is exactly how you would do it for Scikit-Learn. Now, let's try fitting a model.
|
||||
|
||||
tune_search = TuneGridSearchCV(
|
||||
SGDClassifier(), parameter_grid, early_stopping=True, max_iters=10)
|
||||
SGDClassifier(), parameter_grid, early_stopping=True, max_iters=10
|
||||
)
|
||||
|
||||
import time # Just to compare fit times
|
||||
|
||||
start = time.time()
|
||||
tune_search.fit(x_train, y_train)
|
||||
end = time.time()
|
||||
|
@ -93,6 +97,7 @@ print("Tune GridSearch Fit Time:", end - start)
|
|||
# Try running this compared to the GridSearchCV equivalent, and see the speedup for yourself!
|
||||
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
|
||||
# n_jobs=-1 enables use of all cores like Tune does
|
||||
sklearn_search = GridSearchCV(SGDClassifier(), parameter_grid, n_jobs=-1)
|
||||
|
||||
|
@ -120,7 +125,7 @@ import numpy as np
|
|||
digits = datasets.load_digits()
|
||||
x = digits.data
|
||||
y = digits.target
|
||||
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.2)
|
||||
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
|
||||
|
||||
clf = SGDClassifier()
|
||||
parameter_grid = {"alpha": (1e-4, 1), "epsilon": (0.01, 0.1)}
|
||||
|
|
|
@ -8,8 +8,9 @@ import ray
|
|||
def get_host_name(x):
|
||||
import platform
|
||||
import time
|
||||
|
||||
time.sleep(0.01)
|
||||
return x + (platform.node(), )
|
||||
return x + (platform.node(),)
|
||||
|
||||
|
||||
def wait_for_nodes(expected):
|
||||
|
@ -17,8 +18,11 @@ def wait_for_nodes(expected):
|
|||
while True:
|
||||
num_nodes = len(ray.nodes())
|
||||
if num_nodes < expected:
|
||||
print("{} nodes have joined so far, waiting for {} more.".format(
|
||||
num_nodes, expected - num_nodes))
|
||||
print(
|
||||
"{} nodes have joined so far, waiting for {} more.".format(
|
||||
num_nodes, expected - num_nodes
|
||||
)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
time.sleep(1)
|
||||
else:
|
||||
|
@ -31,9 +35,7 @@ def main():
|
|||
# Check that objects can be transferred from each node to each other node.
|
||||
for i in range(10):
|
||||
print("Iteration {}".format(i))
|
||||
results = [
|
||||
get_host_name.remote(get_host_name.remote(())) for _ in range(100)
|
||||
]
|
||||
results = [get_host_name.remote(get_host_name.remote(())) for _ in range(100)]
|
||||
print(Counter(ray.get(results)))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
|
|
@ -20,8 +20,9 @@ def setup_logging() -> None:
|
|||
setup_component_logger(
|
||||
logging_level=ray_constants.LOGGER_LEVEL, # info
|
||||
logging_format=ray_constants.LOGGER_FORMAT,
|
||||
log_dir=os.path.join(ray._private.utils.get_ray_temp_dir(),
|
||||
ray.node.SESSION_LATEST, "logs"),
|
||||
log_dir=os.path.join(
|
||||
ray._private.utils.get_ray_temp_dir(), ray.node.SESSION_LATEST, "logs"
|
||||
),
|
||||
filename=ray_constants.MONITOR_LOG_FILE_NAME, # monitor.log
|
||||
max_bytes=ray_constants.LOGGING_ROTATE_BYTES,
|
||||
backup_count=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
|
||||
|
@ -47,11 +48,11 @@ if __name__ == "__main__":
|
|||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="The password to use for Redis")
|
||||
help="The password to use for Redis",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
cluster_name = yaml.safe_load(
|
||||
open(AUTOSCALING_CONFIG_PATH).read())["cluster_name"]
|
||||
cluster_name = yaml.safe_load(open(AUTOSCALING_CONFIG_PATH).read())["cluster_name"]
|
||||
head_ip = get_node_ip_address()
|
||||
Monitor(
|
||||
address=f"{head_ip}:6379",
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue