[CI] Format Python code with Black (#21975)

See #21316 and #21311 for the motivation behind these changes.
This commit is contained in:
Balaji Veeramani 2022-01-29 18:41:57 -08:00 committed by GitHub
parent 95877be8ee
commit 7f1bacc7dc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
1637 changed files with 75167 additions and 59155 deletions

View file

@ -39,14 +39,16 @@ def perform_auth():
resp = requests.get( resp = requests.get(
"https://vop4ss7n22.execute-api.us-west-2.amazonaws.com/endpoint/", "https://vop4ss7n22.execute-api.us-west-2.amazonaws.com/endpoint/",
auth=auth, auth=auth,
params={"job_id": os.environ["BUILDKITE_JOB_ID"]}) params={"job_id": os.environ["BUILDKITE_JOB_ID"]},
)
return resp return resp
def handle_docker_login(resp): def handle_docker_login(resp):
pwd = resp.json()["docker_password"] pwd = resp.json()["docker_password"]
subprocess.call( subprocess.call(
["docker", "login", "--username", "raytravisbot", "--password", pwd]) ["docker", "login", "--username", "raytravisbot", "--password", pwd]
)
def gather_paths(dir_path) -> List[str]: def gather_paths(dir_path) -> List[str]:
@ -86,7 +88,7 @@ def upload_paths(paths, resp, destination):
"branch_wheels": f"{branch}/{sha}/{fn}", "branch_wheels": f"{branch}/{sha}/{fn}",
"jars": f"jars/latest/{current_os}/{fn}", "jars": f"jars/latest/{current_os}/{fn}",
"branch_jars": f"jars/{branch}/{sha}/{current_os}/{fn}", "branch_jars": f"jars/{branch}/{sha}/{current_os}/{fn}",
"logs": f"bazel_events/{branch}/{sha}/{bk_job_id}/{fn}" "logs": f"bazel_events/{branch}/{sha}/{bk_job_id}/{fn}",
}[destination] }[destination]
of["file"] = open(path, "rb") of["file"] = open(path, "rb")
r = requests.post(c["url"], files=of) r = requests.post(c["url"], files=of)
@ -95,14 +97,19 @@ def upload_paths(paths, resp, destination):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Helper script to upload files to S3 bucket") description="Helper script to upload files to S3 bucket"
)
parser.add_argument("--path", type=str, required=False) parser.add_argument("--path", type=str, required=False)
parser.add_argument("--destination", type=str) parser.add_argument("--destination", type=str)
args = parser.parse_args() args = parser.parse_args()
assert args.destination in { assert args.destination in {
"branch_jars", "branch_wheels", "jars", "logs", "wheels", "branch_jars",
"docker_login" "branch_wheels",
"jars",
"logs",
"wheels",
"docker_login",
} }
assert "BUILDKITE_JOB_ID" in os.environ assert "BUILDKITE_JOB_ID" in os.environ
assert "BUILDKITE_COMMIT" in os.environ assert "BUILDKITE_COMMIT" in os.environ

View file

@ -51,8 +51,10 @@ del monitor_actor
test_utils.wait_for_condition(no_resource_leaks) test_utils.wait_for_condition(no_resource_leaks)
rate = MAX_ACTORS_IN_CLUSTER / (end_time - start_time) rate = MAX_ACTORS_IN_CLUSTER / (end_time - start_time)
print(f"Success! Started {MAX_ACTORS_IN_CLUSTER} actors in " print(
f"{end_time - start_time}s. ({rate} actors/s)") f"Success! Started {MAX_ACTORS_IN_CLUSTER} actors in "
f"{end_time - start_time}s. ({rate} actors/s)"
)
if "TEST_OUTPUT_JSON" in os.environ: if "TEST_OUTPUT_JSON" in os.environ:
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w") out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
@ -62,6 +64,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
"time": end_time - start_time, "time": end_time - start_time,
"success": "1", "success": "1",
"_peak_memory": round(used_gb, 2), "_peak_memory": round(used_gb, 2),
"_peak_process_memory": usage "_peak_process_memory": usage,
} }
json.dump(results, out_file) json.dump(results, out_file)

View file

@ -77,8 +77,10 @@ del monitor_actor
test_utils.wait_for_condition(no_resource_leaks) test_utils.wait_for_condition(no_resource_leaks)
rate = MAX_PLACEMENT_GROUPS / (end_time - start_time) rate = MAX_PLACEMENT_GROUPS / (end_time - start_time)
print(f"Success! Started {MAX_PLACEMENT_GROUPS} pgs in " print(
f"{end_time - start_time}s. ({rate} pgs/s)") f"Success! Started {MAX_PLACEMENT_GROUPS} pgs in "
f"{end_time - start_time}s. ({rate} pgs/s)"
)
if "TEST_OUTPUT_JSON" in os.environ: if "TEST_OUTPUT_JSON" in os.environ:
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w") out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
@ -88,6 +90,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
"time": end_time - start_time, "time": end_time - start_time,
"success": "1", "success": "1",
"_peak_memory": round(used_gb, 2), "_peak_memory": round(used_gb, 2),
"_peak_process_memory": usage "_peak_process_memory": usage,
} }
json.dump(results, out_file) json.dump(results, out_file)

View file

@ -16,9 +16,7 @@ def test_max_running_tasks(num_tasks):
def task(): def task():
time.sleep(sleep_time) time.sleep(sleep_time)
refs = [ refs = [task.remote() for _ in tqdm.trange(num_tasks, desc="Launching tasks")]
task.remote() for _ in tqdm.trange(num_tasks, desc="Launching tasks")
]
max_cpus = ray.cluster_resources()["CPU"] max_cpus = ray.cluster_resources()["CPU"]
min_cpus_available = max_cpus min_cpus_available = max_cpus
@ -48,8 +46,7 @@ def no_resource_leaks():
@click.command() @click.command()
@click.option( @click.option("--num-tasks", required=True, type=int, help="Number of tasks to launch.")
"--num-tasks", required=True, type=int, help="Number of tasks to launch.")
def test(num_tasks): def test(num_tasks):
ray.init(address="auto") ray.init(address="auto")
@ -66,8 +63,10 @@ def test(num_tasks):
test_utils.wait_for_condition(no_resource_leaks) test_utils.wait_for_condition(no_resource_leaks)
rate = num_tasks / (end_time - start_time - sleep_time) rate = num_tasks / (end_time - start_time - sleep_time)
print(f"Success! Started {num_tasks} tasks in {end_time - start_time}s. " print(
f"({rate} tasks/s)") f"Success! Started {num_tasks} tasks in {end_time - start_time}s. "
f"({rate} tasks/s)"
)
if "TEST_OUTPUT_JSON" in os.environ: if "TEST_OUTPUT_JSON" in os.environ:
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w") out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
@ -77,7 +76,7 @@ def test(num_tasks):
"time": end_time - start_time, "time": end_time - start_time,
"success": "1", "success": "1",
"_peak_memory": round(used_gb, 2), "_peak_memory": round(used_gb, 2),
"_peak_process_memory": usage "_peak_process_memory": usage,
} }
json.dump(results, out_file) json.dump(results, out_file)

View file

@ -25,10 +25,12 @@ class SimpleActor:
def start_tasks(num_task, num_cpu_per_task, task_duration): def start_tasks(num_task, num_cpu_per_task, task_duration):
ray.get([ ray.get(
simple_task.options(num_cpus=num_cpu_per_task).remote(task_duration) [
for _ in range(num_task) simple_task.options(num_cpus=num_cpu_per_task).remote(task_duration)
]) for _ in range(num_task)
]
)
def measure(f): def measure(f):
@ -40,13 +42,16 @@ def measure(f):
def start_actor(num_actors, num_actors_per_nodes, job): def start_actor(num_actors, num_actors_per_nodes, job):
resources = {"node": floor(1.0 / num_actors_per_nodes)} resources = {"node": floor(1.0 / num_actors_per_nodes)}
submission_cost, actors = measure(lambda: [ submission_cost, actors = measure(
SimpleActor.options(resources=resources, num_cpus=0).remote(job) lambda: [
for _ in range(num_actors)]) SimpleActor.options(resources=resources, num_cpus=0).remote(job)
ready_cost, _ = measure( for _ in range(num_actors)
lambda: ray.get([actor.ready.remote() for actor in actors])) ]
)
ready_cost, _ = measure(lambda: ray.get([actor.ready.remote() for actor in actors]))
actor_job_cost, _ = measure( actor_job_cost, _ = measure(
lambda: ray.get([actor.do_job.remote() for actor in actors])) lambda: ray.get([actor.do_job.remote() for actor in actors])
)
return (submission_cost, ready_cost, actor_job_cost) return (submission_cost, ready_cost, actor_job_cost)
@ -54,33 +59,32 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(prog="Test Scheduling") parser = argparse.ArgumentParser(prog="Test Scheduling")
# Task workloads # Task workloads
parser.add_argument( parser.add_argument(
"--total-num-task", "--total-num-task", type=int, help="Total number of tasks.", required=False
type=int, )
help="Total number of tasks.",
required=False)
parser.add_argument( parser.add_argument(
"--num-cpu-per-task", "--num-cpu-per-task",
type=int, type=int,
help="Resources needed for tasks.", help="Resources needed for tasks.",
required=False) required=False,
)
parser.add_argument( parser.add_argument(
"--task-duration-s", "--task-duration-s",
type=int, type=int,
help="How long does each task execute.", help="How long does each task execute.",
required=False, required=False,
default=1) default=1,
)
# Actor workloads # Actor workloads
parser.add_argument( parser.add_argument(
"--total-num-actors", "--total-num-actors", type=int, help="Total number of actors.", required=True
type=int, )
help="Total number of actors.",
required=True)
parser.add_argument( parser.add_argument(
"--num-actors-per-nodes", "--num-actors-per-nodes",
type=int, type=int,
help="How many actors to allocate for each nodes.", help="How many actors to allocate for each nodes.",
required=True) required=True,
)
ray.init(address="auto") ray.init(address="auto")
@ -92,13 +96,14 @@ if __name__ == "__main__":
job = None job = None
if args.total_num_task is not None: if args.total_num_task is not None:
if args.num_cpu_per_task is None: if args.num_cpu_per_task is None:
args.num_cpu_per_task = floor( args.num_cpu_per_task = floor(1.0 * total_cpus / args.total_num_task)
1.0 * total_cpus / args.total_num_task)
job = lambda: start_tasks( # noqa: E731 job = lambda: start_tasks( # noqa: E731
args.total_num_task, args.num_cpu_per_task, args.task_duration_s) args.total_num_task, args.num_cpu_per_task, args.task_duration_s
)
submission_cost, ready_cost, actor_job_cost = start_actor( submission_cost, ready_cost, actor_job_cost = start_actor(
args.total_num_actors, args.num_actors_per_nodes, job) args.total_num_actors, args.num_actors_per_nodes, job
)
output = os.environ.get("TEST_OUTPUT_JSON") output = os.environ.get("TEST_OUTPUT_JSON")
@ -118,6 +123,7 @@ if __name__ == "__main__":
if output is not None: if output is not None:
from pathlib import Path from pathlib import Path
p = Path(output) p = Path(output)
p.write_text(json.dumps(result)) p.write_text(json.dumps(result))

View file

@ -12,8 +12,7 @@ def num_alive_nodes():
@click.command() @click.command()
@click.option( @click.option("--num-nodes", required=True, type=int, help="The target number of nodes")
"--num-nodes", required=True, type=int, help="The target number of nodes")
def wait_cluster(num_nodes: int): def wait_cluster(num_nodes: int):
ray.init(address="auto") ray.init(address="auto")
while num_alive_nodes() != num_nodes: while num_alive_nodes() != num_nodes:

View file

@ -9,7 +9,7 @@ from time import perf_counter
from tqdm import tqdm from tqdm import tqdm
NUM_NODES = 50 NUM_NODES = 50
OBJECT_SIZE = 2**30 OBJECT_SIZE = 2 ** 30
def num_alive_nodes(): def num_alive_nodes():
@ -60,6 +60,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
"broadcast_time": end - start, "broadcast_time": end - start,
"object_size": OBJECT_SIZE, "object_size": OBJECT_SIZE,
"num_nodes": NUM_NODES, "num_nodes": NUM_NODES,
"success": "1" "success": "1",
} }
json.dump(results, out_file) json.dump(results, out_file)

View file

@ -13,7 +13,7 @@ MAX_ARGS = 10000
MAX_RETURNS = 3000 MAX_RETURNS = 3000
MAX_RAY_GET_ARGS = 10000 MAX_RAY_GET_ARGS = 10000
MAX_QUEUED_TASKS = 1_000_000 MAX_QUEUED_TASKS = 1_000_000
MAX_RAY_GET_SIZE = 100 * 2**30 MAX_RAY_GET_SIZE = 100 * 2 ** 30
def assert_no_leaks(): def assert_no_leaks():
@ -189,8 +189,7 @@ print(f"Many args time: {args_time} ({MAX_ARGS} args)")
print(f"Many returns time: {returns_time} ({MAX_RETURNS} returns)") print(f"Many returns time: {returns_time} ({MAX_RETURNS} returns)")
print(f"Ray.get time: {get_time} ({MAX_RAY_GET_ARGS} args)") print(f"Ray.get time: {get_time} ({MAX_RAY_GET_ARGS} args)")
print(f"Queued task time: {queued_time} ({MAX_QUEUED_TASKS} tasks)") print(f"Queued task time: {queued_time} ({MAX_QUEUED_TASKS} tasks)")
print(f"Ray.get large object time: {large_object_time} " print(f"Ray.get large object time: {large_object_time} " f"({MAX_RAY_GET_SIZE} bytes)")
f"({MAX_RAY_GET_SIZE} bytes)")
if "TEST_OUTPUT_JSON" in os.environ: if "TEST_OUTPUT_JSON" in os.environ:
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w") out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
@ -205,6 +204,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
"num_queued": MAX_QUEUED_TASKS, "num_queued": MAX_QUEUED_TASKS,
"large_object_time": large_object_time, "large_object_time": large_object_time,
"large_object_size": MAX_RAY_GET_SIZE, "large_object_size": MAX_RAY_GET_SIZE,
"success": "1" "success": "1",
} }
json.dump(results, out_file) json.dump(results, out_file)

View file

@ -66,7 +66,7 @@ def get_remote_url(remote):
def replace_suffix(base, old_suffix, new_suffix=""): def replace_suffix(base, old_suffix, new_suffix=""):
if base.endswith(old_suffix): if base.endswith(old_suffix):
base = base[:len(base) - len(old_suffix)] + new_suffix base = base[: len(base) - len(old_suffix)] + new_suffix
return base return base
@ -199,12 +199,21 @@ def monitor():
expected_line = "{}\t{}".format(expected_sha, ref) expected_line = "{}\t{}".format(expected_sha, ref)
if should_keep_alive(git("show", "-s", "--format=%B", "HEAD^-")): if should_keep_alive(git("show", "-s", "--format=%B", "HEAD^-")):
logger.info("Not monitoring %s on %s due to keep-alive on: %s", ref, logger.info(
remote, expected_line) "Not monitoring %s on %s due to keep-alive on: %s",
ref,
remote,
expected_line,
)
return return
logger.info("Monitoring %s (%s) for changes in %s: %s", remote, logger.info(
get_remote_url(remote), ref, expected_line) "Monitoring %s (%s) for changes in %s: %s",
remote,
get_remote_url(remote),
ref,
expected_line,
)
for to_wait in yield_poll_schedule(): for to_wait in yield_poll_schedule():
time.sleep(to_wait) time.sleep(to_wait)
@ -217,12 +226,21 @@ def monitor():
status = ex.returncode status = ex.returncode
if status == 2: if status == 2:
logger.info("Terminating job as %s has been deleted on %s: %s", logger.info(
ref, remote, expected_line) "Terminating job as %s has been deleted on %s: %s",
ref,
remote,
expected_line,
)
break break
elif status != 0: elif status != 0:
logger.error("Error %d: unable to check %s on %s: %s", status, ref, logger.error(
remote, expected_line) "Error %d: unable to check %s on %s: %s",
status,
ref,
remote,
expected_line,
)
else: else:
prev = expected_line prev = expected_line
expected_line = detect_spurious_commit(line, expected_line, remote) expected_line = detect_spurious_commit(line, expected_line, remote)
@ -230,14 +248,24 @@ def monitor():
logger.info( logger.info(
"Terminating job as %s has been updated on %s\n" "Terminating job as %s has been updated on %s\n"
" from:\t%s\n" " from:\t%s\n"
" to: \t%s", ref, remote, expected_line, line) " to: \t%s",
ref,
remote,
expected_line,
line,
)
time.sleep(1) # wait for CI to flush output time.sleep(1) # wait for CI to flush output
break break
if expected_line != prev: if expected_line != prev:
logger.info( logger.info(
"%s appeared to spuriously change on %s\n" "%s appeared to spuriously change on %s\n"
" from:\t%s\n" " from:\t%s\n"
" to: \t%s", ref, remote, prev, expected_line) " to: \t%s",
ref,
remote,
prev,
expected_line,
)
return terminate_my_process_group() return terminate_my_process_group()
@ -259,9 +287,8 @@ def main(program, *args):
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig( logging.basicConfig(
format="%(levelname)s: %(message)s", format="%(levelname)s: %(message)s", stream=sys.stderr, level=logging.DEBUG
stream=sys.stderr, )
level=logging.DEBUG)
try: try:
raise SystemExit(main(*sys.argv) or 0) raise SystemExit(main(*sys.argv) or 0)
except KeyboardInterrupt: except KeyboardInterrupt:

View file

@ -53,7 +53,10 @@ def maybe_fetch_buildkite_token():
"secretsmanager", region_name="us-west-2" "secretsmanager", region_name="us-west-2"
).get_secret_value( ).get_secret_value(
SecretId="arn:aws:secretsmanager:us-west-2:029272617770:secret:" SecretId="arn:aws:secretsmanager:us-west-2:029272617770:secret:"
"buildkite/ro-token")["SecretString"] "buildkite/ro-token"
)[
"SecretString"
]
def escape(v: Any): def escape(v: Any):
@ -85,26 +88,26 @@ def env_str(env: Dict[str, Any]):
def script_str(v: Any): def script_str(v: Any):
if isinstance(v, bool): if isinstance(v, bool):
return f"\"{int(v)}\"" return f'"{int(v)}"'
elif isinstance(v, Number): elif isinstance(v, Number):
return f"\"{v}\"" return f'"{v}"'
elif isinstance(v, list): elif isinstance(v, list):
return "(" + " ".join(f"\"{shlex.quote(w)}\"" for w in v) + ")" return "(" + " ".join(f'"{shlex.quote(w)}"' for w in v) + ")"
else: else:
return f"\"{shlex.quote(v)}\"" return f'"{shlex.quote(v)}"'
class ReproSession: class ReproSession:
plugin_default_env = { plugin_default_env = {
"docker": { "docker": {"BUILDKITE_PLUGIN_DOCKER_MOUNT_BUILDKITE_AGENT": False}
"BUILDKITE_PLUGIN_DOCKER_MOUNT_BUILDKITE_AGENT": False
}
} }
def __init__(self, def __init__(
buildkite_token: str, self,
instance_name: Optional[str] = None, buildkite_token: str,
logger: Optional[logging.Logger] = None): instance_name: Optional[str] = None,
logger: Optional[logging.Logger] = None,
):
self.logger = logger or logging.getLogger(self.__class__.__name__) self.logger = logger or logging.getLogger(self.__class__.__name__)
self.bk = Buildkite() self.bk = Buildkite()
@ -139,12 +142,15 @@ class ReproSession:
# https://buildkite.com/ray-project/ray-builders-pr/ # https://buildkite.com/ray-project/ray-builders-pr/
# builds/19635#55a0d71a-831e-4f68-b668-2b10c6f65ee6 # builds/19635#55a0d71a-831e-4f68-b668-2b10c6f65ee6
pattern = re.compile( pattern = re.compile(
"https://buildkite.com/([^/]+)/([^/]+)/builds/([0-9]+)#(.+)") "https://buildkite.com/([^/]+)/([^/]+)/builds/([0-9]+)#(.+)"
)
org, pipeline, build_id, job_id = pattern.match(session_url).groups() org, pipeline, build_id, job_id = pattern.match(session_url).groups()
self.logger.debug(f"Parsed session URL: {session_url}. " self.logger.debug(
f"Got org='{org}', pipeline='{pipeline}', " f"Parsed session URL: {session_url}. "
f"build_id='{build_id}', job_id='{job_id}'.") f"Got org='{org}', pipeline='{pipeline}', "
f"build_id='{build_id}', job_id='{job_id}'."
)
self.org = org self.org = org
self.pipeline = pipeline self.pipeline = pipeline
@ -155,7 +161,8 @@ class ReproSession:
assert self.bk assert self.bk
self.env = self.bk.jobs().get_job_environment_variables( self.env = self.bk.jobs().get_job_environment_variables(
self.org, self.pipeline, self.build_id, self.job_id)["env"] self.org, self.pipeline, self.build_id, self.job_id
)["env"]
if overwrite: if overwrite:
self.env.update(overwrite) self.env.update(overwrite)
@ -166,33 +173,30 @@ class ReproSession:
assert self.env assert self.env
if not self.aws_instance_name: if not self.aws_instance_name:
self.aws_instance_name = ( self.aws_instance_name = f"repro_ci_{self.build_id}_{self.job_id[:8]}"
f"repro_ci_{self.build_id}_{self.job_id[:8]}")
self.logger.info( self.logger.info(
f"No instance name provided, using {self.aws_instance_name}") f"No instance name provided, using {self.aws_instance_name}"
)
instance_type = self.env["BUILDKITE_AGENT_META_DATA_AWS_INSTANCE_TYPE"] instance_type = self.env["BUILDKITE_AGENT_META_DATA_AWS_INSTANCE_TYPE"]
instance_ami = self.env["BUILDKITE_AGENT_META_DATA_AWS_AMI_ID"] instance_ami = self.env["BUILDKITE_AGENT_META_DATA_AWS_AMI_ID"]
instance_sg = "sg-0ccfca2ef191c04ae" instance_sg = "sg-0ccfca2ef191c04ae"
instance_block_device_mappings = [{ instance_block_device_mappings = [
"DeviceName": "/dev/xvda", {"DeviceName": "/dev/xvda", "Ebs": {"VolumeSize": 500}}
"Ebs": { ]
"VolumeSize": 500
}
}]
# Check if instance exists: # Check if instance exists:
running_instances = self.ec2_resource.instances.filter(Filters=[{ running_instances = self.ec2_resource.instances.filter(
"Name": "tag:repro_name", Filters=[
"Values": [self.aws_instance_name] {"Name": "tag:repro_name", "Values": [self.aws_instance_name]},
}, { {"Name": "instance-state-name", "Values": ["running"]},
"Name": "instance-state-name", ]
"Values": ["running"] )
}])
self.logger.info( self.logger.info(
f"Check if instance with name {self.aws_instance_name} " f"Check if instance with name {self.aws_instance_name} "
f"already exists...") f"already exists..."
)
for instance in running_instances: for instance in running_instances:
self.aws_instance_id = instance.id self.aws_instance_id = instance.id
@ -201,8 +205,8 @@ class ReproSession:
return return
self.logger.info( self.logger.info(
f"Instance with name {self.aws_instance_name} not found, " f"Instance with name {self.aws_instance_name} not found, " f"creating..."
f"creating...") )
# Else, not running, yet, start. # Else, not running, yet, start.
instance = self.ec2_resource.create_instances( instance = self.ec2_resource.create_instances(
@ -211,20 +215,18 @@ class ReproSession:
InstanceType=instance_type, InstanceType=instance_type,
KeyName=self.ssh_key_name, KeyName=self.ssh_key_name,
SecurityGroupIds=[instance_sg], SecurityGroupIds=[instance_sg],
TagSpecifications=[{ TagSpecifications=[
"ResourceType": "instance", {
"Tags": [{ "ResourceType": "instance",
"Key": "repro_name", "Tags": [{"Key": "repro_name", "Value": self.aws_instance_name}],
"Value": self.aws_instance_name }
}] ],
}],
MinCount=1, MinCount=1,
MaxCount=1, MaxCount=1,
)[0] )[0]
self.aws_instance_id = instance.id self.aws_instance_id = instance.id
self.logger.info( self.logger.info(f"Created new instance with ID {self.aws_instance_id}")
f"Created new instance with ID {self.aws_instance_id}")
def aws_wait_for_instance(self): def aws_wait_for_instance(self):
assert self.aws_instance_id assert self.aws_instance_id
@ -234,28 +236,32 @@ class ReproSession:
repro_instance_state = None repro_instance_state = None
while repro_instance_state != "running": while repro_instance_state != "running":
detail = self.ec2_client.describe_instances( detail = self.ec2_client.describe_instances(
InstanceIds=[self.aws_instance_id], ) InstanceIds=[self.aws_instance_id],
repro_instance_state = \ )
detail["Reservations"][0]["Instances"][0]["State"]["Name"] repro_instance_state = detail["Reservations"][0]["Instances"][0]["State"][
"Name"
]
if repro_instance_state != "running": if repro_instance_state != "running":
time.sleep(2) time.sleep(2)
self.aws_instance_ip = detail["Reservations"][0]["Instances"][0][ self.aws_instance_ip = detail["Reservations"][0]["Instances"][0][
"PublicIpAddress"] "PublicIpAddress"
]
def aws_stop_instance(self): def aws_stop_instance(self):
assert self.aws_instance_id assert self.aws_instance_id
self.ec2_client.terminate_instances( self.ec2_client.terminate_instances(
InstanceIds=[self.aws_instance_id], ) InstanceIds=[self.aws_instance_id],
)
def print_stop_command(self): def print_stop_command(self):
click.secho("To stop this instance in the future, run this: ") click.secho("To stop this instance in the future, run this: ")
click.secho( click.secho(
f"aws ec2 terminate-instances " f"aws ec2 terminate-instances " f"--instance-ids={self.aws_instance_id}",
f"--instance-ids={self.aws_instance_id}", bold=True,
bold=True) )
def create_new_ssh_client(self): def create_new_ssh_client(self):
assert self.aws_instance_ip assert self.aws_instance_ip
@ -264,7 +270,8 @@ class ReproSession:
self.ssh.close() self.ssh.close()
self.logger.info( self.logger.info(
"Creating SSH client and waiting for SSH to become available...") "Creating SSH client and waiting for SSH to become available..."
)
ssh = paramiko.client.SSHClient() ssh = paramiko.client.SSHClient()
ssh.load_system_host_keys() ssh.load_system_host_keys()
@ -275,7 +282,8 @@ class ReproSession:
ssh.connect( ssh.connect(
self.aws_instance_ip, self.aws_instance_ip,
username=self.ssh_user, username=self.ssh_user,
key_filename=os.path.expanduser(self.ssh_key_file)) key_filename=os.path.expanduser(self.ssh_key_file),
)
break break
except paramiko.ssh_exception.NoValidConnectionsError: except paramiko.ssh_exception.NoValidConnectionsError:
self.logger.info("SSH not ready, yet, sleeping 5 seconds") self.logger.info("SSH not ready, yet, sleeping 5 seconds")
@ -291,8 +299,7 @@ class ReproSession:
result = {} result = {}
def exec(): def exec():
stdin, stdout, stderr = self.ssh.exec_command( stdin, stdout, stderr = self.ssh.exec_command(command, get_pty=True)
command, get_pty=True)
output = "" output = ""
for line in stdout.readlines(): for line in stdout.readlines():
@ -321,12 +328,13 @@ class ReproSession:
return result.get("output", "") return result.get("output", "")
def execute_ssh_command( def execute_ssh_command(
self, self,
command: str, command: str,
env: Optional[Dict[str, str]] = None, env: Optional[Dict[str, str]] = None,
as_script: bool = False, as_script: bool = False,
quiet: bool = False, quiet: bool = False,
command_wrapper: Optional[Callable[[str], str]] = None) -> str: command_wrapper: Optional[Callable[[str], str]] = None,
) -> str:
assert self.ssh assert self.ssh
if not command_wrapper: if not command_wrapper:
@ -360,23 +368,25 @@ class ReproSession:
return output return output
def execute_ssh_commands(self, def execute_ssh_commands(
commands: List[str], self,
env: Optional[Dict[str, str]] = None, commands: List[str],
quiet: bool = False): env: Optional[Dict[str, str]] = None,
quiet: bool = False,
):
for command in commands: for command in commands:
self.execute_ssh_command(command, env=env, quiet=quiet) self.execute_ssh_command(command, env=env, quiet=quiet)
def execute_docker_command(self, def execute_docker_command(
command: str, self, command: str, env: Optional[Dict[str, str]] = None, quiet: bool = False
env: Optional[Dict[str, str]] = None, ):
quiet: bool = False):
def command_wrapper(s): def command_wrapper(s):
escaped = s.replace("'", "'\"'\"'") escaped = s.replace("'", "'\"'\"'")
return f"docker exec -it ray_container /bin/bash -ci '{escaped}'" return f"docker exec -it ray_container /bin/bash -ci '{escaped}'"
self.execute_ssh_command( self.execute_ssh_command(
command, env=env, quiet=quiet, command_wrapper=command_wrapper) command, env=env, quiet=quiet, command_wrapper=command_wrapper
)
def prepare_instance(self): def prepare_instance(self):
self.create_new_ssh_client() self.create_new_ssh_client()
@ -387,8 +397,9 @@ class ReproSession:
self.logger.info("Preparing instance (installing docker etc.)") self.logger.info("Preparing instance (installing docker etc.)")
commands = [ commands = [
"sudo yum install -y docker", "sudo service docker start", "sudo yum install -y docker",
f"sudo usermod -aG docker {self.ssh_user}" "sudo service docker start",
f"sudo usermod -aG docker {self.ssh_user}",
] ]
self.execute_ssh_commands(commands, quiet=True) self.execute_ssh_commands(commands, quiet=True)
self.create_new_ssh_client() self.create_new_ssh_client()
@ -398,13 +409,18 @@ class ReproSession:
def docker_login(self): def docker_login(self):
self.logger.info("Logging into docker...") self.logger.info("Logging into docker...")
credentials = boto3.client( credentials = boto3.client(
"ecr", region_name="us-west-2").get_authorization_token() "ecr", region_name="us-west-2"
token = base64.b64decode(credentials["authorizationData"][0][ ).get_authorization_token()
"authorizationToken"]).decode("utf-8").replace("AWS:", "") token = (
base64.b64decode(credentials["authorizationData"][0]["authorizationToken"])
.decode("utf-8")
.replace("AWS:", "")
)
endpoint = credentials["authorizationData"][0]["proxyEndpoint"] endpoint = credentials["authorizationData"][0]["proxyEndpoint"]
self.execute_ssh_command( self.execute_ssh_command(
f"docker login -u AWS -p {token} {endpoint}", quiet=True) f"docker login -u AWS -p {token} {endpoint}", quiet=True
)
def fetch_buildkite_plugins(self): def fetch_buildkite_plugins(self):
assert self.env assert self.env
@ -415,8 +431,9 @@ class ReproSession:
for collection in plugins: for collection in plugins:
for plugin, options in collection.items(): for plugin, options in collection.items():
plugin_url, plugin_version = plugin.split("#") plugin_url, plugin_version = plugin.split("#")
if not plugin_url.startswith( if not plugin_url.startswith("http://") or not plugin_url.startswith(
"http://") or not plugin_url.startswith("https://"): "https://"
):
plugin_url = f"https://{plugin_url}" plugin_url = f"https://{plugin_url}"
plugin_name = plugin_url.split("/")[-1].rstrip(".git") plugin_name = plugin_url.split("/")[-1].rstrip(".git")
@ -432,7 +449,7 @@ class ReproSession:
"version": plugin_version, "version": plugin_version,
"dir": plugin_dir, "dir": plugin_dir,
"env": plugin_env, "env": plugin_env,
"details": {} "details": {},
} }
def get_plugin_env(self, plugin_short: str, options: Dict[str, Any]): def get_plugin_env(self, plugin_short: str, options: Dict[str, Any]):
@ -457,30 +474,33 @@ class ReproSession:
self.execute_ssh_command( self.execute_ssh_command(
f"[ ! -e {plugin_dir} ] && git clone --depth 1 " f"[ ! -e {plugin_dir} ] && git clone --depth 1 "
f"--branch {plugin_version} {plugin_url} {plugin_dir}", f"--branch {plugin_version} {plugin_url} {plugin_dir}",
quiet=True) quiet=True,
)
def load_plugin_details(self, plugin: str): def load_plugin_details(self, plugin: str):
assert plugin in self.plugins assert plugin in self.plugins
plugin_dir = self.plugins[plugin]["dir"] plugin_dir = self.plugins[plugin]["dir"]
yaml_str = self.execute_ssh_command( yaml_str = self.execute_ssh_command(f"cat {plugin_dir}/plugin.yml", quiet=True)
f"cat {plugin_dir}/plugin.yml", quiet=True)
details = yaml.safe_load(yaml_str) details = yaml.safe_load(yaml_str)
self.plugins[plugin]["details"] = details self.plugins[plugin]["details"] = details
return details return details
def execute_plugin_hook(self, def execute_plugin_hook(
plugin: str, self,
hook: str, plugin: str,
env: Optional[Dict[str, Any]] = None, hook: str,
script_command: Optional[str] = None): env: Optional[Dict[str, Any]] = None,
script_command: Optional[str] = None,
):
assert plugin in self.plugins assert plugin in self.plugins
self.logger.info( self.logger.info(
f"Executing Buildkite hook for plugin {plugin}: {hook}. " f"Executing Buildkite hook for plugin {plugin}: {hook}. "
f"This pulls a Docker image and could take a while.") f"This pulls a Docker image and could take a while."
)
plugin_dir = self.plugins[plugin]["dir"] plugin_dir = self.plugins[plugin]["dir"]
plugin_env = self.plugins[plugin]["env"].copy() plugin_env = self.plugins[plugin]["env"].copy()
@ -500,21 +520,23 @@ class ReproSession:
def print_buildkite_command(self, skipped: bool = False): def print_buildkite_command(self, skipped: bool = False):
print("-" * 80) print("-" * 80)
print("These are the commands you need to execute to fully reproduce " print(
"the run") "These are the commands you need to execute to fully reproduce " "the run"
)
print("-" * 80) print("-" * 80)
print(self.env["BUILDKITE_COMMAND"]) print(self.env["BUILDKITE_COMMAND"])
print("-" * 80) print("-" * 80)
if skipped and self.skipped_commands: if skipped and self.skipped_commands:
print("Some of the commands above have already been run. " print(
"Remaining commands:") "Some of the commands above have already been run. "
"Remaining commands:"
)
print("-" * 80) print("-" * 80)
print("\n".join(self.skipped_commands)) print("\n".join(self.skipped_commands))
print("-" * 80) print("-" * 80)
def run_buildkite_command(self, def run_buildkite_command(self, command_filter: Optional[List[str]] = None):
command_filter: Optional[List[str]] = None):
commands = self.env["BUILDKITE_COMMAND"].split("\n") commands = self.env["BUILDKITE_COMMAND"].split("\n")
regexes = [re.compile(cf) for cf in command_filter or []] regexes = [re.compile(cf) for cf in command_filter or []]
@ -537,15 +559,18 @@ class ReproSession:
f"grep -q 'source ~/.env' $HOME/.bashrc " f"grep -q 'source ~/.env' $HOME/.bashrc "
f"|| echo 'source ~/.env' >> $HOME/.bashrc; " f"|| echo 'source ~/.env' >> $HOME/.bashrc; "
f"echo 'export {escaped}' > $HOME/.env", f"echo 'export {escaped}' > $HOME/.env",
quiet=True) quiet=True,
)
def attach_to_container(self): def attach_to_container(self):
self.logger.info("Attaching to AWS instance...") self.logger.info("Attaching to AWS instance...")
ssh_command = (f"ssh -ti {self.ssh_key_file} " ssh_command = (
f"-o StrictHostKeyChecking=no " f"ssh -ti {self.ssh_key_file} "
f"-o ServerAliveInterval=30 " f"-o StrictHostKeyChecking=no "
f"{self.ssh_user}@{self.aws_instance_ip} " f"-o ServerAliveInterval=30 "
f"'docker exec -it ray_container bash -l'") f"{self.ssh_user}@{self.aws_instance_ip} "
f"'docker exec -it ray_container bash -l'"
)
subprocess.run(ssh_command, shell=True) subprocess.run(ssh_command, shell=True)
@ -555,29 +580,32 @@ class ReproSession:
@click.option("-n", "--instance-name", default=None) @click.option("-n", "--instance-name", default=None)
@click.option("-c", "--commands", is_flag=True, default=False) @click.option("-c", "--commands", is_flag=True, default=False)
@click.option("-f", "--filters", multiple=True, default=[]) @click.option("-f", "--filters", multiple=True, default=[])
def main(session_url: Optional[str], def main(
instance_name: Optional[str] = None, session_url: Optional[str],
commands: bool = False, instance_name: Optional[str] = None,
filters: Optional[List[str]] = None): commands: bool = False,
filters: Optional[List[str]] = None,
):
random.seed(1235) random.seed(1235)
logger = logging.getLogger("main") logger = logging.getLogger("main")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter( handler.setFormatter(
logging.Formatter("[%(levelname)s %(asctime)s] " logging.Formatter(
"%(filename)s: %(lineno)d " "[%(levelname)s %(asctime)s] " "%(filename)s: %(lineno)d " "%(message)s"
"%(message)s")) )
)
logger.addHandler(handler) logger.addHandler(handler)
maybe_fetch_buildkite_token() maybe_fetch_buildkite_token()
repro = ReproSession( repro = ReproSession(
os.environ["BUILDKITE_TOKEN"], os.environ["BUILDKITE_TOKEN"], instance_name=instance_name, logger=logger
instance_name=instance_name, )
logger=logger)
session_url = session_url or click.prompt( session_url = session_url or click.prompt(
"Please copy and paste the Buildkite job build URI here") "Please copy and paste the Buildkite job build URI here"
)
repro.set_session(session_url) repro.set_session(session_url)
@ -610,13 +638,16 @@ def main(session_url: Optional[str],
"BUILDKITE_PLUGIN_DOCKER_TTY": "0", "BUILDKITE_PLUGIN_DOCKER_TTY": "0",
"BUILDKITE_PLUGIN_DOCKER_MOUNT_CHECKOUT": "0", "BUILDKITE_PLUGIN_DOCKER_MOUNT_CHECKOUT": "0",
}, },
script_command=("sed -E 's/" script_command=(
"docker run/" "sed -E 's/"
"docker run " "docker run/"
"--cap-add=SYS_PTRACE " "docker run "
"--name ray_container " "--cap-add=SYS_PTRACE "
"-d/g' | " "--name ray_container "
"bash -l")) "-d/g' | "
"bash -l"
),
)
repro.create_new_ssh_client() repro.create_new_ssh_client()

View file

@ -54,7 +54,8 @@ def get_target_expansion_query(targets, tests_only, exclude_manual):
if exclude_manual: if exclude_manual:
query = '{} except tests(attr("tags", "manual", set({})))'.format( query = '{} except tests(attr("tags", "manual", set({})))'.format(
query, included_targets) query, included_targets
)
return query return query
@ -82,17 +83,16 @@ def get_targets_for_shard(targets, index, count):
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Expand and shard Bazel targets.")
description="Expand and shard Bazel targets.")
parser.add_argument("--debug", action="store_true") parser.add_argument("--debug", action="store_true")
parser.add_argument("--tests_only", action="store_true") parser.add_argument("--tests_only", action="store_true")
parser.add_argument("--exclude_manual", action="store_true") parser.add_argument("--exclude_manual", action="store_true")
parser.add_argument( parser.add_argument(
"--index", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB", 1)) "--index", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB", 1)
)
parser.add_argument( parser.add_argument(
"--count", "--count", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB_COUNT", 1)
type=int, )
default=os.getenv("BUILDKITE_PARALLEL_JOB_COUNT", 1))
parser.add_argument("targets", nargs="+") parser.add_argument("targets", nargs="+")
args, extra_args = parser.parse_known_args() args, extra_args = parser.parse_known_args()
args.targets = list(args.targets) + list(extra_args) args.targets = list(args.targets) + list(extra_args)
@ -100,11 +100,11 @@ def main():
if args.index >= args.count: if args.index >= args.count:
parser.error("--index must be between 0 and {}".format(args.count - 1)) parser.error("--index must be between 0 and {}".format(args.count - 1))
query = get_target_expansion_query(args.targets, args.tests_only, query = get_target_expansion_query(
args.exclude_manual) args.targets, args.tests_only, args.exclude_manual
)
expanded_targets = run_bazel_query(query, args.debug) expanded_targets = run_bazel_query(query, args.debug)
my_targets = get_targets_for_shard(expanded_targets, args.index, my_targets = get_targets_for_shard(expanded_targets, args.index, args.count)
args.count)
print(" ".join(my_targets)) print(" ".join(my_targets))
return 0 return 0

View file

@ -14,10 +14,10 @@ from collections import defaultdict, OrderedDict
def textproto_format(space, key, value, json_encoder): def textproto_format(space, key, value, json_encoder):
"""Rewrites a key-value pair from textproto as JSON.""" """Rewrites a key-value pair from textproto as JSON."""
if value.startswith(b"\""): if value.startswith(b'"'):
evaluated = ast.literal_eval(value.decode("utf-8")) evaluated = ast.literal_eval(value.decode("utf-8"))
value = json_encoder.encode(evaluated).encode("utf-8") value = json_encoder.encode(evaluated).encode("utf-8")
return b"%s[\"%s\", %s]" % (space, key, value) return b'%s["%s", %s]' % (space, key, value)
def textproto_split(input_lines, json_encoder): def textproto_split(input_lines, json_encoder):
@ -50,19 +50,21 @@ def textproto_split(input_lines, json_encoder):
pieces = re.split(b"(\\r|\\n)", full_line, 1) pieces = re.split(b"(\\r|\\n)", full_line, 1)
pieces[1:] = [b"".join(pieces[1:])] pieces[1:] = [b"".join(pieces[1:])]
[line, tail] = pieces [line, tail] = pieces
next_line = pat_open.sub(b"\\1[\"\\2\",\\3[", line) next_line = pat_open.sub(b'\\1["\\2",\\3[', line)
outputs.append(b"" if not prev_comma else b"]" outputs.append(
if next_line.endswith(b"}") else b",") b"" if not prev_comma else b"]" if next_line.endswith(b"}") else b","
)
next_line = pat_close.sub(b"]", next_line) next_line = pat_close.sub(b"]", next_line)
next_line = pat_line.sub( next_line = pat_line.sub(
lambda m: textproto_format(*(m.groups() + (json_encoder, ))), lambda m: textproto_format(*(m.groups() + (json_encoder,))), next_line
next_line) )
outputs.append(prev_tail + next_line) outputs.append(prev_tail + next_line)
if line == b"}": if line == b"}":
yield b"".join(outputs) yield b"".join(outputs)
del outputs[:] del outputs[:]
prev_comma = line != b"}" and (next_line.endswith(b"]") prev_comma = line != b"}" and (
or next_line.endswith(b"\"")) next_line.endswith(b"]") or next_line.endswith(b'"')
)
prev_tail = tail prev_tail = tail
if len(outputs) > 0: if len(outputs) > 0:
yield b"".join(outputs) yield b"".join(outputs)
@ -80,13 +82,14 @@ class Bazel(object):
def __init__(self, program=None): def __init__(self, program=None):
if program is None: if program is None:
program = os.getenv("BAZEL_EXECUTABLE", "bazel") program = os.getenv("BAZEL_EXECUTABLE", "bazel")
self.argv = (program, ) self.argv = (program,)
self.extra_args = ("--show_progress=no", ) self.extra_args = ("--show_progress=no",)
def _call(self, command, *args): def _call(self, command, *args):
return subprocess.check_output( return subprocess.check_output(
self.argv + (command, ) + args[:1] + self.extra_args + args[1:], self.argv + (command,) + args[:1] + self.extra_args + args[1:],
stdin=subprocess.PIPE) stdin=subprocess.PIPE,
)
def info(self, *args): def info(self, *args):
result = OrderedDict() result = OrderedDict()
@ -248,8 +251,7 @@ def shellcheck(bazel_aquery, *shellcheck_argv):
def main(program, command, *command_args): def main(program, command, *command_args):
result = 0 result = 0
if command == textproto2json.__name__: if command == textproto2json.__name__:
result = textproto2json(sys.stdin.buffer, sys.stdout.buffer, result = textproto2json(sys.stdin.buffer, sys.stdout.buffer, *command_args)
*command_args)
elif command == shellcheck.__name__: elif command == shellcheck.__name__:
result = shellcheck(*command_args) result = shellcheck(*command_args)
elif command == preclean.__name__: elif command == preclean.__name__:

View file

@ -20,21 +20,16 @@ DOCKER_CLIENT = None
PYTHON_WHL_VERSION = "cp3" PYTHON_WHL_VERSION = "cp3"
DOCKER_HUB_DESCRIPTION = { DOCKER_HUB_DESCRIPTION = {
"base-deps": ("Internal Image, refer to " "base-deps": (
"https://hub.docker.com/r/rayproject/ray"), "Internal Image, refer to " "https://hub.docker.com/r/rayproject/ray"
"ray-deps": ("Internal Image, refer to " ),
"https://hub.docker.com/r/rayproject/ray"), "ray-deps": ("Internal Image, refer to " "https://hub.docker.com/r/rayproject/ray"),
"ray": "Official Docker Images for Ray, the distributed computing API.", "ray": "Official Docker Images for Ray, the distributed computing API.",
"ray-ml": "Developer ready Docker Image for Ray.", "ray-ml": "Developer ready Docker Image for Ray.",
"ray-worker-container": "Internal Image for CI test", "ray-worker-container": "Internal Image for CI test",
} }
PY_MATRIX = { PY_MATRIX = {"py36": "3.6.12", "py37": "3.7.7", "py38": "3.8.5", "py39": "3.9.5"}
"py36": "3.6.12",
"py37": "3.7.7",
"py38": "3.8.5",
"py39": "3.9.5"
}
BASE_IMAGES = { BASE_IMAGES = {
"cu112": "nvidia/cuda:11.2.0-cudnn8-devel-ubuntu18.04", "cu112": "nvidia/cuda:11.2.0-cudnn8-devel-ubuntu18.04",
@ -50,7 +45,7 @@ CUDA_FULL = {
"cu111": "CUDA 11.1", "cu111": "CUDA 11.1",
"cu110": "CUDA 11.0", "cu110": "CUDA 11.0",
"cu102": "CUDA 10.2", "cu102": "CUDA 10.2",
"cu101": "CUDA 10.1" "cu101": "CUDA 10.1",
} }
# The CUDA version to use for the ML Docker image. # The CUDA version to use for the ML Docker image.
@ -62,8 +57,7 @@ IMAGE_NAMES = list(DOCKER_HUB_DESCRIPTION.keys())
def _get_branch(): def _get_branch():
branch = (os.environ.get("TRAVIS_BRANCH") branch = os.environ.get("TRAVIS_BRANCH") or os.environ.get("BUILDKITE_BRANCH")
or os.environ.get("BUILDKITE_BRANCH"))
if not branch: if not branch:
print("Branch not found!") print("Branch not found!")
print(os.environ) print(os.environ)
@ -94,8 +88,7 @@ def _get_root_dir():
def _get_commit_sha(): def _get_commit_sha():
sha = (os.environ.get("TRAVIS_COMMIT") sha = os.environ.get("TRAVIS_COMMIT") or os.environ.get("BUILDKITE_COMMIT") or ""
or os.environ.get("BUILDKITE_COMMIT") or "")
if len(sha) < 6: if len(sha) < 6:
print("INVALID SHA FOUND") print("INVALID SHA FOUND")
return "ERROR" return "ERROR"
@ -105,8 +98,9 @@ def _get_commit_sha():
def _configure_human_version(): def _configure_human_version():
global _get_branch global _get_branch
global _get_commit_sha global _get_commit_sha
fake_branch_name = input("Provide a 'branch name'. For releases, it " fake_branch_name = input(
"should be `releases/x.x.x`") "Provide a 'branch name'. For releases, it " "should be `releases/x.x.x`"
)
_get_branch = lambda: fake_branch_name # noqa: E731 _get_branch = lambda: fake_branch_name # noqa: E731
fake_sha = input("Provide a SHA (used for tag value)") fake_sha = input("Provide a SHA (used for tag value)")
_get_commit_sha = lambda: fake_sha # noqa: E731 _get_commit_sha = lambda: fake_sha # noqa: E731
@ -115,38 +109,44 @@ def _configure_human_version():
def _get_wheel_name(minor_version_number): def _get_wheel_name(minor_version_number):
if minor_version_number: if minor_version_number:
matches = [ matches = [
file for file in glob.glob( file
for file in glob.glob(
f"{_get_root_dir()}/.whl/ray-*{PYTHON_WHL_VERSION}" f"{_get_root_dir()}/.whl/ray-*{PYTHON_WHL_VERSION}"
f"{minor_version_number}*-manylinux*") f"{minor_version_number}*-manylinux*"
)
if "+" not in file # Exclude dbg, asan builds if "+" not in file # Exclude dbg, asan builds
] ]
assert len(matches) == 1, ( assert len(matches) == 1, (
f"Found ({len(matches)}) matches for 'ray-*{PYTHON_WHL_VERSION}" f"Found ({len(matches)}) matches for 'ray-*{PYTHON_WHL_VERSION}"
f"{minor_version_number}*-manylinux*' instead of 1.\n" f"{minor_version_number}*-manylinux*' instead of 1.\n"
f"wheel matches: {matches}") f"wheel matches: {matches}"
)
return os.path.basename(matches[0]) return os.path.basename(matches[0])
else: else:
matches = glob.glob( matches = glob.glob(f"{_get_root_dir()}/.whl/*{PYTHON_WHL_VERSION}*-manylinux*")
f"{_get_root_dir()}/.whl/*{PYTHON_WHL_VERSION}*-manylinux*")
return [os.path.basename(i) for i in matches] return [os.path.basename(i) for i in matches]
def _check_if_docker_files_modified(): def _check_if_docker_files_modified():
stdout = subprocess.check_output([ stdout = subprocess.check_output(
sys.executable, f"{_get_curr_dir()}/determine_tests_to_run.py", [
"--output=json" sys.executable,
]) f"{_get_curr_dir()}/determine_tests_to_run.py",
"--output=json",
]
)
affected_env_var_list = json.loads(stdout) affected_env_var_list = json.loads(stdout)
affected = ("RAY_CI_DOCKER_AFFECTED" in affected_env_var_list or affected = (
"RAY_CI_PYTHON_DEPENDENCIES_AFFECTED" in affected_env_var_list) "RAY_CI_DOCKER_AFFECTED" in affected_env_var_list
or "RAY_CI_PYTHON_DEPENDENCIES_AFFECTED" in affected_env_var_list
)
print(f"Docker affected: {affected}") print(f"Docker affected: {affected}")
return affected return affected
def _build_docker_image(image_name: str, def _build_docker_image(
py_version: str, image_name: str, py_version: str, image_type: str, no_cache=True
image_type: str, ):
no_cache=True):
"""Builds Docker image with the provided info. """Builds Docker image with the provided info.
image_name (str): The name of the image to build. Must be one of image_name (str): The name of the image to build. Must be one of
@ -161,23 +161,27 @@ def _build_docker_image(image_name: str,
if image_name not in IMAGE_NAMES: if image_name not in IMAGE_NAMES:
raise ValueError( raise ValueError(
f"The provided image name {image_name} is not " f"The provided image name {image_name} is not "
f"recognized. Image names must be one of {IMAGE_NAMES}") f"recognized. Image names must be one of {IMAGE_NAMES}"
)
if py_version not in PY_MATRIX.keys(): if py_version not in PY_MATRIX.keys():
raise ValueError(f"The provided python version {py_version} is not " raise ValueError(
f"recognized. Python version must be one of" f"The provided python version {py_version} is not "
f" {PY_MATRIX.keys()}") f"recognized. Python version must be one of"
f" {PY_MATRIX.keys()}"
)
if image_type not in BASE_IMAGES.keys(): if image_type not in BASE_IMAGES.keys():
raise ValueError(f"The provided CUDA version {image_type} is not " raise ValueError(
f"recognized. CUDA version must be one of" f"The provided CUDA version {image_type} is not "
f" {image_type.keys()}") f"recognized. CUDA version must be one of"
f" {image_type.keys()}"
)
# TODO(https://github.com/ray-project/ray/issues/16599): # TODO(https://github.com/ray-project/ray/issues/16599):
# remove below after supporting ray-ml images with Python 3.9 # remove below after supporting ray-ml images with Python 3.9
if image_name == "ray-ml" and py_version == "py39": if image_name == "ray-ml" and py_version == "py39":
print(f"{image_name} image is currently unsupported with " print(f"{image_name} image is currently unsupported with " "Python 3.9")
"Python 3.9")
return return
build_args = {} build_args = {}
@ -212,7 +216,7 @@ def _build_docker_image(image_name: str,
labels = { labels = {
"image-name": image_name, "image-name": image_name,
"python-version": PY_MATRIX[py_version], "python-version": PY_MATRIX[py_version],
"ray-commit": _get_commit_sha() "ray-commit": _get_commit_sha(),
} }
if image_type in CUDA_FULL: if image_type in CUDA_FULL:
labels["cuda-version"] = CUDA_FULL[image_type] labels["cuda-version"] = CUDA_FULL[image_type]
@ -222,7 +226,8 @@ def _build_docker_image(image_name: str,
tag=tagged_name, tag=tagged_name,
nocache=no_cache, nocache=no_cache,
labels=labels, labels=labels,
buildargs=build_args) buildargs=build_args,
)
cmd_output = [] cmd_output = []
try: try:
@ -230,12 +235,15 @@ def _build_docker_image(image_name: str,
current_iter = start current_iter = start
for line in output: for line in output:
cmd_output.append(line.decode("utf-8")) cmd_output.append(line.decode("utf-8"))
if datetime.datetime.now( if datetime.datetime.now() - current_iter >= datetime.timedelta(
) - current_iter >= datetime.timedelta(minutes=5): minutes=5
):
current_iter = datetime.datetime.now() current_iter = datetime.datetime.now()
elapsed = datetime.datetime.now() - start elapsed = datetime.datetime.now() - start
print(f"Still building {tagged_name} after " print(
f"{elapsed.seconds} seconds") f"Still building {tagged_name} after "
f"{elapsed.seconds} seconds"
)
if elapsed >= datetime.timedelta(minutes=15): if elapsed >= datetime.timedelta(minutes=15):
print("Additional build output:") print("Additional build output:")
print(*cmd_output, sep="\n") print(*cmd_output, sep="\n")
@ -259,8 +267,10 @@ def _build_docker_image(image_name: str,
def copy_wheels(human_build): def copy_wheels(human_build):
if human_build: if human_build:
print("Please download images using:\n" print(
"`pip download --python-version <py_version> ray==<ray_version>") "Please download images using:\n"
"`pip download --python-version <py_version> ray==<ray_version>"
)
root_dir = _get_root_dir() root_dir = _get_root_dir()
wheels = _get_wheel_name(None) wheels = _get_wheel_name(None)
for wheel in wheels: for wheel in wheels:
@ -268,7 +278,8 @@ def copy_wheels(human_build):
ray_dst = os.path.join(root_dir, "docker/ray/.whl/") ray_dst = os.path.join(root_dir, "docker/ray/.whl/")
ray_dep_dst = os.path.join(root_dir, "docker/ray-deps/.whl/") ray_dep_dst = os.path.join(root_dir, "docker/ray-deps/.whl/")
ray_worker_container_dst = os.path.join( ray_worker_container_dst = os.path.join(
root_dir, "docker/ray-worker-container/.whl/") root_dir, "docker/ray-worker-container/.whl/"
)
os.makedirs(ray_dst, exist_ok=True) os.makedirs(ray_dst, exist_ok=True)
shutil.copy(source, ray_dst) shutil.copy(source, ray_dst)
os.makedirs(ray_dep_dst, exist_ok=True) os.makedirs(ray_dep_dst, exist_ok=True)
@ -282,8 +293,7 @@ def check_staleness(repository, tag):
age = DOCKER_CLIENT.api.inspect_image(f"{repository}:{tag}")["Created"] age = DOCKER_CLIENT.api.inspect_image(f"{repository}:{tag}")["Created"]
short_date = datetime.datetime.strptime(age.split("T")[0], "%Y-%m-%d") short_date = datetime.datetime.strptime(age.split("T")[0], "%Y-%m-%d")
is_stale = ( is_stale = (datetime.datetime.now() - short_date) > datetime.timedelta(days=14)
datetime.datetime.now() - short_date) > datetime.timedelta(days=14)
return is_stale return is_stale
@ -292,28 +302,23 @@ def build_for_all_versions(image_name, py_versions, image_types, **kwargs):
for py_version in py_versions: for py_version in py_versions:
for image_type in image_types: for image_type in image_types:
_build_docker_image( _build_docker_image(
image_name, image_name, py_version=py_version, image_type=image_type, **kwargs
py_version=py_version, )
image_type=image_type,
**kwargs)
def build_base_images(py_versions, image_types): def build_base_images(py_versions, image_types):
build_for_all_versions( build_for_all_versions("base-deps", py_versions, image_types, no_cache=False)
"base-deps", py_versions, image_types, no_cache=False) build_for_all_versions("ray-deps", py_versions, image_types, no_cache=False)
build_for_all_versions(
"ray-deps", py_versions, image_types, no_cache=False)
def build_or_pull_base_images(py_versions: List[str], def build_or_pull_base_images(
image_types: List[str], py_versions: List[str], image_types: List[str], rebuild_base_images: bool = True
rebuild_base_images: bool = True) -> bool: ) -> bool:
"""Returns images to tag and build.""" """Returns images to tag and build."""
repositories = ["rayproject/base-deps", "rayproject/ray-deps"] repositories = ["rayproject/base-deps", "rayproject/ray-deps"]
tags = [ tags = [
f"nightly-{py_version}-{image_type}" f"nightly-{py_version}-{image_type}"
for py_version, image_type in itertools.product( for py_version, image_type in itertools.product(py_versions, image_types)
py_versions, image_types)
] ]
try: try:
@ -339,12 +344,15 @@ def build_or_pull_base_images(py_versions: List[str],
def prep_ray_ml(): def prep_ray_ml():
root_dir = _get_root_dir() root_dir = _get_root_dir()
requirement_files = glob.glob( requirement_files = glob.glob(
f"{_get_root_dir()}/python/**/requirements*.txt", recursive=True) f"{_get_root_dir()}/python/**/requirements*.txt", recursive=True
)
for fl in requirement_files: for fl in requirement_files:
shutil.copy(fl, os.path.join(root_dir, "docker/ray-ml/")) shutil.copy(fl, os.path.join(root_dir, "docker/ray-ml/"))
# Install atari roms script # Install atari roms script
shutil.copy(f"{_get_root_dir()}/rllib/utils/install_atari_roms.sh", shutil.copy(
os.path.join(root_dir, "docker/ray-ml/")) f"{_get_root_dir()}/rllib/utils/install_atari_roms.sh",
os.path.join(root_dir, "docker/ray-ml/"),
)
def _get_docker_creds() -> Tuple[str, str]: def _get_docker_creds() -> Tuple[str, str]:
@ -377,10 +385,13 @@ def _tag_and_push(full_image_name, old_tag, new_tag, merge_build=False):
DOCKER_CLIENT.api.tag( DOCKER_CLIENT.api.tag(
image=f"{full_image_name}:{old_tag}", image=f"{full_image_name}:{old_tag}",
repository=full_image_name, repository=full_image_name,
tag=new_tag) tag=new_tag,
)
if not merge_build: if not merge_build:
print("This is a PR Build! On a merge build, we would normally push" print(
f"to: {full_image_name}:{new_tag}") "This is a PR Build! On a merge build, we would normally push"
f"to: {full_image_name}:{new_tag}"
)
else: else:
_docker_push(full_image_name, new_tag) _docker_push(full_image_name, new_tag)
@ -395,16 +406,17 @@ def _create_new_tags(all_tags, old_str, new_str):
# For non-release builds, push "nightly" & "sha" # For non-release builds, push "nightly" & "sha"
# For release builds, push "nightly" & "latest" & "x.x.x" # For release builds, push "nightly" & "latest" & "x.x.x"
def push_and_tag_images(py_versions: List[str], def push_and_tag_images(
image_types: List[str], py_versions: List[str],
push_base_images: bool, image_types: List[str],
merge_build: bool = False): push_base_images: bool,
merge_build: bool = False,
):
date_tag = datetime.datetime.now().strftime("%Y-%m-%d") date_tag = datetime.datetime.now().strftime("%Y-%m-%d")
sha_tag = _get_commit_sha() sha_tag = _get_commit_sha()
if _release_build(): if _release_build():
release_name = re.search("[0-9]+\.[0-9]+\.[0-9].*", release_name = re.search("[0-9]+\.[0-9]+\.[0-9].*", _get_branch()).group(0)
_get_branch()).group(0)
date_tag = release_name date_tag = release_name
sha_tag = release_name sha_tag = release_name
@ -423,16 +435,19 @@ def push_and_tag_images(py_versions: List[str],
for py_name in py_versions: for py_name in py_versions:
for image_type in image_types: for image_type in image_types:
if image_name == "ray-ml" and image_type != ML_CUDA_VERSION: if image_name == "ray-ml" and image_type != ML_CUDA_VERSION:
print("ML Docker image is not built for the following " print(
f"device type: {image_type}") "ML Docker image is not built for the following "
f"device type: {image_type}"
)
continue continue
# TODO(https://github.com/ray-project/ray/issues/16599): # TODO(https://github.com/ray-project/ray/issues/16599):
# remove below after supporting ray-ml images with Python 3.9 # remove below after supporting ray-ml images with Python 3.9
if image_name in ["ray-ml" if image_name in ["ray-ml"] and PY_MATRIX[py_name].startswith("3.9"):
] and PY_MATRIX[py_name].startswith("3.9"): print(
print(f"{image_name} image is currently " f"{image_name} image is currently "
f"unsupported with Python 3.9") f"unsupported with Python 3.9"
)
continue continue
tag = f"nightly-{py_name}-{image_type}" tag = f"nightly-{py_name}-{image_type}"
@ -445,20 +460,19 @@ def push_and_tag_images(py_versions: List[str],
for old_tag in tag_mapping.keys(): for old_tag in tag_mapping.keys():
if "cpu" in old_tag: if "cpu" in old_tag:
new_tags = _create_new_tags( new_tags = _create_new_tags(
tag_mapping[old_tag], old_str="-cpu", new_str="") tag_mapping[old_tag], old_str="-cpu", new_str=""
)
tag_mapping[old_tag].extend(new_tags) tag_mapping[old_tag].extend(new_tags)
elif ML_CUDA_VERSION in old_tag: elif ML_CUDA_VERSION in old_tag:
new_tags = _create_new_tags( new_tags = _create_new_tags(
tag_mapping[old_tag], tag_mapping[old_tag], old_str=f"-{ML_CUDA_VERSION}", new_str="-gpu"
old_str=f"-{ML_CUDA_VERSION}", )
new_str="-gpu")
tag_mapping[old_tag].extend(new_tags) tag_mapping[old_tag].extend(new_tags)
if image_name == "ray-ml": if image_name == "ray-ml":
new_tags = _create_new_tags( new_tags = _create_new_tags(
tag_mapping[old_tag], tag_mapping[old_tag], old_str=f"-{ML_CUDA_VERSION}", new_str=""
old_str=f"-{ML_CUDA_VERSION}", )
new_str="")
tag_mapping[old_tag].extend(new_tags) tag_mapping[old_tag].extend(new_tags)
# No Python version specified should refer to DEFAULT_PYTHON_VERSION # No Python version specified should refer to DEFAULT_PYTHON_VERSION
@ -467,7 +481,8 @@ def push_and_tag_images(py_versions: List[str],
new_tags = _create_new_tags( new_tags = _create_new_tags(
tag_mapping[old_tag], tag_mapping[old_tag],
old_str=f"-{DEFAULT_PYTHON_VERSION}", old_str=f"-{DEFAULT_PYTHON_VERSION}",
new_str="") new_str="",
)
tag_mapping[old_tag].extend(new_tags) tag_mapping[old_tag].extend(new_tags)
# For all tags, create Date/Sha tags # For all tags, create Date/Sha tags
@ -475,7 +490,8 @@ def push_and_tag_images(py_versions: List[str],
new_tags = _create_new_tags( new_tags = _create_new_tags(
tag_mapping[old_tag], tag_mapping[old_tag],
old_str="nightly", old_str="nightly",
new_str=date_tag if "-deps" in image_name else sha_tag) new_str=date_tag if "-deps" in image_name else sha_tag,
)
tag_mapping[old_tag].extend(new_tags) tag_mapping[old_tag].extend(new_tags)
# Sanity checking. # Sanity checking.
@ -511,7 +527,8 @@ def push_and_tag_images(py_versions: List[str],
full_image_name, full_image_name,
old_tag=old_tag, old_tag=old_tag,
new_tag=new_tag, new_tag=new_tag,
merge_build=merge_build) merge_build=merge_build,
)
# Push infra here: # Push infra here:
@ -527,9 +544,9 @@ def push_readmes(merge_build: bool):
"DOCKER_PASS": password, "DOCKER_PASS": password,
"PUSHRM_FILE": f"/myvol/docker/{image}/README.md", "PUSHRM_FILE": f"/myvol/docker/{image}/README.md",
"PUSHRM_DEBUG": 1, "PUSHRM_DEBUG": 1,
"PUSHRM_SHORT": tag_line "PUSHRM_SHORT": tag_line,
} }
cmd_string = (f"rayproject/{image}") cmd_string = f"rayproject/{image}"
print( print(
DOCKER_CLIENT.containers.run( DOCKER_CLIENT.containers.run(
@ -546,7 +563,9 @@ def push_readmes(merge_build: bool):
detach=False, detach=False,
stderr=True, stderr=True,
stdout=True, stdout=True,
tty=False)) tty=False,
)
)
# Build base-deps/ray-deps only on file change, 2 weeks, per release # Build base-deps/ray-deps only on file change, 2 weeks, per release
@ -566,63 +585,73 @@ if __name__ == "__main__":
choices=list(PY_MATRIX.keys()), choices=list(PY_MATRIX.keys()),
default="py37", default="py37",
nargs="*", nargs="*",
help="Which python versions to build. " help="Which python versions to build. " "Must be in (py36, py37, py38, py39)",
"Must be in (py36, py37, py38, py39)") )
parser.add_argument( parser.add_argument(
"--device-types", "--device-types",
choices=list(BASE_IMAGES.keys()), choices=list(BASE_IMAGES.keys()),
default=None, default=None,
nargs="*", nargs="*",
help="Which device types (CPU/CUDA versions) to build images for. " help="Which device types (CPU/CUDA versions) to build images for. "
"If not specified, images will be built for all device types.") "If not specified, images will be built for all device types.",
)
parser.add_argument( parser.add_argument(
"--build-type", "--build-type",
choices=BUILD_TYPES, choices=BUILD_TYPES,
required=True, required=True,
help="Whether to bypass checking if docker is affected") help="Whether to bypass checking if docker is affected",
)
parser.add_argument( parser.add_argument(
"--build-base", "--build-base",
dest="base", dest="base",
action="store_true", action="store_true",
help="Whether to build base-deps & ray-deps") help="Whether to build base-deps & ray-deps",
)
parser.add_argument("--no-build-base", dest="base", action="store_false") parser.add_argument("--no-build-base", dest="base", action="store_false")
parser.set_defaults(base=True) parser.set_defaults(base=True)
parser.add_argument( parser.add_argument(
"--only-build-worker-container", "--only-build-worker-container",
dest="only_build_worker_container", dest="only_build_worker_container",
action="store_true", action="store_true",
help="Whether only to build ray-worker-container") help="Whether only to build ray-worker-container",
)
parser.set_defaults(only_build_worker_container=False) parser.set_defaults(only_build_worker_container=False)
args = parser.parse_args() args = parser.parse_args()
py_versions = args.py_versions py_versions = args.py_versions
py_versions = py_versions if isinstance(py_versions, py_versions = py_versions if isinstance(py_versions, list) else [py_versions]
list) else [py_versions]
image_types = args.device_types if args.device_types else list( image_types = args.device_types if args.device_types else list(BASE_IMAGES.keys())
BASE_IMAGES.keys())
assert set(list(CUDA_FULL.keys()) + ["cpu"]) == set(BASE_IMAGES.keys()) assert set(list(CUDA_FULL.keys()) + ["cpu"]) == set(BASE_IMAGES.keys())
# Make sure the python images and cuda versions we build here are # Make sure the python images and cuda versions we build here are
# consistent with the ones used with fix-latest-docker.sh script. # consistent with the ones used with fix-latest-docker.sh script.
py_version_file = os.path.join(_get_root_dir(), "docker/retag-lambda", py_version_file = os.path.join(
"python_versions.txt") _get_root_dir(), "docker/retag-lambda", "python_versions.txt"
)
with open(py_version_file) as f: with open(py_version_file) as f:
py_file_versions = f.read().splitlines() py_file_versions = f.read().splitlines()
assert set(PY_MATRIX.keys()) == set(py_file_versions), \ assert set(PY_MATRIX.keys()) == set(py_file_versions), (
(PY_MATRIX.keys(), py_file_versions) PY_MATRIX.keys(),
py_file_versions,
)
cuda_version_file = os.path.join(_get_root_dir(), "docker/retag-lambda", cuda_version_file = os.path.join(
"cuda_versions.txt") _get_root_dir(), "docker/retag-lambda", "cuda_versions.txt"
)
with open(cuda_version_file) as f: with open(cuda_version_file) as f:
cuda_file_versions = f.read().splitlines() cuda_file_versions = f.read().splitlines()
assert set(BASE_IMAGES.keys()) == set(cuda_file_versions + ["cpu"]),\ assert set(BASE_IMAGES.keys()) == set(cuda_file_versions + ["cpu"]), (
(BASE_IMAGES.keys(), cuda_file_versions + ["cpu"]) BASE_IMAGES.keys(),
cuda_file_versions + ["cpu"],
)
print("Building the following python versions: ", print(
[PY_MATRIX[py_version] for py_version in py_versions]) "Building the following python versions: ",
[PY_MATRIX[py_version] for py_version in py_versions],
)
print("Building images for the following devices: ", image_types) print("Building images for the following devices: ", image_types)
print("Building base images: ", args.base) print("Building base images: ", args.base)
@ -639,9 +668,11 @@ if __name__ == "__main__":
if build_type == HUMAN: if build_type == HUMAN:
# If manually triggered, request user for branch and SHA value to use. # If manually triggered, request user for branch and SHA value to use.
_configure_human_version() _configure_human_version()
if (build_type in {HUMAN, MERGE, BUILDKITE, LOCAL} if (
or _check_if_docker_files_modified() build_type in {HUMAN, MERGE, BUILDKITE, LOCAL}
or args.only_build_worker_container): or _check_if_docker_files_modified()
or args.only_build_worker_container
):
DOCKER_CLIENT = docker.from_env() DOCKER_CLIENT = docker.from_env()
is_merge = build_type == MERGE is_merge = build_type == MERGE
# Buildkite is authenticated in the background. # Buildkite is authenticated in the background.
@ -652,11 +683,11 @@ if __name__ == "__main__":
DOCKER_CLIENT.api.login(username=username, password=password) DOCKER_CLIENT.api.login(username=username, password=password)
copy_wheels(build_type == HUMAN) copy_wheels(build_type == HUMAN)
is_base_images_built = build_or_pull_base_images( is_base_images_built = build_or_pull_base_images(
py_versions, image_types, args.base) py_versions, image_types, args.base
)
if args.only_build_worker_container: if args.only_build_worker_container:
build_for_all_versions("ray-worker-container", py_versions, build_for_all_versions("ray-worker-container", py_versions, image_types)
image_types)
# TODO Currently don't push ray_worker_container # TODO Currently don't push ray_worker_container
else: else:
# Build Ray Docker images. # Build Ray Docker images.
@ -668,15 +699,19 @@ if __name__ == "__main__":
prep_ray_ml() prep_ray_ml()
# Only build ML Docker for the ML_CUDA_VERSION # Only build ML Docker for the ML_CUDA_VERSION
build_for_all_versions( build_for_all_versions(
"ray-ml", py_versions, image_types=[ML_CUDA_VERSION]) "ray-ml", py_versions, image_types=[ML_CUDA_VERSION]
)
if build_type in {MERGE, PR}: if build_type in {MERGE, PR}:
valid_branch = _valid_branch() valid_branch = _valid_branch()
if (not valid_branch) and is_merge: if (not valid_branch) and is_merge:
print(f"Invalid Branch found: {_get_branch()}") print(f"Invalid Branch found: {_get_branch()}")
push_and_tag_images(py_versions, image_types, push_and_tag_images(
is_base_images_built, valid_branch py_versions,
and is_merge) image_types,
is_base_images_built,
valid_branch and is_merge,
)
# TODO(ilr) Re-Enable Push READMEs by using a normal password # TODO(ilr) Re-Enable Push READMEs by using a normal password
# (not auth token :/) # (not auth token :/)

View file

@ -20,7 +20,8 @@ def build_multinode_image(source_image: str, target_image: str):
f.write("RUN sudo apt install -y openssh-server\n") f.write("RUN sudo apt install -y openssh-server\n")
subprocess.check_output( subprocess.check_output(
f"docker build -t {target_image} .", shell=True, cwd=tempdir) f"docker build -t {target_image} .", shell=True, cwd=tempdir
)
shutil.rmtree(tempdir) shutil.rmtree(tempdir)

View file

@ -25,9 +25,7 @@ def perform_check(raw_xml_string: str):
missing_owners = [] missing_owners = []
for rule in tree.findall("rule"): for rule in tree.findall("rule"):
test_name = rule.attrib["name"] test_name = rule.attrib["name"]
tags = [ tags = [child.attrib["value"] for child in rule.find("list").getchildren()]
child.attrib["value"] for child in rule.find("list").getchildren()
]
team_owner = [t for t in tags if t.startswith("team")] team_owner = [t for t in tags if t.startswith("team")]
if len(team_owner) == 0: if len(team_owner) == 0:
missing_owners.append(test_name) missing_owners.append(test_name)
@ -36,7 +34,8 @@ def perform_check(raw_xml_string: str):
if len(missing_owners): if len(missing_owners):
raise Exception( raise Exception(
f"Cannot find owner for tests {missing_owners}, please add " f"Cannot find owner for tests {missing_owners}, please add "
"`team:*` to the tags.") "`team:*` to the tags."
)
print(owners) print(owners)

View file

@ -19,11 +19,7 @@ exit_with_error = False
def check_import(file): def check_import(file):
check_to_lines = { check_to_lines = {"import ray": -1, "import psutil": -1, "import setproctitle": -1}
"import ray": -1,
"import psutil": -1,
"import setproctitle": -1
}
with io.open(file, "r", encoding="utf-8") as f: with io.open(file, "r", encoding="utf-8") as f:
for i, line in enumerate(f): for i, line in enumerate(f):
@ -37,8 +33,10 @@ def check_import(file):
# It will not match the following # It will not match the following
# - submodule import: `import ray.constants as ray_constants` # - submodule import: `import ray.constants as ray_constants`
# - submodule import: `from ray import xyz` # - submodule import: `from ray import xyz`
if re.search(r"^\s*" + check + r"(\s*|\s+# noqa F401.*)$", if (
line) and check_to_lines[check] == -1: re.search(r"^\s*" + check + r"(\s*|\s+# noqa F401.*)$", line)
and check_to_lines[check] == -1
):
check_to_lines[check] = i check_to_lines[check] = i
for import_lib in ["import psutil", "import setproctitle"]: for import_lib in ["import psutil", "import setproctitle"]:
@ -48,8 +46,8 @@ def check_import(file):
if import_ray_line == -1 or import_ray_line > import_psutil_line: if import_ray_line == -1 or import_ray_line > import_psutil_line:
print( print(
"{}:{}".format(str(file), import_psutil_line + 1), "{}:{}".format(str(file), import_psutil_line + 1),
"{} without explicitly import ray before it.".format( "{} without explicitly import ray before it.".format(import_lib),
import_lib)) )
global exit_with_error global exit_with_error
exit_with_error = True exit_with_error = True
@ -59,8 +57,7 @@ if __name__ == "__main__":
parser.add_argument("path", help="File path to check. e.g. '.' or './src'") parser.add_argument("path", help="File path to check. e.g. '.' or './src'")
# TODO(simon): For the future, consider adding a feature to explicitly # TODO(simon): For the future, consider adding a feature to explicitly
# white-list the path instead of skipping them. # white-list the path instead of skipping them.
parser.add_argument( parser.add_argument("-s", "--skip", action="append", help="Skip certian directory")
"-s", "--skip", action="append", help="Skip certian directory")
args = parser.parse_args() args = parser.parse_args()
file_path = Path(args.path) file_path = Path(args.path)

View file

@ -18,7 +18,7 @@ DEFAULT_BLACKLIST = [
"gpustat", "gpustat",
"opencensus", "opencensus",
"prometheus_client", "prometheus_client",
"smart_open" "smart_open",
] ]
@ -28,19 +28,20 @@ def assert_packages_not_installed(blacklist: List[str]):
except ImportError: # pip < 10.0 except ImportError: # pip < 10.0
from pip.operations import freeze from pip.operations import freeze
installed_packages = [ installed_packages = [p.split("==")[0].split(" @ ")[0] for p in freeze.freeze()]
p.split("==")[0].split(" @ ")[0] for p in freeze.freeze()
]
assert not any(p in installed_packages for p in blacklist), \ assert not any(p in installed_packages for p in blacklist), (
f"Found blacklisted packages in installed python packages: " \ f"Found blacklisted packages in installed python packages: "
f"{[p for p in blacklist if p in installed_packages]}. " \ f"{[p for p in blacklist if p in installed_packages]}. "
f"Minimal dependency tests could be tainted by this. " \ f"Minimal dependency tests could be tainted by this. "
f"Check the install logs and primary dependencies if any of these " \ f"Check the install logs and primary dependencies if any of these "
f"packages were installed as part of another install step." f"packages were installed as part of another install step."
)
print(f"Confirmed that blacklisted packages are not installed in " print(
f"current Python environment: {blacklist}") f"Confirmed that blacklisted packages are not installed in "
f"current Python environment: {blacklist}"
)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -54,7 +54,8 @@ def run_tidy(task_queue, lock, timeout):
command = task_queue.get() command = task_queue.get()
try: try:
proc = subprocess.Popen( proc = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
if timeout is not None: if timeout is not None:
watchdog = threading.Timer(timeout, proc.kill) watchdog = threading.Timer(timeout, proc.kill)
@ -70,22 +71,21 @@ def run_tidy(task_queue, lock, timeout):
sys.stderr.flush() sys.stderr.flush()
except Exception as e: except Exception as e:
with lock: with lock:
sys.stderr.write("Failed: " + str(e) + ": ".join(command) + sys.stderr.write("Failed: " + str(e) + ": ".join(command) + "\n")
"\n")
finally: finally:
with lock: with lock:
if timeout is not None and watchdog is not None: if timeout is not None and watchdog is not None:
if not watchdog.is_alive(): if not watchdog.is_alive():
sys.stderr.write("Terminated by timeout: " + sys.stderr.write(
" ".join(command) + "\n") "Terminated by timeout: " + " ".join(command) + "\n"
)
watchdog.cancel() watchdog.cancel()
task_queue.task_done() task_queue.task_done()
def start_workers(max_tasks, tidy_caller, task_queue, lock, timeout): def start_workers(max_tasks, tidy_caller, task_queue, lock, timeout):
for _ in range(max_tasks): for _ in range(max_tasks):
t = threading.Thread( t = threading.Thread(target=tidy_caller, args=(task_queue, lock, timeout))
target=tidy_caller, args=(task_queue, lock, timeout))
t.daemon = True t.daemon = True
t.start() t.start()
@ -119,84 +119,87 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run clang-tidy against changed files, and " description="Run clang-tidy against changed files, and "
"output diagnostics only for modified " "output diagnostics only for modified "
"lines.") "lines."
)
parser.add_argument( parser.add_argument(
"-clang-tidy-binary", "-clang-tidy-binary",
metavar="PATH", metavar="PATH",
default="clang-tidy", default="clang-tidy",
help="path to clang-tidy binary") help="path to clang-tidy binary",
)
parser.add_argument( parser.add_argument(
"-p", "-p",
metavar="NUM", metavar="NUM",
default=0, default=0,
help="strip the smallest prefix containing P slashes") help="strip the smallest prefix containing P slashes",
)
parser.add_argument( parser.add_argument(
"-regex", "-regex",
metavar="PATTERN", metavar="PATTERN",
default=None, default=None,
help="custom pattern selecting file paths to check " help="custom pattern selecting file paths to check "
"(case sensitive, overrides -iregex)") "(case sensitive, overrides -iregex)",
)
parser.add_argument( parser.add_argument(
"-iregex", "-iregex",
metavar="PATTERN", metavar="PATTERN",
default=r".*\.(cpp|cc|c\+\+|cxx|c|cl|h|hpp|m|mm|inc)", default=r".*\.(cpp|cc|c\+\+|cxx|c|cl|h|hpp|m|mm|inc)",
help="custom pattern selecting file paths to check " help="custom pattern selecting file paths to check "
"(case insensitive, overridden by -regex)") "(case insensitive, overridden by -regex)",
)
parser.add_argument( parser.add_argument(
"-j", "-j",
type=int, type=int,
default=1, default=1,
help="number of tidy instances to be run in parallel.") help="number of tidy instances to be run in parallel.",
)
parser.add_argument( parser.add_argument(
"-timeout", "-timeout", type=int, default=None, help="timeout per each file in seconds."
type=int, )
default=None,
help="timeout per each file in seconds.")
parser.add_argument( parser.add_argument(
"-fix", "-fix", action="store_true", default=False, help="apply suggested fixes"
action="store_true", )
default=False,
help="apply suggested fixes")
parser.add_argument( parser.add_argument(
"-checks", "-checks",
help="checks filter, when not specified, use clang-tidy " help="checks filter, when not specified, use clang-tidy " "default",
"default", default="",
default="") )
parser.add_argument( parser.add_argument(
"-path", "-path", dest="build_path", help="Path used to read a compile command database."
dest="build_path", )
help="Path used to read a compile command database.")
if yaml: if yaml:
parser.add_argument( parser.add_argument(
"-export-fixes", "-export-fixes",
metavar="FILE", metavar="FILE",
dest="export_fixes", dest="export_fixes",
help="Create a yaml file to store suggested fixes in, " help="Create a yaml file to store suggested fixes in, "
"which can be applied with clang-apply-replacements.") "which can be applied with clang-apply-replacements.",
)
parser.add_argument( parser.add_argument(
"-extra-arg", "-extra-arg",
dest="extra_arg", dest="extra_arg",
action="append", action="append",
default=[], default=[],
help="Additional argument to append to the compiler " help="Additional argument to append to the compiler " "command line.",
"command line.") )
parser.add_argument( parser.add_argument(
"-extra-arg-before", "-extra-arg-before",
dest="extra_arg_before", dest="extra_arg_before",
action="append", action="append",
default=[], default=[],
help="Additional argument to prepend to the compiler " help="Additional argument to prepend to the compiler " "command line.",
"command line.") )
parser.add_argument( parser.add_argument(
"-quiet", "-quiet",
action="store_true", action="store_true",
default=False, default=False,
help="Run clang-tidy in quiet mode") help="Run clang-tidy in quiet mode",
)
clang_tidy_args = [] clang_tidy_args = []
argv = sys.argv[1:] argv = sys.argv[1:]
if "--" in argv: if "--" in argv:
clang_tidy_args.extend(argv[argv.index("--"):]) clang_tidy_args.extend(argv[argv.index("--") :])
argv = argv[:argv.index("--")] argv = argv[: argv.index("--")]
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -204,7 +207,7 @@ def main():
filename = None filename = None
lines_by_file = {} lines_by_file = {}
for line in sys.stdin: for line in sys.stdin:
match = re.search('^\+\+\+\ \"?(.*?/){%s}([^ \t\n\"]*)' % args.p, line) match = re.search('^\+\+\+\ "?(.*?/){%s}([^ \t\n"]*)' % args.p, line)
if match: if match:
filename = match.group(2) filename = match.group(2)
if filename is None: if filename is None:
@ -226,8 +229,7 @@ def main():
if line_count == 0: if line_count == 0:
continue continue
end_line = start_line + line_count - 1 end_line = start_line + line_count - 1
lines_by_file.setdefault(filename, lines_by_file.setdefault(filename, []).append([start_line, end_line])
[]).append([start_line, end_line])
if not any(lines_by_file): if not any(lines_by_file):
print("No relevant changes found.") print("No relevant changes found.")
@ -267,11 +269,8 @@ def main():
for name in lines_by_file: for name in lines_by_file:
line_filter_json = json.dumps( line_filter_json = json.dumps(
[{ [{"name": name, "lines": lines_by_file[name]}], separators=(",", ":")
"name": name, )
"lines": lines_by_file[name]
}],
separators=(",", ":"))
# Run clang-tidy on files containing changes. # Run clang-tidy on files containing changes.
command = [args.clang_tidy_binary] command = [args.clang_tidy_binary]

View file

@ -38,8 +38,10 @@ def is_pull_request():
for key in ["GITHUB_EVENT_NAME", "TRAVIS_EVENT_TYPE"]: for key in ["GITHUB_EVENT_NAME", "TRAVIS_EVENT_TYPE"]:
event_type = os.getenv(key, event_type) event_type = os.getenv(key, event_type)
if (os.environ.get("BUILDKITE") if (
and os.environ.get("BUILDKITE_PULL_REQUEST") != "false"): os.environ.get("BUILDKITE")
and os.environ.get("BUILDKITE_PULL_REQUEST") != "false"
):
event_type = "pull_request" event_type = "pull_request"
return event_type == "pull_request" return event_type == "pull_request"
@ -67,8 +69,7 @@ def get_commit_range():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--output", type=str, help="json or envvars", default="envvars")
"--output", type=str, help="json or envvars", default="envvars")
args = parser.parse_args() args = parser.parse_args()
RAY_CI_TUNE_AFFECTED = 0 RAY_CI_TUNE_AFFECTED = 0
@ -103,8 +104,7 @@ if __name__ == "__main__":
try: try:
graph = pda.build_dep_graph() graph = pda.build_dep_graph()
rllib_tests = pda.list_rllib_tests() rllib_tests = pda.list_rllib_tests()
print( print("Total # of RLlib tests: ", len(rllib_tests), file=sys.stderr)
"Total # of RLlib tests: ", len(rllib_tests), file=sys.stderr)
impacted = {} impacted = {}
for test in rllib_tests: for test in rllib_tests:
@ -120,9 +120,7 @@ if __name__ == "__main__":
print(e, file=sys.stderr) print(e, file=sys.stderr)
# End of dry run. # End of dry run.
skip_prefix_list = [ skip_prefix_list = ["doc/", "examples/", "dev/", "kubernetes/", "site/"]
"doc/", "examples/", "dev/", "kubernetes/", "site/"
]
for changed_file in files: for changed_file in files:
if changed_file.startswith("python/ray/tune"): if changed_file.startswith("python/ray/tune"):
@ -181,7 +179,8 @@ if __name__ == "__main__":
# Java also depends on Python CLI to manage processes. # Java also depends on Python CLI to manage processes.
RAY_CI_JAVA_AFFECTED = 1 RAY_CI_JAVA_AFFECTED = 1
if changed_file.startswith("python/setup.py") or re.match( if changed_file.startswith("python/setup.py") or re.match(
".*requirements.*\.txt", changed_file): ".*requirements.*\.txt", changed_file
):
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED = 1 RAY_CI_PYTHON_DEPENDENCIES_AFFECTED = 1
elif changed_file.startswith("java/"): elif changed_file.startswith("java/"):
RAY_CI_JAVA_AFFECTED = 1 RAY_CI_JAVA_AFFECTED = 1
@ -190,12 +189,9 @@ if __name__ == "__main__":
elif changed_file.startswith("docker/"): elif changed_file.startswith("docker/"):
RAY_CI_DOCKER_AFFECTED = 1 RAY_CI_DOCKER_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1
elif changed_file.startswith("doc/") and changed_file.endswith( elif changed_file.startswith("doc/") and changed_file.endswith(".py"):
".py"):
RAY_CI_DOC_AFFECTED = 1 RAY_CI_DOC_AFFECTED = 1
elif any( elif any(changed_file.startswith(prefix) for prefix in skip_prefix_list):
changed_file.startswith(prefix)
for prefix in skip_prefix_list):
# nothing is run but linting in these cases # nothing is run but linting in these cases
pass pass
elif changed_file.endswith("build-docker-images.py"): elif changed_file.endswith("build-docker-images.py"):
@ -246,26 +242,28 @@ if __name__ == "__main__":
RAY_CI_DASHBOARD_AFFECTED = 1 RAY_CI_DASHBOARD_AFFECTED = 1
# Log the modified environment variables visible in console. # Log the modified environment variables visible in console.
output_string = " ".join([ output_string = " ".join(
"RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED), [
"RAY_CI_SGD_AFFECTED={}".format(RAY_CI_SGD_AFFECTED), "RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED),
"RAY_CI_TRAIN_AFFECTED={}".format(RAY_CI_TRAIN_AFFECTED), "RAY_CI_SGD_AFFECTED={}".format(RAY_CI_SGD_AFFECTED),
"RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED), "RAY_CI_TRAIN_AFFECTED={}".format(RAY_CI_TRAIN_AFFECTED),
"RAY_CI_RLLIB_DIRECTLY_AFFECTED={}".format( "RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED),
RAY_CI_RLLIB_DIRECTLY_AFFECTED), "RAY_CI_RLLIB_DIRECTLY_AFFECTED={}".format(RAY_CI_RLLIB_DIRECTLY_AFFECTED),
"RAY_CI_SERVE_AFFECTED={}".format(RAY_CI_SERVE_AFFECTED), "RAY_CI_SERVE_AFFECTED={}".format(RAY_CI_SERVE_AFFECTED),
"RAY_CI_DASHBOARD_AFFECTED={}".format(RAY_CI_DASHBOARD_AFFECTED), "RAY_CI_DASHBOARD_AFFECTED={}".format(RAY_CI_DASHBOARD_AFFECTED),
"RAY_CI_DOC_AFFECTED={}".format(RAY_CI_DOC_AFFECTED), "RAY_CI_DOC_AFFECTED={}".format(RAY_CI_DOC_AFFECTED),
"RAY_CI_CORE_CPP_AFFECTED={}".format(RAY_CI_CORE_CPP_AFFECTED), "RAY_CI_CORE_CPP_AFFECTED={}".format(RAY_CI_CORE_CPP_AFFECTED),
"RAY_CI_CPP_AFFECTED={}".format(RAY_CI_CPP_AFFECTED), "RAY_CI_CPP_AFFECTED={}".format(RAY_CI_CPP_AFFECTED),
"RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED), "RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED),
"RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED), "RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED),
"RAY_CI_LINUX_WHEELS_AFFECTED={}".format(RAY_CI_LINUX_WHEELS_AFFECTED), "RAY_CI_LINUX_WHEELS_AFFECTED={}".format(RAY_CI_LINUX_WHEELS_AFFECTED),
"RAY_CI_MACOS_WHEELS_AFFECTED={}".format(RAY_CI_MACOS_WHEELS_AFFECTED), "RAY_CI_MACOS_WHEELS_AFFECTED={}".format(RAY_CI_MACOS_WHEELS_AFFECTED),
"RAY_CI_DOCKER_AFFECTED={}".format(RAY_CI_DOCKER_AFFECTED), "RAY_CI_DOCKER_AFFECTED={}".format(RAY_CI_DOCKER_AFFECTED),
"RAY_CI_PYTHON_DEPENDENCIES_AFFECTED={}".format( "RAY_CI_PYTHON_DEPENDENCIES_AFFECTED={}".format(
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED), RAY_CI_PYTHON_DEPENDENCIES_AFFECTED
]) ),
]
)
# Debug purpose # Debug purpose
print(output_string, file=sys.stderr) print(output_string, file=sys.stderr)

View file

@ -5,7 +5,7 @@
# Cause the script to exit if a single command fails # Cause the script to exit if a single command fails
set -euo pipefail set -euo pipefail
BLACK_IS_ENABLED=false BLACK_IS_ENABLED=true
FLAKE8_VERSION_REQUIRED="3.9.1" FLAKE8_VERSION_REQUIRED="3.9.1"
BLACK_VERSION_REQUIRED="21.12b0" BLACK_VERSION_REQUIRED="21.12b0"

View file

@ -17,19 +17,19 @@ import json
def gha_get_self_url(): def gha_get_self_url():
import requests import requests
# stringed together api call to get the current check's html url. # stringed together api call to get the current check's html url.
sha = os.environ["GITHUB_SHA"] sha = os.environ["GITHUB_SHA"]
repo = os.environ["GITHUB_REPOSITORY"] repo = os.environ["GITHUB_REPOSITORY"]
resp = requests.get( resp = requests.get(
"https://api.github.com/repos/{}/commits/{}/check-suites".format( "https://api.github.com/repos/{}/commits/{}/check-suites".format(repo, sha)
repo, sha)) )
data = resp.json() data = resp.json()
for check in data["check_suites"]: for check in data["check_suites"]:
slug = check["app"]["slug"] slug = check["app"]["slug"]
if slug == "github-actions": if slug == "github-actions":
run_url = check["check_runs_url"] run_url = check["check_runs_url"]
html_url = ( html_url = requests.get(run_url).json()["check_runs"][0]["html_url"]
requests.get(run_url).json()["check_runs"][0]["html_url"])
return html_url return html_url
# Return a fallback url # Return a fallback url
@ -47,10 +47,12 @@ def get_build_env():
if os.environ.get("BUILDKITE"): if os.environ.get("BUILDKITE"):
return { return {
"TRAVIS_COMMIT": os.environ["BUILDKITE_COMMIT"], "TRAVIS_COMMIT": os.environ["BUILDKITE_COMMIT"],
"TRAVIS_JOB_WEB_URL": (os.environ["BUILDKITE_BUILD_URL"] + "#" + "TRAVIS_JOB_WEB_URL": (
os.environ["BUILDKITE_BUILD_ID"]), os.environ["BUILDKITE_BUILD_URL"]
"TRAVIS_OS_NAME": # The map is used to stay consistent with Travis + "#"
{ + os.environ["BUILDKITE_BUILD_ID"]
),
"TRAVIS_OS_NAME": { # The map is used to stay consistent with Travis
"linux": "linux", "linux": "linux",
"darwin": "osx", "darwin": "osx",
"win32": "windows", "win32": "windows",
@ -70,13 +72,10 @@ def get_build_config():
return {"config": {"env": "Windows CI"}} return {"config": {"env": "Windows CI"}}
if os.environ.get("BUILDKITE"): if os.environ.get("BUILDKITE"):
return { return {"config": {"env": "Buildkite " + os.environ["BUILDKITE_LABEL"]}}
"config": {
"env": "Buildkite " + os.environ["BUILDKITE_LABEL"]
}
}
import requests import requests
url = "https://api.travis-ci.com/job/{job_id}?include=job.config" url = "https://api.travis-ci.com/job/{job_id}?include=job.config"
url = url.format(job_id=os.environ["TRAVIS_JOB_ID"]) url = url.format(job_id=os.environ["TRAVIS_JOB_ID"])
resp = requests.get(url, headers={"Travis-API-Version": "3"}) resp = requests.get(url, headers={"Travis-API-Version": "3"})
@ -87,9 +86,4 @@ if __name__ == "__main__":
build_env = get_build_env() build_env = get_build_env()
build_config = get_build_config() build_config = get_build_config()
print( print(json.dumps({"build_env": build_env, "build_config": build_config}, indent=2))
json.dumps(
{
"build_env": build_env,
"build_config": build_config
}, indent=2))

View file

@ -43,7 +43,8 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
test: only return information about a specific test. test: only return information about a specific test.
""" """
tests_res = _run_shell( tests_res = _run_shell(
["bazel", "query", "tests(//python/ray/rllib:*)", "--output", "label"]) ["bazel", "query", "tests(//python/ray/rllib:*)", "--output", "label"]
)
all_tests = [] all_tests = []
@ -53,15 +54,18 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
if test and t != test: if test and t != test:
continue continue
src_out = _run_shell([ src_out = _run_shell(
"bazel", "query", "kind(\"source file\", deps({}))".format(t), [
"--output", "label" "bazel",
]) "query",
'kind("source file", deps({}))'.format(t),
"--output",
"label",
]
)
srcs = [f.strip() for f in src_out.splitlines()] srcs = [f.strip() for f in src_out.splitlines()]
srcs = [ srcs = [f for f in srcs if f.startswith("//python") and f.endswith(".py")]
f for f in srcs if f.startswith("//python") and f.endswith(".py")
]
if srcs: if srcs:
all_tests.append((t, srcs)) all_tests.append((t, srcs))
@ -73,8 +77,7 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
def _new_dep(graph: DepGraph, src_module: str, dep: str): def _new_dep(graph: DepGraph, src_module: str, dep: str):
"""Create a new dependency between src_module and dep. """Create a new dependency between src_module and dep."""
"""
if dep not in graph.ids: if dep not in graph.ids:
graph.ids[dep] = len(graph.ids) graph.ids[dep] = len(graph.ids)
@ -87,8 +90,7 @@ def _new_dep(graph: DepGraph, src_module: str, dep: str):
def _new_import(graph: DepGraph, src_module: str, dep_module: str): def _new_import(graph: DepGraph, src_module: str, dep_module: str):
"""Process a new import statement in src_module. """Process a new import statement in src_module."""
"""
# We don't care about system imports. # We don't care about system imports.
if not dep_module.startswith("ray"): if not dep_module.startswith("ray"):
return return
@ -97,8 +99,7 @@ def _new_import(graph: DepGraph, src_module: str, dep_module: str):
def _is_path_module(module: str, name: str, _base_dir: str) -> bool: def _is_path_module(module: str, name: str, _base_dir: str) -> bool:
"""Figure out if base.sub is a python module or not. """Figure out if base.sub is a python module or not."""
"""
# Special handling for _raylet, which is a C++ lib. # Special handling for _raylet, which is a C++ lib.
if module == "ray._raylet": if module == "ray._raylet":
return False return False
@ -110,10 +111,10 @@ def _is_path_module(module: str, name: str, _base_dir: str) -> bool:
return False return False
def _new_from_import(graph: DepGraph, src_module: str, dep_module: str, def _new_from_import(
dep_name: str, _base_dir: str): graph: DepGraph, src_module: str, dep_module: str, dep_name: str, _base_dir: str
"""Process a new "from ... import ..." statement in src_module. ):
""" """Process a new "from ... import ..." statement in src_module."""
# We don't care about imports outside of ray package. # We don't care about imports outside of ray package.
if not dep_module or not dep_module.startswith("ray"): if not dep_module or not dep_module.startswith("ray"):
return return
@ -126,10 +127,7 @@ def _new_from_import(graph: DepGraph, src_module: str, dep_module: str,
_new_dep(graph, src_module, dep_module) _new_dep(graph, src_module, dep_module)
def _process_file(graph: DepGraph, def _process_file(graph: DepGraph, src_path: str, src_module: str, _base_dir=""):
src_path: str,
src_module: str,
_base_dir=""):
"""Create dependencies from src_module to all the valid imports in src_path. """Create dependencies from src_module to all the valid imports in src_path.
Args: Args:
@ -147,13 +145,13 @@ def _process_file(graph: DepGraph,
_new_import(graph, src_module, alias.name) _new_import(graph, src_module, alias.name)
elif isinstance(node, ast.ImportFrom): elif isinstance(node, ast.ImportFrom):
for alias in node.names: for alias in node.names:
_new_from_import(graph, src_module, node.module, _new_from_import(
alias.name, _base_dir) graph, src_module, node.module, alias.name, _base_dir
)
def build_dep_graph() -> DepGraph: def build_dep_graph() -> DepGraph:
"""Build index from py files to their immediate dependees. """Build index from py files to their immediate dependees."""
"""
graph = DepGraph() graph = DepGraph()
# Assuming we run from root /ray directory. # Assuming we run from root /ray directory.
@ -197,8 +195,7 @@ def _full_module_path(module, f) -> str:
def _should_skip(d: str) -> bool: def _should_skip(d: str) -> bool:
"""Skip directories that should not contain py sources. """Skip directories that should not contain py sources."""
"""
if d.startswith("python/.eggs/"): if d.startswith("python/.eggs/"):
return True return True
if d.startswith("python/."): if d.startswith("python/."):
@ -224,14 +221,14 @@ def _bazel_path_to_module_path(d: str) -> str:
def _file_path_to_module_path(f: str) -> str: def _file_path_to_module_path(f: str) -> str:
"""Return the corresponding module path for a .py file. """Return the corresponding module path for a .py file."""
"""
dir, fn = os.path.split(f) dir, fn = os.path.split(f)
return _full_module_path(_bazel_path_to_module_path(dir), fn) return _full_module_path(_bazel_path_to_module_path(dir), fn)
def _depends(graph: DepGraph, visited: Dict[int, bool], tid: int, def _depends(
qid: int) -> List[int]: graph: DepGraph, visited: Dict[int, bool], tid: int, qid: int
) -> List[int]:
"""Whether there is a dependency path from module tid to module qid. """Whether there is a dependency path from module tid to module qid.
Given graph, and without going through visited. Given graph, and without going through visited.
@ -253,8 +250,9 @@ def _depends(graph: DepGraph, visited: Dict[int, bool], tid: int,
return [] return []
def test_depends_on_file(graph: DepGraph, test: Tuple[str, Tuple[str]], def test_depends_on_file(
path: str) -> List[int]: graph: DepGraph, test: Tuple[str, Tuple[str]], path: str
) -> List[int]:
"""Give dependency graph, check if a test depends on a specific .py file. """Give dependency graph, check if a test depends on a specific .py file.
Args: Args:
@ -307,8 +305,7 @@ def _find_circular_dep_impl(graph: DepGraph, id: str, branch: str) -> bool:
def find_circular_dep(graph: DepGraph) -> Dict[str, List[int]]: def find_circular_dep(graph: DepGraph) -> Dict[str, List[int]]:
"""Find circular dependencies among a dependency graph. """Find circular dependencies among a dependency graph."""
"""
known = {} known = {}
circles = {} circles = {}
for m, id in graph.ids.items(): for m, id in graph.ids.items():
@ -334,25 +331,29 @@ if __name__ == "__main__":
"--mode", "--mode",
type=str, type=str,
default="test-dep", default="test-dep",
help=("test-dep: find dependencies for a specified test. " help=(
"circular-dep: find circular dependencies in " "test-dep: find dependencies for a specified test. "
"the specific codebase.")) "circular-dep: find circular dependencies in "
"the specific codebase."
),
)
parser.add_argument( parser.add_argument(
"--file", "--file", type=str, help="Path of a .py source file relative to --base_dir."
type=str, )
help="Path of a .py source file relative to --base_dir.")
parser.add_argument("--test", type=str, help="Specific test to check.") parser.add_argument("--test", type=str, help="Specific test to check.")
parser.add_argument( parser.add_argument(
"--smoke-test", "--smoke-test", action="store_true", help="Load only a few tests for testing."
action="store_true", )
help="Load only a few tests for testing.")
args = parser.parse_args() args = parser.parse_args()
print("building dep graph ...") print("building dep graph ...")
graph = build_dep_graph() graph = build_dep_graph()
print("done. total {} files, {} of which have dependencies.".format( print(
len(graph.ids), len(graph.edges))) "done. total {} files, {} of which have dependencies.".format(
len(graph.ids), len(graph.edges)
)
)
if args.mode == "circular-dep": if args.mode == "circular-dep":
circles = find_circular_dep(graph) circles = find_circular_dep(graph)

View file

@ -12,30 +12,33 @@ class TestPyDepAnalysis(unittest.TestCase):
f.close() f.close()
def test_full_module_path(self): def test_full_module_path(self):
self.assertEqual( self.assertEqual(pda._full_module_path("aa.bb.cc", "__init__.py"), "aa.bb.cc")
pda._full_module_path("aa.bb.cc", "__init__.py"), "aa.bb.cc") self.assertEqual(pda._full_module_path("aa.bb.cc", "dd.py"), "aa.bb.cc.dd")
self.assertEqual(
pda._full_module_path("aa.bb.cc", "dd.py"), "aa.bb.cc.dd")
self.assertEqual(pda._full_module_path("", "dd.py"), "dd") self.assertEqual(pda._full_module_path("", "dd.py"), "dd")
def test_bazel_path_to_module_path(self): def test_bazel_path_to_module_path(self):
self.assertEqual( self.assertEqual(
pda._bazel_path_to_module_path("//python/ray/rllib:xxx/yyy/dd"), pda._bazel_path_to_module_path("//python/ray/rllib:xxx/yyy/dd"),
"ray.rllib.xxx.yyy.dd") "ray.rllib.xxx.yyy.dd",
)
self.assertEqual( self.assertEqual(
pda._bazel_path_to_module_path("python:ray/rllib/xxx/yyy/dd"), pda._bazel_path_to_module_path("python:ray/rllib/xxx/yyy/dd"),
"ray.rllib.xxx.yyy.dd") "ray.rllib.xxx.yyy.dd",
)
self.assertEqual( self.assertEqual(
pda._bazel_path_to_module_path("python/ray/rllib:xxx/yyy/dd"), pda._bazel_path_to_module_path("python/ray/rllib:xxx/yyy/dd"),
"ray.rllib.xxx.yyy.dd") "ray.rllib.xxx.yyy.dd",
)
def test_file_path_to_module_path(self): def test_file_path_to_module_path(self):
self.assertEqual( self.assertEqual(
pda._file_path_to_module_path("python/ray/rllib/env/env.py"), pda._file_path_to_module_path("python/ray/rllib/env/env.py"),
"ray.rllib.env.env") "ray.rllib.env.env",
)
self.assertEqual( self.assertEqual(
pda._file_path_to_module_path("python/ray/rllib/env/__init__.py"), pda._file_path_to_module_path("python/ray/rllib/env/__init__.py"),
"ray.rllib.env") "ray.rllib.env",
)
def test_import_line_continuation(self): def test_import_line_continuation(self):
graph = pda.DepGraph() graph = pda.DepGraph()
@ -44,11 +47,13 @@ class TestPyDepAnalysis(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "continuation1.py") src_path = os.path.join(tmpdir, "continuation1.py")
self.create_tmp_file( self.create_tmp_file(
src_path, """ src_path,
"""
import ray.rllib.env.\\ import ray.rllib.env.\\
mock_env mock_env
b = 2 b = 2
""") """,
)
pda._process_file(graph, src_path, "ray") pda._process_file(graph, src_path, "ray")
self.assertEqual(len(graph.ids), 2) self.assertEqual(len(graph.ids), 2)
@ -64,11 +69,13 @@ b = 2
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "continuation1.py") src_path = os.path.join(tmpdir, "continuation1.py")
self.create_tmp_file( self.create_tmp_file(
src_path, """ src_path,
"""
from ray.rllib.env import (ClassName, from ray.rllib.env import (ClassName,
module1, module2) module1, module2)
b = 2 b = 2
""") """,
)
pda._process_file(graph, src_path, "ray") pda._process_file(graph, src_path, "ray")
self.assertEqual(len(graph.ids), 2) self.assertEqual(len(graph.ids), 2)
@ -84,11 +91,13 @@ b = 2
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
src_path = "multi_line_comment_3.py" src_path = "multi_line_comment_3.py"
self.create_tmp_file( self.create_tmp_file(
os.path.join(tmpdir, src_path), """ os.path.join(tmpdir, src_path),
"""
from ray.rllib.env import mock_env from ray.rllib.env import mock_env
a = 1 a = 1
b = 2 b = 2
""") """,
)
# Touch ray/rllib/env/mock_env.py in tmpdir, # Touch ray/rllib/env/mock_env.py in tmpdir,
# so that it looks like a module. # so that it looks like a module.
module_dir = os.path.join(tmpdir, "python", "ray", "rllib", "env") module_dir = os.path.join(tmpdir, "python", "ray", "rllib", "env")
@ -112,11 +121,13 @@ b = 2
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
src_path = "multi_line_comment_3.py" src_path = "multi_line_comment_3.py"
self.create_tmp_file( self.create_tmp_file(
os.path.join(tmpdir, src_path), """ os.path.join(tmpdir, src_path),
"""
from ray.rllib.env import MockEnv from ray.rllib.env import MockEnv
a = 1 a = 1
b = 2 b = 2
""") """,
)
# Touch ray/rllib/env.py in tmpdir, # Touch ray/rllib/env.py in tmpdir,
# MockEnv is a class on env module. # MockEnv is a class on env module.
module_dir = os.path.join(tmpdir, "python", "ray", "rllib") module_dir = os.path.join(tmpdir, "python", "ray", "rllib")
@ -138,4 +149,5 @@ b = 2
if __name__ == "__main__": if __name__ == "__main__":
import pytest import pytest
import sys import sys
sys.exit(pytest.main(["-v", __file__])) sys.exit(pytest.main(["-v", __file__]))

View file

@ -22,8 +22,11 @@ import ray.ray_constants as ray_constants
import ray._private.services import ray._private.services
import ray._private.utils import ray._private.utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.gcs_utils import GcsClient, \ from ray._private.gcs_utils import (
get_gcs_address_from_redis, use_gcs_for_bootstrap GcsClient,
get_gcs_address_from_redis,
use_gcs_for_bootstrap,
)
from ray.core.generated import agent_manager_pb2 from ray.core.generated import agent_manager_pb2
from ray.core.generated import agent_manager_pb2_grpc from ray.core.generated import agent_manager_pb2_grpc
from ray._private.ray_logging import setup_component_logger from ray._private.ray_logging import setup_component_logger
@ -42,23 +45,25 @@ aiogrpc.init_grpc_aio()
class DashboardAgent(object): class DashboardAgent(object):
def __init__(self, def __init__(
node_ip_address, self,
redis_address, node_ip_address,
dashboard_agent_port, redis_address,
gcs_address, dashboard_agent_port,
minimal, gcs_address,
redis_password=None, minimal,
temp_dir=None, redis_password=None,
session_dir=None, temp_dir=None,
runtime_env_dir=None, session_dir=None,
log_dir=None, runtime_env_dir=None,
metrics_export_port=None, log_dir=None,
node_manager_port=None, metrics_export_port=None,
listen_port=0, node_manager_port=None,
object_store_name=None, listen_port=0,
raylet_name=None, object_store_name=None,
logging_params=None): raylet_name=None,
logging_params=None,
):
"""Initialize the DashboardAgent object.""" """Initialize the DashboardAgent object."""
# Public attributes are accessible for all agent modules. # Public attributes are accessible for all agent modules.
self.ip = node_ip_address self.ip = node_ip_address
@ -92,15 +97,16 @@ class DashboardAgent(object):
self.ppid = int(os.environ["RAY_RAYLET_PID"]) self.ppid = int(os.environ["RAY_RAYLET_PID"])
assert self.ppid > 0 assert self.ppid > 0
logger.info("Parent pid is %s", self.ppid) logger.info("Parent pid is %s", self.ppid)
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0" grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server( self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
self.server, f"{grpc_ip}:{self.dashboard_agent_port}") self.server, f"{grpc_ip}:{self.dashboard_agent_port}"
logger.info("Dashboard agent grpc address: %s:%s", grpc_ip, )
self.grpc_port) logger.info("Dashboard agent grpc address: %s:%s", grpc_ip, self.grpc_port)
options = (("grpc.enable_http_proxy", 0), ) options = (("grpc.enable_http_proxy", 0),)
self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel( self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel(
f"{self.ip}:{self.node_manager_port}", options, asynchronous=True) f"{self.ip}:{self.node_manager_port}", options, asynchronous=True
)
# If the agent is started as non-minimal version, http server should # If the agent is started as non-minimal version, http server should
# be configured to communicate with the dashboard in a head node. # be configured to communicate with the dashboard in a head node.
@ -108,6 +114,7 @@ class DashboardAgent(object):
async def _configure_http_server(self, modules): async def _configure_http_server(self, modules):
from ray.dashboard.http_server_agent import HttpServerAgent from ray.dashboard.http_server_agent import HttpServerAgent
http_server = HttpServerAgent(self.ip, self.listen_port) http_server = HttpServerAgent(self.ip, self.listen_port)
await http_server.start(modules) await http_server.start(modules)
return http_server return http_server
@ -116,10 +123,12 @@ class DashboardAgent(object):
"""Load dashboard agent modules.""" """Load dashboard agent modules."""
modules = [] modules = []
agent_cls_list = dashboard_utils.get_all_modules( agent_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardAgentModule) dashboard_utils.DashboardAgentModule
)
for cls in agent_cls_list: for cls in agent_cls_list:
logger.info("Loading %s: %s", logger.info(
dashboard_utils.DashboardAgentModule.__name__, cls) "Loading %s: %s", dashboard_utils.DashboardAgentModule.__name__, cls
)
c = cls(self) c = cls(self)
modules.append(c) modules.append(c)
logger.info("Loaded %d modules.", len(modules)) logger.info("Loaded %d modules.", len(modules))
@ -137,13 +146,12 @@ class DashboardAgent(object):
curr_proc = psutil.Process() curr_proc = psutil.Process()
while True: while True:
parent = curr_proc.parent() parent = curr_proc.parent()
if (parent is None or parent.pid == 1 if parent is None or parent.pid == 1 or self.ppid != parent.pid:
or self.ppid != parent.pid):
logger.error("Raylet is dead, exiting.") logger.error("Raylet is dead, exiting.")
sys.exit(0) sys.exit(0)
await asyncio.sleep( await asyncio.sleep(
dashboard_consts. dashboard_consts.DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS) )
except Exception: except Exception:
logger.error("Failed to check parent PID, exiting.") logger.error("Failed to check parent PID, exiting.")
sys.exit(1) sys.exit(1)
@ -154,15 +162,17 @@ class DashboardAgent(object):
if not use_gcs_for_bootstrap(): if not use_gcs_for_bootstrap():
# Create an aioredis client for all modules. # Create an aioredis client for all modules.
try: try:
self.aioredis_client = \ self.aioredis_client = await dashboard_utils.get_aioredis_client(
await dashboard_utils.get_aioredis_client( self.redis_address,
self.redis_address, self.redis_password, self.redis_password,
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS, dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES) dashboard_consts.RETRY_REDIS_CONNECTION_TIMES,
)
except (socket.gaierror, ConnectionRefusedError): except (socket.gaierror, ConnectionRefusedError):
logger.error( logger.error(
"Dashboard agent exiting: " "Dashboard agent exiting: " "Failed to connect to redis at %s",
"Failed to connect to redis at %s", self.redis_address) self.redis_address,
)
sys.exit(-1) sys.exit(-1)
# Start a grpc asyncio server. # Start a grpc asyncio server.
@ -170,7 +180,8 @@ class DashboardAgent(object):
if not use_gcs_for_bootstrap(): if not use_gcs_for_bootstrap():
gcs_address = await self.aioredis_client.get( gcs_address = await self.aioredis_client.get(
dashboard_consts.GCS_SERVER_ADDRESS) dashboard_consts.GCS_SERVER_ADDRESS
)
self.gcs_client = GcsClient(address=gcs_address.decode()) self.gcs_client = GcsClient(address=gcs_address.decode())
else: else:
self.gcs_client = GcsClient(address=self.gcs_address) self.gcs_client = GcsClient(address=self.gcs_address)
@ -192,17 +203,21 @@ class DashboardAgent(object):
internal_kv._internal_kv_put( internal_kv._internal_kv_put(
f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{self.node_id}", f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{self.node_id}",
json.dumps([http_port, self.grpc_port]), json.dumps([http_port, self.grpc_port]),
namespace=ray_constants.KV_NAMESPACE_DASHBOARD) namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
# Register agent to agent manager. # Register agent to agent manager.
raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub( raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub(
self.aiogrpc_raylet_channel) self.aiogrpc_raylet_channel
)
await raylet_stub.RegisterAgent( await raylet_stub.RegisterAgent(
agent_manager_pb2.RegisterAgentRequest( agent_manager_pb2.RegisterAgentRequest(
agent_pid=os.getpid(), agent_pid=os.getpid(),
agent_port=self.grpc_port, agent_port=self.grpc_port,
agent_ip_address=self.ip)) agent_ip_address=self.ip,
)
)
tasks = [m.run(self.server) for m in modules] tasks = [m.run(self.server) for m in modules]
if sys.platform not in ["win32", "cygwin"]: if sys.platform not in ["win32", "cygwin"]:
@ -221,123 +236,139 @@ if __name__ == "__main__":
"--node-ip-address", "--node-ip-address",
required=True, required=True,
type=str, type=str,
help="the IP address of this node.") help="the IP address of this node.",
)
parser.add_argument( parser.add_argument(
"--gcs-address", "--gcs-address", required=False, type=str, help="The address (ip:port) of GCS."
required=False, )
type=str,
help="The address (ip:port) of GCS.")
parser.add_argument( parser.add_argument(
"--redis-address", "--redis-address", required=True, type=str, help="The address to use for Redis."
required=True, )
type=str,
help="The address to use for Redis.")
parser.add_argument( parser.add_argument(
"--metrics-export-port", "--metrics-export-port",
required=True, required=True,
type=int, type=int,
help="The port to expose metrics through Prometheus.") help="The port to expose metrics through Prometheus.",
)
parser.add_argument( parser.add_argument(
"--dashboard-agent-port", "--dashboard-agent-port",
required=True, required=True,
type=int, type=int,
help="The port on which the dashboard agent will receive GRPCs.") help="The port on which the dashboard agent will receive GRPCs.",
)
parser.add_argument( parser.add_argument(
"--node-manager-port", "--node-manager-port",
required=True, required=True,
type=int, type=int,
help="The port to use for starting the node manager") help="The port to use for starting the node manager",
)
parser.add_argument( parser.add_argument(
"--object-store-name", "--object-store-name",
required=True, required=True,
type=str, type=str,
default=None, default=None,
help="The socket name of the plasma store") help="The socket name of the plasma store",
)
parser.add_argument( parser.add_argument(
"--listen-port", "--listen-port",
required=False, required=False,
type=int, type=int,
default=0, default=0,
help="Port for HTTP server to listen on") help="Port for HTTP server to listen on",
)
parser.add_argument( parser.add_argument(
"--raylet-name", "--raylet-name",
required=True, required=True,
type=str, type=str,
default=None, default=None,
help="The socket path of the raylet process") help="The socket path of the raylet process",
)
parser.add_argument( parser.add_argument(
"--redis-password", "--redis-password",
required=False, required=False,
type=str, type=str,
default=None, default=None,
help="The password to use for Redis") help="The password to use for Redis",
)
parser.add_argument( parser.add_argument(
"--logging-level", "--logging-level",
required=False, required=False,
type=lambda s: logging.getLevelName(s.upper()), type=lambda s: logging.getLevelName(s.upper()),
default=ray_constants.LOGGER_LEVEL, default=ray_constants.LOGGER_LEVEL,
choices=ray_constants.LOGGER_LEVEL_CHOICES, choices=ray_constants.LOGGER_LEVEL_CHOICES,
help=ray_constants.LOGGER_LEVEL_HELP) help=ray_constants.LOGGER_LEVEL_HELP,
)
parser.add_argument( parser.add_argument(
"--logging-format", "--logging-format",
required=False, required=False,
type=str, type=str,
default=ray_constants.LOGGER_FORMAT, default=ray_constants.LOGGER_FORMAT,
help=ray_constants.LOGGER_FORMAT_HELP) help=ray_constants.LOGGER_FORMAT_HELP,
)
parser.add_argument( parser.add_argument(
"--logging-filename", "--logging-filename",
required=False, required=False,
type=str, type=str,
default=dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME, default=dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME,
help="Specify the name of log file, " help="Specify the name of log file, "
"log to stdout if set empty, default is \"{}\".".format( 'log to stdout if set empty, default is "{}".'.format(
dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME)) dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME
),
)
parser.add_argument( parser.add_argument(
"--logging-rotate-bytes", "--logging-rotate-bytes",
required=False, required=False,
type=int, type=int,
default=ray_constants.LOGGING_ROTATE_BYTES, default=ray_constants.LOGGING_ROTATE_BYTES,
help="Specify the max bytes for rotating " help="Specify the max bytes for rotating "
"log file, default is {} bytes.".format( "log file, default is {} bytes.".format(ray_constants.LOGGING_ROTATE_BYTES),
ray_constants.LOGGING_ROTATE_BYTES)) )
parser.add_argument( parser.add_argument(
"--logging-rotate-backup-count", "--logging-rotate-backup-count",
required=False, required=False,
type=int, type=int,
default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT, default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
help="Specify the backup count of rotated log file, default is {}.". help="Specify the backup count of rotated log file, default is {}.".format(
format(ray_constants.LOGGING_ROTATE_BACKUP_COUNT)) ray_constants.LOGGING_ROTATE_BACKUP_COUNT
),
)
parser.add_argument( parser.add_argument(
"--log-dir", "--log-dir",
required=True, required=True,
type=str, type=str,
default=None, default=None,
help="Specify the path of log directory.") help="Specify the path of log directory.",
)
parser.add_argument( parser.add_argument(
"--temp-dir", "--temp-dir",
required=True, required=True,
type=str, type=str,
default=None, default=None,
help="Specify the path of the temporary directory use by Ray process.") help="Specify the path of the temporary directory use by Ray process.",
)
parser.add_argument( parser.add_argument(
"--session-dir", "--session-dir",
required=True, required=True,
type=str, type=str,
default=None, default=None,
help="Specify the path of this session.") help="Specify the path of this session.",
)
parser.add_argument( parser.add_argument(
"--runtime-env-dir", "--runtime-env-dir",
required=True, required=True,
type=str, type=str,
default=None, default=None,
help="Specify the path of the resource directory used by runtime_env.") help="Specify the path of the resource directory used by runtime_env.",
)
parser.add_argument( parser.add_argument(
"--minimal", "--minimal",
action="store_true", action="store_true",
help=( help=(
"Minimal agent only contains a subset of features that don't " "Minimal agent only contains a subset of features that don't "
"require additional dependencies installed when ray is installed " "require additional dependencies installed when ray is installed "
"by `pip install ray[default]`.")) "by `pip install ray[default]`."
),
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -347,7 +378,8 @@ if __name__ == "__main__":
log_dir=args.log_dir, log_dir=args.log_dir,
filename=args.logging_filename, filename=args.logging_filename,
max_bytes=args.logging_rotate_bytes, max_bytes=args.logging_rotate_bytes,
backup_count=args.logging_rotate_backup_count) backup_count=args.logging_rotate_backup_count,
)
setup_component_logger(**logging_params) setup_component_logger(**logging_params)
agent = DashboardAgent( agent = DashboardAgent(
@ -366,7 +398,8 @@ if __name__ == "__main__":
listen_port=args.listen_port, listen_port=args.listen_port,
object_store_name=args.object_store_name, object_store_name=args.object_store_name,
raylet_name=args.raylet_name, raylet_name=args.raylet_name,
logging_params=logging_params) logging_params=logging_params,
)
if os.environ.get("_RAY_AGENT_FAILING"): if os.environ.get("_RAY_AGENT_FAILING"):
raise Exception("Failure injection failure.") raise Exception("Failure injection failure.")
@ -390,15 +423,19 @@ if __name__ == "__main__":
gcs_publisher = GcsPublisher(args.gcs_address) gcs_publisher = GcsPublisher(args.gcs_address)
else: else:
redis_client = ray._private.services.create_redis_client( redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password) args.redis_address, password=args.redis_password
)
gcs_publisher = GcsPublisher( gcs_publisher = GcsPublisher(
address=get_gcs_address_from_redis(redis_client)) address=get_gcs_address_from_redis(redis_client)
)
else: else:
redis_client = ray._private.services.create_redis_client( redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password) args.redis_address, password=args.redis_password
)
traceback_str = ray._private.utils.format_error_message( traceback_str = ray._private.utils.format_error_message(
traceback.format_exc()) traceback.format_exc()
)
message = ( message = (
f"(ip={node_ip}) " f"(ip={node_ip}) "
f"The agent on node {platform.uname()[1]} failed to " f"The agent on node {platform.uname()[1]} failed to "
@ -409,12 +446,14 @@ if __name__ == "__main__":
"\n 2. Metrics on this node won't be reported." "\n 2. Metrics on this node won't be reported."
"\n 3. runtime_env APIs won't work." "\n 3. runtime_env APIs won't work."
"\nCheck out the `dashboard_agent.log` to see the " "\nCheck out the `dashboard_agent.log` to see the "
"detailed failure messages.") "detailed failure messages."
)
ray._private.utils.publish_error_to_driver( ray._private.utils.publish_error_to_driver(
ray_constants.DASHBOARD_AGENT_DIED_ERROR, ray_constants.DASHBOARD_AGENT_DIED_ERROR,
message, message,
redis_client=redis_client, redis_client=redis_client,
gcs_publisher=gcs_publisher) gcs_publisher=gcs_publisher,
)
logger.error(message) logger.error(message)
logger.exception(e) logger.exception(e)
exit(1) exit(1)

View file

@ -12,12 +12,13 @@ DASHBOARD_RPC_ADDRESS = "dashboard_rpc"
GCS_SERVER_ADDRESS = "GcsServerAddress" GCS_SERVER_ADDRESS = "GcsServerAddress"
# GCS check alive # GCS check alive
GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR = env_integer( GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR = env_integer(
"GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR", 10) "GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR", 10
GCS_CHECK_ALIVE_INTERVAL_SECONDS = env_integer( )
"GCS_CHECK_ALIVE_INTERVAL_SECONDS", 5) GCS_CHECK_ALIVE_INTERVAL_SECONDS = env_integer("GCS_CHECK_ALIVE_INTERVAL_SECONDS", 5)
GCS_CHECK_ALIVE_RPC_TIMEOUT = env_integer("GCS_CHECK_ALIVE_RPC_TIMEOUT", 10) GCS_CHECK_ALIVE_RPC_TIMEOUT = env_integer("GCS_CHECK_ALIVE_RPC_TIMEOUT", 10)
GCS_RETRY_CONNECT_INTERVAL_SECONDS = env_integer( GCS_RETRY_CONNECT_INTERVAL_SECONDS = env_integer(
"GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2) "GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2
)
# aiohttp_cache # aiohttp_cache
AIOHTTP_CACHE_TTL_SECONDS = 2 AIOHTTP_CACHE_TTL_SECONDS = 2
AIOHTTP_CACHE_MAX_SIZE = 128 AIOHTTP_CACHE_MAX_SIZE = 128

View file

@ -38,17 +38,21 @@ class FrontendNotFoundError(OSError):
def setup_static_dir(): def setup_static_dir():
build_dir = os.path.join( build_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "client", "build") os.path.dirname(os.path.abspath(__file__)), "client", "build"
)
module_name = os.path.basename(os.path.dirname(__file__)) module_name = os.path.basename(os.path.dirname(__file__))
if not os.path.isdir(build_dir): if not os.path.isdir(build_dir):
raise FrontendNotFoundError( raise FrontendNotFoundError(
errno.ENOENT, "Dashboard build directory not found. If installing " errno.ENOENT,
"Dashboard build directory not found. If installing "
"from source, please follow the additional steps " "from source, please follow the additional steps "
"required to build the dashboard" "required to build the dashboard"
f"(cd python/ray/{module_name}/client " f"(cd python/ray/{module_name}/client "
"&& npm install " "&& npm install "
"&& npm ci " "&& npm ci "
"&& npm run build)", build_dir) "&& npm run build)",
build_dir,
)
static_dir = os.path.join(build_dir, "static") static_dir = os.path.join(build_dir, "static")
routes.static("/static", static_dir, follow_symlinks=True) routes.static("/static", static_dir, follow_symlinks=True)
@ -72,14 +76,16 @@ class Dashboard:
log_dir(str): Log directory of dashboard. log_dir(str): Log directory of dashboard.
""" """
def __init__(self, def __init__(
host, self,
port, host,
port_retries, port,
gcs_address, port_retries,
redis_address, gcs_address,
redis_password=None, redis_address,
log_dir=None): redis_password=None,
log_dir=None,
):
self.dashboard_head = dashboard_head.DashboardHead( self.dashboard_head = dashboard_head.DashboardHead(
http_host=host, http_host=host,
http_port=port, http_port=port,
@ -87,7 +93,8 @@ class Dashboard:
gcs_address=gcs_address, gcs_address=gcs_address,
redis_address=redis_address, redis_address=redis_address,
redis_password=redis_password, redis_password=redis_password,
log_dir=log_dir) log_dir=log_dir,
)
# Setup Dashboard Routes # Setup Dashboard Routes
try: try:
@ -107,15 +114,17 @@ class Dashboard:
async def get_index(self, req) -> aiohttp.web.FileResponse: async def get_index(self, req) -> aiohttp.web.FileResponse:
return aiohttp.web.FileResponse( return aiohttp.web.FileResponse(
os.path.join( os.path.join(
os.path.dirname(os.path.abspath(__file__)), os.path.dirname(os.path.abspath(__file__)), "client/build/index.html"
"client/build/index.html")) )
)
@routes.get("/favicon.ico") @routes.get("/favicon.ico")
async def get_favicon(self, req) -> aiohttp.web.FileResponse: async def get_favicon(self, req) -> aiohttp.web.FileResponse:
return aiohttp.web.FileResponse( return aiohttp.web.FileResponse(
os.path.join( os.path.join(
os.path.dirname(os.path.abspath(__file__)), os.path.dirname(os.path.abspath(__file__)), "client/build/favicon.ico"
"client/build/favicon.ico")) )
)
async def run(self): async def run(self):
await self.dashboard_head.run() await self.dashboard_head.run()
@ -124,92 +133,96 @@ class Dashboard:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Ray dashboard.") parser = argparse.ArgumentParser(description="Ray dashboard.")
parser.add_argument( parser.add_argument(
"--host", "--host", required=True, type=str, help="The host to use for the HTTP server."
required=True, )
type=str,
help="The host to use for the HTTP server.")
parser.add_argument( parser.add_argument(
"--port", "--port", required=True, type=int, help="The port to use for the HTTP server."
required=True, )
type=int,
help="The port to use for the HTTP server.")
parser.add_argument( parser.add_argument(
"--port-retries", "--port-retries",
required=False, required=False,
type=int, type=int,
default=0, default=0,
help="The retry times to select a valid port.") help="The retry times to select a valid port.",
)
parser.add_argument( parser.add_argument(
"--gcs-address", "--gcs-address", required=False, type=str, help="The address (ip:port) of GCS."
required=False, )
type=str,
help="The address (ip:port) of GCS.")
parser.add_argument( parser.add_argument(
"--redis-address", "--redis-address", required=True, type=str, help="The address to use for Redis."
required=True, )
type=str,
help="The address to use for Redis.")
parser.add_argument( parser.add_argument(
"--redis-password", "--redis-password",
required=False, required=False,
type=str, type=str,
default=None, default=None,
help="The password to use for Redis") help="The password to use for Redis",
)
parser.add_argument( parser.add_argument(
"--logging-level", "--logging-level",
required=False, required=False,
type=lambda s: logging.getLevelName(s.upper()), type=lambda s: logging.getLevelName(s.upper()),
default=ray_constants.LOGGER_LEVEL, default=ray_constants.LOGGER_LEVEL,
choices=ray_constants.LOGGER_LEVEL_CHOICES, choices=ray_constants.LOGGER_LEVEL_CHOICES,
help=ray_constants.LOGGER_LEVEL_HELP) help=ray_constants.LOGGER_LEVEL_HELP,
)
parser.add_argument( parser.add_argument(
"--logging-format", "--logging-format",
required=False, required=False,
type=str, type=str,
default=ray_constants.LOGGER_FORMAT, default=ray_constants.LOGGER_FORMAT,
help=ray_constants.LOGGER_FORMAT_HELP) help=ray_constants.LOGGER_FORMAT_HELP,
)
parser.add_argument( parser.add_argument(
"--logging-filename", "--logging-filename",
required=False, required=False,
type=str, type=str,
default=dashboard_consts.DASHBOARD_LOG_FILENAME, default=dashboard_consts.DASHBOARD_LOG_FILENAME,
help="Specify the name of log file, " help="Specify the name of log file, "
"log to stdout if set empty, default is \"{}\"".format( 'log to stdout if set empty, default is "{}"'.format(
dashboard_consts.DASHBOARD_LOG_FILENAME)) dashboard_consts.DASHBOARD_LOG_FILENAME
),
)
parser.add_argument( parser.add_argument(
"--logging-rotate-bytes", "--logging-rotate-bytes",
required=False, required=False,
type=int, type=int,
default=ray_constants.LOGGING_ROTATE_BYTES, default=ray_constants.LOGGING_ROTATE_BYTES,
help="Specify the max bytes for rotating " help="Specify the max bytes for rotating "
"log file, default is {} bytes.".format( "log file, default is {} bytes.".format(ray_constants.LOGGING_ROTATE_BYTES),
ray_constants.LOGGING_ROTATE_BYTES)) )
parser.add_argument( parser.add_argument(
"--logging-rotate-backup-count", "--logging-rotate-backup-count",
required=False, required=False,
type=int, type=int,
default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT, default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
help="Specify the backup count of rotated log file, default is {}.". help="Specify the backup count of rotated log file, default is {}.".format(
format(ray_constants.LOGGING_ROTATE_BACKUP_COUNT)) ray_constants.LOGGING_ROTATE_BACKUP_COUNT
),
)
parser.add_argument( parser.add_argument(
"--log-dir", "--log-dir",
required=True, required=True,
type=str, type=str,
default=None, default=None,
help="Specify the path of log directory.") help="Specify the path of log directory.",
)
parser.add_argument( parser.add_argument(
"--temp-dir", "--temp-dir",
required=True, required=True,
type=str, type=str,
default=None, default=None,
help="Specify the path of the temporary directory use by Ray process.") help="Specify the path of the temporary directory use by Ray process.",
)
parser.add_argument( parser.add_argument(
"--minimal", "--minimal",
action="store_true", action="store_true",
help=( help=(
"Minimal dashboard only contains a subset of features that don't " "Minimal dashboard only contains a subset of features that don't "
"require additional dependencies installed when ray is installed " "require additional dependencies installed when ray is installed "
"by `pip install ray[default]`.")) "by `pip install ray[default]`."
),
)
args = parser.parse_args() args = parser.parse_args()
@ -226,7 +239,8 @@ if __name__ == "__main__":
log_dir=args.log_dir, log_dir=args.log_dir,
filename=args.logging_filename, filename=args.logging_filename,
max_bytes=args.logging_rotate_bytes, max_bytes=args.logging_rotate_bytes,
backup_count=args.logging_rotate_backup_count) backup_count=args.logging_rotate_backup_count,
)
dashboard = Dashboard( dashboard = Dashboard(
args.host, args.host,
@ -235,25 +249,27 @@ if __name__ == "__main__":
args.gcs_address, args.gcs_address,
args.redis_address, args.redis_address,
redis_password=args.redis_password, redis_password=args.redis_password,
log_dir=args.log_dir) log_dir=args.log_dir,
)
# TODO(fyrestone): Avoid using ray.state in dashboard, it's not # TODO(fyrestone): Avoid using ray.state in dashboard, it's not
# asynchronous and will lead to low performance. ray disconnect() # asynchronous and will lead to low performance. ray disconnect()
# will be hang when the ray.state is connected and the GCS is exit. # will be hang when the ray.state is connected and the GCS is exit.
# Please refer to: https://github.com/ray-project/ray/issues/16328 # Please refer to: https://github.com/ray-project/ray/issues/16328
service_discovery = PrometheusServiceDiscoveryWriter( service_discovery = PrometheusServiceDiscoveryWriter(
args.redis_address, args.redis_password, args.gcs_address, args.redis_address, args.redis_password, args.gcs_address, args.temp_dir
args.temp_dir) )
# Need daemon True to avoid dashboard hangs at exit. # Need daemon True to avoid dashboard hangs at exit.
service_discovery.daemon = True service_discovery.daemon = True
service_discovery.start() service_discovery.start()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(dashboard.run()) loop.run_until_complete(dashboard.run())
except Exception as e: except Exception as e:
traceback_str = ray._private.utils.format_error_message( traceback_str = ray._private.utils.format_error_message(traceback.format_exc())
traceback.format_exc()) message = (
message = f"The dashboard on node {platform.uname()[1]} " \ f"The dashboard on node {platform.uname()[1]} "
f"failed with the following " \ f"failed with the following "
f"error:\n{traceback_str}" f"error:\n{traceback_str}"
)
if isinstance(e, FrontendNotFoundError): if isinstance(e, FrontendNotFoundError):
logger.warning(message) logger.warning(message)
else: else:
@ -268,17 +284,21 @@ if __name__ == "__main__":
gcs_publisher = GcsPublisher(args.gcs_address) gcs_publisher = GcsPublisher(args.gcs_address)
else: else:
redis_client = ray._private.services.create_redis_client( redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password) args.redis_address, password=args.redis_password
)
gcs_publisher = GcsPublisher( gcs_publisher = GcsPublisher(
address=gcs_utils.get_gcs_address_from_redis(redis_client)) address=gcs_utils.get_gcs_address_from_redis(redis_client)
)
redis_client = None redis_client = None
else: else:
redis_client = ray._private.services.create_redis_client( redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password) args.redis_address, password=args.redis_password
)
ray._private.utils.publish_error_to_driver( ray._private.utils.publish_error_to_driver(
redis_client, redis_client,
ray_constants.DASHBOARD_DIED_ERROR, ray_constants.DASHBOARD_DIED_ERROR,
message, message,
redis_client=redis_client, redis_client=redis_client,
gcs_publisher=gcs_publisher) gcs_publisher=gcs_publisher,
)

View file

@ -2,9 +2,9 @@ import asyncio
import logging import logging
import ray.dashboard.consts as dashboard_consts import ray.dashboard.consts as dashboard_consts
import ray.dashboard.memory_utils as memory_utils import ray.dashboard.memory_utils as memory_utils
# TODO(fyrestone): Not import from dashboard module. # TODO(fyrestone): Not import from dashboard module.
from ray.dashboard.modules.actor.actor_utils import \ from ray.dashboard.modules.actor.actor_utils import actor_classname_from_task_spec
actor_classname_from_task_spec
from ray.dashboard.utils import Dict, Signal, async_loop_forever from ray.dashboard.utils import Dict, Signal, async_loop_forever
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -132,9 +132,9 @@ class DataOrganizer:
worker["errorCount"] = len(node_errs.get(str(pid), [])) worker["errorCount"] = len(node_errs.get(str(pid), []))
worker["coreWorkerStats"] = pid_to_worker_stats.get(pid, []) worker["coreWorkerStats"] = pid_to_worker_stats.get(pid, [])
worker["language"] = pid_to_language.get( worker["language"] = pid_to_language.get(
pid, dashboard_consts.DEFAULT_LANGUAGE) pid, dashboard_consts.DEFAULT_LANGUAGE
worker["jobId"] = pid_to_job_id.get( )
pid, dashboard_consts.DEFAULT_JOB_ID) worker["jobId"] = pid_to_job_id.get(pid, dashboard_consts.DEFAULT_JOB_ID)
await GlobalSignals.worker_info_fetched.send(node_id, worker) await GlobalSignals.worker_info_fetched.send(node_id, worker)
@ -143,8 +143,7 @@ class DataOrganizer:
@classmethod @classmethod
async def get_node_info(cls, node_id): async def get_node_info(cls, node_id):
node_physical_stats = dict( node_physical_stats = dict(DataSource.node_physical_stats.get(node_id, {}))
DataSource.node_physical_stats.get(node_id, {}))
node_stats = dict(DataSource.node_stats.get(node_id, {})) node_stats = dict(DataSource.node_stats.get(node_id, {}))
node = DataSource.nodes.get(node_id, {}) node = DataSource.nodes.get(node_id, {})
node_ip = DataSource.node_id_to_ip.get(node_id) node_ip = DataSource.node_id_to_ip.get(node_id)
@ -162,8 +161,8 @@ class DataOrganizer:
view_data = node_stats.get("viewData", []) view_data = node_stats.get("viewData", [])
ray_stats = cls._extract_view_data( ray_stats = cls._extract_view_data(
view_data, view_data, {"object_store_used_memory", "object_store_available_memory"}
{"object_store_used_memory", "object_store_available_memory"}) )
node_info = node_physical_stats node_info = node_physical_stats
# Merge node stats to node physical stats under raylet # Merge node stats to node physical stats under raylet
@ -184,8 +183,7 @@ class DataOrganizer:
@classmethod @classmethod
async def get_node_summary(cls, node_id): async def get_node_summary(cls, node_id):
node_physical_stats = dict( node_physical_stats = dict(DataSource.node_physical_stats.get(node_id, {}))
DataSource.node_physical_stats.get(node_id, {}))
node_stats = dict(DataSource.node_stats.get(node_id, {})) node_stats = dict(DataSource.node_stats.get(node_id, {}))
node = DataSource.nodes.get(node_id, {}) node = DataSource.nodes.get(node_id, {})
@ -193,8 +191,8 @@ class DataOrganizer:
node_stats.pop("workersStats", None) node_stats.pop("workersStats", None)
view_data = node_stats.get("viewData", []) view_data = node_stats.get("viewData", [])
ray_stats = cls._extract_view_data( ray_stats = cls._extract_view_data(
view_data, view_data, {"object_store_used_memory", "object_store_available_memory"}
{"object_store_used_memory", "object_store_available_memory"}) )
node_stats.pop("viewData", None) node_stats.pop("viewData", None)
node_summary = node_physical_stats node_summary = node_physical_stats
@ -244,8 +242,9 @@ class DataOrganizer:
actor = dict(actor) actor = dict(actor)
worker_id = actor["address"]["workerId"] worker_id = actor["address"]["workerId"]
core_worker_stats = DataSource.core_worker_stats.get(worker_id, {}) core_worker_stats = DataSource.core_worker_stats.get(worker_id, {})
actor_constructor = core_worker_stats.get("actorTitle", actor_constructor = core_worker_stats.get(
"Unknown actor constructor") "actorTitle", "Unknown actor constructor"
)
actor["actorConstructor"] = actor_constructor actor["actorConstructor"] = actor_constructor
actor.update(core_worker_stats) actor.update(core_worker_stats)
@ -275,8 +274,12 @@ class DataOrganizer:
@classmethod @classmethod
async def get_actor_creation_tasks(cls): async def get_actor_creation_tasks(cls):
infeasible_tasks = sum( infeasible_tasks = sum(
(list(node_stats.get("infeasibleTasks", [])) (
for node_stats in DataSource.node_stats.values()), []) list(node_stats.get("infeasibleTasks", []))
for node_stats in DataSource.node_stats.values()
),
[],
)
new_infeasible_tasks = [] new_infeasible_tasks = []
for task in infeasible_tasks: for task in infeasible_tasks:
task = dict(task) task = dict(task)
@ -285,8 +288,12 @@ class DataOrganizer:
new_infeasible_tasks.append(task) new_infeasible_tasks.append(task)
resource_pending_tasks = sum( resource_pending_tasks = sum(
(list(data.get("readyTasks", [])) (
for data in DataSource.node_stats.values()), []) list(data.get("readyTasks", []))
for data in DataSource.node_stats.values()
),
[],
)
new_resource_pending_tasks = [] new_resource_pending_tasks = []
for task in resource_pending_tasks: for task in resource_pending_tasks:
task = dict(task) task = dict(task)
@ -301,14 +308,17 @@ class DataOrganizer:
return results return results
@classmethod @classmethod
async def get_memory_table(cls, async def get_memory_table(
sort_by=memory_utils.SortingType.OBJECT_SIZE, cls,
group_by=memory_utils.GroupByType.STACK_TRACE): sort_by=memory_utils.SortingType.OBJECT_SIZE,
group_by=memory_utils.GroupByType.STACK_TRACE,
):
all_worker_stats = [] all_worker_stats = []
for node_stats in DataSource.node_stats.values(): for node_stats in DataSource.node_stats.values():
all_worker_stats.extend(node_stats.get("coreWorkersStats", [])) all_worker_stats.extend(node_stats.get("coreWorkersStats", []))
memory_information = memory_utils.construct_memory_table( memory_information = memory_utils.construct_memory_table(
all_worker_stats, group_by=group_by, sort_by=sort_by) all_worker_stats, group_by=group_by, sort_by=sort_by
)
return memory_information return memory_information
@staticmethod @staticmethod

View file

@ -10,6 +10,7 @@ from queue import Queue
from distutils.version import LooseVersion from distutils.version import LooseVersion
import grpc import grpc
try: try:
from grpc import aio as aiogrpc from grpc import aio as aiogrpc
except ImportError: except ImportError:
@ -23,8 +24,11 @@ import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.optional_utils as dashboard_optional_utils
from ray import ray_constants from ray import ray_constants
from ray._private.gcs_pubsub import gcs_pubsub_enabled, \ from ray._private.gcs_pubsub import (
GcsAioErrorSubscriber, GcsAioLogSubscriber gcs_pubsub_enabled,
GcsAioErrorSubscriber,
GcsAioLogSubscriber,
)
from ray.core.generated import gcs_service_pb2 from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated import gcs_service_pb2_grpc
from ray.dashboard.datacenter import DataOrganizer from ray.dashboard.datacenter import DataOrganizer
@ -42,33 +46,33 @@ aiogrpc.init_grpc_aio()
GRPC_CHANNEL_OPTIONS = ( GRPC_CHANNEL_OPTIONS = (
("grpc.enable_http_proxy", 0), ("grpc.enable_http_proxy", 0),
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE), ("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", ("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
) )
async def get_gcs_address_with_retry(redis_client) -> str: async def get_gcs_address_with_retry(redis_client) -> str:
while True: while True:
try: try:
gcs_address = (await redis_client.get( gcs_address = (
dashboard_consts.GCS_SERVER_ADDRESS)).decode() await redis_client.get(dashboard_consts.GCS_SERVER_ADDRESS)
).decode()
if not gcs_address: if not gcs_address:
raise Exception("GCS address not found.") raise Exception("GCS address not found.")
logger.info("Connect to GCS at %s", gcs_address) logger.info("Connect to GCS at %s", gcs_address)
return gcs_address return gcs_address
except Exception as ex: except Exception as ex:
logger.error("Connect to GCS failed: %s, retry...", ex) logger.error("Connect to GCS failed: %s, retry...", ex)
await asyncio.sleep( await asyncio.sleep(dashboard_consts.GCS_RETRY_CONNECT_INTERVAL_SECONDS)
dashboard_consts.GCS_RETRY_CONNECT_INTERVAL_SECONDS)
class GCSHealthCheckThread(threading.Thread): class GCSHealthCheckThread(threading.Thread):
def __init__(self, gcs_address: str): def __init__(self, gcs_address: str):
self.grpc_gcs_channel = ray._private.utils.init_grpc_channel( self.grpc_gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, options=GRPC_CHANNEL_OPTIONS) gcs_address, options=GRPC_CHANNEL_OPTIONS
self.gcs_heartbeat_info_stub = ( )
gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub( self.gcs_heartbeat_info_stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
self.grpc_gcs_channel)) self.grpc_gcs_channel
)
self.work_queue = Queue() self.work_queue = Queue()
super().__init__(daemon=True) super().__init__(daemon=True)
@ -83,10 +87,10 @@ class GCSHealthCheckThread(threading.Thread):
request = gcs_service_pb2.CheckAliveRequest() request = gcs_service_pb2.CheckAliveRequest()
try: try:
reply = self.gcs_heartbeat_info_stub.CheckAlive( reply = self.gcs_heartbeat_info_stub.CheckAlive(
request, timeout=dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT) request, timeout=dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT
)
if reply.status.code != 0: if reply.status.code != 0:
logger.exception( logger.exception(f"Failed to CheckAlive: {reply.status.message}")
f"Failed to CheckAlive: {reply.status.message}")
return False return False
except grpc.RpcError: # Deadline Exceeded except grpc.RpcError: # Deadline Exceeded
logger.exception("Got RpcError when checking GCS is alive") logger.exception("Got RpcError when checking GCS is alive")
@ -95,9 +99,9 @@ class GCSHealthCheckThread(threading.Thread):
async def check_once(self) -> bool: async def check_once(self) -> bool:
"""Ask the thread to perform a healthcheck.""" """Ask the thread to perform a healthcheck."""
assert threading.current_thread != self, ( assert (
"caller shouldn't be from the same thread as GCSHealthCheckThread." threading.current_thread != self
) ), "caller shouldn't be from the same thread as GCSHealthCheckThread."
future = Future() future = Future()
self.work_queue.put(future) self.work_queue.put(future)
@ -105,8 +109,16 @@ class GCSHealthCheckThread(threading.Thread):
class DashboardHead: class DashboardHead:
def __init__(self, http_host, http_port, http_port_retries, gcs_address, def __init__(
redis_address, redis_password, log_dir): self,
http_host,
http_port,
http_port_retries,
gcs_address,
redis_address,
redis_password,
log_dir,
):
self.health_check_thread: GCSHealthCheckThread = None self.health_check_thread: GCSHealthCheckThread = None
self._gcs_rpc_error_counter = 0 self._gcs_rpc_error_counter = 0
# Public attributes are accessible for all head modules. # Public attributes are accessible for all head modules.
@ -134,12 +146,12 @@ class DashboardHead:
else: else:
ip, port = gcs_address.split(":") ip, port = gcs_address.split(":")
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0" grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server( self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
self.server, f"{grpc_ip}:0") self.server, f"{grpc_ip}:0"
logger.info("Dashboard head grpc address: %s:%s", grpc_ip, )
self.grpc_port) logger.info("Dashboard head grpc address: %s:%s", grpc_ip, self.grpc_port)
@async_loop_forever(dashboard_consts.GCS_CHECK_ALIVE_INTERVAL_SECONDS) @async_loop_forever(dashboard_consts.GCS_CHECK_ALIVE_INTERVAL_SECONDS)
async def _gcs_check_alive(self): async def _gcs_check_alive(self):
@ -149,7 +161,8 @@ class DashboardHead:
# Otherwise, the dashboard will always think that gcs is alive. # Otherwise, the dashboard will always think that gcs is alive.
try: try:
is_alive = await asyncio.wait_for( is_alive = await asyncio.wait_for(
check_future, dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT + 1) check_future, dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT + 1
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.error("Failed to check gcs health, client timed out.") logger.error("Failed to check gcs health, client timed out.")
is_alive = False is_alive = False
@ -158,13 +171,16 @@ class DashboardHead:
self._gcs_rpc_error_counter = 0 self._gcs_rpc_error_counter = 0
else: else:
self._gcs_rpc_error_counter += 1 self._gcs_rpc_error_counter += 1
if self._gcs_rpc_error_counter > \ if (
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR: self._gcs_rpc_error_counter
> dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR
):
logger.error( logger.error(
"Dashboard exiting because it received too many GCS RPC " "Dashboard exiting because it received too many GCS RPC "
"errors count: %s, threshold is %s.", "errors count: %s, threshold is %s.",
self._gcs_rpc_error_counter, self._gcs_rpc_error_counter,
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR) dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR,
)
# TODO(fyrestone): Do not use ray.state in # TODO(fyrestone): Do not use ray.state in
# PrometheusServiceDiscoveryWriter. # PrometheusServiceDiscoveryWriter.
# Currently, we use os._exit() here to avoid hanging at the ray # Currently, we use os._exit() here to avoid hanging at the ray
@ -176,10 +192,12 @@ class DashboardHead:
"""Load dashboard head modules.""" """Load dashboard head modules."""
modules = [] modules = []
head_cls_list = dashboard_utils.get_all_modules( head_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardHeadModule) dashboard_utils.DashboardHeadModule
)
for cls in head_cls_list: for cls in head_cls_list:
logger.info("Loading %s: %s", logger.info(
dashboard_utils.DashboardHeadModule.__name__, cls) "Loading %s: %s", dashboard_utils.DashboardHeadModule.__name__, cls
)
c = cls(self) c = cls(self)
dashboard_optional_utils.ClassMethodRouteTable.bind(c) dashboard_optional_utils.ClassMethodRouteTable.bind(c)
modules.append(c) modules.append(c)
@ -192,15 +210,17 @@ class DashboardHead:
return self.gcs_address return self.gcs_address
else: else:
try: try:
self.aioredis_client = \ self.aioredis_client = await dashboard_utils.get_aioredis_client(
await dashboard_utils.get_aioredis_client( self.redis_address,
self.redis_address, self.redis_password, self.redis_password,
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS, dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES) dashboard_consts.RETRY_REDIS_CONNECTION_TIMES,
)
except (socket.gaierror, ConnectionError): except (socket.gaierror, ConnectionError):
logger.error( logger.error(
"Dashboard head exiting: " "Dashboard head exiting: " "Failed to connect to redis at %s",
"Failed to connect to redis at %s", self.redis_address) self.redis_address,
)
sys.exit(-1) sys.exit(-1)
return await get_gcs_address_with_retry(self.aioredis_client) return await get_gcs_address_with_retry(self.aioredis_client)
@ -209,22 +229,20 @@ class DashboardHead:
# Create a http session for all modules. # Create a http session for all modules.
# aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"): if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
self.http_session = aiohttp.ClientSession( self.http_session = aiohttp.ClientSession(loop=asyncio.get_event_loop())
loop=asyncio.get_event_loop())
else: else:
self.http_session = aiohttp.ClientSession() self.http_session = aiohttp.ClientSession()
gcs_address = await self.get_gcs_address() gcs_address = await self.get_gcs_address()
# Dashboard will handle connection failure automatically # Dashboard will handle connection failure automatically
self.gcs_client = GcsClient( self.gcs_client = GcsClient(address=gcs_address, nums_reconnect_retry=0)
address=gcs_address, nums_reconnect_retry=0)
internal_kv._initialize_internal_kv(self.gcs_client) internal_kv._initialize_internal_kv(self.gcs_client)
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel( self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True) gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True
)
if gcs_pubsub_enabled(): if gcs_pubsub_enabled():
self.gcs_error_subscriber = GcsAioErrorSubscriber( self.gcs_error_subscriber = GcsAioErrorSubscriber(address=gcs_address)
address=gcs_address)
self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address) self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address)
await self.gcs_error_subscriber.subscribe() await self.gcs_error_subscriber.subscribe()
await self.gcs_log_subscriber.subscribe() await self.gcs_log_subscriber.subscribe()
@ -248,7 +266,7 @@ class DashboardHead:
# Http server should be initialized after all modules loaded. # Http server should be initialized after all modules loaded.
# working_dir uploads for job submission can be up to 100MiB. # working_dir uploads for job submission can be up to 100MiB.
app = aiohttp.web.Application(client_max_size=100 * 1024**2) app = aiohttp.web.Application(client_max_size=100 * 1024 ** 2)
app.add_routes(routes=routes.bound_routes()) app.add_routes(routes=routes.bound_routes())
runner = aiohttp.web.AppRunner(app) runner = aiohttp.web.AppRunner(app)
@ -256,8 +274,7 @@ class DashboardHead:
last_ex = None last_ex = None
for i in range(1 + self.http_port_retries): for i in range(1 + self.http_port_retries):
try: try:
site = aiohttp.web.TCPSite(runner, self.http_host, site = aiohttp.web.TCPSite(runner, self.http_host, self.http_port)
self.http_port)
await site.start() await site.start()
break break
except OSError as e: except OSError as e:
@ -265,11 +282,14 @@ class DashboardHead:
self.http_port += 1 self.http_port += 1
logger.warning("Try to use port %s: %s", self.http_port, e) logger.warning("Try to use port %s: %s", self.http_port, e)
else: else:
raise Exception(f"Failed to find a valid port for dashboard after " raise Exception(
f"{self.http_port_retries} retries: {last_ex}") f"Failed to find a valid port for dashboard after "
f"{self.http_port_retries} retries: {last_ex}"
)
http_host, http_port, *_ = site._server.sockets[0].getsockname() http_host, http_port, *_ = site._server.sockets[0].getsockname()
http_host = self.ip if ipaddress.ip_address( http_host = (
http_host).is_unspecified else http_host self.ip if ipaddress.ip_address(http_host).is_unspecified else http_host
)
logger.info("Dashboard head http address: %s:%s", http_host, http_port) logger.info("Dashboard head http address: %s:%s", http_host, http_port)
# TODO: Use async version if performance is an issue # TODO: Use async version if performance is an issue
@ -277,16 +297,16 @@ class DashboardHead:
internal_kv._internal_kv_put( internal_kv._internal_kv_put(
ray_constants.DASHBOARD_ADDRESS, ray_constants.DASHBOARD_ADDRESS,
f"{http_host}:{http_port}", f"{http_host}:{http_port}",
namespace=ray_constants.KV_NAMESPACE_DASHBOARD) namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
internal_kv._internal_kv_put( internal_kv._internal_kv_put(
dashboard_consts.DASHBOARD_RPC_ADDRESS, dashboard_consts.DASHBOARD_RPC_ADDRESS,
f"{self.ip}:{self.grpc_port}", f"{self.ip}:{self.grpc_port}",
namespace=ray_constants.KV_NAMESPACE_DASHBOARD) namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
# Dump registered http routes. # Dump registered http routes.
dump_routes = [ dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD]
r for r in app.router.routes() if r.method != hdrs.METH_HEAD
]
for r in dump_routes: for r in dump_routes:
logger.info(r) logger.info(r)
logger.info("Registered %s routes.", len(dump_routes)) logger.info("Registered %s routes.", len(dump_routes))
@ -299,6 +319,5 @@ class DashboardHead:
DataOrganizer.purge(), DataOrganizer.purge(),
DataOrganizer.organize(), DataOrganizer.organize(),
] ]
await asyncio.gather(*concurrent_tasks, await asyncio.gather(*concurrent_tasks, *(m.run(self.server) for m in modules))
*(m.run(self.server) for m in modules))
await self.server.wait_for_termination() await self.server.wait_for_termination()

View file

@ -23,8 +23,7 @@ class HttpServerAgent:
# Create a http session for all modules. # Create a http session for all modules.
# aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"): if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
self.http_session = aiohttp.ClientSession( self.http_session = aiohttp.ClientSession(loop=asyncio.get_event_loop())
loop=asyncio.get_event_loop())
else: else:
self.http_session = aiohttp.ClientSession() self.http_session = aiohttp.ClientSession()
@ -47,25 +46,26 @@ class HttpServerAgent:
allow_methods="*", allow_methods="*",
allow_headers=("Content-Type", "X-Header"), allow_headers=("Content-Type", "X-Header"),
) )
}) },
)
for route in list(app.router.routes()): for route in list(app.router.routes()):
cors.add(route) cors.add(route)
self.runner = aiohttp.web.AppRunner(app) self.runner = aiohttp.web.AppRunner(app)
await self.runner.setup() await self.runner.setup()
site = aiohttp.web.TCPSite( site = aiohttp.web.TCPSite(
self.runner, "127.0.0.1" self.runner,
if self.ip == "127.0.0.1" else "0.0.0.0", self.listen_port) "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0",
self.listen_port,
)
await site.start() await site.start()
self.http_host, self.http_port, *_ = ( self.http_host, self.http_port, *_ = site._server.sockets[0].getsockname()
site._server.sockets[0].getsockname()) logger.info(
logger.info("Dashboard agent http address: %s:%s", self.http_host, "Dashboard agent http address: %s:%s", self.http_host, self.http_port
self.http_port) )
# Dump registered http routes. # Dump registered http routes.
dump_routes = [ dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD]
r for r in app.router.routes() if r.method != hdrs.METH_HEAD
]
for r in dump_routes: for r in dump_routes:
logger.info(r) logger.info(r)
logger.info("Registered %s routes.", len(dump_routes)) logger.info("Registered %s routes.", len(dump_routes))

View file

@ -31,7 +31,7 @@ def cpu_percent():
delta in total host cpu usage, averaged over host's cpus. delta in total host cpu usage, averaged over host's cpus.
Since deltas are not initially available, return 0.0 on first call. Since deltas are not initially available, return 0.0 on first call.
""" # noqa """ # noqa
global last_system_usage global last_system_usage
global last_cpu_usage global last_cpu_usage
try: try:
@ -43,12 +43,10 @@ def cpu_percent():
else: else:
cpu_delta = cpu_usage - last_cpu_usage cpu_delta = cpu_usage - last_cpu_usage
# "System time passed." (Typically close to clock time.) # "System time passed." (Typically close to clock time.)
system_delta = ( system_delta = (system_usage - last_system_usage) / _host_num_cpus()
(system_usage - last_system_usage) / _host_num_cpus())
quotient = cpu_delta / system_delta quotient = cpu_delta / system_delta
cpu_percent = round( cpu_percent = round(quotient * 100 / ray._private.utils.get_k8s_cpus(), 1)
quotient * 100 / ray._private.utils.get_k8s_cpus(), 1)
last_system_usage = system_usage last_system_usage = system_usage
last_cpu_usage = cpu_usage last_cpu_usage = cpu_usage
# Computed percentage might be slightly above 100%. # Computed percentage might be slightly above 100%.
@ -73,14 +71,14 @@ def _system_usage():
See also the /proc/stat entry here: See also the /proc/stat entry here:
https://man7.org/linux/man-pages/man5/proc.5.html https://man7.org/linux/man-pages/man5/proc.5.html
""" # noqa """ # noqa
cpu_summary_str = open(PROC_STAT_PATH).read().split("\n")[0] cpu_summary_str = open(PROC_STAT_PATH).read().split("\n")[0]
parts = cpu_summary_str.split() parts = cpu_summary_str.split()
assert parts[0] == "cpu" assert parts[0] == "cpu"
usage_data = parts[1:8] usage_data = parts[1:8]
total_clock_ticks = sum(int(entry) for entry in usage_data) total_clock_ticks = sum(int(entry) for entry in usage_data)
# 100 clock ticks per second, 10^9 ns per second # 100 clock ticks per second, 10^9 ns per second
usage_ns = total_clock_ticks * 10**7 usage_ns = total_clock_ticks * 10 ** 7
return usage_ns return usage_ns
@ -91,7 +89,8 @@ def _host_num_cpus():
proc_stat_lines = open(PROC_STAT_PATH).read().split("\n") proc_stat_lines = open(PROC_STAT_PATH).read().split("\n")
split_proc_stat_lines = [line.split() for line in proc_stat_lines] split_proc_stat_lines = [line.split() for line in proc_stat_lines]
cpu_lines = [ cpu_lines = [
split_line for split_line in split_proc_stat_lines split_line
for split_line in split_proc_stat_lines
if len(split_line) > 0 and "cpu" in split_line[0] if len(split_line) > 0 and "cpu" in split_line[0]
] ]
# Number of lines starting with a word including 'cpu', subtracting # Number of lines starting with a word including 'cpu', subtracting

View file

@ -6,7 +6,7 @@ from typing import List
import ray import ray
from ray._raylet import (TaskID, ActorID, JobID) from ray._raylet import TaskID, ActorID, JobID
from ray.internal.internal_api import node_stats from ray.internal.internal_api import node_stats
import logging import logging
@ -69,8 +69,10 @@ def get_sorting_type(sort_by: str):
elif sort_by == "REFERENCE_TYPE": elif sort_by == "REFERENCE_TYPE":
return SortingType.REFERENCE_TYPE return SortingType.REFERENCE_TYPE
else: else:
raise Exception("The sort-by input provided is not one of\ raise Exception(
PID, OBJECT_SIZE, or REFERENCE_TYPE.") "The sort-by input provided is not one of\
PID, OBJECT_SIZE, or REFERENCE_TYPE."
)
def get_group_by_type(group_by: str): def get_group_by_type(group_by: str):
@ -81,13 +83,16 @@ def get_group_by_type(group_by: str):
elif group_by == "STACK_TRACE": elif group_by == "STACK_TRACE":
return GroupByType.STACK_TRACE return GroupByType.STACK_TRACE
else: else:
raise Exception("The group-by input provided is not one of\ raise Exception(
NODE_ADDRESS or STACK_TRACE.") "The group-by input provided is not one of\
NODE_ADDRESS or STACK_TRACE."
)
class MemoryTableEntry: class MemoryTableEntry:
def __init__(self, *, object_ref: dict, node_address: str, is_driver: bool, def __init__(
pid: int): self, *, object_ref: dict, node_address: str, is_driver: bool, pid: int
):
# worker info # worker info
self.is_driver = is_driver self.is_driver = is_driver
self.pid = pid self.pid = pid
@ -97,13 +102,13 @@ class MemoryTableEntry:
self.object_size = int(object_ref.get("objectSize", -1)) self.object_size = int(object_ref.get("objectSize", -1))
self.call_site = object_ref.get("callSite", "<Unknown>") self.call_site = object_ref.get("callSite", "<Unknown>")
self.object_ref = ray.ObjectRef( self.object_ref = ray.ObjectRef(
decode_object_ref_if_needed(object_ref["objectId"])) decode_object_ref_if_needed(object_ref["objectId"])
)
# reference info # reference info
self.local_ref_count = int(object_ref.get("localRefCount", 0)) self.local_ref_count = int(object_ref.get("localRefCount", 0))
self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False)) self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False))
self.submitted_task_ref_count = int( self.submitted_task_ref_count = int(object_ref.get("submittedTaskRefCount", 0))
object_ref.get("submittedTaskRefCount", 0))
self.contained_in_owned = [ self.contained_in_owned = [
ray.ObjectRef(decode_object_ref_if_needed(object_ref)) ray.ObjectRef(decode_object_ref_if_needed(object_ref))
for object_ref in object_ref.get("containedInOwned", []) for object_ref in object_ref.get("containedInOwned", [])
@ -113,9 +118,12 @@ class MemoryTableEntry:
def is_valid(self) -> bool: def is_valid(self) -> bool:
# If the entry doesn't have a reference type or some invalid state, # If the entry doesn't have a reference type or some invalid state,
# (e.g., no object ref presented), it is considered invalid. # (e.g., no object ref presented), it is considered invalid.
if (not self.pinned_in_memory and self.local_ref_count == 0 if (
and self.submitted_task_ref_count == 0 not self.pinned_in_memory
and len(self.contained_in_owned) == 0): and self.local_ref_count == 0
and self.submitted_task_ref_count == 0
and len(self.contained_in_owned) == 0
):
return False return False
elif self.object_ref.is_nil(): elif self.object_ref.is_nil():
return False return False
@ -153,10 +161,10 @@ class MemoryTableEntry:
# are not all 'f', that means it is an actor creation # are not all 'f', that means it is an actor creation
# task, which is an actor handle. # task, which is an actor handle.
random_bits = object_ref_hex[:TASKID_RANDOM_BITS_SIZE] random_bits = object_ref_hex[:TASKID_RANDOM_BITS_SIZE]
actor_random_bits = object_ref_hex[TASKID_RANDOM_BITS_SIZE: actor_random_bits = object_ref_hex[
TASKID_RANDOM_BITS_SIZE + TASKID_RANDOM_BITS_SIZE : TASKID_RANDOM_BITS_SIZE + ACTORID_RANDOM_BITS_SIZE
ACTORID_RANDOM_BITS_SIZE] ]
if (random_bits == "f" * 16 and not actor_random_bits == "f" * 24): if random_bits == "f" * 16 and not actor_random_bits == "f" * 24:
return True return True
else: else:
return False return False
@ -175,7 +183,7 @@ class MemoryTableEntry:
"contained_in_owned": [ "contained_in_owned": [
object_ref.hex() for object_ref in self.contained_in_owned object_ref.hex() for object_ref in self.contained_in_owned
], ],
"type": "Driver" if self.is_driver else "Worker" "type": "Driver" if self.is_driver else "Worker",
} }
def __str__(self): def __str__(self):
@ -186,10 +194,12 @@ class MemoryTableEntry:
class MemoryTable: class MemoryTable:
def __init__(self, def __init__(
entries: List[MemoryTableEntry], self,
group_by_type: GroupByType = GroupByType.NODE_ADDRESS, entries: List[MemoryTableEntry],
sort_by_type: SortingType = SortingType.PID): group_by_type: GroupByType = GroupByType.NODE_ADDRESS,
sort_by_type: SortingType = SortingType.PID,
):
self.table = entries self.table = entries
# Group is a list of memory tables grouped by a group key. # Group is a list of memory tables grouped by a group key.
self.group = {} self.group = {}
@ -247,7 +257,7 @@ class MemoryTable:
"total_pinned_in_memory": total_pinned_in_memory, "total_pinned_in_memory": total_pinned_in_memory,
"total_used_by_pending_task": total_used_by_pending_task, "total_used_by_pending_task": total_used_by_pending_task,
"total_captured_in_objects": total_captured_in_objects, "total_captured_in_objects": total_captured_in_objects,
"total_actor_handles": total_actor_handles "total_actor_handles": total_actor_handles,
} }
return self return self
@ -278,7 +288,8 @@ class MemoryTable:
# Build a group table. # Build a group table.
for group_key, entries in group.items(): for group_key, entries in group.items():
self.group[group_key] = MemoryTable( self.group[group_key] = MemoryTable(
entries, group_by_type=None, sort_by_type=None) entries, group_by_type=None, sort_by_type=None
)
for group_key, group_memory_table in self.group.items(): for group_key, group_memory_table in self.group.items():
group_memory_table.summarize() group_memory_table.summarize()
return self return self
@ -289,10 +300,10 @@ class MemoryTable:
"group": { "group": {
group_key: { group_key: {
"entries": group_memory_table.get_entries(), "entries": group_memory_table.get_entries(),
"summary": group_memory_table.summary "summary": group_memory_table.summary,
} }
for group_key, group_memory_table in self.group.items() for group_key, group_memory_table in self.group.items()
} },
} }
def get_entries(self) -> List[dict]: def get_entries(self) -> List[dict]:
@ -305,9 +316,11 @@ class MemoryTable:
return self.__repr__() return self.__repr__()
def construct_memory_table(workers_stats: List, def construct_memory_table(
group_by: GroupByType = GroupByType.NODE_ADDRESS, workers_stats: List,
sort_by=SortingType.OBJECT_SIZE) -> MemoryTable: group_by: GroupByType = GroupByType.NODE_ADDRESS,
sort_by=SortingType.OBJECT_SIZE,
) -> MemoryTable:
memory_table_entries = [] memory_table_entries = []
for core_worker_stats in workers_stats: for core_worker_stats in workers_stats:
pid = core_worker_stats["pid"] pid = core_worker_stats["pid"]
@ -320,11 +333,13 @@ def construct_memory_table(workers_stats: List,
object_ref=object_ref, object_ref=object_ref,
node_address=node_address, node_address=node_address,
is_driver=is_driver, is_driver=is_driver,
pid=pid) pid=pid,
)
if memory_table_entry.is_valid(): if memory_table_entry.is_valid():
memory_table_entries.append(memory_table_entry) memory_table_entries.append(memory_table_entry)
memory_table = MemoryTable( memory_table = MemoryTable(
memory_table_entries, group_by_type=group_by, sort_by_type=sort_by) memory_table_entries, group_by_type=group_by, sort_by_type=sort_by
)
return memory_table return memory_table
@ -337,7 +352,7 @@ def track_reference_size(group):
"PINNED_IN_MEMORY": "total_pinned_in_memory", "PINNED_IN_MEMORY": "total_pinned_in_memory",
"USED_BY_PENDING_TASK": "total_used_by_pending_task", "USED_BY_PENDING_TASK": "total_used_by_pending_task",
"CAPTURED_IN_OBJECT": "total_captured_in_objects", "CAPTURED_IN_OBJECT": "total_captured_in_objects",
"ACTOR_HANDLE": "total_actor_handles" "ACTOR_HANDLE": "total_actor_handles",
} }
for entry in group["entries"]: for entry in group["entries"]:
size = entry["object_size"] size = entry["object_size"]
@ -348,51 +363,64 @@ def track_reference_size(group):
return d return d
def memory_summary(state, def memory_summary(
group_by="NODE_ADDRESS", state,
sort_by="OBJECT_SIZE", group_by="NODE_ADDRESS",
line_wrap=True, sort_by="OBJECT_SIZE",
unit="B", line_wrap=True,
num_entries=None) -> str: unit="B",
from ray.dashboard.modules.node.node_head\ num_entries=None,
import node_stats_to_dict ) -> str:
from ray.dashboard.modules.node.node_head import node_stats_to_dict
# Get terminal size # Get terminal size
import shutil import shutil
size = shutil.get_terminal_size((80, 20)).columns size = shutil.get_terminal_size((80, 20)).columns
line_wrap_threshold = 137 line_wrap_threshold = 137
# Unit conversions # Unit conversions
units = {"B": 10**0, "KB": 10**3, "MB": 10**6, "GB": 10**9} units = {"B": 10 ** 0, "KB": 10 ** 3, "MB": 10 ** 6, "GB": 10 ** 9}
# Fetch core memory worker stats, store as a dictionary # Fetch core memory worker stats, store as a dictionary
core_worker_stats = [] core_worker_stats = []
for raylet in state.node_table(): for raylet in state.node_table():
stats = node_stats_to_dict( stats = node_stats_to_dict(
node_stats(raylet["NodeManagerAddress"], node_stats(raylet["NodeManagerAddress"], raylet["NodeManagerPort"])
raylet["NodeManagerPort"])) )
core_worker_stats.extend(stats["coreWorkersStats"]) core_worker_stats.extend(stats["coreWorkersStats"])
assert type(stats) is dict and "coreWorkersStats" in stats assert type(stats) is dict and "coreWorkersStats" in stats
# Build memory table with "group_by" and "sort_by" parameters # Build memory table with "group_by" and "sort_by" parameters
group_by, sort_by = get_group_by_type(group_by), get_sorting_type(sort_by) group_by, sort_by = get_group_by_type(group_by), get_sorting_type(sort_by)
memory_table = construct_memory_table(core_worker_stats, group_by, memory_table = construct_memory_table(
sort_by).as_dict() core_worker_stats, group_by, sort_by
).as_dict()
assert "summary" in memory_table and "group" in memory_table assert "summary" in memory_table and "group" in memory_table
# Build memory summary # Build memory summary
mem = "" mem = ""
group_by, sort_by = group_by.name.lower().replace( group_by, sort_by = group_by.name.lower().replace(
"_", " "), sort_by.name.lower().replace("_", " ") "_", " "
), sort_by.name.lower().replace("_", " ")
summary_labels = [ summary_labels = [
"Mem Used by Objects", "Local References", "Pinned", "Pending Tasks", "Mem Used by Objects",
"Captured in Objects", "Actor Handles" "Local References",
"Pinned",
"Pending Tasks",
"Captured in Objects",
"Actor Handles",
] ]
summary_string = "{:<19} {:<16} {:<12} {:<13} {:<19} {:<13}\n" summary_string = "{:<19} {:<16} {:<12} {:<13} {:<19} {:<13}\n"
object_ref_labels = [ object_ref_labels = [
"IP Address", "PID", "Type", "Call Site", "Size", "Reference Type", "IP Address",
"Object Ref" "PID",
"Type",
"Call Site",
"Size",
"Reference Type",
"Object Ref",
] ]
object_ref_string = "{:<13} | {:<8} | {:<7} | {:<9} \ object_ref_string = "{:<13} | {:<8} | {:<7} | {:<9} \
| {:<8} | {:<14} | {:<10}\n" | {:<8} | {:<14} | {:<10}\n"
@ -416,22 +444,21 @@ entries per group...\n\n\n"
else: else:
summary[k] = str(v) + f", ({ref_size[k] / units[unit]} {unit})" summary[k] = str(v) + f", ({ref_size[k] / units[unit]} {unit})"
mem += f"--- Summary for {group_by}: {key} ---\n" mem += f"--- Summary for {group_by}: {key} ---\n"
mem += summary_string\ mem += summary_string.format(*summary_labels)
.format(*summary_labels) mem += summary_string.format(*summary.values()) + "\n"
mem += summary_string\
.format(*summary.values()) + "\n"
# Memory table per group # Memory table per group
mem += f"--- Object references for {group_by}: {key} ---\n" mem += f"--- Object references for {group_by}: {key} ---\n"
mem += object_ref_string\ mem += object_ref_string.format(*object_ref_labels)
.format(*object_ref_labels)
n = 1 # Counter for num entries per group n = 1 # Counter for num entries per group
for entry in group["entries"]: for entry in group["entries"]:
if num_entries is not None and n > num_entries: if num_entries is not None and n > num_entries:
break break
entry["object_size"] = str( entry["object_size"] = (
entry["object_size"] / str(entry["object_size"] / units[unit]) + f" {unit}"
units[unit]) + f" {unit}" if entry["object_size"] > -1 else "?" if entry["object_size"] > -1
else "?"
)
num_lines = 1 num_lines = 1
if size > line_wrap_threshold and line_wrap: if size > line_wrap_threshold and line_wrap:
call_site_length = 22 call_site_length = 22
@ -439,30 +466,36 @@ entries per group...\n\n\n"
entry["call_site"] = ["disabled"] entry["call_site"] = ["disabled"]
else: else:
entry["call_site"] = [ entry["call_site"] = [
entry["call_site"][i:i + call_site_length] for i in entry["call_site"][i : i + call_site_length]
range(0, len(entry["call_site"]), call_site_length) for i in range(0, len(entry["call_site"]), call_site_length)
] ]
num_lines = len(entry["call_site"]) num_lines = len(entry["call_site"])
else: else:
mem += "\n" mem += "\n"
object_ref_values = [ object_ref_values = [
entry["node_ip_address"], entry["pid"], entry["type"], entry["node_ip_address"],
entry["call_site"], entry["object_size"], entry["pid"],
entry["reference_type"], entry["object_ref"] entry["type"],
entry["call_site"],
entry["object_size"],
entry["reference_type"],
entry["object_ref"],
] ]
for i in range(len(object_ref_values)): for i in range(len(object_ref_values)):
if not isinstance(object_ref_values[i], list): if not isinstance(object_ref_values[i], list):
object_ref_values[i] = [object_ref_values[i]] object_ref_values[i] = [object_ref_values[i]]
object_ref_values[i].extend( object_ref_values[i].extend(
["" for x in range(num_lines - len(object_ref_values[i]))]) ["" for x in range(num_lines - len(object_ref_values[i]))]
)
for i in range(num_lines): for i in range(num_lines):
row = [elem[i] for elem in object_ref_values] row = [elem[i] for elem in object_ref_values]
mem += object_ref_string\ mem += object_ref_string.format(*row)
.format(*row)
mem += "\n" mem += "\n"
n += 1 n += 1
mem += "To record callsite information for each ObjectRef created, set " \ mem += (
"env variable RAY_record_ref_creation_sites=1\n\n" "To record callsite information for each ObjectRef created, set "
"env variable RAY_record_ref_creation_sites=1\n\n"
)
return mem return mem

View file

@ -4,6 +4,7 @@ import aiohttp.web
import ray._private.utils import ray._private.utils
from ray.dashboard.modules.actor import actor_utils from ray.dashboard.modules.actor import actor_utils
from aioredis.pubsub import Receiver from aioredis.pubsub import Receiver
try: try:
from grpc import aio as aiogrpc from grpc import aio as aiogrpc
except ImportError: except ImportError:
@ -15,8 +16,7 @@ import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.optional_utils as dashboard_optional_utils
from ray.dashboard.optional_utils import rest_response from ray.dashboard.optional_utils import rest_response
from ray.dashboard.modules.actor import actor_consts from ray.dashboard.modules.actor import actor_consts
from ray.dashboard.modules.actor.actor_utils import \ from ray.dashboard.modules.actor.actor_utils import actor_classname_from_task_spec
actor_classname_from_task_spec
from ray.core.generated import node_manager_pb2_grpc from ray.core.generated import node_manager_pb2_grpc
from ray.core.generated import gcs_service_pb2 from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated import gcs_service_pb2_grpc
@ -30,12 +30,22 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
def actor_table_data_to_dict(message): def actor_table_data_to_dict(message):
orig_message = dashboard_utils.message_to_dict( orig_message = dashboard_utils.message_to_dict(
message, { message,
"actorId", "parentId", "jobId", "workerId", "rayletId", {
"actorCreationDummyObjectId", "callerId", "taskId", "parentTaskId", "actorId",
"sourceActorId", "placementGroupId" "parentId",
"jobId",
"workerId",
"rayletId",
"actorCreationDummyObjectId",
"callerId",
"taskId",
"parentTaskId",
"sourceActorId",
"placementGroupId",
}, },
including_default_value_fields=True) including_default_value_fields=True,
)
# The complete schema for actor table is here: # The complete schema for actor table is here:
# src/ray/protobuf/gcs.proto # src/ray/protobuf/gcs.proto
# It is super big and for dashboard, we don't need that much information. # It is super big and for dashboard, we don't need that much information.
@ -58,8 +68,7 @@ def actor_table_data_to_dict(message):
light_message["actorClass"] = actor_class light_message["actorClass"] = actor_class
if "functionDescriptor" in light_message["taskSpec"]: if "functionDescriptor" in light_message["taskSpec"]:
light_message["taskSpec"] = { light_message["taskSpec"] = {
"functionDescriptor": light_message["taskSpec"][ "functionDescriptor": light_message["taskSpec"]["functionDescriptor"]
"functionDescriptor"]
} }
else: else:
light_message.pop("taskSpec") light_message.pop("taskSpec")
@ -81,11 +90,13 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
if change.new: if change.new:
# TODO(fyrestone): Handle exceptions. # TODO(fyrestone): Handle exceptions.
node_id, node_info = change.new node_id, node_info = change.new
address = "{}:{}".format(node_info["nodeManagerAddress"], address = "{}:{}".format(
int(node_info["nodeManagerPort"])) node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
options = (("grpc.enable_http_proxy", 0), ) )
options = (("grpc.enable_http_proxy", 0),)
channel = ray._private.utils.init_grpc_channel( channel = ray._private.utils.init_grpc_channel(
address, options, asynchronous=True) address, options, asynchronous=True
)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
self._stubs[node_id] = stub self._stubs[node_id] = stub
@ -96,7 +107,8 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
logger.info("Getting all actor info from GCS.") logger.info("Getting all actor info from GCS.")
request = gcs_service_pb2.GetAllActorInfoRequest() request = gcs_service_pb2.GetAllActorInfoRequest()
reply = await self._gcs_actor_info_stub.GetAllActorInfo( reply = await self._gcs_actor_info_stub.GetAllActorInfo(
request, timeout=5) request, timeout=5
)
if reply.status.code == 0: if reply.status.code == 0:
actors = {} actors = {}
for message in reply.actor_table_data: for message in reply.actor_table_data:
@ -110,24 +122,25 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
for actor_id, actor_table_data in actors.items(): for actor_id, actor_table_data in actors.items():
job_id = actor_table_data["jobId"] job_id = actor_table_data["jobId"]
node_id = actor_table_data["address"]["rayletId"] node_id = actor_table_data["address"]["rayletId"]
job_actors.setdefault(job_id, job_actors.setdefault(job_id, {})[actor_id] = actor_table_data
{})[actor_id] = actor_table_data
# Update only when node_id is not Nil. # Update only when node_id is not Nil.
if node_id != actor_consts.NIL_NODE_ID: if node_id != actor_consts.NIL_NODE_ID:
node_actors.setdefault( node_actors.setdefault(node_id, {})[
node_id, {})[actor_id] = actor_table_data actor_id
] = actor_table_data
DataSource.job_actors.reset(job_actors) DataSource.job_actors.reset(job_actors)
DataSource.node_actors.reset(node_actors) DataSource.node_actors.reset(node_actors)
logger.info("Received %d actor info from GCS.", logger.info("Received %d actor info from GCS.", len(actors))
len(actors))
break break
else: else:
raise Exception( raise Exception(
f"Failed to GetAllActorInfo: {reply.status.message}") f"Failed to GetAllActorInfo: {reply.status.message}"
)
except Exception: except Exception:
logger.exception("Error Getting all actor info from GCS.") logger.exception("Error Getting all actor info from GCS.")
await asyncio.sleep( await asyncio.sleep(
actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS) actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS
)
state_keys = ("state", "address", "numRestarts", "timestamp", "pid") state_keys = ("state", "address", "numRestarts", "timestamp", "pid")
@ -167,8 +180,7 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
if actor_id is not None: if actor_id is not None:
# Convert to lower case hex ID. # Convert to lower case hex ID.
actor_id = actor_id.hex() actor_id = actor_id.hex()
process_actor_data_from_pubsub(actor_id, process_actor_data_from_pubsub(actor_id, actor_table_data)
actor_table_data)
except Exception: except Exception:
logger.exception("Error processing actor info from GCS.") logger.exception("Error processing actor info from GCS.")
@ -183,12 +195,15 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
async for sender, msg in receiver.iter(): async for sender, msg in receiver.iter():
try: try:
actor_id, actor_table_data = msg actor_id, actor_table_data = msg
actor_id = actor_id.decode("UTF-8")[len( actor_id = actor_id.decode("UTF-8")[
gcs_utils.TablePrefix_ACTOR_string + ":"):] len(gcs_utils.TablePrefix_ACTOR_string + ":") :
]
pubsub_message = gcs_utils.PubSubMessage.FromString( pubsub_message = gcs_utils.PubSubMessage.FromString(
actor_table_data) actor_table_data
)
actor_table_data = gcs_utils.ActorTableData.FromString( actor_table_data = gcs_utils.ActorTableData.FromString(
pubsub_message.data) pubsub_message.data
)
process_actor_data_from_pubsub(actor_id, actor_table_data) process_actor_data_from_pubsub(actor_id, actor_table_data)
except Exception: except Exception:
logger.exception("Error processing actor info from Redis.") logger.exception("Error processing actor info from Redis.")
@ -203,17 +218,15 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
actors.update(actor_creation_tasks) actors.update(actor_creation_tasks)
actor_groups = actor_utils.construct_actor_groups(actors) actor_groups = actor_utils.construct_actor_groups(actors)
return rest_response( return rest_response(
success=True, success=True, message="Fetched actor groups.", actor_groups=actor_groups
message="Fetched actor groups.", )
actor_groups=actor_groups)
@routes.get("/logical/actors") @routes.get("/logical/actors")
@dashboard_optional_utils.aiohttp_cache @dashboard_optional_utils.aiohttp_cache
async def get_all_actors(self, req) -> aiohttp.web.Response: async def get_all_actors(self, req) -> aiohttp.web.Response:
return rest_response( return rest_response(
success=True, success=True, message="All actors fetched.", actors=DataSource.actors
message="All actors fetched.", )
actors=DataSource.actors)
@routes.get("/logical/kill_actor") @routes.get("/logical/kill_actor")
async def kill_actor(self, req) -> aiohttp.web.Response: async def kill_actor(self, req) -> aiohttp.web.Response:
@ -224,15 +237,17 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
except KeyError: except KeyError:
return rest_response(success=False, message="Bad Request") return rest_response(success=False, message="Bad Request")
try: try:
options = (("grpc.enable_http_proxy", 0), ) options = (("grpc.enable_http_proxy", 0),)
channel = ray._private.utils.init_grpc_channel( channel = ray._private.utils.init_grpc_channel(
f"{ip_address}:{port}", options=options, asynchronous=True) f"{ip_address}:{port}", options=options, asynchronous=True
)
stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel) stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel)
await stub.KillActor( await stub.KillActor(
core_worker_pb2.KillActorRequest( core_worker_pb2.KillActorRequest(
intended_actor_id=ray._private.utils.hex_to_binary( intended_actor_id=ray._private.utils.hex_to_binary(actor_id)
actor_id))) )
)
except aiogrpc.AioRpcError: except aiogrpc.AioRpcError:
# This always throws an exception because the worker # This always throws an exception because the worker
@ -240,13 +255,13 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
# before this handler, however it deletes the actor correctly. # before this handler, however it deletes the actor correctly.
pass pass
return rest_response( return rest_response(success=True, message=f"Killed actor with id {actor_id}")
success=True, message=f"Killed actor with id {actor_id}")
async def run(self, server): async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._gcs_actor_info_stub = \ self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
gcs_service_pb2_grpc.ActorInfoGcsServiceStub(gcs_channel) gcs_channel
)
await asyncio.gather(self._update_actors()) await asyncio.gather(self._update_actors())

View file

@ -7,27 +7,29 @@ PYCLASSNAME_RE = re.compile(r"(.+?)\(")
def construct_actor_groups(actors): def construct_actor_groups(actors):
"""actors is a dict from actor id to an actor or an """actors is a dict from actor id to an actor or an
actor creation task The shared fields currently are actor creation task The shared fields currently are
"actorClass", "actorId", and "state" """ "actorClass", "actorId", and "state" """
actor_groups = _group_actors_by_python_class(actors) actor_groups = _group_actors_by_python_class(actors)
stats_by_group = { stats_by_group = {
name: _get_actor_group_stats(group) name: _get_actor_group_stats(group) for name, group in actor_groups.items()
for name, group in actor_groups.items()
} }
summarized_actor_groups = {} summarized_actor_groups = {}
for name, group in actor_groups.items(): for name, group in actor_groups.items():
summarized_actor_groups[name] = { summarized_actor_groups[name] = {
"entries": group, "entries": group,
"summary": stats_by_group[name] "summary": stats_by_group[name],
} }
return summarized_actor_groups return summarized_actor_groups
def actor_classname_from_task_spec(task_spec): def actor_classname_from_task_spec(task_spec):
return task_spec.get("functionDescriptor", {})\ return (
.get("pythonFunctionDescriptor", {})\ task_spec.get("functionDescriptor", {})
.get("className", "Unknown actor class").split(".")[-1] .get("pythonFunctionDescriptor", {})
.get("className", "Unknown actor class")
.split(".")[-1]
)
def _group_actors_by_python_class(actors): def _group_actors_by_python_class(actors):

View file

@ -50,8 +50,7 @@ def test_actor_groups(ray_start_with_dashboard):
response = requests.get(webui_url + "/logical/actor_groups") response = requests.get(webui_url + "/logical/actor_groups")
response.raise_for_status() response.raise_for_status()
actor_groups_resp = response.json() actor_groups_resp = response.json()
assert actor_groups_resp["result"] is True, actor_groups_resp[ assert actor_groups_resp["result"] is True, actor_groups_resp["msg"]
"msg"]
actor_groups = actor_groups_resp["data"]["actorGroups"] actor_groups = actor_groups_resp["data"]["actorGroups"]
assert "Foo" in actor_groups assert "Foo" in actor_groups
summary = actor_groups["Foo"]["summary"] summary = actor_groups["Foo"]["summary"]
@ -78,9 +77,13 @@ def test_actor_groups(ray_start_with_dashboard):
last_ex = ex last_ex = ex
finally: finally:
if time.time() > start_time + timeout_seconds: if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception( ex_stack = (
type(last_ex), last_ex, traceback.format_exception(
last_ex.__traceback__) if last_ex else [] type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack) ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}") raise Exception(f"Timed out while testing, {ex_stack}")
@ -135,9 +138,13 @@ def test_actors(disable_aiohttp_cache, ray_start_with_dashboard):
last_ex = ex last_ex = ex
finally: finally:
if time.time() > start_time + timeout_seconds: if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception( ex_stack = (
type(last_ex), last_ex, traceback.format_exception(
last_ex.__traceback__) if last_ex else [] type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack) ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}") raise Exception(f"Timed out while testing, {ex_stack}")
@ -183,8 +190,9 @@ def test_kill_actor(ray_start_with_dashboard):
params={ params={
"actorId": actor["actorId"], "actorId": actor["actorId"],
"ipAddress": actor["ipAddress"], "ipAddress": actor["ipAddress"],
"port": actor["port"] "port": actor["port"],
}) },
)
resp.raise_for_status() resp.raise_for_status()
resp_json = resp.json() resp_json = resp.json()
assert resp_json["result"] is True, "msg" in resp_json assert resp_json["result"] is True, "msg" in resp_json
@ -199,19 +207,17 @@ def test_kill_actor(ray_start_with_dashboard):
break break
except (KeyError, AssertionError) as e: except (KeyError, AssertionError) as e:
last_exc = e last_exc = e
time.sleep(.1) time.sleep(0.1)
assert last_exc is None assert last_exc is None
def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard): def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
timeout = 5 timeout = 5
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
is True)
address_info = ray_start_with_dashboard address_info = ray_start_with_dashboard
if gcs_pubsub.gcs_pubsub_enabled(): if gcs_pubsub.gcs_pubsub_enabled():
sub = gcs_pubsub.GcsActorSubscriber( sub = gcs_pubsub.GcsActorSubscriber(address=address_info["gcs_address"])
address=address_info["gcs_address"])
sub.subscribe() sub.subscribe()
else: else:
address = address_info["redis_address"] address = address_info["redis_address"]
@ -221,7 +227,8 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
client = redis.StrictRedis( client = redis.StrictRedis(
host=address[0], host=address[0],
port=int(address[1]), port=int(address[1]),
password=ray_constants.REDIS_DEFAULT_PASSWORD) password=ray_constants.REDIS_DEFAULT_PASSWORD,
)
sub = client.pubsub(ignore_subscribe_messages=True) sub = client.pubsub(ignore_subscribe_messages=True)
sub.psubscribe(gcs_utils.RAY_ACTOR_PUBSUB_PATTERN) sub.psubscribe(gcs_utils.RAY_ACTOR_PUBSUB_PATTERN)
@ -245,8 +252,7 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
time.sleep(0.01) time.sleep(0.01)
continue continue
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"]) pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
actor_data = gcs_utils.ActorTableData.FromString( actor_data = gcs_utils.ActorTableData.FromString(pubsub_msg.data)
pubsub_msg.data)
if actor_data is None: if actor_data is None:
continue continue
msgs.append(actor_data) msgs.append(actor_data)
@ -266,12 +272,22 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
def actor_table_data_to_dict(message): def actor_table_data_to_dict(message):
return dashboard_utils.message_to_dict( return dashboard_utils.message_to_dict(
message, { message,
"actorId", "parentId", "jobId", "workerId", "rayletId", {
"actorCreationDummyObjectId", "callerId", "taskId", "actorId",
"parentTaskId", "sourceActorId", "placementGroupId" "parentId",
"jobId",
"workerId",
"rayletId",
"actorCreationDummyObjectId",
"callerId",
"taskId",
"parentTaskId",
"sourceActorId",
"placementGroupId",
}, },
including_default_value_fields=False) including_default_value_fields=False,
)
non_state_keys = ("actorId", "jobId", "taskSpec") non_state_keys = ("actorId", "jobId", "taskSpec")
@ -287,23 +303,31 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
# be published. # be published.
elif actor_data_dict["state"] in ("ALIVE", "DEAD"): elif actor_data_dict["state"] in ("ALIVE", "DEAD"):
assert actor_data_dict.keys() >= { assert actor_data_dict.keys() >= {
"state", "address", "timestamp", "pid", "rayNamespace" "state",
"address",
"timestamp",
"pid",
"rayNamespace",
} }
elif actor_data_dict["state"] == "PENDING_CREATION": elif actor_data_dict["state"] == "PENDING_CREATION":
assert actor_data_dict.keys() == { assert actor_data_dict.keys() == {
"state", "address", "actorId", "actorCreationDummyObjectId", "state",
"jobId", "ownerAddress", "taskSpec", "className", "address",
"serializedRuntimeEnv", "rayNamespace" "actorId",
"actorCreationDummyObjectId",
"jobId",
"ownerAddress",
"taskSpec",
"className",
"serializedRuntimeEnv",
"rayNamespace",
} }
else: else:
raise Exception("Unknown state: {}".format( raise Exception("Unknown state: {}".format(actor_data_dict["state"]))
actor_data_dict["state"]))
def test_nil_node(enable_test_module, disable_aiohttp_cache, def test_nil_node(enable_test_module, disable_aiohttp_cache, ray_start_with_dashboard):
ray_start_with_dashboard): assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
webui_url = ray_start_with_dashboard["webui_url"] webui_url = ray_start_with_dashboard["webui_url"]
assert wait_until_server_available(webui_url) assert wait_until_server_available(webui_url)
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
@ -334,9 +358,13 @@ def test_nil_node(enable_test_module, disable_aiohttp_cache,
last_ex = ex last_ex = ex
finally: finally:
if time.time() > start_time + timeout_seconds: if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception( ex_stack = (
type(last_ex), last_ex, traceback.format_exception(
last_ex.__traceback__) if last_ex else [] type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack) ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}") raise Exception(f"Timed out while testing, {ex_stack}")

View file

@ -24,13 +24,11 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
os.makedirs(self._event_dir, exist_ok=True) os.makedirs(self._event_dir, exist_ok=True)
self._monitor: Union[asyncio.Task, None] = None self._monitor: Union[asyncio.Task, None] = None
self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None
self._cached_events = asyncio.Queue( self._cached_events = asyncio.Queue(event_consts.EVENT_AGENT_CACHE_SIZE)
event_consts.EVENT_AGENT_CACHE_SIZE) logger.info("Event agent cache buffer size: %s", self._cached_events.maxsize)
logger.info("Event agent cache buffer size: %s",
self._cached_events.maxsize)
async def _connect_to_dashboard(self): async def _connect_to_dashboard(self):
""" Connect to the dashboard. If the dashboard is not started, then """Connect to the dashboard. If the dashboard is not started, then
this method will never returns. this method will never returns.
Returns: Returns:
@ -41,23 +39,24 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
# TODO: Use async version if performance is an issue # TODO: Use async version if performance is an issue
dashboard_rpc_address = internal_kv._internal_kv_get( dashboard_rpc_address = internal_kv._internal_kv_get(
dashboard_consts.DASHBOARD_RPC_ADDRESS, dashboard_consts.DASHBOARD_RPC_ADDRESS,
namespace=ray_constants.KV_NAMESPACE_DASHBOARD) namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
if dashboard_rpc_address: if dashboard_rpc_address:
logger.info("Report events to %s", dashboard_rpc_address) logger.info("Report events to %s", dashboard_rpc_address)
options = (("grpc.enable_http_proxy", 0), ) options = (("grpc.enable_http_proxy", 0),)
channel = utils.init_grpc_channel( channel = utils.init_grpc_channel(
dashboard_rpc_address, dashboard_rpc_address, options=options, asynchronous=True
options=options, )
asynchronous=True)
return event_pb2_grpc.ReportEventServiceStub(channel) return event_pb2_grpc.ReportEventServiceStub(channel)
except Exception: except Exception:
logger.exception("Connect to dashboard failed.") logger.exception("Connect to dashboard failed.")
await asyncio.sleep( await asyncio.sleep(
event_consts.RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS) event_consts.RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS
)
@async_loop_forever(event_consts.EVENT_AGENT_REPORT_INTERVAL_SECONDS) @async_loop_forever(event_consts.EVENT_AGENT_REPORT_INTERVAL_SECONDS)
async def report_events(self): async def report_events(self):
""" Report events from cached events queue. Reconnect to dashboard if """Report events from cached events queue. Reconnect to dashboard if
report failed. Log error after retry EVENT_AGENT_RETRY_TIMES. report failed. Log error after retry EVENT_AGENT_RETRY_TIMES.
This method will never returns. This method will never returns.
@ -70,14 +69,15 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
await self._stub.ReportEvents(request) await self._stub.ReportEvents(request)
break break
except Exception: except Exception:
logger.exception("Report event failed, reconnect to the " logger.exception("Report event failed, reconnect to the " "dashboard.")
"dashboard.")
self._stub = await self._connect_to_dashboard() self._stub = await self._connect_to_dashboard()
else: else:
data_str = str(data) data_str = str(data)
limit = event_consts.LOG_ERROR_EVENT_STRING_LENGTH_LIMIT limit = event_consts.LOG_ERROR_EVENT_STRING_LENGTH_LIMIT
logger.error("Report event failed: %s", logger.error(
data_str[:limit] + (data_str[limit:] and "...")) "Report event failed: %s",
data_str[:limit] + (data_str[limit:] and "..."),
)
async def run(self, server): async def run(self, server):
# Connect to dashboard. # Connect to dashboard.
@ -86,7 +86,8 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
self._monitor = monitor_events( self._monitor = monitor_events(
self._event_dir, self._event_dir,
lambda data: create_task(self._cached_events.put(data)), lambda data: create_task(self._cached_events.put(data)),
source_types=event_consts.EVENT_AGENT_MONITOR_SOURCE_TYPES) source_types=event_consts.EVENT_AGENT_MONITOR_SOURCE_TYPES,
)
# Start reporting events. # Start reporting events.
await self.report_events() await self.report_events()

View file

@ -4,22 +4,20 @@ from ray.core.generated import event_pb2
LOG_ERROR_EVENT_STRING_LENGTH_LIMIT = 1000 LOG_ERROR_EVENT_STRING_LENGTH_LIMIT = 1000
RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS = 2 RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS = 2
# Monitor events # Monitor events
SCAN_EVENT_DIR_INTERVAL_SECONDS = env_integer( SCAN_EVENT_DIR_INTERVAL_SECONDS = env_integer("SCAN_EVENT_DIR_INTERVAL_SECONDS", 2)
"SCAN_EVENT_DIR_INTERVAL_SECONDS", 2)
SCAN_EVENT_START_OFFSET_SECONDS = -30 * 60 SCAN_EVENT_START_OFFSET_SECONDS = -30 * 60
CONCURRENT_READ_LIMIT = 50 CONCURRENT_READ_LIMIT = 50
EVENT_READ_LINE_COUNT_LIMIT = 200 EVENT_READ_LINE_COUNT_LIMIT = 200
EVENT_READ_LINE_LENGTH_LIMIT = env_integer("EVENT_READ_LINE_LENGTH_LIMIT", EVENT_READ_LINE_LENGTH_LIMIT = env_integer(
2 * 1024 * 1024) # 2MB "EVENT_READ_LINE_LENGTH_LIMIT", 2 * 1024 * 1024
) # 2MB
# Report events # Report events
EVENT_AGENT_REPORT_INTERVAL_SECONDS = 0.1 EVENT_AGENT_REPORT_INTERVAL_SECONDS = 0.1
EVENT_AGENT_RETRY_TIMES = 10 EVENT_AGENT_RETRY_TIMES = 10
EVENT_AGENT_CACHE_SIZE = 10240 EVENT_AGENT_CACHE_SIZE = 10240
# Event sources # Event sources
EVENT_HEAD_MONITOR_SOURCE_TYPES = [ EVENT_HEAD_MONITOR_SOURCE_TYPES = [event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)]
event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)
]
EVENT_AGENT_MONITOR_SOURCE_TYPES = list( EVENT_AGENT_MONITOR_SOURCE_TYPES = list(
set(event_pb2.Event.SourceType.keys()) - set(event_pb2.Event.SourceType.keys()) - set(EVENT_HEAD_MONITOR_SOURCE_TYPES)
set(EVENT_HEAD_MONITOR_SOURCE_TYPES)) )
EVENT_SOURCE_ALL = event_pb2.Event.SourceType.keys() EVENT_SOURCE_ALL = event_pb2.Event.SourceType.keys()

View file

@ -24,8 +24,9 @@ JobEvents = OrderedDict
dashboard_utils._json_compatible_types.add(JobEvents) dashboard_utils._json_compatible_types.add(JobEvents)
class EventHead(dashboard_utils.DashboardHeadModule, class EventHead(
event_pb2_grpc.ReportEventServiceServicer): dashboard_utils.DashboardHeadModule, event_pb2_grpc.ReportEventServiceServicer
):
def __init__(self, dashboard_head): def __init__(self, dashboard_head):
super().__init__(dashboard_head) super().__init__(dashboard_head)
self._event_dir = os.path.join(self._dashboard_head.log_dir, "events") self._event_dir = os.path.join(self._dashboard_head.log_dir, "events")
@ -70,21 +71,24 @@ class EventHead(dashboard_utils.DashboardHeadModule,
for job_id, job_events in DataSource.events.items() for job_id, job_events in DataSource.events.items()
} }
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, message="All events fetched.", events=all_events) success=True, message="All events fetched.", events=all_events
)
job_events = DataSource.events.get(job_id, {}) job_events = DataSource.events.get(job_id, {})
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, success=True,
message="Job events fetched.", message="Job events fetched.",
job_id=job_id, job_id=job_id,
events=list(job_events.values())) events=list(job_events.values()),
)
async def run(self, server): async def run(self, server):
event_pb2_grpc.add_ReportEventServiceServicer_to_server(self, server) event_pb2_grpc.add_ReportEventServiceServicer_to_server(self, server)
self._monitor = monitor_events( self._monitor = monitor_events(
self._event_dir, self._event_dir,
lambda data: self._update_events(parse_event_strings(data)), lambda data: self._update_events(parse_event_strings(data)),
source_types=event_consts.EVENT_HEAD_MONITOR_SOURCE_TYPES) source_types=event_consts.EVENT_HEAD_MONITOR_SOURCE_TYPES,
)
@staticmethod @staticmethod
def is_minimal_module(): def is_minimal_module():

View file

@ -19,8 +19,7 @@ def _get_source_files(event_dir, source_types=None, event_file_filter=None):
source_files = {} source_files = {}
all_source_types = set(event_consts.EVENT_SOURCE_ALL) all_source_types = set(event_consts.EVENT_SOURCE_ALL)
for source_type in source_types or event_consts.EVENT_SOURCE_ALL: for source_type in source_types or event_consts.EVENT_SOURCE_ALL:
assert source_type in all_source_types, \ assert source_type in all_source_types, f"Invalid source type: {source_type}"
f"Invalid source type: {source_type}"
files = [] files = []
for n in event_log_names: for n in event_log_names:
if fnmatch.fnmatch(n, f"*{source_type}*"): if fnmatch.fnmatch(n, f"*{source_type}*"):
@ -35,9 +34,9 @@ def _get_source_files(event_dir, source_types=None, event_file_filter=None):
def _restore_newline(event_dict): def _restore_newline(event_dict):
try: try:
event_dict["message"] = event_dict["message"]\ event_dict["message"] = (
.replace("\\n", "\n")\ event_dict["message"].replace("\\n", "\n").replace("\\r", "\n")
.replace("\\r", "\n") )
except Exception: except Exception:
logger.exception("Restore newline for event failed: %s", event_dict) logger.exception("Restore newline for event failed: %s", event_dict)
return event_dict return event_dict
@ -61,13 +60,13 @@ def parse_event_strings(event_string_list):
ReadFileResult = collections.namedtuple( ReadFileResult = collections.namedtuple(
"ReadFileResult", ["fid", "size", "mtime", "position", "lines"]) "ReadFileResult", ["fid", "size", "mtime", "position", "lines"]
)
def _read_file(file, def _read_file(
pos, file, pos, n_lines=event_consts.EVENT_READ_LINE_COUNT_LIMIT, closefd=True
n_lines=event_consts.EVENT_READ_LINE_COUNT_LIMIT, ):
closefd=True):
with open(file, "rb", closefd=closefd) as f: with open(file, "rb", closefd=closefd) as f:
# The ino may be 0 on Windows. # The ino may be 0 on Windows.
stat = os.stat(f.fileno()) stat = os.stat(f.fileno())
@ -82,24 +81,25 @@ def _read_file(file,
if sep - start <= event_consts.EVENT_READ_LINE_LENGTH_LIMIT: if sep - start <= event_consts.EVENT_READ_LINE_LENGTH_LIMIT:
lines.append(mm[start:sep].decode("utf-8")) lines.append(mm[start:sep].decode("utf-8"))
else: else:
truncated_size = min( truncated_size = min(100, event_consts.EVENT_READ_LINE_LENGTH_LIMIT)
100, event_consts.EVENT_READ_LINE_LENGTH_LIMIT)
logger.warning( logger.warning(
"Ignored long string: %s...(%s chars)", "Ignored long string: %s...(%s chars)",
mm[start:start + truncated_size].decode("utf-8"), mm[start : start + truncated_size].decode("utf-8"),
sep - start) sep - start,
)
start = sep + 1 start = sep + 1
return ReadFileResult(fid, stat.st_size, stat.st_mtime, start, lines) return ReadFileResult(fid, stat.st_size, stat.st_mtime, start, lines)
def monitor_events( def monitor_events(
event_dir, event_dir,
callback, callback,
scan_interval_seconds=event_consts.SCAN_EVENT_DIR_INTERVAL_SECONDS, scan_interval_seconds=event_consts.SCAN_EVENT_DIR_INTERVAL_SECONDS,
start_mtime=time.time() + event_consts.SCAN_EVENT_START_OFFSET_SECONDS, start_mtime=time.time() + event_consts.SCAN_EVENT_START_OFFSET_SECONDS,
monitor_files=None, monitor_files=None,
source_types=None): source_types=None,
""" Monitor events in directory. New events will be read and passed to the ):
"""Monitor events in directory. New events will be read and passed to the
callback. callback.
Args: Args:
@ -121,20 +121,22 @@ def monitor_events(
monitor_files = {} monitor_files = {}
logger.info( logger.info(
"Monitor events logs modified after %s on %s, " "Monitor events logs modified after %s on %s, " "the source types are %s.",
"the source types are %s.", start_mtime, event_dir, "all" start_mtime,
if source_types is None else source_types) event_dir,
"all" if source_types is None else source_types,
)
MonitorFile = collections.namedtuple("MonitorFile", MonitorFile = collections.namedtuple("MonitorFile", ["size", "mtime", "position"])
["size", "mtime", "position"])
def _source_file_filter(source_file): def _source_file_filter(source_file):
stat = os.stat(source_file) stat = os.stat(source_file)
return stat.st_mtime > start_mtime return stat.st_mtime > start_mtime
def _read_monitor_file(file, pos): def _read_monitor_file(file, pos):
assert isinstance(file, str), \ assert isinstance(
f"File should be a str, but a {type(file)}({file}) found" file, str
), f"File should be a str, but a {type(file)}({file}) found"
fd = os.open(file, os.O_RDONLY) fd = os.open(file, os.O_RDONLY)
try: try:
stat = os.stat(fd) stat = os.stat(fd)
@ -145,12 +147,14 @@ def monitor_events(
fid = stat.st_ino or file fid = stat.st_ino or file
monitor_file = monitor_files.get(fid) monitor_file = monitor_files.get(fid)
if monitor_file: if monitor_file:
if (monitor_file.position == monitor_file.size if (
and monitor_file.size == stat.st_size monitor_file.position == monitor_file.size
and monitor_file.mtime == stat.st_mtime): and monitor_file.size == stat.st_size
and monitor_file.mtime == stat.st_mtime
):
logger.debug( logger.debug(
"Skip reading the file because " "Skip reading the file because " "there is no change: %s", file
"there is no change: %s", file) )
return [] return []
position = monitor_file.position position = monitor_file.position
else: else:
@ -169,22 +173,23 @@ def monitor_events(
@async_loop_forever(scan_interval_seconds, cancellable=True) @async_loop_forever(scan_interval_seconds, cancellable=True)
async def _scan_event_log_files(): async def _scan_event_log_files():
# Scan event files. # Scan event files.
source_files = await loop.run_in_executor(None, _get_source_files, source_files = await loop.run_in_executor(
event_dir, source_types, None, _get_source_files, event_dir, source_types, _source_file_filter
_source_file_filter) )
# Limit concurrent read to avoid fd exhaustion. # Limit concurrent read to avoid fd exhaustion.
semaphore = asyncio.Semaphore(event_consts.CONCURRENT_READ_LIMIT) semaphore = asyncio.Semaphore(event_consts.CONCURRENT_READ_LIMIT)
async def _concurrent_coro(filename): async def _concurrent_coro(filename):
async with semaphore: async with semaphore:
return await loop.run_in_executor(None, _read_monitor_file, return await loop.run_in_executor(None, _read_monitor_file, filename, 0)
filename, 0)
# Read files. # Read files.
await asyncio.gather(*[ await asyncio.gather(
_concurrent_coro(filename) *[
for filename in list(itertools.chain(*source_files.values())) _concurrent_coro(filename)
]) for filename in list(itertools.chain(*source_files.values()))
]
)
return create_task(_scan_event_log_files()) return create_task(_scan_event_log_files())

View file

@ -23,7 +23,8 @@ from ray._private.test_utils import (
wait_for_condition, wait_for_condition,
) )
from ray.dashboard.modules.event.event_utils import ( from ray.dashboard.modules.event.event_utils import (
monitor_events, ) monitor_events,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,7 +33,8 @@ def _get_event(msg="empty message", job_id=None, source_type=None):
return { return {
"event_id": binary_to_hex(np.random.bytes(18)), "event_id": binary_to_hex(np.random.bytes(18)),
"source_type": random.choice(event_pb2.Event.SourceType.keys()) "source_type": random.choice(event_pb2.Event.SourceType.keys())
if source_type is None else source_type, if source_type is None
else source_type,
"host_name": "po-dev.inc.alipay.net", "host_name": "po-dev.inc.alipay.net",
"pid": random.randint(1, 65536), "pid": random.randint(1, 65536),
"label": "", "label": "",
@ -41,16 +43,18 @@ def _get_event(msg="empty message", job_id=None, source_type=None):
"severity": "INFO", "severity": "INFO",
"custom_fields": { "custom_fields": {
"job_id": ray.JobID.from_int(random.randint(1, 100)).hex() "job_id": ray.JobID.from_int(random.randint(1, 100)).hex()
if job_id is None else job_id, if job_id is None
else job_id,
"node_id": "", "node_id": "",
"task_id": "", "task_id": "",
} },
} }
def _test_logger(name, log_file, max_bytes, backup_count): def _test_logger(name, log_file, max_bytes, backup_count):
handler = logging.handlers.RotatingFileHandler( handler = logging.handlers.RotatingFileHandler(
log_file, maxBytes=max_bytes, backupCount=backup_count) log_file, maxBytes=max_bytes, backupCount=backup_count
)
formatter = logging.Formatter("%(message)s") formatter = logging.Formatter("%(message)s")
handler.setFormatter(formatter) handler.setFormatter(formatter)
@ -63,15 +67,14 @@ def _test_logger(name, log_file, max_bytes, backup_count):
def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard): def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])) assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
webui_url = format_web_url(ray_start_with_dashboard["webui_url"]) webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
session_dir = ray_start_with_dashboard["session_dir"] session_dir = ray_start_with_dashboard["session_dir"]
event_dir = os.path.join(session_dir, "logs", "events") event_dir = os.path.join(session_dir, "logs", "events")
job_id = ray.JobID.from_int(100).hex() job_id = ray.JobID.from_int(100).hex()
source_type_gcs = event_pb2.Event.SourceType.Name(event_pb2.Event.GCS) source_type_gcs = event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)
source_type_raylet = event_pb2.Event.SourceType.Name( source_type_raylet = event_pb2.Event.SourceType.Name(event_pb2.Event.RAYLET)
event_pb2.Event.RAYLET)
test_count = 20 test_count = 20
for source_type in [source_type_gcs, source_type_raylet]: for source_type in [source_type_gcs, source_type_raylet]:
@ -80,10 +83,10 @@ def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
__name__ + str(random.random()), __name__ + str(random.random()),
test_log_file, test_log_file,
max_bytes=2000, max_bytes=2000,
backup_count=1000) backup_count=1000,
)
for i in range(test_count): for i in range(test_count):
sample_event = _get_event( sample_event = _get_event(str(i), job_id=job_id, source_type=source_type)
str(i), job_id=job_id, source_type=source_type)
test_logger.info("%s", json.dumps(sample_event)) test_logger.info("%s", json.dumps(sample_event))
def _check_events(): def _check_events():
@ -112,10 +115,11 @@ def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
wait_for_condition(_check_events, timeout=15) wait_for_condition(_check_events, timeout=15)
def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache, def test_event_message_limit(
ray_start_with_dashboard): small_event_line_limit, disable_aiohttp_cache, ray_start_with_dashboard
):
event_read_line_length_limit = small_event_line_limit event_read_line_length_limit = small_event_line_limit
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])) assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
webui_url = format_web_url(ray_start_with_dashboard["webui_url"]) webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
session_dir = ray_start_with_dashboard["session_dir"] session_dir = ray_start_with_dashboard["session_dir"]
event_dir = os.path.join(session_dir, "logs", "events") event_dir = os.path.join(session_dir, "logs", "events")
@ -148,8 +152,8 @@ def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache,
except Exception: except Exception:
pass pass
os.rename( os.rename(
os.path.join(event_dir, "tmp.log"), os.path.join(event_dir, "tmp.log"), os.path.join(event_dir, "event_GCS.log")
os.path.join(event_dir, "event_GCS.log")) )
def _check_events(): def _check_events():
try: try:
@ -157,14 +161,14 @@ def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache,
resp.raise_for_status() resp.raise_for_status()
result = resp.json() result = resp.json()
all_events = result["data"]["events"] all_events = result["data"]["events"]
assert len(all_events[job_id] assert (
) >= event_consts.EVENT_READ_LINE_COUNT_LIMIT + 10 len(all_events[job_id]) >= event_consts.EVENT_READ_LINE_COUNT_LIMIT + 10
)
messages = [e["message"] for e in all_events[job_id]] messages = [e["message"] for e in all_events[job_id]]
for i in range(10): for i in range(10):
assert str(i) * message_len in messages assert str(i) * message_len in messages
assert "2" * (message_len + 1) not in messages assert "2" * (message_len + 1) not in messages
assert str(event_consts.EVENT_READ_LINE_COUNT_LIMIT - assert str(event_consts.EVENT_READ_LINE_COUNT_LIMIT - 1) in messages
1) in messages
return True return True
except Exception as ex: except Exception as ex:
logger.exception(ex) logger.exception(ex)
@ -179,15 +183,12 @@ async def test_monitor_events():
common = event_pb2.Event.SourceType.Name(event_pb2.Event.COMMON) common = event_pb2.Event.SourceType.Name(event_pb2.Event.COMMON)
common_log = os.path.join(temp_dir, f"event_{common}.log") common_log = os.path.join(temp_dir, f"event_{common}.log")
test_logger = _test_logger( test_logger = _test_logger(
__name__ + str(random.random()), __name__ + str(random.random()), common_log, max_bytes=10, backup_count=10
common_log, )
max_bytes=10,
backup_count=10)
test_events1 = [] test_events1 = []
monitor_task = monitor_events( monitor_task = monitor_events(
temp_dir, temp_dir, lambda x: test_events1.extend(x), scan_interval_seconds=0.01
lambda x: test_events1.extend(x), )
scan_interval_seconds=0.01)
assert not monitor_task.done() assert not monitor_task.done()
count = 10 count = 10
@ -206,7 +207,8 @@ async def test_monitor_events():
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
raise TimeoutError( raise TimeoutError(
f"Timeout, read events: {sorted_events}, " f"Timeout, read events: {sorted_events}, "
f"expect events: {expect_events}") f"expect events: {expect_events}"
)
if len(sorted_events) == len(expect_events): if len(sorted_events) == len(expect_events):
if sorted_events == expect_events: if sorted_events == expect_events:
break break
@ -214,40 +216,37 @@ async def test_monitor_events():
await asyncio.gather( await asyncio.gather(
_writer(count, read_events=test_events1), _writer(count, read_events=test_events1),
_check_events( _check_events([str(i) for i in range(count)], read_events=test_events1),
[str(i) for i in range(count)], read_events=test_events1)) )
monitor_task.cancel() monitor_task.cancel()
test_events2 = [] test_events2 = []
monitor_task = monitor_events( monitor_task = monitor_events(
temp_dir, temp_dir, lambda x: test_events2.extend(x), scan_interval_seconds=0.1
lambda x: test_events2.extend(x), )
scan_interval_seconds=0.1)
await _check_events( await _check_events([str(i) for i in range(count)], read_events=test_events2)
[str(i) for i in range(count)], read_events=test_events2)
await _writer(count, count * 2, read_events=test_events2) await _writer(count, count * 2, read_events=test_events2)
await _check_events( await _check_events(
[str(i) for i in range(count * 2)], read_events=test_events2) [str(i) for i in range(count * 2)], read_events=test_events2
)
log_file_count = len(os.listdir(temp_dir)) log_file_count = len(os.listdir(temp_dir))
test_logger = _test_logger( test_logger = _test_logger(
__name__ + str(random.random()), __name__ + str(random.random()), common_log, max_bytes=1000, backup_count=10
common_log, )
max_bytes=1000,
backup_count=10)
assert len(os.listdir(temp_dir)) == log_file_count assert len(os.listdir(temp_dir)) == log_file_count
await _writer( await _writer(count * 2, count * 3, spin=False, read_events=test_events2)
count * 2, count * 3, spin=False, read_events=test_events2)
await _check_events( await _check_events(
[str(i) for i in range(count * 3)], read_events=test_events2) [str(i) for i in range(count * 3)], read_events=test_events2
await _writer( )
count * 3, count * 4, spin=False, read_events=test_events2) await _writer(count * 3, count * 4, spin=False, read_events=test_events2)
await _check_events( await _check_events(
[str(i) for i in range(count * 4)], read_events=test_events2) [str(i) for i in range(count * 4)], read_events=test_events2
)
# Test cancel monitor task. # Test cancel monitor task.
monitor_task.cancel() monitor_task.cancel()
@ -255,8 +254,7 @@ async def test_monitor_events():
await monitor_task await monitor_task
assert monitor_task.done() assert monitor_task.done()
assert len( assert len(os.listdir(temp_dir)) > 1, "Event log should have rollovers."
os.listdir(temp_dir)) > 1, "Event log should have rollovers."
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -7,21 +7,21 @@ import yaml
import click import click
from ray.autoscaler._private.cli_logger import (add_click_logging_options, from ray.autoscaler._private.cli_logger import add_click_logging_options, cli_logger, cf
cli_logger, cf)
from ray.dashboard.modules.job.common import JobStatus from ray.dashboard.modules.job.common import JobStatus
from ray.dashboard.modules.job.sdk import JobSubmissionClient from ray.dashboard.modules.job.sdk import JobSubmissionClient
def _get_sdk_client(address: Optional[str], def _get_sdk_client(
create_cluster_if_needed: bool = False address: Optional[str], create_cluster_if_needed: bool = False
) -> JobSubmissionClient: ) -> JobSubmissionClient:
if address is None: if address is None:
if "RAY_ADDRESS" not in os.environ: if "RAY_ADDRESS" not in os.environ:
raise ValueError( raise ValueError(
"Address must be specified using either the --address flag " "Address must be specified using either the --address flag "
"or RAY_ADDRESS environment variable.") "or RAY_ADDRESS environment variable."
)
address = os.environ["RAY_ADDRESS"] address = os.environ["RAY_ADDRESS"]
cli_logger.labeled_value("Job submission server address", address) cli_logger.labeled_value("Job submission server address", address)
@ -73,55 +73,67 @@ def job_cli_group():
pass pass
@job_cli_group.command( @job_cli_group.command("submit", help="Submit a job to be executed on the cluster.")
"submit", help="Submit a job to be executed on the cluster.")
@click.option( @click.option(
"--address", "--address",
type=str, type=str,
default=None, default=None,
required=False, required=False,
help=("Address of the Ray cluster to connect to. Can also be specified " help=(
"using the RAY_ADDRESS environment variable.")) "Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."
),
)
@click.option( @click.option(
"--job-id", "--job-id",
type=str, type=str,
default=None, default=None,
required=False, required=False,
help=("Job ID to specify for the job. " help=("Job ID to specify for the job. " "If not provided, one will be generated."),
"If not provided, one will be generated.")) )
@click.option( @click.option(
"--runtime-env", "--runtime-env",
type=str, type=str,
default=None, default=None,
required=False, required=False,
help="Path to a local YAML file containing a runtime_env definition.") help="Path to a local YAML file containing a runtime_env definition.",
)
@click.option( @click.option(
"--runtime-env-json", "--runtime-env-json",
type=str, type=str,
default=None, default=None,
required=False, required=False,
help="JSON-serialized runtime_env dictionary.") help="JSON-serialized runtime_env dictionary.",
)
@click.option( @click.option(
"--working-dir", "--working-dir",
type=str, type=str,
default=None, default=None,
required=False, required=False,
help=("Directory containing files that your job will run in. Can be a " help=(
"local directory or a remote URI to a .zip file (S3, GS, HTTP). " "Directory containing files that your job will run in. Can be a "
"If specified, this overrides the option in --runtime-env."), "local directory or a remote URI to a .zip file (S3, GS, HTTP). "
"If specified, this overrides the option in --runtime-env."
),
) )
@click.option( @click.option(
"--no-wait", "--no-wait",
is_flag=True, is_flag=True,
type=bool, type=bool,
default=False, default=False,
help="If set, will not stream logs and wait for the job to exit.") help="If set, will not stream logs and wait for the job to exit.",
)
@add_click_logging_options @add_click_logging_options
@click.argument("entrypoint", nargs=-1, required=True, type=click.UNPROCESSED) @click.argument("entrypoint", nargs=-1, required=True, type=click.UNPROCESSED)
def job_submit(address: Optional[str], job_id: Optional[str], def job_submit(
runtime_env: Optional[str], runtime_env_json: Optional[str], address: Optional[str],
working_dir: Optional[str], entrypoint: Tuple[str], job_id: Optional[str],
no_wait: bool): runtime_env: Optional[str],
runtime_env_json: Optional[str],
working_dir: Optional[str],
entrypoint: Tuple[str],
no_wait: bool,
):
"""Submits a job to be run on the cluster. """Submits a job to be run on the cluster.
Example: Example:
@ -132,8 +144,9 @@ def job_submit(address: Optional[str], job_id: Optional[str],
final_runtime_env = {} final_runtime_env = {}
if runtime_env is not None: if runtime_env is not None:
if runtime_env_json is not None: if runtime_env_json is not None:
raise ValueError("Only one of --runtime_env and " raise ValueError(
"--runtime-env-json can be provided.") "Only one of --runtime_env and " "--runtime-env-json can be provided."
)
with open(runtime_env, "r") as f: with open(runtime_env, "r") as f:
final_runtime_env = yaml.safe_load(f) final_runtime_env = yaml.safe_load(f)
@ -143,14 +156,14 @@ def job_submit(address: Optional[str], job_id: Optional[str],
if working_dir is not None: if working_dir is not None:
if "working_dir" in final_runtime_env: if "working_dir" in final_runtime_env:
cli_logger.warning( cli_logger.warning(
"Overriding runtime_env working_dir with --working-dir option") "Overriding runtime_env working_dir with --working-dir option"
)
final_runtime_env["working_dir"] = working_dir final_runtime_env["working_dir"] = working_dir
job_id = client.submit_job( job_id = client.submit_job(
entrypoint=" ".join(entrypoint), entrypoint=" ".join(entrypoint), job_id=job_id, runtime_env=final_runtime_env
job_id=job_id, )
runtime_env=final_runtime_env)
_log_big_success_msg(f"Job '{job_id}' submitted successfully") _log_big_success_msg(f"Job '{job_id}' submitted successfully")
@ -172,15 +185,16 @@ def job_submit(address: Optional[str], job_id: Optional[str],
# sdk version 0 does not have log streaming # sdk version 0 does not have log streaming
if not no_wait: if not no_wait:
if int(sdk_version) > 0: if int(sdk_version) > 0:
cli_logger.print("Tailing logs until the job exits " cli_logger.print(
"(disable with --no-wait):") "Tailing logs until the job exits " "(disable with --no-wait):"
asyncio.get_event_loop().run_until_complete( )
_tail_logs(client, job_id)) asyncio.get_event_loop().run_until_complete(_tail_logs(client, job_id))
else: else:
cli_logger.warning( cli_logger.warning(
"Tailing logs is not enabled for job sdk client version " "Tailing logs is not enabled for job sdk client version "
f"{sdk_version}. Please upgrade your ray to latest version " f"{sdk_version}. Please upgrade your ray to latest version "
"for this feature.") "for this feature."
)
@job_cli_group.command("status", help="Get the status of a running job.") @job_cli_group.command("status", help="Get the status of a running job.")
@ -189,8 +203,11 @@ def job_submit(address: Optional[str], job_id: Optional[str],
type=str, type=str,
default=None, default=None,
required=False, required=False,
help=("Address of the Ray cluster to connect to. Can also be specified " help=(
"using the RAY_ADDRESS environment variable.")) "Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."
),
)
@click.argument("job-id", type=str) @click.argument("job-id", type=str)
@add_click_logging_options @add_click_logging_options
def job_status(address: Optional[str], job_id: str): def job_status(address: Optional[str], job_id: str):
@ -209,14 +226,18 @@ def job_status(address: Optional[str], job_id: str):
type=str, type=str,
default=None, default=None,
required=False, required=False,
help=("Address of the Ray cluster to connect to. Can also be specified " help=(
"using the RAY_ADDRESS environment variable.")) "Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."
),
)
@click.option( @click.option(
"--no-wait", "--no-wait",
is_flag=True, is_flag=True,
type=bool, type=bool,
default=False, default=False,
help="If set, will not wait for the job to exit.") help="If set, will not wait for the job to exit.",
)
@click.argument("job-id", type=str) @click.argument("job-id", type=str)
@add_click_logging_options @add_click_logging_options
def job_stop(address: Optional[str], no_wait: bool, job_id: str): def job_stop(address: Optional[str], no_wait: bool, job_id: str):
@ -232,14 +253,13 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
if no_wait: if no_wait:
return return
else: else:
cli_logger.print(f"Waiting for job '{job_id}' to exit " cli_logger.print(
f"(disable with --no-wait):") f"Waiting for job '{job_id}' to exit " f"(disable with --no-wait):"
)
while True: while True:
status = client.get_job_status(job_id) status = client.get_job_status(job_id)
if status.status in { if status.status in {JobStatus.STOPPED, JobStatus.SUCCEEDED, JobStatus.FAILED}:
JobStatus.STOPPED, JobStatus.SUCCEEDED, JobStatus.FAILED
}:
_log_job_status(client, job_id) _log_job_status(client, job_id)
break break
else: else:
@ -253,8 +273,11 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
type=str, type=str,
default=None, default=None,
required=False, required=False,
help=("Address of the Ray cluster to connect to. Can also be specified " help=(
"using the RAY_ADDRESS environment variable.")) "Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."
),
)
@click.argument("job-id", type=str) @click.argument("job-id", type=str)
@click.option( @click.option(
"-f", "-f",
@ -262,7 +285,8 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
is_flag=True, is_flag=True,
type=bool, type=bool,
default=False, default=False,
help="If set, follow the logs (like `tail -f`).") help="If set, follow the logs (like `tail -f`).",
)
@add_click_logging_options @add_click_logging_options
def job_logs(address: Optional[str], job_id: str, follow: bool): def job_logs(address: Optional[str], job_id: str, follow: bool):
"""Gets the logs of a job. """Gets the logs of a job.
@ -275,12 +299,12 @@ def job_logs(address: Optional[str], job_id: str, follow: bool):
# sdk version 0 did not have log streaming # sdk version 0 did not have log streaming
if follow: if follow:
if int(sdk_version) > 0: if int(sdk_version) > 0:
asyncio.get_event_loop().run_until_complete( asyncio.get_event_loop().run_until_complete(_tail_logs(client, job_id))
_tail_logs(client, job_id))
else: else:
cli_logger.warning( cli_logger.warning(
"Tailing logs is not enabled for job sdk client version " "Tailing logs is not enabled for job sdk client version "
f"{sdk_version}. Please upgrade your ray to latest version " f"{sdk_version}. Please upgrade your ray to latest version "
"for this feature.") "for this feature."
)
else: else:
print(client.get_job_logs(job_id), end="") print(client.get_job_logs(job_id), end="")

View file

@ -39,8 +39,10 @@ class JobStatusInfo:
def __post_init__(self): def __post_init__(self):
if self.message is None: if self.message is None:
if self.status == JobStatus.PENDING: if self.status == JobStatus.PENDING:
self.message = ("Job has not started yet, likely waiting " self.message = (
"for the runtime_env to be set up.") "Job has not started yet, likely waiting "
"for the runtime_env to be set up."
)
elif self.status == JobStatus.RUNNING: elif self.status == JobStatus.RUNNING:
self.message = "Job is currently running." self.message = "Job is currently running."
elif self.status == JobStatus.STOPPED: elif self.status == JobStatus.STOPPED:
@ -55,6 +57,7 @@ class JobStatusStorageClient:
""" """
Handles formatting of status storage key given job id. Handles formatting of status storage key given job id.
""" """
JOB_STATUS_KEY = "_ray_internal_job_status_{job_id}" JOB_STATUS_KEY = "_ray_internal_job_status_{job_id}"
def __init__(self): def __init__(self):
@ -69,12 +72,14 @@ class JobStatusStorageClient:
_internal_kv_put( _internal_kv_put(
self.JOB_STATUS_KEY.format(job_id=job_id), self.JOB_STATUS_KEY.format(job_id=job_id),
pickle.dumps(status), pickle.dumps(status),
namespace=ray_constants.KV_NAMESPACE_JOB) namespace=ray_constants.KV_NAMESPACE_JOB,
)
def get_status(self, job_id: str) -> Optional[JobStatusInfo]: def get_status(self, job_id: str) -> Optional[JobStatusInfo]:
pickled_status = _internal_kv_get( pickled_status = _internal_kv_get(
self.JOB_STATUS_KEY.format(job_id=job_id), self.JOB_STATUS_KEY.format(job_id=job_id),
namespace=ray_constants.KV_NAMESPACE_JOB) namespace=ray_constants.KV_NAMESPACE_JOB,
)
if pickled_status is None: if pickled_status is None:
return None return None
else: else:
@ -87,18 +92,16 @@ def uri_to_http_components(package_uri: str) -> Tuple[str, str]:
# We need to strip the gcs:// prefix and .zip suffix to make it # We need to strip the gcs:// prefix and .zip suffix to make it
# possible to pass the package_uri over HTTP. # possible to pass the package_uri over HTTP.
protocol, package_name = parse_uri(package_uri) protocol, package_name = parse_uri(package_uri)
return protocol.value, package_name[:-len(".zip")] return protocol.value, package_name[: -len(".zip")]
def http_uri_components_to_uri(protocol: str, package_name: str) -> str: def http_uri_components_to_uri(protocol: str, package_name: str) -> str:
if package_name.endswith(".zip"): if package_name.endswith(".zip"):
raise ValueError( raise ValueError(f"package_name ({package_name}) should not end in .zip")
f"package_name ({package_name}) should not end in .zip")
return f"{protocol}://{package_name}.zip" return f"{protocol}://{package_name}.zip"
def validate_request_type(json_data: Dict[str, Any], def validate_request_type(json_data: Dict[str, Any], request_type: dataclass) -> Any:
request_type: dataclass) -> Any:
return request_type(**json_data) return request_type(**json_data)
@ -124,8 +127,7 @@ class JobSubmitRequest:
def __post_init__(self): def __post_init__(self):
if not isinstance(self.entrypoint, str): if not isinstance(self.entrypoint, str):
raise TypeError( raise TypeError(f"entrypoint must be a string, got {type(self.entrypoint)}")
f"entrypoint must be a string, got {type(self.entrypoint)}")
if self.job_id is not None and not isinstance(self.job_id, str): if self.job_id is not None and not isinstance(self.job_id, str):
raise TypeError( raise TypeError(
@ -141,21 +143,21 @@ class JobSubmitRequest:
for k in self.runtime_env.keys(): for k in self.runtime_env.keys():
if not isinstance(k, str): if not isinstance(k, str):
raise TypeError( raise TypeError(
f"runtime_env keys must be strings, got {type(k)}") f"runtime_env keys must be strings, got {type(k)}"
)
if self.metadata is not None: if self.metadata is not None:
if not isinstance(self.metadata, dict): if not isinstance(self.metadata, dict):
raise TypeError( raise TypeError(f"metadata must be a dict, got {type(self.metadata)}")
f"metadata must be a dict, got {type(self.metadata)}")
else: else:
for k in self.metadata.keys(): for k in self.metadata.keys():
if not isinstance(k, str): if not isinstance(k, str):
raise TypeError( raise TypeError(f"metadata keys must be strings, got {type(k)}")
f"metadata keys must be strings, got {type(k)}")
for v in self.metadata.values(): for v in self.metadata.values():
if not isinstance(v, str): if not isinstance(v, str):
raise TypeError( raise TypeError(
f"metadata values must be strings, got {type(v)}") f"metadata values must be strings, got {type(v)}"
)
@dataclass @dataclass

View file

@ -12,8 +12,7 @@ import ray
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.optional_utils as dashboard_optional_utils
from ray._private.gcs_utils import use_gcs_for_bootstrap from ray._private.gcs_utils import use_gcs_for_bootstrap
from ray._private.runtime_env.packaging import (package_exists, from ray._private.runtime_env.packaging import package_exists, upload_package_to_gcs
upload_package_to_gcs)
from ray.dashboard.modules.job.common import ( from ray.dashboard.modules.job.common import (
CURRENT_VERSION, CURRENT_VERSION,
http_uri_components_to_uri, http_uri_components_to_uri,
@ -45,19 +44,20 @@ def _init_ray_and_catch_exceptions(f: Callable) -> Callable:
if use_gcs_for_bootstrap(): if use_gcs_for_bootstrap():
address = self._dashboard_head.gcs_address address = self._dashboard_head.gcs_address
redis_pw = None redis_pw = None
logger.info( logger.info(f"Connecting to ray with address={address}")
f"Connecting to ray with address={address}")
else: else:
ip, port = self._dashboard_head.redis_address ip, port = self._dashboard_head.redis_address
redis_pw = self._dashboard_head.redis_password redis_pw = self._dashboard_head.redis_password
address = f"{ip}:{port}" address = f"{ip}:{port}"
logger.info( logger.info(
f"Connecting to ray with address={address}, " f"Connecting to ray with address={address}, "
f"redis_pw={redis_pw}") f"redis_pw={redis_pw}"
)
ray.init( ray.init(
address=address, address=address,
namespace=RAY_INTERNAL_JOBS_NAMESPACE, namespace=RAY_INTERNAL_JOBS_NAMESPACE,
_redis_password=redis_pw) _redis_password=redis_pw,
)
except Exception as e: except Exception as e:
ray.shutdown() ray.shutdown()
raise e from None raise e from None
@ -67,7 +67,8 @@ def _init_ray_and_catch_exceptions(f: Callable) -> Callable:
logger.exception(f"Unexpected error in handler: {e}") logger.exception(f"Unexpected error in handler: {e}")
return Response( return Response(
text=traceback.format_exc(), text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code) status=aiohttp.web.HTTPInternalServerError.status_code,
)
return check return check
@ -77,8 +78,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
super().__init__(dashboard_head) super().__init__(dashboard_head)
self._job_manager = None self._job_manager = None
async def _parse_and_validate_request(self, req: Request, async def _parse_and_validate_request(
request_type: dataclass) -> Any: self, req: Request, request_type: dataclass
) -> Any:
"""Parse request and cast to request type. If parsing failed, return a """Parse request and cast to request type. If parsing failed, return a
Response object with status 400 and stacktrace instead. Response object with status 400 and stacktrace instead.
""" """
@ -88,7 +90,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
logger.info(f"Got invalid request type: {e}") logger.info(f"Got invalid request type: {e}")
return Response( return Response(
text=traceback.format_exc(), text=traceback.format_exc(),
status=aiohttp.web.HTTPBadRequest.status_code) status=aiohttp.web.HTTPBadRequest.status_code,
)
def job_exists(self, job_id: str) -> bool: def job_exists(self, job_id: str) -> bool:
status = self._job_manager.get_job_status(job_id) status = self._job_manager.get_job_status(job_id)
@ -101,7 +104,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
resp = VersionResponse( resp = VersionResponse(
version=CURRENT_VERSION, version=CURRENT_VERSION,
ray_version=ray.__version__, ray_version=ray.__version__,
ray_commit=ray.__commit__) ray_commit=ray.__commit__,
)
return Response( return Response(
text=json.dumps(dataclasses.asdict(resp)), text=json.dumps(dataclasses.asdict(resp)),
content_type="application/json", content_type="application/json",
@ -113,12 +117,14 @@ class JobHead(dashboard_utils.DashboardHeadModule):
async def get_package(self, req: Request) -> Response: async def get_package(self, req: Request) -> Response:
package_uri = http_uri_components_to_uri( package_uri = http_uri_components_to_uri(
protocol=req.match_info["protocol"], protocol=req.match_info["protocol"],
package_name=req.match_info["package_name"]) package_name=req.match_info["package_name"],
)
if not package_exists(package_uri): if not package_exists(package_uri):
return Response( return Response(
text=f"Package {package_uri} does not exist", text=f"Package {package_uri} does not exist",
status=aiohttp.web.HTTPNotFound.status_code) status=aiohttp.web.HTTPNotFound.status_code,
)
return Response() return Response()
@ -127,14 +133,16 @@ class JobHead(dashboard_utils.DashboardHeadModule):
async def upload_package(self, req: Request): async def upload_package(self, req: Request):
package_uri = http_uri_components_to_uri( package_uri = http_uri_components_to_uri(
protocol=req.match_info["protocol"], protocol=req.match_info["protocol"],
package_name=req.match_info["package_name"]) package_name=req.match_info["package_name"],
)
logger.info(f"Uploading package {package_uri} to the GCS.") logger.info(f"Uploading package {package_uri} to the GCS.")
try: try:
upload_package_to_gcs(package_uri, await req.read()) upload_package_to_gcs(package_uri, await req.read())
except Exception: except Exception:
return Response( return Response(
text=traceback.format_exc(), text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code) status=aiohttp.web.HTTPInternalServerError.status_code,
)
return Response(status=aiohttp.web.HTTPOk.status_code) return Response(status=aiohttp.web.HTTPOk.status_code)
@ -153,17 +161,20 @@ class JobHead(dashboard_utils.DashboardHeadModule):
entrypoint=submit_request.entrypoint, entrypoint=submit_request.entrypoint,
job_id=submit_request.job_id, job_id=submit_request.job_id,
runtime_env=submit_request.runtime_env, runtime_env=submit_request.runtime_env,
metadata=submit_request.metadata) metadata=submit_request.metadata,
)
resp = JobSubmitResponse(job_id=job_id) resp = JobSubmitResponse(job_id=job_id)
except (TypeError, ValueError): except (TypeError, ValueError):
return Response( return Response(
text=traceback.format_exc(), text=traceback.format_exc(),
status=aiohttp.web.HTTPBadRequest.status_code) status=aiohttp.web.HTTPBadRequest.status_code,
)
except Exception: except Exception:
return Response( return Response(
text=traceback.format_exc(), text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code) status=aiohttp.web.HTTPInternalServerError.status_code,
)
return Response( return Response(
text=json.dumps(dataclasses.asdict(resp)), text=json.dumps(dataclasses.asdict(resp)),
@ -178,7 +189,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
if not self.job_exists(job_id): if not self.job_exists(job_id):
return Response( return Response(
text=f"Job {job_id} does not exist", text=f"Job {job_id} does not exist",
status=aiohttp.web.HTTPNotFound.status_code) status=aiohttp.web.HTTPNotFound.status_code,
)
try: try:
stopped = self._job_manager.stop_job(job_id) stopped = self._job_manager.stop_job(job_id)
@ -186,11 +198,12 @@ class JobHead(dashboard_utils.DashboardHeadModule):
except Exception: except Exception:
return Response( return Response(
text=traceback.format_exc(), text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code) status=aiohttp.web.HTTPInternalServerError.status_code,
)
return Response( return Response(
text=json.dumps(dataclasses.asdict(resp)), text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
content_type="application/json") )
@routes.get("/api/jobs/{job_id}") @routes.get("/api/jobs/{job_id}")
@_init_ray_and_catch_exceptions @_init_ray_and_catch_exceptions
@ -199,13 +212,14 @@ class JobHead(dashboard_utils.DashboardHeadModule):
if not self.job_exists(job_id): if not self.job_exists(job_id):
return Response( return Response(
text=f"Job {job_id} does not exist", text=f"Job {job_id} does not exist",
status=aiohttp.web.HTTPNotFound.status_code) status=aiohttp.web.HTTPNotFound.status_code,
)
status: JobStatusInfo = self._job_manager.get_job_status(job_id) status: JobStatusInfo = self._job_manager.get_job_status(job_id)
resp = JobStatusResponse(status=status.status, message=status.message) resp = JobStatusResponse(status=status.status, message=status.message)
return Response( return Response(
text=json.dumps(dataclasses.asdict(resp)), text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
content_type="application/json") )
@routes.get("/api/jobs/{job_id}/logs") @routes.get("/api/jobs/{job_id}/logs")
@_init_ray_and_catch_exceptions @_init_ray_and_catch_exceptions
@ -214,12 +228,13 @@ class JobHead(dashboard_utils.DashboardHeadModule):
if not self.job_exists(job_id): if not self.job_exists(job_id):
return Response( return Response(
text=f"Job {job_id} does not exist", text=f"Job {job_id} does not exist",
status=aiohttp.web.HTTPNotFound.status_code) status=aiohttp.web.HTTPNotFound.status_code,
)
resp = JobLogsResponse(logs=self._job_manager.get_job_logs(job_id)) resp = JobLogsResponse(logs=self._job_manager.get_job_logs(job_id))
return Response( return Response(
text=json.dumps(dataclasses.asdict(resp)), text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
content_type="application/json") )
@routes.get("/api/jobs/{job_id}/logs/tail") @routes.get("/api/jobs/{job_id}/logs/tail")
@_init_ray_and_catch_exceptions @_init_ray_and_catch_exceptions
@ -228,7 +243,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
if not self.job_exists(job_id): if not self.job_exists(job_id):
return Response( return Response(
text=f"Job {job_id} does not exist", text=f"Job {job_id} does not exist",
status=aiohttp.web.HTTPNotFound.status_code) status=aiohttp.web.HTTPNotFound.status_code,
)
ws = aiohttp.web.WebSocketResponse() ws = aiohttp.web.WebSocketResponse()
await ws.prepare(req) await ws.prepare(req)

View file

@ -15,8 +15,12 @@ from ray.exceptions import RuntimeEnvSetupError
import ray.ray_constants as ray_constants import ray.ray_constants as ray_constants
from ray.actor import ActorHandle from ray.actor import ActorHandle
from ray.dashboard.modules.job.common import ( from ray.dashboard.modules.job.common import (
JobStatus, JobStatusInfo, JobStatusStorageClient, JOB_ID_METADATA_KEY, JobStatus,
JOB_NAME_METADATA_KEY) JobStatusInfo,
JobStatusStorageClient,
JOB_ID_METADATA_KEY,
JOB_NAME_METADATA_KEY,
)
from ray.dashboard.modules.job.utils import file_tail_iterator from ray.dashboard.modules.job.utils import file_tail_iterator
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
@ -36,8 +40,8 @@ def generate_job_id() -> str:
""" """
rand = random.SystemRandom() rand = random.SystemRandom()
possible_characters = list( possible_characters = list(
set(string.ascii_letters + string.digits) - set(string.ascii_letters + string.digits)
{"I", "l", "o", "O", "0"} # No confusing characters - {"I", "l", "o", "O", "0"} # No confusing characters
) )
id_part = "".join(rand.choices(possible_characters, k=16)) id_part = "".join(rand.choices(possible_characters, k=16))
return f"raysubmit_{id_part}" return f"raysubmit_{id_part}"
@ -47,6 +51,7 @@ class JobLogStorageClient:
""" """
Disk storage for stdout / stderr of driver script logs. Disk storage for stdout / stderr of driver script logs.
""" """
JOB_LOGS_PATH = "job-driver-{job_id}.log" JOB_LOGS_PATH = "job-driver-{job_id}.log"
# Number of last N lines to put in job message upon failure. # Number of last N lines to put in job message upon failure.
NUM_LOG_LINES_ON_ERROR = 10 NUM_LOG_LINES_ON_ERROR = 10
@ -61,9 +66,9 @@ class JobLogStorageClient:
def tail_logs(self, job_id: str) -> Iterator[str]: def tail_logs(self, job_id: str) -> Iterator[str]:
return file_tail_iterator(self.get_log_file_path(job_id)) return file_tail_iterator(self.get_log_file_path(job_id))
def get_last_n_log_lines(self, def get_last_n_log_lines(
job_id: str, self, job_id: str, num_log_lines=NUM_LOG_LINES_ON_ERROR
num_log_lines=NUM_LOG_LINES_ON_ERROR) -> str: ) -> str:
log_tail_iter = self.tail_logs(job_id) log_tail_iter = self.tail_logs(job_id)
log_tail_deque = deque(maxlen=num_log_lines) log_tail_deque = deque(maxlen=num_log_lines)
for line in log_tail_iter: for line in log_tail_iter:
@ -80,7 +85,8 @@ class JobLogStorageClient:
""" """
return os.path.join( return os.path.join(
ray.worker._global_node.get_logs_dir_path(), ray.worker._global_node.get_logs_dir_path(),
self.JOB_LOGS_PATH.format(job_id=job_id)) self.JOB_LOGS_PATH.format(job_id=job_id),
)
class JobSupervisor: class JobSupervisor:
@ -95,8 +101,7 @@ class JobSupervisor:
SUBPROCESS_POLL_PERIOD_S = 0.1 SUBPROCESS_POLL_PERIOD_S = 0.1
def __init__(self, job_id: str, entrypoint: str, def __init__(self, job_id: str, entrypoint: str, user_metadata: Dict[str, str]):
user_metadata: Dict[str, str]):
self._job_id = job_id self._job_id = job_id
self._status_client = JobStatusStorageClient() self._status_client = JobStatusStorageClient()
self._log_client = JobLogStorageClient() self._log_client = JobLogStorageClient()
@ -104,10 +109,7 @@ class JobSupervisor:
self._entrypoint = entrypoint self._entrypoint = entrypoint
# Default metadata if not passed by the user. # Default metadata if not passed by the user.
self._metadata = { self._metadata = {JOB_ID_METADATA_KEY: job_id, JOB_NAME_METADATA_KEY: job_id}
JOB_ID_METADATA_KEY: job_id,
JOB_NAME_METADATA_KEY: job_id
}
self._metadata.update(user_metadata) self._metadata.update(user_metadata)
# fire and forget call from outer job manager to this actor # fire and forget call from outer job manager to this actor
@ -142,7 +144,8 @@ class JobSupervisor:
shell=True, shell=True,
start_new_session=True, start_new_session=True,
stdout=logs_file, stdout=logs_file,
stderr=subprocess.STDOUT) stderr=subprocess.STDOUT,
)
parent_pid = os.getpid() parent_pid = os.getpid()
# Create new pgid with new subprocess to execute driver command # Create new pgid with new subprocess to execute driver command
child_pid = child_process.pid child_pid = child_process.pid
@ -177,9 +180,10 @@ class JobSupervisor:
return 1 return 1
async def run( async def run(
self, self,
# Signal actor used in testing to capture PENDING -> RUNNING cases # Signal actor used in testing to capture PENDING -> RUNNING cases
_start_signal_actor: Optional[ActorHandle] = None): _start_signal_actor: Optional[ActorHandle] = None,
):
""" """
Stop and start both happen asynchrously, coordinated by asyncio event Stop and start both happen asynchrously, coordinated by asyncio event
and coroutine, respectively. and coroutine, respectively.
@ -190,26 +194,26 @@ class JobSupervisor:
3) Handle concurrent events of driver execution and 3) Handle concurrent events of driver execution and
""" """
cur_status = self._get_status() cur_status = self._get_status()
assert cur_status.status == JobStatus.PENDING, ( assert cur_status.status == JobStatus.PENDING, "Run should only be called once."
"Run should only be called once.")
if _start_signal_actor: if _start_signal_actor:
# Block in PENDING state until start signal received. # Block in PENDING state until start signal received.
await _start_signal_actor.wait.remote() await _start_signal_actor.wait.remote()
self._status_client.put_status(self._job_id, self._status_client.put_status(self._job_id, JobStatusInfo(JobStatus.RUNNING))
JobStatusInfo(JobStatus.RUNNING))
try: try:
# Set JobConfig for the child process (runtime_env, metadata). # Set JobConfig for the child process (runtime_env, metadata).
os.environ[RAY_JOB_CONFIG_JSON_ENV_VAR] = json.dumps({ os.environ[RAY_JOB_CONFIG_JSON_ENV_VAR] = json.dumps(
"runtime_env": self._runtime_env, {
"metadata": self._metadata, "runtime_env": self._runtime_env,
}) "metadata": self._metadata,
}
)
# Set RAY_ADDRESS to local Ray address, if it is not set. # Set RAY_ADDRESS to local Ray address, if it is not set.
os.environ[ os.environ[
ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE] = \ ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE
ray._private.services.get_ray_address_from_environment() ] = ray._private.services.get_ray_address_from_environment()
# Set PYTHONUNBUFFERED=1 to stream logs during the job instead of # Set PYTHONUNBUFFERED=1 to stream logs during the job instead of
# only streaming them upon completion of the job. # only streaming them upon completion of the job.
os.environ["PYTHONUNBUFFERED"] = "1" os.environ["PYTHONUNBUFFERED"] = "1"
@ -218,8 +222,8 @@ class JobSupervisor:
polling_task = create_task(self._polling(child_process)) polling_task = create_task(self._polling(child_process))
finished, _ = await asyncio.wait( finished, _ = await asyncio.wait(
[polling_task, self._stop_event.wait()], [polling_task, self._stop_event.wait()], return_when=FIRST_COMPLETED
return_when=FIRST_COMPLETED) )
if self._stop_event.is_set(): if self._stop_event.is_set():
polling_task.cancel() polling_task.cancel()
@ -229,29 +233,29 @@ class JobSupervisor:
else: else:
# Child process finished execution and no stop event is set # Child process finished execution and no stop event is set
# at the same time # at the same time
assert len( assert len(finished) == 1, "Should have only one coroutine done"
finished) == 1, "Should have only one coroutine done"
[child_process_task] = finished [child_process_task] = finished
return_code = child_process_task.result() return_code = child_process_task.result()
if return_code == 0: if return_code == 0:
self._status_client.put_status(self._job_id, self._status_client.put_status(self._job_id, JobStatus.SUCCEEDED)
JobStatus.SUCCEEDED)
else: else:
log_tail = self._log_client.get_last_n_log_lines( log_tail = self._log_client.get_last_n_log_lines(self._job_id)
self._job_id)
if log_tail is not None and log_tail != "": if log_tail is not None and log_tail != "":
message = ("Job failed due to an application error, " message = (
"last available logs:\n" + log_tail) "Job failed due to an application error, "
"last available logs:\n" + log_tail
)
else: else:
message = None message = None
self._status_client.put_status( self._status_client.put_status(
self._job_id, self._job_id,
JobStatusInfo( JobStatusInfo(status=JobStatus.FAILED, message=message),
status=JobStatus.FAILED, message=message)) )
except Exception: except Exception:
logger.error( logger.error(
"Got unexpected exception while trying to execute driver " "Got unexpected exception while trying to execute driver "
f"command. {traceback.format_exc()}") f"command. {traceback.format_exc()}"
)
finally: finally:
# clean up actor after tasks are finished # clean up actor after tasks are finished
ray.actor.exit_actor() ray.actor.exit_actor()
@ -260,8 +264,7 @@ class JobSupervisor:
return self._status_client.get_status(self._job_id) return self._status_client.get_status(self._job_id)
def stop(self): def stop(self):
"""Set step_event and let run() handle the rest in its asyncio.wait(). """Set step_event and let run() handle the rest in its asyncio.wait()."""
"""
self._stop_event.set() self._stop_event.set()
@ -271,6 +274,7 @@ class JobManager:
It does not provide persistence, all info will be lost if the cluster It does not provide persistence, all info will be lost if the cluster
goes down. goes down.
""" """
JOB_ACTOR_NAME = "_ray_internal_job_actor_{job_id}" JOB_ACTOR_NAME = "_ray_internal_job_actor_{job_id}"
# Time that we will sleep while tailing logs if no new log line is # Time that we will sleep while tailing logs if no new log line is
# available. # available.
@ -300,11 +304,9 @@ class JobManager:
if key.startswith("node:"): if key.startswith("node:"):
return key return key
else: else:
raise ValueError( raise ValueError("Cannot find the node dictionary for current node.")
"Cannot find the node dictionary for current node.")
def _handle_supervisor_startup(self, job_id: str, def _handle_supervisor_startup(self, job_id: str, result: Optional[Exception]):
result: Optional[Exception]):
"""Handle the result of starting a job supervisor actor. """Handle the result of starting a job supervisor actor.
If started successfully, result should be None. Otherwise it should be If started successfully, result should be None. Otherwise it should be
@ -321,26 +323,30 @@ class JobManager:
job_id, job_id,
JobStatusInfo( JobStatusInfo(
status=JobStatus.FAILED, status=JobStatus.FAILED,
message=(f"runtime_env setup failed: {result}"))) message=(f"runtime_env setup failed: {result}"),
),
)
elif isinstance(result, Exception): elif isinstance(result, Exception):
logger.error( logger.error(f"Failed to start supervisor for job {job_id}: {result}.")
f"Failed to start supervisor for job {job_id}: {result}.")
self._status_client.put_status( self._status_client.put_status(
job_id, job_id,
JobStatusInfo( JobStatusInfo(
status=JobStatus.FAILED, status=JobStatus.FAILED,
message=f"Error occurred while starting the job: {result}") message=f"Error occurred while starting the job: {result}",
),
) )
else: else:
assert False, "This should not be reached." assert False, "This should not be reached."
def submit_job(self, def submit_job(
*, self,
entrypoint: str, *,
job_id: Optional[str] = None, entrypoint: str,
runtime_env: Optional[Dict[str, Any]] = None, job_id: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None, runtime_env: Optional[Dict[str, Any]] = None,
_start_signal_actor: Optional[ActorHandle] = None) -> str: metadata: Optional[Dict[str, str]] = None,
_start_signal_actor: Optional[ActorHandle] = None,
) -> str:
""" """
Job execution happens asynchronously. Job execution happens asynchronously.
@ -390,8 +396,8 @@ class JobManager:
resources={ resources={
self._get_current_node_resource_key(): 0.001, self._get_current_node_resource_key(): 0.001,
}, },
runtime_env=runtime_env).remote(job_id, entrypoint, metadata runtime_env=runtime_env,
or {}) ).remote(job_id, entrypoint, metadata or {})
actor.run.remote(_start_signal_actor=_start_signal_actor) actor.run.remote(_start_signal_actor=_start_signal_actor)
def callback(result: Optional[Exception]): def callback(result: Optional[Exception]):
@ -441,7 +447,8 @@ class JobManager:
# updating GCS with latest status. # updating GCS with latest status.
last_status = self._status_client.get_status(job_id) last_status = self._status_client.get_status(job_id)
if last_status and last_status.status in { if last_status and last_status.status in {
JobStatus.PENDING, JobStatus.RUNNING JobStatus.PENDING,
JobStatus.RUNNING,
}: }:
self._status_client.put_status(job_id, JobStatus.FAILED) self._status_client.put_status(job_id, JobStatus.FAILED)

View file

@ -13,10 +13,19 @@ except ImportError:
requests = None requests = None
from ray._private.runtime_env.packaging import ( from ray._private.runtime_env.packaging import (
create_package, get_uri_for_directory, parse_uri) create_package,
get_uri_for_directory,
parse_uri,
)
from ray.dashboard.modules.job.common import ( from ray.dashboard.modules.job.common import (
JobSubmitRequest, JobSubmitResponse, JobStopResponse, JobStatusInfo, JobSubmitRequest,
JobStatusResponse, JobLogsResponse, uri_to_http_components) JobSubmitResponse,
JobStopResponse,
JobStatusInfo,
JobStatusResponse,
JobLogsResponse,
uri_to_http_components,
)
from ray.client_builder import _split_address from ray.client_builder import _split_address
@ -33,51 +42,49 @@ class ClusterInfo:
def get_job_submission_client_cluster_info( def get_job_submission_client_cluster_info(
address: str, address: str,
# For backwards compatibility # For backwards compatibility
*, *,
# only used in importlib case in parse_cluster_info, but needed # only used in importlib case in parse_cluster_info, but needed
# in function signature. # in function signature.
create_cluster_if_needed: Optional[bool] = False, create_cluster_if_needed: Optional[bool] = False,
cookies: Optional[Dict[str, Any]] = None, cookies: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None) -> ClusterInfo: headers: Optional[Dict[str, Any]] = None,
) -> ClusterInfo:
"""Get address, cookies, and metadata used for JobSubmissionClient. """Get address, cookies, and metadata used for JobSubmissionClient.
Args: Args:
address (str): Address without the module prefix that is passed address (str): Address without the module prefix that is passed
to JobSubmissionClient. to JobSubmissionClient.
create_cluster_if_needed (bool): Indicates whether the cluster create_cluster_if_needed (bool): Indicates whether the cluster
of the address returned needs to be running. Ray doesn't of the address returned needs to be running. Ray doesn't
start a cluster before interacting with jobs, but other start a cluster before interacting with jobs, but other
implementations may do so. implementations may do so.
Returns: Returns:
ClusterInfo object consisting of address, cookies, and metadata ClusterInfo object consisting of address, cookies, and metadata
for JobSubmissionClient to use. for JobSubmissionClient to use.
""" """
return ClusterInfo( return ClusterInfo(
address="http://" + address, address="http://" + address, cookies=cookies, metadata=metadata, headers=headers
cookies=cookies, )
metadata=metadata,
headers=headers)
def parse_cluster_info( def parse_cluster_info(
address: str, address: str,
create_cluster_if_needed: bool = False, create_cluster_if_needed: bool = False,
cookies: Optional[Dict[str, Any]] = None, cookies: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None) -> ClusterInfo: headers: Optional[Dict[str, Any]] = None,
) -> ClusterInfo:
module_string, inner_address = _split_address(address.rstrip("/")) module_string, inner_address = _split_address(address.rstrip("/"))
# If user passes in a raw HTTP(S) address, just pass it through. # If user passes in a raw HTTP(S) address, just pass it through.
if module_string == "http" or module_string == "https": if module_string == "http" or module_string == "https":
return ClusterInfo( return ClusterInfo(
address=address, address=address, cookies=cookies, metadata=metadata, headers=headers
cookies=cookies, )
metadata=metadata,
headers=headers)
# If user passes in a Ray address, convert it to HTTP. # If user passes in a Ray address, convert it to HTTP.
elif module_string == "ray": elif module_string == "ray":
return get_job_submission_client_cluster_info( return get_job_submission_client_cluster_info(
@ -85,7 +92,8 @@ def parse_cluster_info(
create_cluster_if_needed=create_cluster_if_needed, create_cluster_if_needed=create_cluster_if_needed,
cookies=cookies, cookies=cookies,
metadata=metadata, metadata=metadata,
headers=headers) headers=headers,
)
# Try to dynamically import the function to get cluster info. # Try to dynamically import the function to get cluster info.
else: else:
try: try:
@ -93,33 +101,40 @@ def parse_cluster_info(
except Exception: except Exception:
raise RuntimeError( raise RuntimeError(
f"Module: {module_string} does not exist.\n" f"Module: {module_string} does not exist.\n"
f"This module was parsed from Address: {address}") from None f"This module was parsed from Address: {address}"
) from None
assert "get_job_submission_client_cluster_info" in dir(module), ( assert "get_job_submission_client_cluster_info" in dir(module), (
f"Module: {module_string} does " f"Module: {module_string} does "
"not have `get_job_submission_client_cluster_info`.") "not have `get_job_submission_client_cluster_info`."
)
return module.get_job_submission_client_cluster_info( return module.get_job_submission_client_cluster_info(
inner_address, inner_address,
create_cluster_if_needed=create_cluster_if_needed, create_cluster_if_needed=create_cluster_if_needed,
cookies=cookies, cookies=cookies,
metadata=metadata, metadata=metadata,
headers=headers) headers=headers,
)
class JobSubmissionClient: class JobSubmissionClient:
def __init__(self, def __init__(
address: str, self,
create_cluster_if_needed=False, address: str,
cookies: Optional[Dict[str, Any]] = None, create_cluster_if_needed=False,
metadata: Optional[Dict[str, Any]] = None, cookies: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None): metadata: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
):
if requests is None: if requests is None:
raise RuntimeError( raise RuntimeError(
"The Ray jobs CLI & SDK require the ray[default] " "The Ray jobs CLI & SDK require the ray[default] "
"installation: `pip install 'ray[default']``") "installation: `pip install 'ray[default']``"
)
cluster_info = parse_cluster_info(address, create_cluster_if_needed, cluster_info = parse_cluster_info(
cookies, metadata, headers) address, create_cluster_if_needed, cookies, metadata, headers
)
self._address = cluster_info.address self._address = cluster_info.address
self._cookies = cluster_info.cookies self._cookies = cluster_info.cookies
self._default_metadata = cluster_info.metadata or {} self._default_metadata = cluster_info.metadata or {}
@ -136,38 +151,43 @@ class JobSubmissionClient:
raise RuntimeError( raise RuntimeError(
"Jobs API not supported on the Ray cluster. " "Jobs API not supported on the Ray cluster. "
"Please ensure the cluster is running " "Please ensure the cluster is running "
"Ray 1.9 or higher.") "Ray 1.9 or higher."
)
r.raise_for_status() r.raise_for_status()
# TODO(edoakes): check the version if/when we break compatibility. # TODO(edoakes): check the version if/when we break compatibility.
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
raise ConnectionError( raise ConnectionError(
f"Failed to connect to Ray at address: {self._address}.") f"Failed to connect to Ray at address: {self._address}."
)
def _raise_error(self, r: "requests.Response"): def _raise_error(self, r: "requests.Response"):
raise RuntimeError( raise RuntimeError(
f"Request failed with status code {r.status_code}: {r.text}.") f"Request failed with status code {r.status_code}: {r.text}."
)
def _do_request(self, def _do_request(
method: str, self,
endpoint: str, method: str,
*, endpoint: str,
data: Optional[bytes] = None, *,
json_data: Optional[dict] = None) -> Optional[object]: data: Optional[bytes] = None,
json_data: Optional[dict] = None,
) -> Optional[object]:
url = self._address + endpoint url = self._address + endpoint
logger.debug( logger.debug(f"Sending request to {url} with json data: {json_data or {}}.")
f"Sending request to {url} with json data: {json_data or {}}.")
return requests.request( return requests.request(
method, method,
url, url,
cookies=self._cookies, cookies=self._cookies,
data=data, data=data,
json=json_data, json=json_data,
headers=self._headers) headers=self._headers,
)
def _package_exists( def _package_exists(
self, self,
package_uri: str, package_uri: str,
) -> bool: ) -> bool:
protocol, package_name = uri_to_http_components(package_uri) protocol, package_name = uri_to_http_components(package_uri)
r = self._do_request("GET", f"/api/packages/{protocol}/{package_name}") r = self._do_request("GET", f"/api/packages/{protocol}/{package_name}")
@ -181,11 +201,13 @@ class JobSubmissionClient:
else: else:
self._raise_error(r) self._raise_error(r)
def _upload_package(self, def _upload_package(
package_uri: str, self,
package_path: str, package_uri: str,
include_parent_dir: Optional[bool] = False, package_path: str,
excludes: Optional[List[str]] = None) -> bool: include_parent_dir: Optional[bool] = False,
excludes: Optional[List[str]] = None,
) -> bool:
logger.info(f"Uploading package {package_uri}.") logger.info(f"Uploading package {package_uri}.")
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
protocol, package_name = uri_to_http_components(package_uri) protocol, package_name = uri_to_http_components(package_uri)
@ -194,26 +216,27 @@ class JobSubmissionClient:
package_path, package_path,
package_file, package_file,
include_parent_dir=include_parent_dir, include_parent_dir=include_parent_dir,
excludes=excludes) excludes=excludes,
)
try: try:
r = self._do_request( r = self._do_request(
"PUT", "PUT",
f"/api/packages/{protocol}/{package_name}", f"/api/packages/{protocol}/{package_name}",
data=package_file.read_bytes()) data=package_file.read_bytes(),
)
if r.status_code != 200: if r.status_code != 200:
self._raise_error(r) self._raise_error(r)
finally: finally:
package_file.unlink() package_file.unlink()
def _upload_package_if_needed(self, def _upload_package_if_needed(
package_path: str, self, package_path: str, excludes: Optional[List[str]] = None
excludes: Optional[List[str]] = None) -> str: ) -> str:
package_uri = get_uri_for_directory(package_path, excludes=excludes) package_uri = get_uri_for_directory(package_path, excludes=excludes)
if not self._package_exists(package_uri): if not self._package_exists(package_uri):
self._upload_package(package_uri, package_path, excludes=excludes) self._upload_package(package_uri, package_path, excludes=excludes)
else: else:
logger.info( logger.info(f"Package {package_uri} already exists, skipping upload.")
f"Package {package_uri} already exists, skipping upload.")
return package_uri return package_uri
@ -230,7 +253,8 @@ class JobSubmissionClient:
if not is_uri: if not is_uri:
logger.debug("working_dir is not a URI, attempting to upload.") logger.debug("working_dir is not a URI, attempting to upload.")
package_uri = self._upload_package_if_needed( package_uri = self._upload_package_if_needed(
working_dir, excludes=runtime_env.get("excludes", None)) working_dir, excludes=runtime_env.get("excludes", None)
)
runtime_env["working_dir"] = package_uri runtime_env["working_dir"] = package_uri
def get_version(self) -> str: def get_version(self) -> str:
@ -241,12 +265,12 @@ class JobSubmissionClient:
self._raise_error(r) self._raise_error(r)
def submit_job( def submit_job(
self, self,
*, *,
entrypoint: str, entrypoint: str,
job_id: Optional[str] = None, job_id: Optional[str] = None,
runtime_env: Optional[Dict[str, Any]] = None, runtime_env: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, str]] = None,
) -> str: ) -> str:
runtime_env = runtime_env or {} runtime_env = runtime_env or {}
metadata = metadata or {} metadata = metadata or {}
@ -257,11 +281,11 @@ class JobSubmissionClient:
entrypoint=entrypoint, entrypoint=entrypoint,
job_id=job_id, job_id=job_id,
runtime_env=runtime_env, runtime_env=runtime_env,
metadata=metadata) metadata=metadata,
)
logger.debug(f"Submitting job with job_id={job_id}.") logger.debug(f"Submitting job with job_id={job_id}.")
r = self._do_request( r = self._do_request("POST", "/api/jobs/", json_data=dataclasses.asdict(req))
"POST", "/api/jobs/", json_data=dataclasses.asdict(req))
if r.status_code == 200: if r.status_code == 200:
return JobSubmitResponse(**r.json()).job_id return JobSubmitResponse(**r.json()).job_id
@ -269,8 +293,8 @@ class JobSubmissionClient:
self._raise_error(r) self._raise_error(r)
def stop_job( def stop_job(
self, self,
job_id: str, job_id: str,
) -> bool: ) -> bool:
logger.debug(f"Stopping job with job_id={job_id}.") logger.debug(f"Stopping job with job_id={job_id}.")
r = self._do_request("POST", f"/api/jobs/{job_id}/stop") r = self._do_request("POST", f"/api/jobs/{job_id}/stop")
@ -281,15 +305,14 @@ class JobSubmissionClient:
self._raise_error(r) self._raise_error(r)
def get_job_status( def get_job_status(
self, self,
job_id: str, job_id: str,
) -> JobStatusInfo: ) -> JobStatusInfo:
r = self._do_request("GET", f"/api/jobs/{job_id}") r = self._do_request("GET", f"/api/jobs/{job_id}")
if r.status_code == 200: if r.status_code == 200:
response = JobStatusResponse(**r.json()) response = JobStatusResponse(**r.json())
return JobStatusInfo( return JobStatusInfo(status=response.status, message=response.message)
status=response.status, message=response.message)
else: else:
self._raise_error(r) self._raise_error(r)
@ -304,7 +327,8 @@ class JobSubmissionClient:
async def tail_job_logs(self, job_id: str) -> Iterator[str]: async def tail_job_logs(self, job_id: str) -> Iterator[str]:
async with aiohttp.ClientSession(cookies=self._cookies) as session: async with aiohttp.ClientSession(cookies=self._cookies) as session:
ws = await session.ws_connect( ws = await session.ws_connect(
f"{self._address}/api/jobs/{job_id}/logs/tail") f"{self._address}/api/jobs/{job_id}/logs/tail"
)
while True: while True:
msg = await ws.receive() msg = await ws.receive()

View file

@ -7,13 +7,13 @@ we ended up using job submission API call's runtime_env instead of scripts
def run(): def run():
import ray import ray
import os import os
ray.init( ray.init(
address=os.environ["RAY_ADDRESS"], address=os.environ["RAY_ADDRESS"],
runtime_env={ runtime_env={
"env_vars": { "env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "SHOULD_BE_OVERRIDEN"}
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "SHOULD_BE_OVERRIDEN" },
} )
})
@ray.remote @ray.remote
def foo(): def foo():

View file

@ -21,15 +21,14 @@ def conda_env(env_name):
# Clean up created conda env upon test exit to prevent leaking # Clean up created conda env upon test exit to prevent leaking
del os.environ["JOB_COMPATIBILITY_TEST_TEMP_ENV"] del os.environ["JOB_COMPATIBILITY_TEST_TEMP_ENV"]
subprocess.run( subprocess.run(
f"conda env remove -y --name {env_name}", f"conda env remove -y --name {env_name}", shell=True, stdout=subprocess.PIPE
shell=True, )
stdout=subprocess.PIPE)
def _compatibility_script_path(file_name: str) -> str: def _compatibility_script_path(file_name: str) -> str:
return os.path.join( return os.path.join(
os.path.dirname(__file__), "backwards_compatibility_scripts", os.path.dirname(__file__), "backwards_compatibility_scripts", file_name
file_name) )
class TestBackwardsCompatibility: class TestBackwardsCompatibility:
@ -48,8 +47,7 @@ class TestBackwardsCompatibility:
shell_cmd = f"{_compatibility_script_path('test_backwards_compatibility.sh')}" # noqa: E501 shell_cmd = f"{_compatibility_script_path('test_backwards_compatibility.sh')}" # noqa: E501
try: try:
subprocess.check_output( subprocess.check_output(shell_cmd, shell=True, stderr=subprocess.STDOUT)
shell_cmd, shell=True, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
logger.error(str(e)) logger.error(str(e))
logger.error(e.stdout.decode()) logger.error(e.stdout.decode())

View file

@ -34,8 +34,7 @@ def mock_sdk_client():
if "RAY_ADDRESS" in os.environ: if "RAY_ADDRESS" in os.environ:
del os.environ["RAY_ADDRESS"] del os.environ["RAY_ADDRESS"]
with mock.patch("ray.dashboard.modules.job.cli.JobSubmissionClient" with mock.patch("ray.dashboard.modules.job.cli.JobSubmissionClient") as mock_client:
) as mock_client:
# In python 3.6 it will fail with error # In python 3.6 it will fail with error
# 'async for' requires an object with __aiter__ method, got MagicMock" # 'async for' requires an object with __aiter__ method, got MagicMock"
mock_client().tail_job_logs.return_value = AsyncIterator(range(10)) mock_client().tail_job_logs.return_value = AsyncIterator(range(10))
@ -52,9 +51,7 @@ def runtime_env_formats():
"working_dir": "s3://bogus.zip", "working_dir": "s3://bogus.zip",
"conda": "conda_env", "conda": "conda_env",
"pip": ["pip-install-test"], "pip": ["pip-install-test"],
"env_vars": { "env_vars": {"hi": "hi2"},
"hi": "hi2"
}
} }
yaml_file = path / "env.yaml" yaml_file = path / "env.yaml"
@ -86,14 +83,13 @@ class TestSubmit:
# Test passing address via command line. # Test passing address via command line.
result = runner.invoke( result = runner.invoke(
job_cli_group, job_cli_group, ["submit", "--address=arg_addr", "--", "echo hello"]
["submit", "--address=arg_addr", "--", "echo hello"]) )
assert mock_sdk_client.called_with("arg_addr") assert mock_sdk_client.called_with("arg_addr")
assert result.exit_code == 0 assert result.exit_code == 0
# Test passing address via env var. # Test passing address via env var.
with set_env_var("RAY_ADDRESS", "env_addr"): with set_env_var("RAY_ADDRESS", "env_addr"):
result = runner.invoke(job_cli_group, result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
["submit", "--", "echo hello"])
assert result.exit_code == 0 assert result.exit_code == 0
assert mock_sdk_client.called_with("env_addr") assert mock_sdk_client.called_with("env_addr")
# Test passing no address. # Test passing no address.
@ -106,24 +102,22 @@ class TestSubmit:
mock_client_instance = mock_sdk_client.return_value mock_client_instance = mock_sdk_client.return_value
with set_env_var("RAY_ADDRESS", "env_addr"): with set_env_var("RAY_ADDRESS", "env_addr"):
result = runner.invoke(job_cli_group, result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
["submit", "--", "echo hello"])
assert result.exit_code == 0 assert result.exit_code == 0
assert mock_client_instance.called_with(runtime_env={}) assert mock_client_instance.called_with(runtime_env={})
result = runner.invoke( result = runner.invoke(
job_cli_group, job_cli_group,
["submit", "--", "--working-dir", "blah", "--", "echo hello"]) ["submit", "--", "--working-dir", "blah", "--", "echo hello"],
)
assert result.exit_code == 0 assert result.exit_code == 0
assert mock_client_instance.called_with( assert mock_client_instance.called_with(runtime_env={"working_dir": "blah"})
runtime_env={"working_dir": "blah"})
result = runner.invoke( result = runner.invoke(
job_cli_group, job_cli_group, ["submit", "--", "--working-dir='.'", "--", "echo hello"]
["submit", "--", "--working-dir='.'", "--", "echo hello"]) )
assert result.exit_code == 0 assert result.exit_code == 0
assert mock_client_instance.called_with( assert mock_client_instance.called_with(runtime_env={"working_dir": "."})
runtime_env={"working_dir": "."})
def test_runtime_env(self, mock_sdk_client, runtime_env_formats): def test_runtime_env(self, mock_sdk_client, runtime_env_formats):
runner = CliRunner() runner = CliRunner()
@ -133,39 +127,64 @@ class TestSubmit:
with set_env_var("RAY_ADDRESS", "env_addr"): with set_env_var("RAY_ADDRESS", "env_addr"):
# Test passing via file. # Test passing via file.
result = runner.invoke( result = runner.invoke(
job_cli_group, job_cli_group, ["submit", "--runtime-env", env_yaml, "--", "echo hello"]
["submit", "--runtime-env", env_yaml, "--", "echo hello"]) )
assert result.exit_code == 0 assert result.exit_code == 0
assert mock_client_instance.called_with(runtime_env=env_dict) assert mock_client_instance.called_with(runtime_env=env_dict)
# Test passing via json. # Test passing via json.
result = runner.invoke( result = runner.invoke(
job_cli_group, job_cli_group,
["submit", "--runtime-env-json", env_json, "--", "echo hello"]) ["submit", "--runtime-env-json", env_json, "--", "echo hello"],
)
assert result.exit_code == 0 assert result.exit_code == 0
assert mock_client_instance.called_with(runtime_env=env_dict) assert mock_client_instance.called_with(runtime_env=env_dict)
# Test passing both throws an error. # Test passing both throws an error.
result = runner.invoke(job_cli_group, [ result = runner.invoke(
"submit", "--runtime-env", env_yaml, "--runtime-env-json", job_cli_group,
env_json, "--", "echo hello" [
]) "submit",
"--runtime-env",
env_yaml,
"--runtime-env-json",
env_json,
"--",
"echo hello",
],
)
assert result.exit_code == 1 assert result.exit_code == 1
assert "Only one of" in str(result.exception) assert "Only one of" in str(result.exception)
# Test overriding working_dir. # Test overriding working_dir.
env_dict.update(working_dir=".") env_dict.update(working_dir=".")
result = runner.invoke(job_cli_group, [ result = runner.invoke(
"submit", "--runtime-env", env_yaml, "--working-dir", ".", job_cli_group,
"--", "echo hello" [
]) "submit",
"--runtime-env",
env_yaml,
"--working-dir",
".",
"--",
"echo hello",
],
)
assert result.exit_code == 0 assert result.exit_code == 0
assert mock_client_instance.called_with(runtime_env=env_dict) assert mock_client_instance.called_with(runtime_env=env_dict)
result = runner.invoke(job_cli_group, [ result = runner.invoke(
"submit", "--runtime-env-json", env_json, "--working-dir", ".", job_cli_group,
"--", "echo hello" [
]) "submit",
"--runtime-env-json",
env_json,
"--working-dir",
".",
"--",
"echo hello",
],
)
assert result.exit_code == 0 assert result.exit_code == 0
assert mock_client_instance.called_with(runtime_env=env_dict) assert mock_client_instance.called_with(runtime_env=env_dict)
@ -174,18 +193,18 @@ class TestSubmit:
mock_client_instance = mock_sdk_client.return_value mock_client_instance = mock_sdk_client.return_value
with set_env_var("RAY_ADDRESS", "env_addr"): with set_env_var("RAY_ADDRESS", "env_addr"):
result = runner.invoke(job_cli_group, result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
["submit", "--", "echo hello"])
assert result.exit_code == 0 assert result.exit_code == 0
assert mock_client_instance.called_with(job_id=None) assert mock_client_instance.called_with(job_id=None)
result = runner.invoke( result = runner.invoke(
job_cli_group, job_cli_group, ["submit", "--", "--job-id=my_job_id", "echo hello"]
["submit", "--", "--job-id=my_job_id", "echo hello"]) )
assert result.exit_code == 0 assert result.exit_code == 0
assert mock_client_instance.called_with(job_id="my_job_id") assert mock_client_instance.called_with(job_id="my_job_id")
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
sys.exit(pytest.main(["-v", __file__])) sys.exit(pytest.main(["-v", __file__]))

View file

@ -60,25 +60,27 @@ class TestRayAddress:
def test_empty_ray_address(self, ray_start_stop): def test_empty_ray_address(self, ray_start_stop):
with set_env_var("RAY_ADDRESS", None): with set_env_var("RAY_ADDRESS", None):
completed_process = subprocess.run( completed_process = subprocess.run(
["ray", "job", "submit", "--", "echo hello"], ["ray", "job", "submit", "--", "echo hello"], stderr=subprocess.PIPE
stderr=subprocess.PIPE) )
stderr = completed_process.stderr.decode("utf-8") stderr = completed_process.stderr.decode("utf-8")
# Current dashboard module that raises no exception from requests.. # Current dashboard module that raises no exception from requests..
assert ("Address must be specified using either the " assert (
"--address flag or RAY_ADDRESS environment") in stderr "Address must be specified using either the "
"--address flag or RAY_ADDRESS environment"
) in stderr
def test_ray_client_address(self, ray_start_stop): def test_ray_client_address(self, ray_start_stop):
completed_process = subprocess.run( completed_process = subprocess.run(
["ray", "job", "submit", "--", "echo hello"], ["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
stdout=subprocess.PIPE) )
stdout = completed_process.stdout.decode("utf-8") stdout = completed_process.stdout.decode("utf-8")
assert "hello" in stdout assert "hello" in stdout
assert "succeeded" in stdout assert "succeeded" in stdout
def test_valid_http_ray_address(self, ray_start_stop): def test_valid_http_ray_address(self, ray_start_stop):
completed_process = subprocess.run( completed_process = subprocess.run(
["ray", "job", "submit", "--", "echo hello"], ["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
stdout=subprocess.PIPE) )
stdout = completed_process.stdout.decode("utf-8") stdout = completed_process.stdout.decode("utf-8")
assert "hello" in stdout assert "hello" in stdout
assert "succeeded" in stdout assert "succeeded" in stdout
@ -87,8 +89,8 @@ class TestRayAddress:
with set_env_var("RAY_ADDRESS", "http://127.0.0.1:8265"): with set_env_var("RAY_ADDRESS", "http://127.0.0.1:8265"):
with ray_cluster_manager(): with ray_cluster_manager():
completed_process = subprocess.run( completed_process = subprocess.run(
["ray", "job", "submit", "--", "echo hello"], ["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
stdout=subprocess.PIPE) )
stdout = completed_process.stdout.decode("utf-8") stdout = completed_process.stdout.decode("utf-8")
assert "hello" in stdout assert "hello" in stdout
assert "succeeded" in stdout assert "succeeded" in stdout
@ -97,8 +99,8 @@ class TestRayAddress:
with set_env_var("RAY_ADDRESS", "127.0.0.1:8265"): with set_env_var("RAY_ADDRESS", "127.0.0.1:8265"):
with ray_cluster_manager(): with ray_cluster_manager():
completed_process = subprocess.run( completed_process = subprocess.run(
["ray", "job", "submit", "--", "echo hello"], ["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
stdout=subprocess.PIPE) )
stdout = completed_process.stdout.decode("utf-8") stdout = completed_process.stdout.decode("utf-8")
assert "hello" in stdout assert "hello" in stdout
assert "succeeded" in stdout assert "succeeded" in stdout
@ -109,7 +111,8 @@ class TestJobSubmit:
"""Should tail logs and wait for process to exit.""" """Should tail logs and wait for process to exit."""
cmd = "sleep 1 && echo hello && sleep 1 && echo hello" cmd = "sleep 1 && echo hello && sleep 1 && echo hello"
completed_process = subprocess.run( completed_process = subprocess.run(
["ray", "job", "submit", "--", cmd], stdout=subprocess.PIPE) ["ray", "job", "submit", "--", cmd], stdout=subprocess.PIPE
)
stdout = completed_process.stdout.decode("utf-8") stdout = completed_process.stdout.decode("utf-8")
assert "hello\nhello" in stdout assert "hello\nhello" in stdout
assert "succeeded" in stdout assert "succeeded" in stdout
@ -118,8 +121,8 @@ class TestJobSubmit:
"""Should exit immediately w/o printing logs.""" """Should exit immediately w/o printing logs."""
cmd = "echo hello && sleep 1000" cmd = "echo hello && sleep 1000"
completed_process = subprocess.run( completed_process = subprocess.run(
["ray", "job", "submit", "--no-wait", "--", cmd], ["ray", "job", "submit", "--no-wait", "--", cmd], stdout=subprocess.PIPE
stdout=subprocess.PIPE) )
stdout = completed_process.stdout.decode("utf-8") stdout = completed_process.stdout.decode("utf-8")
assert "hello" not in stdout assert "hello" not in stdout
assert "Tailing logs until the job exits" not in stdout assert "Tailing logs until the job exits" not in stdout
@ -130,13 +133,13 @@ class TestJobStop:
"""Should wait until the job is stopped.""" """Should wait until the job is stopped."""
cmd = "sleep 1000" cmd = "sleep 1000"
job_id = "test_basic_stop" job_id = "test_basic_stop"
completed_process = subprocess.run([ completed_process = subprocess.run(
"ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--", ["ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--", cmd]
cmd )
])
completed_process = subprocess.run( completed_process = subprocess.run(
["ray", "job", "stop", job_id], stdout=subprocess.PIPE) ["ray", "job", "stop", job_id], stdout=subprocess.PIPE
)
stdout = completed_process.stdout.decode("utf-8") stdout = completed_process.stdout.decode("utf-8")
assert "Waiting for job" in stdout assert "Waiting for job" in stdout
assert f"Job '{job_id}' was stopped" in stdout assert f"Job '{job_id}' was stopped" in stdout
@ -145,14 +148,13 @@ class TestJobStop:
"""Should not wait until the job is stopped.""" """Should not wait until the job is stopped."""
cmd = "echo hello && sleep 1000" cmd = "echo hello && sleep 1000"
job_id = "test_stop_no_wait" job_id = "test_stop_no_wait"
completed_process = subprocess.run([ completed_process = subprocess.run(
"ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--", ["ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--", cmd]
cmd )
])
completed_process = subprocess.run( completed_process = subprocess.run(
["ray", "job", "stop", "--no-wait", job_id], ["ray", "job", "stop", "--no-wait", job_id], stdout=subprocess.PIPE
stdout=subprocess.PIPE) )
stdout = completed_process.stdout.decode("utf-8") stdout = completed_process.stdout.decode("utf-8")
assert "Waiting for job" not in stdout assert "Waiting for job" not in stdout
assert f"Job '{job_id}' was stopped" not in stdout assert f"Job '{job_id}' was stopped" not in stdout

View file

@ -24,82 +24,61 @@ class TestJobSubmitRequestValidation:
assert r.entrypoint == "abc" assert r.entrypoint == "abc"
assert r.job_id is None assert r.job_id is None
r = validate_request_type({ r = validate_request_type(
"entrypoint": "abc", {"entrypoint": "abc", "job_id": "123"}, JobSubmitRequest
"job_id": "123" )
}, JobSubmitRequest)
assert r.entrypoint == "abc" assert r.entrypoint == "abc"
assert r.job_id == "123" assert r.job_id == "123"
with pytest.raises(TypeError, match="must be a string"): with pytest.raises(TypeError, match="must be a string"):
validate_request_type({ validate_request_type({"entrypoint": 123, "job_id": 1}, JobSubmitRequest)
"entrypoint": 123,
"job_id": 1
}, JobSubmitRequest)
def test_validate_runtime_env(self): def test_validate_runtime_env(self):
r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest) r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest)
assert r.entrypoint == "abc" assert r.entrypoint == "abc"
assert r.runtime_env is None assert r.runtime_env is None
r = validate_request_type({ r = validate_request_type(
"entrypoint": "abc", {"entrypoint": "abc", "runtime_env": {"hi": "hi2"}}, JobSubmitRequest
"runtime_env": { )
"hi": "hi2"
}
}, JobSubmitRequest)
assert r.entrypoint == "abc" assert r.entrypoint == "abc"
assert r.runtime_env == {"hi": "hi2"} assert r.runtime_env == {"hi": "hi2"}
with pytest.raises(TypeError, match="must be a dict"): with pytest.raises(TypeError, match="must be a dict"):
validate_request_type({ validate_request_type(
"entrypoint": "abc", {"entrypoint": "abc", "runtime_env": 123}, JobSubmitRequest
"runtime_env": 123 )
}, JobSubmitRequest)
with pytest.raises(TypeError, match="keys must be strings"): with pytest.raises(TypeError, match="keys must be strings"):
validate_request_type({ validate_request_type(
"entrypoint": "abc", {"entrypoint": "abc", "runtime_env": {1: "hi"}}, JobSubmitRequest
"runtime_env": { )
1: "hi"
}
}, JobSubmitRequest)
def test_validate_metadata(self): def test_validate_metadata(self):
r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest) r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest)
assert r.entrypoint == "abc" assert r.entrypoint == "abc"
assert r.metadata is None assert r.metadata is None
r = validate_request_type({ r = validate_request_type(
"entrypoint": "abc", {"entrypoint": "abc", "metadata": {"hi": "hi2"}}, JobSubmitRequest
"metadata": { )
"hi": "hi2"
}
}, JobSubmitRequest)
assert r.entrypoint == "abc" assert r.entrypoint == "abc"
assert r.metadata == {"hi": "hi2"} assert r.metadata == {"hi": "hi2"}
with pytest.raises(TypeError, match="must be a dict"): with pytest.raises(TypeError, match="must be a dict"):
validate_request_type({ validate_request_type(
"entrypoint": "abc", {"entrypoint": "abc", "metadata": 123}, JobSubmitRequest
"metadata": 123 )
}, JobSubmitRequest)
with pytest.raises(TypeError, match="keys must be strings"): with pytest.raises(TypeError, match="keys must be strings"):
validate_request_type({ validate_request_type(
"entrypoint": "abc", {"entrypoint": "abc", "metadata": {1: "hi"}}, JobSubmitRequest
"metadata": { )
1: "hi"
}
}, JobSubmitRequest)
with pytest.raises(TypeError, match="values must be strings"): with pytest.raises(TypeError, match="values must be strings"):
validate_request_type({ validate_request_type(
"entrypoint": "abc", {"entrypoint": "abc", "metadata": {"hi": 1}}, JobSubmitRequest
"metadata": { )
"hi": 1
}
}, JobSubmitRequest)
def test_uri_to_http_and_back(): def test_uri_to_http_and_back():
@ -127,4 +106,5 @@ def test_uri_to_http_and_back():
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
sys.exit(pytest.main(["-v", __file__])) sys.exit(pytest.main(["-v", __file__]))

View file

@ -8,11 +8,17 @@ import pytest
import ray import ray
from ray.dashboard.tests.conftest import * # noqa from ray.dashboard.tests.conftest import * # noqa
from ray.tests.conftest import _ray_start from ray.tests.conftest import _ray_start
from ray._private.test_utils import (format_web_url, wait_for_condition, from ray._private.test_utils import (
wait_until_server_available) format_web_url,
wait_for_condition,
wait_until_server_available,
)
from ray.dashboard.modules.job.common import CURRENT_VERSION, JobStatus from ray.dashboard.modules.job.common import CURRENT_VERSION, JobStatus
from ray.dashboard.modules.job.sdk import (ClusterInfo, JobSubmissionClient, from ray.dashboard.modules.job.sdk import (
parse_cluster_info) ClusterInfo,
JobSubmissionClient,
parse_cluster_info,
)
from unittest.mock import patch from unittest.mock import patch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,8 +56,8 @@ def _check_job_stopped(client: JobSubmissionClient, job_id: str) -> bool:
@pytest.fixture( @pytest.fixture(
scope="module", scope="module", params=["no_working_dir", "local_working_dir", "s3_working_dir"]
params=["no_working_dir", "local_working_dir", "s3_working_dir"]) )
def working_dir_option(request): def working_dir_option(request):
if request.param == "no_working_dir": if request.param == "no_working_dir":
yield { yield {
@ -81,9 +87,7 @@ def working_dir_option(request):
f.write("from test_module.test import run_test\n") f.write("from test_module.test import run_test\n")
yield { yield {
"runtime_env": { "runtime_env": {"working_dir": tmp_dir},
"working_dir": tmp_dir
},
"entrypoint": "python test.py", "entrypoint": "python test.py",
"expected_logs": "Hello from test_module!\n", "expected_logs": "Hello from test_module!\n",
} }
@ -104,7 +108,8 @@ def test_submit_job(job_sdk_client, working_dir_option):
job_id = client.submit_job( job_id = client.submit_job(
entrypoint=working_dir_option["entrypoint"], entrypoint=working_dir_option["entrypoint"],
runtime_env=working_dir_option["runtime_env"]) runtime_env=working_dir_option["runtime_env"],
)
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id) wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
@ -133,7 +138,8 @@ def test_http_bad_request(job_sdk_client):
def test_invalid_runtime_env(job_sdk_client): def test_invalid_runtime_env(job_sdk_client):
client = job_sdk_client client = job_sdk_client
job_id = client.submit_job( job_id = client.submit_job(
entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"}) entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"}
)
wait_for_condition(_check_job_failed, client=client, job_id=job_id) wait_for_condition(_check_job_failed, client=client, job_id=job_id)
status = client.get_job_status(job_id) status = client.get_job_status(job_id)
@ -143,8 +149,8 @@ def test_invalid_runtime_env(job_sdk_client):
def test_runtime_env_setup_failure(job_sdk_client): def test_runtime_env_setup_failure(job_sdk_client):
client = job_sdk_client client = job_sdk_client
job_id = client.submit_job( job_id = client.submit_job(
entrypoint="echo hello", entrypoint="echo hello", runtime_env={"working_dir": "s3://does_not_exist.zip"}
runtime_env={"working_dir": "s3://does_not_exist.zip"}) )
wait_for_condition(_check_job_failed, client=client, job_id=job_id) wait_for_condition(_check_job_failed, client=client, job_id=job_id)
status = client.get_job_status(job_id) status = client.get_job_status(job_id)
@ -168,8 +174,8 @@ raise RuntimeError('Intentionally failed.')
file.write(driver_script) file.write(driver_script)
job_id = client.submit_job( job_id = client.submit_job(
entrypoint="python test_script.py", entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
runtime_env={"working_dir": tmp_dir}) )
wait_for_condition(_check_job_failed, client=client, job_id=job_id) wait_for_condition(_check_job_failed, client=client, job_id=job_id)
logs = client.get_job_logs(job_id) logs = client.get_job_logs(job_id)
@ -196,8 +202,8 @@ raise RuntimeError('Intentionally failed.')
file.write(driver_script) file.write(driver_script)
job_id = client.submit_job( job_id = client.submit_job(
entrypoint="python test_script.py", entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
runtime_env={"working_dir": tmp_dir}) )
assert client.stop_job(job_id) is True assert client.stop_job(job_id) is True
wait_for_condition(_check_job_stopped, client=client, job_id=job_id) wait_for_condition(_check_job_stopped, client=client, job_id=job_id)
@ -206,28 +212,31 @@ def test_job_metadata(job_sdk_client):
client = job_sdk_client client = job_sdk_client
print_metadata_cmd = ( print_metadata_cmd = (
"python -c\"" 'python -c"'
"import ray;" "import ray;"
"ray.init();" "ray.init();"
"job_config=ray.worker.global_worker.core_worker.get_job_config();" "job_config=ray.worker.global_worker.core_worker.get_job_config();"
"print(dict(sorted(job_config.metadata.items())))" "print(dict(sorted(job_config.metadata.items())))"
"\"") '"'
)
job_id = client.submit_job( job_id = client.submit_job(
entrypoint=print_metadata_cmd, entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
metadata={ )
"key1": "val1",
"key2": "val2"
})
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id) wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
assert str({ assert (
"job_name": job_id, str(
"job_submission_id": job_id, {
"key1": "val1", "job_name": job_id,
"key2": "val2" "job_submission_id": job_id,
}) in client.get_job_logs(job_id) "key1": "val1",
"key2": "val2",
}
)
in client.get_job_logs(job_id)
)
def test_pass_job_id(job_sdk_client): def test_pass_job_id(job_sdk_client):
@ -261,19 +270,19 @@ def test_submit_optional_args(job_sdk_client):
json_data={"entrypoint": "ls"}, json_data={"entrypoint": "ls"},
) )
wait_for_condition( wait_for_condition(_check_job_succeeded, client=client, job_id=r.json()["job_id"])
_check_job_succeeded, client=client, job_id=r.json()["job_id"])
def test_missing_resources(job_sdk_client): def test_missing_resources(job_sdk_client):
"""Check that 404s are raised for resources that don't exist.""" """Check that 404s are raised for resources that don't exist."""
client = job_sdk_client client = job_sdk_client
conditions = [("GET", conditions = [
"/api/jobs/fake_job_id"), ("GET", ("GET", "/api/jobs/fake_job_id"),
"/api/jobs/fake_job_id/logs"), ("GET", "/api/jobs/fake_job_id/logs"),
("POST", "/api/jobs/fake_job_id/stop"), ("POST", "/api/jobs/fake_job_id/stop"),
("GET", "/api/packages/fake_package_uri")] ("GET", "/api/packages/fake_package_uri"),
]
for method, route in conditions: for method, route in conditions:
assert client._do_request(method, route).status_code == 404 assert client._do_request(method, route).status_code == 404
@ -287,7 +296,7 @@ def test_version_endpoint(job_sdk_client):
assert r.json() == { assert r.json() == {
"version": CURRENT_VERSION, "version": CURRENT_VERSION,
"ray_version": ray.__version__, "ray_version": ray.__version__,
"ray_commit": ray.__commit__ "ray_commit": ray.__commit__,
} }
@ -306,26 +315,31 @@ def test_request_headers(job_sdk_client):
cookies=None, cookies=None,
data=None, data=None,
json={"entrypoint": "ls"}, json={"entrypoint": "ls"},
headers={ headers={"Connection": "keep-alive", "Authorization": "TOK:<MY_TOKEN>"},
"Connection": "keep-alive", )
"Authorization": "TOK:<MY_TOKEN>"
})
@pytest.mark.parametrize("address", [ @pytest.mark.parametrize(
"http://127.0.0.1", "https://127.0.0.1", "ray://127.0.0.1", "address",
"fake_module://127.0.0.1" [
]) "http://127.0.0.1",
"https://127.0.0.1",
"ray://127.0.0.1",
"fake_module://127.0.0.1",
],
)
def test_parse_cluster_info(address: str): def test_parse_cluster_info(address: str):
if address.startswith("ray"): if address.startswith("ray"):
assert parse_cluster_info(address, False) == ClusterInfo( assert parse_cluster_info(address, False) == ClusterInfo(
address="http" + address[address.index("://"):], address="http" + address[address.index("://") :],
cookies=None, cookies=None,
metadata=None, metadata=None,
headers=None) headers=None,
)
elif address.startswith("http") or address.startswith("https"): elif address.startswith("http") or address.startswith("https"):
assert parse_cluster_info(address, False) == ClusterInfo( assert parse_cluster_info(address, False) == ClusterInfo(
address=address, cookies=None, metadata=None, headers=None) address=address, cookies=None, metadata=None, headers=None
)
else: else:
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
parse_cluster_info(address, False) parse_cluster_info(address, False)
@ -347,8 +361,8 @@ for i in range(100):
f.write(driver_script) f.write(driver_script)
job_id = client.submit_job( job_id = client.submit_job(
entrypoint="python test_script.py", entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
runtime_env={"working_dir": tmp_dir}) )
i = 0 i = 0
async for lines in client.tail_job_logs(job_id): async for lines in client.tail_job_logs(job_id):

View file

@ -8,8 +8,11 @@ import signal
import pytest import pytest
import ray import ray
from ray.dashboard.modules.job.common import (JobStatus, JOB_ID_METADATA_KEY, from ray.dashboard.modules.job.common import (
JOB_NAME_METADATA_KEY) JobStatus,
JOB_ID_METADATA_KEY,
JOB_NAME_METADATA_KEY,
)
from ray.dashboard.modules.job.job_manager import generate_job_id, JobManager from ray.dashboard.modules.job.job_manager import generate_job_id, JobManager
from ray._private.test_utils import SignalActor, wait_for_condition from ray._private.test_utils import SignalActor, wait_for_condition
@ -28,7 +31,8 @@ def job_manager(shared_ray_instance):
def _driver_script_path(file_name: str) -> str: def _driver_script_path(file_name: str) -> str:
return os.path.join( return os.path.join(
os.path.dirname(__file__), "subprocess_driver_scripts", file_name) os.path.dirname(__file__), "subprocess_driver_scripts", file_name
)
def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None): def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
@ -36,12 +40,15 @@ def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
pid_file = os.path.join(tmp_dir, "pid") pid_file = os.path.join(tmp_dir, "pid")
# Write subprocess pid to pid_file and block until tmp_file is present. # Write subprocess pid to pid_file and block until tmp_file is present.
wait_for_file_cmd = (f"echo $$ > {pid_file} && " wait_for_file_cmd = (
f"until [ -f {tmp_file} ]; " f"echo $$ > {pid_file} && "
"do echo 'Waiting...' && sleep 1; " f"until [ -f {tmp_file} ]; "
"done") "do echo 'Waiting...' && sleep 1; "
"done"
)
job_id = job_manager.submit_job( job_id = job_manager.submit_job(
entrypoint=wait_for_file_cmd, _start_signal_actor=start_signal_actor) entrypoint=wait_for_file_cmd, _start_signal_actor=start_signal_actor
)
status = job_manager.get_job_status(job_id) status = job_manager.get_job_status(job_id)
if start_signal_actor: if start_signal_actor:
@ -50,11 +57,9 @@ def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
logs = job_manager.get_job_logs(job_id) logs = job_manager.get_job_logs(job_id)
assert logs == "" assert logs == ""
else: else:
wait_for_condition( wait_for_condition(check_job_running, job_manager=job_manager, job_id=job_id)
check_job_running, job_manager=job_manager, job_id=job_id)
wait_for_condition( wait_for_condition(lambda: "Waiting..." in job_manager.get_job_logs(job_id))
lambda: "Waiting..." in job_manager.get_job_logs(job_id))
return pid_file, tmp_file, job_id return pid_file, tmp_file, job_id
@ -63,25 +68,19 @@ def check_job_succeeded(job_manager, job_id):
status = job_manager.get_job_status(job_id) status = job_manager.get_job_status(job_id)
if status.status == JobStatus.FAILED: if status.status == JobStatus.FAILED:
raise RuntimeError(f"Job failed! {status.message}") raise RuntimeError(f"Job failed! {status.message}")
assert status.status in { assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.SUCCEEDED}
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.SUCCEEDED
}
return status.status == JobStatus.SUCCEEDED return status.status == JobStatus.SUCCEEDED
def check_job_failed(job_manager, job_id): def check_job_failed(job_manager, job_id):
status = job_manager.get_job_status(job_id) status = job_manager.get_job_status(job_id)
assert status.status in { assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.FAILED}
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.FAILED
}
return status.status == JobStatus.FAILED return status.status == JobStatus.FAILED
def check_job_stopped(job_manager, job_id): def check_job_stopped(job_manager, job_id):
status = job_manager.get_job_status(job_id) status = job_manager.get_job_status(job_id)
assert status.status in { assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.STOPPED}
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.STOPPED
}
return status.status == JobStatus.STOPPED return status.status == JobStatus.STOPPED
@ -111,12 +110,10 @@ def test_generate_job_id():
def test_pass_job_id(job_manager): def test_pass_job_id(job_manager):
job_id = "my_custom_id" job_id = "my_custom_id"
returned_id = job_manager.submit_job( returned_id = job_manager.submit_job(entrypoint="echo hello", job_id=job_id)
entrypoint="echo hello", job_id=job_id)
assert returned_id == job_id assert returned_id == job_id
wait_for_condition( wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id)
# Check that the same job_id is rejected. # Check that the same job_id is rejected.
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@ -127,23 +124,20 @@ class TestShellScriptExecution:
def test_submit_basic_echo(self, job_manager): def test_submit_basic_echo(self, job_manager):
job_id = job_manager.submit_job(entrypoint="echo hello") job_id = job_manager.submit_job(entrypoint="echo hello")
wait_for_condition( wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert job_manager.get_job_logs(job_id) == "hello\n" assert job_manager.get_job_logs(job_id) == "hello\n"
def test_submit_stderr(self, job_manager): def test_submit_stderr(self, job_manager):
job_id = job_manager.submit_job(entrypoint="echo error 1>&2") job_id = job_manager.submit_job(entrypoint="echo error 1>&2")
wait_for_condition( wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert job_manager.get_job_logs(job_id) == "error\n" assert job_manager.get_job_logs(job_id) == "error\n"
def test_submit_ls_grep(self, job_manager): def test_submit_ls_grep(self, job_manager):
grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py" grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py"
job_id = job_manager.submit_job(entrypoint=grep_cmd) job_id = job_manager.submit_job(entrypoint=grep_cmd)
wait_for_condition( wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n" assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n"
def test_subprocess_exception(self, job_manager): def test_subprocess_exception(self, job_manager):
@ -161,8 +155,7 @@ class TestShellScriptExecution:
status = job_manager.get_job_status(job_id) status = job_manager.get_job_status(job_id)
if status.status != JobStatus.FAILED: if status.status != JobStatus.FAILED:
return False return False
if ("Exception: Script failed with exception !" not in if "Exception: Script failed with exception !" not in status.message:
status.message):
return False return False
return job_manager._get_actor_for_job(job_id) is None return job_manager._get_actor_for_job(job_id) is None
@ -172,14 +165,13 @@ class TestShellScriptExecution:
def test_submit_with_s3_runtime_env(self, job_manager): def test_submit_with_s3_runtime_env(self, job_manager):
job_id = job_manager.submit_job( job_id = job_manager.submit_job(
entrypoint="python script.py", entrypoint="python script.py",
runtime_env={ runtime_env={"working_dir": "s3://runtime-env-test/script_runtime_env.zip"},
"working_dir": "s3://runtime-env-test/script_runtime_env.zip" )
})
wait_for_condition( wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id) assert (
assert job_manager.get_job_logs( job_manager.get_job_logs(job_id) == "Executing main() from script.py !!\n"
job_id) == "Executing main() from script.py !!\n" )
class TestRuntimeEnv: class TestRuntimeEnv:
@ -193,14 +185,10 @@ class TestRuntimeEnv:
""" """
job_id = job_manager.submit_job( job_id = job_manager.submit_job(
entrypoint="echo $TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR", entrypoint="echo $TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR",
runtime_env={ runtime_env={"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"}},
"env_vars": { )
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"
}
})
wait_for_condition( wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert job_manager.get_job_logs(job_id) == "233\n" assert job_manager.get_job_logs(job_id) == "233\n"
def test_multiple_runtime_envs(self, job_manager): def test_multiple_runtime_envs(self, job_manager):
@ -208,28 +196,32 @@ class TestRuntimeEnv:
job_id_1 = job_manager.submit_job( job_id_1 = job_manager.submit_job(
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}", entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
runtime_env={ runtime_env={
"env_vars": { "env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR" },
} )
})
wait_for_condition( wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id_1) check_job_succeeded, job_manager=job_manager, job_id=job_id_1
)
logs = job_manager.get_job_logs(job_id_1) logs = job_manager.get_job_logs(job_id_1)
assert "{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_1_VAR'}}" in logs # noqa: E501 assert (
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_1_VAR'}}" in logs
) # noqa: E501
job_id_2 = job_manager.submit_job( job_id_2 = job_manager.submit_job(
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}", entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
runtime_env={ runtime_env={
"env_vars": { "env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_2_VAR"}
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_2_VAR" },
} )
})
wait_for_condition( wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id_2) check_job_succeeded, job_manager=job_manager, job_id=job_id_2
)
logs = job_manager.get_job_logs(job_id_2) logs = job_manager.get_job_logs(job_id_2)
assert "{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_2_VAR'}}" in logs # noqa: E501 assert (
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_2_VAR'}}" in logs
) # noqa: E501
def test_env_var_and_driver_job_config_warning(self, job_manager): def test_env_var_and_driver_job_config_warning(self, job_manager):
"""Ensure we got error message from worker.py and job logs """Ensure we got error message from worker.py and job logs
@ -238,17 +230,15 @@ class TestRuntimeEnv:
job_id = job_manager.submit_job( job_id = job_manager.submit_job(
entrypoint=f"python {_driver_script_path('override_env_var.py')}", entrypoint=f"python {_driver_script_path('override_env_var.py')}",
runtime_env={ runtime_env={
"env_vars": { "env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR" },
} )
})
wait_for_condition( wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id)
logs = job_manager.get_job_logs(job_id) logs = job_manager.get_job_logs(job_id)
assert logs.startswith( assert logs.startswith(
"Both RAY_JOB_CONFIG_JSON_ENV_VAR and ray.init(runtime_env) " "Both RAY_JOB_CONFIG_JSON_ENV_VAR and ray.init(runtime_env) " "are provided"
"are provided") )
assert "JOB_1_VAR" in logs assert "JOB_1_VAR" in logs
def test_failed_runtime_env_validation(self, job_manager): def test_failed_runtime_env_validation(self, job_manager):
@ -257,7 +247,8 @@ class TestRuntimeEnv:
""" """
run_cmd = f"python {_driver_script_path('override_env_var.py')}" run_cmd = f"python {_driver_script_path('override_env_var.py')}"
job_id = job_manager.submit_job( job_id = job_manager.submit_job(
entrypoint=run_cmd, runtime_env={"working_dir": "path_not_exist"}) entrypoint=run_cmd, runtime_env={"working_dir": "path_not_exist"}
)
status = job_manager.get_job_status(job_id) status = job_manager.get_job_status(job_id)
assert status.status == JobStatus.FAILED assert status.status == JobStatus.FAILED
@ -269,11 +260,10 @@ class TestRuntimeEnv:
""" """
run_cmd = f"python {_driver_script_path('override_env_var.py')}" run_cmd = f"python {_driver_script_path('override_env_var.py')}"
job_id = job_manager.submit_job( job_id = job_manager.submit_job(
entrypoint=run_cmd, entrypoint=run_cmd, runtime_env={"working_dir": "s3://does_not_exist.zip"}
runtime_env={"working_dir": "s3://does_not_exist.zip"}) )
wait_for_condition( wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
check_job_failed, job_manager=job_manager, job_id=job_id)
status = job_manager.get_job_status(job_id) status = job_manager.get_job_status(job_id)
assert "runtime_env setup failed" in status.message assert "runtime_env setup failed" in status.message
@ -283,69 +273,67 @@ class TestRuntimeEnv:
return str(dict(sorted(d.items()))) return str(dict(sorted(d.items())))
print_metadata_cmd = ( print_metadata_cmd = (
"python -c\"" 'python -c"'
"import ray;" "import ray;"
"ray.init();" "ray.init();"
"job_config=ray.worker.global_worker.core_worker.get_job_config();" "job_config=ray.worker.global_worker.core_worker.get_job_config();"
"print(dict(sorted(job_config.metadata.items())))" "print(dict(sorted(job_config.metadata.items())))"
"\"") '"'
)
# Check that we default to only the job ID and job name. # Check that we default to only the job ID and job name.
job_id = job_manager.submit_job(entrypoint=print_metadata_cmd) job_id = job_manager.submit_job(entrypoint=print_metadata_cmd)
wait_for_condition( wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id) assert dict_to_str(
assert dict_to_str({ {JOB_NAME_METADATA_KEY: job_id, JOB_ID_METADATA_KEY: job_id}
JOB_NAME_METADATA_KEY: job_id, ) in job_manager.get_job_logs(job_id)
JOB_ID_METADATA_KEY: job_id
}) in job_manager.get_job_logs(job_id)
# Check that we can pass custom metadata. # Check that we can pass custom metadata.
job_id = job_manager.submit_job( job_id = job_manager.submit_job(
entrypoint=print_metadata_cmd, entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
metadata={ )
"key1": "val1",
"key2": "val2"
})
wait_for_condition( wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id) assert (
assert dict_to_str({ dict_to_str(
JOB_NAME_METADATA_KEY: job_id, {
JOB_ID_METADATA_KEY: job_id, JOB_NAME_METADATA_KEY: job_id,
"key1": "val1", JOB_ID_METADATA_KEY: job_id,
"key2": "val2" "key1": "val1",
}) in job_manager.get_job_logs(job_id) "key2": "val2",
}
)
in job_manager.get_job_logs(job_id)
)
# Check that we can override job name. # Check that we can override job name.
job_id = job_manager.submit_job( job_id = job_manager.submit_job(
entrypoint=print_metadata_cmd, entrypoint=print_metadata_cmd,
metadata={JOB_NAME_METADATA_KEY: "custom_name"}) metadata={JOB_NAME_METADATA_KEY: "custom_name"},
)
wait_for_condition( wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id) assert dict_to_str(
assert dict_to_str({ {JOB_NAME_METADATA_KEY: "custom_name", JOB_ID_METADATA_KEY: job_id}
JOB_NAME_METADATA_KEY: "custom_name", ) in job_manager.get_job_logs(job_id)
JOB_ID_METADATA_KEY: job_id
}) in job_manager.get_job_logs(job_id)
class TestAsyncAPI: class TestAsyncAPI:
def test_status_and_logs_while_blocking(self, job_manager): def test_status_and_logs_while_blocking(self, job_manager):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, tmp_file, job_id = _run_hanging_command( pid_file, tmp_file, job_id = _run_hanging_command(job_manager, tmp_dir)
job_manager, tmp_dir)
with open(pid_file, "r") as file: with open(pid_file, "r") as file:
pid = int(file.read()) pid = int(file.read())
assert psutil.pid_exists(pid), ( assert psutil.pid_exists(pid), "driver subprocess should be running"
"driver subprocess should be running")
# Signal the job to exit by writing to the file. # Signal the job to exit by writing to the file.
with open(tmp_file, "w") as f: with open(tmp_file, "w") as f:
print("hello", file=f) print("hello", file=f)
wait_for_condition( wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id) check_job_succeeded, job_manager=job_manager, job_id=job_id
)
# Ensure driver subprocess gets cleaned up after job reached # Ensure driver subprocess gets cleaned up after job reached
# termination state # termination state
wait_for_condition(check_subprocess_cleaned, pid=pid) wait_for_condition(check_subprocess_cleaned, pid=pid)
@ -356,7 +344,8 @@ class TestAsyncAPI:
assert job_manager.stop_job(job_id) is True assert job_manager.stop_job(job_id) is True
wait_for_condition( wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id) check_job_stopped, job_manager=job_manager, job_id=job_id
)
# Assert re-stopping a stopped job also returns False # Assert re-stopping a stopped job also returns False
wait_for_condition(lambda: job_manager.stop_job(job_id) is False) wait_for_condition(lambda: job_manager.stop_job(job_id) is False)
# Assert stopping non-existent job returns False # Assert stopping non-existent job returns False
@ -375,13 +364,11 @@ class TestAsyncAPI:
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir) pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
with open(pid_file, "r") as file: with open(pid_file, "r") as file:
pid = int(file.read()) pid = int(file.read())
assert psutil.pid_exists(pid), ( assert psutil.pid_exists(pid), "driver subprocess should be running"
"driver subprocess should be running")
actor = job_manager._get_actor_for_job(job_id) actor = job_manager._get_actor_for_job(job_id)
ray.kill(actor, no_restart=True) ray.kill(actor, no_restart=True)
wait_for_condition( wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
check_job_failed, job_manager=job_manager, job_id=job_id)
# Ensure driver subprocess gets cleaned up after job reached # Ensure driver subprocess gets cleaned up after job reached
# termination state # termination state
@ -398,16 +385,18 @@ class TestAsyncAPI:
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, _, job_id = _run_hanging_command( pid_file, _, job_id = _run_hanging_command(
job_manager, tmp_dir, start_signal_actor=start_signal_actor) job_manager, tmp_dir, start_signal_actor=start_signal_actor
)
assert not os.path.exists(pid_file), ( assert not os.path.exists(pid_file), (
"driver subprocess should NOT be running while job is " "driver subprocess should NOT be running while job is " "still PENDING."
"still PENDING.") )
assert job_manager.stop_job(job_id) is True assert job_manager.stop_job(job_id) is True
# Send run signal to unblock run function # Send run signal to unblock run function
ray.get(start_signal_actor.send.remote()) ray.get(start_signal_actor.send.remote())
wait_for_condition( wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id) check_job_stopped, job_manager=job_manager, job_id=job_id
)
def test_kill_job_actor_in_pending(self, job_manager): def test_kill_job_actor_in_pending(self, job_manager):
""" """
@ -420,16 +409,16 @@ class TestAsyncAPI:
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, _, job_id = _run_hanging_command( pid_file, _, job_id = _run_hanging_command(
job_manager, tmp_dir, start_signal_actor=start_signal_actor) job_manager, tmp_dir, start_signal_actor=start_signal_actor
)
assert not os.path.exists(pid_file), ( assert not os.path.exists(pid_file), (
"driver subprocess should NOT be running while job is " "driver subprocess should NOT be running while job is " "still PENDING."
"still PENDING.") )
actor = job_manager._get_actor_for_job(job_id) actor = job_manager._get_actor_for_job(job_id)
ray.kill(actor, no_restart=True) ray.kill(actor, no_restart=True)
wait_for_condition( wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
check_job_failed, job_manager=job_manager, job_id=job_id)
def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager): def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager):
""" """
@ -442,12 +431,12 @@ class TestAsyncAPI:
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir) pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
with open(pid_file, "r") as file: with open(pid_file, "r") as file:
pid = int(file.read()) pid = int(file.read())
assert psutil.pid_exists(pid), ( assert psutil.pid_exists(pid), "driver subprocess should be running"
"driver subprocess should be running")
assert job_manager.stop_job(job_id) is True assert job_manager.stop_job(job_id) is True
wait_for_condition( wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id) check_job_stopped, job_manager=job_manager, job_id=job_id
)
# Ensure driver subprocess gets cleaned up after job reached # Ensure driver subprocess gets cleaned up after job reached
# termination state # termination state
@ -455,11 +444,9 @@ class TestAsyncAPI:
class TestTailLogs: class TestTailLogs:
async def _tail_and_assert_logs(self, async def _tail_and_assert_logs(
job_id, self, job_id, job_manager, expected_log="", num_iteration=5
job_manager, ):
expected_log="",
num_iteration=5):
i = 0 i = 0
async for lines in job_manager.tail_job_logs(job_id): async for lines in job_manager.tail_job_logs(job_id):
assert all(s == expected_log for s in lines.strip().split("\n")) assert all(s == expected_log for s in lines.strip().split("\n"))
@ -470,8 +457,7 @@ class TestTailLogs:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unknown_job(self, job_manager): async def test_unknown_job(self, job_manager):
with pytest.raises( with pytest.raises(RuntimeError, match="Job 'unknown' does not exist."):
RuntimeError, match="Job 'unknown' does not exist."):
async for _ in job_manager.tail_job_logs("unknown"): async for _ in job_manager.tail_job_logs("unknown"):
pass pass
@ -482,33 +468,31 @@ class TestTailLogs:
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
_, tmp_file, job_id = _run_hanging_command( _, tmp_file, job_id = _run_hanging_command(
job_manager, tmp_dir, start_signal_actor=start_signal_actor) job_manager, tmp_dir, start_signal_actor=start_signal_actor
)
# TODO(edoakes): check we get no logs before actor starts (not sure # TODO(edoakes): check we get no logs before actor starts (not sure
# how to timeout the iterator call). # how to timeout the iterator call).
assert job_manager.get_job_status( assert job_manager.get_job_status(job_id).status == JobStatus.PENDING
job_id).status == JobStatus.PENDING
# Signal job to start. # Signal job to start.
ray.get(start_signal_actor.send.remote()) ray.get(start_signal_actor.send.remote())
await self._tail_and_assert_logs( await self._tail_and_assert_logs(
job_id, job_id, job_manager, expected_log="Waiting...", num_iteration=5
job_manager, )
expected_log="Waiting...",
num_iteration=5)
# Signal the job to exit by writing to the file. # Signal the job to exit by writing to the file.
with open(tmp_file, "w") as f: with open(tmp_file, "w") as f:
print("hello", file=f) print("hello", file=f)
async for lines in job_manager.tail_job_logs(job_id): async for lines in job_manager.tail_job_logs(job_id):
assert all( assert all(s == "Waiting..." for s in lines.strip().split("\n"))
s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="") print(lines, end="")
wait_for_condition( wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id) check_job_succeeded, job_manager=job_manager, job_id=job_id
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_failed_job(self, job_manager): async def test_failed_job(self, job_manager):
@ -517,22 +501,18 @@ class TestTailLogs:
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir) pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
await self._tail_and_assert_logs( await self._tail_and_assert_logs(
job_id, job_id, job_manager, expected_log="Waiting...", num_iteration=5
job_manager, )
expected_log="Waiting...",
num_iteration=5)
# Kill the job unexpectedly. # Kill the job unexpectedly.
with open(pid_file, "r") as f: with open(pid_file, "r") as f:
os.kill(int(f.read()), signal.SIGKILL) os.kill(int(f.read()), signal.SIGKILL)
async for lines in job_manager.tail_job_logs(job_id): async for lines in job_manager.tail_job_logs(job_id):
assert all( assert all(s == "Waiting..." for s in lines.strip().split("\n"))
s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="") print(lines, end="")
wait_for_condition( wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
check_job_failed, job_manager=job_manager, job_id=job_id)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stopped_job(self, job_manager): async def test_stopped_job(self, job_manager):
@ -541,21 +521,19 @@ class TestTailLogs:
_, _, job_id = _run_hanging_command(job_manager, tmp_dir) _, _, job_id = _run_hanging_command(job_manager, tmp_dir)
await self._tail_and_assert_logs( await self._tail_and_assert_logs(
job_id, job_id, job_manager, expected_log="Waiting...", num_iteration=5
job_manager, )
expected_log="Waiting...",
num_iteration=5)
# Stop the job via the API. # Stop the job via the API.
job_manager.stop_job(job_id) job_manager.stop_job(job_id)
async for lines in job_manager.tail_job_logs(job_id): async for lines in job_manager.tail_job_logs(job_id):
assert all( assert all(s == "Waiting..." for s in lines.strip().split("\n"))
s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="") print(lines, end="")
wait_for_condition( wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id) check_job_stopped, job_manager=job_manager, job_id=job_id
)
def test_logs_streaming(job_manager): def test_logs_streaming(job_manager):
@ -568,7 +546,7 @@ while True:
time.sleep(1) time.sleep(1)
""" """
stream_logs_cmd = f"python -c \"{stream_logs_script}\"" stream_logs_cmd = f'python -c "{stream_logs_script}"'
job_id = job_manager.submit_job(entrypoint=stream_logs_cmd) job_id = job_manager.submit_job(entrypoint=stream_logs_cmd)
wait_for_condition(lambda: "STREAMED" in job_manager.get_job_logs(job_id)) wait_for_condition(lambda: "STREAMED" in job_manager.get_job_logs(job_id))

View file

@ -12,7 +12,7 @@ def tmp():
yield f.name yield f.name
class TestIterLine(): class TestIterLine:
def test_invalid_type(self): def test_invalid_type(self):
with pytest.raises(TypeError, match="path must be a string"): with pytest.raises(TypeError, match="path must be a string"):
next(file_tail_iterator(1)) next(file_tail_iterator(1))

View file

@ -19,10 +19,8 @@ class LogHead(dashboard_utils.DashboardHeadModule):
self._proxy_session = aiohttp.ClientSession(auto_decompress=False) self._proxy_session = aiohttp.ClientSession(auto_decompress=False)
log_utils.register_mimetypes() log_utils.register_mimetypes()
routes.static("/logs", self._dashboard_head.log_dir, show_index=True) routes.static("/logs", self._dashboard_head.log_dir, show_index=True)
GlobalSignals.node_info_fetched.append( GlobalSignals.node_info_fetched.append(self.insert_log_url_to_node_info)
self.insert_log_url_to_node_info) GlobalSignals.node_summary_fetched.append(self.insert_log_url_to_node_info)
GlobalSignals.node_summary_fetched.append(
self.insert_log_url_to_node_info)
async def insert_log_url_to_node_info(self, node_info): async def insert_log_url_to_node_info(self, node_info):
node_id = node_info.get("raylet", {}).get("nodeId") node_id = node_info.get("raylet", {}).get("nodeId")
@ -33,7 +31,8 @@ class LogHead(dashboard_utils.DashboardHeadModule):
return return
agent_http_port, _ = agent_port agent_http_port, _ = agent_port
log_url = self.LOG_URL_TEMPLATE.format( log_url = self.LOG_URL_TEMPLATE.format(
ip=node_info.get("ip"), port=agent_http_port) ip=node_info.get("ip"), port=agent_http_port
)
node_info["logUrl"] = log_url node_info["logUrl"] = log_url
@routes.get("/log_index") @routes.get("/log_index")
@ -43,15 +42,16 @@ class LogHead(dashboard_utils.DashboardHeadModule):
for node_id, ports in DataSource.agents.items(): for node_id, ports in DataSource.agents.items():
ip = DataSource.node_id_to_ip[node_id] ip = DataSource.node_id_to_ip[node_id]
agent_ips.append(ip) agent_ips.append(ip)
url_list.append( url_list.append(self.LOG_URL_TEMPLATE.format(ip=ip, port=str(ports[0])))
self.LOG_URL_TEMPLATE.format(ip=ip, port=str(ports[0])))
if self._dashboard_head.ip not in agent_ips: if self._dashboard_head.ip not in agent_ips:
url_list.append( url_list.append(
self.LOG_URL_TEMPLATE.format( self.LOG_URL_TEMPLATE.format(
ip=self._dashboard_head.ip, ip=self._dashboard_head.ip, port=self._dashboard_head.http_port
port=self._dashboard_head.http_port)) )
)
return aiohttp.web.Response( return aiohttp.web.Response(
text=self._directory_as_html(url_list), content_type="text/html") text=self._directory_as_html(url_list), content_type="text/html"
)
@routes.get("/log_proxy") @routes.get("/log_proxy")
async def get_log_from_proxy(self, req) -> aiohttp.web.StreamResponse: async def get_log_from_proxy(self, req) -> aiohttp.web.StreamResponse:
@ -60,9 +60,11 @@ class LogHead(dashboard_utils.DashboardHeadModule):
raise Exception("url is None.") raise Exception("url is None.")
body = await req.read() body = await req.read()
async with self._proxy_session.request( async with self._proxy_session.request(
req.method, url, data=body, headers=req.headers) as r: req.method, url, data=body, headers=req.headers
) as r:
sr = aiohttp.web.StreamResponse( sr = aiohttp.web.StreamResponse(
status=r.status, reason=r.reason, headers=req.headers) status=r.status, reason=r.reason, headers=req.headers
)
sr.content_length = r.content_length sr.content_length = r.content_length
sr.content_type = r.content_type sr.content_type = r.content_type
sr.charset = r.charset sr.charset = r.charset

View file

@ -40,8 +40,7 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
test_log_text = "test_log_text" test_log_text = "test_log_text"
ray.get(write_log.remote(test_log_text)) ray.get(write_log.remote(test_log_text))
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
is True)
webui_url = ray_start_with_dashboard["webui_url"] webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
node_id = ray_start_with_dashboard["node_id"] node_id = ray_start_with_dashboard["node_id"]
@ -82,8 +81,8 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
# Test range request. # Test range request.
response = requests.get( response = requests.get(
webui_url + "/logs/dashboard.log", webui_url + "/logs/dashboard.log", headers={"Range": "bytes=44-52"}
headers={"Range": "bytes=44-52"}) )
response.raise_for_status() response.raise_for_status()
assert response.text == "Dashboard" assert response.text == "Dashboard"
@ -100,16 +99,19 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
last_ex = ex last_ex = ex
finally: finally:
if time.time() > start_time + timeout_seconds: if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception( ex_stack = (
type(last_ex), last_ex, traceback.format_exception(
last_ex.__traceback__) if last_ex else [] type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack) ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}") raise Exception(f"Timed out while testing, {ex_stack}")
def test_log_proxy(ray_start_with_dashboard): def test_log_proxy(ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
is True)
webui_url = ray_start_with_dashboard["webui_url"] webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
@ -122,21 +124,27 @@ def test_log_proxy(ray_start_with_dashboard):
# Test range request. # Test range request.
response = requests.get( response = requests.get(
f"{webui_url}/log_proxy?url={webui_url}/logs/dashboard.log", f"{webui_url}/log_proxy?url={webui_url}/logs/dashboard.log",
headers={"Range": "bytes=44-52"}) headers={"Range": "bytes=44-52"},
)
response.raise_for_status() response.raise_for_status()
assert response.text == "Dashboard" assert response.text == "Dashboard"
# Test 404. # Test 404.
response = requests.get(f"{webui_url}/log_proxy?" response = requests.get(
f"url={webui_url}/logs/not_exist_file.log") f"{webui_url}/log_proxy?" f"url={webui_url}/logs/not_exist_file.log"
)
assert response.status_code == 404 assert response.status_code == 404
break break
except Exception as ex: except Exception as ex:
last_ex = ex last_ex = ex
finally: finally:
if time.time() > start_time + timeout_seconds: if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception( ex_stack = (
type(last_ex), last_ex, traceback.format_exception(
last_ex.__traceback__) if last_ex else [] type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack) ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}") raise Exception(f"Timed out while testing, {ex_stack}")

View file

@ -9,8 +9,10 @@ import ray._private.utils
import ray._private.gcs_utils as gcs_utils import ray._private.gcs_utils as gcs_utils
from ray import ray_constants from ray import ray_constants
from ray.dashboard.modules.node import node_consts from ray.dashboard.modules.node import node_consts
from ray.dashboard.modules.node.node_consts import (MAX_LOGS_TO_CACHE, from ray.dashboard.modules.node.node_consts import (
LOG_PRUNE_THREASHOLD) MAX_LOGS_TO_CACHE,
LOG_PRUNE_THREASHOLD,
)
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.dashboard.consts as dashboard_consts import ray.dashboard.consts as dashboard_consts
@ -28,13 +30,21 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
def gcs_node_info_to_dict(message): def gcs_node_info_to_dict(message):
return dashboard_utils.message_to_dict( return dashboard_utils.message_to_dict(
message, {"nodeId"}, including_default_value_fields=True) message, {"nodeId"}, including_default_value_fields=True
)
def node_stats_to_dict(message): def node_stats_to_dict(message):
decode_keys = { decode_keys = {
"actorId", "jobId", "taskId", "parentTaskId", "sourceActorId", "actorId",
"callerId", "rayletId", "workerId", "placementGroupId" "jobId",
"taskId",
"parentTaskId",
"sourceActorId",
"callerId",
"rayletId",
"workerId",
"placementGroupId",
} }
core_workers_stats = message.core_workers_stats core_workers_stats = message.core_workers_stats
message.ClearField("core_workers_stats") message.ClearField("core_workers_stats")
@ -42,7 +52,8 @@ def node_stats_to_dict(message):
result = dashboard_utils.message_to_dict(message, decode_keys) result = dashboard_utils.message_to_dict(message, decode_keys)
result["coreWorkersStats"] = [ result["coreWorkersStats"] = [
dashboard_utils.message_to_dict( dashboard_utils.message_to_dict(
m, decode_keys, including_default_value_fields=True) m, decode_keys, including_default_value_fields=True
)
for m in core_workers_stats for m in core_workers_stats
] ]
return result return result
@ -66,11 +77,13 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if change.new: if change.new:
# TODO(fyrestone): Handle exceptions. # TODO(fyrestone): Handle exceptions.
node_id, node_info = change.new node_id, node_info = change.new
address = "{}:{}".format(node_info["nodeManagerAddress"], address = "{}:{}".format(
int(node_info["nodeManagerPort"])) node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
options = (("grpc.enable_http_proxy", 0), ) )
options = (("grpc.enable_http_proxy", 0),)
channel = ray._private.utils.init_grpc_channel( channel = ray._private.utils.init_grpc_channel(
address, options, asynchronous=True) address, options, asynchronous=True
)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
self._stubs[node_id] = stub self._stubs[node_id] = stub
@ -81,8 +94,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
A dict of information about the nodes in the cluster. A dict of information about the nodes in the cluster.
""" """
request = gcs_service_pb2.GetAllNodeInfoRequest() request = gcs_service_pb2.GetAllNodeInfoRequest()
reply = await self._gcs_node_info_stub.GetAllNodeInfo( reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=2)
request, timeout=2)
if reply.status.code == 0: if reply.status.code == 0:
result = {} result = {}
for node_info in reply.node_info_list: for node_info in reply.node_info_list:
@ -116,11 +128,11 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
agents = dict(DataSource.agents) agents = dict(DataSource.agents)
for node_id in alive_node_ids: for node_id in alive_node_ids:
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" \ key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" f"{node_id}"
f"{node_id}"
# TODO: Use async version if performance is an issue # TODO: Use async version if performance is an issue
agent_port = ray.experimental.internal_kv._internal_kv_get( agent_port = ray.experimental.internal_kv._internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD) key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
)
if agent_port: if agent_port:
agents[node_id] = json.loads(agent_port) agents[node_id] = json.loads(agent_port)
for node_id in agents.keys() - set(alive_node_ids): for node_id in agents.keys() - set(alive_node_ids):
@ -142,9 +154,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if view == "summary": if view == "summary":
all_node_summary = await DataOrganizer.get_all_node_summary() all_node_summary = await DataOrganizer.get_all_node_summary()
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, success=True, message="Node summary fetched.", summary=all_node_summary
message="Node summary fetched.", )
summary=all_node_summary)
elif view == "details": elif view == "details":
all_node_details = await DataOrganizer.get_all_node_details() all_node_details = await DataOrganizer.get_all_node_details()
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
@ -160,10 +171,12 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, success=True,
message="Node hostname list fetched.", message="Node hostname list fetched.",
host_name_list=list(alive_hostnames)) host_name_list=list(alive_hostnames),
)
else: else:
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=False, message=f"Unknown view {view}") success=False, message=f"Unknown view {view}"
)
@routes.get("/nodes/{node_id}") @routes.get("/nodes/{node_id}")
@dashboard_optional_utils.aiohttp_cache @dashboard_optional_utils.aiohttp_cache
@ -171,7 +184,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
node_id = req.match_info.get("node_id") node_id = req.match_info.get("node_id")
node_info = await DataOrganizer.get_node_info(node_id) node_info = await DataOrganizer.get_node_info(node_id)
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, message="Node details fetched.", detail=node_info) success=True, message="Node details fetched.", detail=node_info
)
@routes.get("/memory/memory_table") @routes.get("/memory/memory_table")
async def get_memory_table(self, req) -> aiohttp.web.Response: async def get_memory_table(self, req) -> aiohttp.web.Response:
@ -187,7 +201,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, success=True,
message="Fetched memory table", message="Fetched memory table",
memory_table=memory_table.as_dict()) memory_table=memory_table.as_dict(),
)
@routes.get("/memory/set_fetch") @routes.get("/memory/set_fetch")
async def set_fetch_memory_info(self, req) -> aiohttp.web.Response: async def set_fetch_memory_info(self, req) -> aiohttp.web.Response:
@ -198,11 +213,11 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
self._collect_memory_info = False self._collect_memory_info = False
else: else:
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=False, success=False, message=f"Unknown argument to set_fetch {should_fetch}"
message=f"Unknown argument to set_fetch {should_fetch}") )
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, success=True, message=f"Successfully set fetching to {should_fetch}"
message=f"Successfully set fetching to {should_fetch}") )
@routes.get("/node_logs") @routes.get("/node_logs")
async def get_logs(self, req) -> aiohttp.web.Response: async def get_logs(self, req) -> aiohttp.web.Response:
@ -212,7 +227,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if pid: if pid:
node_logs = {str(pid): node_logs.get(pid, [])} node_logs = {str(pid): node_logs.get(pid, [])}
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, message="Fetched logs.", logs=node_logs) success=True, message="Fetched logs.", logs=node_logs
)
@routes.get("/node_errors") @routes.get("/node_errors")
async def get_errors(self, req) -> aiohttp.web.Response: async def get_errors(self, req) -> aiohttp.web.Response:
@ -222,7 +238,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if pid: if pid:
node_errors = {str(pid): node_errors.get(pid, [])} node_errors = {str(pid): node_errors.get(pid, [])}
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, message="Fetched errors.", errors=node_errors) success=True, message="Fetched errors.", errors=node_errors
)
@async_loop_forever(node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS) @async_loop_forever(node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS)
async def _update_node_stats(self): async def _update_node_stats(self):
@ -234,8 +251,10 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
try: try:
reply = await stub.GetNodeStats( reply = await stub.GetNodeStats(
node_manager_pb2.GetNodeStatsRequest( node_manager_pb2.GetNodeStatsRequest(
include_memory_info=self._collect_memory_info), include_memory_info=self._collect_memory_info
timeout=2) ),
timeout=2,
)
reply_dict = node_stats_to_dict(reply) reply_dict = node_stats_to_dict(reply)
DataSource.node_stats[node_id] = reply_dict DataSource.node_stats[node_id] = reply_dict
except Exception: except Exception:
@ -263,8 +282,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if self._dashboard_head.gcs_log_subscriber: if self._dashboard_head.gcs_log_subscriber:
while True: while True:
try: try:
log_batch = await \ log_batch = await self._dashboard_head.gcs_log_subscriber.poll()
self._dashboard_head.gcs_log_subscriber.poll()
if log_batch is None: if log_batch is None:
continue continue
process_log_batch(log_batch) process_log_batch(log_batch)
@ -296,11 +314,13 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
ip = match.group(2) ip = match.group(2)
errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {})) errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
pid_errors = list(errs_for_ip.get(pid, [])) pid_errors = list(errs_for_ip.get(pid, []))
pid_errors.append({ pid_errors.append(
"message": message, {
"timestamp": error_data.timestamp, "message": message,
"type": error_data.type "timestamp": error_data.timestamp,
}) "type": error_data.type,
}
)
errs_for_ip[pid] = pid_errors errs_for_ip[pid] = pid_errors
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
logger.info(f"Received error entry for {ip} {pid}") logger.info(f"Received error entry for {ip} {pid}")
@ -308,8 +328,10 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if self._dashboard_head.gcs_error_subscriber: if self._dashboard_head.gcs_error_subscriber:
while True: while True:
try: try:
_, error_data = await \ (
self._dashboard_head.gcs_error_subscriber.poll() _,
error_data,
) = await self._dashboard_head.gcs_error_subscriber.poll()
if error_data is None: if error_data is None:
continue continue
process_error(error_data) process_error(error_data)
@ -328,20 +350,23 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
try: try:
_, data = msg _, data = msg
pubsub_msg = gcs_utils.PubSubMessage.FromString(data) pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
error_data = gcs_utils.ErrorTableData.FromString( error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
pubsub_msg.data)
process_error(error_data) process_error(error_data)
except Exception: except Exception:
logger.exception("Error receiving error info from Redis.") logger.exception("Error receiving error info from Redis.")
async def run(self, server): async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._gcs_node_info_stub = \ self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
gcs_service_pb2_grpc.NodeInfoGcsServiceStub(gcs_channel) gcs_channel
)
await asyncio.gather(self._update_nodes(), self._update_node_stats(), await asyncio.gather(
self._update_log_info(), self._update_nodes(),
self._update_error_info()) self._update_node_stats(),
self._update_log_info(),
self._update_error_info(),
)
@staticmethod @staticmethod
def is_minimal_module(): def is_minimal_module():

View file

@ -10,19 +10,23 @@ import ray
import threading import threading
from datetime import datetime, timedelta from datetime import datetime, timedelta
from ray.cluster_utils import Cluster from ray.cluster_utils import Cluster
from ray.dashboard.modules.node.node_consts import (LOG_PRUNE_THREASHOLD, from ray.dashboard.modules.node.node_consts import (
MAX_LOGS_TO_CACHE) LOG_PRUNE_THREASHOLD,
MAX_LOGS_TO_CACHE,
)
from ray.dashboard.tests.conftest import * # noqa from ray.dashboard.tests.conftest import * # noqa
from ray._private.test_utils import ( from ray._private.test_utils import (
format_web_url, wait_until_server_available, wait_for_condition, format_web_url,
wait_until_succeeded_without_exception) wait_until_server_available,
wait_for_condition,
wait_until_succeeded_without_exception,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def test_nodes_update(enable_test_module, ray_start_with_dashboard): def test_nodes_update(enable_test_module, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
is True)
webui_url = ray_start_with_dashboard["webui_url"] webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
@ -44,8 +48,7 @@ def test_nodes_update(enable_test_module, ray_start_with_dashboard):
assert len(dump_data["agents"]) == 1 assert len(dump_data["agents"]) == 1
assert len(dump_data["nodeIdToIp"]) == 1 assert len(dump_data["nodeIdToIp"]) == 1
assert len(dump_data["nodeIdToHostname"]) == 1 assert len(dump_data["nodeIdToHostname"]) == 1
assert dump_data["nodes"].keys() == dump_data[ assert dump_data["nodes"].keys() == dump_data["nodeIdToHostname"].keys()
"nodeIdToHostname"].keys()
response = requests.get(webui_url + "/test/notified_agents") response = requests.get(webui_url + "/test/notified_agents")
response.raise_for_status() response.raise_for_status()
@ -77,8 +80,7 @@ def test_node_info(disable_aiohttp_cache, ray_start_with_dashboard):
actor_pids = [actor.getpid.remote() for actor in actors] actor_pids = [actor.getpid.remote() for actor in actors]
actor_pids = set(ray.get(actor_pids)) actor_pids = set(ray.get(actor_pids))
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
is True)
webui_url = ray_start_with_dashboard["webui_url"] webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
node_id = ray_start_with_dashboard["node_id"] node_id = ray_start_with_dashboard["node_id"]
@ -134,15 +136,19 @@ def test_node_info(disable_aiohttp_cache, ray_start_with_dashboard):
last_ex = ex last_ex = ex
finally: finally:
if time.time() > start_time + timeout_seconds: if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception( ex_stack = (
type(last_ex), last_ex, traceback.format_exception(
last_ex.__traceback__) if last_ex else [] type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack) ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}") raise Exception(f"Timed out while testing, {ex_stack}")
def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard): def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])) assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
@ray.remote @ray.remote
class ActorWithObjs: class ActorWithObjs:
@ -156,8 +162,7 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
actors = [ActorWithObjs.remote() for _ in range(2)] # noqa actors = [ActorWithObjs.remote() for _ in range(2)] # noqa
results = ray.get([actor.get_obj.remote() for actor in actors]) # noqa results = ray.get([actor.get_obj.remote() for actor in actors]) # noqa
webui_url = format_web_url(ray_start_with_dashboard["webui_url"]) webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
resp = requests.get( resp = requests.get(webui_url + "/memory/set_fetch", params={"shouldFetch": "true"})
webui_url + "/memory/set_fetch", params={"shouldFetch": "true"})
resp.raise_for_status() resp.raise_for_status()
def check_mem_table(): def check_mem_table():
@ -172,11 +177,12 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
assert summary["totalLocalRefCount"] == 3 assert summary["totalLocalRefCount"] == 3
assert wait_until_succeeded_without_exception( assert wait_until_succeeded_without_exception(
check_mem_table, (AssertionError, ), timeout_ms=10000) check_mem_table, (AssertionError,), timeout_ms=10000
)
def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard): def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])) assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
webui_url = format_web_url(ray_start_with_dashboard["webui_url"]) webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
@ -220,21 +226,25 @@ def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard):
last_ex = ex last_ex = ex
finally: finally:
if time.time() > start_time + timeout_seconds: if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception( ex_stack = (
type(last_ex), last_ex, traceback.format_exception(
last_ex.__traceback__) if last_ex else [] type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack) ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}") raise Exception(f"Timed out while testing, {ex_stack}")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ray_start_cluster_head", [{ "ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
"include_dashboard": True )
}], indirect=True) def test_multi_nodes_info(
def test_multi_nodes_info(enable_test_module, disable_aiohttp_cache, enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
ray_start_cluster_head): ):
cluster: Cluster = ray_start_cluster_head cluster: Cluster = ray_start_cluster_head
assert (wait_until_server_available(cluster.webui_url) is True) assert wait_until_server_available(cluster.webui_url) is True
webui_url = cluster.webui_url webui_url = cluster.webui_url
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
cluster.add_node() cluster.add_node()
@ -269,13 +279,13 @@ def test_multi_nodes_info(enable_test_module, disable_aiohttp_cache,
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ray_start_cluster_head", [{ "ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
"include_dashboard": True )
}], indirect=True) def test_multi_node_churn(
def test_multi_node_churn(enable_test_module, disable_aiohttp_cache, enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
ray_start_cluster_head): ):
cluster: Cluster = ray_start_cluster_head cluster: Cluster = ray_start_cluster_head
assert (wait_until_server_available(cluster.webui_url) is True) assert wait_until_server_available(cluster.webui_url) is True
webui_url = format_web_url(cluster.webui_url) webui_url = format_web_url(cluster.webui_url)
def cluster_chaos_monkey(): def cluster_chaos_monkey():
@ -315,13 +325,11 @@ def test_multi_node_churn(enable_test_module, disable_aiohttp_cache,
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ray_start_cluster_head", [{ "ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
"include_dashboard": True )
}], indirect=True) def test_logs(enable_test_module, disable_aiohttp_cache, ray_start_cluster_head):
def test_logs(enable_test_module, disable_aiohttp_cache,
ray_start_cluster_head):
cluster = ray_start_cluster_head cluster = ray_start_cluster_head
assert (wait_until_server_available(cluster.webui_url) is True) assert wait_until_server_available(cluster.webui_url) is True
webui_url = cluster.webui_url webui_url = cluster.webui_url
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
nodes = ray.nodes() nodes = ray.nodes()
@ -348,21 +356,18 @@ def test_logs(enable_test_module, disable_aiohttp_cache,
def check_logs(): def check_logs():
node_logs_response = requests.get( node_logs_response = requests.get(
f"{webui_url}/node_logs", params={"ip": node_ip}) f"{webui_url}/node_logs", params={"ip": node_ip}
)
node_logs_response.raise_for_status() node_logs_response.raise_for_status()
node_logs = node_logs_response.json() node_logs = node_logs_response.json()
assert node_logs["result"] assert node_logs["result"]
assert type(node_logs["data"]["logs"]) is dict assert type(node_logs["data"]["logs"]) is dict
assert all( assert all(pid in node_logs["data"]["logs"] for pid in (la_pid, la2_pid))
pid in node_logs["data"]["logs"] for pid in (la_pid, la2_pid))
assert len(node_logs["data"]["logs"][la2_pid]) == 1 assert len(node_logs["data"]["logs"][la2_pid]) == 1
actor_one_logs_response = requests.get( actor_one_logs_response = requests.get(
f"{webui_url}/node_logs", f"{webui_url}/node_logs", params={"ip": node_ip, "pid": str(la_pid)}
params={ )
"ip": node_ip,
"pid": str(la_pid)
})
actor_one_logs_response.raise_for_status() actor_one_logs_response.raise_for_status()
actor_one_logs = actor_one_logs_response.json() actor_one_logs = actor_one_logs_response.json()
assert actor_one_logs["result"] assert actor_one_logs["result"]
@ -370,19 +375,19 @@ def test_logs(enable_test_module, disable_aiohttp_cache,
assert len(actor_one_logs["data"]["logs"][la_pid]) == 4 assert len(actor_one_logs["data"]["logs"][la_pid]) == 4
assert wait_until_succeeded_without_exception( assert wait_until_succeeded_without_exception(
check_logs, (AssertionError, ), timeout_ms=1000) check_logs, (AssertionError,), timeout_ms=1000
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ray_start_cluster_head", [{ "ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
"include_dashboard": True )
}], indirect=True) def test_logs_clean_up(
def test_logs_clean_up(enable_test_module, disable_aiohttp_cache, enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
ray_start_cluster_head): ):
"""Check if logs from the dead pids are GC'ed. """Check if logs from the dead pids are GC'ed."""
"""
cluster = ray_start_cluster_head cluster = ray_start_cluster_head
assert (wait_until_server_available(cluster.webui_url) is True) assert wait_until_server_available(cluster.webui_url) is True
webui_url = cluster.webui_url webui_url = cluster.webui_url
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
nodes = ray.nodes() nodes = ray.nodes()
@ -406,38 +411,41 @@ def test_logs_clean_up(enable_test_module, disable_aiohttp_cache,
def check_logs(): def check_logs():
node_logs_response = requests.get( node_logs_response = requests.get(
f"{webui_url}/node_logs", params={"ip": node_ip}) f"{webui_url}/node_logs", params={"ip": node_ip}
)
node_logs_response.raise_for_status() node_logs_response.raise_for_status()
node_logs = node_logs_response.json() node_logs = node_logs_response.json()
assert node_logs["result"] assert node_logs["result"]
assert la_pid in node_logs["data"]["logs"] assert la_pid in node_logs["data"]["logs"]
assert wait_until_succeeded_without_exception( assert wait_until_succeeded_without_exception(
check_logs, (AssertionError, ), timeout_ms=1000) check_logs, (AssertionError,), timeout_ms=1000
)
ray.kill(la) ray.kill(la)
def check_logs_not_exist(): def check_logs_not_exist():
node_logs_response = requests.get( node_logs_response = requests.get(
f"{webui_url}/node_logs", params={"ip": node_ip}) f"{webui_url}/node_logs", params={"ip": node_ip}
)
node_logs_response.raise_for_status() node_logs_response.raise_for_status()
node_logs = node_logs_response.json() node_logs = node_logs_response.json()
assert node_logs["result"] assert node_logs["result"]
assert la_pid not in node_logs["data"]["logs"] assert la_pid not in node_logs["data"]["logs"]
assert wait_until_succeeded_without_exception( assert wait_until_succeeded_without_exception(
check_logs_not_exist, (AssertionError, ), timeout_ms=10000) check_logs_not_exist, (AssertionError,), timeout_ms=10000
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ray_start_cluster_head", [{ "ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
"include_dashboard": True )
}], indirect=True) def test_logs_max_count(
def test_logs_max_count(enable_test_module, disable_aiohttp_cache, enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
ray_start_cluster_head): ):
"""Test that each Ray worker cannot cache more than 1000 logs at a time. """Test that each Ray worker cannot cache more than 1000 logs at a time."""
"""
cluster = ray_start_cluster_head cluster = ray_start_cluster_head
assert (wait_until_server_available(cluster.webui_url) is True) assert wait_until_server_available(cluster.webui_url) is True
webui_url = cluster.webui_url webui_url = cluster.webui_url
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
nodes = ray.nodes() nodes = ray.nodes()
@ -461,7 +469,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
def check_logs(): def check_logs():
node_logs_response = requests.get( node_logs_response = requests.get(
f"{webui_url}/node_logs", params={"ip": node_ip}) f"{webui_url}/node_logs", params={"ip": node_ip}
)
node_logs_response.raise_for_status() node_logs_response.raise_for_status()
node_logs = node_logs_response.json() node_logs = node_logs_response.json()
assert node_logs["result"] assert node_logs["result"]
@ -472,11 +481,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD
actor_one_logs_response = requests.get( actor_one_logs_response = requests.get(
f"{webui_url}/node_logs", f"{webui_url}/node_logs", params={"ip": node_ip, "pid": str(la_pid)}
params={ )
"ip": node_ip,
"pid": str(la_pid)
})
actor_one_logs_response.raise_for_status() actor_one_logs_response.raise_for_status()
actor_one_logs = actor_one_logs_response.json() actor_one_logs = actor_one_logs_response.json()
assert actor_one_logs["result"] assert actor_one_logs["result"]
@ -486,7 +492,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD
assert wait_until_succeeded_without_exception( assert wait_until_succeeded_without_exception(
check_logs, (AssertionError, ), timeout_ms=10000) check_logs, (AssertionError,), timeout_ms=10000
)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -39,10 +39,12 @@ try:
except (ModuleNotFoundError, ImportError): except (ModuleNotFoundError, ImportError):
gpustat = None gpustat = None
if log_once("gpustat_import_warning"): if log_once("gpustat_import_warning"):
warnings.warn("`gpustat` package is not installed. GPU monitoring is " warnings.warn(
"not available. To have full functionality of the " "`gpustat` package is not installed. GPU monitoring is "
"dashboard please install `pip install ray[" "not available. To have full functionality of the "
"default]`.)") "dashboard please install `pip install ray["
"default]`.)"
)
def recursive_asdict(o): def recursive_asdict(o):
@ -68,68 +70,81 @@ def jsonify_asdict(o) -> str:
# A list of gauges to record and export metrics. # A list of gauges to record and export metrics.
METRICS_GAUGES = { METRICS_GAUGES = {
"node_cpu_utilization": Gauge("node_cpu_utilization", "node_cpu_utilization": Gauge(
"Total CPU usage on a ray node", "node_cpu_utilization", "Total CPU usage on a ray node", "percentage", ["ip"]
"percentage", ["ip"]), ),
"node_cpu_count": Gauge("node_cpu_count", "node_cpu_count": Gauge(
"Total CPUs available on a ray node", "cores", "node_cpu_count", "Total CPUs available on a ray node", "cores", ["ip"]
["ip"]), ),
"node_mem_used": Gauge("node_mem_used", "Memory usage on a ray node", "node_mem_used": Gauge(
"bytes", ["ip"]), "node_mem_used", "Memory usage on a ray node", "bytes", ["ip"]
"node_mem_available": Gauge("node_mem_available", ),
"Memory available on a ray node", "bytes", "node_mem_available": Gauge(
["ip"]), "node_mem_available", "Memory available on a ray node", "bytes", ["ip"]
"node_mem_total": Gauge("node_mem_total", "Total memory on a ray node", ),
"bytes", ["ip"]), "node_mem_total": Gauge(
"node_gpus_available": Gauge("node_gpus_available", "node_mem_total", "Total memory on a ray node", "bytes", ["ip"]
"Total GPUs available on a ray node", ),
"percentage", ["ip"]), "node_gpus_available": Gauge(
"node_gpus_utilization": Gauge("node_gpus_utilization", "node_gpus_available",
"Total GPUs usage on a ray node", "Total GPUs available on a ray node",
"percentage", ["ip"]), "percentage",
"node_gram_used": Gauge("node_gram_used", ["ip"],
"Total GPU RAM usage on a ray node", "bytes", ),
["ip"]), "node_gpus_utilization": Gauge(
"node_gram_available": Gauge("node_gram_available", "node_gpus_utilization", "Total GPUs usage on a ray node", "percentage", ["ip"]
"Total GPU RAM available on a ray node", ),
"bytes", ["ip"]), "node_gram_used": Gauge(
"node_disk_usage": Gauge("node_disk_usage", "node_gram_used", "Total GPU RAM usage on a ray node", "bytes", ["ip"]
"Total disk usage (bytes) on a ray node", "bytes", ),
["ip"]), "node_gram_available": Gauge(
"node_disk_free": Gauge("node_disk_free", "node_gram_available", "Total GPU RAM available on a ray node", "bytes", ["ip"]
"Total disk free (bytes) on a ray node", "bytes", ),
["ip"]), "node_disk_usage": Gauge(
"node_disk_usage", "Total disk usage (bytes) on a ray node", "bytes", ["ip"]
),
"node_disk_free": Gauge(
"node_disk_free", "Total disk free (bytes) on a ray node", "bytes", ["ip"]
),
"node_disk_utilization_percentage": Gauge( "node_disk_utilization_percentage": Gauge(
"node_disk_utilization_percentage", "node_disk_utilization_percentage",
"Total disk utilization (percentage) on a ray node", "percentage", "Total disk utilization (percentage) on a ray node",
["ip"]), "percentage",
"node_network_sent": Gauge("node_network_sent", "Total network sent", ["ip"],
"bytes", ["ip"]), ),
"node_network_received": Gauge("node_network_received", "node_network_sent": Gauge(
"Total network received", "bytes", ["ip"]), "node_network_sent", "Total network sent", "bytes", ["ip"]
),
"node_network_received": Gauge(
"node_network_received", "Total network received", "bytes", ["ip"]
),
"node_network_send_speed": Gauge( "node_network_send_speed": Gauge(
"node_network_send_speed", "Network send speed", "bytes/sec", ["ip"]), "node_network_send_speed", "Network send speed", "bytes/sec", ["ip"]
"node_network_receive_speed": Gauge("node_network_receive_speed", ),
"Network receive speed", "bytes/sec", "node_network_receive_speed": Gauge(
["ip"]), "node_network_receive_speed", "Network receive speed", "bytes/sec", ["ip"]
"raylet_cpu": Gauge("raylet_cpu", "CPU usage of the raylet on a node.", ),
"percentage", ["ip", "pid"]), "raylet_cpu": Gauge(
"raylet_mem": Gauge("raylet_mem", "Memory usage of the raylet on a node", "raylet_cpu", "CPU usage of the raylet on a node.", "percentage", ["ip", "pid"]
"mb", ["ip", "pid"]), ),
"cluster_active_nodes": Gauge("cluster_active_nodes", "raylet_mem": Gauge(
"Active nodes on the cluster", "count", "raylet_mem", "Memory usage of the raylet on a node", "mb", ["ip", "pid"]
["node_type"]), ),
"cluster_failed_nodes": Gauge("cluster_failed_nodes", "cluster_active_nodes": Gauge(
"Failed nodes on the cluster", "count", "cluster_active_nodes", "Active nodes on the cluster", "count", ["node_type"]
["node_type"]), ),
"cluster_pending_nodes": Gauge("cluster_pending_nodes", "cluster_failed_nodes": Gauge(
"Pending nodes on the cluster", "count", "cluster_failed_nodes", "Failed nodes on the cluster", "count", ["node_type"]
["node_type"]), ),
"cluster_pending_nodes": Gauge(
"cluster_pending_nodes", "Pending nodes on the cluster", "count", ["node_type"]
),
} }
class ReporterAgent(dashboard_utils.DashboardAgentModule, class ReporterAgent(
reporter_pb2_grpc.ReporterServiceServicer): dashboard_utils.DashboardAgentModule, reporter_pb2_grpc.ReporterServiceServicer
):
"""A monitor process for monitoring Ray nodes. """A monitor process for monitoring Ray nodes.
Attributes: Attributes:
@ -145,37 +160,39 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
cpu_count = ray._private.utils.get_num_cpus() cpu_count = ray._private.utils.get_num_cpus()
self._cpu_counts = (cpu_count, cpu_count) self._cpu_counts = (cpu_count, cpu_count)
else: else:
self._cpu_counts = (psutil.cpu_count(), self._cpu_counts = (psutil.cpu_count(), psutil.cpu_count(logical=False))
psutil.cpu_count(logical=False))
self._ip = dashboard_agent.ip self._ip = dashboard_agent.ip
if not use_gcs_for_bootstrap(): if not use_gcs_for_bootstrap():
self._redis_address, _ = dashboard_agent.redis_address self._redis_address, _ = dashboard_agent.redis_address
self._is_head_node = (self._ip == self._redis_address) self._is_head_node = self._ip == self._redis_address
else: else:
self._is_head_node = ( self._is_head_node = self._ip == dashboard_agent.gcs_address.split(":")[0]
self._ip == dashboard_agent.gcs_address.split(":")[0])
self._hostname = socket.gethostname() self._hostname = socket.gethostname()
self._workers = set() self._workers = set()
self._network_stats_hist = [(0, (0.0, 0.0))] # time, (sent, recv) self._network_stats_hist = [(0, (0.0, 0.0))] # time, (sent, recv)
self._metrics_agent = MetricsAgent( self._metrics_agent = MetricsAgent(
"127.0.0.1" if self._ip == "127.0.0.1" else "", "127.0.0.1" if self._ip == "127.0.0.1" else "",
dashboard_agent.metrics_export_port) dashboard_agent.metrics_export_port,
self._key = f"{reporter_consts.REPORTER_PREFIX}" \ )
f"{self._dashboard_agent.node_id}" self._key = (
f"{reporter_consts.REPORTER_PREFIX}" f"{self._dashboard_agent.node_id}"
)
async def GetProfilingStats(self, request, context): async def GetProfilingStats(self, request, context):
pid = request.pid pid = request.pid
duration = request.duration duration = request.duration
profiling_file_path = os.path.join( profiling_file_path = os.path.join(
ray._private.utils.get_ray_temp_dir(), f"{pid}_profiling.txt") ray._private.utils.get_ray_temp_dir(), f"{pid}_profiling.txt"
)
sudo = "sudo" if ray._private.utils.get_user() != "root" else "" sudo = "sudo" if ray._private.utils.get_user() != "root" else ""
process = await asyncio.create_subprocess_shell( process = await asyncio.create_subprocess_shell(
f"{sudo} $(which py-spy) record " f"{sudo} $(which py-spy) record "
f"-o {profiling_file_path} -p {pid} -d {duration} -f speedscope", f"-o {profiling_file_path} -p {pid} -d {duration} -f speedscope",
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
shell=True) shell=True,
)
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode != 0: if process.returncode != 0:
profiling_stats = "" profiling_stats = ""
@ -183,14 +200,14 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
with open(profiling_file_path, "r") as f: with open(profiling_file_path, "r") as f:
profiling_stats = f.read() profiling_stats = f.read()
return reporter_pb2.GetProfilingStatsReply( return reporter_pb2.GetProfilingStatsReply(
profiling_stats=profiling_stats, std_out=stdout, std_err=stderr) profiling_stats=profiling_stats, std_out=stdout, std_err=stderr
)
async def ReportOCMetrics(self, request, context): async def ReportOCMetrics(self, request, context):
# This function receives a GRPC containing OpenCensus (OC) metrics # This function receives a GRPC containing OpenCensus (OC) metrics
# from a Ray process, then exposes those metrics to Prometheus. # from a Ray process, then exposes those metrics to Prometheus.
try: try:
self._metrics_agent.record_metric_points_from_protobuf( self._metrics_agent.record_metric_points_from_protobuf(request.metrics)
request.metrics)
except Exception: except Exception:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return reporter_pb2.ReportOCMetricsReply() return reporter_pb2.ReportOCMetricsReply()
@ -227,10 +244,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
for gpu in gpus: for gpu in gpus:
# Note the keys in this dict have periods which throws # Note the keys in this dict have periods which throws
# off javascript so we change .s to _s # off javascript so we change .s to _s
gpu_data = { gpu_data = {"_".join(key.split(".")): val for key, val in gpu.entry.items()}
"_".join(key.split(".")): val
for key, val in gpu.entry.items()
}
gpu_utilizations.append(gpu_data) gpu_utilizations.append(gpu_data)
return gpu_utilizations return gpu_utilizations
@ -245,8 +259,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
@staticmethod @staticmethod
def _get_network_stats(): def _get_network_stats():
ifaces = [ ifaces = [
v for k, v in psutil.net_io_counters(pernic=True).items() v for k, v in psutil.net_io_counters(pernic=True).items() if k[0] == "e"
if k[0] == "e"
] ]
sent = sum((iface.bytes_sent for iface in ifaces)) sent = sum((iface.bytes_sent for iface in ifaces))
@ -266,8 +279,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
if IN_KUBERNETES_POD: if IN_KUBERNETES_POD:
# If in a K8s pod, disable disk display by passing in dummy values. # If in a K8s pod, disable disk display by passing in dummy values.
return { return {
"/": psutil._common.sdiskusage( "/": psutil._common.sdiskusage(total=1, used=0, free=1, percent=0.0)
total=1, used=0, free=1, percent=0.0)
} }
root = os.environ["USERPROFILE"] if sys.platform == "win32" else os.sep root = os.environ["USERPROFILE"] if sys.platform == "win32" else os.sep
tmp = ray._private.utils.get_user_temp_dir() tmp = ray._private.utils.get_user_temp_dir()
@ -286,14 +298,18 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
self._workers.update(workers) self._workers.update(workers)
self._workers.discard(psutil.Process()) self._workers.discard(psutil.Process())
return [ return [
w.as_dict(attrs=[ w.as_dict(
"pid", attrs=[
"create_time", "pid",
"cpu_percent", "create_time",
"cpu_times", "cpu_percent",
"cmdline", "cpu_times",
"memory_info", "cmdline",
]) for w in self._workers if w.status() != psutil.STATUS_ZOMBIE "memory_info",
]
)
for w in self._workers
if w.status() != psutil.STATUS_ZOMBIE
] ]
@staticmethod @staticmethod
@ -318,14 +334,16 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
if raylet_proc is None: if raylet_proc is None:
return {} return {}
else: else:
return raylet_proc.as_dict(attrs=[ return raylet_proc.as_dict(
"pid", attrs=[
"create_time", "pid",
"cpu_percent", "create_time",
"cpu_times", "cpu_percent",
"cmdline", "cpu_times",
"memory_info", "cmdline",
]) "memory_info",
]
)
def _get_load_avg(self): def _get_load_avg(self):
if sys.platform == "win32": if sys.platform == "win32":
@ -345,8 +363,10 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
then, prev_network_stats = self._network_stats_hist[0] then, prev_network_stats = self._network_stats_hist[0]
prev_send, prev_recv = prev_network_stats prev_send, prev_recv = prev_network_stats
now_send, now_recv = network_stats now_send, now_recv = network_stats
network_speed_stats = ((now_send - prev_send) / (now - then), network_speed_stats = (
(now_recv - prev_recv) / (now - then)) (now_send - prev_send) / (now - then),
(now_recv - prev_recv) / (now - then),
)
return { return {
"now": now, "now": now,
"hostname": self._hostname, "hostname": self._hostname,
@ -379,7 +399,9 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
Record( Record(
gauge=METRICS_GAUGES["cluster_active_nodes"], gauge=METRICS_GAUGES["cluster_active_nodes"],
value=active_node_count, value=active_node_count,
tags={"node_type": node_type})) tags={"node_type": node_type},
)
)
failed_nodes = cluster_stats["autoscaler_report"]["failed_nodes"] failed_nodes = cluster_stats["autoscaler_report"]["failed_nodes"]
failed_nodes_dict = {} failed_nodes_dict = {}
@ -394,7 +416,9 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
Record( Record(
gauge=METRICS_GAUGES["cluster_failed_nodes"], gauge=METRICS_GAUGES["cluster_failed_nodes"],
value=failed_node_count, value=failed_node_count,
tags={"node_type": node_type})) tags={"node_type": node_type},
)
)
pending_nodes = cluster_stats["autoscaler_report"]["pending_nodes"] pending_nodes = cluster_stats["autoscaler_report"]["pending_nodes"]
pending_nodes_dict = {} pending_nodes_dict = {}
@ -409,35 +433,36 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
Record( Record(
gauge=METRICS_GAUGES["cluster_pending_nodes"], gauge=METRICS_GAUGES["cluster_pending_nodes"],
value=pending_node_count, value=pending_node_count,
tags={"node_type": node_type})) tags={"node_type": node_type},
)
)
# -- CPU per node -- # -- CPU per node --
cpu_usage = float(stats["cpu"]) cpu_usage = float(stats["cpu"])
cpu_record = Record( cpu_record = Record(
gauge=METRICS_GAUGES["node_cpu_utilization"], gauge=METRICS_GAUGES["node_cpu_utilization"],
value=cpu_usage, value=cpu_usage,
tags={"ip": ip}) tags={"ip": ip},
)
cpu_count, _ = stats["cpus"] cpu_count, _ = stats["cpus"]
cpu_count_record = Record( cpu_count_record = Record(
gauge=METRICS_GAUGES["node_cpu_count"], gauge=METRICS_GAUGES["node_cpu_count"], value=cpu_count, tags={"ip": ip}
value=cpu_count, )
tags={"ip": ip})
# -- Mem per node -- # -- Mem per node --
mem_total, mem_available, _, mem_used = stats["mem"] mem_total, mem_available, _, mem_used = stats["mem"]
mem_used_record = Record( mem_used_record = Record(
gauge=METRICS_GAUGES["node_mem_used"], gauge=METRICS_GAUGES["node_mem_used"], value=mem_used, tags={"ip": ip}
value=mem_used, )
tags={"ip": ip})
mem_available_record = Record( mem_available_record = Record(
gauge=METRICS_GAUGES["node_mem_available"], gauge=METRICS_GAUGES["node_mem_available"],
value=mem_available, value=mem_available,
tags={"ip": ip}) tags={"ip": ip},
)
mem_total_record = Record( mem_total_record = Record(
gauge=METRICS_GAUGES["node_mem_total"], gauge=METRICS_GAUGES["node_mem_total"], value=mem_total, tags={"ip": ip}
value=mem_total, )
tags={"ip": ip})
# -- GPU per node -- # -- GPU per node --
gpus = stats["gpus"] gpus = stats["gpus"]
@ -455,23 +480,29 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
gpus_available_record = Record( gpus_available_record = Record(
gauge=METRICS_GAUGES["node_gpus_available"], gauge=METRICS_GAUGES["node_gpus_available"],
value=gpus_available, value=gpus_available,
tags={"ip": ip}) tags={"ip": ip},
)
gpus_utilization_record = Record( gpus_utilization_record = Record(
gauge=METRICS_GAUGES["node_gpus_utilization"], gauge=METRICS_GAUGES["node_gpus_utilization"],
value=gpus_utilization, value=gpus_utilization,
tags={"ip": ip}) tags={"ip": ip},
)
gram_used_record = Record( gram_used_record = Record(
gauge=METRICS_GAUGES["node_gram_used"], gauge=METRICS_GAUGES["node_gram_used"], value=gram_used, tags={"ip": ip}
value=gram_used, )
tags={"ip": ip})
gram_available_record = Record( gram_available_record = Record(
gauge=METRICS_GAUGES["node_gram_available"], gauge=METRICS_GAUGES["node_gram_available"],
value=gram_available, value=gram_available,
tags={"ip": ip}) tags={"ip": ip},
records_reported.extend([ )
gpus_available_record, gpus_utilization_record, records_reported.extend(
gram_used_record, gram_available_record [
]) gpus_available_record,
gpus_utilization_record,
gram_used_record,
gram_available_record,
]
)
# -- Disk per node -- # -- Disk per node --
used, free = 0, 0 used, free = 0, 0
@ -480,39 +511,42 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
free += entry.free free += entry.free
disk_utilization = float(used / (used + free)) * 100 disk_utilization = float(used / (used + free)) * 100
disk_usage_record = Record( disk_usage_record = Record(
gauge=METRICS_GAUGES["node_disk_usage"], gauge=METRICS_GAUGES["node_disk_usage"], value=used, tags={"ip": ip}
value=used, )
tags={"ip": ip})
disk_free_record = Record( disk_free_record = Record(
gauge=METRICS_GAUGES["node_disk_free"], gauge=METRICS_GAUGES["node_disk_free"], value=free, tags={"ip": ip}
value=free, )
tags={"ip": ip})
disk_utilization_percentage_record = Record( disk_utilization_percentage_record = Record(
gauge=METRICS_GAUGES["node_disk_utilization_percentage"], gauge=METRICS_GAUGES["node_disk_utilization_percentage"],
value=disk_utilization, value=disk_utilization,
tags={"ip": ip}) tags={"ip": ip},
)
# -- Network speed (send/receive) stats per node -- # -- Network speed (send/receive) stats per node --
network_stats = stats["network"] network_stats = stats["network"]
network_sent_record = Record( network_sent_record = Record(
gauge=METRICS_GAUGES["node_network_sent"], gauge=METRICS_GAUGES["node_network_sent"],
value=network_stats[0], value=network_stats[0],
tags={"ip": ip}) tags={"ip": ip},
)
network_received_record = Record( network_received_record = Record(
gauge=METRICS_GAUGES["node_network_received"], gauge=METRICS_GAUGES["node_network_received"],
value=network_stats[1], value=network_stats[1],
tags={"ip": ip}) tags={"ip": ip},
)
# -- Network speed (send/receive) per node -- # -- Network speed (send/receive) per node --
network_speed_stats = stats["network_speed"] network_speed_stats = stats["network_speed"]
network_send_speed_record = Record( network_send_speed_record = Record(
gauge=METRICS_GAUGES["node_network_send_speed"], gauge=METRICS_GAUGES["node_network_send_speed"],
value=network_speed_stats[0], value=network_speed_stats[0],
tags={"ip": ip}) tags={"ip": ip},
)
network_receive_speed_record = Record( network_receive_speed_record = Record(
gauge=METRICS_GAUGES["node_network_receive_speed"], gauge=METRICS_GAUGES["node_network_receive_speed"],
value=network_speed_stats[1], value=network_speed_stats[1],
tags={"ip": ip}) tags={"ip": ip},
)
raylet_stats = stats["raylet"] raylet_stats = stats["raylet"]
if raylet_stats: if raylet_stats:
@ -522,29 +556,34 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
raylet_cpu_record = Record( raylet_cpu_record = Record(
gauge=METRICS_GAUGES["raylet_cpu"], gauge=METRICS_GAUGES["raylet_cpu"],
value=raylet_cpu_usage, value=raylet_cpu_usage,
tags={ tags={"ip": ip, "pid": raylet_pid},
"ip": ip, )
"pid": raylet_pid
})
# -- raylet mem -- # -- raylet mem --
raylet_mem_usage = float(raylet_stats["memory_info"].rss) / 1e6 raylet_mem_usage = float(raylet_stats["memory_info"].rss) / 1e6
raylet_mem_record = Record( raylet_mem_record = Record(
gauge=METRICS_GAUGES["raylet_mem"], gauge=METRICS_GAUGES["raylet_mem"],
value=raylet_mem_usage, value=raylet_mem_usage,
tags={ tags={"ip": ip, "pid": raylet_pid},
"ip": ip, )
"pid": raylet_pid
})
records_reported.extend([raylet_cpu_record, raylet_mem_record]) records_reported.extend([raylet_cpu_record, raylet_mem_record])
records_reported.extend([ records_reported.extend(
cpu_record, cpu_count_record, mem_used_record, [
mem_available_record, mem_total_record, disk_usage_record, cpu_record,
disk_free_record, disk_utilization_percentage_record, cpu_count_record,
network_sent_record, network_received_record, mem_used_record,
network_send_speed_record, network_receive_speed_record mem_available_record,
]) mem_total_record,
disk_usage_record,
disk_free_record,
disk_utilization_percentage_record,
network_sent_record,
network_received_record,
network_send_speed_record,
network_receive_speed_record,
]
)
return records_reported return records_reported
async def _perform_iteration(self, publish): async def _perform_iteration(self, publish):
@ -552,9 +591,13 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
while True: while True:
try: try:
formatted_status_string = internal_kv._internal_kv_get( formatted_status_string = internal_kv._internal_kv_get(
DEBUG_AUTOSCALING_STATUS) DEBUG_AUTOSCALING_STATUS
cluster_stats = json.loads(formatted_status_string.decode( )
)) if formatted_status_string else {} cluster_stats = (
json.loads(formatted_status_string.decode())
if formatted_status_string
else {}
)
stats = self._get_all_stats() stats = self._get_all_stats()
records_reported = self._record_stats(stats, cluster_stats) records_reported = self._record_stats(stats, cluster_stats)
@ -563,8 +606,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
except Exception: except Exception:
logger.exception("Error publishing node physical stats.") logger.exception("Error publishing node physical stats.")
await asyncio.sleep( await asyncio.sleep(reporter_consts.REPORTER_UPDATE_INTERVAL_MS / 1000)
reporter_consts.REPORTER_UPDATE_INTERVAL_MS / 1000)
async def run(self, server): async def run(self, server):
reporter_pb2_grpc.add_ReporterServiceServicer_to_server(self, server) reporter_pb2_grpc.add_ReporterServiceServicer_to_server(self, server)
@ -573,17 +615,20 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
if gcs_addr is None: if gcs_addr is None:
aioredis_client = await aioredis.create_redis_pool( aioredis_client = await aioredis.create_redis_pool(
address=self._dashboard_agent.redis_address, address=self._dashboard_agent.redis_address,
password=self._dashboard_agent.redis_password) password=self._dashboard_agent.redis_password,
)
gcs_addr = await aioredis_client.get("GcsServerAddress") gcs_addr = await aioredis_client.get("GcsServerAddress")
gcs_addr = gcs_addr.decode() gcs_addr = gcs_addr.decode()
publisher = GcsAioPublisher(address=gcs_addr) publisher = GcsAioPublisher(address=gcs_addr)
async def publish(key: str, data: str): async def publish(key: str, data: str):
await publisher.publish_resource_usage(key, data) await publisher.publish_resource_usage(key, data)
else: else:
aioredis_client = await aioredis.create_redis_pool( aioredis_client = await aioredis.create_redis_pool(
address=self._dashboard_agent.redis_address, address=self._dashboard_agent.redis_address,
password=self._dashboard_agent.redis_password) password=self._dashboard_agent.redis_password,
)
async def publish(key: str, data: str): async def publish(key: str, data: str):
await aioredis_client.publish(key, data) await aioredis_client.publish(key, data)

View file

@ -3,4 +3,5 @@ import ray.ray_constants as ray_constants
REPORTER_PREFIX = "RAY_REPORTER:" REPORTER_PREFIX = "RAY_REPORTER:"
# The reporter will report its statistics this often (milliseconds). # The reporter will report its statistics this often (milliseconds).
REPORTER_UPDATE_INTERVAL_MS = ray_constants.env_integer( REPORTER_UPDATE_INTERVAL_MS = ray_constants.env_integer(
"REPORTER_UPDATE_INTERVAL_MS", 2500) "REPORTER_UPDATE_INTERVAL_MS", 2500
)

View file

@ -9,13 +9,14 @@ import ray
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.optional_utils as dashboard_optional_utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, \ from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioResourceUsageSubscriber
GcsAioResourceUsageSubscriber
import ray._private.services import ray._private.services
import ray._private.utils import ray._private.utils
from ray.ray_constants import (DEBUG_AUTOSCALING_STATUS, from ray.ray_constants import (
DEBUG_AUTOSCALING_STATUS_LEGACY, DEBUG_AUTOSCALING_STATUS,
DEBUG_AUTOSCALING_ERROR) DEBUG_AUTOSCALING_STATUS_LEGACY,
DEBUG_AUTOSCALING_ERROR,
)
from ray.core.generated import reporter_pb2 from ray.core.generated import reporter_pb2
from ray.core.generated import reporter_pb2_grpc from ray.core.generated import reporter_pb2_grpc
import ray.experimental.internal_kv as internal_kv import ray.experimental.internal_kv as internal_kv
@ -40,9 +41,10 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
if change.new: if change.new:
node_id, ports = change.new node_id, ports = change.new
ip = DataSource.node_id_to_ip[node_id] ip = DataSource.node_id_to_ip[node_id]
options = (("grpc.enable_http_proxy", 0), ) options = (("grpc.enable_http_proxy", 0),)
channel = ray._private.utils.init_grpc_channel( channel = ray._private.utils.init_grpc_channel(
f"{ip}:{ports[1]}", options=options, asynchronous=True) f"{ip}:{ports[1]}", options=options, asynchronous=True
)
stub = reporter_pb2_grpc.ReporterServiceStub(channel) stub = reporter_pb2_grpc.ReporterServiceStub(channel)
self._stubs[ip] = stub self._stubs[ip] = stub
@ -53,13 +55,16 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
duration = int(req.query["duration"]) duration = int(req.query["duration"])
reporter_stub = self._stubs[ip] reporter_stub = self._stubs[ip]
reply = await reporter_stub.GetProfilingStats( reply = await reporter_stub.GetProfilingStats(
reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration)) reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration)
profiling_info = (json.loads(reply.profiling_stats) )
if reply.profiling_stats else reply.std_out) profiling_info = (
json.loads(reply.profiling_stats)
if reply.profiling_stats
else reply.std_out
)
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, success=True, message="Profiling success.", profiling_info=profiling_info
message="Profiling success.", )
profiling_info=profiling_info)
@routes.get("/api/ray_config") @routes.get("/api/ray_config")
async def get_ray_config(self, req) -> aiohttp.web.Response: async def get_ray_config(self, req) -> aiohttp.web.Response:
@ -75,12 +80,12 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
) )
except FileNotFoundError: except FileNotFoundError:
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=False, success=False, message="Invalid config, could not load YAML."
message="Invalid config, could not load YAML.") )
payload = { payload = {
"min_workers": cfg.get("min_workers", "unspecified"), "min_workers": cfg.get("min_workers", "unspecified"),
"max_workers": cfg.get("max_workers", "unspecified") "max_workers": cfg.get("max_workers", "unspecified"),
} }
try: try:
@ -115,18 +120,18 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
""" """
assert ray.experimental.internal_kv._internal_kv_initialized() assert ray.experimental.internal_kv._internal_kv_initialized()
legacy_status = internal_kv._internal_kv_get( legacy_status = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS_LEGACY)
DEBUG_AUTOSCALING_STATUS_LEGACY) formatted_status_string = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS)
formatted_status_string = internal_kv._internal_kv_get( formatted_status = (
DEBUG_AUTOSCALING_STATUS) json.loads(formatted_status_string.decode())
formatted_status = json.loads(formatted_status_string.decode() if formatted_status_string
) if formatted_status_string else {} else {}
)
error = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_ERROR) error = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_ERROR)
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, success=True,
message="Got cluster status.", message="Got cluster status.",
autoscaling_status=legacy_status.decode() autoscaling_status=legacy_status.decode() if legacy_status else None,
if legacy_status else None,
autoscaling_error=error.decode() if error else None, autoscaling_error=error.decode() if error else None,
cluster_status=formatted_status if formatted_status else None, cluster_status=formatted_status if formatted_status else None,
) )
@ -148,8 +153,9 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
node_id = key.split(":")[-1] node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = data DataSource.node_physical_stats[node_id] = data
except Exception: except Exception:
logger.exception("Error receiving node physical stats " logger.exception(
"from reporter agent.") "Error receiving node physical stats " "from reporter agent."
)
else: else:
receiver = Receiver() receiver = Receiver()
aioredis_client = self._dashboard_head.aioredis_client aioredis_client = self._dashboard_head.aioredis_client
@ -165,8 +171,9 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
node_id = key.split(":")[-1] node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = data DataSource.node_physical_stats[node_id] = data
except Exception: except Exception:
logger.exception("Error receiving node physical stats " logger.exception(
"from reporter agent.") "Error receiving node physical stats " "from reporter agent."
)
@staticmethod @staticmethod
def is_minimal_module(): def is_minimal_module():

View file

@ -10,9 +10,13 @@ from ray import ray_constants
from ray.dashboard.tests.conftest import * # noqa from ray.dashboard.tests.conftest import * # noqa
from ray.dashboard.utils import Bunch from ray.dashboard.utils import Bunch
from ray.dashboard.modules.reporter.reporter_agent import ReporterAgent from ray.dashboard.modules.reporter.reporter_agent import ReporterAgent
from ray._private.test_utils import (format_web_url, RayTestTimeoutException, from ray._private.test_utils import (
wait_until_server_available, format_web_url,
wait_for_condition, fetch_prometheus) RayTestTimeoutException,
wait_until_server_available,
wait_for_condition,
fetch_prometheus,
)
try: try:
import prometheus_client import prometheus_client
@ -34,7 +38,7 @@ def test_profiling(shutdown_only):
actor_pid = ray.get(c.getpid.remote()) actor_pid = ray.get(c.getpid.remote())
webui_url = addresses["webui_url"] webui_url = addresses["webui_url"]
assert (wait_until_server_available(webui_url) is True) assert wait_until_server_available(webui_url) is True
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
start_time = time.time() start_time = time.time()
@ -44,14 +48,16 @@ def test_profiling(shutdown_only):
if time.time() - start_time > 15: if time.time() - start_time > 15:
raise RayTestTimeoutException( raise RayTestTimeoutException(
"Timed out while collecting profiling stats, " "Timed out while collecting profiling stats, "
f"launch_profiling: {launch_profiling}") f"launch_profiling: {launch_profiling}"
)
launch_profiling = requests.get( launch_profiling = requests.get(
webui_url + "/api/launch_profiling", webui_url + "/api/launch_profiling",
params={ params={
"ip": ray.nodes()[0]["NodeManagerAddress"], "ip": ray.nodes()[0]["NodeManagerAddress"],
"pid": actor_pid, "pid": actor_pid,
"duration": 5 "duration": 5,
}).json() },
).json()
if launch_profiling["result"]: if launch_profiling["result"]:
profiling_info = launch_profiling["data"]["profilingInfo"] profiling_info = launch_profiling["data"]["profilingInfo"]
break break
@ -72,13 +78,12 @@ def test_node_physical_stats(enable_test_module, shutdown_only):
actor_pids = set(actor_pids) actor_pids = set(actor_pids)
webui_url = addresses["webui_url"] webui_url = addresses["webui_url"]
assert (wait_until_server_available(webui_url) is True) assert wait_until_server_available(webui_url) is True
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
def _check_workers(): def _check_workers():
try: try:
resp = requests.get(webui_url + resp = requests.get(webui_url + "/test/dump?key=node_physical_stats")
"/test/dump?key=node_physical_stats")
resp.raise_for_status() resp.raise_for_status()
result = resp.json() result = resp.json()
assert result["result"] is True assert result["result"] is True
@ -101,8 +106,7 @@ def test_node_physical_stats(enable_test_module, shutdown_only):
wait_for_condition(_check_workers, timeout=10) wait_for_condition(_check_workers, timeout=10)
@pytest.mark.skipif( @pytest.mark.skipif(prometheus_client is None, reason="prometheus_client not installed")
prometheus_client is None, reason="prometheus_client not installed")
def test_prometheus_physical_stats_record(enable_test_module, shutdown_only): def test_prometheus_physical_stats_record(enable_test_module, shutdown_only):
addresses = ray.init(include_dashboard=True, num_cpus=1) addresses = ray.init(include_dashboard=True, num_cpus=1)
metrics_export_port = addresses["metrics_export_port"] metrics_export_port = addresses["metrics_export_port"]
@ -110,29 +114,31 @@ def test_prometheus_physical_stats_record(enable_test_module, shutdown_only):
prom_addresses = [f"{addr}:{metrics_export_port}"] prom_addresses = [f"{addr}:{metrics_export_port}"]
def test_case_stats_exist(): def test_case_stats_exist():
components_dict, metric_names, metric_samples = fetch_prometheus( components_dict, metric_names, metric_samples = fetch_prometheus(prom_addresses)
prom_addresses) return all(
return all([ [
"ray_node_cpu_utilization" in metric_names, "ray_node_cpu_utilization" in metric_names,
"ray_node_cpu_count" in metric_names, "ray_node_cpu_count" in metric_names,
"ray_node_mem_used" in metric_names, "ray_node_mem_used" in metric_names,
"ray_node_mem_available" in metric_names, "ray_node_mem_available" in metric_names,
"ray_node_mem_total" in metric_names, "ray_node_mem_total" in metric_names,
"ray_raylet_cpu" in metric_names, "ray_raylet_mem" in metric_names, "ray_raylet_cpu" in metric_names,
"ray_node_disk_usage" in metric_names, "ray_raylet_mem" in metric_names,
"ray_node_disk_free" in metric_names, "ray_node_disk_usage" in metric_names,
"ray_node_disk_utilization_percentage" in metric_names, "ray_node_disk_free" in metric_names,
"ray_node_network_sent" in metric_names, "ray_node_disk_utilization_percentage" in metric_names,
"ray_node_network_received" in metric_names, "ray_node_network_sent" in metric_names,
"ray_node_network_send_speed" in metric_names, "ray_node_network_received" in metric_names,
"ray_node_network_receive_speed" in metric_names "ray_node_network_send_speed" in metric_names,
]) "ray_node_network_receive_speed" in metric_names,
]
)
def test_case_ip_correct(): def test_case_ip_correct():
components_dict, metric_names, metric_samples = fetch_prometheus( components_dict, metric_names, metric_samples = fetch_prometheus(prom_addresses)
prom_addresses)
raylet_proc = ray.worker._global_node.all_processes[ raylet_proc = ray.worker._global_node.all_processes[
ray_constants.PROCESS_TYPE_RAYLET][0] ray_constants.PROCESS_TYPE_RAYLET
][0]
raylet_pid = None raylet_pid = None
# Find the raylet pid recorded in the tag. # Find the raylet pid recorded in the tag.
for sample in metric_samples: for sample in metric_samples:
@ -159,24 +165,25 @@ def test_report_stats():
"cpu": 57.4, "cpu": 57.4,
"cpus": (8, 4), "cpus": (8, 4),
"mem": (17179869184, 5723353088, 66.7, 9234341888), "mem": (17179869184, 5723353088, 66.7, 9234341888),
"workers": [{ "workers": [
"memory_info": Bunch( {
rss=55934976, vms=7026937856, pfaults=15354, pageins=0), "memory_info": Bunch(
"cpu_percent": 0.0, rss=55934976, vms=7026937856, pfaults=15354, pageins=0
"cmdline": [ ),
"ray::IDLE", "", "", "", "", "", "", "", "", "", "", "" "cpu_percent": 0.0,
], "cmdline": ["ray::IDLE", "", "", "", "", "", "", "", "", "", "", ""],
"create_time": 1614826391.338613, "create_time": 1614826391.338613,
"pid": 7174, "pid": 7174,
"cpu_times": Bunch( "cpu_times": Bunch(
user=0.607899328, user=0.607899328,
system=0.274044032, system=0.274044032,
children_user=0.0, children_user=0.0,
children_system=0.0) children_system=0.0,
}], ),
}
],
"raylet": { "raylet": {
"memory_info": Bunch( "memory_info": Bunch(rss=18354176, vms=6921486336, pfaults=6206, pageins=3),
rss=18354176, vms=6921486336, pfaults=6206, pageins=3),
"cpu_percent": 0.0, "cpu_percent": 0.0,
"cmdline": ["fake raylet cmdline"], "cmdline": ["fake raylet cmdline"],
"create_time": 1614826390.274854, "create_time": 1614826390.274854,
@ -185,22 +192,18 @@ def test_report_stats():
user=0.03683138, user=0.03683138,
system=0.035913716, system=0.035913716,
children_user=0.0, children_user=0.0,
children_system=0.0) children_system=0.0,
),
}, },
"bootTime": 1612934656.0, "bootTime": 1612934656.0,
"loadAvg": ((4.4521484375, 3.61083984375, 3.5400390625), (0.56, 0.45, "loadAvg": ((4.4521484375, 3.61083984375, 3.5400390625), (0.56, 0.45, 0.44)),
0.44)),
"disk": { "disk": {
"/": Bunch( "/": Bunch(
total=250790436864, total=250790436864, used=11316781056, free=22748921856, percent=33.2
used=11316781056, ),
free=22748921856,
percent=33.2),
"/tmp": Bunch( "/tmp": Bunch(
total=250790436864, total=250790436864, used=209532035072, free=22748921856, percent=90.2
used=209532035072, ),
free=22748921856,
percent=90.2)
}, },
"gpus": [], "gpus": [],
"network": (13621160960, 11914936320), "network": (13621160960, 11914936320),
@ -209,13 +212,10 @@ def test_report_stats():
cluster_stats = { cluster_stats = {
"autoscaler_report": { "autoscaler_report": {
"active_nodes": { "active_nodes": {"head_node": 1, "worker-node-0": 2},
"head_node": 1,
"worker-node-0": 2
},
"failed_nodes": [], "failed_nodes": [],
"pending_launches": {}, "pending_launches": {},
"pending_nodes": [] "pending_nodes": [],
} }
} }
@ -226,11 +226,9 @@ def test_report_stats():
records = ReporterAgent._record_stats(obj, test_stats, cluster_stats) records = ReporterAgent._record_stats(obj, test_stats, cluster_stats)
assert len(records) == 14 assert len(records) == 14
# Test stats with gpus # Test stats with gpus
test_stats["gpus"] = [{ test_stats["gpus"] = [
"utilization_gpu": 1, {"utilization_gpu": 1, "memory_used": 100, "memory_total": 1000}
"memory_used": 100, ]
"memory_total": 1000
}]
records = ReporterAgent._record_stats(obj, test_stats, cluster_stats) records = ReporterAgent._record_stats(obj, test_stats, cluster_stats)
assert len(records) == 18 assert len(records) == 18
# Test stats without autoscaler report # Test stats without autoscaler report

View file

@ -12,10 +12,11 @@ from ray.core.generated import runtime_env_agent_pb2
from ray.core.generated import runtime_env_agent_pb2_grpc from ray.core.generated import runtime_env_agent_pb2_grpc
from ray.core.generated import agent_manager_pb2 from ray.core.generated import agent_manager_pb2
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
import ray.dashboard.modules.runtime_env.runtime_env_consts \ import ray.dashboard.modules.runtime_env.runtime_env_consts as runtime_env_consts
as runtime_env_consts from ray.experimental.internal_kv import (
from ray.experimental.internal_kv import _internal_kv_initialized, \ _internal_kv_initialized,
_initialize_internal_kv _initialize_internal_kv,
)
from ray._private.ray_logging import setup_component_logger from ray._private.ray_logging import setup_component_logger
from ray._private.runtime_env.pip import PipManager from ray._private.runtime_env.pip import PipManager
from ray._private.runtime_env.conda import CondaManager from ray._private.runtime_env.conda import CondaManager
@ -42,8 +43,10 @@ class CreatedEnvResult:
result: str result: str
class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule, class RuntimeEnvAgent(
runtime_env_agent_pb2_grpc.RuntimeEnvServiceServicer): dashboard_utils.DashboardAgentModule,
runtime_env_agent_pb2_grpc.RuntimeEnvServiceServicer,
):
"""An RPC server to create and delete runtime envs. """An RPC server to create and delete runtime envs.
Attributes: Attributes:
@ -86,32 +89,33 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
return self._per_job_logger_cache[job_id] return self._per_job_logger_cache[job_id]
async def CreateRuntimeEnv(self, request, context): async def CreateRuntimeEnv(self, request, context):
async def _setup_runtime_env(serialized_runtime_env, async def _setup_runtime_env(
serialized_allocated_resource_instances): serialized_runtime_env, serialized_allocated_resource_instances
):
# This function will be ran inside a thread # This function will be ran inside a thread
def run_setup_with_logger(): def run_setup_with_logger():
runtime_env = RuntimeEnv( runtime_env = RuntimeEnv(serialized_runtime_env=serialized_runtime_env)
serialized_runtime_env=serialized_runtime_env)
allocated_resource: dict = json.loads( allocated_resource: dict = json.loads(
serialized_allocated_resource_instances or "{}") serialized_allocated_resource_instances or "{}"
)
# Use a separate logger for each job. # Use a separate logger for each job.
per_job_logger = self.get_or_create_logger(request.job_id) per_job_logger = self.get_or_create_logger(request.job_id)
# TODO(chenk008): Add log about allocated_resource to # TODO(chenk008): Add log about allocated_resource to
# avoid lint error. That will be moved to cgroup plugin. # avoid lint error. That will be moved to cgroup plugin.
per_job_logger.debug(f"Worker has resource :" per_job_logger.debug(f"Worker has resource :" f"{allocated_resource}")
f"{allocated_resource}")
context = RuntimeEnvContext(env_vars=runtime_env.env_vars()) context = RuntimeEnvContext(env_vars=runtime_env.env_vars())
self._pip_manager.setup( self._pip_manager.setup(runtime_env, context, logger=per_job_logger)
runtime_env, context, logger=per_job_logger) self._conda_manager.setup(runtime_env, context, logger=per_job_logger)
self._conda_manager.setup(
runtime_env, context, logger=per_job_logger)
self._py_modules_manager.setup( self._py_modules_manager.setup(
runtime_env, context, logger=per_job_logger) runtime_env, context, logger=per_job_logger
)
self._working_dir_manager.setup( self._working_dir_manager.setup(
runtime_env, context, logger=per_job_logger) runtime_env, context, logger=per_job_logger
)
self._container_manager.setup( self._container_manager.setup(
runtime_env, context, logger=per_job_logger) runtime_env, context, logger=per_job_logger
)
# Add the mapping of URIs -> the serialized environment to be # Add the mapping of URIs -> the serialized environment to be
# used for cache invalidation. # used for cache invalidation.
@ -133,14 +137,15 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
# Run setup function from all the plugins # Run setup function from all the plugins
for plugin_class_path, config in runtime_env.plugins(): for plugin_class_path, config in runtime_env.plugins():
logger.debug( logger.debug(f"Setting up runtime env plugin {plugin_class_path}")
f"Setting up runtime env plugin {plugin_class_path}")
plugin_class = import_attr(plugin_class_path) plugin_class = import_attr(plugin_class_path)
# TODO(simon): implement uri support # TODO(simon): implement uri support
plugin_class.create("uri not implemented", plugin_class.create(
json.loads(config), context) "uri not implemented", json.loads(config), context
plugin_class.modify_context("uri not implemented", )
json.loads(config), context) plugin_class.modify_context(
"uri not implemented", json.loads(config), context
)
return context return context
@ -159,18 +164,24 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
result = self._env_cache[serialized_env] result = self._env_cache[serialized_env]
if result.success: if result.success:
context = result.result context = result.result
logger.info("Runtime env already created successfully. " logger.info(
f"Env: {serialized_env}, context: {context}") "Runtime env already created successfully. "
f"Env: {serialized_env}, context: {context}"
)
return runtime_env_agent_pb2.CreateRuntimeEnvReply( return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_OK, status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
serialized_runtime_env_context=context) serialized_runtime_env_context=context,
)
else: else:
error_message = result.result error_message = result.result
logger.info("Runtime env already failed. " logger.info(
f"Env: {serialized_env}, err: {error_message}") "Runtime env already failed. "
f"Env: {serialized_env}, err: {error_message}"
)
return runtime_env_agent_pb2.CreateRuntimeEnvReply( return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED, status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message=error_message) error_message=error_message,
)
if SLEEP_FOR_TESTING_S: if SLEEP_FOR_TESTING_S:
logger.info(f"Sleeping for {SLEEP_FOR_TESTING_S}s.") logger.info(f"Sleeping for {SLEEP_FOR_TESTING_S}s.")
@ -182,8 +193,8 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES): for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES):
try: try:
runtime_env_context = await _setup_runtime_env( runtime_env_context = await _setup_runtime_env(
serialized_env, serialized_env, request.serialized_allocated_resource_instances
request.serialized_allocated_resource_instances) )
break break
except Exception as ex: except Exception as ex:
logger.exception("Runtime env creation failed.") logger.exception("Runtime env creation failed.")
@ -195,22 +206,25 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
logger.error( logger.error(
"Runtime env creation failed for %d times, " "Runtime env creation failed for %d times, "
"don't retry any more.", "don't retry any more.",
runtime_env_consts.RUNTIME_ENV_RETRY_TIMES) runtime_env_consts.RUNTIME_ENV_RETRY_TIMES,
self._env_cache[serialized_env] = CreatedEnvResult( )
False, error_message) self._env_cache[serialized_env] = CreatedEnvResult(False, error_message)
return runtime_env_agent_pb2.CreateRuntimeEnvReply( return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED, status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message=error_message) error_message=error_message,
)
serialized_context = runtime_env_context.serialize() serialized_context = runtime_env_context.serialize()
self._env_cache[serialized_env] = CreatedEnvResult( self._env_cache[serialized_env] = CreatedEnvResult(True, serialized_context)
True, serialized_context)
logger.info( logger.info(
"Successfully created runtime env: %s, the context: %s", "Successfully created runtime env: %s, the context: %s",
serialized_env, serialized_context) serialized_env,
serialized_context,
)
return runtime_env_agent_pb2.CreateRuntimeEnvReply( return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_OK, status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
serialized_runtime_env_context=serialized_context) serialized_runtime_env_context=serialized_context,
)
async def DeleteURIs(self, request, context): async def DeleteURIs(self, request, context):
logger.info(f"Got request to delete URIs: {request.uris}.") logger.info(f"Got request to delete URIs: {request.uris}.")
@ -239,20 +253,21 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
else: else:
raise ValueError( raise ValueError(
"RuntimeEnvAgent received DeleteURI request " "RuntimeEnvAgent received DeleteURI request "
f"for unsupported plugin {plugin}. URI: {uri}") f"for unsupported plugin {plugin}. URI: {uri}"
)
if failed_uris: if failed_uris:
return runtime_env_agent_pb2.DeleteURIsReply( return runtime_env_agent_pb2.DeleteURIsReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED, status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message="Local files for URI(s) " error_message="Local files for URI(s) " f"{failed_uris} not found.",
f"{failed_uris} not found.") )
else: else:
return runtime_env_agent_pb2.DeleteURIsReply( return runtime_env_agent_pb2.DeleteURIsReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_OK) status=agent_manager_pb2.AGENT_RPC_STATUS_OK
)
async def run(self, server): async def run(self, server):
runtime_env_agent_pb2_grpc.add_RuntimeEnvServiceServicer_to_server( runtime_env_agent_pb2_grpc.add_RuntimeEnvServiceServicer_to_server(self, server)
self, server)
@staticmethod @staticmethod
def is_minimal_module(): def is_minimal_module():

View file

@ -1,7 +1,7 @@
import ray.ray_constants as ray_constants import ray.ray_constants as ray_constants
RUNTIME_ENV_RETRY_TIMES = ray_constants.env_integer("RUNTIME_ENV_RETRY_TIMES", RUNTIME_ENV_RETRY_TIMES = ray_constants.env_integer("RUNTIME_ENV_RETRY_TIMES", 3)
3)
RUNTIME_ENV_RETRY_INTERVAL_MS = ray_constants.env_integer( RUNTIME_ENV_RETRY_INTERVAL_MS = ray_constants.env_integer(
"RUNTIME_ENV_RETRY_INTERVAL_MS", 1000) "RUNTIME_ENV_RETRY_INTERVAL_MS", 1000
)

View file

@ -8,13 +8,19 @@ from ray import ray_constants
from ray.core.generated import gcs_service_pb2 from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_pb2 from ray.core.generated import gcs_pb2
from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated import gcs_service_pb2_grpc
from ray.experimental.internal_kv import (_internal_kv_initialized, from ray.experimental.internal_kv import (
_internal_kv_get, _internal_kv_list) _internal_kv_initialized,
_internal_kv_get,
_internal_kv_list,
)
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.optional_utils as dashboard_optional_utils
from ray._private.runtime_env.validation import ParsedRuntimeEnv from ray._private.runtime_env.validation import ParsedRuntimeEnv
from ray.dashboard.modules.job.common import ( from ray.dashboard.modules.job.common import (
JobStatusInfo, JobStatusStorageClient, JOB_ID_METADATA_KEY) JobStatusInfo,
JobStatusStorageClient,
JOB_ID_METADATA_KEY,
)
import json import json
import aiohttp.web import aiohttp.web
@ -32,7 +38,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
self._job_status_client = JobStatusStorageClient() self._job_status_client = JobStatusStorageClient()
# For offloading CPU intensive work. # For offloading CPU intensive work.
self._thread_pool = concurrent.futures.ThreadPoolExecutor( self._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2, thread_name_prefix="api_head") max_workers=2, thread_name_prefix="api_head"
)
@routes.get("/api/actors/kill") @routes.get("/api/actors/kill")
async def kill_actor_gcs(self, req) -> aiohttp.web.Response: async def kill_actor_gcs(self, req) -> aiohttp.web.Response:
@ -41,7 +48,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
no_restart = req.query.get("no_restart", False) in ("true", "True") no_restart = req.query.get("no_restart", False) in ("true", "True")
if not actor_id: if not actor_id:
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=False, message="actor_id is required.") success=False, message="actor_id is required."
)
request = gcs_service_pb2.KillActorViaGcsRequest() request = gcs_service_pb2.KillActorViaGcsRequest()
request.actor_id = bytes.fromhex(actor_id) request.actor_id = bytes.fromhex(actor_id)
@ -49,31 +57,36 @@ class APIHead(dashboard_utils.DashboardHeadModule):
request.no_restart = no_restart request.no_restart = no_restart
await self._gcs_actor_info_stub.KillActorViaGcs(request, timeout=5) await self._gcs_actor_info_stub.KillActorViaGcs(request, timeout=5)
message = (f"Force killed actor with id {actor_id}" if force_kill else message = (
f"Requested actor with id {actor_id} to terminate. " + f"Force killed actor with id {actor_id}"
"It will exit once running tasks complete") if force_kill
else f"Requested actor with id {actor_id} to terminate. "
+ "It will exit once running tasks complete"
)
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(success=True, message=message)
success=True, message=message)
@routes.get("/api/snapshot") @routes.get("/api/snapshot")
async def snapshot(self, req): async def snapshot(self, req):
job_data, actor_data, serve_data, session_name = await asyncio.gather( job_data, actor_data, serve_data, session_name = await asyncio.gather(
self.get_job_info(), self.get_actor_info(), self.get_serve_info(), self.get_job_info(),
self.get_session_name()) self.get_actor_info(),
self.get_serve_info(),
self.get_session_name(),
)
snapshot = { snapshot = {
"jobs": job_data, "jobs": job_data,
"actors": actor_data, "actors": actor_data,
"deployments": serve_data, "deployments": serve_data,
"session_name": session_name, "session_name": session_name,
"ray_version": ray.__version__, "ray_version": ray.__version__,
"ray_commit": ray.__commit__ "ray_commit": ray.__commit__,
} }
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, message="hello", snapshot=snapshot) success=True, message="hello", snapshot=snapshot
)
def _get_job_status(self, def _get_job_status(self, metadata: Dict[str, str]) -> Optional[JobStatusInfo]:
metadata: Dict[str, str]) -> Optional[JobStatusInfo]:
# If a job submission ID has been added to a job, the status is # If a job submission ID has been added to a job, the status is
# guaranteed to be returned. # guaranteed to be returned.
job_submission_id = metadata.get(JOB_ID_METADATA_KEY) job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
@ -91,8 +104,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
"namespace": job_table_entry.config.ray_namespace, "namespace": job_table_entry.config.ray_namespace,
"metadata": metadata, "metadata": metadata,
"runtime_env": ParsedRuntimeEnv.deserialize( "runtime_env": ParsedRuntimeEnv.deserialize(
job_table_entry.config.runtime_env_info. job_table_entry.config.runtime_env_info.serialized_runtime_env
serialized_runtime_env), ),
} }
status = self._get_job_status(metadata) status = self._get_job_status(metadata)
entry = { entry = {
@ -111,8 +124,7 @@ class APIHead(dashboard_utils.DashboardHeadModule):
# TODO (Alex): GCS still needs to return actors from dead jobs. # TODO (Alex): GCS still needs to return actors from dead jobs.
request = gcs_service_pb2.GetAllActorInfoRequest() request = gcs_service_pb2.GetAllActorInfoRequest()
request.show_dead_jobs = True request.show_dead_jobs = True
reply = await self._gcs_actor_info_stub.GetAllActorInfo( reply = await self._gcs_actor_info_stub.GetAllActorInfo(request, timeout=5)
request, timeout=5)
actors = {} actors = {}
for actor_table_entry in reply.actor_table_data: for actor_table_entry in reply.actor_table_data:
actor_id = actor_table_entry.actor_id.hex() actor_id = actor_table_entry.actor_id.hex()
@ -120,37 +132,33 @@ class APIHead(dashboard_utils.DashboardHeadModule):
entry = { entry = {
"job_id": actor_table_entry.job_id.hex(), "job_id": actor_table_entry.job_id.hex(),
"state": gcs_pb2.ActorTableData.ActorState.Name( "state": gcs_pb2.ActorTableData.ActorState.Name(
actor_table_entry.state), actor_table_entry.state
),
"name": actor_table_entry.name, "name": actor_table_entry.name,
"namespace": actor_table_entry.ray_namespace, "namespace": actor_table_entry.ray_namespace,
"runtime_env": runtime_env, "runtime_env": runtime_env,
"start_time": actor_table_entry.start_time, "start_time": actor_table_entry.start_time,
"end_time": actor_table_entry.end_time, "end_time": actor_table_entry.end_time,
"is_detached": actor_table_entry.is_detached, "is_detached": actor_table_entry.is_detached,
"resources": dict( "resources": dict(actor_table_entry.task_spec.required_resources),
actor_table_entry.task_spec.required_resources),
"actor_class": actor_table_entry.class_name, "actor_class": actor_table_entry.class_name,
"current_worker_id": actor_table_entry.address.worker_id.hex(), "current_worker_id": actor_table_entry.address.worker_id.hex(),
"current_raylet_id": actor_table_entry.address.raylet_id.hex(), "current_raylet_id": actor_table_entry.address.raylet_id.hex(),
"ip_address": actor_table_entry.address.ip_address, "ip_address": actor_table_entry.address.ip_address,
"port": actor_table_entry.address.port, "port": actor_table_entry.address.port,
"metadata": dict() "metadata": dict(),
} }
actors[actor_id] = entry actors[actor_id] = entry
deployments = await self.get_serve_info() deployments = await self.get_serve_info()
for _, deployment_info in deployments.items(): for _, deployment_info in deployments.items():
for replica_actor_id, actor_info in deployment_info[ for replica_actor_id, actor_info in deployment_info["actors"].items():
"actors"].items():
if replica_actor_id in actors: if replica_actor_id in actors:
serve_metadata = dict() serve_metadata = dict()
serve_metadata["replica_tag"] = actor_info[ serve_metadata["replica_tag"] = actor_info["replica_tag"]
"replica_tag"] serve_metadata["deployment_name"] = deployment_info["name"]
serve_metadata["deployment_name"] = deployment_info[
"name"]
serve_metadata["version"] = actor_info["version"] serve_metadata["version"] = actor_info["version"]
actors[replica_actor_id]["metadata"][ actors[replica_actor_id]["metadata"]["serve"] = serve_metadata
"serve"] = serve_metadata
return actors return actors
async def get_serve_info(self) -> Dict[str, Any]: async def get_serve_info(self) -> Dict[str, Any]:
@ -168,22 +176,21 @@ class APIHead(dashboard_utils.DashboardHeadModule):
# TODO: Convert to async GRPC, if CPU usage is not a concern. # TODO: Convert to async GRPC, if CPU usage is not a concern.
def get_deployments(): def get_deployments():
serve_keys = _internal_kv_list( serve_keys = _internal_kv_list(
SERVE_CONTROLLER_NAME, SERVE_CONTROLLER_NAME, namespace=ray_constants.KV_NAMESPACE_SERVE
namespace=ray_constants.KV_NAMESPACE_SERVE) )
serve_snapshot_keys = filter( serve_snapshot_keys = filter(
lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys) lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys
)
deployments_per_controller: List[Dict[str, Any]] = [] deployments_per_controller: List[Dict[str, Any]] = []
for key in serve_snapshot_keys: for key in serve_snapshot_keys:
val_bytes = _internal_kv_get( val_bytes = _internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_SERVE key, namespace=ray_constants.KV_NAMESPACE_SERVE
) or "{}".encode("utf-8") ) or "{}".encode("utf-8")
deployments_per_controller.append( deployments_per_controller.append(json.loads(val_bytes.decode("utf-8")))
json.loads(val_bytes.decode("utf-8")))
# Merge the deployments dicts of all controllers. # Merge the deployments dicts of all controllers.
deployments: Dict[str, Any] = { deployments: Dict[str, Any] = {
k: v k: v for d in deployments_per_controller for k, v in d.items()
for d in deployments_per_controller for k, v in d.items()
} }
# Replace the keys (deployment names) with their hashes to prevent # Replace the keys (deployment names) with their hashes to prevent
# collisions caused by the automatic conversion to camelcase by the # collisions caused by the automatic conversion to camelcase by the
@ -194,24 +201,27 @@ class APIHead(dashboard_utils.DashboardHeadModule):
} }
return await asyncio.get_event_loop().run_in_executor( return await asyncio.get_event_loop().run_in_executor(
executor=self._thread_pool, func=get_deployments) executor=self._thread_pool, func=get_deployments
)
async def get_session_name(self): async def get_session_name(self):
# TODO(yic): Convert to async GRPC. # TODO(yic): Convert to async GRPC.
def get_session(): def get_session():
return ray.experimental.internal_kv._internal_kv_get( return ray.experimental.internal_kv._internal_kv_get(
"session_name", "session_name", namespace=ray_constants.KV_NAMESPACE_SESSION
namespace=ray_constants.KV_NAMESPACE_SESSION).decode() ).decode()
return await asyncio.get_event_loop().run_in_executor( return await asyncio.get_event_loop().run_in_executor(
executor=self._thread_pool, func=get_session) executor=self._thread_pool, func=get_session
)
async def run(self, server): async def run(self, server):
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub( self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel) self._dashboard_head.aiogrpc_gcs_channel
self._gcs_actor_info_stub = \ )
gcs_service_pb2_grpc.ActorInfoGcsServiceStub( self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel) self._dashboard_head.aiogrpc_gcs_channel
)
@staticmethod @staticmethod
def is_minimal_module(): def is_minimal_module():

View file

@ -36,15 +36,14 @@ def _actor_killed_loop(worker_pid: str, timeout_secs=3) -> bool:
return dead return dead
def _kill_actor_using_dashboard_gcs(webui_url: str, def _kill_actor_using_dashboard_gcs(webui_url: str, actor_id: str, force_kill=False):
actor_id: str,
force_kill=False):
resp = requests.get( resp = requests.get(
webui_url + KILL_ACTOR_ENDPOINT, webui_url + KILL_ACTOR_ENDPOINT,
params={ params={
"actor_id": actor_id, "actor_id": actor_id,
"force_kill": force_kill, "force_kill": force_kill,
}) },
)
resp.raise_for_status() resp.raise_for_status()
resp_json = resp.json() resp_json = resp.json()
assert resp_json["result"] is True, "msg" in resp_json assert resp_json["result"] is True, "msg" in resp_json

View file

@ -8,8 +8,11 @@ import pprint
import pytest import pytest
import requests import requests
from ray._private.test_utils import (format_web_url, wait_for_condition, from ray._private.test_utils import (
wait_until_server_available) format_web_url,
wait_for_condition,
wait_until_server_available,
)
from ray.dashboard import dashboard from ray.dashboard import dashboard
from ray.dashboard.tests.conftest import * # noqa from ray.dashboard.tests.conftest import * # noqa
from ray.dashboard.modules.job.sdk import JobSubmissionClient from ray.dashboard.modules.job.sdk import JobSubmissionClient
@ -22,25 +25,23 @@ def _get_snapshot(address: str):
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
schema_path = os.path.join( schema_path = os.path.join(
os.path.dirname(dashboard.__file__), os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
"modules/snapshot/snapshot_schema.json") )
pprint.pprint(data) pprint.pprint(data)
jsonschema.validate(instance=data, schema=json.load(open(schema_path))) jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
return data return data
def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache, def test_successful_job_status(
enable_test_module): ray_start_with_dashboard, disable_aiohttp_cache, enable_test_module
):
address = ray_start_with_dashboard["webui_url"] address = ray_start_with_dashboard["webui_url"]
assert wait_until_server_available(address) assert wait_until_server_available(address)
address = format_web_url(address) address = format_web_url(address)
entrypoint_cmd = ("python -c\"" entrypoint_cmd = (
"import ray;" 'python -c"' "import ray;" "ray.init();" "import time;" "time.sleep(5);" '"'
"ray.init();" )
"import time;"
"time.sleep(5);"
"\"")
client = JobSubmissionClient(address) client = JobSubmissionClient(address)
job_id = client.submit_job(entrypoint=entrypoint_cmd) job_id = client.submit_job(entrypoint=entrypoint_cmd)
@ -49,11 +50,8 @@ def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
data = _get_snapshot(address) data = _get_snapshot(address)
for job_entry in data["data"]["snapshot"]["jobs"].values(): for job_entry in data["data"]["snapshot"]["jobs"].values():
if job_entry["status"] is not None: if job_entry["status"] is not None:
assert job_entry["config"]["metadata"][ assert job_entry["config"]["metadata"]["jobSubmissionId"] == job_id
"jobSubmissionId"] == job_id assert job_entry["status"] in {"PENDING", "RUNNING", "SUCCEEDED"}
assert job_entry["status"] in {
"PENDING", "RUNNING", "SUCCEEDED"
}
assert job_entry["statusMessage"] is not None assert job_entry["statusMessage"] is not None
return job_entry["status"] == "SUCCEEDED" return job_entry["status"] == "SUCCEEDED"
@ -62,20 +60,23 @@ def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
wait_for_condition(wait_for_job_to_succeed, timeout=30) wait_for_condition(wait_for_job_to_succeed, timeout=30)
def test_failed_job_status(ray_start_with_dashboard, disable_aiohttp_cache, def test_failed_job_status(
enable_test_module): ray_start_with_dashboard, disable_aiohttp_cache, enable_test_module
):
address = ray_start_with_dashboard["webui_url"] address = ray_start_with_dashboard["webui_url"]
assert wait_until_server_available(address) assert wait_until_server_available(address)
address = format_web_url(address) address = format_web_url(address)
entrypoint_cmd = ("python -c\"" entrypoint_cmd = (
"import ray;" 'python -c"'
"ray.init();" "import ray;"
"import time;" "ray.init();"
"time.sleep(5);" "import time;"
"import sys;" "time.sleep(5);"
"sys.exit(1);" "import sys;"
"\"") "sys.exit(1);"
'"'
)
client = JobSubmissionClient(address) client = JobSubmissionClient(address)
job_id = client.submit_job(entrypoint=entrypoint_cmd) job_id = client.submit_job(entrypoint=entrypoint_cmd)
@ -83,8 +84,7 @@ def test_failed_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
data = _get_snapshot(address) data = _get_snapshot(address)
for job_entry in data["data"]["snapshot"]["jobs"].values(): for job_entry in data["data"]["snapshot"]["jobs"].values():
if job_entry["status"] is not None: if job_entry["status"] is not None:
assert job_entry["config"]["metadata"][ assert job_entry["config"]["metadata"]["jobSubmissionId"] == job_id
"jobSubmissionId"] == job_id
assert job_entry["status"] in {"PENDING", "RUNNING", "FAILED"} assert job_entry["status"] in {"PENDING", "RUNNING", "FAILED"}
assert job_entry["statusMessage"] is not None assert job_entry["statusMessage"] is not None
return job_entry["status"] == "FAILED" return job_entry["status"] == "FAILED"

View file

@ -34,11 +34,14 @@ ray.get(a.ping.remote())
""" """
address = ray_start_with_dashboard["address"] address = ray_start_with_dashboard["address"]
detached_driver = driver_template.format( detached_driver = driver_template.format(
address=address, lifetime="'detached'", name="'abc'") address=address, lifetime="'detached'", name="'abc'"
)
named_driver = driver_template.format( named_driver = driver_template.format(
address=address, lifetime="None", name="'xyz'") address=address, lifetime="None", name="'xyz'"
)
unnamed_driver = driver_template.format( unnamed_driver = driver_template.format(
address=address, lifetime="None", name="None") address=address, lifetime="None", name="None"
)
run_string_as_driver(detached_driver) run_string_as_driver(detached_driver)
run_string_as_driver(named_driver) run_string_as_driver(named_driver)
@ -50,8 +53,8 @@ ray.get(a.ping.remote())
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
schema_path = os.path.join( schema_path = os.path.join(
os.path.dirname(dashboard.__file__), os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
"modules/snapshot/snapshot_schema.json") )
pprint.pprint(data) pprint.pprint(data)
jsonschema.validate(instance=data, schema=json.load(open(schema_path))) jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
@ -72,10 +75,7 @@ ray.get(a.ping.remote())
assert data["data"]["snapshot"]["rayVersion"] == ray.__version__ assert data["data"]["snapshot"]["rayVersion"] == ray.__version__
@pytest.mark.parametrize( @pytest.mark.parametrize("ray_start_with_dashboard", [{"num_cpus": 4}], indirect=True)
"ray_start_with_dashboard", [{
"num_cpus": 4
}], indirect=True)
def test_serve_snapshot(ray_start_with_dashboard): def test_serve_snapshot(ray_start_with_dashboard):
"""Test detached and nondetached Serve instances running concurrently.""" """Test detached and nondetached Serve instances running concurrently."""
@ -115,8 +115,7 @@ my_func_deleted.delete()
my_func_nondetached.deploy() my_func_nondetached.deploy()
assert requests.get( assert requests.get("http://127.0.0.1:8123/my_func_nondetached").text == "hello"
"http://127.0.0.1:8123/my_func_nondetached").text == "hello"
webui_url = ray_start_with_dashboard["webui_url"] webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
@ -124,15 +123,16 @@ my_func_deleted.delete()
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
schema_path = os.path.join( schema_path = os.path.join(
os.path.dirname(dashboard.__file__), os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
"modules/snapshot/snapshot_schema.json") )
pprint.pprint(data) pprint.pprint(data)
jsonschema.validate(instance=data, schema=json.load(open(schema_path))) jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
assert len(data["data"]["snapshot"]["deployments"]) == 3 assert len(data["data"]["snapshot"]["deployments"]) == 3
entry = data["data"]["snapshot"]["deployments"][hashlib.sha1( entry = data["data"]["snapshot"]["deployments"][
"my_func".encode()).hexdigest()] hashlib.sha1("my_func".encode()).hexdigest()
]
assert entry["name"] == "my_func" assert entry["name"] == "my_func"
assert entry["version"] is None assert entry["version"] is None
assert entry["namespace"] == "serve" assert entry["namespace"] == "serve"
@ -145,14 +145,14 @@ my_func_deleted.delete()
assert len(entry["actors"]) == 1 assert len(entry["actors"]) == 1
actor_id = next(iter(entry["actors"])) actor_id = next(iter(entry["actors"]))
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"][ metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"]["serve"]
"serve"]
assert metadata["deploymentName"] == "my_func" assert metadata["deploymentName"] == "my_func"
assert metadata["version"] is None assert metadata["version"] is None
assert len(metadata["replicaTag"]) > 0 assert len(metadata["replicaTag"]) > 0
entry_deleted = data["data"]["snapshot"]["deployments"][hashlib.sha1( entry_deleted = data["data"]["snapshot"]["deployments"][
"my_func_deleted".encode()).hexdigest()] hashlib.sha1("my_func_deleted".encode()).hexdigest()
]
assert entry_deleted["name"] == "my_func_deleted" assert entry_deleted["name"] == "my_func_deleted"
assert entry_deleted["version"] == "v1" assert entry_deleted["version"] == "v1"
assert entry_deleted["namespace"] == "serve" assert entry_deleted["namespace"] == "serve"
@ -163,8 +163,9 @@ my_func_deleted.delete()
assert entry_deleted["startTime"] > 0 assert entry_deleted["startTime"] > 0
assert entry_deleted["endTime"] > entry_deleted["startTime"] assert entry_deleted["endTime"] > entry_deleted["startTime"]
entry_nondetached = data["data"]["snapshot"]["deployments"][hashlib.sha1( entry_nondetached = data["data"]["snapshot"]["deployments"][
"my_func_nondetached".encode()).hexdigest()] hashlib.sha1("my_func_nondetached".encode()).hexdigest()
]
assert entry_nondetached["name"] == "my_func_nondetached" assert entry_nondetached["name"] == "my_func_nondetached"
assert entry_nondetached["version"] == "v1" assert entry_nondetached["version"] == "v1"
assert entry_nondetached["namespace"] == "default_test_namespace" assert entry_nondetached["namespace"] == "default_test_namespace"
@ -177,8 +178,7 @@ my_func_deleted.delete()
assert len(entry_nondetached["actors"]) == 1 assert len(entry_nondetached["actors"]) == 1
actor_id = next(iter(entry_nondetached["actors"])) actor_id = next(iter(entry_nondetached["actors"]))
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"][ metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"]["serve"]
"serve"]
assert metadata["deploymentName"] == "my_func_nondetached" assert metadata["deploymentName"] == "my_func_nondetached"
assert metadata["version"] == "v1" assert metadata["version"] == "v1"
assert len(metadata["replicaTag"]) > 0 assert len(metadata["replicaTag"]) > 0

View file

@ -13,7 +13,8 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
@dashboard_utils.dashboard_module( @dashboard_utils.dashboard_module(
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False)) enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False)
)
class TestAgent(dashboard_utils.DashboardAgentModule): class TestAgent(dashboard_utils.DashboardAgentModule):
def __init__(self, dashboard_agent): def __init__(self, dashboard_agent):
super().__init__(dashboard_agent) super().__init__(dashboard_agent)
@ -25,8 +26,7 @@ class TestAgent(dashboard_utils.DashboardAgentModule):
@routes.get("/test/http_get_from_agent") @routes.get("/test/http_get_from_agent")
async def get_url(self, req) -> aiohttp.web.Response: async def get_url(self, req) -> aiohttp.web.Response:
url = req.query.get("url") url = req.query.get("url")
result = await test_utils.http_get(self._dashboard_agent.http_session, result = await test_utils.http_get(self._dashboard_agent.http_session, url)
url)
return aiohttp.web.json_response(result) return aiohttp.web.json_response(result)
@routes.head("/test/route_head") @routes.head("/test/route_head")

View file

@ -15,7 +15,8 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
@dashboard_utils.dashboard_module( @dashboard_utils.dashboard_module(
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False)) enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False)
)
class TestHead(dashboard_utils.DashboardHeadModule): class TestHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head): def __init__(self, dashboard_head):
super().__init__(dashboard_head) super().__init__(dashboard_head)
@ -62,26 +63,28 @@ class TestHead(dashboard_utils.DashboardHeadModule):
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, success=True,
message="Fetch all data from datacenter success.", message="Fetch all data from datacenter success.",
**all_data) **all_data,
)
else: else:
data = dict(DataSource.__dict__.get(key)) data = dict(DataSource.__dict__.get(key))
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, success=True,
message=f"Fetch {key} from datacenter success.", message=f"Fetch {key} from datacenter success.",
**{key: data}) **{key: data},
)
@routes.get("/test/notified_agents") @routes.get("/test/notified_agents")
async def get_notified_agents(self, req) -> aiohttp.web.Response: async def get_notified_agents(self, req) -> aiohttp.web.Response:
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, success=True,
message="Fetch notified agents success.", message="Fetch notified agents success.",
**self._notified_agents) **self._notified_agents,
)
@routes.get("/test/http_get") @routes.get("/test/http_get")
async def get_url(self, req) -> aiohttp.web.Response: async def get_url(self, req) -> aiohttp.web.Response:
url = req.query.get("url") url = req.query.get("url")
result = await test_utils.http_get(self._dashboard_head.http_session, result = await test_utils.http_get(self._dashboard_head.http_session, url)
url)
return aiohttp.web.json_response(result) return aiohttp.web.json_response(result)
@routes.get("/test/aiohttp_cache/{sub_path}") @routes.get("/test/aiohttp_cache/{sub_path}")
@ -89,14 +92,16 @@ class TestHead(dashboard_utils.DashboardHeadModule):
async def test_aiohttp_cache(self, req) -> aiohttp.web.Response: async def test_aiohttp_cache(self, req) -> aiohttp.web.Response:
value = req.query["value"] value = req.query["value"]
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, message="OK", value=value, timestamp=time.time()) success=True, message="OK", value=value, timestamp=time.time()
)
@routes.get("/test/aiohttp_cache_lru/{sub_path}") @routes.get("/test/aiohttp_cache_lru/{sub_path}")
@dashboard_optional_utils.aiohttp_cache(ttl_seconds=60, maxsize=5) @dashboard_optional_utils.aiohttp_cache(ttl_seconds=60, maxsize=5)
async def test_aiohttp_cache_lru(self, req) -> aiohttp.web.Response: async def test_aiohttp_cache_lru(self, req) -> aiohttp.web.Response:
value = req.query.get("value") value = req.query.get("value")
return dashboard_optional_utils.rest_response( return dashboard_optional_utils.rest_response(
success=True, message="OK", value=value, timestamp=time.time()) success=True, message="OK", value=value, timestamp=time.time()
)
@routes.get("/test/file") @routes.get("/test/file")
async def test_file(self, req) -> aiohttp.web.FileResponse: async def test_file(self, req) -> aiohttp.web.FileResponse:

View file

@ -4,8 +4,7 @@ import copy
import os import os
import aiohttp.web import aiohttp.web
import ray.dashboard.modules.tune.tune_consts \ import ray.dashboard.modules.tune.tune_consts as tune_consts
as tune_consts
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.optional_utils as dashboard_optional_utils
from ray.dashboard.utils import async_loop_forever from ray.dashboard.utils import async_loop_forever
@ -45,19 +44,17 @@ class TuneController(dashboard_utils.DashboardHeadModule):
@routes.get("/tune/info") @routes.get("/tune/info")
async def tune_info(self, req) -> aiohttp.web.Response: async def tune_info(self, req) -> aiohttp.web.Response:
stats = self.get_stats() stats = self.get_stats()
return rest_response( return rest_response(success=True, message="Fetched tune info", result=stats)
success=True, message="Fetched tune info", result=stats)
@routes.get("/tune/availability") @routes.get("/tune/availability")
async def get_availability(self, req) -> aiohttp.web.Response: async def get_availability(self, req) -> aiohttp.web.Response:
availability = { availability = {
"available": ExperimentAnalysis is not None, "available": ExperimentAnalysis is not None,
"trials_available": self._trials_available "trials_available": self._trials_available,
} }
return rest_response( return rest_response(
success=True, success=True, message="Fetched tune availability", result=availability
message="Fetched tune availability", )
result=availability)
@routes.get("/tune/set_experiment") @routes.get("/tune/set_experiment")
async def set_tune_experiment(self, req) -> aiohttp.web.Response: async def set_tune_experiment(self, req) -> aiohttp.web.Response:
@ -66,25 +63,25 @@ class TuneController(dashboard_utils.DashboardHeadModule):
if err: if err:
return rest_response(success=False, error=err) return rest_response(success=False, error=err)
return rest_response( return rest_response(
success=True, message="Successfully set experiment", **experiment) success=True, message="Successfully set experiment", **experiment
)
@routes.get("/tune/enable_tensorboard") @routes.get("/tune/enable_tensorboard")
async def enable_tensorboard(self, req) -> aiohttp.web.Response: async def enable_tensorboard(self, req) -> aiohttp.web.Response:
self._enable_tensorboard() self._enable_tensorboard()
if not self._tensor_board_dir: if not self._tensor_board_dir:
return rest_response( return rest_response(success=False, message="Error enabling tensorboard")
success=False, message="Error enabling tensorboard")
return rest_response(success=True, message="Enabled tensorboard") return rest_response(success=True, message="Enabled tensorboard")
def get_stats(self): def get_stats(self):
tensor_board_info = { tensor_board_info = {
"tensorboard_current": self._logdir == self._tensor_board_dir, "tensorboard_current": self._logdir == self._tensor_board_dir,
"tensorboard_enabled": self._tensor_board_dir != "" "tensorboard_enabled": self._tensor_board_dir != "",
} }
return { return {
"trial_records": copy.deepcopy(self._trial_records), "trial_records": copy.deepcopy(self._trial_records),
"errors": copy.deepcopy(self._errors), "errors": copy.deepcopy(self._errors),
"tensorboard": tensor_board_info "tensorboard": tensor_board_info,
} }
def set_experiment(self, experiment): def set_experiment(self, experiment):
@ -104,7 +101,8 @@ class TuneController(dashboard_utils.DashboardHeadModule):
def collect_errors(self, df): def collect_errors(self, df):
sub_dirs = os.listdir(self._logdir) sub_dirs = os.listdir(self._logdir)
trial_names = filter( trial_names = filter(
lambda d: os.path.isdir(os.path.join(self._logdir, d)), sub_dirs) lambda d: os.path.isdir(os.path.join(self._logdir, d)), sub_dirs
)
for trial in trial_names: for trial in trial_names:
error_path = os.path.join(self._logdir, trial, "error.txt") error_path = os.path.join(self._logdir, trial, "error.txt")
if os.path.isfile(error_path): if os.path.isfile(error_path):
@ -114,7 +112,7 @@ class TuneController(dashboard_utils.DashboardHeadModule):
self._errors[str(trial)] = { self._errors[str(trial)] = {
"text": text, "text": text,
"job_id": os.path.basename(self._logdir), "job_id": os.path.basename(self._logdir),
"trial_id": "No Trial ID" "trial_id": "No Trial ID",
} }
other_data = df[df["logdir"].str.contains(trial)] other_data = df[df["logdir"].str.contains(trial)]
if len(other_data) > 0: if len(other_data) > 0:
@ -175,12 +173,25 @@ class TuneController(dashboard_utils.DashboardHeadModule):
# list of static attributes for trial # list of static attributes for trial
default_names = { default_names = {
"logdir", "time_this_iter_s", "done", "episodes_total", "logdir",
"training_iteration", "timestamp", "timesteps_total", "time_this_iter_s",
"experiment_id", "date", "timestamp", "time_total_s", "pid", "done",
"hostname", "node_ip", "time_since_restore", "episodes_total",
"timesteps_since_restore", "iterations_since_restore", "training_iteration",
"experiment_tag", "trial_id" "timestamp",
"timesteps_total",
"experiment_id",
"date",
"timestamp",
"time_total_s",
"pid",
"hostname",
"node_ip",
"time_since_restore",
"timesteps_since_restore",
"iterations_since_restore",
"experiment_tag",
"trial_id",
} }
# filter attributes into floats, metrics, and config variables # filter attributes into floats, metrics, and config variables
@ -196,7 +207,8 @@ class TuneController(dashboard_utils.DashboardHeadModule):
for trial, details in trial_details.items(): for trial, details in trial_details.items():
ts = os.path.getctime(details["logdir"]) ts = os.path.getctime(details["logdir"])
formatted_time = datetime.datetime.fromtimestamp(ts).strftime( formatted_time = datetime.datetime.fromtimestamp(ts).strftime(
"%Y-%m-%d %H:%M:%S") "%Y-%m-%d %H:%M:%S"
)
details["start_time"] = formatted_time details["start_time"] = formatted_time
details["params"] = {} details["params"] = {}
details["metrics"] = {} details["metrics"] = {}

View file

@ -25,7 +25,7 @@ except AttributeError:
# All third-party dependencies that are not included in the minimal Ray # All third-party dependencies that are not included in the minimal Ray
# installation must be included in this file. This allows us to determine if # installation must be included in this file. This allows us to determine if
# the agent has the necessary dependencies to be started. # the agent has the necessary dependencies to be started.
from ray.dashboard.optional_deps import (aiohttp, hdrs, PathLike, RouteDef) from ray.dashboard.optional_deps import aiohttp, hdrs, PathLike, RouteDef
from ray.dashboard.utils import to_google_style, CustomEncoder from ray.dashboard.utils import to_google_style, CustomEncoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,12 +68,15 @@ class ClassMethodRouteTable:
def _wrapper(handler): def _wrapper(handler):
if path in cls._bind_map[method]: if path in cls._bind_map[method]:
bind_info = cls._bind_map[method][path] bind_info = cls._bind_map[method][path]
raise Exception(f"Duplicated route path: {path}, " raise Exception(
f"previous one registered at " f"Duplicated route path: {path}, "
f"{bind_info.filename}:{bind_info.lineno}") f"previous one registered at "
f"{bind_info.filename}:{bind_info.lineno}"
)
bind_info = cls._BindInfo(handler.__code__.co_filename, bind_info = cls._BindInfo(
handler.__code__.co_firstlineno, None) handler.__code__.co_filename, handler.__code__.co_firstlineno, None
)
@functools.wraps(handler) @functools.wraps(handler)
async def _handler_route(*args) -> aiohttp.web.Response: async def _handler_route(*args) -> aiohttp.web.Response:
@ -86,8 +89,7 @@ class ClassMethodRouteTable:
return await handler(bind_info.instance, req) return await handler(bind_info.instance, req)
except Exception: except Exception:
logger.exception("Handle %s %s failed.", method, path) logger.exception("Handle %s %s failed.", method, path)
return rest_response( return rest_response(success=False, message=traceback.format_exc())
success=False, message=traceback.format_exc())
cls._bind_map[method][path] = bind_info cls._bind_map[method][path] = bind_info
_handler_route.__route_method__ = method _handler_route.__route_method__ = method
@ -132,18 +134,19 @@ class ClassMethodRouteTable:
def bind(cls, instance): def bind(cls, instance):
def predicate(o): def predicate(o):
if inspect.ismethod(o): if inspect.ismethod(o):
return hasattr(o, "__route_method__") and hasattr( return hasattr(o, "__route_method__") and hasattr(o, "__route_path__")
o, "__route_path__")
return False return False
handler_routes = inspect.getmembers(instance, predicate) handler_routes = inspect.getmembers(instance, predicate)
for _, h in handler_routes: for _, h in handler_routes:
cls._bind_map[h.__func__.__route_method__][ cls._bind_map[h.__func__.__route_method__][
h.__func__.__route_path__].instance = instance h.__func__.__route_path__
].instance = instance
def rest_response(success, message, convert_google_style=True, def rest_response(
**kwargs) -> aiohttp.web.Response: success, message, convert_google_style=True, **kwargs
) -> aiohttp.web.Response:
# In the dev context we allow a dev server running on a # In the dev context we allow a dev server running on a
# different port to consume the API, meaning we need to allow # different port to consume the API, meaning we need to allow
# cross-origin access # cross-origin access
@ -155,24 +158,24 @@ def rest_response(success, message, convert_google_style=True,
{ {
"result": success, "result": success,
"msg": message, "msg": message,
"data": to_google_style(kwargs) if convert_google_style else kwargs "data": to_google_style(kwargs) if convert_google_style else kwargs,
}, },
dumps=functools.partial(json.dumps, cls=CustomEncoder), dumps=functools.partial(json.dumps, cls=CustomEncoder),
headers=headers) headers=headers,
)
# The cache value type used by aiohttp_cache. # The cache value type used by aiohttp_cache.
_AiohttpCacheValue = namedtuple("AiohttpCacheValue", _AiohttpCacheValue = namedtuple("AiohttpCacheValue", ["data", "expiration", "task"])
["data", "expiration", "task"])
# The methods with no request body used by aiohttp_cache. # The methods with no request body used by aiohttp_cache.
_AIOHTTP_CACHE_NOBODY_METHODS = {hdrs.METH_GET, hdrs.METH_DELETE} _AIOHTTP_CACHE_NOBODY_METHODS = {hdrs.METH_GET, hdrs.METH_DELETE}
def aiohttp_cache( def aiohttp_cache(
ttl_seconds=dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS, ttl_seconds=dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS,
maxsize=dashboard_consts.AIOHTTP_CACHE_MAX_SIZE, maxsize=dashboard_consts.AIOHTTP_CACHE_MAX_SIZE,
enable=not env_bool( enable=not env_bool(dashboard_consts.AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY, False),
dashboard_consts.AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY, False)): ):
assert maxsize > 0 assert maxsize > 0
cache = collections.OrderedDict() cache = collections.OrderedDict()
@ -195,8 +198,7 @@ def aiohttp_cache(
value = cache.get(key) value = cache.get(key)
if value is not None: if value is not None:
cache.move_to_end(key) cache.move_to_end(key)
if (not value.task.done() if not value.task.done() or value.expiration >= time.time():
or value.expiration >= time.time()):
# Update task not done or the data is not expired. # Update task not done or the data is not expired.
return aiohttp.web.Response(**value.data) return aiohttp.web.Response(**value.data)
@ -205,15 +207,16 @@ def aiohttp_cache(
response = task.result() response = task.result()
except Exception: except Exception:
response = rest_response( response = rest_response(
success=False, message=traceback.format_exc()) success=False, message=traceback.format_exc()
)
data = { data = {
"status": response.status, "status": response.status,
"headers": dict(response.headers), "headers": dict(response.headers),
"body": response.body, "body": response.body,
} }
cache[key] = _AiohttpCacheValue(data, cache[key] = _AiohttpCacheValue(
time.time() + ttl_seconds, data, time.time() + ttl_seconds, task
task) )
cache.move_to_end(key) cache.move_to_end(key)
if len(cache) > maxsize: if len(cache) > maxsize:
cache.popitem(last=False) cache.popitem(last=False)

View file

@ -19,11 +19,14 @@ import requests
from ray import ray_constants from ray import ray_constants
from ray._private.test_utils import ( from ray._private.test_utils import (
format_web_url, wait_for_condition, wait_until_server_available, format_web_url,
run_string_as_driver, wait_until_succeeded_without_exception) wait_for_condition,
wait_until_server_available,
run_string_as_driver,
wait_until_succeeded_without_exception,
)
from ray._private.gcs_pubsub import gcs_pubsub_enabled from ray._private.gcs_pubsub import gcs_pubsub_enabled
from ray.ray_constants import (DEBUG_AUTOSCALING_STATUS_LEGACY, from ray.ray_constants import DEBUG_AUTOSCALING_STATUS_LEGACY, DEBUG_AUTOSCALING_ERROR
DEBUG_AUTOSCALING_ERROR)
from ray.dashboard import dashboard from ray.dashboard import dashboard
import ray.dashboard.consts as dashboard_consts import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils import ray.dashboard.utils as dashboard_utils
@ -43,7 +46,8 @@ def make_gcs_client(address_info):
client = redis.StrictRedis( client = redis.StrictRedis(
host=address[0], host=address[0],
port=int(address[1]), port=int(address[1]),
password=ray_constants.REDIS_DEFAULT_PASSWORD) password=ray_constants.REDIS_DEFAULT_PASSWORD,
)
gcs_client = ray._private.gcs_utils.GcsClient.create_from_redis(client) gcs_client = ray._private.gcs_utils.GcsClient.create_from_redis(client)
else: else:
address = address_info["gcs_address"] address = address_info["gcs_address"]
@ -73,17 +77,14 @@ cleanup_test_files()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ray_start_with_dashboard", [{ "ray_start_with_dashboard",
"_system_config": { [{"_system_config": {"agent_register_timeout_ms": 5000}}],
"agent_register_timeout_ms": 5000 indirect=True,
} )
}],
indirect=True)
def test_basic(ray_start_with_dashboard): def test_basic(ray_start_with_dashboard):
"""Dashboard test that starts a Ray cluster with a dashboard server running, """Dashboard test that starts a Ray cluster with a dashboard server running,
then hits the dashboard API and asserts that it receives sensible data.""" then hits the dashboard API and asserts that it receives sensible data."""
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
is True)
address_info = ray_start_with_dashboard address_info = ray_start_with_dashboard
node_id = address_info["node_id"] node_id = address_info["node_id"]
gcs_client = make_gcs_client(address_info) gcs_client = make_gcs_client(address_info)
@ -92,11 +93,12 @@ def test_basic(ray_start_with_dashboard):
all_processes = ray.worker._global_node.all_processes all_processes = ray.worker._global_node.all_processes
assert ray_constants.PROCESS_TYPE_DASHBOARD in all_processes assert ray_constants.PROCESS_TYPE_DASHBOARD in all_processes
assert ray_constants.PROCESS_TYPE_REPORTER not in all_processes assert ray_constants.PROCESS_TYPE_REPORTER not in all_processes
dashboard_proc_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][ dashboard_proc_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
0]
dashboard_proc = psutil.Process(dashboard_proc_info.process.pid) dashboard_proc = psutil.Process(dashboard_proc_info.process.pid)
assert dashboard_proc.status() in [ assert dashboard_proc.status() in [
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING, psutil.STATUS_DISK_SLEEP psutil.STATUS_RUNNING,
psutil.STATUS_SLEEPING,
psutil.STATUS_DISK_SLEEP,
] ]
raylet_proc_info = all_processes[ray_constants.PROCESS_TYPE_RAYLET][0] raylet_proc_info = all_processes[ray_constants.PROCESS_TYPE_RAYLET][0]
raylet_proc = psutil.Process(raylet_proc_info.process.pid) raylet_proc = psutil.Process(raylet_proc_info.process.pid)
@ -140,9 +142,7 @@ def test_basic(ray_start_with_dashboard):
logger.info("Test agent register is OK.") logger.info("Test agent register is OK.")
wait_for_condition(lambda: _search_agent(raylet_proc.children())) wait_for_condition(lambda: _search_agent(raylet_proc.children()))
assert dashboard_proc.status() in [ assert dashboard_proc.status() in [psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING]
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING
]
agent_proc = _search_agent(raylet_proc.children()) agent_proc = _search_agent(raylet_proc.children())
agent_pid = agent_proc.pid agent_pid = agent_proc.pid
@ -161,40 +161,39 @@ def test_basic(ray_start_with_dashboard):
# Check kv keys are set. # Check kv keys are set.
logger.info("Check kv keys are set.") logger.info("Check kv keys are set.")
dashboard_address = ray.experimental.internal_kv._internal_kv_get( dashboard_address = ray.experimental.internal_kv._internal_kv_get(
ray_constants.DASHBOARD_ADDRESS, ray_constants.DASHBOARD_ADDRESS, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
namespace=ray_constants.KV_NAMESPACE_DASHBOARD) )
assert dashboard_address is not None assert dashboard_address is not None
dashboard_rpc_address = ray.experimental.internal_kv._internal_kv_get( dashboard_rpc_address = ray.experimental.internal_kv._internal_kv_get(
dashboard_consts.DASHBOARD_RPC_ADDRESS, dashboard_consts.DASHBOARD_RPC_ADDRESS,
namespace=ray_constants.KV_NAMESPACE_DASHBOARD) namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
assert dashboard_rpc_address is not None assert dashboard_rpc_address is not None
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}" key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}"
agent_ports = ray.experimental.internal_kv._internal_kv_get( agent_ports = ray.experimental.internal_kv._internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD) key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
)
assert agent_ports is not None assert agent_ports is not None
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ray_start_with_dashboard", [{ "ray_start_with_dashboard",
"dashboard_host": "127.0.0.1" [
}, { {"dashboard_host": "127.0.0.1"},
"dashboard_host": "0.0.0.0" {"dashboard_host": "0.0.0.0"},
}, { {"dashboard_host": "::"},
"dashboard_host": "::" ],
}], indirect=True,
indirect=True) )
def test_dashboard_address(ray_start_with_dashboard): def test_dashboard_address(ray_start_with_dashboard):
webui_url = ray_start_with_dashboard["webui_url"] webui_url = ray_start_with_dashboard["webui_url"]
webui_ip = webui_url.split(":")[0] webui_ip = webui_url.split(":")[0]
assert not ipaddress.ip_address(webui_ip).is_unspecified assert not ipaddress.ip_address(webui_ip).is_unspecified
assert webui_ip in [ assert webui_ip in ["127.0.0.1", ray_start_with_dashboard["node_ip_address"]]
"127.0.0.1", ray_start_with_dashboard["node_ip_address"]
]
def test_http_get(enable_test_module, ray_start_with_dashboard): def test_http_get(enable_test_module, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
is True)
webui_url = ray_start_with_dashboard["webui_url"] webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
@ -205,8 +204,7 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
while True: while True:
time.sleep(3) time.sleep(3)
try: try:
response = requests.get(webui_url + "/test/http_get?url=" + response = requests.get(webui_url + "/test/http_get?url=" + target_url)
target_url)
response.raise_for_status() response.raise_for_status()
try: try:
dump_info = response.json() dump_info = response.json()
@ -221,8 +219,8 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
http_port, grpc_port = ports http_port, grpc_port = ports
response = requests.get( response = requests.get(
f"http://{ip}:{http_port}" f"http://{ip}:{http_port}" f"/test/http_get_from_agent?url={target_url}"
f"/test/http_get_from_agent?url={target_url}") )
response.raise_for_status() response.raise_for_status()
try: try:
dump_info = response.json() dump_info = response.json()
@ -239,10 +237,10 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
def test_class_method_route_table(enable_test_module): def test_class_method_route_table(enable_test_module):
head_cls_list = dashboard_utils.get_all_modules( head_cls_list = dashboard_utils.get_all_modules(dashboard_utils.DashboardHeadModule)
dashboard_utils.DashboardHeadModule)
agent_cls_list = dashboard_utils.get_all_modules( agent_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardAgentModule) dashboard_utils.DashboardAgentModule
)
test_head_cls = None test_head_cls = None
for cls in head_cls_list: for cls in head_cls_list:
if cls.__name__ == "TestHead": if cls.__name__ == "TestHead":
@ -274,28 +272,23 @@ def test_class_method_route_table(enable_test_module):
assert any(_has_route(r, "POST", "/test/route_post") for r in all_routes) assert any(_has_route(r, "POST", "/test/route_post") for r in all_routes)
assert any(_has_route(r, "PUT", "/test/route_put") for r in all_routes) assert any(_has_route(r, "PUT", "/test/route_put") for r in all_routes)
assert any(_has_route(r, "PATCH", "/test/route_patch") for r in all_routes) assert any(_has_route(r, "PATCH", "/test/route_patch") for r in all_routes)
assert any( assert any(_has_route(r, "DELETE", "/test/route_delete") for r in all_routes)
_has_route(r, "DELETE", "/test/route_delete") for r in all_routes)
assert any(_has_route(r, "*", "/test/route_view") for r in all_routes) assert any(_has_route(r, "*", "/test/route_view") for r in all_routes)
# Test bind() # Test bind()
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes( bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
)
assert len(bound_routes) == 0 assert len(bound_routes) == 0
dashboard_optional_utils.ClassMethodRouteTable.bind( dashboard_optional_utils.ClassMethodRouteTable.bind(
test_agent_cls.__new__(test_agent_cls)) test_agent_cls.__new__(test_agent_cls)
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes(
) )
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
assert any(_has_route(r, "POST", "/test/route_post") for r in bound_routes) assert any(_has_route(r, "POST", "/test/route_post") for r in bound_routes)
assert all( assert all(not _has_route(r, "PUT", "/test/route_put") for r in bound_routes)
not _has_route(r, "PUT", "/test/route_put") for r in bound_routes)
# Static def should be in bound routes. # Static def should be in bound routes.
routes.static("/test/route_static", "/path") routes.static("/test/route_static", "/path")
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes( bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
) assert any(_has_static(r, "/path", "/test/route_static") for r in bound_routes)
assert any(
_has_static(r, "/path", "/test/route_static") for r in bound_routes)
# Test duplicated routes should raise exception. # Test duplicated routes should raise exception.
try: try:
@ -358,10 +351,10 @@ def test_async_loop_forever():
def test_dashboard_module_decorator(enable_test_module): def test_dashboard_module_decorator(enable_test_module):
head_cls_list = dashboard_utils.get_all_modules( head_cls_list = dashboard_utils.get_all_modules(dashboard_utils.DashboardHeadModule)
dashboard_utils.DashboardHeadModule)
agent_cls_list = dashboard_utils.get_all_modules( agent_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardAgentModule) dashboard_utils.DashboardAgentModule
)
assert any(cls.__name__ == "TestHead" for cls in head_cls_list) assert any(cls.__name__ == "TestHead" for cls in head_cls_list)
assert any(cls.__name__ == "TestAgent" for cls in agent_cls_list) assert any(cls.__name__ == "TestAgent" for cls in agent_cls_list)
@ -385,8 +378,7 @@ print("success")
def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard): def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
is True)
webui_url = ray_start_with_dashboard["webui_url"] webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
@ -397,8 +389,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
time.sleep(1) time.sleep(1)
try: try:
for x in range(10): for x in range(10):
response = requests.get(webui_url + response = requests.get(webui_url + "/test/aiohttp_cache/t1?value=1")
"/test/aiohttp_cache/t1?value=1")
response.raise_for_status() response.raise_for_status()
timestamp = response.json()["data"]["timestamp"] timestamp = response.json()["data"]["timestamp"]
value1_timestamps.append(timestamp) value1_timestamps.append(timestamp)
@ -412,8 +403,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
sub_path_timestamps = [] sub_path_timestamps = []
for x in range(10): for x in range(10):
response = requests.get(webui_url + response = requests.get(webui_url + f"/test/aiohttp_cache/tt{x}?value=1")
f"/test/aiohttp_cache/tt{x}?value=1")
response.raise_for_status() response.raise_for_status()
timestamp = response.json()["data"]["timestamp"] timestamp = response.json()["data"]["timestamp"]
sub_path_timestamps.append(timestamp) sub_path_timestamps.append(timestamp)
@ -421,8 +411,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
volatile_value_timestamps = [] volatile_value_timestamps = []
for x in range(10): for x in range(10):
response = requests.get(webui_url + response = requests.get(webui_url + f"/test/aiohttp_cache/tt?value={x}")
f"/test/aiohttp_cache/tt?value={x}")
response.raise_for_status() response.raise_for_status()
timestamp = response.json()["data"]["timestamp"] timestamp = response.json()["data"]["timestamp"]
volatile_value_timestamps.append(timestamp) volatile_value_timestamps.append(timestamp)
@ -436,8 +425,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
volatile_value_timestamps = [] volatile_value_timestamps = []
for x in range(10): for x in range(10):
response = requests.get(webui_url + response = requests.get(webui_url + f"/test/aiohttp_cache_lru/tt{x % 4}")
f"/test/aiohttp_cache_lru/tt{x % 4}")
response.raise_for_status() response.raise_for_status()
timestamp = response.json()["data"]["timestamp"] timestamp = response.json()["data"]["timestamp"]
volatile_value_timestamps.append(timestamp) volatile_value_timestamps.append(timestamp)
@ -446,8 +434,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
volatile_value_timestamps = [] volatile_value_timestamps = []
data = collections.defaultdict(set) data = collections.defaultdict(set)
for x in [0, 1, 2, 3, 4, 5, 2, 1, 0, 3]: for x in [0, 1, 2, 3, 4, 5, 2, 1, 0, 3]:
response = requests.get(webui_url + response = requests.get(webui_url + f"/test/aiohttp_cache_lru/t1?value={x}")
f"/test/aiohttp_cache_lru/t1?value={x}")
response.raise_for_status() response.raise_for_status()
timestamp = response.json()["data"]["timestamp"] timestamp = response.json()["data"]["timestamp"]
data[x].add(timestamp) data[x].add(timestamp)
@ -458,8 +445,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
def test_get_cluster_status(ray_start_with_dashboard): def test_get_cluster_status(ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
is True)
address_info = ray_start_with_dashboard address_info = ray_start_with_dashboard
webui_url = address_info["webui_url"] webui_url = address_info["webui_url"]
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
@ -478,14 +464,15 @@ def test_get_cluster_status(ray_start_with_dashboard):
assert "loadMetricsReport" in response.json()["data"]["clusterStatus"] assert "loadMetricsReport" in response.json()["data"]["clusterStatus"]
assert wait_until_succeeded_without_exception( assert wait_until_succeeded_without_exception(
get_cluster_status, (requests.RequestException, )) get_cluster_status, (requests.RequestException,)
)
gcs_client = make_gcs_client(address_info) gcs_client = make_gcs_client(address_info)
ray.experimental.internal_kv._initialize_internal_kv(gcs_client) ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
ray.experimental.internal_kv._internal_kv_put( ray.experimental.internal_kv._internal_kv_put(
DEBUG_AUTOSCALING_STATUS_LEGACY, "hello") DEBUG_AUTOSCALING_STATUS_LEGACY, "hello"
ray.experimental.internal_kv._internal_kv_put(DEBUG_AUTOSCALING_ERROR, )
"world") ray.experimental.internal_kv._internal_kv_put(DEBUG_AUTOSCALING_ERROR, "world")
response = requests.get(f"{webui_url}/api/cluster_status") response = requests.get(f"{webui_url}/api/cluster_status")
response.raise_for_status() response.raise_for_status()
@ -508,20 +495,19 @@ def test_immutable_types():
assert immutable_dict == dashboard_utils.ImmutableDict(d) assert immutable_dict == dashboard_utils.ImmutableDict(d)
assert immutable_dict == d assert immutable_dict == d
assert dashboard_utils.ImmutableDict(immutable_dict) == immutable_dict assert dashboard_utils.ImmutableDict(immutable_dict) == immutable_dict
assert dashboard_utils.ImmutableList( assert (
immutable_dict["list"]) == immutable_dict["list"] dashboard_utils.ImmutableList(immutable_dict["list"]) == immutable_dict["list"]
)
assert "512" in d assert "512" in d
assert "512" in d["list"][0] assert "512" in d["list"][0]
assert "512" in d["dict"] assert "512" in d["dict"]
# Test type conversion # Test type conversion
assert type(dict(immutable_dict)["list"]) == dashboard_utils.ImmutableList assert type(dict(immutable_dict)["list"]) == dashboard_utils.ImmutableList
assert type(list( assert type(list(immutable_dict["list"])[0]) == dashboard_utils.ImmutableDict
immutable_dict["list"])[0]) == dashboard_utils.ImmutableDict
# Test json dumps / loads # Test json dumps / loads
json_str = json.dumps( json_str = json.dumps(immutable_dict, cls=dashboard_optional_utils.CustomEncoder)
immutable_dict, cls=dashboard_optional_utils.CustomEncoder)
deserialized_immutable_dict = json.loads(json_str) deserialized_immutable_dict = json.loads(json_str)
assert type(deserialized_immutable_dict) == dict assert type(deserialized_immutable_dict) == dict
assert type(deserialized_immutable_dict["list"]) == list assert type(deserialized_immutable_dict["list"]) == list
@ -577,7 +563,7 @@ def test_immutable_types():
def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only): def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
address_info = ray.init(num_cpus=1, include_dashboard=True) address_info = ray.init(num_cpus=1, include_dashboard=True)
assert (wait_until_server_available(address_info["webui_url"]) is True) assert wait_until_server_available(address_info["webui_url"]) is True
webui_url = address_info["webui_url"] webui_url = address_info["webui_url"]
webui_url = format_web_url(webui_url) webui_url = format_web_url(webui_url)
@ -588,11 +574,8 @@ def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
time.sleep(1) time.sleep(1)
try: try:
response = requests.get( response = requests.get(
webui_url + "/test/dump", webui_url + "/test/dump", proxies={"http": None, "https": None}
proxies={ )
"http": None,
"https": None
})
response.raise_for_status() response.raise_for_status()
try: try:
response.json() response.json()
@ -609,8 +592,7 @@ def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
def test_dashboard_port_conflict(ray_start_with_dashboard): def test_dashboard_port_conflict(ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
is True)
address_info = ray_start_with_dashboard address_info = ray_start_with_dashboard
gcs_client = make_gcs_client(address_info) gcs_client = make_gcs_client(address_info)
ray.experimental.internal_kv._initialize_internal_kv(gcs_client) ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
@ -618,11 +600,15 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
temp_dir = "/tmp/ray" temp_dir = "/tmp/ray"
log_dir = "/tmp/ray/session_latest/logs" log_dir = "/tmp/ray/session_latest/logs"
dashboard_cmd = [ dashboard_cmd = [
sys.executable, dashboard.__file__, f"--host={host}", f"--port={port}", sys.executable,
f"--temp-dir={temp_dir}", f"--log-dir={log_dir}", dashboard.__file__,
f"--host={host}",
f"--port={port}",
f"--temp-dir={temp_dir}",
f"--log-dir={log_dir}",
f"--redis-address={address_info['redis_address']}", f"--redis-address={address_info['redis_address']}",
f"--redis-password={ray_constants.REDIS_DEFAULT_PASSWORD}", f"--redis-password={ray_constants.REDIS_DEFAULT_PASSWORD}",
f"--gcs-address={address_info['gcs_address']}" f"--gcs-address={address_info['gcs_address']}",
] ]
logger.info("The dashboard should be exit: %s", dashboard_cmd) logger.info("The dashboard should be exit: %s", dashboard_cmd)
p = subprocess.Popen(dashboard_cmd) p = subprocess.Popen(dashboard_cmd)
@ -638,7 +624,8 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
try: try:
dashboard_url = ray.experimental.internal_kv._internal_kv_get( dashboard_url = ray.experimental.internal_kv._internal_kv_get(
ray_constants.DASHBOARD_ADDRESS, ray_constants.DASHBOARD_ADDRESS,
namespace=ray_constants.KV_NAMESPACE_DASHBOARD) namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
if dashboard_url: if dashboard_url:
new_port = int(dashboard_url.split(b":")[-1]) new_port = int(dashboard_url.split(b":")[-1])
assert new_port > int(port) assert new_port > int(port)
@ -651,8 +638,7 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard): def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
is True)
all_processes = ray.worker._global_node.all_processes all_processes = ray.worker._global_node.all_processes
dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0] dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
@ -661,7 +647,9 @@ def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
gcs_server_proc = psutil.Process(gcs_server_info.process.pid) gcs_server_proc = psutil.Process(gcs_server_info.process.pid)
assert dashboard_proc.status() in [ assert dashboard_proc.status() in [
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING, psutil.STATUS_DISK_SLEEP psutil.STATUS_RUNNING,
psutil.STATUS_SLEEPING,
psutil.STATUS_DISK_SLEEP,
] ]
gcs_server_proc.kill() gcs_server_proc.kill()

View file

@ -1,7 +1,12 @@
import ray import ray
from ray.dashboard.memory_utils import ( from ray.dashboard.memory_utils import (
ReferenceType, decode_object_ref_if_needed, MemoryTableEntry, MemoryTable, ReferenceType,
SortingType) decode_object_ref_if_needed,
MemoryTableEntry,
MemoryTable,
SortingType,
)
"""Memory Table Unit Test""" """Memory Table Unit Test"""
NODE_ADDRESS = "127.0.0.1" NODE_ADDRESS = "127.0.0.1"
@ -14,15 +19,17 @@ DECODED_ID = decode_object_ref_if_needed(OBJECT_ID)
OBJECT_SIZE = 100 OBJECT_SIZE = 100
def build_memory_entry(*, def build_memory_entry(
local_ref_count, *,
pinned_in_memory, local_ref_count,
submitted_task_reference_count, pinned_in_memory,
contained_in_owned, submitted_task_reference_count,
object_size, contained_in_owned,
pid, object_size,
object_id=OBJECT_ID, pid,
node_address=NODE_ADDRESS): object_id=OBJECT_ID,
node_address=NODE_ADDRESS
):
object_ref = { object_ref = {
"objectId": object_id, "objectId": object_id,
"callSite": "(task call) /Users:458", "callSite": "(task call) /Users:458",
@ -30,18 +37,16 @@ def build_memory_entry(*,
"localRefCount": local_ref_count, "localRefCount": local_ref_count,
"pinnedInMemory": pinned_in_memory, "pinnedInMemory": pinned_in_memory,
"submittedTaskRefCount": submitted_task_reference_count, "submittedTaskRefCount": submitted_task_reference_count,
"containedInOwned": contained_in_owned "containedInOwned": contained_in_owned,
} }
return MemoryTableEntry( return MemoryTableEntry(
object_ref=object_ref, object_ref=object_ref, node_address=node_address, is_driver=IS_DRIVER, pid=pid
node_address=node_address, )
is_driver=IS_DRIVER,
pid=pid)
def build_local_reference_entry(object_size=OBJECT_SIZE, def build_local_reference_entry(
pid=PID, object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
node_address=NODE_ADDRESS): ):
return build_memory_entry( return build_memory_entry(
local_ref_count=1, local_ref_count=1,
pinned_in_memory=False, pinned_in_memory=False,
@ -49,12 +54,13 @@ def build_local_reference_entry(object_size=OBJECT_SIZE,
contained_in_owned=[], contained_in_owned=[],
object_size=object_size, object_size=object_size,
pid=pid, pid=pid,
node_address=node_address) node_address=node_address,
)
def build_used_by_pending_task_entry(object_size=OBJECT_SIZE, def build_used_by_pending_task_entry(
pid=PID, object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
node_address=NODE_ADDRESS): ):
return build_memory_entry( return build_memory_entry(
local_ref_count=0, local_ref_count=0,
pinned_in_memory=False, pinned_in_memory=False,
@ -62,12 +68,13 @@ def build_used_by_pending_task_entry(object_size=OBJECT_SIZE,
contained_in_owned=[], contained_in_owned=[],
object_size=object_size, object_size=object_size,
pid=pid, pid=pid,
node_address=node_address) node_address=node_address,
)
def build_captured_in_object_entry(object_size=OBJECT_SIZE, def build_captured_in_object_entry(
pid=PID, object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
node_address=NODE_ADDRESS): ):
return build_memory_entry( return build_memory_entry(
local_ref_count=0, local_ref_count=0,
pinned_in_memory=False, pinned_in_memory=False,
@ -75,12 +82,13 @@ def build_captured_in_object_entry(object_size=OBJECT_SIZE,
contained_in_owned=[OBJECT_ID], contained_in_owned=[OBJECT_ID],
object_size=object_size, object_size=object_size,
pid=pid, pid=pid,
node_address=node_address) node_address=node_address,
)
def build_actor_handle_entry(object_size=OBJECT_SIZE, def build_actor_handle_entry(
pid=PID, object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
node_address=NODE_ADDRESS): ):
return build_memory_entry( return build_memory_entry(
local_ref_count=1, local_ref_count=1,
pinned_in_memory=False, pinned_in_memory=False,
@ -89,12 +97,13 @@ def build_actor_handle_entry(object_size=OBJECT_SIZE,
object_size=object_size, object_size=object_size,
pid=pid, pid=pid,
node_address=node_address, node_address=node_address,
object_id=ACTOR_ID) object_id=ACTOR_ID,
)
def build_pinned_in_memory_entry(object_size=OBJECT_SIZE, def build_pinned_in_memory_entry(
pid=PID, object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
node_address=NODE_ADDRESS): ):
return build_memory_entry( return build_memory_entry(
local_ref_count=0, local_ref_count=0,
pinned_in_memory=True, pinned_in_memory=True,
@ -102,28 +111,36 @@ def build_pinned_in_memory_entry(object_size=OBJECT_SIZE,
contained_in_owned=[], contained_in_owned=[],
object_size=object_size, object_size=object_size,
pid=pid, pid=pid,
node_address=node_address) node_address=node_address,
)
def build_entry(object_size=OBJECT_SIZE, def build_entry(
pid=PID, object_size=OBJECT_SIZE,
node_address=NODE_ADDRESS, pid=PID,
reference_type=ReferenceType.PINNED_IN_MEMORY): node_address=NODE_ADDRESS,
reference_type=ReferenceType.PINNED_IN_MEMORY,
):
if reference_type == ReferenceType.USED_BY_PENDING_TASK: if reference_type == ReferenceType.USED_BY_PENDING_TASK:
return build_used_by_pending_task_entry( return build_used_by_pending_task_entry(
pid=pid, object_size=object_size, node_address=node_address) pid=pid, object_size=object_size, node_address=node_address
)
elif reference_type == ReferenceType.LOCAL_REFERENCE: elif reference_type == ReferenceType.LOCAL_REFERENCE:
return build_local_reference_entry( return build_local_reference_entry(
pid=pid, object_size=object_size, node_address=node_address) pid=pid, object_size=object_size, node_address=node_address
)
elif reference_type == ReferenceType.PINNED_IN_MEMORY: elif reference_type == ReferenceType.PINNED_IN_MEMORY:
return build_pinned_in_memory_entry( return build_pinned_in_memory_entry(
pid=pid, object_size=object_size, node_address=node_address) pid=pid, object_size=object_size, node_address=node_address
)
elif reference_type == ReferenceType.ACTOR_HANDLE: elif reference_type == ReferenceType.ACTOR_HANDLE:
return build_actor_handle_entry( return build_actor_handle_entry(
pid=pid, object_size=object_size, node_address=node_address) pid=pid, object_size=object_size, node_address=node_address
)
elif reference_type == ReferenceType.CAPTURED_IN_OBJECT: elif reference_type == ReferenceType.CAPTURED_IN_OBJECT:
return build_captured_in_object_entry( return build_captured_in_object_entry(
pid=pid, object_size=object_size, node_address=node_address) pid=pid, object_size=object_size, node_address=node_address
)
def test_invalid_memory_entry(): def test_invalid_memory_entry():
@ -133,7 +150,8 @@ def test_invalid_memory_entry():
submitted_task_reference_count=0, submitted_task_reference_count=0,
contained_in_owned=[], contained_in_owned=[],
object_size=OBJECT_SIZE, object_size=OBJECT_SIZE,
pid=PID) pid=PID,
)
assert memory_entry.is_valid() is False assert memory_entry.is_valid() is False
memory_entry = build_memory_entry( memory_entry = build_memory_entry(
local_ref_count=0, local_ref_count=0,
@ -141,7 +159,8 @@ def test_invalid_memory_entry():
submitted_task_reference_count=0, submitted_task_reference_count=0,
contained_in_owned=[], contained_in_owned=[],
object_size=-1, object_size=-1,
pid=PID) pid=PID,
)
assert memory_entry.is_valid() is False assert memory_entry.is_valid() is False
@ -149,7 +168,8 @@ def test_valid_reference_memory_entry():
memory_entry = build_local_reference_entry() memory_entry = build_local_reference_entry()
assert memory_entry.reference_type == ReferenceType.LOCAL_REFERENCE assert memory_entry.reference_type == ReferenceType.LOCAL_REFERENCE
assert memory_entry.object_ref == ray.ObjectRef( assert memory_entry.object_ref == ray.ObjectRef(
decode_object_ref_if_needed(OBJECT_ID)) decode_object_ref_if_needed(OBJECT_ID)
)
assert memory_entry.is_valid() is True assert memory_entry.is_valid() is True
@ -178,15 +198,14 @@ def test_memory_table_summary():
build_captured_in_object_entry(), build_captured_in_object_entry(),
build_actor_handle_entry(), build_actor_handle_entry(),
build_local_reference_entry(), build_local_reference_entry(),
build_local_reference_entry() build_local_reference_entry(),
] ]
memory_table = MemoryTable(entries) memory_table = MemoryTable(entries)
assert len(memory_table.group) == 1 assert len(memory_table.group) == 1
assert memory_table.summary["total_actor_handles"] == 1 assert memory_table.summary["total_actor_handles"] == 1
assert memory_table.summary["total_captured_in_objects"] == 1 assert memory_table.summary["total_captured_in_objects"] == 1
assert memory_table.summary["total_local_ref_count"] == 2 assert memory_table.summary["total_local_ref_count"] == 2
assert memory_table.summary[ assert memory_table.summary["total_object_size"] == len(entries) * OBJECT_SIZE
"total_object_size"] == len(entries) * OBJECT_SIZE
assert memory_table.summary["total_pinned_in_memory"] == 1 assert memory_table.summary["total_pinned_in_memory"] == 1
assert memory_table.summary["total_used_by_pending_task"] == 1 assert memory_table.summary["total_used_by_pending_task"] == 1
@ -202,14 +221,13 @@ def test_memory_table_sort_by_pid():
def test_memory_table_sort_by_reference_type(): def test_memory_table_sort_by_reference_type():
unsort = [ unsort = [
ReferenceType.USED_BY_PENDING_TASK, ReferenceType.LOCAL_REFERENCE, ReferenceType.USED_BY_PENDING_TASK,
ReferenceType.LOCAL_REFERENCE, ReferenceType.PINNED_IN_MEMORY ReferenceType.LOCAL_REFERENCE,
ReferenceType.LOCAL_REFERENCE,
ReferenceType.PINNED_IN_MEMORY,
] ]
entries = [ entries = [build_entry(reference_type=reference_type) for reference_type in unsort]
build_entry(reference_type=reference_type) for reference_type in unsort memory_table = MemoryTable(entries, sort_by_type=SortingType.REFERENCE_TYPE)
]
memory_table = MemoryTable(
entries, sort_by_type=SortingType.REFERENCE_TYPE)
sort = sorted(unsort) sort = sorted(unsort)
for reference_type, entry in zip(sort, memory_table.table): for reference_type, entry in zip(sort, memory_table.table):
assert reference_type == entry.reference_type assert reference_type == entry.reference_type
@ -231,7 +249,7 @@ def test_group_by():
build_entry(node_address=node_second, pid=2), build_entry(node_address=node_second, pid=2),
build_entry(node_address=node_second, pid=1), build_entry(node_address=node_second, pid=1),
build_entry(node_address=node_first, pid=2), build_entry(node_address=node_first, pid=2),
build_entry(node_address=node_first, pid=1) build_entry(node_address=node_first, pid=1),
] ]
memory_table = MemoryTable(entries) memory_table = MemoryTable(entries)
@ -250,4 +268,5 @@ def test_group_by():
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
import pytest import pytest
sys.exit(pytest.main(["-v", __file__])) sys.exit(pytest.main(["-v", __file__]))

View file

@ -18,8 +18,7 @@ import aiosignal # noqa: F401
from google.protobuf.json_format import MessageToDict from google.protobuf.json_format import MessageToDict
from frozenlist import FrozenList # noqa: F401 from frozenlist import FrozenList # noqa: F401
from ray._private.utils import (binary_to_hex, from ray._private.utils import binary_to_hex, check_dashboard_dependencies_installed
check_dashboard_dependencies_installed)
try: try:
create_task = asyncio.create_task create_task = asyncio.create_task
@ -97,23 +96,26 @@ def get_all_modules(module_type):
""" """
logger.info(f"Get all modules by type: {module_type.__name__}") logger.info(f"Get all modules by type: {module_type.__name__}")
import ray.dashboard.modules import ray.dashboard.modules
should_only_load_minimal_modules = (
not check_dashboard_dependencies_installed()) should_only_load_minimal_modules = not check_dashboard_dependencies_installed()
for module_loader, name, ispkg in pkgutil.walk_packages( for module_loader, name, ispkg in pkgutil.walk_packages(
ray.dashboard.modules.__path__, ray.dashboard.modules.__path__, ray.dashboard.modules.__name__ + "."
ray.dashboard.modules.__name__ + "."): ):
try: try:
importlib.import_module(name) importlib.import_module(name)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
logger.info(f"Module {name} cannot be loaded because " logger.info(
"we cannot import all dependencies. Download " f"Module {name} cannot be loaded because "
"`pip install ray[default]` for the full " "we cannot import all dependencies. Download "
f"dashboard functionality. Error: {e}") "`pip install ray[default]` for the full "
f"dashboard functionality. Error: {e}"
)
if not should_only_load_minimal_modules: if not should_only_load_minimal_modules:
logger.info( logger.info(
"Although `pip install ray[default] is downloaded, " "Although `pip install ray[default] is downloaded, "
"module couldn't be imported`") "module couldn't be imported`"
)
raise e raise e
imported_modules = [] imported_modules = []
@ -202,7 +204,8 @@ def message_to_dict(message, decode_keys=None, **kwargs):
if decode_keys: if decode_keys:
return _decode_keys( return _decode_keys(
MessageToDict(message, use_integers_for_enums=False, **kwargs)) MessageToDict(message, use_integers_for_enums=False, **kwargs)
)
else: else:
return MessageToDict(message, use_integers_for_enums=False, **kwargs) return MessageToDict(message, use_integers_for_enums=False, **kwargs)
@ -251,8 +254,9 @@ class Change:
self.new = new self.new = new
def __str__(self): def __str__(self):
return f"Change(owner: {type(self.owner)}), " \ return (
f"old: {self.old}, new: {self.new}" f"Change(owner: {type(self.owner)}), " f"old: {self.old}, new: {self.new}"
)
class NotifyQueue: class NotifyQueue:
@ -289,10 +293,7 @@ https://docs.python.org/3/library/json.html?highlight=json#json.JSONEncoder
| None | null | | None | null |
+-------------------+---------------+ +-------------------+---------------+
""" """
_json_compatible_types = { _json_compatible_types = {dict, list, tuple, str, int, float, bool, type(None), bytes}
dict, list, tuple, str, int, float, bool,
type(None), bytes
}
def is_immutable(self): def is_immutable(self):
@ -318,8 +319,7 @@ class Immutable(metaclass=ABCMeta):
class ImmutableList(Immutable, Sequence): class ImmutableList(Immutable, Sequence):
"""Makes a :class:`list` immutable. """Makes a :class:`list` immutable."""
"""
__slots__ = ("_list", "_proxy") __slots__ = ("_list", "_proxy")
@ -332,7 +332,7 @@ class ImmutableList(Immutable, Sequence):
self._proxy = [None] * len(list_value) self._proxy = [None] * len(list_value)
def __reduce_ex__(self, protocol): def __reduce_ex__(self, protocol):
return type(self), (self._list, ) return type(self), (self._list,)
def mutable(self): def mutable(self):
return self._list return self._list
@ -366,8 +366,7 @@ class ImmutableList(Immutable, Sequence):
class ImmutableDict(Immutable, Mapping): class ImmutableDict(Immutable, Mapping):
"""Makes a :class:`dict` immutable. """Makes a :class:`dict` immutable."""
"""
__slots__ = ("_dict", "_proxy") __slots__ = ("_dict", "_proxy")
@ -380,7 +379,7 @@ class ImmutableDict(Immutable, Mapping):
self._proxy = {} self._proxy = {}
def __reduce_ex__(self, protocol): def __reduce_ex__(self, protocol):
return type(self), (self._dict, ) return type(self), (self._dict,)
def mutable(self): def mutable(self):
return self._dict return self._dict
@ -443,21 +442,23 @@ class Dict(ImmutableDict, MutableMapping):
if len(self.signal) and old != value: if len(self.signal) and old != value:
if old is None: if old is None:
co = self.signal.send( co = self.signal.send(
Change(owner=self, new=Dict.ChangeItem(key, value))) Change(owner=self, new=Dict.ChangeItem(key, value))
)
else: else:
co = self.signal.send( co = self.signal.send(
Change( Change(
owner=self, owner=self,
old=Dict.ChangeItem(key, old), old=Dict.ChangeItem(key, old),
new=Dict.ChangeItem(key, value))) new=Dict.ChangeItem(key, value),
)
)
NotifyQueue.put(co) NotifyQueue.put(co)
def __delitem__(self, key): def __delitem__(self, key):
old = self._dict.pop(key, None) old = self._dict.pop(key, None)
self._proxy.pop(key, None) self._proxy.pop(key, None)
if len(self.signal) and old is not None: if len(self.signal) and old is not None:
co = self.signal.send( co = self.signal.send(Change(owner=self, old=Dict.ChangeItem(key, old)))
Change(owner=self, old=Dict.ChangeItem(key, old)))
NotifyQueue.put(co) NotifyQueue.put(co)
def reset(self, d): def reset(self, d):
@ -482,12 +483,15 @@ def async_loop_forever(interval_seconds, cancellable=False):
await coro(*args, **kwargs) await coro(*args, **kwargs)
except asyncio.CancelledError as ex: except asyncio.CancelledError as ex:
if cancellable: if cancellable:
logger.info(f"An async loop forever coroutine " logger.info(
f"is cancelled {coro}.") f"An async loop forever coroutine " f"is cancelled {coro}."
)
raise ex raise ex
else: else:
logger.exception(f"Can not cancel the async loop " logger.exception(
f"forever coroutine {coro}.") f"Can not cancel the async loop "
f"forever coroutine {coro}."
)
except Exception: except Exception:
logger.exception(f"Error looping coroutine {coro}.") logger.exception(f"Error looping coroutine {coro}.")
await asyncio.sleep(interval_seconds) await asyncio.sleep(interval_seconds)
@ -497,15 +501,18 @@ def async_loop_forever(interval_seconds, cancellable=False):
return _wrapper return _wrapper
async def get_aioredis_client(redis_address, redis_password, async def get_aioredis_client(
retry_interval_seconds, retry_times): redis_address, redis_password, retry_interval_seconds, retry_times
):
for x in range(retry_times): for x in range(retry_times):
try: try:
return await aioredis.create_redis_pool( return await aioredis.create_redis_pool(
address=redis_address, password=redis_password) address=redis_address, password=redis_password
)
except (socket.gaierror, ConnectionError) as ex: except (socket.gaierror, ConnectionError) as ex:
logger.error("Connect to Redis failed: %s, retry...", ex) logger.error("Connect to Redis failed: %s, retry...", ex)
await asyncio.sleep(retry_interval_seconds) await asyncio.sleep(retry_interval_seconds)
# Raise exception from create_redis_pool # Raise exception from create_redis_pool
return await aioredis.create_redis_pool( return await aioredis.create_redis_pool(
address=redis_address, password=redis_password) address=redis_address, password=redis_password
)

View file

@ -2,6 +2,7 @@ from collections import Counter
import sys import sys
import time import time
import ray import ray
""" This script is meant to be run from a pod in the same Kubernetes namespace """ This script is meant to be run from a pod in the same Kubernetes namespace
as your Ray cluster. as your Ray cluster.
""" """
@ -11,8 +12,9 @@ as your Ray cluster.
def gethostname(x): def gethostname(x):
import platform import platform
import time import time
time.sleep(0.01) time.sleep(0.01)
return x + (platform.node(), ) return x + (platform.node(),)
def wait_for_nodes(expected): def wait_for_nodes(expected):
@ -22,8 +24,11 @@ def wait_for_nodes(expected):
node_keys = [key for key in resources if "node" in key] node_keys = [key for key in resources if "node" in key]
num_nodes = sum(resources[node_key] for node_key in node_keys) num_nodes = sum(resources[node_key] for node_key in node_keys)
if num_nodes < expected: if num_nodes < expected:
print("{} nodes have joined so far, waiting for {} more.".format( print(
num_nodes, expected - num_nodes)) "{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes
)
)
sys.stdout.flush() sys.stdout.flush()
time.sleep(1) time.sleep(1)
else: else:
@ -36,9 +41,7 @@ def main():
# Check that objects can be transferred from each node to each other node. # Check that objects can be transferred from each node to each other node.
for i in range(10): for i in range(10):
print("Iteration {}".format(i)) print("Iteration {}".format(i))
results = [ results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
gethostname.remote(gethostname.remote(())) for _ in range(100)
]
print(Counter(ray.get(results))) print(Counter(ray.get(results)))
sys.stdout.flush() sys.stdout.flush()

View file

@ -2,6 +2,7 @@ from collections import Counter
import sys import sys
import time import time
import ray import ray
""" Run this script locally to execute a Ray program on your Ray cluster on """ Run this script locally to execute a Ray program on your Ray cluster on
Kubernetes. Kubernetes.
@ -18,8 +19,9 @@ LOCAL_PORT = 10001
def gethostname(x): def gethostname(x):
import platform import platform
import time import time
time.sleep(0.01) time.sleep(0.01)
return x + (platform.node(), ) return x + (platform.node(),)
def wait_for_nodes(expected): def wait_for_nodes(expected):
@ -29,8 +31,11 @@ def wait_for_nodes(expected):
node_keys = [key for key in resources if "node" in key] node_keys = [key for key in resources if "node" in key]
num_nodes = sum(resources[node_key] for node_key in node_keys) num_nodes = sum(resources[node_key] for node_key in node_keys)
if num_nodes < expected: if num_nodes < expected:
print("{} nodes have joined so far, waiting for {} more.".format( print(
num_nodes, expected - num_nodes)) "{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes
)
)
sys.stdout.flush() sys.stdout.flush()
time.sleep(1) time.sleep(1)
else: else:
@ -43,9 +48,7 @@ def main():
# Check that objects can be transferred from each node to each other node. # Check that objects can be transferred from each node to each other node.
for i in range(10): for i in range(10):
print("Iteration {}".format(i)) print("Iteration {}".format(i))
results = [ results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
gethostname.remote(gethostname.remote(())) for _ in range(100)
]
print(Counter(ray.get(results))) print(Counter(ray.get(results)))
sys.stdout.flush() sys.stdout.flush()

View file

@ -10,8 +10,9 @@ import ray
def gethostname(x): def gethostname(x):
import platform import platform
import time import time
time.sleep(0.01) time.sleep(0.01)
return x + (platform.node(), ) return x + (platform.node(),)
def wait_for_nodes(expected): def wait_for_nodes(expected):
@ -21,8 +22,11 @@ def wait_for_nodes(expected):
node_keys = [key for key in resources if "node" in key] node_keys = [key for key in resources if "node" in key]
num_nodes = sum(resources[node_key] for node_key in node_keys) num_nodes = sum(resources[node_key] for node_key in node_keys)
if num_nodes < expected: if num_nodes < expected:
print("{} nodes have joined so far, waiting for {} more.".format( print(
num_nodes, expected - num_nodes)) "{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes
)
)
sys.stdout.flush() sys.stdout.flush()
time.sleep(1) time.sleep(1)
else: else:
@ -35,9 +39,7 @@ def main():
# Check that objects can be transferred from each node to each other node. # Check that objects can be transferred from each node to each other node.
for i in range(10): for i in range(10):
print("Iteration {}".format(i)) print("Iteration {}".format(i))
results = [ results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
gethostname.remote(gethostname.remote(())) for _ in range(100)
]
print(Counter(ray.get(results))) print(Counter(ray.get(results)))
sys.stdout.flush() sys.stdout.flush()

View file

@ -25,24 +25,24 @@ if __name__ == "__main__":
"--exp-name", "--exp-name",
type=str, type=str,
required=True, required=True,
help="The job name and path to logging file (exp_name.log).") help="The job name and path to logging file (exp_name.log).",
)
parser.add_argument( parser.add_argument(
"--num-nodes", "--num-nodes", "-n", type=int, default=1, help="Number of nodes to use."
"-n", )
type=int,
default=1,
help="Number of nodes to use.")
parser.add_argument( parser.add_argument(
"--node", "--node",
"-w", "-w",
type=str, type=str,
help="The specified nodes to use. Same format as the " help="The specified nodes to use. Same format as the "
"return of 'sinfo'. Default: ''.") "return of 'sinfo'. Default: ''.",
)
parser.add_argument( parser.add_argument(
"--num-gpus", "--num-gpus",
type=int, type=int,
default=0, default=0,
help="Number of GPUs to use in each node. (Default: 0)") help="Number of GPUs to use in each node. (Default: 0)",
)
parser.add_argument( parser.add_argument(
"--partition", "--partition",
"-p", "-p",
@ -51,14 +51,16 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--load-env", "--load-env",
type=str, type=str,
help="The script to load your environment ('module load cuda/10.1')") help="The script to load your environment ('module load cuda/10.1')",
)
parser.add_argument( parser.add_argument(
"--command", "--command",
type=str, type=str,
required=True, required=True,
help="The command you wish to execute. For example: " help="The command you wish to execute. For example: "
" --command 'python test.py'. " " --command 'python test.py'. "
"Note that the command must be a string.") "Note that the command must be a string.",
)
args = parser.parse_args() args = parser.parse_args()
if args.node: if args.node:
@ -67,11 +69,13 @@ if __name__ == "__main__":
else: else:
node_info = "" node_info = ""
job_name = "{}_{}".format(args.exp_name, job_name = "{}_{}".format(
time.strftime("%m%d-%H%M", time.localtime())) args.exp_name, time.strftime("%m%d-%H%M", time.localtime())
)
partition_option = "#SBATCH --partition={}".format( partition_option = (
args.partition) if args.partition else "" "#SBATCH --partition={}".format(args.partition) if args.partition else ""
)
# ===== Modified the template script ===== # ===== Modified the template script =====
with open(template_file, "r") as f: with open(template_file, "r") as f:
@ -84,10 +88,10 @@ if __name__ == "__main__":
text = text.replace(LOAD_ENV, str(args.load_env)) text = text.replace(LOAD_ENV, str(args.load_env))
text = text.replace(GIVEN_NODE, node_info) text = text.replace(GIVEN_NODE, node_info)
text = text.replace( text = text.replace(
"# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO " "# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO " "PRODUCTION!",
"PRODUCTION!",
"# THIS FILE IS MODIFIED AUTOMATICALLY FROM TEMPLATE AND SHOULD BE " "# THIS FILE IS MODIFIED AUTOMATICALLY FROM TEMPLATE AND SHOULD BE "
"RUNNABLE!") "RUNNABLE!",
)
# ===== Save the script ===== # ===== Save the script =====
script_file = "{}.sh".format(job_name) script_file = "{}.sh".format(job_name)
@ -99,5 +103,7 @@ if __name__ == "__main__":
subprocess.Popen(["sbatch", script_file]) subprocess.Popen(["sbatch", script_file])
print( print(
"Job submitted! Script file is at: <{}>. Log file is at: <{}>".format( "Job submitted! Script file is at: <{}>. Log file is at: <{}>".format(
script_file, "{}.log".format(job_name))) script_file, "{}.log".format(job_name)
)
)
sys.exit(0) sys.exit(0)

View file

@ -89,7 +89,7 @@ myst_enable_extensions = [
] ]
external_toc_exclude_missing = False external_toc_exclude_missing = False
external_toc_path = '_toc.yml' external_toc_path = "_toc.yml"
# There's a flaky autodoc import for "TensorFlowVariables" that fails depending on the doc structure / order # There's a flaky autodoc import for "TensorFlowVariables" that fails depending on the doc structure / order
# of imports. # of imports.
@ -112,7 +112,8 @@ versionwarning_messages = {
"<b>Got questions?</b> Join " "<b>Got questions?</b> Join "
f'<a href="{FORUM_LINK}">the Ray Community forum</a> ' f'<a href="{FORUM_LINK}">the Ray Community forum</a> '
"for Q&A on all things Ray, as well as to share and learn use cases " "for Q&A on all things Ray, as well as to share and learn use cases "
"and best practices with the Ray community."), "and best practices with the Ray community."
),
} }
versionwarning_body_selector = "#main-content" versionwarning_body_selector = "#main-content"
@ -189,11 +190,16 @@ exclude_patterns += sphinx_gallery_conf["examples_dirs"]
# If "DOC_LIB" is found, only build that top-level navigation item. # If "DOC_LIB" is found, only build that top-level navigation item.
build_one_lib = os.getenv("DOC_LIB") build_one_lib = os.getenv("DOC_LIB")
all_toc_libs = [ all_toc_libs = [f.path for f in os.scandir(".") if f.is_dir() and "ray-" in f.path]
f.path for f in os.scandir(".") if f.is_dir() and "ray-" in f.path
]
all_toc_libs += [ all_toc_libs += [
"cluster", "tune", "data", "raysgd", "train", "rllib", "serve", "workflows" "cluster",
"tune",
"data",
"raysgd",
"train",
"rllib",
"serve",
"workflows",
] ]
if build_one_lib and build_one_lib in all_toc_libs: if build_one_lib and build_one_lib in all_toc_libs:
all_toc_libs.remove(build_one_lib) all_toc_libs.remove(build_one_lib)
@ -405,7 +411,8 @@ def setup(app):
# Custom JS # Custom JS
app.add_js_file( app.add_js_file(
"https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js", "https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js",
defer="defer") defer="defer",
)
app.add_js_file("js/docsearch.js", defer="defer") app.add_js_file("js/docsearch.js", defer="defer")
# Custom Sphinx directives # Custom Sphinx directives
app.add_directive("customgalleryitem", CustomGalleryItemDirective) app.add_directive("customgalleryitem", CustomGalleryItemDirective)

View file

@ -6,13 +6,17 @@ from docutils import nodes
import os import os
import sphinx_gallery import sphinx_gallery
import urllib import urllib
# Note: the scipy import has to stay here, it's used implicitly down the line # Note: the scipy import has to stay here, it's used implicitly down the line
import scipy.stats # noqa: F401 import scipy.stats # noqa: F401
import scipy.linalg # noqa: F401 import scipy.linalg # noqa: F401
__all__ = [ __all__ = [
"CustomGalleryItemDirective", "fix_xgb_lgbm_docs", "MOCK_MODULES", "CustomGalleryItemDirective",
"CHILD_MOCK_MODULES", "update_context" "fix_xgb_lgbm_docs",
"MOCK_MODULES",
"CHILD_MOCK_MODULES",
"update_context",
] ]
try: try:
@ -60,7 +64,7 @@ class CustomGalleryItemDirective(Directive):
option_spec = { option_spec = {
"tooltip": directives.unchanged, "tooltip": directives.unchanged,
"figure": directives.unchanged, "figure": directives.unchanged,
"description": directives.unchanged "description": directives.unchanged,
} }
has_content = False has_content = False
@ -73,8 +77,9 @@ class CustomGalleryItemDirective(Directive):
if len(self.options["tooltip"]) > 195: if len(self.options["tooltip"]) > 195:
tooltip = tooltip[:195] + "..." tooltip = tooltip[:195] + "..."
else: else:
raise ValueError("Need to provide :tooltip: under " raise ValueError(
"`.. customgalleryitem::`.") "Need to provide :tooltip: under " "`.. customgalleryitem::`."
)
# Generate `thumbnail` used in the gallery. # Generate `thumbnail` used in the gallery.
if "figure" in self.options: if "figure" in self.options:
@ -95,11 +100,13 @@ class CustomGalleryItemDirective(Directive):
if "description" in self.options: if "description" in self.options:
description = self.options["description"] description = self.options["description"]
else: else:
raise ValueError("Need to provide :description: under " raise ValueError(
"`customgalleryitem::`.") "Need to provide :description: under " "`customgalleryitem::`."
)
thumbnail_rst = GALLERY_TEMPLATE.format( thumbnail_rst = GALLERY_TEMPLATE.format(
tooltip=tooltip, thumbnail=thumbnail, description=description) tooltip=tooltip, thumbnail=thumbnail, description=description
)
thumbnail = StringList(thumbnail_rst.split("\n")) thumbnail = StringList(thumbnail_rst.split("\n"))
thumb = nodes.paragraph() thumb = nodes.paragraph()
self.state.nested_parse(thumbnail, self.content_offset, thumb) self.state.nested_parse(thumbnail, self.content_offset, thumb)
@ -146,29 +153,30 @@ def fix_xgb_lgbm_docs(app, what, name, obj, options, lines):
# Taken from https://github.com/edx/edx-documentation # Taken from https://github.com/edx/edx-documentation
FEEDBACK_FORM_FMT = "https://github.com/ray-project/ray/issues/new?" \ FEEDBACK_FORM_FMT = (
"title={title}&labels=docs&body={body}" "https://github.com/ray-project/ray/issues/new?"
"title={title}&labels=docs&body={body}"
)
def feedback_form_url(project, page): def feedback_form_url(project, page):
"""Create a URL for feedback on a particular page in a project.""" """Create a URL for feedback on a particular page in a project."""
return FEEDBACK_FORM_FMT.format( return FEEDBACK_FORM_FMT.format(
title=urllib.parse.quote( title=urllib.parse.quote("[docs] Issue on `{page}.rst`".format(page=page)),
"[docs] Issue on `{page}.rst`".format(page=page)),
body=urllib.parse.quote( body=urllib.parse.quote(
"# Documentation Problem/Question/Comment\n" "# Documentation Problem/Question/Comment\n"
"<!-- Describe your issue/question/comment below. -->\n" "<!-- Describe your issue/question/comment below. -->\n"
"<!-- If there are typos or errors in the docs, feel free " "<!-- If there are typos or errors in the docs, feel free "
"to create a pull-request. -->\n" "to create a pull-request. -->\n"
"\n\n\n\n" "\n\n\n\n"
"(Created directly from the docs)\n"), "(Created directly from the docs)\n"
),
) )
def update_context(app, pagename, templatename, context, doctree): def update_context(app, pagename, templatename, context, doctree):
"""Update the page rendering context to include ``feedback_form_url``.""" """Update the page rendering context to include ``feedback_form_url``."""
context["feedback_form_url"] = feedback_form_url(app.config.project, context["feedback_form_url"] = feedback_form_url(app.config.project, pagename)
pagename)
MOCK_MODULES = [ MOCK_MODULES = [
@ -187,8 +195,7 @@ MOCK_MODULES = [
"horovod.ray.runner", "horovod.ray.runner",
"horovod.ray.utils", "horovod.ray.utils",
"hyperopt", "hyperopt",
"hyperopt.hp" "hyperopt.hp" "kubernetes",
"kubernetes",
"mlflow", "mlflow",
"modin", "modin",
"mxnet", "mxnet",

View file

@ -59,13 +59,16 @@ from ray.data.datasource.datasource import RandomIntRowDatasource
# Lets see how we implement such pipeline using Ray Dataset: # Lets see how we implement such pipeline using Ray Dataset:
def create_shuffle_pipeline(training_data_dir: str, num_epochs: int, def create_shuffle_pipeline(
num_shards: int) -> List[DatasetPipeline]: training_data_dir: str, num_epochs: int, num_shards: int
) -> List[DatasetPipeline]:
return ray.data.read_parquet(training_data_dir) \ return (
.repeat(num_epochs) \ ray.data.read_parquet(training_data_dir)
.random_shuffle_each_window() \ .repeat(num_epochs)
.random_shuffle_each_window()
.split(num_shards, equal=True) .split(num_shards, equal=True)
)
############################################################################ ############################################################################
@ -117,7 +120,8 @@ parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--large-scale-test", "--large-scale-test",
action="store_true", action="store_true",
help="Run large scale test (500GiB of data).") help="Run large scale test (500GiB of data).",
)
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
@ -142,13 +146,15 @@ if not args.large_scale_test:
ray.data.read_datasource( ray.data.read_datasource(
RandomIntRowDatasource(), RandomIntRowDatasource(),
n=size_bytes // 8 // NUM_COLUMNS, n=size_bytes // 8 // NUM_COLUMNS,
num_columns=NUM_COLUMNS).write_parquet(tmpdir) num_columns=NUM_COLUMNS,
).write_parquet(tmpdir)
return tmpdir return tmpdir
example_files_dir = generate_example_files(SIZE_100MiB) example_files_dir = generate_example_files(SIZE_100MiB)
splits = create_shuffle_pipeline(example_files_dir, NUM_EPOCHS, splits = create_shuffle_pipeline(
NUM_TRAINING_WORKERS) example_files_dir, NUM_EPOCHS, NUM_TRAINING_WORKERS
)
training_workers = [ training_workers = [
TrainingWorker.remote(rank, shard) for rank, shard in enumerate(splits) TrainingWorker.remote(rank, shard) for rank, shard in enumerate(splits)
@ -198,18 +204,22 @@ if not args.large_scale_test:
# generated data. # generated data.
def create_large_shuffle_pipeline(data_size_bytes: int, num_epochs: int, def create_large_shuffle_pipeline(
num_columns: int, data_size_bytes: int, num_epochs: int, num_columns: int, num_shards: int
num_shards: int) -> List[DatasetPipeline]: ) -> List[DatasetPipeline]:
# _spread_resource_prefix is used to ensure tasks are evenly spread to all # _spread_resource_prefix is used to ensure tasks are evenly spread to all
# CPU nodes. # CPU nodes.
return ray.data.read_datasource( return (
RandomIntRowDatasource(), n=data_size_bytes // 8 // num_columns, ray.data.read_datasource(
RandomIntRowDatasource(),
n=data_size_bytes // 8 // num_columns,
num_columns=num_columns, num_columns=num_columns,
_spread_resource_prefix="node:") \ _spread_resource_prefix="node:",
.repeat(num_epochs) \ )
.random_shuffle_each_window(_spread_resource_prefix="node:") \ .repeat(num_epochs)
.random_shuffle_each_window(_spread_resource_prefix="node:")
.split(num_shards, equal=True) .split(num_shards, equal=True)
)
################################################################################# #################################################################################
@ -229,19 +239,18 @@ if args.large_scale_test:
# waiting for cluster nodes to come up. # waiting for cluster nodes to come up.
while len(ray.nodes()) < TOTAL_NUM_NODES: while len(ray.nodes()) < TOTAL_NUM_NODES:
print( print(f"waiting for nodes to start up: {len(ray.nodes())}/{TOTAL_NUM_NODES}")
f"waiting for nodes to start up: {len(ray.nodes())}/{TOTAL_NUM_NODES}"
)
time.sleep(5) time.sleep(5)
splits = create_large_shuffle_pipeline(SIZE_500GiB, NUM_EPOCHS, splits = create_large_shuffle_pipeline(
NUM_COLUMNS, NUM_TRAINING_WORKERS) SIZE_500GiB, NUM_EPOCHS, NUM_COLUMNS, NUM_TRAINING_WORKERS
)
# Note we set num_gpus=1 for workers so that # Note we set num_gpus=1 for workers so that
# the workers will only run on GPU nodes. # the workers will only run on GPU nodes.
training_workers = [ training_workers = [
TrainingWorker.options(num_gpus=1) \ TrainingWorker.options(num_gpus=1).remote(rank, shard)
.remote(rank, shard) for rank, shard in enumerate(splits) for rank, shard in enumerate(splits)
] ]
start = time.time() start = time.time()

View file

@ -4,21 +4,30 @@ import pyximport
pyximport.install(setup_args={"include_dirs": numpy.get_include()}) pyximport.install(setup_args={"include_dirs": numpy.get_include()})
from .cython_simple import simple_func, fib, fib_int, \ from .cython_simple import simple_func, fib, fib_int, fib_cpdef, fib_cdef, simple_class
fib_cpdef, fib_cdef, simple_class
from .masked_log import masked_log from .masked_log import masked_log
from .cython_blas import \ from .cython_blas import (
compute_self_corr_for_voxel_sel, \ compute_self_corr_for_voxel_sel,
compute_kernel_matrix, \ compute_kernel_matrix,
compute_single_self_corr_syrk, \ compute_single_self_corr_syrk,
compute_single_self_corr_gemm, \ compute_single_self_corr_gemm,
compute_corr_vectors, \ compute_corr_vectors,
compute_single_matrix_multiplication compute_single_matrix_multiplication,
)
__all__ = [ __all__ = [
"simple_func", "fib", "fib_int", "fib_cpdef", "fib_cdef", "simple_class", "simple_func",
"masked_log", "compute_self_corr_for_voxel_sel", "compute_kernel_matrix", "fib",
"compute_single_self_corr_syrk", "compute_single_self_corr_gemm", "fib_int",
"compute_corr_vectors", "compute_single_matrix_multiplication" "fib_cpdef",
"fib_cdef",
"simple_class",
"masked_log",
"compute_self_corr_for_voxel_sel",
"compute_kernel_matrix",
"compute_single_self_corr_syrk",
"compute_single_self_corr_gemm",
"compute_corr_vectors",
"compute_single_matrix_multiplication",
] ]

View file

@ -94,11 +94,11 @@ def example8():
# See cython_blas.pyx for argument documentation # See cython_blas.pyx for argument documentation
mat = np.array( mat = np.array(
[[[2.0, 2.0], [2.0, 2.0]], [[2.0, 2.0], [2.0, 2.0]]], dtype=np.float32) [[[2.0, 2.0], [2.0, 2.0]], [[2.0, 2.0], [2.0, 2.0]]], dtype=np.float32
)
result = np.zeros((2, 2), np.float32, order="C") result = np.zeros((2, 2), np.float32, order="C")
run_func(cyth.compute_kernel_matrix, "L", "T", 2, 2, 1.0, mat, 0, 2, 1.0, run_func(cyth.compute_kernel_matrix, "L", "T", 2, 2, 1.0, mat, 0, 2, 1.0, result, 2)
result, 2)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -13,6 +13,7 @@ include_dirs = [numpy.get_include()]
# dependencies # dependencies
try: try:
import scipy # noqa import scipy # noqa
modules.append("cython_blas.pyx") modules.append("cython_blas.pyx")
install_requires.append("scipy") install_requires.append("scipy")
except ImportError as e: # noqa except ImportError as e: # noqa
@ -27,4 +28,5 @@ setup(
packages=[pkg_dir], packages=[pkg_dir],
ext_modules=cythonize(modules), ext_modules=cythonize(modules),
install_requires=install_requires, install_requires=install_requires,
include_dirs=include_dirs) include_dirs=include_dirs,
)

View file

@ -56,4 +56,5 @@ class CythonTest(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
import pytest import pytest
import sys import sys
sys.exit(pytest.main(["-v", __file__])) sys.exit(pytest.main(["-v", __file__]))

View file

@ -65,31 +65,34 @@ from ray.util.dask import ray_dask_get
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--address", type=str, default="auto", help="The address to use for Ray.") "--address", type=str, default="auto", help="The address to use for Ray."
)
parser.add_argument( parser.add_argument(
"--smoke-test", "--smoke-test",
action="store_true", action="store_true",
help="Read a smaller dataset for quick testing purposes.") help="Read a smaller dataset for quick testing purposes.",
)
parser.add_argument( parser.add_argument(
"--num-actors", "--num-actors", type=int, default=4, help="Sets number of actors for training."
type=int, )
default=4,
help="Sets number of actors for training.")
parser.add_argument( parser.add_argument(
"--cpus-per-actor", "--cpus-per-actor",
type=int, type=int,
default=6, default=6,
help="The number of CPUs per actor for training.") help="The number of CPUs per actor for training.",
)
parser.add_argument( parser.add_argument(
"--num-actors-inference", "--num-actors-inference",
type=int, type=int,
default=16, default=16,
help="Sets number of actors for inference.") help="Sets number of actors for inference.",
)
parser.add_argument( parser.add_argument(
"--cpus-per-actor-inference", "--cpus-per-actor-inference",
type=int, type=int,
default=2, default=2,
help="The number of CPUs per actor for inference.") help="The number of CPUs per actor for inference.",
)
# Ignore -f from ipykernel_launcher # Ignore -f from ipykernel_launcher
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
@ -125,12 +128,13 @@ if not ray.is_initialized():
LABEL_COLUMN = "label" LABEL_COLUMN = "label"
if smoke_test: if smoke_test:
# Test dataset with only 10,000 records. # Test dataset with only 10,000 records.
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \ FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" ".csv"
".csv"
else: else:
# Full dataset. This may take a couple of minutes to load. # Full dataset. This may take a couple of minutes to load.
FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \ FILE_URL = (
"/00280/HIGGS.csv.gz" "https://archive.ics.uci.edu/ml/machine-learning-databases"
"/00280/HIGGS.csv.gz"
)
colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)] colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)]
dask.config.set(scheduler=ray_dask_get) dask.config.set(scheduler=ray_dask_get)
@ -192,7 +196,8 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
dtrain=train_set, dtrain=train_set,
evals=[(test_set, "eval")], evals=[(test_set, "eval")],
evals_result=evals_result, evals_result=evals_result,
ray_params=ray_params) ray_params=ray_params,
)
train_end_time = time.time() train_end_time = time.time()
train_duration = train_end_time - train_start_time train_duration = train_end_time - train_start_time
@ -200,8 +205,7 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
model_path = "model.xgb" model_path = "model.xgb"
bst.save_model(model_path) bst.save_model(model_path)
print("Final validation error: {:.4f}".format( print("Final validation error: {:.4f}".format(evals_result["eval"]["error"][-1]))
evals_result["eval"]["error"][-1]))
return bst, evals_result return bst, evals_result
@ -221,8 +225,12 @@ config = {
} }
bst, evals_result = train_xgboost( bst, evals_result = train_xgboost(
config, train_df, eval_df, LABEL_COLUMN, config,
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors)) train_df,
eval_df,
LABEL_COLUMN,
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors),
)
print(f"Results: {evals_result}") print(f"Results: {evals_result}")
############################################################################### ###############################################################################
@ -260,13 +268,12 @@ def tune_xgboost(train_df, test_df, target_column):
"eval_metric": ["logloss", "error"], "eval_metric": ["logloss", "error"],
"eta": tune.loguniform(1e-4, 1e-1), "eta": tune.loguniform(1e-4, 1e-1),
"subsample": tune.uniform(0.5, 1.0), "subsample": tune.uniform(0.5, 1.0),
"max_depth": tune.randint(1, 9) "max_depth": tune.randint(1, 9),
} }
ray_params = RayParams( ray_params = RayParams(
max_actor_restarts=1, max_actor_restarts=1, cpus_per_actor=cpus_per_actor, num_actors=num_actors
cpus_per_actor=cpus_per_actor, )
num_actors=num_actors)
tune_start_time = time.time() tune_start_time = time.time()
@ -276,19 +283,21 @@ def tune_xgboost(train_df, test_df, target_column):
train_df=train_df, train_df=train_df,
test_df=test_df, test_df=test_df,
target_column=target_column, target_column=target_column,
ray_params=ray_params), ray_params=ray_params,
),
# Use the `get_tune_resources` helper function to set the resources. # Use the `get_tune_resources` helper function to set the resources.
resources_per_trial=ray_params.get_tune_resources(), resources_per_trial=ray_params.get_tune_resources(),
config=config, config=config,
num_samples=10, num_samples=10,
metric="eval-error", metric="eval-error",
mode="min") mode="min",
)
tune_end_time = time.time() tune_end_time = time.time()
tune_duration = tune_end_time - tune_start_time tune_duration = tune_end_time - tune_start_time
print(f"Total time taken: {tune_duration} seconds.") print(f"Total time taken: {tune_duration} seconds.")
accuracy = 1. - analysis.best_result["eval-error"] accuracy = 1.0 - analysis.best_result["eval-error"]
print(f"Best model parameters: {analysis.best_config}") print(f"Best model parameters: {analysis.best_config}")
print(f"Best model total accuracy: {accuracy:.4f}") print(f"Best model total accuracy: {accuracy:.4f}")
@ -315,7 +324,8 @@ results = predict(
bst, bst,
inference_df, inference_df,
ray_params=RayParams( ray_params=RayParams(
cpus_per_actor=cpus_per_actor_inference, cpus_per_actor=cpus_per_actor_inference, num_actors=num_actors_inference
num_actors=num_actors_inference)) ),
)
print(results) print(results)

View file

@ -59,7 +59,8 @@ def make_and_upload_dataset(dir_path):
shift=0.0, shift=0.0,
scale=1.0, scale=1.0,
shuffle=False, shuffle=False,
random_state=seed) random_state=seed,
)
# turn into dataframe with column names # turn into dataframe with column names
col_names = ["feature_%0d" % i for i in range(1, d + 1, 1)] col_names = ["feature_%0d" % i for i in range(1, d + 1, 1)]
@ -91,10 +92,8 @@ def make_and_upload_dataset(dir_path):
path = os.path.join(data_path, f"data_{i:05d}.parquet.snappy") path = os.path.join(data_path, f"data_{i:05d}.parquet.snappy")
if not os.path.exists(path): if not os.path.exists(path):
tmp_df = create_data_chunk( tmp_df = create_data_chunk(
n=PARQUET_FILE_CHUNK_SIZE, n=PARQUET_FILE_CHUNK_SIZE, d=NUM_FEATURES, seed=i, include_label=True
d=NUM_FEATURES, )
seed=i,
include_label=True)
tmp_df.to_parquet(path, compression="snappy", index=False) tmp_df.to_parquet(path, compression="snappy", index=False)
print(f"Wrote {path} to disk...") print(f"Wrote {path} to disk...")
# todo: at large enough scale we might want to upload the rest after # todo: at large enough scale we might want to upload the rest after
@ -108,10 +107,8 @@ def make_and_upload_dataset(dir_path):
path = os.path.join(inference_path, f"data_{i:05d}.parquet.snappy") path = os.path.join(inference_path, f"data_{i:05d}.parquet.snappy")
if not os.path.exists(path): if not os.path.exists(path):
tmp_df = create_data_chunk( tmp_df = create_data_chunk(
n=PARQUET_FILE_CHUNK_SIZE, n=PARQUET_FILE_CHUNK_SIZE, d=NUM_FEATURES, seed=i, include_label=False
d=NUM_FEATURES, )
seed=i,
include_label=False)
tmp_df.to_parquet(path, compression="snappy", index=False) tmp_df.to_parquet(path, compression="snappy", index=False)
print(f"Wrote {path} to disk...") print(f"Wrote {path} to disk...")
# todo: at large enough scale we might want to upload the rest after # todo: at large enough scale we might want to upload the rest after
@ -124,8 +121,9 @@ def make_and_upload_dataset(dir_path):
def read_dataset(path: str) -> ray.data.Dataset: def read_dataset(path: str) -> ray.data.Dataset:
print(f"reading data from {path}") print(f"reading data from {path}")
return ray.data.read_parquet(path, _spread_resource_prefix="node:") \ return ray.data.read_parquet(path, _spread_resource_prefix="node:").random_shuffle(
.random_shuffle(_spread_resource_prefix="node:") _spread_resource_prefix="node:"
)
class DataPreprocessor: class DataPreprocessor:
@ -141,20 +139,20 @@ class DataPreprocessor:
# columns. # columns.
self.standard_stats = None self.standard_stats = None
def preprocess_train_data(self, ds: ray.data.Dataset def preprocess_train_data(
) -> Tuple[ray.data.Dataset, ray.data.Dataset]: self, ds: ray.data.Dataset
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
print("\n\nPreprocessing training dataset.\n") print("\n\nPreprocessing training dataset.\n")
return self._preprocess(ds, False) return self._preprocess(ds, False)
def preprocess_inference_data(self, def preprocess_inference_data(self, df: ray.data.Dataset) -> ray.data.Dataset:
df: ray.data.Dataset) -> ray.data.Dataset:
print("\n\nPreprocessing inference dataset.\n") print("\n\nPreprocessing inference dataset.\n")
return self._preprocess(df, True)[0] return self._preprocess(df, True)[0]
def _preprocess(self, ds: ray.data.Dataset, inferencing: bool def _preprocess(
) -> Tuple[ray.data.Dataset, ray.data.Dataset]: self, ds: ray.data.Dataset, inferencing: bool
print( ) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
"\nStep 1: Dropping nulls, creating new_col, updating feature_1\n") print("\nStep 1: Dropping nulls, creating new_col, updating feature_1\n")
def batch_transformer(df: pd.DataFrame): def batch_transformer(df: pd.DataFrame):
# Disable chained assignment warning. # Disable chained assignment warning.
@ -165,25 +163,27 @@ class DataPreprocessor:
# Add new column. # Add new column.
df["new_col"] = ( df["new_col"] = (
df["feature_1"] - 2 * df["feature_2"] + df["feature_3"]) / 3. df["feature_1"] - 2 * df["feature_2"] + df["feature_3"]
) / 3.0
# Transform column. # Transform column.
df["feature_1"] = 2. * df["feature_1"] + 0.1 df["feature_1"] = 2.0 * df["feature_1"] + 0.1
return df return df
ds = ds.map_batches(batch_transformer, batch_format="pandas") ds = ds.map_batches(batch_transformer, batch_format="pandas")
print("\nStep 2: Precalculating fruit-grouped mean for new column and " print(
"for one-hot encoding (latter only uses fruit groups)\n") "\nStep 2: Precalculating fruit-grouped mean for new column and "
"for one-hot encoding (latter only uses fruit groups)\n"
)
agg_ds = ds.groupby("fruit").mean("feature_1") agg_ds = ds.groupby("fruit").mean("feature_1")
fruit_means = { fruit_means = {r["fruit"]: r["mean(feature_1)"] for r in agg_ds.take_all()}
r["fruit"]: r["mean(feature_1)"]
for r in agg_ds.take_all()
}
print("\nStep 3: create mean_by_fruit as mean of feature_1 groupby " print(
"fruit; one-hot encode fruit column\n") "\nStep 3: create mean_by_fruit as mean of feature_1 groupby "
"fruit; one-hot encode fruit column\n"
)
if inferencing: if inferencing:
assert self.fruits is not None assert self.fruits is not None
@ -192,8 +192,7 @@ class DataPreprocessor:
self.fruits = list(fruit_means.keys()) self.fruits = list(fruit_means.keys())
fruit_one_hots = { fruit_one_hots = {
fruit: collections.defaultdict(int, fruit=1) fruit: collections.defaultdict(int, fruit=1) for fruit in self.fruits
for fruit in self.fruits
} }
def batch_transformer(df: pd.DataFrame): def batch_transformer(df: pd.DataFrame):
@ -224,12 +223,12 @@ class DataPreprocessor:
# Split into 90% training set, 10% test set. # Split into 90% training set, 10% test set.
train_ds, test_ds = ds.split_at_indices([split_index]) train_ds, test_ds = ds.split_at_indices([split_index])
print("\nStep 4b: Precalculate training dataset stats for " print(
"standard scaling\n") "\nStep 4b: Precalculate training dataset stats for "
"standard scaling\n"
)
# Calculate stats needed for standard scaling feature columns. # Calculate stats needed for standard scaling feature columns.
feature_columns = [ feature_columns = [col for col in train_ds.schema().names if col != "label"]
col for col in train_ds.schema().names if col != "label"
]
standard_aggs = [ standard_aggs = [
agg(on=col) for col in feature_columns for agg in (Mean, Std) agg(on=col) for col in feature_columns for agg in (Mean, Std)
] ]
@ -252,30 +251,29 @@ class DataPreprocessor:
if inferencing: if inferencing:
# Apply standard scaling to inference dataset. # Apply standard scaling to inference dataset.
inference_ds = ds.map_batches( inference_ds = ds.map_batches(batch_standard_scaler, batch_format="pandas")
batch_standard_scaler, batch_format="pandas")
return inference_ds, None return inference_ds, None
else: else:
# Apply standard scaling to both training dataset and test dataset. # Apply standard scaling to both training dataset and test dataset.
train_ds = train_ds.map_batches( train_ds = train_ds.map_batches(
batch_standard_scaler, batch_format="pandas") batch_standard_scaler, batch_format="pandas"
test_ds = test_ds.map_batches( )
batch_standard_scaler, batch_format="pandas") test_ds = test_ds.map_batches(batch_standard_scaler, batch_format="pandas")
return train_ds, test_ds return train_ds, test_ds
def inference(dataset, model_cls: type, batch_size: int, result_path: str, def inference(
use_gpu: bool): dataset, model_cls: type, batch_size: int, result_path: str, use_gpu: bool
):
print("inferencing...") print("inferencing...")
num_gpus = 1 if use_gpu else 0 num_gpus = 1 if use_gpu else 0
dataset \ dataset.map_batches(
.map_batches( model_cls,
model_cls, compute="actors",
compute="actors", batch_size=batch_size,
batch_size=batch_size, num_gpus=num_gpus,
num_gpus=num_gpus, num_cpus=0,
num_cpus=0) \ ).write_parquet(result_path)
.write_parquet(result_path)
""" """
@ -295,8 +293,7 @@ P1:
class Net(nn.Module): class Net(nn.Module):
def __init__(self, n_layers, n_features, num_hidden, dropout_every, def __init__(self, n_layers, n_features, num_hidden, dropout_every, drop_prob):
drop_prob):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.dropout_every = dropout_every self.dropout_every = dropout_every
@ -406,8 +403,9 @@ def train_func(config):
print("Defining model, loss, and optimizer...") print("Defining model, loss, and optimizer...")
# Setup device. # Setup device.
device = torch.device(f"cuda:{train.local_rank()}" device = torch.device(
if use_gpu and torch.cuda.is_available() else "cpu") f"cuda:{train.local_rank()}" if use_gpu and torch.cuda.is_available() else "cpu"
)
print(f"Device: {device}") print(f"Device: {device}")
# Setup data. # Setup data.
@ -415,7 +413,8 @@ def train_func(config):
train_dataset_epoch_iterator = train_dataset_pipeline.iter_epochs() train_dataset_epoch_iterator = train_dataset_pipeline.iter_epochs()
test_dataset = train.get_dataset_shard("test_dataset") test_dataset = train.get_dataset_shard("test_dataset")
test_torch_dataset = test_dataset.to_torch( test_torch_dataset = test_dataset.to_torch(
label_column="label", batch_size=batch_size) label_column="label", batch_size=batch_size
)
net = Net( net = Net(
n_layers=num_layers, n_layers=num_layers,
@ -436,30 +435,37 @@ def train_func(config):
train_dataset = next(train_dataset_epoch_iterator) train_dataset = next(train_dataset_epoch_iterator)
train_torch_dataset = train_dataset.to_torch( train_torch_dataset = train_dataset.to_torch(
label_column="label", batch_size=batch_size) label_column="label", batch_size=batch_size
)
train_running_loss, train_num_correct, train_num_total = train_epoch( train_running_loss, train_num_correct, train_num_total = train_epoch(
train_torch_dataset, net, device, criterion, optimizer) train_torch_dataset, net, device, criterion, optimizer
)
train_acc = train_num_correct / train_num_total train_acc = train_num_correct / train_num_total
print(f"epoch [{epoch + 1}]: training accuracy: " print(
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}") f"epoch [{epoch + 1}]: training accuracy: "
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}"
)
test_running_loss, test_num_correct, test_num_total = test_epoch( test_running_loss, test_num_correct, test_num_total = test_epoch(
test_torch_dataset, net, device, criterion) test_torch_dataset, net, device, criterion
)
test_acc = test_num_correct / test_num_total test_acc = test_num_correct / test_num_total
print(f"epoch [{epoch + 1}]: testing accuracy: " print(
f"{test_num_correct} / {test_num_total} = {test_acc:.4f}") f"epoch [{epoch + 1}]: testing accuracy: "
f"{test_num_correct} / {test_num_total} = {test_acc:.4f}"
)
# Record and log stats. # Record and log stats.
train.report( train.report(
train_acc=train_acc, train_acc=train_acc,
train_loss=train_running_loss, train_loss=train_running_loss,
test_acc=test_acc, test_acc=test_acc,
test_loss=test_running_loss) test_loss=test_running_loss,
)
# Checkpoint model. # Checkpoint model.
module = (net.module module = net.module if isinstance(net, DistributedDataParallel) else net
if isinstance(net, DistributedDataParallel) else net)
train.save_checkpoint(model_state_dict=module.state_dict()) train.save_checkpoint(model_state_dict=module.state_dict())
if train.world_rank() == 0: if train.world_rank() == 0:
@ -469,46 +475,44 @@ def train_func(config):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--dir-path", "--dir-path", default=".", type=str, help="Path to read and write data from"
default=".", )
type=str,
help="Path to read and write data from")
parser.add_argument( parser.add_argument(
"--use-s3", "--use-s3",
action="store_true", action="store_true",
default=False, default=False,
help="Use data from s3 for testing.") help="Use data from s3 for testing.",
)
parser.add_argument( parser.add_argument(
"--smoke-test", "--smoke-test",
action="store_true", action="store_true",
default=False, default=False,
help="Finish quickly for testing.") help="Finish quickly for testing.",
)
parser.add_argument( parser.add_argument(
"--address", "--address",
required=False, required=False,
type=str, type=str,
help="The address to use for Ray. " help="The address to use for Ray. " "`auto` if running through `ray submit.",
"`auto` if running through `ray submit.") )
parser.add_argument( parser.add_argument(
"--num-workers", "--num-workers",
default=1, default=1,
type=int, type=int,
help="The number of Ray workers to use for distributed training") help="The number of Ray workers to use for distributed training",
)
parser.add_argument( parser.add_argument(
"--large-dataset", "--large-dataset", action="store_true", default=False, help="Use 500GB dataset"
action="store_true", )
default=False,
help="Use 500GB dataset")
parser.add_argument( parser.add_argument(
"--use-gpu", "--use-gpu", action="store_true", default=False, help="Use GPU for training."
action="store_true", )
default=False,
help="Use GPU for training.")
parser.add_argument( parser.add_argument(
"--mlflow-register-model", "--mlflow-register-model",
action="store_true", action="store_true",
help="Whether to use mlflow model registry. If set, a local MLflow " help="Whether to use mlflow model registry. If set, a local MLflow "
"tracking server is expected to have already been started.") "tracking server is expected to have already been started.",
)
args = parser.parse_args() args = parser.parse_args()
smoke_test = args.smoke_test smoke_test = args.smoke_test
@ -553,8 +557,11 @@ if __name__ == "__main__":
if len(list(count)) == 0: if len(list(count)) == 0:
print("please run `python make_and_upload_dataset.py` first") print("please run `python make_and_upload_dataset.py` first")
sys.exit(1) sys.exit(1)
data_path = ("s3://cuj-big-data/big-data/" data_path = (
if large_dataset else "s3://cuj-big-data/data/") "s3://cuj-big-data/big-data/"
if large_dataset
else "s3://cuj-big-data/data/"
)
inference_path = "s3://cuj-big-data/inference/" inference_path = "s3://cuj-big-data/inference/"
inference_output_path = "s3://cuj-big-data/output/" inference_output_path = "s3://cuj-big-data/output/"
else: else:
@ -562,20 +569,19 @@ if __name__ == "__main__":
inference_path = os.path.join(dir_path, "inference") inference_path = os.path.join(dir_path, "inference")
inference_output_path = "/tmp" inference_output_path = "/tmp"
if len(os.listdir(data_path)) <= 1 or len( if len(os.listdir(data_path)) <= 1 or len(os.listdir(inference_path)) <= 1:
os.listdir(inference_path)) <= 1:
print("please run `python make_and_upload_dataset.py` first") print("please run `python make_and_upload_dataset.py` first")
sys.exit(1) sys.exit(1)
if smoke_test: if smoke_test:
# Only read a single file. # Only read a single file.
data_path = os.path.join(data_path, "data_00000.parquet.snappy") data_path = os.path.join(data_path, "data_00000.parquet.snappy")
inference_path = os.path.join(inference_path, inference_path = os.path.join(inference_path, "data_00000.parquet.snappy")
"data_00000.parquet.snappy")
preprocessor = DataPreprocessor() preprocessor = DataPreprocessor()
train_dataset, test_dataset = preprocessor.preprocess_train_data( train_dataset, test_dataset = preprocessor.preprocess_train_data(
read_dataset(data_path)) read_dataset(data_path)
)
num_columns = len(train_dataset.schema().names) num_columns = len(train_dataset.schema().names)
# remove label column and internal Arrow column. # remove label column and internal Arrow column.
@ -589,14 +595,12 @@ if __name__ == "__main__":
DROPOUT_PROB = 0.2 DROPOUT_PROB = 0.2
# Random global shuffle # Random global shuffle
train_dataset_pipeline = train_dataset.repeat() \ train_dataset_pipeline = train_dataset.repeat().random_shuffle_each_window(
.random_shuffle_each_window(_spread_resource_prefix="node:") _spread_resource_prefix="node:"
)
del train_dataset del train_dataset
datasets = { datasets = {"train_dataset": train_dataset_pipeline, "test_dataset": test_dataset}
"train_dataset": train_dataset_pipeline,
"test_dataset": test_dataset
}
config = { config = {
"use_gpu": use_gpu, "use_gpu": use_gpu,
@ -606,7 +610,7 @@ if __name__ == "__main__":
"num_layers": NUM_LAYERS, "num_layers": NUM_LAYERS,
"dropout_every": DROPOUT_EVERY, "dropout_every": DROPOUT_EVERY,
"dropout_prob": DROPOUT_PROB, "dropout_prob": DROPOUT_PROB,
"num_features": num_features "num_features": num_features,
} }
# Create 2 callbacks: one for Tensorboard Logging and one for MLflow # Create 2 callbacks: one for Tensorboard Logging and one for MLflow
@ -619,7 +623,8 @@ if __name__ == "__main__":
callbacks = [ callbacks = [
TBXLoggerCallback(logdir=tbx_logdir), TBXLoggerCallback(logdir=tbx_logdir),
MLflowLoggerCallback( MLflowLoggerCallback(
experiment_name="cuj-big-data-training", save_artifact=True) experiment_name="cuj-big-data-training", save_artifact=True
),
] ]
# Remove CPU resource so Datasets can be scheduled. # Remove CPU resource so Datasets can be scheduled.
@ -629,19 +634,19 @@ if __name__ == "__main__":
backend="torch", backend="torch",
num_workers=num_workers, num_workers=num_workers,
use_gpu=use_gpu, use_gpu=use_gpu,
resources_per_worker=resources_per_worker) resources_per_worker=resources_per_worker,
)
trainer.start() trainer.start()
results = trainer.run( results = trainer.run(
train_func=train_func, train_func=train_func, config=config, callbacks=callbacks, dataset=datasets
config=config, )
callbacks=callbacks,
dataset=datasets)
model = results[0] model = results[0]
trainer.shutdown() trainer.shutdown()
if args.mlflow_register_model: if args.mlflow_register_model:
mlflow.pytorch.log_model( mlflow.pytorch.log_model(
model, artifact_path="models", registered_model_name="torch_model") model, artifact_path="models", registered_model_name="torch_model"
)
# Get the latest model from mlflow model registry. # Get the latest model from mlflow model registry.
client = mlflow.tracking.MlflowClient() client = mlflow.tracking.MlflowClient()
@ -649,12 +654,14 @@ if __name__ == "__main__":
# Get the info for the latest model. # Get the info for the latest model.
# By default, registered models are in stage "None". # By default, registered models are in stage "None".
latest_model_info = client.get_latest_versions( latest_model_info = client.get_latest_versions(
registered_model_name, stages=["None"])[0] registered_model_name, stages=["None"]
)[0]
latest_version = latest_model_info.version latest_version = latest_model_info.version
def load_model_func(): def load_model_func():
model_uri = f"models:/torch_model/{latest_version}" model_uri = f"models:/torch_model/{latest_version}"
return mlflow.pytorch.load_model(model_uri) return mlflow.pytorch.load_model(model_uri)
else: else:
state_dict = model.state_dict() state_dict = model.state_dict()
@ -670,25 +677,30 @@ if __name__ == "__main__":
n_features=num_features, n_features=num_features,
num_hidden=num_hidden, num_hidden=num_hidden,
dropout_every=dropout_every, dropout_every=dropout_every,
drop_prob=dropout_prob) drop_prob=dropout_prob,
)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
class BatchInferModel: class BatchInferModel:
def __init__(self, load_model_func): def __init__(self, load_model_func):
self.device = torch.device("cuda:0" self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available() else "cpu")
self.model = load_model_func().to(self.device) self.model = load_model_func().to(self.device)
def __call__(self, batch) -> "pd.DataFrame": def __call__(self, batch) -> "pd.DataFrame":
tensor = torch.FloatTensor(batch.to_pandas().values).to( tensor = torch.FloatTensor(batch.to_pandas().values).to(self.device)
self.device)
return pd.DataFrame(self.model(tensor).cpu().detach().numpy()) return pd.DataFrame(self.model(tensor).cpu().detach().numpy())
inference_dataset = preprocessor.preprocess_inference_data( inference_dataset = preprocessor.preprocess_inference_data(
read_dataset(inference_path)) read_dataset(inference_path)
inference(inference_dataset, BatchInferModel(load_model_func), 100, )
inference_output_path, use_gpu) inference(
inference_dataset,
BatchInferModel(load_model_func),
100,
inference_output_path,
use_gpu,
)
end_time = time.time() end_time = time.time()

View file

@ -14,20 +14,23 @@ class MyActor:
self.counter = Counter( self.counter = Counter(
"num_requests", "num_requests",
description="Number of requests processed by the actor.", description="Number of requests processed by the actor.",
tag_keys=("actor_name", )) tag_keys=("actor_name",),
)
self.counter.set_default_tags({"actor_name": name}) self.counter.set_default_tags({"actor_name": name})
self.gauge = Gauge( self.gauge = Gauge(
"curr_count", "curr_count",
description="Current count held by the actor. Goes up and down.", description="Current count held by the actor. Goes up and down.",
tag_keys=("actor_name", )) tag_keys=("actor_name",),
)
self.gauge.set_default_tags({"actor_name": name}) self.gauge.set_default_tags({"actor_name": name})
self.histogram = Histogram( self.histogram = Histogram(
"request_latency", "request_latency",
description="Latencies of requests in ms.", description="Latencies of requests in ms.",
boundaries=[0.1, 1], boundaries=[0.1, 1],
tag_keys=("actor_name", )) tag_keys=("actor_name",),
)
self.histogram.set_default_tags({"actor_name": name}) self.histogram.set_default_tags({"actor_name": name})
def process_request(self, num): def process_request(self, num):

View file

@ -46,7 +46,8 @@ class LinearModel(object):
y_ = tf.placeholder(tf.float32, [None, shape[1]]) y_ = tf.placeholder(tf.float32, [None, shape[1]])
self.y_ = y_ self.y_ = y_
cross_entropy = tf.reduce_mean( cross_entropy = tf.reduce_mean(
-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) -tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])
)
self.cross_entropy = cross_entropy self.cross_entropy = cross_entropy
self.cross_entropy_grads = tf.gradients(cross_entropy, [w, b]) self.cross_entropy_grads = tf.gradients(cross_entropy, [w, b])
self.sess = tf.Session() self.sess = tf.Session()
@ -54,24 +55,20 @@ class LinearModel(object):
# Ray's TensorFlowVariables to automatically create methods to modify # Ray's TensorFlowVariables to automatically create methods to modify
# the weights. # the weights.
self.variables = ray.experimental.tf_utils.TensorFlowVariables( self.variables = ray.experimental.tf_utils.TensorFlowVariables(
cross_entropy, self.sess) cross_entropy, self.sess
)
def loss(self, xs, ys): def loss(self, xs, ys):
"""Computes the loss of the network.""" """Computes the loss of the network."""
return float( return float(
self.sess.run( self.sess.run(self.cross_entropy, feed_dict={self.x: xs, self.y_: ys})
self.cross_entropy, feed_dict={ )
self.x: xs,
self.y_: ys
}))
def grad(self, xs, ys): def grad(self, xs, ys):
"""Computes the gradients of the network.""" """Computes the gradients of the network."""
return self.sess.run( return self.sess.run(
self.cross_entropy_grads, feed_dict={ self.cross_entropy_grads, feed_dict={self.x: xs, self.y_: ys}
self.x: xs, )
self.y_: ys
})
@ray.remote @ray.remote
@ -143,4 +140,5 @@ if __name__ == "__main__":
# Use L-BFGS to minimize the loss function. # Use L-BFGS to minimize the loss function.
print("Running L-BFGS.") print("Running L-BFGS.")
result = scipy.optimize.fmin_l_bfgs_b( result = scipy.optimize.fmin_l_bfgs_b(
full_loss, theta_init, maxiter=10, fprime=full_grad, disp=True) full_loss, theta_init, maxiter=10, fprime=full_grad, disp=True
)

View file

@ -26,8 +26,7 @@ class RayDistributedActor:
""" """
# Set the init_method and rank of the process for distributed training. # Set the init_method and rank of the process for distributed training.
print("Ray worker at {url} rank {rank}".format( print("Ray worker at {url} rank {rank}".format(url=url, rank=world_rank))
url=url, rank=world_rank))
self.url = url self.url = url
self.world_rank = world_rank self.world_rank = world_rank
args.distributed_rank = world_rank args.distributed_rank = world_rank
@ -55,8 +54,10 @@ class RayDistributedActor:
n_cpus = int(ray.cluster_resources()["CPU"]) n_cpus = int(ray.cluster_resources()["CPU"])
if n_cpus > original_n_cpus: if n_cpus > original_n_cpus:
raise Exception( raise Exception(
"New CPUs find (original %d CPUs, now %d CPUs)" % "New CPUs find (original %d CPUs, now %d CPUs)"
(original_n_cpus, n_cpus)) % (original_n_cpus, n_cpus)
)
else: else:
original_n_gpus = args.distributed_world_size original_n_gpus = args.distributed_world_size
@ -65,8 +66,9 @@ class RayDistributedActor:
n_gpus = int(ray.cluster_resources().get("GPU", 0)) n_gpus = int(ray.cluster_resources().get("GPU", 0))
if n_gpus > original_n_gpus: if n_gpus > original_n_gpus:
raise Exception( raise Exception(
"New GPUs find (original %d GPUs, now %d GPUs)" % "New GPUs find (original %d GPUs, now %d GPUs)"
(original_n_gpus, n_gpus)) % (original_n_gpus, n_gpus)
)
fairseq.checkpoint_utils.save_checkpoint = _new_save_checkpoint fairseq.checkpoint_utils.save_checkpoint = _new_save_checkpoint
@ -103,8 +105,7 @@ def run_fault_tolerant_loop():
set_batch_size(args) set_batch_size(args)
# Set up Ray distributed actors. # Set up Ray distributed actors.
Actor = ray.remote( Actor = ray.remote(num_cpus=1, num_gpus=int(not args.cpu))(RayDistributedActor)
num_cpus=1, num_gpus=int(not args.cpu))(RayDistributedActor)
workers = [Actor.remote() for i in range(args.distributed_world_size)] workers = [Actor.remote() for i in range(args.distributed_world_size)]
# Get the IP address and a free port of actor 0, which is used for # Get the IP address and a free port of actor 0, which is used for
@ -116,8 +117,7 @@ def run_fault_tolerant_loop():
# Start the remote processes, and check whether their are any process # Start the remote processes, and check whether their are any process
# fails. If so, restart all the processes. # fails. If so, restart all the processes.
unfinished = [ unfinished = [
worker.run.remote(address, i, args) worker.run.remote(address, i, args) for i, worker in enumerate(workers)
for i, worker in enumerate(workers)
] ]
try: try:
while len(unfinished) > 0: while len(unfinished) > 0:
@ -135,10 +135,8 @@ def add_ray_args(parser):
"""Add ray and fault-tolerance related parser arguments to the parser.""" """Add ray and fault-tolerance related parser arguments to the parser."""
group = parser.add_argument_group("Ray related arguments") group = parser.add_argument_group("Ray related arguments")
group.add_argument( group.add_argument(
"--ray-address", "--ray-address", default="auto", type=str, help="address for ray initialization"
default="auto", )
type=str,
help="address for ray initialization")
group.add_argument( group.add_argument(
"--fix-batch-size", "--fix-batch-size",
default=None, default=None,
@ -147,7 +145,8 @@ def add_ray_args(parser):
help="fix the actual batch size (max_sentences * update_freq " help="fix the actual batch size (max_sentences * update_freq "
"* n_GPUs) to be the fixed input values by adjusting update_freq " "* n_GPUs) to be the fixed input values by adjusting update_freq "
"accroding to actual n_GPUs; the batch size is fixed to B_i for " "accroding to actual n_GPUs; the batch size is fixed to B_i for "
"epoch i; all epochs >N are fixed to B_N") "epoch i; all epochs >N are fixed to B_N",
)
return group return group
@ -168,13 +167,13 @@ def set_batch_size(args):
"""Fixes the total batch_size to be agnostic to the GPU count.""" """Fixes the total batch_size to be agnostic to the GPU count."""
if args.fix_batch_size is not None: if args.fix_batch_size is not None:
args.update_freq = [ args.update_freq = [
math.ceil(batch_size / math.ceil(batch_size / (args.max_sentences * args.distributed_world_size))
(args.max_sentences * args.distributed_world_size))
for batch_size in args.fix_batch_size for batch_size in args.fix_batch_size
] ]
print("Training on %d GPUs, max_sentences=%d, update_freq=%s" % print(
(args.distributed_world_size, args.max_sentences, "Training on %d GPUs, max_sentences=%d, update_freq=%s"
repr(args.update_freq))) % (args.distributed_world_size, args.max_sentences, repr(args.update_freq))
)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -62,31 +62,34 @@ import ray
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--address", type=str, default="auto", help="The address to use for Ray.") "--address", type=str, default="auto", help="The address to use for Ray."
)
parser.add_argument( parser.add_argument(
"--smoke-test", "--smoke-test",
action="store_true", action="store_true",
help="Read a smaller dataset for quick testing purposes.") help="Read a smaller dataset for quick testing purposes.",
)
parser.add_argument( parser.add_argument(
"--num-actors", "--num-actors", type=int, default=4, help="Sets number of actors for training."
type=int, )
default=4,
help="Sets number of actors for training.")
parser.add_argument( parser.add_argument(
"--cpus-per-actor", "--cpus-per-actor",
type=int, type=int,
default=8, default=8,
help="The number of CPUs per actor for training.") help="The number of CPUs per actor for training.",
)
parser.add_argument( parser.add_argument(
"--num-actors-inference", "--num-actors-inference",
type=int, type=int,
default=16, default=16,
help="Sets number of actors for inference.") help="Sets number of actors for inference.",
)
parser.add_argument( parser.add_argument(
"--cpus-per-actor-inference", "--cpus-per-actor-inference",
type=int, type=int,
default=2, default=2,
help="The number of CPUs per actor for inference.") help="The number of CPUs per actor for inference.",
)
# Ignore -f from ipykernel_launcher # Ignore -f from ipykernel_launcher
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
@ -119,12 +122,13 @@ if not ray.is_initialized():
LABEL_COLUMN = "label" LABEL_COLUMN = "label"
if smoke_test: if smoke_test:
# Test dataset with only 10,000 records. # Test dataset with only 10,000 records.
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \ FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" ".csv"
".csv"
else: else:
# Full dataset. This may take a couple of minutes to load. # Full dataset. This may take a couple of minutes to load.
FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \ FILE_URL = (
"/00280/HIGGS.csv.gz" "https://archive.ics.uci.edu/ml/machine-learning-databases"
"/00280/HIGGS.csv.gz"
)
colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)] colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)]
@ -182,7 +186,8 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
evals_result=evals_result, evals_result=evals_result,
verbose_eval=False, verbose_eval=False,
num_boost_round=100, num_boost_round=100,
ray_params=ray_params) ray_params=ray_params,
)
train_end_time = time.time() train_end_time = time.time()
train_duration = train_end_time - train_start_time train_duration = train_end_time - train_start_time
@ -190,8 +195,7 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
model_path = "model.xgb" model_path = "model.xgb"
bst.save_model(model_path) bst.save_model(model_path)
print("Final validation error: {:.4f}".format( print("Final validation error: {:.4f}".format(evals_result["eval"]["error"][-1]))
evals_result["eval"]["error"][-1]))
return bst, evals_result return bst, evals_result
@ -208,8 +212,12 @@ config = {
} }
bst, evals_result = train_xgboost( bst, evals_result = train_xgboost(
config, df_train, df_validation, LABEL_COLUMN, config,
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors)) df_train,
df_validation,
LABEL_COLUMN,
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors),
)
print(f"Results: {evals_result}") print(f"Results: {evals_result}")
############################################################################### ###############################################################################
@ -227,7 +235,8 @@ results = predict(
bst, bst,
inference_df, inference_df,
ray_params=RayParams( ray_params=RayParams(
cpus_per_actor=cpus_per_actor_inference, cpus_per_actor=cpus_per_actor_inference, num_actors=num_actors_inference
num_actors=num_actors_inference)) ),
)
print(results) print(results)

View file

@ -12,10 +12,12 @@ class NewsServer(object):
def __init__(self): def __init__(self):
self.conn = sqlite3.connect("newsreader.db") self.conn = sqlite3.connect("newsreader.db")
c = self.conn.cursor() c = self.conn.cursor()
c.execute("""CREATE TABLE IF NOT EXISTS news c.execute(
"""CREATE TABLE IF NOT EXISTS news
(title text, link text, (title text, link text,
description text, published timestamp, description text, published timestamp,
feed url, liked bool)""") feed url, liked bool)"""
)
self.conn.commit() self.conn.commit()
def retrieve_feed(self, url): def retrieve_feed(self, url):
@ -24,36 +26,41 @@ class NewsServer(object):
items = [] items = []
c = self.conn.cursor() c = self.conn.cursor()
for item in feed.items: for item in feed.items:
items.append({ items.append(
"title": item.title, {
"link": item.link, "title": item.title,
"description": item.description, "link": item.link,
"description_text": item.description, "description": item.description,
"pubDate": str(item.pub_date) "description_text": item.description,
}) "pubDate": str(item.pub_date),
}
)
c.execute( c.execute(
"""INSERT INTO news (title, link, description, """INSERT INTO news (title, link, description,
published, feed, liked) values published, feed, liked) values
(?, ?, ?, ?, ?, ?)""", (?, ?, ?, ?, ?, ?)""",
(item.title, item.link, item.description, item.pub_date, (
feed.link, False)) item.title,
item.link,
item.description,
item.pub_date,
feed.link,
False,
),
)
self.conn.commit() self.conn.commit()
return { return {
"channel": { "channel": {"title": feed.title, "link": feed.link, "url": feed.link},
"title": feed.title, "items": items,
"link": feed.link,
"url": feed.link
},
"items": items
} }
def like_item(self, url, is_faved): def like_item(self, url, is_faved):
c = self.conn.cursor() c = self.conn.cursor()
if is_faved: if is_faved:
c.execute("UPDATE news SET liked = 1 WHERE link = ?", (url, )) c.execute("UPDATE news SET liked = 1 WHERE link = ?", (url,))
else: else:
c.execute("UPDATE news SET liked = 0 WHERE link = ?", (url, )) c.execute("UPDATE news SET liked = 0 WHERE link = ?", (url,))
self.conn.commit() self.conn.commit()
@ -77,9 +84,7 @@ def dispatcher():
result = ray.get(method.remote(*method_args)) result = ray.get(method.remote(*method_args))
return jsonify(result) return jsonify(result)
else: else:
return jsonify({ return jsonify({"error": "method_name '" + method_name + "' not found"})
"error": "method_name '" + method_name + "' not found"
})
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -44,16 +44,16 @@ num_evaluations = 10
# A function for generating random hyperparameters. # A function for generating random hyperparameters.
def generate_hyperparameters(): def generate_hyperparameters():
return { return {
"learning_rate": 10**np.random.uniform(-5, 1), "learning_rate": 10 ** np.random.uniform(-5, 1),
"batch_size": np.random.randint(1, 100), "batch_size": np.random.randint(1, 100),
"momentum": np.random.uniform(0, 1) "momentum": np.random.uniform(0, 1),
} }
def get_data_loaders(batch_size): def get_data_loaders(batch_size):
mnist_transforms = transforms.Compose( mnist_transforms = transforms.Compose(
[transforms.ToTensor(), [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
transforms.Normalize((0.1307, ), (0.3081, ))]) )
# We add FileLock here because multiple workers will want to # We add FileLock here because multiple workers will want to
# download data, and this may cause overwrites since # download data, and this may cause overwrites since
@ -61,16 +61,16 @@ def get_data_loaders(batch_size):
with FileLock(os.path.expanduser("~/data.lock")): with FileLock(os.path.expanduser("~/data.lock")):
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.MNIST( datasets.MNIST(
"~/data", "~/data", train=True, download=True, transform=mnist_transforms
train=True, ),
download=True,
transform=mnist_transforms),
batch_size=batch_size, batch_size=batch_size,
shuffle=True) shuffle=True,
)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(
datasets.MNIST("~/data", train=False, transform=mnist_transforms), datasets.MNIST("~/data", train=False, transform=mnist_transforms),
batch_size=batch_size, batch_size=batch_size,
shuffle=True) shuffle=True,
)
return train_loader, test_loader return train_loader, test_loader
@ -152,9 +152,8 @@ def evaluate_hyperparameters(config):
model = ConvNet() model = ConvNet()
train_loader, test_loader = get_data_loaders(config["batch_size"]) train_loader, test_loader = get_data_loaders(config["batch_size"])
optimizer = optim.SGD( optimizer = optim.SGD(
model.parameters(), model.parameters(), lr=config["learning_rate"], momentum=config["momentum"]
lr=config["learning_rate"], )
momentum=config["momentum"])
train(model, optimizer, train_loader) train(model, optimizer, train_loader)
return test(model, test_loader) return test(model, test_loader)
@ -202,22 +201,33 @@ while remaining_ids:
hyperparameters = hyperparameters_mapping[result_id] hyperparameters = hyperparameters_mapping[result_id]
accuracy = ray.get(result_id) accuracy = ray.get(result_id)
print("""We achieve accuracy {:.3}% with print(
"""We achieve accuracy {:.3}% with
learning_rate: {:.2} learning_rate: {:.2}
batch_size: {} batch_size: {}
momentum: {:.2} momentum: {:.2}
""".format(100 * accuracy, hyperparameters["learning_rate"], """.format(
hyperparameters["batch_size"], hyperparameters["momentum"])) 100 * accuracy,
hyperparameters["learning_rate"],
hyperparameters["batch_size"],
hyperparameters["momentum"],
)
)
if accuracy > best_accuracy: if accuracy > best_accuracy:
best_hyperparameters = hyperparameters best_hyperparameters = hyperparameters
best_accuracy = accuracy best_accuracy = accuracy
# Record the best performing set of hyperparameters. # Record the best performing set of hyperparameters.
print("""Best accuracy over {} trials was {:.3} with print(
"""Best accuracy over {} trials was {:.3} with
learning_rate: {:.2} learning_rate: {:.2}
batch_size: {} batch_size: {}
momentum: {:.2} momentum: {:.2}
""".format(num_evaluations, 100 * best_accuracy, """.format(
best_hyperparameters["learning_rate"], num_evaluations,
best_hyperparameters["batch_size"], 100 * best_accuracy,
best_hyperparameters["momentum"])) best_hyperparameters["learning_rate"],
best_hyperparameters["batch_size"],
best_hyperparameters["momentum"],
)
)

View file

@ -39,8 +39,8 @@ import ray
def get_data_loader(): def get_data_loader():
"""Safely downloads data. Returns training/validation set dataloader.""" """Safely downloads data. Returns training/validation set dataloader."""
mnist_transforms = transforms.Compose( mnist_transforms = transforms.Compose(
[transforms.ToTensor(), [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
transforms.Normalize((0.1307, ), (0.3081, ))]) )
# We add FileLock here because multiple workers will want to # We add FileLock here because multiple workers will want to
# download data, and this may cause overwrites since # download data, and this may cause overwrites since
@ -48,16 +48,16 @@ def get_data_loader():
with FileLock(os.path.expanduser("~/data.lock")): with FileLock(os.path.expanduser("~/data.lock")):
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.MNIST( datasets.MNIST(
"~/data", "~/data", train=True, download=True, transform=mnist_transforms
train=True, ),
download=True,
transform=mnist_transforms),
batch_size=128, batch_size=128,
shuffle=True) shuffle=True,
)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(
datasets.MNIST("~/data", train=False, transform=mnist_transforms), datasets.MNIST("~/data", train=False, transform=mnist_transforms),
batch_size=128, batch_size=128,
shuffle=True) shuffle=True,
)
return train_loader, test_loader return train_loader, test_loader
@ -75,7 +75,7 @@ def evaluate(model, test_loader):
_, predicted = torch.max(outputs.data, 1) _, predicted = torch.max(outputs.data, 1)
total += target.size(0) total += target.size(0)
correct += (predicted == target).sum().item() correct += (predicted == target).sum().item()
return 100. * correct / total return 100.0 * correct / total
####################################################################### #######################################################################
@ -144,8 +144,7 @@ class ParameterServer(object):
def apply_gradients(self, *gradients): def apply_gradients(self, *gradients):
summed_gradients = [ summed_gradients = [
np.stack(gradient_zip).sum(axis=0) np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*gradients)
for gradient_zip in zip(*gradients)
] ]
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.model.set_gradients(summed_gradients) self.model.set_gradients(summed_gradients)
@ -215,9 +214,7 @@ test_loader = get_data_loader()[1]
print("Running synchronous parameter server training.") print("Running synchronous parameter server training.")
current_weights = ps.get_weights.remote() current_weights = ps.get_weights.remote()
for i in range(iterations): for i in range(iterations):
gradients = [ gradients = [worker.compute_gradients.remote(current_weights) for worker in workers]
worker.compute_gradients.remote(current_weights) for worker in workers
]
# Calculate update after all gradients are available. # Calculate update after all gradients are available.
current_weights = ps.apply_gradients.remote(*gradients) current_weights = ps.apply_gradients.remote(*gradients)

View file

@ -197,7 +197,7 @@ class Model(object):
"""Applies the gradients to the model parameters with RMSProp.""" """Applies the gradients to the model parameters with RMSProp."""
for k, v in self.weights.items(): for k, v in self.weights.items():
g = grad_buffer[k] g = grad_buffer[k]
rmsprop_cache[k] = (decay * rmsprop_cache[k] + (1 - decay) * g**2) rmsprop_cache[k] = decay * rmsprop_cache[k] + (1 - decay) * g ** 2
self.weights[k] += lr * g / (np.sqrt(rmsprop_cache[k]) + 1e-5) self.weights[k] += lr * g / (np.sqrt(rmsprop_cache[k]) + 1e-5)
@ -278,20 +278,24 @@ for i in range(1, 1 + iterations):
gradient_ids = [] gradient_ids = []
# Launch tasks to compute gradients from multiple rollouts in parallel. # Launch tasks to compute gradients from multiple rollouts in parallel.
start_time = time.time() start_time = time.time()
gradient_ids = [ gradient_ids = [actor.compute_gradient.remote(model_id) for actor in actors]
actor.compute_gradient.remote(model_id) for actor in actors
]
for batch in range(batch_size): for batch in range(batch_size):
[grad_id], gradient_ids = ray.wait(gradient_ids) [grad_id], gradient_ids = ray.wait(gradient_ids)
grad, reward_sum = ray.get(grad_id) grad, reward_sum = ray.get(grad_id)
# Accumulate the gradient over batch. # Accumulate the gradient over batch.
for k in model.weights: for k in model.weights:
grad_buffer[k] += grad[k] grad_buffer[k] += grad[k]
running_reward = (reward_sum if running_reward is None else running_reward = (
running_reward * 0.99 + reward_sum * 0.01) reward_sum
if running_reward is None
else running_reward * 0.99 + reward_sum * 0.01
)
end_time = time.time() end_time = time.time()
print("Batch {} computed {} rollouts in {} seconds, " print(
"running mean is {}".format(i, batch_size, end_time - start_time, "Batch {} computed {} rollouts in {} seconds, "
running_reward)) "running mean is {}".format(
i, batch_size, end_time - start_time, running_reward
)
)
model.update(grad_buffer, rmsprop_cache, learning_rate, decay_rate) model.update(grad_buffer, rmsprop_cache, learning_rate, decay_rate)
zero_grads(grad_buffer) zero_grads(grad_buffer)

View file

@ -23,6 +23,7 @@ from typing import Tuple
from time import sleep from time import sleep
import ray import ray
# For typing purposes # For typing purposes
from ray.actor import ActorHandle from ray.actor import ActorHandle
from tqdm import tqdm from tqdm import tqdm

View file

@ -8,12 +8,11 @@ import wikipedia
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--num-mappers", help="number of mapper actors used", default=3, type=int) "--num-mappers", help="number of mapper actors used", default=3, type=int
)
parser.add_argument( parser.add_argument(
"--num-reducers", "--num-reducers", help="number of reducer actors used", default=4, type=int
help="number of reducer actors used", )
default=4,
type=int)
@ray.remote @ray.remote
@ -36,8 +35,11 @@ class Mapper(object):
while self.num_articles_processed < article_index + 1: while self.num_articles_processed < article_index + 1:
self.get_new_article() self.get_new_article()
# Return the word counts from within a given character range. # Return the word counts from within a given character range.
return [(k, v) for k, v in self.word_counts[article_index].items() return [
if len(k) >= 1 and k[0] >= keys[0] and k[0] <= keys[1]] (k, v)
for k, v in self.word_counts[article_index].items()
if len(k) >= 1 and k[0] >= keys[0] and k[0] <= keys[1]
]
@ray.remote @ray.remote
@ -51,8 +53,7 @@ class Reducer(object):
# Get the word counts for this Reducer's keys from all of the Mappers # Get the word counts for this Reducer's keys from all of the Mappers
# and aggregate the results. # and aggregate the results.
count_ids = [ count_ids = [
mapper.get_range.remote(article_index, self.keys) mapper.get_range.remote(article_index, self.keys) for mapper in self.mappers
for mapper in self.mappers
] ]
while len(count_ids) > 0: while len(count_ids) > 0:
@ -87,9 +88,9 @@ if __name__ == "__main__":
streams.append(Stream([line.strip() for line in f.readlines()])) streams.append(Stream([line.strip() for line in f.readlines()]))
# Partition the keys among the reducers. # Partition the keys among the reducers.
chunks = np.array_split([chr(i) chunks = np.array_split(
for i in range(ord("a"), [chr(i) for i in range(ord("a"), ord("z") + 1)], args.num_reducers
ord("z") + 1)], args.num_reducers) )
keys = [[chunk[0], chunk[-1]] for chunk in chunks] keys = [[chunk[0], chunk[-1]] for chunk in chunks]
# Create a number of mappers. # Create a number of mappers.
@ -103,14 +104,12 @@ if __name__ == "__main__":
while True: while True:
print("article index = {}".format(article_index)) print("article index = {}".format(article_index))
wordcounts = {} wordcounts = {}
counts = ray.get([ counts = ray.get(
reducer.next_reduce_result.remote(article_index) [reducer.next_reduce_result.remote(article_index) for reducer in reducers]
for reducer in reducers )
])
for count in counts: for count in counts:
wordcounts.update(count) wordcounts.update(count)
most_frequent_words = heapq.nlargest( most_frequent_words = heapq.nlargest(10, wordcounts, key=wordcounts.get)
10, wordcounts, key=wordcounts.get)
for word in most_frequent_words: for word in most_frequent_words:
print(" ", word, wordcounts[word]) print(" ", word, wordcounts[word])
article_index += 1 article_index += 1

View file

@ -125,12 +125,12 @@ class MNISTDataInterface(object):
self.data_dir = data_dir self.data_dir = data_dir
self.max_days = max_days self.max_days = max_days
transform = transforms.Compose([ transform = transforms.Compose(
transforms.ToTensor(), [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
transforms.Normalize((0.1307, ), (0.3081, )) )
])
self.dataset = MNIST( self.dataset = MNIST(
self.data_dir, train=True, download=True, transform=transform) self.data_dir, train=True, download=True, transform=transform
)
def _get_day_slice(self, day=0): def _get_day_slice(self, day=0):
if day < 0: if day < 0:
@ -154,8 +154,7 @@ class MNISTDataInterface(object):
end = self._get_day_slice(day) end = self._get_day_slice(day)
available_data = Subset(self.dataset, list(range(start, end))) available_data = Subset(self.dataset, list(range(start, end)))
train_n = int( train_n = int(0.8 * (end - start)) # 80% train data, 20% validation data
0.8 * (end - start)) # 80% train data, 20% validation data
return random_split(available_data, [train_n, end - start - train_n]) return random_split(available_data, [train_n, end - start - train_n])
@ -223,13 +222,15 @@ def test(model, data_loader, device=None):
# will take care of creating the model and optimizer and repeatedly # will take care of creating the model and optimizer and repeatedly
# call the ``train`` function to train the model. Also, this function # call the ``train`` function to train the model. Also, this function
# will report the training progress back to Tune. # will report the training progress back to Tune.
def train_mnist(config, def train_mnist(
start_model=None, config,
checkpoint_dir=None, start_model=None,
num_epochs=10, checkpoint_dir=None,
use_gpus=False, num_epochs=10,
data_fn=None, use_gpus=False,
day=0): data_fn=None,
day=0,
):
# Create model # Create model
use_cuda = use_gpus and torch.cuda.is_available() use_cuda = use_gpus and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu") device = torch.device("cuda" if use_cuda else "cpu")
@ -237,7 +238,8 @@ def train_mnist(config,
# Create optimizer # Create optimizer
optimizer = optim.SGD( optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"]) model.parameters(), lr=config["lr"], momentum=config["momentum"]
)
# Load checkpoint, or load start model if no checkpoint has been # Load checkpoint, or load start model if no checkpoint has been
# passed and a start model is specified # passed and a start model is specified
@ -248,8 +250,7 @@ def train_mnist(config,
load_dir = start_model load_dir = start_model
if load_dir: if load_dir:
model_state, optimizer_state = torch.load( model_state, optimizer_state = torch.load(os.path.join(load_dir, "checkpoint"))
os.path.join(load_dir, "checkpoint"))
model.load_state_dict(model_state) model.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state) optimizer.load_state_dict(optimizer_state)
@ -257,18 +258,22 @@ def train_mnist(config,
train_dataset, validation_dataset = data_fn(day=day) train_dataset, validation_dataset = data_fn(day=day)
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=config["batch_size"], shuffle=True) train_dataset, batch_size=config["batch_size"], shuffle=True
)
validation_loader = torch.utils.data.DataLoader( validation_loader = torch.utils.data.DataLoader(
validation_dataset, batch_size=config["batch_size"], shuffle=True) validation_dataset, batch_size=config["batch_size"], shuffle=True
)
for i in range(num_epochs): for i in range(num_epochs):
train(model, optimizer, train_loader, device) train(model, optimizer, train_loader, device)
acc = test(model, validation_loader, device) acc = test(model, validation_loader, device)
if i == num_epochs - 1: if i == num_epochs - 1:
with tune.checkpoint_dir(step=i) as checkpoint_dir: with tune.checkpoint_dir(step=i) as checkpoint_dir:
torch.save((model.state_dict(), optimizer.state_dict()), torch.save(
os.path.join(checkpoint_dir, "checkpoint")) (model.state_dict(), optimizer.state_dict()),
os.path.join(checkpoint_dir, "checkpoint"),
)
tune.report(mean_accuracy=acc, done=True) tune.report(mean_accuracy=acc, done=True)
else: else:
tune.report(mean_accuracy=acc) tune.report(mean_accuracy=acc)
@ -286,7 +291,7 @@ def train_mnist(config,
# until the given day. Our search space can thus also contain parameters # until the given day. Our search space can thus also contain parameters
# that affect the model complexity (such as the layer size), since it # that affect the model complexity (such as the layer size), since it
# does not have to be compatible to an existing model. # does not have to be compatible to an existing model.
def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0): def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0):
data_interface = MNISTDataInterface("~/data", max_days=10) data_interface = MNISTDataInterface("~/data", max_days=10)
num_examples = data_interface._get_day_slice(day) num_examples = data_interface._get_day_slice(day)
@ -302,11 +307,13 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
mode="max", mode="max",
max_t=num_epochs, max_t=num_epochs,
grace_period=1, grace_period=1,
reduction_factor=2) reduction_factor=2,
)
reporter = CLIReporter( reporter = CLIReporter(
parameter_columns=["layer_size", "lr", "momentum", "batch_size"], parameter_columns=["layer_size", "lr", "momentum", "batch_size"],
metric_columns=["mean_accuracy", "training_iteration"]) metric_columns=["mean_accuracy", "training_iteration"],
)
analysis = tune.run( analysis = tune.run(
partial( partial(
@ -315,17 +322,16 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
data_fn=data_interface.get_data, data_fn=data_interface.get_data,
num_epochs=num_epochs, num_epochs=num_epochs,
use_gpus=True if gpus_per_trial > 0 else False, use_gpus=True if gpus_per_trial > 0 else False,
day=day), day=day,
resources_per_trial={ ),
"cpu": 1, resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
"gpu": gpus_per_trial
},
config=config, config=config,
num_samples=num_samples, num_samples=num_samples,
scheduler=scheduler, scheduler=scheduler,
progress_reporter=reporter, progress_reporter=reporter,
verbose=0, verbose=0,
name="tune_serve_mnist_fromscratch") name="tune_serve_mnist_fromscratch",
)
best_trial = analysis.get_best_trial("mean_accuracy", "max", "last") best_trial = analysis.get_best_trial("mean_accuracy", "max", "last")
best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"] best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"]
@ -344,33 +350,35 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
# layer size parameter. Since we continue to train an existing model, # layer size parameter. Since we continue to train an existing model,
# we cannot change the layer size mid training, so we just continue # we cannot change the layer size mid training, so we just continue
# to use the existing one. # to use the existing one.
def tune_from_existing(start_model, def tune_from_existing(
start_config, start_model, start_config, num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0
num_samples=10, ):
num_epochs=10,
gpus_per_trial=0.,
day=0):
data_interface = MNISTDataInterface("/tmp/mnist_data", max_days=10) data_interface = MNISTDataInterface("/tmp/mnist_data", max_days=10)
num_examples = data_interface._get_day_slice(day) - \ num_examples = data_interface._get_day_slice(day) - data_interface._get_day_slice(
data_interface._get_day_slice(day - 1) day - 1
)
config = start_config.copy() config = start_config.copy()
config.update({ config.update(
"batch_size": tune.choice([16, 32, 64]), {
"lr": tune.loguniform(1e-4, 1e-1), "batch_size": tune.choice([16, 32, 64]),
"momentum": tune.uniform(0.1, 0.9), "lr": tune.loguniform(1e-4, 1e-1),
}) "momentum": tune.uniform(0.1, 0.9),
}
)
scheduler = ASHAScheduler( scheduler = ASHAScheduler(
metric="mean_accuracy", metric="mean_accuracy",
mode="max", mode="max",
max_t=num_epochs, max_t=num_epochs,
grace_period=1, grace_period=1,
reduction_factor=2) reduction_factor=2,
)
reporter = CLIReporter( reporter = CLIReporter(
parameter_columns=["lr", "momentum", "batch_size"], parameter_columns=["lr", "momentum", "batch_size"],
metric_columns=["mean_accuracy", "training_iteration"]) metric_columns=["mean_accuracy", "training_iteration"],
)
analysis = tune.run( analysis = tune.run(
partial( partial(
@ -379,17 +387,16 @@ def tune_from_existing(start_model,
data_fn=data_interface.get_incremental_data, data_fn=data_interface.get_incremental_data,
num_epochs=num_epochs, num_epochs=num_epochs,
use_gpus=True if gpus_per_trial > 0 else False, use_gpus=True if gpus_per_trial > 0 else False,
day=day), day=day,
resources_per_trial={ ),
"cpu": 1, resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
"gpu": gpus_per_trial
},
config=config, config=config,
num_samples=num_samples, num_samples=num_samples,
scheduler=scheduler, scheduler=scheduler,
progress_reporter=reporter, progress_reporter=reporter,
verbose=0, verbose=0,
name="tune_serve_mnist_fromsexisting") name="tune_serve_mnist_fromsexisting",
)
best_trial = analysis.get_best_trial("mean_accuracy", "max", "last") best_trial = analysis.get_best_trial("mean_accuracy", "max", "last")
best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"] best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"]
@ -423,8 +430,8 @@ class MNISTDeployment:
model = ConvNet(layer_size=self.config["layer_size"]).to(self.device) model = ConvNet(layer_size=self.config["layer_size"]).to(self.device)
model_state, optimizer_state = torch.load( model_state, optimizer_state = torch.load(
os.path.join(self.checkpoint_dir, "checkpoint"), os.path.join(self.checkpoint_dir, "checkpoint"), map_location=self.device
map_location=self.device) )
model.load_state_dict(model_state) model.load_state_dict(model_state)
self.model = model self.model = model
@ -442,12 +449,12 @@ class MNISTDeployment:
# active model. We call this directory ``model_dir``. Every time we # active model. We call this directory ``model_dir``. Every time we
# would like to update our model, we copy the checkpoint of the new # would like to update our model, we copy the checkpoint of the new
# model to this directory. We then update the deployment to the new version. # model to this directory. We then update the deployment to the new version.
def serve_new_model(model_dir, checkpoint, config, metrics, day, def serve_new_model(model_dir, checkpoint, config, metrics, day, use_gpu=False):
use_gpu=False):
print("Serving checkpoint: {}".format(checkpoint)) print("Serving checkpoint: {}".format(checkpoint))
checkpoint_path = _move_checkpoint_to_model_dir(model_dir, checkpoint, checkpoint_path = _move_checkpoint_to_model_dir(
config, metrics) model_dir, checkpoint, config, metrics
)
serve.start(detached=True) serve.start(detached=True)
MNISTDeployment.deploy(checkpoint_path, config, metrics, use_gpu) MNISTDeployment.deploy(checkpoint_path, config, metrics, use_gpu)
@ -482,8 +489,7 @@ def get_current_model(model_dir):
checkpoint_path = os.path.join(model_dir, "checkpoint") checkpoint_path = os.path.join(model_dir, "checkpoint")
meta_path = os.path.join(model_dir, "meta.json") meta_path = os.path.join(model_dir, "meta.json")
if not os.path.exists(checkpoint_path) or \ if not os.path.exists(checkpoint_path) or not os.path.exists(meta_path):
not os.path.exists(meta_path):
return None, None, None return None, None, None
with open(meta_path, "rt") as fp: with open(meta_path, "rt") as fp:
@ -559,28 +565,33 @@ if __name__ == "__main__":
"--from_scratch", "--from_scratch",
action="store_true", action="store_true",
help="Train and select best model from scratch", help="Train and select best model from scratch",
default=False) default=False,
)
parser.add_argument( parser.add_argument(
"--from_existing", "--from_existing",
action="store_true", action="store_true",
help="Train and select best model from existing model", help="Train and select best model from existing model",
default=False) default=False,
)
parser.add_argument( parser.add_argument(
"--day", "--day",
help="Indicate the day to simulate the amount of data available to us", help="Indicate the day to simulate the amount of data available to us",
type=int, type=int,
default=0) default=0,
)
parser.add_argument( parser.add_argument(
"--query", help="Query endpoint with example", type=int, default=-1) "--query", help="Query endpoint with example", type=int, default=-1
)
parser.add_argument( parser.add_argument(
"--smoke-test", "--smoke-test",
action="store_true", action="store_true",
help="Finish quickly for testing", help="Finish quickly for testing",
default=False) default=False,
)
args = parser.parse_args() args = parser.parse_args()
@ -600,20 +611,23 @@ if __name__ == "__main__":
# Query our model # Query our model
response = requests.post( response = requests.post(
"http://localhost:8000/mnist", "http://localhost:8000/mnist", json={"images": [data[0].numpy().tolist()]}
json={"images": [data[0].numpy().tolist()]}) )
try: try:
pred = response.json()["result"][0] pred = response.json()["result"][0]
except: # noqa: E722 except: # noqa: E722
pred = -1 pred = -1
print("Querying model with example #{}. " print(
"Label = {}, Response = {}, Correct = {}".format( "Querying model with example #{}. "
args.query, label, pred, label == pred)) "Label = {}, Response = {}, Correct = {}".format(
args.query, label, pred, label == pred
)
)
sys.exit(0) sys.exit(0)
gpus_per_trial = 0.5 if not args.smoke_test else 0. gpus_per_trial = 0.5 if not args.smoke_test else 0.0
serve_gpu = True if gpus_per_trial > 0 else False serve_gpu = True if gpus_per_trial > 0 else False
num_samples = 8 if not args.smoke_test else 1 num_samples = 8 if not args.smoke_test else 1
num_epochs = 10 if not args.smoke_test else 1 num_epochs = 10 if not args.smoke_test else 1
@ -621,23 +635,22 @@ if __name__ == "__main__":
if args.from_scratch: # train everyday from scratch if args.from_scratch: # train everyday from scratch
print("Start training job from scratch on day {}.".format(args.day)) print("Start training job from scratch on day {}.".format(args.day))
acc, config, best_checkpoint, num_examples = tune_from_scratch( acc, config, best_checkpoint, num_examples = tune_from_scratch(
num_samples, num_epochs, gpus_per_trial, day=args.day) num_samples, num_epochs, gpus_per_trial, day=args.day
print("Trained day {} from scratch on {} samples. " )
"Best accuracy: {:.4f}. Best config: {}".format( print(
args.day, num_examples, acc, config)) "Trained day {} from scratch on {} samples. "
"Best accuracy: {:.4f}. Best config: {}".format(
args.day, num_examples, acc, config
)
)
serve_new_model( serve_new_model(
model_dir, model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu
best_checkpoint, )
config,
acc,
args.day,
use_gpu=serve_gpu)
if args.from_existing: if args.from_existing:
old_checkpoint, old_config, old_acc = get_current_model(model_dir) old_checkpoint, old_config, old_acc = get_current_model(model_dir)
if not old_checkpoint or not old_config or not old_acc: if not old_checkpoint or not old_config or not old_acc:
print("No existing model found. Train one with --from_scratch " print("No existing model found. Train one with --from_scratch " "first.")
"first.")
sys.exit(1) sys.exit(1)
acc, config, best_checkpoint, num_examples = tune_from_existing( acc, config, best_checkpoint, num_examples = tune_from_existing(
old_checkpoint, old_checkpoint,
@ -645,17 +658,17 @@ if __name__ == "__main__":
num_samples, num_samples,
num_epochs, num_epochs,
gpus_per_trial, gpus_per_trial,
day=args.day) day=args.day,
print("Trained day {} from existing on {} samples. " )
"Best accuracy: {:.4f}. Best config: {}".format( print(
args.day, num_examples, acc, config)) "Trained day {} from existing on {} samples. "
"Best accuracy: {:.4f}. Best config: {}".format(
args.day, num_examples, acc, config
)
)
serve_new_model( serve_new_model(
model_dir, model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu
best_checkpoint, )
config,
acc,
args.day,
use_gpu=serve_gpu)
####################################################################### #######################################################################
# That's it! We now have an end-to-end workflow to train and update a # That's it! We now have an end-to-end workflow to train and update a

View file

@ -38,6 +38,7 @@ To start out, change the import statement to get tune-scikit-learns grid sear
""" """
# Keep this here for https://github.com/ray-project/ray/issues/11547 # Keep this here for https://github.com/ray-project/ray/issues/11547
from sklearn.model_selection import GridSearchCV from sklearn.model_selection import GridSearchCV
# Replace above line with: # Replace above line with:
from ray.tune.sklearn import TuneGridSearchCV from ray.tune.sklearn import TuneGridSearchCV
@ -60,7 +61,8 @@ X, y = make_classification(
n_informative=50, n_informative=50,
n_redundant=0, n_redundant=0,
n_classes=10, n_classes=10,
class_sep=2.5) class_sep=2.5,
)
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=1000) x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=1000)
# Example parameters to tune from SGDClassifier # Example parameters to tune from SGDClassifier
@ -70,9 +72,11 @@ parameter_grid = {"alpha": [1e-4, 1e-1, 1], "epsilon": [0.01, 0.1]}
# As you can see, the setup here is exactly how you would do it for Scikit-Learn. Now, let's try fitting a model. # As you can see, the setup here is exactly how you would do it for Scikit-Learn. Now, let's try fitting a model.
tune_search = TuneGridSearchCV( tune_search = TuneGridSearchCV(
SGDClassifier(), parameter_grid, early_stopping=True, max_iters=10) SGDClassifier(), parameter_grid, early_stopping=True, max_iters=10
)
import time # Just to compare fit times import time # Just to compare fit times
start = time.time() start = time.time()
tune_search.fit(x_train, y_train) tune_search.fit(x_train, y_train)
end = time.time() end = time.time()
@ -93,6 +97,7 @@ print("Tune GridSearch Fit Time:", end - start)
# Try running this compared to the GridSearchCV equivalent, and see the speedup for yourself! # Try running this compared to the GridSearchCV equivalent, and see the speedup for yourself!
from sklearn.model_selection import GridSearchCV from sklearn.model_selection import GridSearchCV
# n_jobs=-1 enables use of all cores like Tune does # n_jobs=-1 enables use of all cores like Tune does
sklearn_search = GridSearchCV(SGDClassifier(), parameter_grid, n_jobs=-1) sklearn_search = GridSearchCV(SGDClassifier(), parameter_grid, n_jobs=-1)
@ -120,7 +125,7 @@ import numpy as np
digits = datasets.load_digits() digits = datasets.load_digits()
x = digits.data x = digits.data
y = digits.target y = digits.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.2) x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
clf = SGDClassifier() clf = SGDClassifier()
parameter_grid = {"alpha": (1e-4, 1), "epsilon": (0.01, 0.1)} parameter_grid = {"alpha": (1e-4, 1), "epsilon": (0.01, 0.1)}

View file

@ -8,8 +8,9 @@ import ray
def get_host_name(x): def get_host_name(x):
import platform import platform
import time import time
time.sleep(0.01) time.sleep(0.01)
return x + (platform.node(), ) return x + (platform.node(),)
def wait_for_nodes(expected): def wait_for_nodes(expected):
@ -17,8 +18,11 @@ def wait_for_nodes(expected):
while True: while True:
num_nodes = len(ray.nodes()) num_nodes = len(ray.nodes())
if num_nodes < expected: if num_nodes < expected:
print("{} nodes have joined so far, waiting for {} more.".format( print(
num_nodes, expected - num_nodes)) "{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes
)
)
sys.stdout.flush() sys.stdout.flush()
time.sleep(1) time.sleep(1)
else: else:
@ -31,9 +35,7 @@ def main():
# Check that objects can be transferred from each node to each other node. # Check that objects can be transferred from each node to each other node.
for i in range(10): for i in range(10):
print("Iteration {}".format(i)) print("Iteration {}".format(i))
results = [ results = [get_host_name.remote(get_host_name.remote(())) for _ in range(100)]
get_host_name.remote(get_host_name.remote(())) for _ in range(100)
]
print(Counter(ray.get(results))) print(Counter(ray.get(results)))
sys.stdout.flush() sys.stdout.flush()

View file

@ -20,8 +20,9 @@ def setup_logging() -> None:
setup_component_logger( setup_component_logger(
logging_level=ray_constants.LOGGER_LEVEL, # info logging_level=ray_constants.LOGGER_LEVEL, # info
logging_format=ray_constants.LOGGER_FORMAT, logging_format=ray_constants.LOGGER_FORMAT,
log_dir=os.path.join(ray._private.utils.get_ray_temp_dir(), log_dir=os.path.join(
ray.node.SESSION_LATEST, "logs"), ray._private.utils.get_ray_temp_dir(), ray.node.SESSION_LATEST, "logs"
),
filename=ray_constants.MONITOR_LOG_FILE_NAME, # monitor.log filename=ray_constants.MONITOR_LOG_FILE_NAME, # monitor.log
max_bytes=ray_constants.LOGGING_ROTATE_BYTES, max_bytes=ray_constants.LOGGING_ROTATE_BYTES,
backup_count=ray_constants.LOGGING_ROTATE_BACKUP_COUNT, backup_count=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
@ -47,11 +48,11 @@ if __name__ == "__main__":
required=False, required=False,
type=str, type=str,
default=None, default=None,
help="The password to use for Redis") help="The password to use for Redis",
)
args = parser.parse_args() args = parser.parse_args()
cluster_name = yaml.safe_load( cluster_name = yaml.safe_load(open(AUTOSCALING_CONFIG_PATH).read())["cluster_name"]
open(AUTOSCALING_CONFIG_PATH).read())["cluster_name"]
head_ip = get_node_ip_address() head_ip = get_node_ip_address()
Monitor( Monitor(
address=f"{head_ip}:6379", address=f"{head_ip}:6379",

Some files were not shown because too many files have changed in this diff Show more