[CI] Format Python code with Black (#21975)

See #21316 and #21311 for the motivation behind these changes.
This commit is contained in:
Balaji Veeramani 2022-01-29 18:41:57 -08:00 committed by GitHub
parent 95877be8ee
commit 7f1bacc7dc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
1637 changed files with 75167 additions and 59155 deletions

View file

@ -39,14 +39,16 @@ def perform_auth():
resp = requests.get(
"https://vop4ss7n22.execute-api.us-west-2.amazonaws.com/endpoint/",
auth=auth,
params={"job_id": os.environ["BUILDKITE_JOB_ID"]})
params={"job_id": os.environ["BUILDKITE_JOB_ID"]},
)
return resp
def handle_docker_login(resp):
pwd = resp.json()["docker_password"]
subprocess.call(
["docker", "login", "--username", "raytravisbot", "--password", pwd])
["docker", "login", "--username", "raytravisbot", "--password", pwd]
)
def gather_paths(dir_path) -> List[str]:
@ -86,7 +88,7 @@ def upload_paths(paths, resp, destination):
"branch_wheels": f"{branch}/{sha}/{fn}",
"jars": f"jars/latest/{current_os}/{fn}",
"branch_jars": f"jars/{branch}/{sha}/{current_os}/{fn}",
"logs": f"bazel_events/{branch}/{sha}/{bk_job_id}/{fn}"
"logs": f"bazel_events/{branch}/{sha}/{bk_job_id}/{fn}",
}[destination]
of["file"] = open(path, "rb")
r = requests.post(c["url"], files=of)
@ -95,14 +97,19 @@ def upload_paths(paths, resp, destination):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Helper script to upload files to S3 bucket")
description="Helper script to upload files to S3 bucket"
)
parser.add_argument("--path", type=str, required=False)
parser.add_argument("--destination", type=str)
args = parser.parse_args()
assert args.destination in {
"branch_jars", "branch_wheels", "jars", "logs", "wheels",
"docker_login"
"branch_jars",
"branch_wheels",
"jars",
"logs",
"wheels",
"docker_login",
}
assert "BUILDKITE_JOB_ID" in os.environ
assert "BUILDKITE_COMMIT" in os.environ

View file

@ -51,8 +51,10 @@ del monitor_actor
test_utils.wait_for_condition(no_resource_leaks)
rate = MAX_ACTORS_IN_CLUSTER / (end_time - start_time)
print(f"Success! Started {MAX_ACTORS_IN_CLUSTER} actors in "
f"{end_time - start_time}s. ({rate} actors/s)")
print(
f"Success! Started {MAX_ACTORS_IN_CLUSTER} actors in "
f"{end_time - start_time}s. ({rate} actors/s)"
)
if "TEST_OUTPUT_JSON" in os.environ:
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
@ -62,6 +64,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
"time": end_time - start_time,
"success": "1",
"_peak_memory": round(used_gb, 2),
"_peak_process_memory": usage
"_peak_process_memory": usage,
}
json.dump(results, out_file)

View file

@ -77,8 +77,10 @@ del monitor_actor
test_utils.wait_for_condition(no_resource_leaks)
rate = MAX_PLACEMENT_GROUPS / (end_time - start_time)
print(f"Success! Started {MAX_PLACEMENT_GROUPS} pgs in "
f"{end_time - start_time}s. ({rate} pgs/s)")
print(
f"Success! Started {MAX_PLACEMENT_GROUPS} pgs in "
f"{end_time - start_time}s. ({rate} pgs/s)"
)
if "TEST_OUTPUT_JSON" in os.environ:
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
@ -88,6 +90,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
"time": end_time - start_time,
"success": "1",
"_peak_memory": round(used_gb, 2),
"_peak_process_memory": usage
"_peak_process_memory": usage,
}
json.dump(results, out_file)

View file

@ -16,9 +16,7 @@ def test_max_running_tasks(num_tasks):
def task():
time.sleep(sleep_time)
refs = [
task.remote() for _ in tqdm.trange(num_tasks, desc="Launching tasks")
]
refs = [task.remote() for _ in tqdm.trange(num_tasks, desc="Launching tasks")]
max_cpus = ray.cluster_resources()["CPU"]
min_cpus_available = max_cpus
@ -48,8 +46,7 @@ def no_resource_leaks():
@click.command()
@click.option(
"--num-tasks", required=True, type=int, help="Number of tasks to launch.")
@click.option("--num-tasks", required=True, type=int, help="Number of tasks to launch.")
def test(num_tasks):
ray.init(address="auto")
@ -66,8 +63,10 @@ def test(num_tasks):
test_utils.wait_for_condition(no_resource_leaks)
rate = num_tasks / (end_time - start_time - sleep_time)
print(f"Success! Started {num_tasks} tasks in {end_time - start_time}s. "
f"({rate} tasks/s)")
print(
f"Success! Started {num_tasks} tasks in {end_time - start_time}s. "
f"({rate} tasks/s)"
)
if "TEST_OUTPUT_JSON" in os.environ:
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
@ -77,7 +76,7 @@ def test(num_tasks):
"time": end_time - start_time,
"success": "1",
"_peak_memory": round(used_gb, 2),
"_peak_process_memory": usage
"_peak_process_memory": usage,
}
json.dump(results, out_file)

View file

@ -25,10 +25,12 @@ class SimpleActor:
def start_tasks(num_task, num_cpu_per_task, task_duration):
ray.get([
simple_task.options(num_cpus=num_cpu_per_task).remote(task_duration)
for _ in range(num_task)
])
ray.get(
[
simple_task.options(num_cpus=num_cpu_per_task).remote(task_duration)
for _ in range(num_task)
]
)
def measure(f):
@ -40,13 +42,16 @@ def measure(f):
def start_actor(num_actors, num_actors_per_nodes, job):
resources = {"node": floor(1.0 / num_actors_per_nodes)}
submission_cost, actors = measure(lambda: [
SimpleActor.options(resources=resources, num_cpus=0).remote(job)
for _ in range(num_actors)])
ready_cost, _ = measure(
lambda: ray.get([actor.ready.remote() for actor in actors]))
submission_cost, actors = measure(
lambda: [
SimpleActor.options(resources=resources, num_cpus=0).remote(job)
for _ in range(num_actors)
]
)
ready_cost, _ = measure(lambda: ray.get([actor.ready.remote() for actor in actors]))
actor_job_cost, _ = measure(
lambda: ray.get([actor.do_job.remote() for actor in actors]))
lambda: ray.get([actor.do_job.remote() for actor in actors])
)
return (submission_cost, ready_cost, actor_job_cost)
@ -54,33 +59,32 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(prog="Test Scheduling")
# Task workloads
parser.add_argument(
"--total-num-task",
type=int,
help="Total number of tasks.",
required=False)
"--total-num-task", type=int, help="Total number of tasks.", required=False
)
parser.add_argument(
"--num-cpu-per-task",
type=int,
help="Resources needed for tasks.",
required=False)
required=False,
)
parser.add_argument(
"--task-duration-s",
type=int,
help="How long does each task execute.",
required=False,
default=1)
default=1,
)
# Actor workloads
parser.add_argument(
"--total-num-actors",
type=int,
help="Total number of actors.",
required=True)
"--total-num-actors", type=int, help="Total number of actors.", required=True
)
parser.add_argument(
"--num-actors-per-nodes",
type=int,
help="How many actors to allocate for each nodes.",
required=True)
required=True,
)
ray.init(address="auto")
@ -92,13 +96,14 @@ if __name__ == "__main__":
job = None
if args.total_num_task is not None:
if args.num_cpu_per_task is None:
args.num_cpu_per_task = floor(
1.0 * total_cpus / args.total_num_task)
args.num_cpu_per_task = floor(1.0 * total_cpus / args.total_num_task)
job = lambda: start_tasks( # noqa: E731
args.total_num_task, args.num_cpu_per_task, args.task_duration_s)
args.total_num_task, args.num_cpu_per_task, args.task_duration_s
)
submission_cost, ready_cost, actor_job_cost = start_actor(
args.total_num_actors, args.num_actors_per_nodes, job)
args.total_num_actors, args.num_actors_per_nodes, job
)
output = os.environ.get("TEST_OUTPUT_JSON")
@ -118,6 +123,7 @@ if __name__ == "__main__":
if output is not None:
from pathlib import Path
p = Path(output)
p.write_text(json.dumps(result))

View file

@ -12,8 +12,7 @@ def num_alive_nodes():
@click.command()
@click.option(
"--num-nodes", required=True, type=int, help="The target number of nodes")
@click.option("--num-nodes", required=True, type=int, help="The target number of nodes")
def wait_cluster(num_nodes: int):
ray.init(address="auto")
while num_alive_nodes() != num_nodes:

View file

@ -9,7 +9,7 @@ from time import perf_counter
from tqdm import tqdm
NUM_NODES = 50
OBJECT_SIZE = 2**30
OBJECT_SIZE = 2 ** 30
def num_alive_nodes():
@ -60,6 +60,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
"broadcast_time": end - start,
"object_size": OBJECT_SIZE,
"num_nodes": NUM_NODES,
"success": "1"
"success": "1",
}
json.dump(results, out_file)

View file

@ -13,7 +13,7 @@ MAX_ARGS = 10000
MAX_RETURNS = 3000
MAX_RAY_GET_ARGS = 10000
MAX_QUEUED_TASKS = 1_000_000
MAX_RAY_GET_SIZE = 100 * 2**30
MAX_RAY_GET_SIZE = 100 * 2 ** 30
def assert_no_leaks():
@ -189,8 +189,7 @@ print(f"Many args time: {args_time} ({MAX_ARGS} args)")
print(f"Many returns time: {returns_time} ({MAX_RETURNS} returns)")
print(f"Ray.get time: {get_time} ({MAX_RAY_GET_ARGS} args)")
print(f"Queued task time: {queued_time} ({MAX_QUEUED_TASKS} tasks)")
print(f"Ray.get large object time: {large_object_time} "
f"({MAX_RAY_GET_SIZE} bytes)")
print(f"Ray.get large object time: {large_object_time} " f"({MAX_RAY_GET_SIZE} bytes)")
if "TEST_OUTPUT_JSON" in os.environ:
out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
@ -205,6 +204,6 @@ if "TEST_OUTPUT_JSON" in os.environ:
"num_queued": MAX_QUEUED_TASKS,
"large_object_time": large_object_time,
"large_object_size": MAX_RAY_GET_SIZE,
"success": "1"
"success": "1",
}
json.dump(results, out_file)

View file

@ -66,7 +66,7 @@ def get_remote_url(remote):
def replace_suffix(base, old_suffix, new_suffix=""):
if base.endswith(old_suffix):
base = base[:len(base) - len(old_suffix)] + new_suffix
base = base[: len(base) - len(old_suffix)] + new_suffix
return base
@ -199,12 +199,21 @@ def monitor():
expected_line = "{}\t{}".format(expected_sha, ref)
if should_keep_alive(git("show", "-s", "--format=%B", "HEAD^-")):
logger.info("Not monitoring %s on %s due to keep-alive on: %s", ref,
remote, expected_line)
logger.info(
"Not monitoring %s on %s due to keep-alive on: %s",
ref,
remote,
expected_line,
)
return
logger.info("Monitoring %s (%s) for changes in %s: %s", remote,
get_remote_url(remote), ref, expected_line)
logger.info(
"Monitoring %s (%s) for changes in %s: %s",
remote,
get_remote_url(remote),
ref,
expected_line,
)
for to_wait in yield_poll_schedule():
time.sleep(to_wait)
@ -217,12 +226,21 @@ def monitor():
status = ex.returncode
if status == 2:
logger.info("Terminating job as %s has been deleted on %s: %s",
ref, remote, expected_line)
logger.info(
"Terminating job as %s has been deleted on %s: %s",
ref,
remote,
expected_line,
)
break
elif status != 0:
logger.error("Error %d: unable to check %s on %s: %s", status, ref,
remote, expected_line)
logger.error(
"Error %d: unable to check %s on %s: %s",
status,
ref,
remote,
expected_line,
)
else:
prev = expected_line
expected_line = detect_spurious_commit(line, expected_line, remote)
@ -230,14 +248,24 @@ def monitor():
logger.info(
"Terminating job as %s has been updated on %s\n"
" from:\t%s\n"
" to: \t%s", ref, remote, expected_line, line)
" to: \t%s",
ref,
remote,
expected_line,
line,
)
time.sleep(1) # wait for CI to flush output
break
if expected_line != prev:
logger.info(
"%s appeared to spuriously change on %s\n"
" from:\t%s\n"
" to: \t%s", ref, remote, prev, expected_line)
" to: \t%s",
ref,
remote,
prev,
expected_line,
)
return terminate_my_process_group()
@ -259,9 +287,8 @@ def main(program, *args):
if __name__ == "__main__":
logging.basicConfig(
format="%(levelname)s: %(message)s",
stream=sys.stderr,
level=logging.DEBUG)
format="%(levelname)s: %(message)s", stream=sys.stderr, level=logging.DEBUG
)
try:
raise SystemExit(main(*sys.argv) or 0)
except KeyboardInterrupt:

View file

@ -53,7 +53,10 @@ def maybe_fetch_buildkite_token():
"secretsmanager", region_name="us-west-2"
).get_secret_value(
SecretId="arn:aws:secretsmanager:us-west-2:029272617770:secret:"
"buildkite/ro-token")["SecretString"]
"buildkite/ro-token"
)[
"SecretString"
]
def escape(v: Any):
@ -85,26 +88,26 @@ def env_str(env: Dict[str, Any]):
def script_str(v: Any):
if isinstance(v, bool):
return f"\"{int(v)}\""
return f'"{int(v)}"'
elif isinstance(v, Number):
return f"\"{v}\""
return f'"{v}"'
elif isinstance(v, list):
return "(" + " ".join(f"\"{shlex.quote(w)}\"" for w in v) + ")"
return "(" + " ".join(f'"{shlex.quote(w)}"' for w in v) + ")"
else:
return f"\"{shlex.quote(v)}\""
return f'"{shlex.quote(v)}"'
class ReproSession:
plugin_default_env = {
"docker": {
"BUILDKITE_PLUGIN_DOCKER_MOUNT_BUILDKITE_AGENT": False
}
"docker": {"BUILDKITE_PLUGIN_DOCKER_MOUNT_BUILDKITE_AGENT": False}
}
def __init__(self,
buildkite_token: str,
instance_name: Optional[str] = None,
logger: Optional[logging.Logger] = None):
def __init__(
self,
buildkite_token: str,
instance_name: Optional[str] = None,
logger: Optional[logging.Logger] = None,
):
self.logger = logger or logging.getLogger(self.__class__.__name__)
self.bk = Buildkite()
@ -139,12 +142,15 @@ class ReproSession:
# https://buildkite.com/ray-project/ray-builders-pr/
# builds/19635#55a0d71a-831e-4f68-b668-2b10c6f65ee6
pattern = re.compile(
"https://buildkite.com/([^/]+)/([^/]+)/builds/([0-9]+)#(.+)")
"https://buildkite.com/([^/]+)/([^/]+)/builds/([0-9]+)#(.+)"
)
org, pipeline, build_id, job_id = pattern.match(session_url).groups()
self.logger.debug(f"Parsed session URL: {session_url}. "
f"Got org='{org}', pipeline='{pipeline}', "
f"build_id='{build_id}', job_id='{job_id}'.")
self.logger.debug(
f"Parsed session URL: {session_url}. "
f"Got org='{org}', pipeline='{pipeline}', "
f"build_id='{build_id}', job_id='{job_id}'."
)
self.org = org
self.pipeline = pipeline
@ -155,7 +161,8 @@ class ReproSession:
assert self.bk
self.env = self.bk.jobs().get_job_environment_variables(
self.org, self.pipeline, self.build_id, self.job_id)["env"]
self.org, self.pipeline, self.build_id, self.job_id
)["env"]
if overwrite:
self.env.update(overwrite)
@ -166,33 +173,30 @@ class ReproSession:
assert self.env
if not self.aws_instance_name:
self.aws_instance_name = (
f"repro_ci_{self.build_id}_{self.job_id[:8]}")
self.aws_instance_name = f"repro_ci_{self.build_id}_{self.job_id[:8]}"
self.logger.info(
f"No instance name provided, using {self.aws_instance_name}")
f"No instance name provided, using {self.aws_instance_name}"
)
instance_type = self.env["BUILDKITE_AGENT_META_DATA_AWS_INSTANCE_TYPE"]
instance_ami = self.env["BUILDKITE_AGENT_META_DATA_AWS_AMI_ID"]
instance_sg = "sg-0ccfca2ef191c04ae"
instance_block_device_mappings = [{
"DeviceName": "/dev/xvda",
"Ebs": {
"VolumeSize": 500
}
}]
instance_block_device_mappings = [
{"DeviceName": "/dev/xvda", "Ebs": {"VolumeSize": 500}}
]
# Check if instance exists:
running_instances = self.ec2_resource.instances.filter(Filters=[{
"Name": "tag:repro_name",
"Values": [self.aws_instance_name]
}, {
"Name": "instance-state-name",
"Values": ["running"]
}])
running_instances = self.ec2_resource.instances.filter(
Filters=[
{"Name": "tag:repro_name", "Values": [self.aws_instance_name]},
{"Name": "instance-state-name", "Values": ["running"]},
]
)
self.logger.info(
f"Check if instance with name {self.aws_instance_name} "
f"already exists...")
f"already exists..."
)
for instance in running_instances:
self.aws_instance_id = instance.id
@ -201,8 +205,8 @@ class ReproSession:
return
self.logger.info(
f"Instance with name {self.aws_instance_name} not found, "
f"creating...")
f"Instance with name {self.aws_instance_name} not found, " f"creating..."
)
# Else, not running, yet, start.
instance = self.ec2_resource.create_instances(
@ -211,20 +215,18 @@ class ReproSession:
InstanceType=instance_type,
KeyName=self.ssh_key_name,
SecurityGroupIds=[instance_sg],
TagSpecifications=[{
"ResourceType": "instance",
"Tags": [{
"Key": "repro_name",
"Value": self.aws_instance_name
}]
}],
TagSpecifications=[
{
"ResourceType": "instance",
"Tags": [{"Key": "repro_name", "Value": self.aws_instance_name}],
}
],
MinCount=1,
MaxCount=1,
)[0]
self.aws_instance_id = instance.id
self.logger.info(
f"Created new instance with ID {self.aws_instance_id}")
self.logger.info(f"Created new instance with ID {self.aws_instance_id}")
def aws_wait_for_instance(self):
assert self.aws_instance_id
@ -234,28 +236,32 @@ class ReproSession:
repro_instance_state = None
while repro_instance_state != "running":
detail = self.ec2_client.describe_instances(
InstanceIds=[self.aws_instance_id], )
repro_instance_state = \
detail["Reservations"][0]["Instances"][0]["State"]["Name"]
InstanceIds=[self.aws_instance_id],
)
repro_instance_state = detail["Reservations"][0]["Instances"][0]["State"][
"Name"
]
if repro_instance_state != "running":
time.sleep(2)
self.aws_instance_ip = detail["Reservations"][0]["Instances"][0][
"PublicIpAddress"]
"PublicIpAddress"
]
def aws_stop_instance(self):
assert self.aws_instance_id
self.ec2_client.terminate_instances(
InstanceIds=[self.aws_instance_id], )
InstanceIds=[self.aws_instance_id],
)
def print_stop_command(self):
click.secho("To stop this instance in the future, run this: ")
click.secho(
f"aws ec2 terminate-instances "
f"--instance-ids={self.aws_instance_id}",
bold=True)
f"aws ec2 terminate-instances " f"--instance-ids={self.aws_instance_id}",
bold=True,
)
def create_new_ssh_client(self):
assert self.aws_instance_ip
@ -264,7 +270,8 @@ class ReproSession:
self.ssh.close()
self.logger.info(
"Creating SSH client and waiting for SSH to become available...")
"Creating SSH client and waiting for SSH to become available..."
)
ssh = paramiko.client.SSHClient()
ssh.load_system_host_keys()
@ -275,7 +282,8 @@ class ReproSession:
ssh.connect(
self.aws_instance_ip,
username=self.ssh_user,
key_filename=os.path.expanduser(self.ssh_key_file))
key_filename=os.path.expanduser(self.ssh_key_file),
)
break
except paramiko.ssh_exception.NoValidConnectionsError:
self.logger.info("SSH not ready, yet, sleeping 5 seconds")
@ -291,8 +299,7 @@ class ReproSession:
result = {}
def exec():
stdin, stdout, stderr = self.ssh.exec_command(
command, get_pty=True)
stdin, stdout, stderr = self.ssh.exec_command(command, get_pty=True)
output = ""
for line in stdout.readlines():
@ -321,12 +328,13 @@ class ReproSession:
return result.get("output", "")
def execute_ssh_command(
self,
command: str,
env: Optional[Dict[str, str]] = None,
as_script: bool = False,
quiet: bool = False,
command_wrapper: Optional[Callable[[str], str]] = None) -> str:
self,
command: str,
env: Optional[Dict[str, str]] = None,
as_script: bool = False,
quiet: bool = False,
command_wrapper: Optional[Callable[[str], str]] = None,
) -> str:
assert self.ssh
if not command_wrapper:
@ -360,23 +368,25 @@ class ReproSession:
return output
def execute_ssh_commands(self,
commands: List[str],
env: Optional[Dict[str, str]] = None,
quiet: bool = False):
def execute_ssh_commands(
self,
commands: List[str],
env: Optional[Dict[str, str]] = None,
quiet: bool = False,
):
for command in commands:
self.execute_ssh_command(command, env=env, quiet=quiet)
def execute_docker_command(self,
command: str,
env: Optional[Dict[str, str]] = None,
quiet: bool = False):
def execute_docker_command(
self, command: str, env: Optional[Dict[str, str]] = None, quiet: bool = False
):
def command_wrapper(s):
escaped = s.replace("'", "'\"'\"'")
return f"docker exec -it ray_container /bin/bash -ci '{escaped}'"
self.execute_ssh_command(
command, env=env, quiet=quiet, command_wrapper=command_wrapper)
command, env=env, quiet=quiet, command_wrapper=command_wrapper
)
def prepare_instance(self):
self.create_new_ssh_client()
@ -387,8 +397,9 @@ class ReproSession:
self.logger.info("Preparing instance (installing docker etc.)")
commands = [
"sudo yum install -y docker", "sudo service docker start",
f"sudo usermod -aG docker {self.ssh_user}"
"sudo yum install -y docker",
"sudo service docker start",
f"sudo usermod -aG docker {self.ssh_user}",
]
self.execute_ssh_commands(commands, quiet=True)
self.create_new_ssh_client()
@ -398,13 +409,18 @@ class ReproSession:
def docker_login(self):
self.logger.info("Logging into docker...")
credentials = boto3.client(
"ecr", region_name="us-west-2").get_authorization_token()
token = base64.b64decode(credentials["authorizationData"][0][
"authorizationToken"]).decode("utf-8").replace("AWS:", "")
"ecr", region_name="us-west-2"
).get_authorization_token()
token = (
base64.b64decode(credentials["authorizationData"][0]["authorizationToken"])
.decode("utf-8")
.replace("AWS:", "")
)
endpoint = credentials["authorizationData"][0]["proxyEndpoint"]
self.execute_ssh_command(
f"docker login -u AWS -p {token} {endpoint}", quiet=True)
f"docker login -u AWS -p {token} {endpoint}", quiet=True
)
def fetch_buildkite_plugins(self):
assert self.env
@ -415,8 +431,9 @@ class ReproSession:
for collection in plugins:
for plugin, options in collection.items():
plugin_url, plugin_version = plugin.split("#")
if not plugin_url.startswith(
"http://") or not plugin_url.startswith("https://"):
if not plugin_url.startswith("http://") or not plugin_url.startswith(
"https://"
):
plugin_url = f"https://{plugin_url}"
plugin_name = plugin_url.split("/")[-1].rstrip(".git")
@ -432,7 +449,7 @@ class ReproSession:
"version": plugin_version,
"dir": plugin_dir,
"env": plugin_env,
"details": {}
"details": {},
}
def get_plugin_env(self, plugin_short: str, options: Dict[str, Any]):
@ -457,30 +474,33 @@ class ReproSession:
self.execute_ssh_command(
f"[ ! -e {plugin_dir} ] && git clone --depth 1 "
f"--branch {plugin_version} {plugin_url} {plugin_dir}",
quiet=True)
quiet=True,
)
def load_plugin_details(self, plugin: str):
assert plugin in self.plugins
plugin_dir = self.plugins[plugin]["dir"]
yaml_str = self.execute_ssh_command(
f"cat {plugin_dir}/plugin.yml", quiet=True)
yaml_str = self.execute_ssh_command(f"cat {plugin_dir}/plugin.yml", quiet=True)
details = yaml.safe_load(yaml_str)
self.plugins[plugin]["details"] = details
return details
def execute_plugin_hook(self,
plugin: str,
hook: str,
env: Optional[Dict[str, Any]] = None,
script_command: Optional[str] = None):
def execute_plugin_hook(
self,
plugin: str,
hook: str,
env: Optional[Dict[str, Any]] = None,
script_command: Optional[str] = None,
):
assert plugin in self.plugins
self.logger.info(
f"Executing Buildkite hook for plugin {plugin}: {hook}. "
f"This pulls a Docker image and could take a while.")
f"This pulls a Docker image and could take a while."
)
plugin_dir = self.plugins[plugin]["dir"]
plugin_env = self.plugins[plugin]["env"].copy()
@ -500,21 +520,23 @@ class ReproSession:
def print_buildkite_command(self, skipped: bool = False):
print("-" * 80)
print("These are the commands you need to execute to fully reproduce "
"the run")
print(
"These are the commands you need to execute to fully reproduce " "the run"
)
print("-" * 80)
print(self.env["BUILDKITE_COMMAND"])
print("-" * 80)
if skipped and self.skipped_commands:
print("Some of the commands above have already been run. "
"Remaining commands:")
print(
"Some of the commands above have already been run. "
"Remaining commands:"
)
print("-" * 80)
print("\n".join(self.skipped_commands))
print("-" * 80)
def run_buildkite_command(self,
command_filter: Optional[List[str]] = None):
def run_buildkite_command(self, command_filter: Optional[List[str]] = None):
commands = self.env["BUILDKITE_COMMAND"].split("\n")
regexes = [re.compile(cf) for cf in command_filter or []]
@ -537,15 +559,18 @@ class ReproSession:
f"grep -q 'source ~/.env' $HOME/.bashrc "
f"|| echo 'source ~/.env' >> $HOME/.bashrc; "
f"echo 'export {escaped}' > $HOME/.env",
quiet=True)
quiet=True,
)
def attach_to_container(self):
self.logger.info("Attaching to AWS instance...")
ssh_command = (f"ssh -ti {self.ssh_key_file} "
f"-o StrictHostKeyChecking=no "
f"-o ServerAliveInterval=30 "
f"{self.ssh_user}@{self.aws_instance_ip} "
f"'docker exec -it ray_container bash -l'")
ssh_command = (
f"ssh -ti {self.ssh_key_file} "
f"-o StrictHostKeyChecking=no "
f"-o ServerAliveInterval=30 "
f"{self.ssh_user}@{self.aws_instance_ip} "
f"'docker exec -it ray_container bash -l'"
)
subprocess.run(ssh_command, shell=True)
@ -555,29 +580,32 @@ class ReproSession:
@click.option("-n", "--instance-name", default=None)
@click.option("-c", "--commands", is_flag=True, default=False)
@click.option("-f", "--filters", multiple=True, default=[])
def main(session_url: Optional[str],
instance_name: Optional[str] = None,
commands: bool = False,
filters: Optional[List[str]] = None):
def main(
session_url: Optional[str],
instance_name: Optional[str] = None,
commands: bool = False,
filters: Optional[List[str]] = None,
):
random.seed(1235)
logger = logging.getLogger("main")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter("[%(levelname)s %(asctime)s] "
"%(filename)s: %(lineno)d "
"%(message)s"))
logging.Formatter(
"[%(levelname)s %(asctime)s] " "%(filename)s: %(lineno)d " "%(message)s"
)
)
logger.addHandler(handler)
maybe_fetch_buildkite_token()
repro = ReproSession(
os.environ["BUILDKITE_TOKEN"],
instance_name=instance_name,
logger=logger)
os.environ["BUILDKITE_TOKEN"], instance_name=instance_name, logger=logger
)
session_url = session_url or click.prompt(
"Please copy and paste the Buildkite job build URI here")
"Please copy and paste the Buildkite job build URI here"
)
repro.set_session(session_url)
@ -610,13 +638,16 @@ def main(session_url: Optional[str],
"BUILDKITE_PLUGIN_DOCKER_TTY": "0",
"BUILDKITE_PLUGIN_DOCKER_MOUNT_CHECKOUT": "0",
},
script_command=("sed -E 's/"
"docker run/"
"docker run "
"--cap-add=SYS_PTRACE "
"--name ray_container "
"-d/g' | "
"bash -l"))
script_command=(
"sed -E 's/"
"docker run/"
"docker run "
"--cap-add=SYS_PTRACE "
"--name ray_container "
"-d/g' | "
"bash -l"
),
)
repro.create_new_ssh_client()

View file

@ -54,7 +54,8 @@ def get_target_expansion_query(targets, tests_only, exclude_manual):
if exclude_manual:
query = '{} except tests(attr("tags", "manual", set({})))'.format(
query, included_targets)
query, included_targets
)
return query
@ -82,17 +83,16 @@ def get_targets_for_shard(targets, index, count):
def main():
parser = argparse.ArgumentParser(
description="Expand and shard Bazel targets.")
parser = argparse.ArgumentParser(description="Expand and shard Bazel targets.")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--tests_only", action="store_true")
parser.add_argument("--exclude_manual", action="store_true")
parser.add_argument(
"--index", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB", 1))
"--index", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB", 1)
)
parser.add_argument(
"--count",
type=int,
default=os.getenv("BUILDKITE_PARALLEL_JOB_COUNT", 1))
"--count", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB_COUNT", 1)
)
parser.add_argument("targets", nargs="+")
args, extra_args = parser.parse_known_args()
args.targets = list(args.targets) + list(extra_args)
@ -100,11 +100,11 @@ def main():
if args.index >= args.count:
parser.error("--index must be between 0 and {}".format(args.count - 1))
query = get_target_expansion_query(args.targets, args.tests_only,
args.exclude_manual)
query = get_target_expansion_query(
args.targets, args.tests_only, args.exclude_manual
)
expanded_targets = run_bazel_query(query, args.debug)
my_targets = get_targets_for_shard(expanded_targets, args.index,
args.count)
my_targets = get_targets_for_shard(expanded_targets, args.index, args.count)
print(" ".join(my_targets))
return 0

View file

@ -14,10 +14,10 @@ from collections import defaultdict, OrderedDict
def textproto_format(space, key, value, json_encoder):
"""Rewrites a key-value pair from textproto as JSON."""
if value.startswith(b"\""):
if value.startswith(b'"'):
evaluated = ast.literal_eval(value.decode("utf-8"))
value = json_encoder.encode(evaluated).encode("utf-8")
return b"%s[\"%s\", %s]" % (space, key, value)
return b'%s["%s", %s]' % (space, key, value)
def textproto_split(input_lines, json_encoder):
@ -50,19 +50,21 @@ def textproto_split(input_lines, json_encoder):
pieces = re.split(b"(\\r|\\n)", full_line, 1)
pieces[1:] = [b"".join(pieces[1:])]
[line, tail] = pieces
next_line = pat_open.sub(b"\\1[\"\\2\",\\3[", line)
outputs.append(b"" if not prev_comma else b"]"
if next_line.endswith(b"}") else b",")
next_line = pat_open.sub(b'\\1["\\2",\\3[', line)
outputs.append(
b"" if not prev_comma else b"]" if next_line.endswith(b"}") else b","
)
next_line = pat_close.sub(b"]", next_line)
next_line = pat_line.sub(
lambda m: textproto_format(*(m.groups() + (json_encoder, ))),
next_line)
lambda m: textproto_format(*(m.groups() + (json_encoder,))), next_line
)
outputs.append(prev_tail + next_line)
if line == b"}":
yield b"".join(outputs)
del outputs[:]
prev_comma = line != b"}" and (next_line.endswith(b"]")
or next_line.endswith(b"\""))
prev_comma = line != b"}" and (
next_line.endswith(b"]") or next_line.endswith(b'"')
)
prev_tail = tail
if len(outputs) > 0:
yield b"".join(outputs)
@ -80,13 +82,14 @@ class Bazel(object):
def __init__(self, program=None):
if program is None:
program = os.getenv("BAZEL_EXECUTABLE", "bazel")
self.argv = (program, )
self.extra_args = ("--show_progress=no", )
self.argv = (program,)
self.extra_args = ("--show_progress=no",)
def _call(self, command, *args):
return subprocess.check_output(
self.argv + (command, ) + args[:1] + self.extra_args + args[1:],
stdin=subprocess.PIPE)
self.argv + (command,) + args[:1] + self.extra_args + args[1:],
stdin=subprocess.PIPE,
)
def info(self, *args):
result = OrderedDict()
@ -248,8 +251,7 @@ def shellcheck(bazel_aquery, *shellcheck_argv):
def main(program, command, *command_args):
result = 0
if command == textproto2json.__name__:
result = textproto2json(sys.stdin.buffer, sys.stdout.buffer,
*command_args)
result = textproto2json(sys.stdin.buffer, sys.stdout.buffer, *command_args)
elif command == shellcheck.__name__:
result = shellcheck(*command_args)
elif command == preclean.__name__:

View file

@ -20,21 +20,16 @@ DOCKER_CLIENT = None
PYTHON_WHL_VERSION = "cp3"
DOCKER_HUB_DESCRIPTION = {
"base-deps": ("Internal Image, refer to "
"https://hub.docker.com/r/rayproject/ray"),
"ray-deps": ("Internal Image, refer to "
"https://hub.docker.com/r/rayproject/ray"),
"base-deps": (
"Internal Image, refer to " "https://hub.docker.com/r/rayproject/ray"
),
"ray-deps": ("Internal Image, refer to " "https://hub.docker.com/r/rayproject/ray"),
"ray": "Official Docker Images for Ray, the distributed computing API.",
"ray-ml": "Developer ready Docker Image for Ray.",
"ray-worker-container": "Internal Image for CI test",
}
PY_MATRIX = {
"py36": "3.6.12",
"py37": "3.7.7",
"py38": "3.8.5",
"py39": "3.9.5"
}
PY_MATRIX = {"py36": "3.6.12", "py37": "3.7.7", "py38": "3.8.5", "py39": "3.9.5"}
BASE_IMAGES = {
"cu112": "nvidia/cuda:11.2.0-cudnn8-devel-ubuntu18.04",
@ -50,7 +45,7 @@ CUDA_FULL = {
"cu111": "CUDA 11.1",
"cu110": "CUDA 11.0",
"cu102": "CUDA 10.2",
"cu101": "CUDA 10.1"
"cu101": "CUDA 10.1",
}
# The CUDA version to use for the ML Docker image.
@ -62,8 +57,7 @@ IMAGE_NAMES = list(DOCKER_HUB_DESCRIPTION.keys())
def _get_branch():
branch = (os.environ.get("TRAVIS_BRANCH")
or os.environ.get("BUILDKITE_BRANCH"))
branch = os.environ.get("TRAVIS_BRANCH") or os.environ.get("BUILDKITE_BRANCH")
if not branch:
print("Branch not found!")
print(os.environ)
@ -94,8 +88,7 @@ def _get_root_dir():
def _get_commit_sha():
sha = (os.environ.get("TRAVIS_COMMIT")
or os.environ.get("BUILDKITE_COMMIT") or "")
sha = os.environ.get("TRAVIS_COMMIT") or os.environ.get("BUILDKITE_COMMIT") or ""
if len(sha) < 6:
print("INVALID SHA FOUND")
return "ERROR"
@ -105,8 +98,9 @@ def _get_commit_sha():
def _configure_human_version():
global _get_branch
global _get_commit_sha
fake_branch_name = input("Provide a 'branch name'. For releases, it "
"should be `releases/x.x.x`")
fake_branch_name = input(
"Provide a 'branch name'. For releases, it " "should be `releases/x.x.x`"
)
_get_branch = lambda: fake_branch_name # noqa: E731
fake_sha = input("Provide a SHA (used for tag value)")
_get_commit_sha = lambda: fake_sha # noqa: E731
@ -115,38 +109,44 @@ def _configure_human_version():
def _get_wheel_name(minor_version_number):
if minor_version_number:
matches = [
file for file in glob.glob(
file
for file in glob.glob(
f"{_get_root_dir()}/.whl/ray-*{PYTHON_WHL_VERSION}"
f"{minor_version_number}*-manylinux*")
f"{minor_version_number}*-manylinux*"
)
if "+" not in file # Exclude dbg, asan builds
]
assert len(matches) == 1, (
f"Found ({len(matches)}) matches for 'ray-*{PYTHON_WHL_VERSION}"
f"{minor_version_number}*-manylinux*' instead of 1.\n"
f"wheel matches: {matches}")
f"wheel matches: {matches}"
)
return os.path.basename(matches[0])
else:
matches = glob.glob(
f"{_get_root_dir()}/.whl/*{PYTHON_WHL_VERSION}*-manylinux*")
matches = glob.glob(f"{_get_root_dir()}/.whl/*{PYTHON_WHL_VERSION}*-manylinux*")
return [os.path.basename(i) for i in matches]
def _check_if_docker_files_modified():
stdout = subprocess.check_output([
sys.executable, f"{_get_curr_dir()}/determine_tests_to_run.py",
"--output=json"
])
stdout = subprocess.check_output(
[
sys.executable,
f"{_get_curr_dir()}/determine_tests_to_run.py",
"--output=json",
]
)
affected_env_var_list = json.loads(stdout)
affected = ("RAY_CI_DOCKER_AFFECTED" in affected_env_var_list or
"RAY_CI_PYTHON_DEPENDENCIES_AFFECTED" in affected_env_var_list)
affected = (
"RAY_CI_DOCKER_AFFECTED" in affected_env_var_list
or "RAY_CI_PYTHON_DEPENDENCIES_AFFECTED" in affected_env_var_list
)
print(f"Docker affected: {affected}")
return affected
def _build_docker_image(image_name: str,
py_version: str,
image_type: str,
no_cache=True):
def _build_docker_image(
image_name: str, py_version: str, image_type: str, no_cache=True
):
"""Builds Docker image with the provided info.
image_name (str): The name of the image to build. Must be one of
@ -161,23 +161,27 @@ def _build_docker_image(image_name: str,
if image_name not in IMAGE_NAMES:
raise ValueError(
f"The provided image name {image_name} is not "
f"recognized. Image names must be one of {IMAGE_NAMES}")
f"recognized. Image names must be one of {IMAGE_NAMES}"
)
if py_version not in PY_MATRIX.keys():
raise ValueError(f"The provided python version {py_version} is not "
f"recognized. Python version must be one of"
f" {PY_MATRIX.keys()}")
raise ValueError(
f"The provided python version {py_version} is not "
f"recognized. Python version must be one of"
f" {PY_MATRIX.keys()}"
)
if image_type not in BASE_IMAGES.keys():
raise ValueError(f"The provided CUDA version {image_type} is not "
f"recognized. CUDA version must be one of"
f" {image_type.keys()}")
raise ValueError(
f"The provided CUDA version {image_type} is not "
f"recognized. CUDA version must be one of"
f" {image_type.keys()}"
)
# TODO(https://github.com/ray-project/ray/issues/16599):
# remove below after supporting ray-ml images with Python 3.9
if image_name == "ray-ml" and py_version == "py39":
print(f"{image_name} image is currently unsupported with "
"Python 3.9")
print(f"{image_name} image is currently unsupported with " "Python 3.9")
return
build_args = {}
@ -212,7 +216,7 @@ def _build_docker_image(image_name: str,
labels = {
"image-name": image_name,
"python-version": PY_MATRIX[py_version],
"ray-commit": _get_commit_sha()
"ray-commit": _get_commit_sha(),
}
if image_type in CUDA_FULL:
labels["cuda-version"] = CUDA_FULL[image_type]
@ -222,7 +226,8 @@ def _build_docker_image(image_name: str,
tag=tagged_name,
nocache=no_cache,
labels=labels,
buildargs=build_args)
buildargs=build_args,
)
cmd_output = []
try:
@ -230,12 +235,15 @@ def _build_docker_image(image_name: str,
current_iter = start
for line in output:
cmd_output.append(line.decode("utf-8"))
if datetime.datetime.now(
) - current_iter >= datetime.timedelta(minutes=5):
if datetime.datetime.now() - current_iter >= datetime.timedelta(
minutes=5
):
current_iter = datetime.datetime.now()
elapsed = datetime.datetime.now() - start
print(f"Still building {tagged_name} after "
f"{elapsed.seconds} seconds")
print(
f"Still building {tagged_name} after "
f"{elapsed.seconds} seconds"
)
if elapsed >= datetime.timedelta(minutes=15):
print("Additional build output:")
print(*cmd_output, sep="\n")
@ -259,8 +267,10 @@ def _build_docker_image(image_name: str,
def copy_wheels(human_build):
if human_build:
print("Please download images using:\n"
"`pip download --python-version <py_version> ray==<ray_version>")
print(
"Please download images using:\n"
"`pip download --python-version <py_version> ray==<ray_version>"
)
root_dir = _get_root_dir()
wheels = _get_wheel_name(None)
for wheel in wheels:
@ -268,7 +278,8 @@ def copy_wheels(human_build):
ray_dst = os.path.join(root_dir, "docker/ray/.whl/")
ray_dep_dst = os.path.join(root_dir, "docker/ray-deps/.whl/")
ray_worker_container_dst = os.path.join(
root_dir, "docker/ray-worker-container/.whl/")
root_dir, "docker/ray-worker-container/.whl/"
)
os.makedirs(ray_dst, exist_ok=True)
shutil.copy(source, ray_dst)
os.makedirs(ray_dep_dst, exist_ok=True)
@ -282,8 +293,7 @@ def check_staleness(repository, tag):
age = DOCKER_CLIENT.api.inspect_image(f"{repository}:{tag}")["Created"]
short_date = datetime.datetime.strptime(age.split("T")[0], "%Y-%m-%d")
is_stale = (
datetime.datetime.now() - short_date) > datetime.timedelta(days=14)
is_stale = (datetime.datetime.now() - short_date) > datetime.timedelta(days=14)
return is_stale
@ -292,28 +302,23 @@ def build_for_all_versions(image_name, py_versions, image_types, **kwargs):
for py_version in py_versions:
for image_type in image_types:
_build_docker_image(
image_name,
py_version=py_version,
image_type=image_type,
**kwargs)
image_name, py_version=py_version, image_type=image_type, **kwargs
)
def build_base_images(py_versions, image_types):
build_for_all_versions(
"base-deps", py_versions, image_types, no_cache=False)
build_for_all_versions(
"ray-deps", py_versions, image_types, no_cache=False)
build_for_all_versions("base-deps", py_versions, image_types, no_cache=False)
build_for_all_versions("ray-deps", py_versions, image_types, no_cache=False)
def build_or_pull_base_images(py_versions: List[str],
image_types: List[str],
rebuild_base_images: bool = True) -> bool:
def build_or_pull_base_images(
py_versions: List[str], image_types: List[str], rebuild_base_images: bool = True
) -> bool:
"""Returns images to tag and build."""
repositories = ["rayproject/base-deps", "rayproject/ray-deps"]
tags = [
f"nightly-{py_version}-{image_type}"
for py_version, image_type in itertools.product(
py_versions, image_types)
for py_version, image_type in itertools.product(py_versions, image_types)
]
try:
@ -339,12 +344,15 @@ def build_or_pull_base_images(py_versions: List[str],
def prep_ray_ml():
root_dir = _get_root_dir()
requirement_files = glob.glob(
f"{_get_root_dir()}/python/**/requirements*.txt", recursive=True)
f"{_get_root_dir()}/python/**/requirements*.txt", recursive=True
)
for fl in requirement_files:
shutil.copy(fl, os.path.join(root_dir, "docker/ray-ml/"))
# Install atari roms script
shutil.copy(f"{_get_root_dir()}/rllib/utils/install_atari_roms.sh",
os.path.join(root_dir, "docker/ray-ml/"))
shutil.copy(
f"{_get_root_dir()}/rllib/utils/install_atari_roms.sh",
os.path.join(root_dir, "docker/ray-ml/"),
)
def _get_docker_creds() -> Tuple[str, str]:
@ -377,10 +385,13 @@ def _tag_and_push(full_image_name, old_tag, new_tag, merge_build=False):
DOCKER_CLIENT.api.tag(
image=f"{full_image_name}:{old_tag}",
repository=full_image_name,
tag=new_tag)
tag=new_tag,
)
if not merge_build:
print("This is a PR Build! On a merge build, we would normally push"
f"to: {full_image_name}:{new_tag}")
print(
"This is a PR Build! On a merge build, we would normally push"
f"to: {full_image_name}:{new_tag}"
)
else:
_docker_push(full_image_name, new_tag)
@ -395,16 +406,17 @@ def _create_new_tags(all_tags, old_str, new_str):
# For non-release builds, push "nightly" & "sha"
# For release builds, push "nightly" & "latest" & "x.x.x"
def push_and_tag_images(py_versions: List[str],
image_types: List[str],
push_base_images: bool,
merge_build: bool = False):
def push_and_tag_images(
py_versions: List[str],
image_types: List[str],
push_base_images: bool,
merge_build: bool = False,
):
date_tag = datetime.datetime.now().strftime("%Y-%m-%d")
sha_tag = _get_commit_sha()
if _release_build():
release_name = re.search("[0-9]+\.[0-9]+\.[0-9].*",
_get_branch()).group(0)
release_name = re.search("[0-9]+\.[0-9]+\.[0-9].*", _get_branch()).group(0)
date_tag = release_name
sha_tag = release_name
@ -423,16 +435,19 @@ def push_and_tag_images(py_versions: List[str],
for py_name in py_versions:
for image_type in image_types:
if image_name == "ray-ml" and image_type != ML_CUDA_VERSION:
print("ML Docker image is not built for the following "
f"device type: {image_type}")
print(
"ML Docker image is not built for the following "
f"device type: {image_type}"
)
continue
# TODO(https://github.com/ray-project/ray/issues/16599):
# remove below after supporting ray-ml images with Python 3.9
if image_name in ["ray-ml"
] and PY_MATRIX[py_name].startswith("3.9"):
print(f"{image_name} image is currently "
f"unsupported with Python 3.9")
if image_name in ["ray-ml"] and PY_MATRIX[py_name].startswith("3.9"):
print(
f"{image_name} image is currently "
f"unsupported with Python 3.9"
)
continue
tag = f"nightly-{py_name}-{image_type}"
@ -445,20 +460,19 @@ def push_and_tag_images(py_versions: List[str],
for old_tag in tag_mapping.keys():
if "cpu" in old_tag:
new_tags = _create_new_tags(
tag_mapping[old_tag], old_str="-cpu", new_str="")
tag_mapping[old_tag], old_str="-cpu", new_str=""
)
tag_mapping[old_tag].extend(new_tags)
elif ML_CUDA_VERSION in old_tag:
new_tags = _create_new_tags(
tag_mapping[old_tag],
old_str=f"-{ML_CUDA_VERSION}",
new_str="-gpu")
tag_mapping[old_tag], old_str=f"-{ML_CUDA_VERSION}", new_str="-gpu"
)
tag_mapping[old_tag].extend(new_tags)
if image_name == "ray-ml":
new_tags = _create_new_tags(
tag_mapping[old_tag],
old_str=f"-{ML_CUDA_VERSION}",
new_str="")
tag_mapping[old_tag], old_str=f"-{ML_CUDA_VERSION}", new_str=""
)
tag_mapping[old_tag].extend(new_tags)
# No Python version specified should refer to DEFAULT_PYTHON_VERSION
@ -467,7 +481,8 @@ def push_and_tag_images(py_versions: List[str],
new_tags = _create_new_tags(
tag_mapping[old_tag],
old_str=f"-{DEFAULT_PYTHON_VERSION}",
new_str="")
new_str="",
)
tag_mapping[old_tag].extend(new_tags)
# For all tags, create Date/Sha tags
@ -475,7 +490,8 @@ def push_and_tag_images(py_versions: List[str],
new_tags = _create_new_tags(
tag_mapping[old_tag],
old_str="nightly",
new_str=date_tag if "-deps" in image_name else sha_tag)
new_str=date_tag if "-deps" in image_name else sha_tag,
)
tag_mapping[old_tag].extend(new_tags)
# Sanity checking.
@ -511,7 +527,8 @@ def push_and_tag_images(py_versions: List[str],
full_image_name,
old_tag=old_tag,
new_tag=new_tag,
merge_build=merge_build)
merge_build=merge_build,
)
# Push infra here:
@ -527,9 +544,9 @@ def push_readmes(merge_build: bool):
"DOCKER_PASS": password,
"PUSHRM_FILE": f"/myvol/docker/{image}/README.md",
"PUSHRM_DEBUG": 1,
"PUSHRM_SHORT": tag_line
"PUSHRM_SHORT": tag_line,
}
cmd_string = (f"rayproject/{image}")
cmd_string = f"rayproject/{image}"
print(
DOCKER_CLIENT.containers.run(
@ -546,7 +563,9 @@ def push_readmes(merge_build: bool):
detach=False,
stderr=True,
stdout=True,
tty=False))
tty=False,
)
)
# Build base-deps/ray-deps only on file change, 2 weeks, per release
@ -566,63 +585,73 @@ if __name__ == "__main__":
choices=list(PY_MATRIX.keys()),
default="py37",
nargs="*",
help="Which python versions to build. "
"Must be in (py36, py37, py38, py39)")
help="Which python versions to build. " "Must be in (py36, py37, py38, py39)",
)
parser.add_argument(
"--device-types",
choices=list(BASE_IMAGES.keys()),
default=None,
nargs="*",
help="Which device types (CPU/CUDA versions) to build images for. "
"If not specified, images will be built for all device types.")
"If not specified, images will be built for all device types.",
)
parser.add_argument(
"--build-type",
choices=BUILD_TYPES,
required=True,
help="Whether to bypass checking if docker is affected")
help="Whether to bypass checking if docker is affected",
)
parser.add_argument(
"--build-base",
dest="base",
action="store_true",
help="Whether to build base-deps & ray-deps")
help="Whether to build base-deps & ray-deps",
)
parser.add_argument("--no-build-base", dest="base", action="store_false")
parser.set_defaults(base=True)
parser.add_argument(
"--only-build-worker-container",
dest="only_build_worker_container",
action="store_true",
help="Whether only to build ray-worker-container")
help="Whether only to build ray-worker-container",
)
parser.set_defaults(only_build_worker_container=False)
args = parser.parse_args()
py_versions = args.py_versions
py_versions = py_versions if isinstance(py_versions,
list) else [py_versions]
py_versions = py_versions if isinstance(py_versions, list) else [py_versions]
image_types = args.device_types if args.device_types else list(
BASE_IMAGES.keys())
image_types = args.device_types if args.device_types else list(BASE_IMAGES.keys())
assert set(list(CUDA_FULL.keys()) + ["cpu"]) == set(BASE_IMAGES.keys())
# Make sure the python images and cuda versions we build here are
# consistent with the ones used with fix-latest-docker.sh script.
py_version_file = os.path.join(_get_root_dir(), "docker/retag-lambda",
"python_versions.txt")
py_version_file = os.path.join(
_get_root_dir(), "docker/retag-lambda", "python_versions.txt"
)
with open(py_version_file) as f:
py_file_versions = f.read().splitlines()
assert set(PY_MATRIX.keys()) == set(py_file_versions), \
(PY_MATRIX.keys(), py_file_versions)
assert set(PY_MATRIX.keys()) == set(py_file_versions), (
PY_MATRIX.keys(),
py_file_versions,
)
cuda_version_file = os.path.join(_get_root_dir(), "docker/retag-lambda",
"cuda_versions.txt")
cuda_version_file = os.path.join(
_get_root_dir(), "docker/retag-lambda", "cuda_versions.txt"
)
with open(cuda_version_file) as f:
cuda_file_versions = f.read().splitlines()
assert set(BASE_IMAGES.keys()) == set(cuda_file_versions + ["cpu"]),\
(BASE_IMAGES.keys(), cuda_file_versions + ["cpu"])
assert set(BASE_IMAGES.keys()) == set(cuda_file_versions + ["cpu"]), (
BASE_IMAGES.keys(),
cuda_file_versions + ["cpu"],
)
print("Building the following python versions: ",
[PY_MATRIX[py_version] for py_version in py_versions])
print(
"Building the following python versions: ",
[PY_MATRIX[py_version] for py_version in py_versions],
)
print("Building images for the following devices: ", image_types)
print("Building base images: ", args.base)
@ -639,9 +668,11 @@ if __name__ == "__main__":
if build_type == HUMAN:
# If manually triggered, request user for branch and SHA value to use.
_configure_human_version()
if (build_type in {HUMAN, MERGE, BUILDKITE, LOCAL}
or _check_if_docker_files_modified()
or args.only_build_worker_container):
if (
build_type in {HUMAN, MERGE, BUILDKITE, LOCAL}
or _check_if_docker_files_modified()
or args.only_build_worker_container
):
DOCKER_CLIENT = docker.from_env()
is_merge = build_type == MERGE
# Buildkite is authenticated in the background.
@ -652,11 +683,11 @@ if __name__ == "__main__":
DOCKER_CLIENT.api.login(username=username, password=password)
copy_wheels(build_type == HUMAN)
is_base_images_built = build_or_pull_base_images(
py_versions, image_types, args.base)
py_versions, image_types, args.base
)
if args.only_build_worker_container:
build_for_all_versions("ray-worker-container", py_versions,
image_types)
build_for_all_versions("ray-worker-container", py_versions, image_types)
# TODO Currently don't push ray_worker_container
else:
# Build Ray Docker images.
@ -668,15 +699,19 @@ if __name__ == "__main__":
prep_ray_ml()
# Only build ML Docker for the ML_CUDA_VERSION
build_for_all_versions(
"ray-ml", py_versions, image_types=[ML_CUDA_VERSION])
"ray-ml", py_versions, image_types=[ML_CUDA_VERSION]
)
if build_type in {MERGE, PR}:
valid_branch = _valid_branch()
if (not valid_branch) and is_merge:
print(f"Invalid Branch found: {_get_branch()}")
push_and_tag_images(py_versions, image_types,
is_base_images_built, valid_branch
and is_merge)
push_and_tag_images(
py_versions,
image_types,
is_base_images_built,
valid_branch and is_merge,
)
# TODO(ilr) Re-Enable Push READMEs by using a normal password
# (not auth token :/)

View file

@ -20,7 +20,8 @@ def build_multinode_image(source_image: str, target_image: str):
f.write("RUN sudo apt install -y openssh-server\n")
subprocess.check_output(
f"docker build -t {target_image} .", shell=True, cwd=tempdir)
f"docker build -t {target_image} .", shell=True, cwd=tempdir
)
shutil.rmtree(tempdir)

View file

@ -25,9 +25,7 @@ def perform_check(raw_xml_string: str):
missing_owners = []
for rule in tree.findall("rule"):
test_name = rule.attrib["name"]
tags = [
child.attrib["value"] for child in rule.find("list").getchildren()
]
tags = [child.attrib["value"] for child in rule.find("list").getchildren()]
team_owner = [t for t in tags if t.startswith("team")]
if len(team_owner) == 0:
missing_owners.append(test_name)
@ -36,7 +34,8 @@ def perform_check(raw_xml_string: str):
if len(missing_owners):
raise Exception(
f"Cannot find owner for tests {missing_owners}, please add "
"`team:*` to the tags.")
"`team:*` to the tags."
)
print(owners)

View file

@ -19,11 +19,7 @@ exit_with_error = False
def check_import(file):
check_to_lines = {
"import ray": -1,
"import psutil": -1,
"import setproctitle": -1
}
check_to_lines = {"import ray": -1, "import psutil": -1, "import setproctitle": -1}
with io.open(file, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
@ -37,8 +33,10 @@ def check_import(file):
# It will not match the following
# - submodule import: `import ray.constants as ray_constants`
# - submodule import: `from ray import xyz`
if re.search(r"^\s*" + check + r"(\s*|\s+# noqa F401.*)$",
line) and check_to_lines[check] == -1:
if (
re.search(r"^\s*" + check + r"(\s*|\s+# noqa F401.*)$", line)
and check_to_lines[check] == -1
):
check_to_lines[check] = i
for import_lib in ["import psutil", "import setproctitle"]:
@ -48,8 +46,8 @@ def check_import(file):
if import_ray_line == -1 or import_ray_line > import_psutil_line:
print(
"{}:{}".format(str(file), import_psutil_line + 1),
"{} without explicitly import ray before it.".format(
import_lib))
"{} without explicitly import ray before it.".format(import_lib),
)
global exit_with_error
exit_with_error = True
@ -59,8 +57,7 @@ if __name__ == "__main__":
parser.add_argument("path", help="File path to check. e.g. '.' or './src'")
# TODO(simon): For the future, consider adding a feature to explicitly
# white-list the path instead of skipping them.
parser.add_argument(
"-s", "--skip", action="append", help="Skip certian directory")
parser.add_argument("-s", "--skip", action="append", help="Skip certian directory")
args = parser.parse_args()
file_path = Path(args.path)

View file

@ -18,7 +18,7 @@ DEFAULT_BLACKLIST = [
"gpustat",
"opencensus",
"prometheus_client",
"smart_open"
"smart_open",
]
@ -28,19 +28,20 @@ def assert_packages_not_installed(blacklist: List[str]):
except ImportError: # pip < 10.0
from pip.operations import freeze
installed_packages = [
p.split("==")[0].split(" @ ")[0] for p in freeze.freeze()
]
installed_packages = [p.split("==")[0].split(" @ ")[0] for p in freeze.freeze()]
assert not any(p in installed_packages for p in blacklist), \
f"Found blacklisted packages in installed python packages: " \
f"{[p for p in blacklist if p in installed_packages]}. " \
f"Minimal dependency tests could be tainted by this. " \
f"Check the install logs and primary dependencies if any of these " \
assert not any(p in installed_packages for p in blacklist), (
f"Found blacklisted packages in installed python packages: "
f"{[p for p in blacklist if p in installed_packages]}. "
f"Minimal dependency tests could be tainted by this. "
f"Check the install logs and primary dependencies if any of these "
f"packages were installed as part of another install step."
)
print(f"Confirmed that blacklisted packages are not installed in "
f"current Python environment: {blacklist}")
print(
f"Confirmed that blacklisted packages are not installed in "
f"current Python environment: {blacklist}"
)
if __name__ == "__main__":

View file

@ -54,7 +54,8 @@ def run_tidy(task_queue, lock, timeout):
command = task_queue.get()
try:
proc = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
if timeout is not None:
watchdog = threading.Timer(timeout, proc.kill)
@ -70,22 +71,21 @@ def run_tidy(task_queue, lock, timeout):
sys.stderr.flush()
except Exception as e:
with lock:
sys.stderr.write("Failed: " + str(e) + ": ".join(command) +
"\n")
sys.stderr.write("Failed: " + str(e) + ": ".join(command) + "\n")
finally:
with lock:
if timeout is not None and watchdog is not None:
if not watchdog.is_alive():
sys.stderr.write("Terminated by timeout: " +
" ".join(command) + "\n")
sys.stderr.write(
"Terminated by timeout: " + " ".join(command) + "\n"
)
watchdog.cancel()
task_queue.task_done()
def start_workers(max_tasks, tidy_caller, task_queue, lock, timeout):
for _ in range(max_tasks):
t = threading.Thread(
target=tidy_caller, args=(task_queue, lock, timeout))
t = threading.Thread(target=tidy_caller, args=(task_queue, lock, timeout))
t.daemon = True
t.start()
@ -119,84 +119,87 @@ def main():
parser = argparse.ArgumentParser(
description="Run clang-tidy against changed files, and "
"output diagnostics only for modified "
"lines.")
"lines."
)
parser.add_argument(
"-clang-tidy-binary",
metavar="PATH",
default="clang-tidy",
help="path to clang-tidy binary")
help="path to clang-tidy binary",
)
parser.add_argument(
"-p",
metavar="NUM",
default=0,
help="strip the smallest prefix containing P slashes")
help="strip the smallest prefix containing P slashes",
)
parser.add_argument(
"-regex",
metavar="PATTERN",
default=None,
help="custom pattern selecting file paths to check "
"(case sensitive, overrides -iregex)")
"(case sensitive, overrides -iregex)",
)
parser.add_argument(
"-iregex",
metavar="PATTERN",
default=r".*\.(cpp|cc|c\+\+|cxx|c|cl|h|hpp|m|mm|inc)",
help="custom pattern selecting file paths to check "
"(case insensitive, overridden by -regex)")
"(case insensitive, overridden by -regex)",
)
parser.add_argument(
"-j",
type=int,
default=1,
help="number of tidy instances to be run in parallel.")
help="number of tidy instances to be run in parallel.",
)
parser.add_argument(
"-timeout",
type=int,
default=None,
help="timeout per each file in seconds.")
"-timeout", type=int, default=None, help="timeout per each file in seconds."
)
parser.add_argument(
"-fix",
action="store_true",
default=False,
help="apply suggested fixes")
"-fix", action="store_true", default=False, help="apply suggested fixes"
)
parser.add_argument(
"-checks",
help="checks filter, when not specified, use clang-tidy "
"default",
default="")
help="checks filter, when not specified, use clang-tidy " "default",
default="",
)
parser.add_argument(
"-path",
dest="build_path",
help="Path used to read a compile command database.")
"-path", dest="build_path", help="Path used to read a compile command database."
)
if yaml:
parser.add_argument(
"-export-fixes",
metavar="FILE",
dest="export_fixes",
help="Create a yaml file to store suggested fixes in, "
"which can be applied with clang-apply-replacements.")
"which can be applied with clang-apply-replacements.",
)
parser.add_argument(
"-extra-arg",
dest="extra_arg",
action="append",
default=[],
help="Additional argument to append to the compiler "
"command line.")
help="Additional argument to append to the compiler " "command line.",
)
parser.add_argument(
"-extra-arg-before",
dest="extra_arg_before",
action="append",
default=[],
help="Additional argument to prepend to the compiler "
"command line.")
help="Additional argument to prepend to the compiler " "command line.",
)
parser.add_argument(
"-quiet",
action="store_true",
default=False,
help="Run clang-tidy in quiet mode")
help="Run clang-tidy in quiet mode",
)
clang_tidy_args = []
argv = sys.argv[1:]
if "--" in argv:
clang_tidy_args.extend(argv[argv.index("--"):])
argv = argv[:argv.index("--")]
clang_tidy_args.extend(argv[argv.index("--") :])
argv = argv[: argv.index("--")]
args = parser.parse_args(argv)
@ -204,7 +207,7 @@ def main():
filename = None
lines_by_file = {}
for line in sys.stdin:
match = re.search('^\+\+\+\ \"?(.*?/){%s}([^ \t\n\"]*)' % args.p, line)
match = re.search('^\+\+\+\ "?(.*?/){%s}([^ \t\n"]*)' % args.p, line)
if match:
filename = match.group(2)
if filename is None:
@ -226,8 +229,7 @@ def main():
if line_count == 0:
continue
end_line = start_line + line_count - 1
lines_by_file.setdefault(filename,
[]).append([start_line, end_line])
lines_by_file.setdefault(filename, []).append([start_line, end_line])
if not any(lines_by_file):
print("No relevant changes found.")
@ -267,11 +269,8 @@ def main():
for name in lines_by_file:
line_filter_json = json.dumps(
[{
"name": name,
"lines": lines_by_file[name]
}],
separators=(",", ":"))
[{"name": name, "lines": lines_by_file[name]}], separators=(",", ":")
)
# Run clang-tidy on files containing changes.
command = [args.clang_tidy_binary]

View file

@ -38,8 +38,10 @@ def is_pull_request():
for key in ["GITHUB_EVENT_NAME", "TRAVIS_EVENT_TYPE"]:
event_type = os.getenv(key, event_type)
if (os.environ.get("BUILDKITE")
and os.environ.get("BUILDKITE_PULL_REQUEST") != "false"):
if (
os.environ.get("BUILDKITE")
and os.environ.get("BUILDKITE_PULL_REQUEST") != "false"
):
event_type = "pull_request"
return event_type == "pull_request"
@ -67,8 +69,7 @@ def get_commit_range():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output", type=str, help="json or envvars", default="envvars")
parser.add_argument("--output", type=str, help="json or envvars", default="envvars")
args = parser.parse_args()
RAY_CI_TUNE_AFFECTED = 0
@ -103,8 +104,7 @@ if __name__ == "__main__":
try:
graph = pda.build_dep_graph()
rllib_tests = pda.list_rllib_tests()
print(
"Total # of RLlib tests: ", len(rllib_tests), file=sys.stderr)
print("Total # of RLlib tests: ", len(rllib_tests), file=sys.stderr)
impacted = {}
for test in rllib_tests:
@ -120,9 +120,7 @@ if __name__ == "__main__":
print(e, file=sys.stderr)
# End of dry run.
skip_prefix_list = [
"doc/", "examples/", "dev/", "kubernetes/", "site/"
]
skip_prefix_list = ["doc/", "examples/", "dev/", "kubernetes/", "site/"]
for changed_file in files:
if changed_file.startswith("python/ray/tune"):
@ -181,7 +179,8 @@ if __name__ == "__main__":
# Java also depends on Python CLI to manage processes.
RAY_CI_JAVA_AFFECTED = 1
if changed_file.startswith("python/setup.py") or re.match(
".*requirements.*\.txt", changed_file):
".*requirements.*\.txt", changed_file
):
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED = 1
elif changed_file.startswith("java/"):
RAY_CI_JAVA_AFFECTED = 1
@ -190,12 +189,9 @@ if __name__ == "__main__":
elif changed_file.startswith("docker/"):
RAY_CI_DOCKER_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
elif changed_file.startswith("doc/") and changed_file.endswith(
".py"):
elif changed_file.startswith("doc/") and changed_file.endswith(".py"):
RAY_CI_DOC_AFFECTED = 1
elif any(
changed_file.startswith(prefix)
for prefix in skip_prefix_list):
elif any(changed_file.startswith(prefix) for prefix in skip_prefix_list):
# nothing is run but linting in these cases
pass
elif changed_file.endswith("build-docker-images.py"):
@ -246,26 +242,28 @@ if __name__ == "__main__":
RAY_CI_DASHBOARD_AFFECTED = 1
# Log the modified environment variables visible in console.
output_string = " ".join([
"RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED),
"RAY_CI_SGD_AFFECTED={}".format(RAY_CI_SGD_AFFECTED),
"RAY_CI_TRAIN_AFFECTED={}".format(RAY_CI_TRAIN_AFFECTED),
"RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED),
"RAY_CI_RLLIB_DIRECTLY_AFFECTED={}".format(
RAY_CI_RLLIB_DIRECTLY_AFFECTED),
"RAY_CI_SERVE_AFFECTED={}".format(RAY_CI_SERVE_AFFECTED),
"RAY_CI_DASHBOARD_AFFECTED={}".format(RAY_CI_DASHBOARD_AFFECTED),
"RAY_CI_DOC_AFFECTED={}".format(RAY_CI_DOC_AFFECTED),
"RAY_CI_CORE_CPP_AFFECTED={}".format(RAY_CI_CORE_CPP_AFFECTED),
"RAY_CI_CPP_AFFECTED={}".format(RAY_CI_CPP_AFFECTED),
"RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED),
"RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED),
"RAY_CI_LINUX_WHEELS_AFFECTED={}".format(RAY_CI_LINUX_WHEELS_AFFECTED),
"RAY_CI_MACOS_WHEELS_AFFECTED={}".format(RAY_CI_MACOS_WHEELS_AFFECTED),
"RAY_CI_DOCKER_AFFECTED={}".format(RAY_CI_DOCKER_AFFECTED),
"RAY_CI_PYTHON_DEPENDENCIES_AFFECTED={}".format(
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED),
])
output_string = " ".join(
[
"RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED),
"RAY_CI_SGD_AFFECTED={}".format(RAY_CI_SGD_AFFECTED),
"RAY_CI_TRAIN_AFFECTED={}".format(RAY_CI_TRAIN_AFFECTED),
"RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED),
"RAY_CI_RLLIB_DIRECTLY_AFFECTED={}".format(RAY_CI_RLLIB_DIRECTLY_AFFECTED),
"RAY_CI_SERVE_AFFECTED={}".format(RAY_CI_SERVE_AFFECTED),
"RAY_CI_DASHBOARD_AFFECTED={}".format(RAY_CI_DASHBOARD_AFFECTED),
"RAY_CI_DOC_AFFECTED={}".format(RAY_CI_DOC_AFFECTED),
"RAY_CI_CORE_CPP_AFFECTED={}".format(RAY_CI_CORE_CPP_AFFECTED),
"RAY_CI_CPP_AFFECTED={}".format(RAY_CI_CPP_AFFECTED),
"RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED),
"RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED),
"RAY_CI_LINUX_WHEELS_AFFECTED={}".format(RAY_CI_LINUX_WHEELS_AFFECTED),
"RAY_CI_MACOS_WHEELS_AFFECTED={}".format(RAY_CI_MACOS_WHEELS_AFFECTED),
"RAY_CI_DOCKER_AFFECTED={}".format(RAY_CI_DOCKER_AFFECTED),
"RAY_CI_PYTHON_DEPENDENCIES_AFFECTED={}".format(
RAY_CI_PYTHON_DEPENDENCIES_AFFECTED
),
]
)
# Debug purpose
print(output_string, file=sys.stderr)

View file

@ -5,7 +5,7 @@
# Cause the script to exit if a single command fails
set -euo pipefail
BLACK_IS_ENABLED=false
BLACK_IS_ENABLED=true
FLAKE8_VERSION_REQUIRED="3.9.1"
BLACK_VERSION_REQUIRED="21.12b0"

View file

@ -17,19 +17,19 @@ import json
def gha_get_self_url():
import requests
# stringed together api call to get the current check's html url.
sha = os.environ["GITHUB_SHA"]
repo = os.environ["GITHUB_REPOSITORY"]
resp = requests.get(
"https://api.github.com/repos/{}/commits/{}/check-suites".format(
repo, sha))
"https://api.github.com/repos/{}/commits/{}/check-suites".format(repo, sha)
)
data = resp.json()
for check in data["check_suites"]:
slug = check["app"]["slug"]
if slug == "github-actions":
run_url = check["check_runs_url"]
html_url = (
requests.get(run_url).json()["check_runs"][0]["html_url"])
html_url = requests.get(run_url).json()["check_runs"][0]["html_url"]
return html_url
# Return a fallback url
@ -47,10 +47,12 @@ def get_build_env():
if os.environ.get("BUILDKITE"):
return {
"TRAVIS_COMMIT": os.environ["BUILDKITE_COMMIT"],
"TRAVIS_JOB_WEB_URL": (os.environ["BUILDKITE_BUILD_URL"] + "#" +
os.environ["BUILDKITE_BUILD_ID"]),
"TRAVIS_OS_NAME": # The map is used to stay consistent with Travis
{
"TRAVIS_JOB_WEB_URL": (
os.environ["BUILDKITE_BUILD_URL"]
+ "#"
+ os.environ["BUILDKITE_BUILD_ID"]
),
"TRAVIS_OS_NAME": { # The map is used to stay consistent with Travis
"linux": "linux",
"darwin": "osx",
"win32": "windows",
@ -70,13 +72,10 @@ def get_build_config():
return {"config": {"env": "Windows CI"}}
if os.environ.get("BUILDKITE"):
return {
"config": {
"env": "Buildkite " + os.environ["BUILDKITE_LABEL"]
}
}
return {"config": {"env": "Buildkite " + os.environ["BUILDKITE_LABEL"]}}
import requests
url = "https://api.travis-ci.com/job/{job_id}?include=job.config"
url = url.format(job_id=os.environ["TRAVIS_JOB_ID"])
resp = requests.get(url, headers={"Travis-API-Version": "3"})
@ -87,9 +86,4 @@ if __name__ == "__main__":
build_env = get_build_env()
build_config = get_build_config()
print(
json.dumps(
{
"build_env": build_env,
"build_config": build_config
}, indent=2))
print(json.dumps({"build_env": build_env, "build_config": build_config}, indent=2))

View file

@ -43,7 +43,8 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
test: only return information about a specific test.
"""
tests_res = _run_shell(
["bazel", "query", "tests(//python/ray/rllib:*)", "--output", "label"])
["bazel", "query", "tests(//python/ray/rllib:*)", "--output", "label"]
)
all_tests = []
@ -53,15 +54,18 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
if test and t != test:
continue
src_out = _run_shell([
"bazel", "query", "kind(\"source file\", deps({}))".format(t),
"--output", "label"
])
src_out = _run_shell(
[
"bazel",
"query",
'kind("source file", deps({}))'.format(t),
"--output",
"label",
]
)
srcs = [f.strip() for f in src_out.splitlines()]
srcs = [
f for f in srcs if f.startswith("//python") and f.endswith(".py")
]
srcs = [f for f in srcs if f.startswith("//python") and f.endswith(".py")]
if srcs:
all_tests.append((t, srcs))
@ -73,8 +77,7 @@ def list_rllib_tests(n: int = -1, test: str = None) -> Tuple[str, List[str]]:
def _new_dep(graph: DepGraph, src_module: str, dep: str):
"""Create a new dependency between src_module and dep.
"""
"""Create a new dependency between src_module and dep."""
if dep not in graph.ids:
graph.ids[dep] = len(graph.ids)
@ -87,8 +90,7 @@ def _new_dep(graph: DepGraph, src_module: str, dep: str):
def _new_import(graph: DepGraph, src_module: str, dep_module: str):
"""Process a new import statement in src_module.
"""
"""Process a new import statement in src_module."""
# We don't care about system imports.
if not dep_module.startswith("ray"):
return
@ -97,8 +99,7 @@ def _new_import(graph: DepGraph, src_module: str, dep_module: str):
def _is_path_module(module: str, name: str, _base_dir: str) -> bool:
"""Figure out if base.sub is a python module or not.
"""
"""Figure out if base.sub is a python module or not."""
# Special handling for _raylet, which is a C++ lib.
if module == "ray._raylet":
return False
@ -110,10 +111,10 @@ def _is_path_module(module: str, name: str, _base_dir: str) -> bool:
return False
def _new_from_import(graph: DepGraph, src_module: str, dep_module: str,
dep_name: str, _base_dir: str):
"""Process a new "from ... import ..." statement in src_module.
"""
def _new_from_import(
graph: DepGraph, src_module: str, dep_module: str, dep_name: str, _base_dir: str
):
"""Process a new "from ... import ..." statement in src_module."""
# We don't care about imports outside of ray package.
if not dep_module or not dep_module.startswith("ray"):
return
@ -126,10 +127,7 @@ def _new_from_import(graph: DepGraph, src_module: str, dep_module: str,
_new_dep(graph, src_module, dep_module)
def _process_file(graph: DepGraph,
src_path: str,
src_module: str,
_base_dir=""):
def _process_file(graph: DepGraph, src_path: str, src_module: str, _base_dir=""):
"""Create dependencies from src_module to all the valid imports in src_path.
Args:
@ -147,13 +145,13 @@ def _process_file(graph: DepGraph,
_new_import(graph, src_module, alias.name)
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
_new_from_import(graph, src_module, node.module,
alias.name, _base_dir)
_new_from_import(
graph, src_module, node.module, alias.name, _base_dir
)
def build_dep_graph() -> DepGraph:
"""Build index from py files to their immediate dependees.
"""
"""Build index from py files to their immediate dependees."""
graph = DepGraph()
# Assuming we run from root /ray directory.
@ -197,8 +195,7 @@ def _full_module_path(module, f) -> str:
def _should_skip(d: str) -> bool:
"""Skip directories that should not contain py sources.
"""
"""Skip directories that should not contain py sources."""
if d.startswith("python/.eggs/"):
return True
if d.startswith("python/."):
@ -224,14 +221,14 @@ def _bazel_path_to_module_path(d: str) -> str:
def _file_path_to_module_path(f: str) -> str:
"""Return the corresponding module path for a .py file.
"""
"""Return the corresponding module path for a .py file."""
dir, fn = os.path.split(f)
return _full_module_path(_bazel_path_to_module_path(dir), fn)
def _depends(graph: DepGraph, visited: Dict[int, bool], tid: int,
qid: int) -> List[int]:
def _depends(
graph: DepGraph, visited: Dict[int, bool], tid: int, qid: int
) -> List[int]:
"""Whether there is a dependency path from module tid to module qid.
Given graph, and without going through visited.
@ -253,8 +250,9 @@ def _depends(graph: DepGraph, visited: Dict[int, bool], tid: int,
return []
def test_depends_on_file(graph: DepGraph, test: Tuple[str, Tuple[str]],
path: str) -> List[int]:
def test_depends_on_file(
graph: DepGraph, test: Tuple[str, Tuple[str]], path: str
) -> List[int]:
"""Give dependency graph, check if a test depends on a specific .py file.
Args:
@ -307,8 +305,7 @@ def _find_circular_dep_impl(graph: DepGraph, id: str, branch: str) -> bool:
def find_circular_dep(graph: DepGraph) -> Dict[str, List[int]]:
"""Find circular dependencies among a dependency graph.
"""
"""Find circular dependencies among a dependency graph."""
known = {}
circles = {}
for m, id in graph.ids.items():
@ -334,25 +331,29 @@ if __name__ == "__main__":
"--mode",
type=str,
default="test-dep",
help=("test-dep: find dependencies for a specified test. "
"circular-dep: find circular dependencies in "
"the specific codebase."))
help=(
"test-dep: find dependencies for a specified test. "
"circular-dep: find circular dependencies in "
"the specific codebase."
),
)
parser.add_argument(
"--file",
type=str,
help="Path of a .py source file relative to --base_dir.")
"--file", type=str, help="Path of a .py source file relative to --base_dir."
)
parser.add_argument("--test", type=str, help="Specific test to check.")
parser.add_argument(
"--smoke-test",
action="store_true",
help="Load only a few tests for testing.")
"--smoke-test", action="store_true", help="Load only a few tests for testing."
)
args = parser.parse_args()
print("building dep graph ...")
graph = build_dep_graph()
print("done. total {} files, {} of which have dependencies.".format(
len(graph.ids), len(graph.edges)))
print(
"done. total {} files, {} of which have dependencies.".format(
len(graph.ids), len(graph.edges)
)
)
if args.mode == "circular-dep":
circles = find_circular_dep(graph)

View file

@ -12,30 +12,33 @@ class TestPyDepAnalysis(unittest.TestCase):
f.close()
def test_full_module_path(self):
self.assertEqual(
pda._full_module_path("aa.bb.cc", "__init__.py"), "aa.bb.cc")
self.assertEqual(
pda._full_module_path("aa.bb.cc", "dd.py"), "aa.bb.cc.dd")
self.assertEqual(pda._full_module_path("aa.bb.cc", "__init__.py"), "aa.bb.cc")
self.assertEqual(pda._full_module_path("aa.bb.cc", "dd.py"), "aa.bb.cc.dd")
self.assertEqual(pda._full_module_path("", "dd.py"), "dd")
def test_bazel_path_to_module_path(self):
self.assertEqual(
pda._bazel_path_to_module_path("//python/ray/rllib:xxx/yyy/dd"),
"ray.rllib.xxx.yyy.dd")
"ray.rllib.xxx.yyy.dd",
)
self.assertEqual(
pda._bazel_path_to_module_path("python:ray/rllib/xxx/yyy/dd"),
"ray.rllib.xxx.yyy.dd")
"ray.rllib.xxx.yyy.dd",
)
self.assertEqual(
pda._bazel_path_to_module_path("python/ray/rllib:xxx/yyy/dd"),
"ray.rllib.xxx.yyy.dd")
"ray.rllib.xxx.yyy.dd",
)
def test_file_path_to_module_path(self):
self.assertEqual(
pda._file_path_to_module_path("python/ray/rllib/env/env.py"),
"ray.rllib.env.env")
"ray.rllib.env.env",
)
self.assertEqual(
pda._file_path_to_module_path("python/ray/rllib/env/__init__.py"),
"ray.rllib.env")
"ray.rllib.env",
)
def test_import_line_continuation(self):
graph = pda.DepGraph()
@ -44,11 +47,13 @@ class TestPyDepAnalysis(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "continuation1.py")
self.create_tmp_file(
src_path, """
src_path,
"""
import ray.rllib.env.\\
mock_env
b = 2
""")
""",
)
pda._process_file(graph, src_path, "ray")
self.assertEqual(len(graph.ids), 2)
@ -64,11 +69,13 @@ b = 2
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "continuation1.py")
self.create_tmp_file(
src_path, """
src_path,
"""
from ray.rllib.env import (ClassName,
module1, module2)
b = 2
""")
""",
)
pda._process_file(graph, src_path, "ray")
self.assertEqual(len(graph.ids), 2)
@ -84,11 +91,13 @@ b = 2
with tempfile.TemporaryDirectory() as tmpdir:
src_path = "multi_line_comment_3.py"
self.create_tmp_file(
os.path.join(tmpdir, src_path), """
os.path.join(tmpdir, src_path),
"""
from ray.rllib.env import mock_env
a = 1
b = 2
""")
""",
)
# Touch ray/rllib/env/mock_env.py in tmpdir,
# so that it looks like a module.
module_dir = os.path.join(tmpdir, "python", "ray", "rllib", "env")
@ -112,11 +121,13 @@ b = 2
with tempfile.TemporaryDirectory() as tmpdir:
src_path = "multi_line_comment_3.py"
self.create_tmp_file(
os.path.join(tmpdir, src_path), """
os.path.join(tmpdir, src_path),
"""
from ray.rllib.env import MockEnv
a = 1
b = 2
""")
""",
)
# Touch ray/rllib/env.py in tmpdir,
# MockEnv is a class on env module.
module_dir = os.path.join(tmpdir, "python", "ray", "rllib")
@ -138,4 +149,5 @@ b = 2
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -22,8 +22,11 @@ import ray.ray_constants as ray_constants
import ray._private.services
import ray._private.utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsPublisher
from ray._private.gcs_utils import GcsClient, \
get_gcs_address_from_redis, use_gcs_for_bootstrap
from ray._private.gcs_utils import (
GcsClient,
get_gcs_address_from_redis,
use_gcs_for_bootstrap,
)
from ray.core.generated import agent_manager_pb2
from ray.core.generated import agent_manager_pb2_grpc
from ray._private.ray_logging import setup_component_logger
@ -42,23 +45,25 @@ aiogrpc.init_grpc_aio()
class DashboardAgent(object):
def __init__(self,
node_ip_address,
redis_address,
dashboard_agent_port,
gcs_address,
minimal,
redis_password=None,
temp_dir=None,
session_dir=None,
runtime_env_dir=None,
log_dir=None,
metrics_export_port=None,
node_manager_port=None,
listen_port=0,
object_store_name=None,
raylet_name=None,
logging_params=None):
def __init__(
self,
node_ip_address,
redis_address,
dashboard_agent_port,
gcs_address,
minimal,
redis_password=None,
temp_dir=None,
session_dir=None,
runtime_env_dir=None,
log_dir=None,
metrics_export_port=None,
node_manager_port=None,
listen_port=0,
object_store_name=None,
raylet_name=None,
logging_params=None,
):
"""Initialize the DashboardAgent object."""
# Public attributes are accessible for all agent modules.
self.ip = node_ip_address
@ -92,15 +97,16 @@ class DashboardAgent(object):
self.ppid = int(os.environ["RAY_RAYLET_PID"])
assert self.ppid > 0
logger.info("Parent pid is %s", self.ppid)
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
self.server, f"{grpc_ip}:{self.dashboard_agent_port}")
logger.info("Dashboard agent grpc address: %s:%s", grpc_ip,
self.grpc_port)
options = (("grpc.enable_http_proxy", 0), )
self.server, f"{grpc_ip}:{self.dashboard_agent_port}"
)
logger.info("Dashboard agent grpc address: %s:%s", grpc_ip, self.grpc_port)
options = (("grpc.enable_http_proxy", 0),)
self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel(
f"{self.ip}:{self.node_manager_port}", options, asynchronous=True)
f"{self.ip}:{self.node_manager_port}", options, asynchronous=True
)
# If the agent is started as non-minimal version, http server should
# be configured to communicate with the dashboard in a head node.
@ -108,6 +114,7 @@ class DashboardAgent(object):
async def _configure_http_server(self, modules):
from ray.dashboard.http_server_agent import HttpServerAgent
http_server = HttpServerAgent(self.ip, self.listen_port)
await http_server.start(modules)
return http_server
@ -116,10 +123,12 @@ class DashboardAgent(object):
"""Load dashboard agent modules."""
modules = []
agent_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardAgentModule)
dashboard_utils.DashboardAgentModule
)
for cls in agent_cls_list:
logger.info("Loading %s: %s",
dashboard_utils.DashboardAgentModule.__name__, cls)
logger.info(
"Loading %s: %s", dashboard_utils.DashboardAgentModule.__name__, cls
)
c = cls(self)
modules.append(c)
logger.info("Loaded %d modules.", len(modules))
@ -137,13 +146,12 @@ class DashboardAgent(object):
curr_proc = psutil.Process()
while True:
parent = curr_proc.parent()
if (parent is None or parent.pid == 1
or self.ppid != parent.pid):
if parent is None or parent.pid == 1 or self.ppid != parent.pid:
logger.error("Raylet is dead, exiting.")
sys.exit(0)
await asyncio.sleep(
dashboard_consts.
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS)
dashboard_consts.DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS
)
except Exception:
logger.error("Failed to check parent PID, exiting.")
sys.exit(1)
@ -154,15 +162,17 @@ class DashboardAgent(object):
if not use_gcs_for_bootstrap():
# Create an aioredis client for all modules.
try:
self.aioredis_client = \
await dashboard_utils.get_aioredis_client(
self.redis_address, self.redis_password,
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES)
self.aioredis_client = await dashboard_utils.get_aioredis_client(
self.redis_address,
self.redis_password,
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES,
)
except (socket.gaierror, ConnectionRefusedError):
logger.error(
"Dashboard agent exiting: "
"Failed to connect to redis at %s", self.redis_address)
"Dashboard agent exiting: " "Failed to connect to redis at %s",
self.redis_address,
)
sys.exit(-1)
# Start a grpc asyncio server.
@ -170,7 +180,8 @@ class DashboardAgent(object):
if not use_gcs_for_bootstrap():
gcs_address = await self.aioredis_client.get(
dashboard_consts.GCS_SERVER_ADDRESS)
dashboard_consts.GCS_SERVER_ADDRESS
)
self.gcs_client = GcsClient(address=gcs_address.decode())
else:
self.gcs_client = GcsClient(address=self.gcs_address)
@ -192,17 +203,21 @@ class DashboardAgent(object):
internal_kv._internal_kv_put(
f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{self.node_id}",
json.dumps([http_port, self.grpc_port]),
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
# Register agent to agent manager.
raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub(
self.aiogrpc_raylet_channel)
self.aiogrpc_raylet_channel
)
await raylet_stub.RegisterAgent(
agent_manager_pb2.RegisterAgentRequest(
agent_pid=os.getpid(),
agent_port=self.grpc_port,
agent_ip_address=self.ip))
agent_ip_address=self.ip,
)
)
tasks = [m.run(self.server) for m in modules]
if sys.platform not in ["win32", "cygwin"]:
@ -221,123 +236,139 @@ if __name__ == "__main__":
"--node-ip-address",
required=True,
type=str,
help="the IP address of this node.")
help="the IP address of this node.",
)
parser.add_argument(
"--gcs-address",
required=False,
type=str,
help="The address (ip:port) of GCS.")
"--gcs-address", required=False, type=str, help="The address (ip:port) of GCS."
)
parser.add_argument(
"--redis-address",
required=True,
type=str,
help="The address to use for Redis.")
"--redis-address", required=True, type=str, help="The address to use for Redis."
)
parser.add_argument(
"--metrics-export-port",
required=True,
type=int,
help="The port to expose metrics through Prometheus.")
help="The port to expose metrics through Prometheus.",
)
parser.add_argument(
"--dashboard-agent-port",
required=True,
type=int,
help="The port on which the dashboard agent will receive GRPCs.")
help="The port on which the dashboard agent will receive GRPCs.",
)
parser.add_argument(
"--node-manager-port",
required=True,
type=int,
help="The port to use for starting the node manager")
help="The port to use for starting the node manager",
)
parser.add_argument(
"--object-store-name",
required=True,
type=str,
default=None,
help="The socket name of the plasma store")
help="The socket name of the plasma store",
)
parser.add_argument(
"--listen-port",
required=False,
type=int,
default=0,
help="Port for HTTP server to listen on")
help="Port for HTTP server to listen on",
)
parser.add_argument(
"--raylet-name",
required=True,
type=str,
default=None,
help="The socket path of the raylet process")
help="The socket path of the raylet process",
)
parser.add_argument(
"--redis-password",
required=False,
type=str,
default=None,
help="The password to use for Redis")
help="The password to use for Redis",
)
parser.add_argument(
"--logging-level",
required=False,
type=lambda s: logging.getLevelName(s.upper()),
default=ray_constants.LOGGER_LEVEL,
choices=ray_constants.LOGGER_LEVEL_CHOICES,
help=ray_constants.LOGGER_LEVEL_HELP)
help=ray_constants.LOGGER_LEVEL_HELP,
)
parser.add_argument(
"--logging-format",
required=False,
type=str,
default=ray_constants.LOGGER_FORMAT,
help=ray_constants.LOGGER_FORMAT_HELP)
help=ray_constants.LOGGER_FORMAT_HELP,
)
parser.add_argument(
"--logging-filename",
required=False,
type=str,
default=dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME,
help="Specify the name of log file, "
"log to stdout if set empty, default is \"{}\".".format(
dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME))
'log to stdout if set empty, default is "{}".'.format(
dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME
),
)
parser.add_argument(
"--logging-rotate-bytes",
required=False,
type=int,
default=ray_constants.LOGGING_ROTATE_BYTES,
help="Specify the max bytes for rotating "
"log file, default is {} bytes.".format(
ray_constants.LOGGING_ROTATE_BYTES))
"log file, default is {} bytes.".format(ray_constants.LOGGING_ROTATE_BYTES),
)
parser.add_argument(
"--logging-rotate-backup-count",
required=False,
type=int,
default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
help="Specify the backup count of rotated log file, default is {}.".
format(ray_constants.LOGGING_ROTATE_BACKUP_COUNT))
help="Specify the backup count of rotated log file, default is {}.".format(
ray_constants.LOGGING_ROTATE_BACKUP_COUNT
),
)
parser.add_argument(
"--log-dir",
required=True,
type=str,
default=None,
help="Specify the path of log directory.")
help="Specify the path of log directory.",
)
parser.add_argument(
"--temp-dir",
required=True,
type=str,
default=None,
help="Specify the path of the temporary directory use by Ray process.")
help="Specify the path of the temporary directory use by Ray process.",
)
parser.add_argument(
"--session-dir",
required=True,
type=str,
default=None,
help="Specify the path of this session.")
help="Specify the path of this session.",
)
parser.add_argument(
"--runtime-env-dir",
required=True,
type=str,
default=None,
help="Specify the path of the resource directory used by runtime_env.")
help="Specify the path of the resource directory used by runtime_env.",
)
parser.add_argument(
"--minimal",
action="store_true",
help=(
"Minimal agent only contains a subset of features that don't "
"require additional dependencies installed when ray is installed "
"by `pip install ray[default]`."))
"by `pip install ray[default]`."
),
)
args = parser.parse_args()
try:
@ -347,7 +378,8 @@ if __name__ == "__main__":
log_dir=args.log_dir,
filename=args.logging_filename,
max_bytes=args.logging_rotate_bytes,
backup_count=args.logging_rotate_backup_count)
backup_count=args.logging_rotate_backup_count,
)
setup_component_logger(**logging_params)
agent = DashboardAgent(
@ -366,7 +398,8 @@ if __name__ == "__main__":
listen_port=args.listen_port,
object_store_name=args.object_store_name,
raylet_name=args.raylet_name,
logging_params=logging_params)
logging_params=logging_params,
)
if os.environ.get("_RAY_AGENT_FAILING"):
raise Exception("Failure injection failure.")
@ -390,15 +423,19 @@ if __name__ == "__main__":
gcs_publisher = GcsPublisher(args.gcs_address)
else:
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password)
args.redis_address, password=args.redis_password
)
gcs_publisher = GcsPublisher(
address=get_gcs_address_from_redis(redis_client))
address=get_gcs_address_from_redis(redis_client)
)
else:
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password)
args.redis_address, password=args.redis_password
)
traceback_str = ray._private.utils.format_error_message(
traceback.format_exc())
traceback.format_exc()
)
message = (
f"(ip={node_ip}) "
f"The agent on node {platform.uname()[1]} failed to "
@ -409,12 +446,14 @@ if __name__ == "__main__":
"\n 2. Metrics on this node won't be reported."
"\n 3. runtime_env APIs won't work."
"\nCheck out the `dashboard_agent.log` to see the "
"detailed failure messages.")
"detailed failure messages."
)
ray._private.utils.publish_error_to_driver(
ray_constants.DASHBOARD_AGENT_DIED_ERROR,
message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)
gcs_publisher=gcs_publisher,
)
logger.error(message)
logger.exception(e)
exit(1)

View file

@ -12,12 +12,13 @@ DASHBOARD_RPC_ADDRESS = "dashboard_rpc"
GCS_SERVER_ADDRESS = "GcsServerAddress"
# GCS check alive
GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR = env_integer(
"GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR", 10)
GCS_CHECK_ALIVE_INTERVAL_SECONDS = env_integer(
"GCS_CHECK_ALIVE_INTERVAL_SECONDS", 5)
"GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR", 10
)
GCS_CHECK_ALIVE_INTERVAL_SECONDS = env_integer("GCS_CHECK_ALIVE_INTERVAL_SECONDS", 5)
GCS_CHECK_ALIVE_RPC_TIMEOUT = env_integer("GCS_CHECK_ALIVE_RPC_TIMEOUT", 10)
GCS_RETRY_CONNECT_INTERVAL_SECONDS = env_integer(
"GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2)
"GCS_RETRY_CONNECT_INTERVAL_SECONDS", 2
)
# aiohttp_cache
AIOHTTP_CACHE_TTL_SECONDS = 2
AIOHTTP_CACHE_MAX_SIZE = 128

View file

@ -38,17 +38,21 @@ class FrontendNotFoundError(OSError):
def setup_static_dir():
build_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "client", "build")
os.path.dirname(os.path.abspath(__file__)), "client", "build"
)
module_name = os.path.basename(os.path.dirname(__file__))
if not os.path.isdir(build_dir):
raise FrontendNotFoundError(
errno.ENOENT, "Dashboard build directory not found. If installing "
errno.ENOENT,
"Dashboard build directory not found. If installing "
"from source, please follow the additional steps "
"required to build the dashboard"
f"(cd python/ray/{module_name}/client "
"&& npm install "
"&& npm ci "
"&& npm run build)", build_dir)
"&& npm run build)",
build_dir,
)
static_dir = os.path.join(build_dir, "static")
routes.static("/static", static_dir, follow_symlinks=True)
@ -72,14 +76,16 @@ class Dashboard:
log_dir(str): Log directory of dashboard.
"""
def __init__(self,
host,
port,
port_retries,
gcs_address,
redis_address,
redis_password=None,
log_dir=None):
def __init__(
self,
host,
port,
port_retries,
gcs_address,
redis_address,
redis_password=None,
log_dir=None,
):
self.dashboard_head = dashboard_head.DashboardHead(
http_host=host,
http_port=port,
@ -87,7 +93,8 @@ class Dashboard:
gcs_address=gcs_address,
redis_address=redis_address,
redis_password=redis_password,
log_dir=log_dir)
log_dir=log_dir,
)
# Setup Dashboard Routes
try:
@ -107,15 +114,17 @@ class Dashboard:
async def get_index(self, req) -> aiohttp.web.FileResponse:
return aiohttp.web.FileResponse(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"client/build/index.html"))
os.path.dirname(os.path.abspath(__file__)), "client/build/index.html"
)
)
@routes.get("/favicon.ico")
async def get_favicon(self, req) -> aiohttp.web.FileResponse:
return aiohttp.web.FileResponse(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"client/build/favicon.ico"))
os.path.dirname(os.path.abspath(__file__)), "client/build/favicon.ico"
)
)
async def run(self):
await self.dashboard_head.run()
@ -124,92 +133,96 @@ class Dashboard:
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Ray dashboard.")
parser.add_argument(
"--host",
required=True,
type=str,
help="The host to use for the HTTP server.")
"--host", required=True, type=str, help="The host to use for the HTTP server."
)
parser.add_argument(
"--port",
required=True,
type=int,
help="The port to use for the HTTP server.")
"--port", required=True, type=int, help="The port to use for the HTTP server."
)
parser.add_argument(
"--port-retries",
required=False,
type=int,
default=0,
help="The retry times to select a valid port.")
help="The retry times to select a valid port.",
)
parser.add_argument(
"--gcs-address",
required=False,
type=str,
help="The address (ip:port) of GCS.")
"--gcs-address", required=False, type=str, help="The address (ip:port) of GCS."
)
parser.add_argument(
"--redis-address",
required=True,
type=str,
help="The address to use for Redis.")
"--redis-address", required=True, type=str, help="The address to use for Redis."
)
parser.add_argument(
"--redis-password",
required=False,
type=str,
default=None,
help="The password to use for Redis")
help="The password to use for Redis",
)
parser.add_argument(
"--logging-level",
required=False,
type=lambda s: logging.getLevelName(s.upper()),
default=ray_constants.LOGGER_LEVEL,
choices=ray_constants.LOGGER_LEVEL_CHOICES,
help=ray_constants.LOGGER_LEVEL_HELP)
help=ray_constants.LOGGER_LEVEL_HELP,
)
parser.add_argument(
"--logging-format",
required=False,
type=str,
default=ray_constants.LOGGER_FORMAT,
help=ray_constants.LOGGER_FORMAT_HELP)
help=ray_constants.LOGGER_FORMAT_HELP,
)
parser.add_argument(
"--logging-filename",
required=False,
type=str,
default=dashboard_consts.DASHBOARD_LOG_FILENAME,
help="Specify the name of log file, "
"log to stdout if set empty, default is \"{}\"".format(
dashboard_consts.DASHBOARD_LOG_FILENAME))
'log to stdout if set empty, default is "{}"'.format(
dashboard_consts.DASHBOARD_LOG_FILENAME
),
)
parser.add_argument(
"--logging-rotate-bytes",
required=False,
type=int,
default=ray_constants.LOGGING_ROTATE_BYTES,
help="Specify the max bytes for rotating "
"log file, default is {} bytes.".format(
ray_constants.LOGGING_ROTATE_BYTES))
"log file, default is {} bytes.".format(ray_constants.LOGGING_ROTATE_BYTES),
)
parser.add_argument(
"--logging-rotate-backup-count",
required=False,
type=int,
default=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
help="Specify the backup count of rotated log file, default is {}.".
format(ray_constants.LOGGING_ROTATE_BACKUP_COUNT))
help="Specify the backup count of rotated log file, default is {}.".format(
ray_constants.LOGGING_ROTATE_BACKUP_COUNT
),
)
parser.add_argument(
"--log-dir",
required=True,
type=str,
default=None,
help="Specify the path of log directory.")
help="Specify the path of log directory.",
)
parser.add_argument(
"--temp-dir",
required=True,
type=str,
default=None,
help="Specify the path of the temporary directory use by Ray process.")
help="Specify the path of the temporary directory use by Ray process.",
)
parser.add_argument(
"--minimal",
action="store_true",
help=(
"Minimal dashboard only contains a subset of features that don't "
"require additional dependencies installed when ray is installed "
"by `pip install ray[default]`."))
"by `pip install ray[default]`."
),
)
args = parser.parse_args()
@ -226,7 +239,8 @@ if __name__ == "__main__":
log_dir=args.log_dir,
filename=args.logging_filename,
max_bytes=args.logging_rotate_bytes,
backup_count=args.logging_rotate_backup_count)
backup_count=args.logging_rotate_backup_count,
)
dashboard = Dashboard(
args.host,
@ -235,25 +249,27 @@ if __name__ == "__main__":
args.gcs_address,
args.redis_address,
redis_password=args.redis_password,
log_dir=args.log_dir)
log_dir=args.log_dir,
)
# TODO(fyrestone): Avoid using ray.state in dashboard, it's not
# asynchronous and will lead to low performance. ray disconnect()
# will be hang when the ray.state is connected and the GCS is exit.
# Please refer to: https://github.com/ray-project/ray/issues/16328
service_discovery = PrometheusServiceDiscoveryWriter(
args.redis_address, args.redis_password, args.gcs_address,
args.temp_dir)
args.redis_address, args.redis_password, args.gcs_address, args.temp_dir
)
# Need daemon True to avoid dashboard hangs at exit.
service_discovery.daemon = True
service_discovery.start()
loop = asyncio.get_event_loop()
loop.run_until_complete(dashboard.run())
except Exception as e:
traceback_str = ray._private.utils.format_error_message(
traceback.format_exc())
message = f"The dashboard on node {platform.uname()[1]} " \
f"failed with the following " \
f"error:\n{traceback_str}"
traceback_str = ray._private.utils.format_error_message(traceback.format_exc())
message = (
f"The dashboard on node {platform.uname()[1]} "
f"failed with the following "
f"error:\n{traceback_str}"
)
if isinstance(e, FrontendNotFoundError):
logger.warning(message)
else:
@ -268,17 +284,21 @@ if __name__ == "__main__":
gcs_publisher = GcsPublisher(args.gcs_address)
else:
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password)
args.redis_address, password=args.redis_password
)
gcs_publisher = GcsPublisher(
address=gcs_utils.get_gcs_address_from_redis(redis_client))
address=gcs_utils.get_gcs_address_from_redis(redis_client)
)
redis_client = None
else:
redis_client = ray._private.services.create_redis_client(
args.redis_address, password=args.redis_password)
args.redis_address, password=args.redis_password
)
ray._private.utils.publish_error_to_driver(
redis_client,
ray_constants.DASHBOARD_DIED_ERROR,
message,
redis_client=redis_client,
gcs_publisher=gcs_publisher)
gcs_publisher=gcs_publisher,
)

View file

@ -2,9 +2,9 @@ import asyncio
import logging
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.memory_utils as memory_utils
# TODO(fyrestone): Not import from dashboard module.
from ray.dashboard.modules.actor.actor_utils import \
actor_classname_from_task_spec
from ray.dashboard.modules.actor.actor_utils import actor_classname_from_task_spec
from ray.dashboard.utils import Dict, Signal, async_loop_forever
logger = logging.getLogger(__name__)
@ -132,9 +132,9 @@ class DataOrganizer:
worker["errorCount"] = len(node_errs.get(str(pid), []))
worker["coreWorkerStats"] = pid_to_worker_stats.get(pid, [])
worker["language"] = pid_to_language.get(
pid, dashboard_consts.DEFAULT_LANGUAGE)
worker["jobId"] = pid_to_job_id.get(
pid, dashboard_consts.DEFAULT_JOB_ID)
pid, dashboard_consts.DEFAULT_LANGUAGE
)
worker["jobId"] = pid_to_job_id.get(pid, dashboard_consts.DEFAULT_JOB_ID)
await GlobalSignals.worker_info_fetched.send(node_id, worker)
@ -143,8 +143,7 @@ class DataOrganizer:
@classmethod
async def get_node_info(cls, node_id):
node_physical_stats = dict(
DataSource.node_physical_stats.get(node_id, {}))
node_physical_stats = dict(DataSource.node_physical_stats.get(node_id, {}))
node_stats = dict(DataSource.node_stats.get(node_id, {}))
node = DataSource.nodes.get(node_id, {})
node_ip = DataSource.node_id_to_ip.get(node_id)
@ -162,8 +161,8 @@ class DataOrganizer:
view_data = node_stats.get("viewData", [])
ray_stats = cls._extract_view_data(
view_data,
{"object_store_used_memory", "object_store_available_memory"})
view_data, {"object_store_used_memory", "object_store_available_memory"}
)
node_info = node_physical_stats
# Merge node stats to node physical stats under raylet
@ -184,8 +183,7 @@ class DataOrganizer:
@classmethod
async def get_node_summary(cls, node_id):
node_physical_stats = dict(
DataSource.node_physical_stats.get(node_id, {}))
node_physical_stats = dict(DataSource.node_physical_stats.get(node_id, {}))
node_stats = dict(DataSource.node_stats.get(node_id, {}))
node = DataSource.nodes.get(node_id, {})
@ -193,8 +191,8 @@ class DataOrganizer:
node_stats.pop("workersStats", None)
view_data = node_stats.get("viewData", [])
ray_stats = cls._extract_view_data(
view_data,
{"object_store_used_memory", "object_store_available_memory"})
view_data, {"object_store_used_memory", "object_store_available_memory"}
)
node_stats.pop("viewData", None)
node_summary = node_physical_stats
@ -244,8 +242,9 @@ class DataOrganizer:
actor = dict(actor)
worker_id = actor["address"]["workerId"]
core_worker_stats = DataSource.core_worker_stats.get(worker_id, {})
actor_constructor = core_worker_stats.get("actorTitle",
"Unknown actor constructor")
actor_constructor = core_worker_stats.get(
"actorTitle", "Unknown actor constructor"
)
actor["actorConstructor"] = actor_constructor
actor.update(core_worker_stats)
@ -275,8 +274,12 @@ class DataOrganizer:
@classmethod
async def get_actor_creation_tasks(cls):
infeasible_tasks = sum(
(list(node_stats.get("infeasibleTasks", []))
for node_stats in DataSource.node_stats.values()), [])
(
list(node_stats.get("infeasibleTasks", []))
for node_stats in DataSource.node_stats.values()
),
[],
)
new_infeasible_tasks = []
for task in infeasible_tasks:
task = dict(task)
@ -285,8 +288,12 @@ class DataOrganizer:
new_infeasible_tasks.append(task)
resource_pending_tasks = sum(
(list(data.get("readyTasks", []))
for data in DataSource.node_stats.values()), [])
(
list(data.get("readyTasks", []))
for data in DataSource.node_stats.values()
),
[],
)
new_resource_pending_tasks = []
for task in resource_pending_tasks:
task = dict(task)
@ -301,14 +308,17 @@ class DataOrganizer:
return results
@classmethod
async def get_memory_table(cls,
sort_by=memory_utils.SortingType.OBJECT_SIZE,
group_by=memory_utils.GroupByType.STACK_TRACE):
async def get_memory_table(
cls,
sort_by=memory_utils.SortingType.OBJECT_SIZE,
group_by=memory_utils.GroupByType.STACK_TRACE,
):
all_worker_stats = []
for node_stats in DataSource.node_stats.values():
all_worker_stats.extend(node_stats.get("coreWorkersStats", []))
memory_information = memory_utils.construct_memory_table(
all_worker_stats, group_by=group_by, sort_by=sort_by)
all_worker_stats, group_by=group_by, sort_by=sort_by
)
return memory_information
@staticmethod

View file

@ -10,6 +10,7 @@ from queue import Queue
from distutils.version import LooseVersion
import grpc
try:
from grpc import aio as aiogrpc
except ImportError:
@ -23,8 +24,11 @@ import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
from ray import ray_constants
from ray._private.gcs_pubsub import gcs_pubsub_enabled, \
GcsAioErrorSubscriber, GcsAioLogSubscriber
from ray._private.gcs_pubsub import (
gcs_pubsub_enabled,
GcsAioErrorSubscriber,
GcsAioLogSubscriber,
)
from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc
from ray.dashboard.datacenter import DataOrganizer
@ -42,33 +46,33 @@ aiogrpc.init_grpc_aio()
GRPC_CHANNEL_OPTIONS = (
("grpc.enable_http_proxy", 0),
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length",
ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
)
async def get_gcs_address_with_retry(redis_client) -> str:
while True:
try:
gcs_address = (await redis_client.get(
dashboard_consts.GCS_SERVER_ADDRESS)).decode()
gcs_address = (
await redis_client.get(dashboard_consts.GCS_SERVER_ADDRESS)
).decode()
if not gcs_address:
raise Exception("GCS address not found.")
logger.info("Connect to GCS at %s", gcs_address)
return gcs_address
except Exception as ex:
logger.error("Connect to GCS failed: %s, retry...", ex)
await asyncio.sleep(
dashboard_consts.GCS_RETRY_CONNECT_INTERVAL_SECONDS)
await asyncio.sleep(dashboard_consts.GCS_RETRY_CONNECT_INTERVAL_SECONDS)
class GCSHealthCheckThread(threading.Thread):
def __init__(self, gcs_address: str):
self.grpc_gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, options=GRPC_CHANNEL_OPTIONS)
self.gcs_heartbeat_info_stub = (
gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
self.grpc_gcs_channel))
gcs_address, options=GRPC_CHANNEL_OPTIONS
)
self.gcs_heartbeat_info_stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
self.grpc_gcs_channel
)
self.work_queue = Queue()
super().__init__(daemon=True)
@ -83,10 +87,10 @@ class GCSHealthCheckThread(threading.Thread):
request = gcs_service_pb2.CheckAliveRequest()
try:
reply = self.gcs_heartbeat_info_stub.CheckAlive(
request, timeout=dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT)
request, timeout=dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT
)
if reply.status.code != 0:
logger.exception(
f"Failed to CheckAlive: {reply.status.message}")
logger.exception(f"Failed to CheckAlive: {reply.status.message}")
return False
except grpc.RpcError: # Deadline Exceeded
logger.exception("Got RpcError when checking GCS is alive")
@ -95,9 +99,9 @@ class GCSHealthCheckThread(threading.Thread):
async def check_once(self) -> bool:
"""Ask the thread to perform a healthcheck."""
assert threading.current_thread != self, (
"caller shouldn't be from the same thread as GCSHealthCheckThread."
)
assert (
threading.current_thread != self
), "caller shouldn't be from the same thread as GCSHealthCheckThread."
future = Future()
self.work_queue.put(future)
@ -105,8 +109,16 @@ class GCSHealthCheckThread(threading.Thread):
class DashboardHead:
def __init__(self, http_host, http_port, http_port_retries, gcs_address,
redis_address, redis_password, log_dir):
def __init__(
self,
http_host,
http_port,
http_port_retries,
gcs_address,
redis_address,
redis_password,
log_dir,
):
self.health_check_thread: GCSHealthCheckThread = None
self._gcs_rpc_error_counter = 0
# Public attributes are accessible for all head modules.
@ -134,12 +146,12 @@ class DashboardHead:
else:
ip, port = gcs_address.split(":")
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
self.grpc_port = ray._private.tls_utils.add_port_to_grpc_server(
self.server, f"{grpc_ip}:0")
logger.info("Dashboard head grpc address: %s:%s", grpc_ip,
self.grpc_port)
self.server, f"{grpc_ip}:0"
)
logger.info("Dashboard head grpc address: %s:%s", grpc_ip, self.grpc_port)
@async_loop_forever(dashboard_consts.GCS_CHECK_ALIVE_INTERVAL_SECONDS)
async def _gcs_check_alive(self):
@ -149,7 +161,8 @@ class DashboardHead:
# Otherwise, the dashboard will always think that gcs is alive.
try:
is_alive = await asyncio.wait_for(
check_future, dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT + 1)
check_future, dashboard_consts.GCS_CHECK_ALIVE_RPC_TIMEOUT + 1
)
except asyncio.TimeoutError:
logger.error("Failed to check gcs health, client timed out.")
is_alive = False
@ -158,13 +171,16 @@ class DashboardHead:
self._gcs_rpc_error_counter = 0
else:
self._gcs_rpc_error_counter += 1
if self._gcs_rpc_error_counter > \
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR:
if (
self._gcs_rpc_error_counter
> dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR
):
logger.error(
"Dashboard exiting because it received too many GCS RPC "
"errors count: %s, threshold is %s.",
self._gcs_rpc_error_counter,
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR)
dashboard_consts.GCS_CHECK_ALIVE_MAX_COUNT_OF_RPC_ERROR,
)
# TODO(fyrestone): Do not use ray.state in
# PrometheusServiceDiscoveryWriter.
# Currently, we use os._exit() here to avoid hanging at the ray
@ -176,10 +192,12 @@ class DashboardHead:
"""Load dashboard head modules."""
modules = []
head_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardHeadModule)
dashboard_utils.DashboardHeadModule
)
for cls in head_cls_list:
logger.info("Loading %s: %s",
dashboard_utils.DashboardHeadModule.__name__, cls)
logger.info(
"Loading %s: %s", dashboard_utils.DashboardHeadModule.__name__, cls
)
c = cls(self)
dashboard_optional_utils.ClassMethodRouteTable.bind(c)
modules.append(c)
@ -192,15 +210,17 @@ class DashboardHead:
return self.gcs_address
else:
try:
self.aioredis_client = \
await dashboard_utils.get_aioredis_client(
self.redis_address, self.redis_password,
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES)
self.aioredis_client = await dashboard_utils.get_aioredis_client(
self.redis_address,
self.redis_password,
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES,
)
except (socket.gaierror, ConnectionError):
logger.error(
"Dashboard head exiting: "
"Failed to connect to redis at %s", self.redis_address)
"Dashboard head exiting: " "Failed to connect to redis at %s",
self.redis_address,
)
sys.exit(-1)
return await get_gcs_address_with_retry(self.aioredis_client)
@ -209,22 +229,20 @@ class DashboardHead:
# Create a http session for all modules.
# aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
self.http_session = aiohttp.ClientSession(
loop=asyncio.get_event_loop())
self.http_session = aiohttp.ClientSession(loop=asyncio.get_event_loop())
else:
self.http_session = aiohttp.ClientSession()
gcs_address = await self.get_gcs_address()
# Dashboard will handle connection failure automatically
self.gcs_client = GcsClient(
address=gcs_address, nums_reconnect_retry=0)
self.gcs_client = GcsClient(address=gcs_address, nums_reconnect_retry=0)
internal_kv._initialize_internal_kv(self.gcs_client)
self.aiogrpc_gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True)
gcs_address, GRPC_CHANNEL_OPTIONS, asynchronous=True
)
if gcs_pubsub_enabled():
self.gcs_error_subscriber = GcsAioErrorSubscriber(
address=gcs_address)
self.gcs_error_subscriber = GcsAioErrorSubscriber(address=gcs_address)
self.gcs_log_subscriber = GcsAioLogSubscriber(address=gcs_address)
await self.gcs_error_subscriber.subscribe()
await self.gcs_log_subscriber.subscribe()
@ -248,7 +266,7 @@ class DashboardHead:
# Http server should be initialized after all modules loaded.
# working_dir uploads for job submission can be up to 100MiB.
app = aiohttp.web.Application(client_max_size=100 * 1024**2)
app = aiohttp.web.Application(client_max_size=100 * 1024 ** 2)
app.add_routes(routes=routes.bound_routes())
runner = aiohttp.web.AppRunner(app)
@ -256,8 +274,7 @@ class DashboardHead:
last_ex = None
for i in range(1 + self.http_port_retries):
try:
site = aiohttp.web.TCPSite(runner, self.http_host,
self.http_port)
site = aiohttp.web.TCPSite(runner, self.http_host, self.http_port)
await site.start()
break
except OSError as e:
@ -265,11 +282,14 @@ class DashboardHead:
self.http_port += 1
logger.warning("Try to use port %s: %s", self.http_port, e)
else:
raise Exception(f"Failed to find a valid port for dashboard after "
f"{self.http_port_retries} retries: {last_ex}")
raise Exception(
f"Failed to find a valid port for dashboard after "
f"{self.http_port_retries} retries: {last_ex}"
)
http_host, http_port, *_ = site._server.sockets[0].getsockname()
http_host = self.ip if ipaddress.ip_address(
http_host).is_unspecified else http_host
http_host = (
self.ip if ipaddress.ip_address(http_host).is_unspecified else http_host
)
logger.info("Dashboard head http address: %s:%s", http_host, http_port)
# TODO: Use async version if performance is an issue
@ -277,16 +297,16 @@ class DashboardHead:
internal_kv._internal_kv_put(
ray_constants.DASHBOARD_ADDRESS,
f"{http_host}:{http_port}",
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
internal_kv._internal_kv_put(
dashboard_consts.DASHBOARD_RPC_ADDRESS,
f"{self.ip}:{self.grpc_port}",
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
# Dump registered http routes.
dump_routes = [
r for r in app.router.routes() if r.method != hdrs.METH_HEAD
]
dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD]
for r in dump_routes:
logger.info(r)
logger.info("Registered %s routes.", len(dump_routes))
@ -299,6 +319,5 @@ class DashboardHead:
DataOrganizer.purge(),
DataOrganizer.organize(),
]
await asyncio.gather(*concurrent_tasks,
*(m.run(self.server) for m in modules))
await asyncio.gather(*concurrent_tasks, *(m.run(self.server) for m in modules))
await self.server.wait_for_termination()

View file

@ -23,8 +23,7 @@ class HttpServerAgent:
# Create a http session for all modules.
# aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
self.http_session = aiohttp.ClientSession(
loop=asyncio.get_event_loop())
self.http_session = aiohttp.ClientSession(loop=asyncio.get_event_loop())
else:
self.http_session = aiohttp.ClientSession()
@ -47,25 +46,26 @@ class HttpServerAgent:
allow_methods="*",
allow_headers=("Content-Type", "X-Header"),
)
})
},
)
for route in list(app.router.routes()):
cors.add(route)
self.runner = aiohttp.web.AppRunner(app)
await self.runner.setup()
site = aiohttp.web.TCPSite(
self.runner, "127.0.0.1"
if self.ip == "127.0.0.1" else "0.0.0.0", self.listen_port)
self.runner,
"127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0",
self.listen_port,
)
await site.start()
self.http_host, self.http_port, *_ = (
site._server.sockets[0].getsockname())
logger.info("Dashboard agent http address: %s:%s", self.http_host,
self.http_port)
self.http_host, self.http_port, *_ = site._server.sockets[0].getsockname()
logger.info(
"Dashboard agent http address: %s:%s", self.http_host, self.http_port
)
# Dump registered http routes.
dump_routes = [
r for r in app.router.routes() if r.method != hdrs.METH_HEAD
]
dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD]
for r in dump_routes:
logger.info(r)
logger.info("Registered %s routes.", len(dump_routes))

View file

@ -31,7 +31,7 @@ def cpu_percent():
delta in total host cpu usage, averaged over host's cpus.
Since deltas are not initially available, return 0.0 on first call.
""" # noqa
""" # noqa
global last_system_usage
global last_cpu_usage
try:
@ -43,12 +43,10 @@ def cpu_percent():
else:
cpu_delta = cpu_usage - last_cpu_usage
# "System time passed." (Typically close to clock time.)
system_delta = (
(system_usage - last_system_usage) / _host_num_cpus())
system_delta = (system_usage - last_system_usage) / _host_num_cpus()
quotient = cpu_delta / system_delta
cpu_percent = round(
quotient * 100 / ray._private.utils.get_k8s_cpus(), 1)
cpu_percent = round(quotient * 100 / ray._private.utils.get_k8s_cpus(), 1)
last_system_usage = system_usage
last_cpu_usage = cpu_usage
# Computed percentage might be slightly above 100%.
@ -73,14 +71,14 @@ def _system_usage():
See also the /proc/stat entry here:
https://man7.org/linux/man-pages/man5/proc.5.html
""" # noqa
""" # noqa
cpu_summary_str = open(PROC_STAT_PATH).read().split("\n")[0]
parts = cpu_summary_str.split()
assert parts[0] == "cpu"
usage_data = parts[1:8]
total_clock_ticks = sum(int(entry) for entry in usage_data)
# 100 clock ticks per second, 10^9 ns per second
usage_ns = total_clock_ticks * 10**7
usage_ns = total_clock_ticks * 10 ** 7
return usage_ns
@ -91,7 +89,8 @@ def _host_num_cpus():
proc_stat_lines = open(PROC_STAT_PATH).read().split("\n")
split_proc_stat_lines = [line.split() for line in proc_stat_lines]
cpu_lines = [
split_line for split_line in split_proc_stat_lines
split_line
for split_line in split_proc_stat_lines
if len(split_line) > 0 and "cpu" in split_line[0]
]
# Number of lines starting with a word including 'cpu', subtracting

View file

@ -6,7 +6,7 @@ from typing import List
import ray
from ray._raylet import (TaskID, ActorID, JobID)
from ray._raylet import TaskID, ActorID, JobID
from ray.internal.internal_api import node_stats
import logging
@ -69,8 +69,10 @@ def get_sorting_type(sort_by: str):
elif sort_by == "REFERENCE_TYPE":
return SortingType.REFERENCE_TYPE
else:
raise Exception("The sort-by input provided is not one of\
PID, OBJECT_SIZE, or REFERENCE_TYPE.")
raise Exception(
"The sort-by input provided is not one of\
PID, OBJECT_SIZE, or REFERENCE_TYPE."
)
def get_group_by_type(group_by: str):
@ -81,13 +83,16 @@ def get_group_by_type(group_by: str):
elif group_by == "STACK_TRACE":
return GroupByType.STACK_TRACE
else:
raise Exception("The group-by input provided is not one of\
NODE_ADDRESS or STACK_TRACE.")
raise Exception(
"The group-by input provided is not one of\
NODE_ADDRESS or STACK_TRACE."
)
class MemoryTableEntry:
def __init__(self, *, object_ref: dict, node_address: str, is_driver: bool,
pid: int):
def __init__(
self, *, object_ref: dict, node_address: str, is_driver: bool, pid: int
):
# worker info
self.is_driver = is_driver
self.pid = pid
@ -97,13 +102,13 @@ class MemoryTableEntry:
self.object_size = int(object_ref.get("objectSize", -1))
self.call_site = object_ref.get("callSite", "<Unknown>")
self.object_ref = ray.ObjectRef(
decode_object_ref_if_needed(object_ref["objectId"]))
decode_object_ref_if_needed(object_ref["objectId"])
)
# reference info
self.local_ref_count = int(object_ref.get("localRefCount", 0))
self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False))
self.submitted_task_ref_count = int(
object_ref.get("submittedTaskRefCount", 0))
self.submitted_task_ref_count = int(object_ref.get("submittedTaskRefCount", 0))
self.contained_in_owned = [
ray.ObjectRef(decode_object_ref_if_needed(object_ref))
for object_ref in object_ref.get("containedInOwned", [])
@ -113,9 +118,12 @@ class MemoryTableEntry:
def is_valid(self) -> bool:
# If the entry doesn't have a reference type or some invalid state,
# (e.g., no object ref presented), it is considered invalid.
if (not self.pinned_in_memory and self.local_ref_count == 0
and self.submitted_task_ref_count == 0
and len(self.contained_in_owned) == 0):
if (
not self.pinned_in_memory
and self.local_ref_count == 0
and self.submitted_task_ref_count == 0
and len(self.contained_in_owned) == 0
):
return False
elif self.object_ref.is_nil():
return False
@ -153,10 +161,10 @@ class MemoryTableEntry:
# are not all 'f', that means it is an actor creation
# task, which is an actor handle.
random_bits = object_ref_hex[:TASKID_RANDOM_BITS_SIZE]
actor_random_bits = object_ref_hex[TASKID_RANDOM_BITS_SIZE:
TASKID_RANDOM_BITS_SIZE +
ACTORID_RANDOM_BITS_SIZE]
if (random_bits == "f" * 16 and not actor_random_bits == "f" * 24):
actor_random_bits = object_ref_hex[
TASKID_RANDOM_BITS_SIZE : TASKID_RANDOM_BITS_SIZE + ACTORID_RANDOM_BITS_SIZE
]
if random_bits == "f" * 16 and not actor_random_bits == "f" * 24:
return True
else:
return False
@ -175,7 +183,7 @@ class MemoryTableEntry:
"contained_in_owned": [
object_ref.hex() for object_ref in self.contained_in_owned
],
"type": "Driver" if self.is_driver else "Worker"
"type": "Driver" if self.is_driver else "Worker",
}
def __str__(self):
@ -186,10 +194,12 @@ class MemoryTableEntry:
class MemoryTable:
def __init__(self,
entries: List[MemoryTableEntry],
group_by_type: GroupByType = GroupByType.NODE_ADDRESS,
sort_by_type: SortingType = SortingType.PID):
def __init__(
self,
entries: List[MemoryTableEntry],
group_by_type: GroupByType = GroupByType.NODE_ADDRESS,
sort_by_type: SortingType = SortingType.PID,
):
self.table = entries
# Group is a list of memory tables grouped by a group key.
self.group = {}
@ -247,7 +257,7 @@ class MemoryTable:
"total_pinned_in_memory": total_pinned_in_memory,
"total_used_by_pending_task": total_used_by_pending_task,
"total_captured_in_objects": total_captured_in_objects,
"total_actor_handles": total_actor_handles
"total_actor_handles": total_actor_handles,
}
return self
@ -278,7 +288,8 @@ class MemoryTable:
# Build a group table.
for group_key, entries in group.items():
self.group[group_key] = MemoryTable(
entries, group_by_type=None, sort_by_type=None)
entries, group_by_type=None, sort_by_type=None
)
for group_key, group_memory_table in self.group.items():
group_memory_table.summarize()
return self
@ -289,10 +300,10 @@ class MemoryTable:
"group": {
group_key: {
"entries": group_memory_table.get_entries(),
"summary": group_memory_table.summary
"summary": group_memory_table.summary,
}
for group_key, group_memory_table in self.group.items()
}
},
}
def get_entries(self) -> List[dict]:
@ -305,9 +316,11 @@ class MemoryTable:
return self.__repr__()
def construct_memory_table(workers_stats: List,
group_by: GroupByType = GroupByType.NODE_ADDRESS,
sort_by=SortingType.OBJECT_SIZE) -> MemoryTable:
def construct_memory_table(
workers_stats: List,
group_by: GroupByType = GroupByType.NODE_ADDRESS,
sort_by=SortingType.OBJECT_SIZE,
) -> MemoryTable:
memory_table_entries = []
for core_worker_stats in workers_stats:
pid = core_worker_stats["pid"]
@ -320,11 +333,13 @@ def construct_memory_table(workers_stats: List,
object_ref=object_ref,
node_address=node_address,
is_driver=is_driver,
pid=pid)
pid=pid,
)
if memory_table_entry.is_valid():
memory_table_entries.append(memory_table_entry)
memory_table = MemoryTable(
memory_table_entries, group_by_type=group_by, sort_by_type=sort_by)
memory_table_entries, group_by_type=group_by, sort_by_type=sort_by
)
return memory_table
@ -337,7 +352,7 @@ def track_reference_size(group):
"PINNED_IN_MEMORY": "total_pinned_in_memory",
"USED_BY_PENDING_TASK": "total_used_by_pending_task",
"CAPTURED_IN_OBJECT": "total_captured_in_objects",
"ACTOR_HANDLE": "total_actor_handles"
"ACTOR_HANDLE": "total_actor_handles",
}
for entry in group["entries"]:
size = entry["object_size"]
@ -348,51 +363,64 @@ def track_reference_size(group):
return d
def memory_summary(state,
group_by="NODE_ADDRESS",
sort_by="OBJECT_SIZE",
line_wrap=True,
unit="B",
num_entries=None) -> str:
from ray.dashboard.modules.node.node_head\
import node_stats_to_dict
def memory_summary(
state,
group_by="NODE_ADDRESS",
sort_by="OBJECT_SIZE",
line_wrap=True,
unit="B",
num_entries=None,
) -> str:
from ray.dashboard.modules.node.node_head import node_stats_to_dict
# Get terminal size
import shutil
size = shutil.get_terminal_size((80, 20)).columns
line_wrap_threshold = 137
# Unit conversions
units = {"B": 10**0, "KB": 10**3, "MB": 10**6, "GB": 10**9}
units = {"B": 10 ** 0, "KB": 10 ** 3, "MB": 10 ** 6, "GB": 10 ** 9}
# Fetch core memory worker stats, store as a dictionary
core_worker_stats = []
for raylet in state.node_table():
stats = node_stats_to_dict(
node_stats(raylet["NodeManagerAddress"],
raylet["NodeManagerPort"]))
node_stats(raylet["NodeManagerAddress"], raylet["NodeManagerPort"])
)
core_worker_stats.extend(stats["coreWorkersStats"])
assert type(stats) is dict and "coreWorkersStats" in stats
# Build memory table with "group_by" and "sort_by" parameters
group_by, sort_by = get_group_by_type(group_by), get_sorting_type(sort_by)
memory_table = construct_memory_table(core_worker_stats, group_by,
sort_by).as_dict()
memory_table = construct_memory_table(
core_worker_stats, group_by, sort_by
).as_dict()
assert "summary" in memory_table and "group" in memory_table
# Build memory summary
mem = ""
group_by, sort_by = group_by.name.lower().replace(
"_", " "), sort_by.name.lower().replace("_", " ")
"_", " "
), sort_by.name.lower().replace("_", " ")
summary_labels = [
"Mem Used by Objects", "Local References", "Pinned", "Pending Tasks",
"Captured in Objects", "Actor Handles"
"Mem Used by Objects",
"Local References",
"Pinned",
"Pending Tasks",
"Captured in Objects",
"Actor Handles",
]
summary_string = "{:<19} {:<16} {:<12} {:<13} {:<19} {:<13}\n"
object_ref_labels = [
"IP Address", "PID", "Type", "Call Site", "Size", "Reference Type",
"Object Ref"
"IP Address",
"PID",
"Type",
"Call Site",
"Size",
"Reference Type",
"Object Ref",
]
object_ref_string = "{:<13} | {:<8} | {:<7} | {:<9} \
| {:<8} | {:<14} | {:<10}\n"
@ -416,22 +444,21 @@ entries per group...\n\n\n"
else:
summary[k] = str(v) + f", ({ref_size[k] / units[unit]} {unit})"
mem += f"--- Summary for {group_by}: {key} ---\n"
mem += summary_string\
.format(*summary_labels)
mem += summary_string\
.format(*summary.values()) + "\n"
mem += summary_string.format(*summary_labels)
mem += summary_string.format(*summary.values()) + "\n"
# Memory table per group
mem += f"--- Object references for {group_by}: {key} ---\n"
mem += object_ref_string\
.format(*object_ref_labels)
mem += object_ref_string.format(*object_ref_labels)
n = 1 # Counter for num entries per group
for entry in group["entries"]:
if num_entries is not None and n > num_entries:
break
entry["object_size"] = str(
entry["object_size"] /
units[unit]) + f" {unit}" if entry["object_size"] > -1 else "?"
entry["object_size"] = (
str(entry["object_size"] / units[unit]) + f" {unit}"
if entry["object_size"] > -1
else "?"
)
num_lines = 1
if size > line_wrap_threshold and line_wrap:
call_site_length = 22
@ -439,30 +466,36 @@ entries per group...\n\n\n"
entry["call_site"] = ["disabled"]
else:
entry["call_site"] = [
entry["call_site"][i:i + call_site_length] for i in
range(0, len(entry["call_site"]), call_site_length)
entry["call_site"][i : i + call_site_length]
for i in range(0, len(entry["call_site"]), call_site_length)
]
num_lines = len(entry["call_site"])
else:
mem += "\n"
object_ref_values = [
entry["node_ip_address"], entry["pid"], entry["type"],
entry["call_site"], entry["object_size"],
entry["reference_type"], entry["object_ref"]
entry["node_ip_address"],
entry["pid"],
entry["type"],
entry["call_site"],
entry["object_size"],
entry["reference_type"],
entry["object_ref"],
]
for i in range(len(object_ref_values)):
if not isinstance(object_ref_values[i], list):
object_ref_values[i] = [object_ref_values[i]]
object_ref_values[i].extend(
["" for x in range(num_lines - len(object_ref_values[i]))])
["" for x in range(num_lines - len(object_ref_values[i]))]
)
for i in range(num_lines):
row = [elem[i] for elem in object_ref_values]
mem += object_ref_string\
.format(*row)
mem += object_ref_string.format(*row)
mem += "\n"
n += 1
mem += "To record callsite information for each ObjectRef created, set " \
"env variable RAY_record_ref_creation_sites=1\n\n"
mem += (
"To record callsite information for each ObjectRef created, set "
"env variable RAY_record_ref_creation_sites=1\n\n"
)
return mem

View file

@ -4,6 +4,7 @@ import aiohttp.web
import ray._private.utils
from ray.dashboard.modules.actor import actor_utils
from aioredis.pubsub import Receiver
try:
from grpc import aio as aiogrpc
except ImportError:
@ -15,8 +16,7 @@ import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
from ray.dashboard.optional_utils import rest_response
from ray.dashboard.modules.actor import actor_consts
from ray.dashboard.modules.actor.actor_utils import \
actor_classname_from_task_spec
from ray.dashboard.modules.actor.actor_utils import actor_classname_from_task_spec
from ray.core.generated import node_manager_pb2_grpc
from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc
@ -30,12 +30,22 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
def actor_table_data_to_dict(message):
orig_message = dashboard_utils.message_to_dict(
message, {
"actorId", "parentId", "jobId", "workerId", "rayletId",
"actorCreationDummyObjectId", "callerId", "taskId", "parentTaskId",
"sourceActorId", "placementGroupId"
message,
{
"actorId",
"parentId",
"jobId",
"workerId",
"rayletId",
"actorCreationDummyObjectId",
"callerId",
"taskId",
"parentTaskId",
"sourceActorId",
"placementGroupId",
},
including_default_value_fields=True)
including_default_value_fields=True,
)
# The complete schema for actor table is here:
# src/ray/protobuf/gcs.proto
# It is super big and for dashboard, we don't need that much information.
@ -58,8 +68,7 @@ def actor_table_data_to_dict(message):
light_message["actorClass"] = actor_class
if "functionDescriptor" in light_message["taskSpec"]:
light_message["taskSpec"] = {
"functionDescriptor": light_message["taskSpec"][
"functionDescriptor"]
"functionDescriptor": light_message["taskSpec"]["functionDescriptor"]
}
else:
light_message.pop("taskSpec")
@ -81,11 +90,13 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
if change.new:
# TODO(fyrestone): Handle exceptions.
node_id, node_info = change.new
address = "{}:{}".format(node_info["nodeManagerAddress"],
int(node_info["nodeManagerPort"]))
options = (("grpc.enable_http_proxy", 0), )
address = "{}:{}".format(
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
)
options = (("grpc.enable_http_proxy", 0),)
channel = ray._private.utils.init_grpc_channel(
address, options, asynchronous=True)
address, options, asynchronous=True
)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
self._stubs[node_id] = stub
@ -96,7 +107,8 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
logger.info("Getting all actor info from GCS.")
request = gcs_service_pb2.GetAllActorInfoRequest()
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
request, timeout=5)
request, timeout=5
)
if reply.status.code == 0:
actors = {}
for message in reply.actor_table_data:
@ -110,24 +122,25 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
for actor_id, actor_table_data in actors.items():
job_id = actor_table_data["jobId"]
node_id = actor_table_data["address"]["rayletId"]
job_actors.setdefault(job_id,
{})[actor_id] = actor_table_data
job_actors.setdefault(job_id, {})[actor_id] = actor_table_data
# Update only when node_id is not Nil.
if node_id != actor_consts.NIL_NODE_ID:
node_actors.setdefault(
node_id, {})[actor_id] = actor_table_data
node_actors.setdefault(node_id, {})[
actor_id
] = actor_table_data
DataSource.job_actors.reset(job_actors)
DataSource.node_actors.reset(node_actors)
logger.info("Received %d actor info from GCS.",
len(actors))
logger.info("Received %d actor info from GCS.", len(actors))
break
else:
raise Exception(
f"Failed to GetAllActorInfo: {reply.status.message}")
f"Failed to GetAllActorInfo: {reply.status.message}"
)
except Exception:
logger.exception("Error Getting all actor info from GCS.")
await asyncio.sleep(
actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS)
actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS
)
state_keys = ("state", "address", "numRestarts", "timestamp", "pid")
@ -167,8 +180,7 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
if actor_id is not None:
# Convert to lower case hex ID.
actor_id = actor_id.hex()
process_actor_data_from_pubsub(actor_id,
actor_table_data)
process_actor_data_from_pubsub(actor_id, actor_table_data)
except Exception:
logger.exception("Error processing actor info from GCS.")
@ -183,12 +195,15 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
async for sender, msg in receiver.iter():
try:
actor_id, actor_table_data = msg
actor_id = actor_id.decode("UTF-8")[len(
gcs_utils.TablePrefix_ACTOR_string + ":"):]
actor_id = actor_id.decode("UTF-8")[
len(gcs_utils.TablePrefix_ACTOR_string + ":") :
]
pubsub_message = gcs_utils.PubSubMessage.FromString(
actor_table_data)
actor_table_data
)
actor_table_data = gcs_utils.ActorTableData.FromString(
pubsub_message.data)
pubsub_message.data
)
process_actor_data_from_pubsub(actor_id, actor_table_data)
except Exception:
logger.exception("Error processing actor info from Redis.")
@ -203,17 +218,15 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
actors.update(actor_creation_tasks)
actor_groups = actor_utils.construct_actor_groups(actors)
return rest_response(
success=True,
message="Fetched actor groups.",
actor_groups=actor_groups)
success=True, message="Fetched actor groups.", actor_groups=actor_groups
)
@routes.get("/logical/actors")
@dashboard_optional_utils.aiohttp_cache
async def get_all_actors(self, req) -> aiohttp.web.Response:
return rest_response(
success=True,
message="All actors fetched.",
actors=DataSource.actors)
success=True, message="All actors fetched.", actors=DataSource.actors
)
@routes.get("/logical/kill_actor")
async def kill_actor(self, req) -> aiohttp.web.Response:
@ -224,15 +237,17 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
except KeyError:
return rest_response(success=False, message="Bad Request")
try:
options = (("grpc.enable_http_proxy", 0), )
options = (("grpc.enable_http_proxy", 0),)
channel = ray._private.utils.init_grpc_channel(
f"{ip_address}:{port}", options=options, asynchronous=True)
f"{ip_address}:{port}", options=options, asynchronous=True
)
stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel)
await stub.KillActor(
core_worker_pb2.KillActorRequest(
intended_actor_id=ray._private.utils.hex_to_binary(
actor_id)))
intended_actor_id=ray._private.utils.hex_to_binary(actor_id)
)
)
except aiogrpc.AioRpcError:
# This always throws an exception because the worker
@ -240,13 +255,13 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
# before this handler, however it deletes the actor correctly.
pass
return rest_response(
success=True, message=f"Killed actor with id {actor_id}")
return rest_response(success=True, message=f"Killed actor with id {actor_id}")
async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._gcs_actor_info_stub = \
gcs_service_pb2_grpc.ActorInfoGcsServiceStub(gcs_channel)
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
gcs_channel
)
await asyncio.gather(self._update_actors())

View file

@ -7,27 +7,29 @@ PYCLASSNAME_RE = re.compile(r"(.+?)\(")
def construct_actor_groups(actors):
"""actors is a dict from actor id to an actor or an
actor creation task The shared fields currently are
"actorClass", "actorId", and "state" """
actor creation task The shared fields currently are
"actorClass", "actorId", and "state" """
actor_groups = _group_actors_by_python_class(actors)
stats_by_group = {
name: _get_actor_group_stats(group)
for name, group in actor_groups.items()
name: _get_actor_group_stats(group) for name, group in actor_groups.items()
}
summarized_actor_groups = {}
for name, group in actor_groups.items():
summarized_actor_groups[name] = {
"entries": group,
"summary": stats_by_group[name]
"summary": stats_by_group[name],
}
return summarized_actor_groups
def actor_classname_from_task_spec(task_spec):
return task_spec.get("functionDescriptor", {})\
.get("pythonFunctionDescriptor", {})\
.get("className", "Unknown actor class").split(".")[-1]
return (
task_spec.get("functionDescriptor", {})
.get("pythonFunctionDescriptor", {})
.get("className", "Unknown actor class")
.split(".")[-1]
)
def _group_actors_by_python_class(actors):

View file

@ -50,8 +50,7 @@ def test_actor_groups(ray_start_with_dashboard):
response = requests.get(webui_url + "/logical/actor_groups")
response.raise_for_status()
actor_groups_resp = response.json()
assert actor_groups_resp["result"] is True, actor_groups_resp[
"msg"]
assert actor_groups_resp["result"] is True, actor_groups_resp["msg"]
actor_groups = actor_groups_resp["data"]["actorGroups"]
assert "Foo" in actor_groups
summary = actor_groups["Foo"]["summary"]
@ -78,9 +77,13 @@ def test_actor_groups(ray_start_with_dashboard):
last_ex = ex
finally:
if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception(
type(last_ex), last_ex,
last_ex.__traceback__) if last_ex else []
ex_stack = (
traceback.format_exception(
type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}")
@ -135,9 +138,13 @@ def test_actors(disable_aiohttp_cache, ray_start_with_dashboard):
last_ex = ex
finally:
if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception(
type(last_ex), last_ex,
last_ex.__traceback__) if last_ex else []
ex_stack = (
traceback.format_exception(
type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}")
@ -183,8 +190,9 @@ def test_kill_actor(ray_start_with_dashboard):
params={
"actorId": actor["actorId"],
"ipAddress": actor["ipAddress"],
"port": actor["port"]
})
"port": actor["port"],
},
)
resp.raise_for_status()
resp_json = resp.json()
assert resp_json["result"] is True, "msg" in resp_json
@ -199,19 +207,17 @@ def test_kill_actor(ray_start_with_dashboard):
break
except (KeyError, AssertionError) as e:
last_exc = e
time.sleep(.1)
time.sleep(0.1)
assert last_exc is None
def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
timeout = 5
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
address_info = ray_start_with_dashboard
if gcs_pubsub.gcs_pubsub_enabled():
sub = gcs_pubsub.GcsActorSubscriber(
address=address_info["gcs_address"])
sub = gcs_pubsub.GcsActorSubscriber(address=address_info["gcs_address"])
sub.subscribe()
else:
address = address_info["redis_address"]
@ -221,7 +227,8 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
client = redis.StrictRedis(
host=address[0],
port=int(address[1]),
password=ray_constants.REDIS_DEFAULT_PASSWORD)
password=ray_constants.REDIS_DEFAULT_PASSWORD,
)
sub = client.pubsub(ignore_subscribe_messages=True)
sub.psubscribe(gcs_utils.RAY_ACTOR_PUBSUB_PATTERN)
@ -245,8 +252,7 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
time.sleep(0.01)
continue
pubsub_msg = gcs_utils.PubSubMessage.FromString(msg["data"])
actor_data = gcs_utils.ActorTableData.FromString(
pubsub_msg.data)
actor_data = gcs_utils.ActorTableData.FromString(pubsub_msg.data)
if actor_data is None:
continue
msgs.append(actor_data)
@ -266,12 +272,22 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
def actor_table_data_to_dict(message):
return dashboard_utils.message_to_dict(
message, {
"actorId", "parentId", "jobId", "workerId", "rayletId",
"actorCreationDummyObjectId", "callerId", "taskId",
"parentTaskId", "sourceActorId", "placementGroupId"
message,
{
"actorId",
"parentId",
"jobId",
"workerId",
"rayletId",
"actorCreationDummyObjectId",
"callerId",
"taskId",
"parentTaskId",
"sourceActorId",
"placementGroupId",
},
including_default_value_fields=False)
including_default_value_fields=False,
)
non_state_keys = ("actorId", "jobId", "taskSpec")
@ -287,23 +303,31 @@ def test_actor_pubsub(disable_aiohttp_cache, ray_start_with_dashboard):
# be published.
elif actor_data_dict["state"] in ("ALIVE", "DEAD"):
assert actor_data_dict.keys() >= {
"state", "address", "timestamp", "pid", "rayNamespace"
"state",
"address",
"timestamp",
"pid",
"rayNamespace",
}
elif actor_data_dict["state"] == "PENDING_CREATION":
assert actor_data_dict.keys() == {
"state", "address", "actorId", "actorCreationDummyObjectId",
"jobId", "ownerAddress", "taskSpec", "className",
"serializedRuntimeEnv", "rayNamespace"
"state",
"address",
"actorId",
"actorCreationDummyObjectId",
"jobId",
"ownerAddress",
"taskSpec",
"className",
"serializedRuntimeEnv",
"rayNamespace",
}
else:
raise Exception("Unknown state: {}".format(
actor_data_dict["state"]))
raise Exception("Unknown state: {}".format(actor_data_dict["state"]))
def test_nil_node(enable_test_module, disable_aiohttp_cache,
ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
def test_nil_node(enable_test_module, disable_aiohttp_cache, ray_start_with_dashboard):
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
webui_url = ray_start_with_dashboard["webui_url"]
assert wait_until_server_available(webui_url)
webui_url = format_web_url(webui_url)
@ -334,9 +358,13 @@ def test_nil_node(enable_test_module, disable_aiohttp_cache,
last_ex = ex
finally:
if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception(
type(last_ex), last_ex,
last_ex.__traceback__) if last_ex else []
ex_stack = (
traceback.format_exception(
type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}")

View file

@ -24,13 +24,11 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
os.makedirs(self._event_dir, exist_ok=True)
self._monitor: Union[asyncio.Task, None] = None
self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None
self._cached_events = asyncio.Queue(
event_consts.EVENT_AGENT_CACHE_SIZE)
logger.info("Event agent cache buffer size: %s",
self._cached_events.maxsize)
self._cached_events = asyncio.Queue(event_consts.EVENT_AGENT_CACHE_SIZE)
logger.info("Event agent cache buffer size: %s", self._cached_events.maxsize)
async def _connect_to_dashboard(self):
""" Connect to the dashboard. If the dashboard is not started, then
"""Connect to the dashboard. If the dashboard is not started, then
this method will never returns.
Returns:
@ -41,23 +39,24 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
# TODO: Use async version if performance is an issue
dashboard_rpc_address = internal_kv._internal_kv_get(
dashboard_consts.DASHBOARD_RPC_ADDRESS,
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
if dashboard_rpc_address:
logger.info("Report events to %s", dashboard_rpc_address)
options = (("grpc.enable_http_proxy", 0), )
options = (("grpc.enable_http_proxy", 0),)
channel = utils.init_grpc_channel(
dashboard_rpc_address,
options=options,
asynchronous=True)
dashboard_rpc_address, options=options, asynchronous=True
)
return event_pb2_grpc.ReportEventServiceStub(channel)
except Exception:
logger.exception("Connect to dashboard failed.")
await asyncio.sleep(
event_consts.RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS)
event_consts.RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS
)
@async_loop_forever(event_consts.EVENT_AGENT_REPORT_INTERVAL_SECONDS)
async def report_events(self):
""" Report events from cached events queue. Reconnect to dashboard if
"""Report events from cached events queue. Reconnect to dashboard if
report failed. Log error after retry EVENT_AGENT_RETRY_TIMES.
This method will never returns.
@ -70,14 +69,15 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
await self._stub.ReportEvents(request)
break
except Exception:
logger.exception("Report event failed, reconnect to the "
"dashboard.")
logger.exception("Report event failed, reconnect to the " "dashboard.")
self._stub = await self._connect_to_dashboard()
else:
data_str = str(data)
limit = event_consts.LOG_ERROR_EVENT_STRING_LENGTH_LIMIT
logger.error("Report event failed: %s",
data_str[:limit] + (data_str[limit:] and "..."))
logger.error(
"Report event failed: %s",
data_str[:limit] + (data_str[limit:] and "..."),
)
async def run(self, server):
# Connect to dashboard.
@ -86,7 +86,8 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
self._monitor = monitor_events(
self._event_dir,
lambda data: create_task(self._cached_events.put(data)),
source_types=event_consts.EVENT_AGENT_MONITOR_SOURCE_TYPES)
source_types=event_consts.EVENT_AGENT_MONITOR_SOURCE_TYPES,
)
# Start reporting events.
await self.report_events()

View file

@ -4,22 +4,20 @@ from ray.core.generated import event_pb2
LOG_ERROR_EVENT_STRING_LENGTH_LIMIT = 1000
RETRY_CONNECT_TO_DASHBOARD_INTERVAL_SECONDS = 2
# Monitor events
SCAN_EVENT_DIR_INTERVAL_SECONDS = env_integer(
"SCAN_EVENT_DIR_INTERVAL_SECONDS", 2)
SCAN_EVENT_DIR_INTERVAL_SECONDS = env_integer("SCAN_EVENT_DIR_INTERVAL_SECONDS", 2)
SCAN_EVENT_START_OFFSET_SECONDS = -30 * 60
CONCURRENT_READ_LIMIT = 50
EVENT_READ_LINE_COUNT_LIMIT = 200
EVENT_READ_LINE_LENGTH_LIMIT = env_integer("EVENT_READ_LINE_LENGTH_LIMIT",
2 * 1024 * 1024) # 2MB
EVENT_READ_LINE_LENGTH_LIMIT = env_integer(
"EVENT_READ_LINE_LENGTH_LIMIT", 2 * 1024 * 1024
) # 2MB
# Report events
EVENT_AGENT_REPORT_INTERVAL_SECONDS = 0.1
EVENT_AGENT_RETRY_TIMES = 10
EVENT_AGENT_CACHE_SIZE = 10240
# Event sources
EVENT_HEAD_MONITOR_SOURCE_TYPES = [
event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)
]
EVENT_HEAD_MONITOR_SOURCE_TYPES = [event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)]
EVENT_AGENT_MONITOR_SOURCE_TYPES = list(
set(event_pb2.Event.SourceType.keys()) -
set(EVENT_HEAD_MONITOR_SOURCE_TYPES))
set(event_pb2.Event.SourceType.keys()) - set(EVENT_HEAD_MONITOR_SOURCE_TYPES)
)
EVENT_SOURCE_ALL = event_pb2.Event.SourceType.keys()

View file

@ -24,8 +24,9 @@ JobEvents = OrderedDict
dashboard_utils._json_compatible_types.add(JobEvents)
class EventHead(dashboard_utils.DashboardHeadModule,
event_pb2_grpc.ReportEventServiceServicer):
class EventHead(
dashboard_utils.DashboardHeadModule, event_pb2_grpc.ReportEventServiceServicer
):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._event_dir = os.path.join(self._dashboard_head.log_dir, "events")
@ -70,21 +71,24 @@ class EventHead(dashboard_utils.DashboardHeadModule,
for job_id, job_events in DataSource.events.items()
}
return dashboard_optional_utils.rest_response(
success=True, message="All events fetched.", events=all_events)
success=True, message="All events fetched.", events=all_events
)
job_events = DataSource.events.get(job_id, {})
return dashboard_optional_utils.rest_response(
success=True,
message="Job events fetched.",
job_id=job_id,
events=list(job_events.values()))
events=list(job_events.values()),
)
async def run(self, server):
event_pb2_grpc.add_ReportEventServiceServicer_to_server(self, server)
self._monitor = monitor_events(
self._event_dir,
lambda data: self._update_events(parse_event_strings(data)),
source_types=event_consts.EVENT_HEAD_MONITOR_SOURCE_TYPES)
source_types=event_consts.EVENT_HEAD_MONITOR_SOURCE_TYPES,
)
@staticmethod
def is_minimal_module():

View file

@ -19,8 +19,7 @@ def _get_source_files(event_dir, source_types=None, event_file_filter=None):
source_files = {}
all_source_types = set(event_consts.EVENT_SOURCE_ALL)
for source_type in source_types or event_consts.EVENT_SOURCE_ALL:
assert source_type in all_source_types, \
f"Invalid source type: {source_type}"
assert source_type in all_source_types, f"Invalid source type: {source_type}"
files = []
for n in event_log_names:
if fnmatch.fnmatch(n, f"*{source_type}*"):
@ -35,9 +34,9 @@ def _get_source_files(event_dir, source_types=None, event_file_filter=None):
def _restore_newline(event_dict):
try:
event_dict["message"] = event_dict["message"]\
.replace("\\n", "\n")\
.replace("\\r", "\n")
event_dict["message"] = (
event_dict["message"].replace("\\n", "\n").replace("\\r", "\n")
)
except Exception:
logger.exception("Restore newline for event failed: %s", event_dict)
return event_dict
@ -61,13 +60,13 @@ def parse_event_strings(event_string_list):
ReadFileResult = collections.namedtuple(
"ReadFileResult", ["fid", "size", "mtime", "position", "lines"])
"ReadFileResult", ["fid", "size", "mtime", "position", "lines"]
)
def _read_file(file,
pos,
n_lines=event_consts.EVENT_READ_LINE_COUNT_LIMIT,
closefd=True):
def _read_file(
file, pos, n_lines=event_consts.EVENT_READ_LINE_COUNT_LIMIT, closefd=True
):
with open(file, "rb", closefd=closefd) as f:
# The ino may be 0 on Windows.
stat = os.stat(f.fileno())
@ -82,24 +81,25 @@ def _read_file(file,
if sep - start <= event_consts.EVENT_READ_LINE_LENGTH_LIMIT:
lines.append(mm[start:sep].decode("utf-8"))
else:
truncated_size = min(
100, event_consts.EVENT_READ_LINE_LENGTH_LIMIT)
truncated_size = min(100, event_consts.EVENT_READ_LINE_LENGTH_LIMIT)
logger.warning(
"Ignored long string: %s...(%s chars)",
mm[start:start + truncated_size].decode("utf-8"),
sep - start)
mm[start : start + truncated_size].decode("utf-8"),
sep - start,
)
start = sep + 1
return ReadFileResult(fid, stat.st_size, stat.st_mtime, start, lines)
def monitor_events(
event_dir,
callback,
scan_interval_seconds=event_consts.SCAN_EVENT_DIR_INTERVAL_SECONDS,
start_mtime=time.time() + event_consts.SCAN_EVENT_START_OFFSET_SECONDS,
monitor_files=None,
source_types=None):
""" Monitor events in directory. New events will be read and passed to the
event_dir,
callback,
scan_interval_seconds=event_consts.SCAN_EVENT_DIR_INTERVAL_SECONDS,
start_mtime=time.time() + event_consts.SCAN_EVENT_START_OFFSET_SECONDS,
monitor_files=None,
source_types=None,
):
"""Monitor events in directory. New events will be read and passed to the
callback.
Args:
@ -121,20 +121,22 @@ def monitor_events(
monitor_files = {}
logger.info(
"Monitor events logs modified after %s on %s, "
"the source types are %s.", start_mtime, event_dir, "all"
if source_types is None else source_types)
"Monitor events logs modified after %s on %s, " "the source types are %s.",
start_mtime,
event_dir,
"all" if source_types is None else source_types,
)
MonitorFile = collections.namedtuple("MonitorFile",
["size", "mtime", "position"])
MonitorFile = collections.namedtuple("MonitorFile", ["size", "mtime", "position"])
def _source_file_filter(source_file):
stat = os.stat(source_file)
return stat.st_mtime > start_mtime
def _read_monitor_file(file, pos):
assert isinstance(file, str), \
f"File should be a str, but a {type(file)}({file}) found"
assert isinstance(
file, str
), f"File should be a str, but a {type(file)}({file}) found"
fd = os.open(file, os.O_RDONLY)
try:
stat = os.stat(fd)
@ -145,12 +147,14 @@ def monitor_events(
fid = stat.st_ino or file
monitor_file = monitor_files.get(fid)
if monitor_file:
if (monitor_file.position == monitor_file.size
and monitor_file.size == stat.st_size
and monitor_file.mtime == stat.st_mtime):
if (
monitor_file.position == monitor_file.size
and monitor_file.size == stat.st_size
and monitor_file.mtime == stat.st_mtime
):
logger.debug(
"Skip reading the file because "
"there is no change: %s", file)
"Skip reading the file because " "there is no change: %s", file
)
return []
position = monitor_file.position
else:
@ -169,22 +173,23 @@ def monitor_events(
@async_loop_forever(scan_interval_seconds, cancellable=True)
async def _scan_event_log_files():
# Scan event files.
source_files = await loop.run_in_executor(None, _get_source_files,
event_dir, source_types,
_source_file_filter)
source_files = await loop.run_in_executor(
None, _get_source_files, event_dir, source_types, _source_file_filter
)
# Limit concurrent read to avoid fd exhaustion.
semaphore = asyncio.Semaphore(event_consts.CONCURRENT_READ_LIMIT)
async def _concurrent_coro(filename):
async with semaphore:
return await loop.run_in_executor(None, _read_monitor_file,
filename, 0)
return await loop.run_in_executor(None, _read_monitor_file, filename, 0)
# Read files.
await asyncio.gather(*[
_concurrent_coro(filename)
for filename in list(itertools.chain(*source_files.values()))
])
await asyncio.gather(
*[
_concurrent_coro(filename)
for filename in list(itertools.chain(*source_files.values()))
]
)
return create_task(_scan_event_log_files())

View file

@ -23,7 +23,8 @@ from ray._private.test_utils import (
wait_for_condition,
)
from ray.dashboard.modules.event.event_utils import (
monitor_events, )
monitor_events,
)
logger = logging.getLogger(__name__)
@ -32,7 +33,8 @@ def _get_event(msg="empty message", job_id=None, source_type=None):
return {
"event_id": binary_to_hex(np.random.bytes(18)),
"source_type": random.choice(event_pb2.Event.SourceType.keys())
if source_type is None else source_type,
if source_type is None
else source_type,
"host_name": "po-dev.inc.alipay.net",
"pid": random.randint(1, 65536),
"label": "",
@ -41,16 +43,18 @@ def _get_event(msg="empty message", job_id=None, source_type=None):
"severity": "INFO",
"custom_fields": {
"job_id": ray.JobID.from_int(random.randint(1, 100)).hex()
if job_id is None else job_id,
if job_id is None
else job_id,
"node_id": "",
"task_id": "",
}
},
}
def _test_logger(name, log_file, max_bytes, backup_count):
handler = logging.handlers.RotatingFileHandler(
log_file, maxBytes=max_bytes, backupCount=backup_count)
log_file, maxBytes=max_bytes, backupCount=backup_count
)
formatter = logging.Formatter("%(message)s")
handler.setFormatter(formatter)
@ -63,15 +67,14 @@ def _test_logger(name, log_file, max_bytes, backup_count):
def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
session_dir = ray_start_with_dashboard["session_dir"]
event_dir = os.path.join(session_dir, "logs", "events")
job_id = ray.JobID.from_int(100).hex()
source_type_gcs = event_pb2.Event.SourceType.Name(event_pb2.Event.GCS)
source_type_raylet = event_pb2.Event.SourceType.Name(
event_pb2.Event.RAYLET)
source_type_raylet = event_pb2.Event.SourceType.Name(event_pb2.Event.RAYLET)
test_count = 20
for source_type in [source_type_gcs, source_type_raylet]:
@ -80,10 +83,10 @@ def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
__name__ + str(random.random()),
test_log_file,
max_bytes=2000,
backup_count=1000)
backup_count=1000,
)
for i in range(test_count):
sample_event = _get_event(
str(i), job_id=job_id, source_type=source_type)
sample_event = _get_event(str(i), job_id=job_id, source_type=source_type)
test_logger.info("%s", json.dumps(sample_event))
def _check_events():
@ -112,10 +115,11 @@ def test_event_basic(disable_aiohttp_cache, ray_start_with_dashboard):
wait_for_condition(_check_events, timeout=15)
def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache,
ray_start_with_dashboard):
def test_event_message_limit(
small_event_line_limit, disable_aiohttp_cache, ray_start_with_dashboard
):
event_read_line_length_limit = small_event_line_limit
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
session_dir = ray_start_with_dashboard["session_dir"]
event_dir = os.path.join(session_dir, "logs", "events")
@ -148,8 +152,8 @@ def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache,
except Exception:
pass
os.rename(
os.path.join(event_dir, "tmp.log"),
os.path.join(event_dir, "event_GCS.log"))
os.path.join(event_dir, "tmp.log"), os.path.join(event_dir, "event_GCS.log")
)
def _check_events():
try:
@ -157,14 +161,14 @@ def test_event_message_limit(small_event_line_limit, disable_aiohttp_cache,
resp.raise_for_status()
result = resp.json()
all_events = result["data"]["events"]
assert len(all_events[job_id]
) >= event_consts.EVENT_READ_LINE_COUNT_LIMIT + 10
assert (
len(all_events[job_id]) >= event_consts.EVENT_READ_LINE_COUNT_LIMIT + 10
)
messages = [e["message"] for e in all_events[job_id]]
for i in range(10):
assert str(i) * message_len in messages
assert "2" * (message_len + 1) not in messages
assert str(event_consts.EVENT_READ_LINE_COUNT_LIMIT -
1) in messages
assert str(event_consts.EVENT_READ_LINE_COUNT_LIMIT - 1) in messages
return True
except Exception as ex:
logger.exception(ex)
@ -179,15 +183,12 @@ async def test_monitor_events():
common = event_pb2.Event.SourceType.Name(event_pb2.Event.COMMON)
common_log = os.path.join(temp_dir, f"event_{common}.log")
test_logger = _test_logger(
__name__ + str(random.random()),
common_log,
max_bytes=10,
backup_count=10)
__name__ + str(random.random()), common_log, max_bytes=10, backup_count=10
)
test_events1 = []
monitor_task = monitor_events(
temp_dir,
lambda x: test_events1.extend(x),
scan_interval_seconds=0.01)
temp_dir, lambda x: test_events1.extend(x), scan_interval_seconds=0.01
)
assert not monitor_task.done()
count = 10
@ -206,7 +207,8 @@ async def test_monitor_events():
if time.time() - start_time > timeout:
raise TimeoutError(
f"Timeout, read events: {sorted_events}, "
f"expect events: {expect_events}")
f"expect events: {expect_events}"
)
if len(sorted_events) == len(expect_events):
if sorted_events == expect_events:
break
@ -214,40 +216,37 @@ async def test_monitor_events():
await asyncio.gather(
_writer(count, read_events=test_events1),
_check_events(
[str(i) for i in range(count)], read_events=test_events1))
_check_events([str(i) for i in range(count)], read_events=test_events1),
)
monitor_task.cancel()
test_events2 = []
monitor_task = monitor_events(
temp_dir,
lambda x: test_events2.extend(x),
scan_interval_seconds=0.1)
temp_dir, lambda x: test_events2.extend(x), scan_interval_seconds=0.1
)
await _check_events(
[str(i) for i in range(count)], read_events=test_events2)
await _check_events([str(i) for i in range(count)], read_events=test_events2)
await _writer(count, count * 2, read_events=test_events2)
await _check_events(
[str(i) for i in range(count * 2)], read_events=test_events2)
[str(i) for i in range(count * 2)], read_events=test_events2
)
log_file_count = len(os.listdir(temp_dir))
test_logger = _test_logger(
__name__ + str(random.random()),
common_log,
max_bytes=1000,
backup_count=10)
__name__ + str(random.random()), common_log, max_bytes=1000, backup_count=10
)
assert len(os.listdir(temp_dir)) == log_file_count
await _writer(
count * 2, count * 3, spin=False, read_events=test_events2)
await _writer(count * 2, count * 3, spin=False, read_events=test_events2)
await _check_events(
[str(i) for i in range(count * 3)], read_events=test_events2)
await _writer(
count * 3, count * 4, spin=False, read_events=test_events2)
[str(i) for i in range(count * 3)], read_events=test_events2
)
await _writer(count * 3, count * 4, spin=False, read_events=test_events2)
await _check_events(
[str(i) for i in range(count * 4)], read_events=test_events2)
[str(i) for i in range(count * 4)], read_events=test_events2
)
# Test cancel monitor task.
monitor_task.cancel()
@ -255,8 +254,7 @@ async def test_monitor_events():
await monitor_task
assert monitor_task.done()
assert len(
os.listdir(temp_dir)) > 1, "Event log should have rollovers."
assert len(os.listdir(temp_dir)) > 1, "Event log should have rollovers."
if __name__ == "__main__":

View file

@ -7,21 +7,21 @@ import yaml
import click
from ray.autoscaler._private.cli_logger import (add_click_logging_options,
cli_logger, cf)
from ray.autoscaler._private.cli_logger import add_click_logging_options, cli_logger, cf
from ray.dashboard.modules.job.common import JobStatus
from ray.dashboard.modules.job.sdk import JobSubmissionClient
def _get_sdk_client(address: Optional[str],
create_cluster_if_needed: bool = False
) -> JobSubmissionClient:
def _get_sdk_client(
address: Optional[str], create_cluster_if_needed: bool = False
) -> JobSubmissionClient:
if address is None:
if "RAY_ADDRESS" not in os.environ:
raise ValueError(
"Address must be specified using either the --address flag "
"or RAY_ADDRESS environment variable.")
"or RAY_ADDRESS environment variable."
)
address = os.environ["RAY_ADDRESS"]
cli_logger.labeled_value("Job submission server address", address)
@ -73,55 +73,67 @@ def job_cli_group():
pass
@job_cli_group.command(
"submit", help="Submit a job to be executed on the cluster.")
@job_cli_group.command("submit", help="Submit a job to be executed on the cluster.")
@click.option(
"--address",
type=str,
default=None,
required=False,
help=("Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."))
help=(
"Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."
),
)
@click.option(
"--job-id",
type=str,
default=None,
required=False,
help=("Job ID to specify for the job. "
"If not provided, one will be generated."))
help=("Job ID to specify for the job. " "If not provided, one will be generated."),
)
@click.option(
"--runtime-env",
type=str,
default=None,
required=False,
help="Path to a local YAML file containing a runtime_env definition.")
help="Path to a local YAML file containing a runtime_env definition.",
)
@click.option(
"--runtime-env-json",
type=str,
default=None,
required=False,
help="JSON-serialized runtime_env dictionary.")
help="JSON-serialized runtime_env dictionary.",
)
@click.option(
"--working-dir",
type=str,
default=None,
required=False,
help=("Directory containing files that your job will run in. Can be a "
"local directory or a remote URI to a .zip file (S3, GS, HTTP). "
"If specified, this overrides the option in --runtime-env."),
help=(
"Directory containing files that your job will run in. Can be a "
"local directory or a remote URI to a .zip file (S3, GS, HTTP). "
"If specified, this overrides the option in --runtime-env."
),
)
@click.option(
"--no-wait",
is_flag=True,
type=bool,
default=False,
help="If set, will not stream logs and wait for the job to exit.")
help="If set, will not stream logs and wait for the job to exit.",
)
@add_click_logging_options
@click.argument("entrypoint", nargs=-1, required=True, type=click.UNPROCESSED)
def job_submit(address: Optional[str], job_id: Optional[str],
runtime_env: Optional[str], runtime_env_json: Optional[str],
working_dir: Optional[str], entrypoint: Tuple[str],
no_wait: bool):
def job_submit(
address: Optional[str],
job_id: Optional[str],
runtime_env: Optional[str],
runtime_env_json: Optional[str],
working_dir: Optional[str],
entrypoint: Tuple[str],
no_wait: bool,
):
"""Submits a job to be run on the cluster.
Example:
@ -132,8 +144,9 @@ def job_submit(address: Optional[str], job_id: Optional[str],
final_runtime_env = {}
if runtime_env is not None:
if runtime_env_json is not None:
raise ValueError("Only one of --runtime_env and "
"--runtime-env-json can be provided.")
raise ValueError(
"Only one of --runtime_env and " "--runtime-env-json can be provided."
)
with open(runtime_env, "r") as f:
final_runtime_env = yaml.safe_load(f)
@ -143,14 +156,14 @@ def job_submit(address: Optional[str], job_id: Optional[str],
if working_dir is not None:
if "working_dir" in final_runtime_env:
cli_logger.warning(
"Overriding runtime_env working_dir with --working-dir option")
"Overriding runtime_env working_dir with --working-dir option"
)
final_runtime_env["working_dir"] = working_dir
job_id = client.submit_job(
entrypoint=" ".join(entrypoint),
job_id=job_id,
runtime_env=final_runtime_env)
entrypoint=" ".join(entrypoint), job_id=job_id, runtime_env=final_runtime_env
)
_log_big_success_msg(f"Job '{job_id}' submitted successfully")
@ -172,15 +185,16 @@ def job_submit(address: Optional[str], job_id: Optional[str],
# sdk version 0 does not have log streaming
if not no_wait:
if int(sdk_version) > 0:
cli_logger.print("Tailing logs until the job exits "
"(disable with --no-wait):")
asyncio.get_event_loop().run_until_complete(
_tail_logs(client, job_id))
cli_logger.print(
"Tailing logs until the job exits " "(disable with --no-wait):"
)
asyncio.get_event_loop().run_until_complete(_tail_logs(client, job_id))
else:
cli_logger.warning(
"Tailing logs is not enabled for job sdk client version "
f"{sdk_version}. Please upgrade your ray to latest version "
"for this feature.")
"for this feature."
)
@job_cli_group.command("status", help="Get the status of a running job.")
@ -189,8 +203,11 @@ def job_submit(address: Optional[str], job_id: Optional[str],
type=str,
default=None,
required=False,
help=("Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."))
help=(
"Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."
),
)
@click.argument("job-id", type=str)
@add_click_logging_options
def job_status(address: Optional[str], job_id: str):
@ -209,14 +226,18 @@ def job_status(address: Optional[str], job_id: str):
type=str,
default=None,
required=False,
help=("Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."))
help=(
"Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."
),
)
@click.option(
"--no-wait",
is_flag=True,
type=bool,
default=False,
help="If set, will not wait for the job to exit.")
help="If set, will not wait for the job to exit.",
)
@click.argument("job-id", type=str)
@add_click_logging_options
def job_stop(address: Optional[str], no_wait: bool, job_id: str):
@ -232,14 +253,13 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
if no_wait:
return
else:
cli_logger.print(f"Waiting for job '{job_id}' to exit "
f"(disable with --no-wait):")
cli_logger.print(
f"Waiting for job '{job_id}' to exit " f"(disable with --no-wait):"
)
while True:
status = client.get_job_status(job_id)
if status.status in {
JobStatus.STOPPED, JobStatus.SUCCEEDED, JobStatus.FAILED
}:
if status.status in {JobStatus.STOPPED, JobStatus.SUCCEEDED, JobStatus.FAILED}:
_log_job_status(client, job_id)
break
else:
@ -253,8 +273,11 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
type=str,
default=None,
required=False,
help=("Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."))
help=(
"Address of the Ray cluster to connect to. Can also be specified "
"using the RAY_ADDRESS environment variable."
),
)
@click.argument("job-id", type=str)
@click.option(
"-f",
@ -262,7 +285,8 @@ def job_stop(address: Optional[str], no_wait: bool, job_id: str):
is_flag=True,
type=bool,
default=False,
help="If set, follow the logs (like `tail -f`).")
help="If set, follow the logs (like `tail -f`).",
)
@add_click_logging_options
def job_logs(address: Optional[str], job_id: str, follow: bool):
"""Gets the logs of a job.
@ -275,12 +299,12 @@ def job_logs(address: Optional[str], job_id: str, follow: bool):
# sdk version 0 did not have log streaming
if follow:
if int(sdk_version) > 0:
asyncio.get_event_loop().run_until_complete(
_tail_logs(client, job_id))
asyncio.get_event_loop().run_until_complete(_tail_logs(client, job_id))
else:
cli_logger.warning(
"Tailing logs is not enabled for job sdk client version "
f"{sdk_version}. Please upgrade your ray to latest version "
"for this feature.")
"for this feature."
)
else:
print(client.get_job_logs(job_id), end="")

View file

@ -39,8 +39,10 @@ class JobStatusInfo:
def __post_init__(self):
if self.message is None:
if self.status == JobStatus.PENDING:
self.message = ("Job has not started yet, likely waiting "
"for the runtime_env to be set up.")
self.message = (
"Job has not started yet, likely waiting "
"for the runtime_env to be set up."
)
elif self.status == JobStatus.RUNNING:
self.message = "Job is currently running."
elif self.status == JobStatus.STOPPED:
@ -55,6 +57,7 @@ class JobStatusStorageClient:
"""
Handles formatting of status storage key given job id.
"""
JOB_STATUS_KEY = "_ray_internal_job_status_{job_id}"
def __init__(self):
@ -69,12 +72,14 @@ class JobStatusStorageClient:
_internal_kv_put(
self.JOB_STATUS_KEY.format(job_id=job_id),
pickle.dumps(status),
namespace=ray_constants.KV_NAMESPACE_JOB)
namespace=ray_constants.KV_NAMESPACE_JOB,
)
def get_status(self, job_id: str) -> Optional[JobStatusInfo]:
pickled_status = _internal_kv_get(
self.JOB_STATUS_KEY.format(job_id=job_id),
namespace=ray_constants.KV_NAMESPACE_JOB)
namespace=ray_constants.KV_NAMESPACE_JOB,
)
if pickled_status is None:
return None
else:
@ -87,18 +92,16 @@ def uri_to_http_components(package_uri: str) -> Tuple[str, str]:
# We need to strip the gcs:// prefix and .zip suffix to make it
# possible to pass the package_uri over HTTP.
protocol, package_name = parse_uri(package_uri)
return protocol.value, package_name[:-len(".zip")]
return protocol.value, package_name[: -len(".zip")]
def http_uri_components_to_uri(protocol: str, package_name: str) -> str:
if package_name.endswith(".zip"):
raise ValueError(
f"package_name ({package_name}) should not end in .zip")
raise ValueError(f"package_name ({package_name}) should not end in .zip")
return f"{protocol}://{package_name}.zip"
def validate_request_type(json_data: Dict[str, Any],
request_type: dataclass) -> Any:
def validate_request_type(json_data: Dict[str, Any], request_type: dataclass) -> Any:
return request_type(**json_data)
@ -124,8 +127,7 @@ class JobSubmitRequest:
def __post_init__(self):
if not isinstance(self.entrypoint, str):
raise TypeError(
f"entrypoint must be a string, got {type(self.entrypoint)}")
raise TypeError(f"entrypoint must be a string, got {type(self.entrypoint)}")
if self.job_id is not None and not isinstance(self.job_id, str):
raise TypeError(
@ -141,21 +143,21 @@ class JobSubmitRequest:
for k in self.runtime_env.keys():
if not isinstance(k, str):
raise TypeError(
f"runtime_env keys must be strings, got {type(k)}")
f"runtime_env keys must be strings, got {type(k)}"
)
if self.metadata is not None:
if not isinstance(self.metadata, dict):
raise TypeError(
f"metadata must be a dict, got {type(self.metadata)}")
raise TypeError(f"metadata must be a dict, got {type(self.metadata)}")
else:
for k in self.metadata.keys():
if not isinstance(k, str):
raise TypeError(
f"metadata keys must be strings, got {type(k)}")
raise TypeError(f"metadata keys must be strings, got {type(k)}")
for v in self.metadata.values():
if not isinstance(v, str):
raise TypeError(
f"metadata values must be strings, got {type(v)}")
f"metadata values must be strings, got {type(v)}"
)
@dataclass

View file

@ -12,8 +12,7 @@ import ray
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
from ray._private.gcs_utils import use_gcs_for_bootstrap
from ray._private.runtime_env.packaging import (package_exists,
upload_package_to_gcs)
from ray._private.runtime_env.packaging import package_exists, upload_package_to_gcs
from ray.dashboard.modules.job.common import (
CURRENT_VERSION,
http_uri_components_to_uri,
@ -45,19 +44,20 @@ def _init_ray_and_catch_exceptions(f: Callable) -> Callable:
if use_gcs_for_bootstrap():
address = self._dashboard_head.gcs_address
redis_pw = None
logger.info(
f"Connecting to ray with address={address}")
logger.info(f"Connecting to ray with address={address}")
else:
ip, port = self._dashboard_head.redis_address
redis_pw = self._dashboard_head.redis_password
address = f"{ip}:{port}"
logger.info(
f"Connecting to ray with address={address}, "
f"redis_pw={redis_pw}")
f"redis_pw={redis_pw}"
)
ray.init(
address=address,
namespace=RAY_INTERNAL_JOBS_NAMESPACE,
_redis_password=redis_pw)
_redis_password=redis_pw,
)
except Exception as e:
ray.shutdown()
raise e from None
@ -67,7 +67,8 @@ def _init_ray_and_catch_exceptions(f: Callable) -> Callable:
logger.exception(f"Unexpected error in handler: {e}")
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code)
status=aiohttp.web.HTTPInternalServerError.status_code,
)
return check
@ -77,8 +78,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
super().__init__(dashboard_head)
self._job_manager = None
async def _parse_and_validate_request(self, req: Request,
request_type: dataclass) -> Any:
async def _parse_and_validate_request(
self, req: Request, request_type: dataclass
) -> Any:
"""Parse request and cast to request type. If parsing failed, return a
Response object with status 400 and stacktrace instead.
"""
@ -88,7 +90,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
logger.info(f"Got invalid request type: {e}")
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPBadRequest.status_code)
status=aiohttp.web.HTTPBadRequest.status_code,
)
def job_exists(self, job_id: str) -> bool:
status = self._job_manager.get_job_status(job_id)
@ -101,7 +104,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
resp = VersionResponse(
version=CURRENT_VERSION,
ray_version=ray.__version__,
ray_commit=ray.__commit__)
ray_commit=ray.__commit__,
)
return Response(
text=json.dumps(dataclasses.asdict(resp)),
content_type="application/json",
@ -113,12 +117,14 @@ class JobHead(dashboard_utils.DashboardHeadModule):
async def get_package(self, req: Request) -> Response:
package_uri = http_uri_components_to_uri(
protocol=req.match_info["protocol"],
package_name=req.match_info["package_name"])
package_name=req.match_info["package_name"],
)
if not package_exists(package_uri):
return Response(
text=f"Package {package_uri} does not exist",
status=aiohttp.web.HTTPNotFound.status_code)
status=aiohttp.web.HTTPNotFound.status_code,
)
return Response()
@ -127,14 +133,16 @@ class JobHead(dashboard_utils.DashboardHeadModule):
async def upload_package(self, req: Request):
package_uri = http_uri_components_to_uri(
protocol=req.match_info["protocol"],
package_name=req.match_info["package_name"])
package_name=req.match_info["package_name"],
)
logger.info(f"Uploading package {package_uri} to the GCS.")
try:
upload_package_to_gcs(package_uri, await req.read())
except Exception:
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code)
status=aiohttp.web.HTTPInternalServerError.status_code,
)
return Response(status=aiohttp.web.HTTPOk.status_code)
@ -153,17 +161,20 @@ class JobHead(dashboard_utils.DashboardHeadModule):
entrypoint=submit_request.entrypoint,
job_id=submit_request.job_id,
runtime_env=submit_request.runtime_env,
metadata=submit_request.metadata)
metadata=submit_request.metadata,
)
resp = JobSubmitResponse(job_id=job_id)
except (TypeError, ValueError):
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPBadRequest.status_code)
status=aiohttp.web.HTTPBadRequest.status_code,
)
except Exception:
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code)
status=aiohttp.web.HTTPInternalServerError.status_code,
)
return Response(
text=json.dumps(dataclasses.asdict(resp)),
@ -178,7 +189,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
if not self.job_exists(job_id):
return Response(
text=f"Job {job_id} does not exist",
status=aiohttp.web.HTTPNotFound.status_code)
status=aiohttp.web.HTTPNotFound.status_code,
)
try:
stopped = self._job_manager.stop_job(job_id)
@ -186,11 +198,12 @@ class JobHead(dashboard_utils.DashboardHeadModule):
except Exception:
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code)
status=aiohttp.web.HTTPInternalServerError.status_code,
)
return Response(
text=json.dumps(dataclasses.asdict(resp)),
content_type="application/json")
text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
)
@routes.get("/api/jobs/{job_id}")
@_init_ray_and_catch_exceptions
@ -199,13 +212,14 @@ class JobHead(dashboard_utils.DashboardHeadModule):
if not self.job_exists(job_id):
return Response(
text=f"Job {job_id} does not exist",
status=aiohttp.web.HTTPNotFound.status_code)
status=aiohttp.web.HTTPNotFound.status_code,
)
status: JobStatusInfo = self._job_manager.get_job_status(job_id)
resp = JobStatusResponse(status=status.status, message=status.message)
return Response(
text=json.dumps(dataclasses.asdict(resp)),
content_type="application/json")
text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
)
@routes.get("/api/jobs/{job_id}/logs")
@_init_ray_and_catch_exceptions
@ -214,12 +228,13 @@ class JobHead(dashboard_utils.DashboardHeadModule):
if not self.job_exists(job_id):
return Response(
text=f"Job {job_id} does not exist",
status=aiohttp.web.HTTPNotFound.status_code)
status=aiohttp.web.HTTPNotFound.status_code,
)
resp = JobLogsResponse(logs=self._job_manager.get_job_logs(job_id))
return Response(
text=json.dumps(dataclasses.asdict(resp)),
content_type="application/json")
text=json.dumps(dataclasses.asdict(resp)), content_type="application/json"
)
@routes.get("/api/jobs/{job_id}/logs/tail")
@_init_ray_and_catch_exceptions
@ -228,7 +243,8 @@ class JobHead(dashboard_utils.DashboardHeadModule):
if not self.job_exists(job_id):
return Response(
text=f"Job {job_id} does not exist",
status=aiohttp.web.HTTPNotFound.status_code)
status=aiohttp.web.HTTPNotFound.status_code,
)
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(req)

View file

@ -15,8 +15,12 @@ from ray.exceptions import RuntimeEnvSetupError
import ray.ray_constants as ray_constants
from ray.actor import ActorHandle
from ray.dashboard.modules.job.common import (
JobStatus, JobStatusInfo, JobStatusStorageClient, JOB_ID_METADATA_KEY,
JOB_NAME_METADATA_KEY)
JobStatus,
JobStatusInfo,
JobStatusStorageClient,
JOB_ID_METADATA_KEY,
JOB_NAME_METADATA_KEY,
)
from ray.dashboard.modules.job.utils import file_tail_iterator
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
@ -36,8 +40,8 @@ def generate_job_id() -> str:
"""
rand = random.SystemRandom()
possible_characters = list(
set(string.ascii_letters + string.digits) -
{"I", "l", "o", "O", "0"} # No confusing characters
set(string.ascii_letters + string.digits)
- {"I", "l", "o", "O", "0"} # No confusing characters
)
id_part = "".join(rand.choices(possible_characters, k=16))
return f"raysubmit_{id_part}"
@ -47,6 +51,7 @@ class JobLogStorageClient:
"""
Disk storage for stdout / stderr of driver script logs.
"""
JOB_LOGS_PATH = "job-driver-{job_id}.log"
# Number of last N lines to put in job message upon failure.
NUM_LOG_LINES_ON_ERROR = 10
@ -61,9 +66,9 @@ class JobLogStorageClient:
def tail_logs(self, job_id: str) -> Iterator[str]:
return file_tail_iterator(self.get_log_file_path(job_id))
def get_last_n_log_lines(self,
job_id: str,
num_log_lines=NUM_LOG_LINES_ON_ERROR) -> str:
def get_last_n_log_lines(
self, job_id: str, num_log_lines=NUM_LOG_LINES_ON_ERROR
) -> str:
log_tail_iter = self.tail_logs(job_id)
log_tail_deque = deque(maxlen=num_log_lines)
for line in log_tail_iter:
@ -80,7 +85,8 @@ class JobLogStorageClient:
"""
return os.path.join(
ray.worker._global_node.get_logs_dir_path(),
self.JOB_LOGS_PATH.format(job_id=job_id))
self.JOB_LOGS_PATH.format(job_id=job_id),
)
class JobSupervisor:
@ -95,8 +101,7 @@ class JobSupervisor:
SUBPROCESS_POLL_PERIOD_S = 0.1
def __init__(self, job_id: str, entrypoint: str,
user_metadata: Dict[str, str]):
def __init__(self, job_id: str, entrypoint: str, user_metadata: Dict[str, str]):
self._job_id = job_id
self._status_client = JobStatusStorageClient()
self._log_client = JobLogStorageClient()
@ -104,10 +109,7 @@ class JobSupervisor:
self._entrypoint = entrypoint
# Default metadata if not passed by the user.
self._metadata = {
JOB_ID_METADATA_KEY: job_id,
JOB_NAME_METADATA_KEY: job_id
}
self._metadata = {JOB_ID_METADATA_KEY: job_id, JOB_NAME_METADATA_KEY: job_id}
self._metadata.update(user_metadata)
# fire and forget call from outer job manager to this actor
@ -142,7 +144,8 @@ class JobSupervisor:
shell=True,
start_new_session=True,
stdout=logs_file,
stderr=subprocess.STDOUT)
stderr=subprocess.STDOUT,
)
parent_pid = os.getpid()
# Create new pgid with new subprocess to execute driver command
child_pid = child_process.pid
@ -177,9 +180,10 @@ class JobSupervisor:
return 1
async def run(
self,
# Signal actor used in testing to capture PENDING -> RUNNING cases
_start_signal_actor: Optional[ActorHandle] = None):
self,
# Signal actor used in testing to capture PENDING -> RUNNING cases
_start_signal_actor: Optional[ActorHandle] = None,
):
"""
Stop and start both happen asynchrously, coordinated by asyncio event
and coroutine, respectively.
@ -190,26 +194,26 @@ class JobSupervisor:
3) Handle concurrent events of driver execution and
"""
cur_status = self._get_status()
assert cur_status.status == JobStatus.PENDING, (
"Run should only be called once.")
assert cur_status.status == JobStatus.PENDING, "Run should only be called once."
if _start_signal_actor:
# Block in PENDING state until start signal received.
await _start_signal_actor.wait.remote()
self._status_client.put_status(self._job_id,
JobStatusInfo(JobStatus.RUNNING))
self._status_client.put_status(self._job_id, JobStatusInfo(JobStatus.RUNNING))
try:
# Set JobConfig for the child process (runtime_env, metadata).
os.environ[RAY_JOB_CONFIG_JSON_ENV_VAR] = json.dumps({
"runtime_env": self._runtime_env,
"metadata": self._metadata,
})
os.environ[RAY_JOB_CONFIG_JSON_ENV_VAR] = json.dumps(
{
"runtime_env": self._runtime_env,
"metadata": self._metadata,
}
)
# Set RAY_ADDRESS to local Ray address, if it is not set.
os.environ[
ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE] = \
ray._private.services.get_ray_address_from_environment()
ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE
] = ray._private.services.get_ray_address_from_environment()
# Set PYTHONUNBUFFERED=1 to stream logs during the job instead of
# only streaming them upon completion of the job.
os.environ["PYTHONUNBUFFERED"] = "1"
@ -218,8 +222,8 @@ class JobSupervisor:
polling_task = create_task(self._polling(child_process))
finished, _ = await asyncio.wait(
[polling_task, self._stop_event.wait()],
return_when=FIRST_COMPLETED)
[polling_task, self._stop_event.wait()], return_when=FIRST_COMPLETED
)
if self._stop_event.is_set():
polling_task.cancel()
@ -229,29 +233,29 @@ class JobSupervisor:
else:
# Child process finished execution and no stop event is set
# at the same time
assert len(
finished) == 1, "Should have only one coroutine done"
assert len(finished) == 1, "Should have only one coroutine done"
[child_process_task] = finished
return_code = child_process_task.result()
if return_code == 0:
self._status_client.put_status(self._job_id,
JobStatus.SUCCEEDED)
self._status_client.put_status(self._job_id, JobStatus.SUCCEEDED)
else:
log_tail = self._log_client.get_last_n_log_lines(
self._job_id)
log_tail = self._log_client.get_last_n_log_lines(self._job_id)
if log_tail is not None and log_tail != "":
message = ("Job failed due to an application error, "
"last available logs:\n" + log_tail)
message = (
"Job failed due to an application error, "
"last available logs:\n" + log_tail
)
else:
message = None
self._status_client.put_status(
self._job_id,
JobStatusInfo(
status=JobStatus.FAILED, message=message))
JobStatusInfo(status=JobStatus.FAILED, message=message),
)
except Exception:
logger.error(
"Got unexpected exception while trying to execute driver "
f"command. {traceback.format_exc()}")
f"command. {traceback.format_exc()}"
)
finally:
# clean up actor after tasks are finished
ray.actor.exit_actor()
@ -260,8 +264,7 @@ class JobSupervisor:
return self._status_client.get_status(self._job_id)
def stop(self):
"""Set step_event and let run() handle the rest in its asyncio.wait().
"""
"""Set step_event and let run() handle the rest in its asyncio.wait()."""
self._stop_event.set()
@ -271,6 +274,7 @@ class JobManager:
It does not provide persistence, all info will be lost if the cluster
goes down.
"""
JOB_ACTOR_NAME = "_ray_internal_job_actor_{job_id}"
# Time that we will sleep while tailing logs if no new log line is
# available.
@ -300,11 +304,9 @@ class JobManager:
if key.startswith("node:"):
return key
else:
raise ValueError(
"Cannot find the node dictionary for current node.")
raise ValueError("Cannot find the node dictionary for current node.")
def _handle_supervisor_startup(self, job_id: str,
result: Optional[Exception]):
def _handle_supervisor_startup(self, job_id: str, result: Optional[Exception]):
"""Handle the result of starting a job supervisor actor.
If started successfully, result should be None. Otherwise it should be
@ -321,26 +323,30 @@ class JobManager:
job_id,
JobStatusInfo(
status=JobStatus.FAILED,
message=(f"runtime_env setup failed: {result}")))
message=(f"runtime_env setup failed: {result}"),
),
)
elif isinstance(result, Exception):
logger.error(
f"Failed to start supervisor for job {job_id}: {result}.")
logger.error(f"Failed to start supervisor for job {job_id}: {result}.")
self._status_client.put_status(
job_id,
JobStatusInfo(
status=JobStatus.FAILED,
message=f"Error occurred while starting the job: {result}")
message=f"Error occurred while starting the job: {result}",
),
)
else:
assert False, "This should not be reached."
def submit_job(self,
*,
entrypoint: str,
job_id: Optional[str] = None,
runtime_env: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, str]] = None,
_start_signal_actor: Optional[ActorHandle] = None) -> str:
def submit_job(
self,
*,
entrypoint: str,
job_id: Optional[str] = None,
runtime_env: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, str]] = None,
_start_signal_actor: Optional[ActorHandle] = None,
) -> str:
"""
Job execution happens asynchronously.
@ -390,8 +396,8 @@ class JobManager:
resources={
self._get_current_node_resource_key(): 0.001,
},
runtime_env=runtime_env).remote(job_id, entrypoint, metadata
or {})
runtime_env=runtime_env,
).remote(job_id, entrypoint, metadata or {})
actor.run.remote(_start_signal_actor=_start_signal_actor)
def callback(result: Optional[Exception]):
@ -441,7 +447,8 @@ class JobManager:
# updating GCS with latest status.
last_status = self._status_client.get_status(job_id)
if last_status and last_status.status in {
JobStatus.PENDING, JobStatus.RUNNING
JobStatus.PENDING,
JobStatus.RUNNING,
}:
self._status_client.put_status(job_id, JobStatus.FAILED)

View file

@ -13,10 +13,19 @@ except ImportError:
requests = None
from ray._private.runtime_env.packaging import (
create_package, get_uri_for_directory, parse_uri)
create_package,
get_uri_for_directory,
parse_uri,
)
from ray.dashboard.modules.job.common import (
JobSubmitRequest, JobSubmitResponse, JobStopResponse, JobStatusInfo,
JobStatusResponse, JobLogsResponse, uri_to_http_components)
JobSubmitRequest,
JobSubmitResponse,
JobStopResponse,
JobStatusInfo,
JobStatusResponse,
JobLogsResponse,
uri_to_http_components,
)
from ray.client_builder import _split_address
@ -33,51 +42,49 @@ class ClusterInfo:
def get_job_submission_client_cluster_info(
address: str,
# For backwards compatibility
*,
# only used in importlib case in parse_cluster_info, but needed
# in function signature.
create_cluster_if_needed: Optional[bool] = False,
cookies: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None) -> ClusterInfo:
address: str,
# For backwards compatibility
*,
# only used in importlib case in parse_cluster_info, but needed
# in function signature.
create_cluster_if_needed: Optional[bool] = False,
cookies: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
) -> ClusterInfo:
"""Get address, cookies, and metadata used for JobSubmissionClient.
Args:
address (str): Address without the module prefix that is passed
to JobSubmissionClient.
create_cluster_if_needed (bool): Indicates whether the cluster
of the address returned needs to be running. Ray doesn't
start a cluster before interacting with jobs, but other
implementations may do so.
Args:
address (str): Address without the module prefix that is passed
to JobSubmissionClient.
create_cluster_if_needed (bool): Indicates whether the cluster
of the address returned needs to be running. Ray doesn't
start a cluster before interacting with jobs, but other
implementations may do so.
Returns:
ClusterInfo object consisting of address, cookies, and metadata
for JobSubmissionClient to use.
"""
Returns:
ClusterInfo object consisting of address, cookies, and metadata
for JobSubmissionClient to use.
"""
return ClusterInfo(
address="http://" + address,
cookies=cookies,
metadata=metadata,
headers=headers)
address="http://" + address, cookies=cookies, metadata=metadata, headers=headers
)
def parse_cluster_info(
address: str,
create_cluster_if_needed: bool = False,
cookies: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None) -> ClusterInfo:
address: str,
create_cluster_if_needed: bool = False,
cookies: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
) -> ClusterInfo:
module_string, inner_address = _split_address(address.rstrip("/"))
# If user passes in a raw HTTP(S) address, just pass it through.
if module_string == "http" or module_string == "https":
return ClusterInfo(
address=address,
cookies=cookies,
metadata=metadata,
headers=headers)
address=address, cookies=cookies, metadata=metadata, headers=headers
)
# If user passes in a Ray address, convert it to HTTP.
elif module_string == "ray":
return get_job_submission_client_cluster_info(
@ -85,7 +92,8 @@ def parse_cluster_info(
create_cluster_if_needed=create_cluster_if_needed,
cookies=cookies,
metadata=metadata,
headers=headers)
headers=headers,
)
# Try to dynamically import the function to get cluster info.
else:
try:
@ -93,33 +101,40 @@ def parse_cluster_info(
except Exception:
raise RuntimeError(
f"Module: {module_string} does not exist.\n"
f"This module was parsed from Address: {address}") from None
f"This module was parsed from Address: {address}"
) from None
assert "get_job_submission_client_cluster_info" in dir(module), (
f"Module: {module_string} does "
"not have `get_job_submission_client_cluster_info`.")
"not have `get_job_submission_client_cluster_info`."
)
return module.get_job_submission_client_cluster_info(
inner_address,
create_cluster_if_needed=create_cluster_if_needed,
cookies=cookies,
metadata=metadata,
headers=headers)
headers=headers,
)
class JobSubmissionClient:
def __init__(self,
address: str,
create_cluster_if_needed=False,
cookies: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None):
def __init__(
self,
address: str,
create_cluster_if_needed=False,
cookies: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
):
if requests is None:
raise RuntimeError(
"The Ray jobs CLI & SDK require the ray[default] "
"installation: `pip install 'ray[default']``")
"installation: `pip install 'ray[default']``"
)
cluster_info = parse_cluster_info(address, create_cluster_if_needed,
cookies, metadata, headers)
cluster_info = parse_cluster_info(
address, create_cluster_if_needed, cookies, metadata, headers
)
self._address = cluster_info.address
self._cookies = cluster_info.cookies
self._default_metadata = cluster_info.metadata or {}
@ -136,38 +151,43 @@ class JobSubmissionClient:
raise RuntimeError(
"Jobs API not supported on the Ray cluster. "
"Please ensure the cluster is running "
"Ray 1.9 or higher.")
"Ray 1.9 or higher."
)
r.raise_for_status()
# TODO(edoakes): check the version if/when we break compatibility.
except requests.exceptions.ConnectionError:
raise ConnectionError(
f"Failed to connect to Ray at address: {self._address}.")
f"Failed to connect to Ray at address: {self._address}."
)
def _raise_error(self, r: "requests.Response"):
raise RuntimeError(
f"Request failed with status code {r.status_code}: {r.text}.")
f"Request failed with status code {r.status_code}: {r.text}."
)
def _do_request(self,
method: str,
endpoint: str,
*,
data: Optional[bytes] = None,
json_data: Optional[dict] = None) -> Optional[object]:
def _do_request(
self,
method: str,
endpoint: str,
*,
data: Optional[bytes] = None,
json_data: Optional[dict] = None,
) -> Optional[object]:
url = self._address + endpoint
logger.debug(
f"Sending request to {url} with json data: {json_data or {}}.")
logger.debug(f"Sending request to {url} with json data: {json_data or {}}.")
return requests.request(
method,
url,
cookies=self._cookies,
data=data,
json=json_data,
headers=self._headers)
headers=self._headers,
)
def _package_exists(
self,
package_uri: str,
self,
package_uri: str,
) -> bool:
protocol, package_name = uri_to_http_components(package_uri)
r = self._do_request("GET", f"/api/packages/{protocol}/{package_name}")
@ -181,11 +201,13 @@ class JobSubmissionClient:
else:
self._raise_error(r)
def _upload_package(self,
package_uri: str,
package_path: str,
include_parent_dir: Optional[bool] = False,
excludes: Optional[List[str]] = None) -> bool:
def _upload_package(
self,
package_uri: str,
package_path: str,
include_parent_dir: Optional[bool] = False,
excludes: Optional[List[str]] = None,
) -> bool:
logger.info(f"Uploading package {package_uri}.")
with tempfile.TemporaryDirectory() as tmp_dir:
protocol, package_name = uri_to_http_components(package_uri)
@ -194,26 +216,27 @@ class JobSubmissionClient:
package_path,
package_file,
include_parent_dir=include_parent_dir,
excludes=excludes)
excludes=excludes,
)
try:
r = self._do_request(
"PUT",
f"/api/packages/{protocol}/{package_name}",
data=package_file.read_bytes())
data=package_file.read_bytes(),
)
if r.status_code != 200:
self._raise_error(r)
finally:
package_file.unlink()
def _upload_package_if_needed(self,
package_path: str,
excludes: Optional[List[str]] = None) -> str:
def _upload_package_if_needed(
self, package_path: str, excludes: Optional[List[str]] = None
) -> str:
package_uri = get_uri_for_directory(package_path, excludes=excludes)
if not self._package_exists(package_uri):
self._upload_package(package_uri, package_path, excludes=excludes)
else:
logger.info(
f"Package {package_uri} already exists, skipping upload.")
logger.info(f"Package {package_uri} already exists, skipping upload.")
return package_uri
@ -230,7 +253,8 @@ class JobSubmissionClient:
if not is_uri:
logger.debug("working_dir is not a URI, attempting to upload.")
package_uri = self._upload_package_if_needed(
working_dir, excludes=runtime_env.get("excludes", None))
working_dir, excludes=runtime_env.get("excludes", None)
)
runtime_env["working_dir"] = package_uri
def get_version(self) -> str:
@ -241,12 +265,12 @@ class JobSubmissionClient:
self._raise_error(r)
def submit_job(
self,
*,
entrypoint: str,
job_id: Optional[str] = None,
runtime_env: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, str]] = None,
self,
*,
entrypoint: str,
job_id: Optional[str] = None,
runtime_env: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, str]] = None,
) -> str:
runtime_env = runtime_env or {}
metadata = metadata or {}
@ -257,11 +281,11 @@ class JobSubmissionClient:
entrypoint=entrypoint,
job_id=job_id,
runtime_env=runtime_env,
metadata=metadata)
metadata=metadata,
)
logger.debug(f"Submitting job with job_id={job_id}.")
r = self._do_request(
"POST", "/api/jobs/", json_data=dataclasses.asdict(req))
r = self._do_request("POST", "/api/jobs/", json_data=dataclasses.asdict(req))
if r.status_code == 200:
return JobSubmitResponse(**r.json()).job_id
@ -269,8 +293,8 @@ class JobSubmissionClient:
self._raise_error(r)
def stop_job(
self,
job_id: str,
self,
job_id: str,
) -> bool:
logger.debug(f"Stopping job with job_id={job_id}.")
r = self._do_request("POST", f"/api/jobs/{job_id}/stop")
@ -281,15 +305,14 @@ class JobSubmissionClient:
self._raise_error(r)
def get_job_status(
self,
job_id: str,
self,
job_id: str,
) -> JobStatusInfo:
r = self._do_request("GET", f"/api/jobs/{job_id}")
if r.status_code == 200:
response = JobStatusResponse(**r.json())
return JobStatusInfo(
status=response.status, message=response.message)
return JobStatusInfo(status=response.status, message=response.message)
else:
self._raise_error(r)
@ -304,7 +327,8 @@ class JobSubmissionClient:
async def tail_job_logs(self, job_id: str) -> Iterator[str]:
async with aiohttp.ClientSession(cookies=self._cookies) as session:
ws = await session.ws_connect(
f"{self._address}/api/jobs/{job_id}/logs/tail")
f"{self._address}/api/jobs/{job_id}/logs/tail"
)
while True:
msg = await ws.receive()

View file

@ -7,13 +7,13 @@ we ended up using job submission API call's runtime_env instead of scripts
def run():
import ray
import os
ray.init(
address=os.environ["RAY_ADDRESS"],
runtime_env={
"env_vars": {
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "SHOULD_BE_OVERRIDEN"
}
})
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "SHOULD_BE_OVERRIDEN"}
},
)
@ray.remote
def foo():

View file

@ -21,15 +21,14 @@ def conda_env(env_name):
# Clean up created conda env upon test exit to prevent leaking
del os.environ["JOB_COMPATIBILITY_TEST_TEMP_ENV"]
subprocess.run(
f"conda env remove -y --name {env_name}",
shell=True,
stdout=subprocess.PIPE)
f"conda env remove -y --name {env_name}", shell=True, stdout=subprocess.PIPE
)
def _compatibility_script_path(file_name: str) -> str:
return os.path.join(
os.path.dirname(__file__), "backwards_compatibility_scripts",
file_name)
os.path.dirname(__file__), "backwards_compatibility_scripts", file_name
)
class TestBackwardsCompatibility:
@ -48,8 +47,7 @@ class TestBackwardsCompatibility:
shell_cmd = f"{_compatibility_script_path('test_backwards_compatibility.sh')}" # noqa: E501
try:
subprocess.check_output(
shell_cmd, shell=True, stderr=subprocess.STDOUT)
subprocess.check_output(shell_cmd, shell=True, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
logger.error(str(e))
logger.error(e.stdout.decode())

View file

@ -34,8 +34,7 @@ def mock_sdk_client():
if "RAY_ADDRESS" in os.environ:
del os.environ["RAY_ADDRESS"]
with mock.patch("ray.dashboard.modules.job.cli.JobSubmissionClient"
) as mock_client:
with mock.patch("ray.dashboard.modules.job.cli.JobSubmissionClient") as mock_client:
# In python 3.6 it will fail with error
# 'async for' requires an object with __aiter__ method, got MagicMock"
mock_client().tail_job_logs.return_value = AsyncIterator(range(10))
@ -52,9 +51,7 @@ def runtime_env_formats():
"working_dir": "s3://bogus.zip",
"conda": "conda_env",
"pip": ["pip-install-test"],
"env_vars": {
"hi": "hi2"
}
"env_vars": {"hi": "hi2"},
}
yaml_file = path / "env.yaml"
@ -86,14 +83,13 @@ class TestSubmit:
# Test passing address via command line.
result = runner.invoke(
job_cli_group,
["submit", "--address=arg_addr", "--", "echo hello"])
job_cli_group, ["submit", "--address=arg_addr", "--", "echo hello"]
)
assert mock_sdk_client.called_with("arg_addr")
assert result.exit_code == 0
# Test passing address via env var.
with set_env_var("RAY_ADDRESS", "env_addr"):
result = runner.invoke(job_cli_group,
["submit", "--", "echo hello"])
result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
assert result.exit_code == 0
assert mock_sdk_client.called_with("env_addr")
# Test passing no address.
@ -106,24 +102,22 @@ class TestSubmit:
mock_client_instance = mock_sdk_client.return_value
with set_env_var("RAY_ADDRESS", "env_addr"):
result = runner.invoke(job_cli_group,
["submit", "--", "echo hello"])
result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
assert result.exit_code == 0
assert mock_client_instance.called_with(runtime_env={})
result = runner.invoke(
job_cli_group,
["submit", "--", "--working-dir", "blah", "--", "echo hello"])
["submit", "--", "--working-dir", "blah", "--", "echo hello"],
)
assert result.exit_code == 0
assert mock_client_instance.called_with(
runtime_env={"working_dir": "blah"})
assert mock_client_instance.called_with(runtime_env={"working_dir": "blah"})
result = runner.invoke(
job_cli_group,
["submit", "--", "--working-dir='.'", "--", "echo hello"])
job_cli_group, ["submit", "--", "--working-dir='.'", "--", "echo hello"]
)
assert result.exit_code == 0
assert mock_client_instance.called_with(
runtime_env={"working_dir": "."})
assert mock_client_instance.called_with(runtime_env={"working_dir": "."})
def test_runtime_env(self, mock_sdk_client, runtime_env_formats):
runner = CliRunner()
@ -133,39 +127,64 @@ class TestSubmit:
with set_env_var("RAY_ADDRESS", "env_addr"):
# Test passing via file.
result = runner.invoke(
job_cli_group,
["submit", "--runtime-env", env_yaml, "--", "echo hello"])
job_cli_group, ["submit", "--runtime-env", env_yaml, "--", "echo hello"]
)
assert result.exit_code == 0
assert mock_client_instance.called_with(runtime_env=env_dict)
# Test passing via json.
result = runner.invoke(
job_cli_group,
["submit", "--runtime-env-json", env_json, "--", "echo hello"])
["submit", "--runtime-env-json", env_json, "--", "echo hello"],
)
assert result.exit_code == 0
assert mock_client_instance.called_with(runtime_env=env_dict)
# Test passing both throws an error.
result = runner.invoke(job_cli_group, [
"submit", "--runtime-env", env_yaml, "--runtime-env-json",
env_json, "--", "echo hello"
])
result = runner.invoke(
job_cli_group,
[
"submit",
"--runtime-env",
env_yaml,
"--runtime-env-json",
env_json,
"--",
"echo hello",
],
)
assert result.exit_code == 1
assert "Only one of" in str(result.exception)
# Test overriding working_dir.
env_dict.update(working_dir=".")
result = runner.invoke(job_cli_group, [
"submit", "--runtime-env", env_yaml, "--working-dir", ".",
"--", "echo hello"
])
result = runner.invoke(
job_cli_group,
[
"submit",
"--runtime-env",
env_yaml,
"--working-dir",
".",
"--",
"echo hello",
],
)
assert result.exit_code == 0
assert mock_client_instance.called_with(runtime_env=env_dict)
result = runner.invoke(job_cli_group, [
"submit", "--runtime-env-json", env_json, "--working-dir", ".",
"--", "echo hello"
])
result = runner.invoke(
job_cli_group,
[
"submit",
"--runtime-env-json",
env_json,
"--working-dir",
".",
"--",
"echo hello",
],
)
assert result.exit_code == 0
assert mock_client_instance.called_with(runtime_env=env_dict)
@ -174,18 +193,18 @@ class TestSubmit:
mock_client_instance = mock_sdk_client.return_value
with set_env_var("RAY_ADDRESS", "env_addr"):
result = runner.invoke(job_cli_group,
["submit", "--", "echo hello"])
result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
assert result.exit_code == 0
assert mock_client_instance.called_with(job_id=None)
result = runner.invoke(
job_cli_group,
["submit", "--", "--job-id=my_job_id", "echo hello"])
job_cli_group, ["submit", "--", "--job-id=my_job_id", "echo hello"]
)
assert result.exit_code == 0
assert mock_client_instance.called_with(job_id="my_job_id")
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -60,25 +60,27 @@ class TestRayAddress:
def test_empty_ray_address(self, ray_start_stop):
with set_env_var("RAY_ADDRESS", None):
completed_process = subprocess.run(
["ray", "job", "submit", "--", "echo hello"],
stderr=subprocess.PIPE)
["ray", "job", "submit", "--", "echo hello"], stderr=subprocess.PIPE
)
stderr = completed_process.stderr.decode("utf-8")
# Current dashboard module that raises no exception from requests..
assert ("Address must be specified using either the "
"--address flag or RAY_ADDRESS environment") in stderr
assert (
"Address must be specified using either the "
"--address flag or RAY_ADDRESS environment"
) in stderr
def test_ray_client_address(self, ray_start_stop):
completed_process = subprocess.run(
["ray", "job", "submit", "--", "echo hello"],
stdout=subprocess.PIPE)
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
)
stdout = completed_process.stdout.decode("utf-8")
assert "hello" in stdout
assert "succeeded" in stdout
def test_valid_http_ray_address(self, ray_start_stop):
completed_process = subprocess.run(
["ray", "job", "submit", "--", "echo hello"],
stdout=subprocess.PIPE)
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
)
stdout = completed_process.stdout.decode("utf-8")
assert "hello" in stdout
assert "succeeded" in stdout
@ -87,8 +89,8 @@ class TestRayAddress:
with set_env_var("RAY_ADDRESS", "http://127.0.0.1:8265"):
with ray_cluster_manager():
completed_process = subprocess.run(
["ray", "job", "submit", "--", "echo hello"],
stdout=subprocess.PIPE)
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
)
stdout = completed_process.stdout.decode("utf-8")
assert "hello" in stdout
assert "succeeded" in stdout
@ -97,8 +99,8 @@ class TestRayAddress:
with set_env_var("RAY_ADDRESS", "127.0.0.1:8265"):
with ray_cluster_manager():
completed_process = subprocess.run(
["ray", "job", "submit", "--", "echo hello"],
stdout=subprocess.PIPE)
["ray", "job", "submit", "--", "echo hello"], stdout=subprocess.PIPE
)
stdout = completed_process.stdout.decode("utf-8")
assert "hello" in stdout
assert "succeeded" in stdout
@ -109,7 +111,8 @@ class TestJobSubmit:
"""Should tail logs and wait for process to exit."""
cmd = "sleep 1 && echo hello && sleep 1 && echo hello"
completed_process = subprocess.run(
["ray", "job", "submit", "--", cmd], stdout=subprocess.PIPE)
["ray", "job", "submit", "--", cmd], stdout=subprocess.PIPE
)
stdout = completed_process.stdout.decode("utf-8")
assert "hello\nhello" in stdout
assert "succeeded" in stdout
@ -118,8 +121,8 @@ class TestJobSubmit:
"""Should exit immediately w/o printing logs."""
cmd = "echo hello && sleep 1000"
completed_process = subprocess.run(
["ray", "job", "submit", "--no-wait", "--", cmd],
stdout=subprocess.PIPE)
["ray", "job", "submit", "--no-wait", "--", cmd], stdout=subprocess.PIPE
)
stdout = completed_process.stdout.decode("utf-8")
assert "hello" not in stdout
assert "Tailing logs until the job exits" not in stdout
@ -130,13 +133,13 @@ class TestJobStop:
"""Should wait until the job is stopped."""
cmd = "sleep 1000"
job_id = "test_basic_stop"
completed_process = subprocess.run([
"ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--",
cmd
])
completed_process = subprocess.run(
["ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--", cmd]
)
completed_process = subprocess.run(
["ray", "job", "stop", job_id], stdout=subprocess.PIPE)
["ray", "job", "stop", job_id], stdout=subprocess.PIPE
)
stdout = completed_process.stdout.decode("utf-8")
assert "Waiting for job" in stdout
assert f"Job '{job_id}' was stopped" in stdout
@ -145,14 +148,13 @@ class TestJobStop:
"""Should not wait until the job is stopped."""
cmd = "echo hello && sleep 1000"
job_id = "test_stop_no_wait"
completed_process = subprocess.run([
"ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--",
cmd
])
completed_process = subprocess.run(
["ray", "job", "submit", "--no-wait", f"--job-id={job_id}", "--", cmd]
)
completed_process = subprocess.run(
["ray", "job", "stop", "--no-wait", job_id],
stdout=subprocess.PIPE)
["ray", "job", "stop", "--no-wait", job_id], stdout=subprocess.PIPE
)
stdout = completed_process.stdout.decode("utf-8")
assert "Waiting for job" not in stdout
assert f"Job '{job_id}' was stopped" not in stdout

View file

@ -24,82 +24,61 @@ class TestJobSubmitRequestValidation:
assert r.entrypoint == "abc"
assert r.job_id is None
r = validate_request_type({
"entrypoint": "abc",
"job_id": "123"
}, JobSubmitRequest)
r = validate_request_type(
{"entrypoint": "abc", "job_id": "123"}, JobSubmitRequest
)
assert r.entrypoint == "abc"
assert r.job_id == "123"
with pytest.raises(TypeError, match="must be a string"):
validate_request_type({
"entrypoint": 123,
"job_id": 1
}, JobSubmitRequest)
validate_request_type({"entrypoint": 123, "job_id": 1}, JobSubmitRequest)
def test_validate_runtime_env(self):
r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest)
assert r.entrypoint == "abc"
assert r.runtime_env is None
r = validate_request_type({
"entrypoint": "abc",
"runtime_env": {
"hi": "hi2"
}
}, JobSubmitRequest)
r = validate_request_type(
{"entrypoint": "abc", "runtime_env": {"hi": "hi2"}}, JobSubmitRequest
)
assert r.entrypoint == "abc"
assert r.runtime_env == {"hi": "hi2"}
with pytest.raises(TypeError, match="must be a dict"):
validate_request_type({
"entrypoint": "abc",
"runtime_env": 123
}, JobSubmitRequest)
validate_request_type(
{"entrypoint": "abc", "runtime_env": 123}, JobSubmitRequest
)
with pytest.raises(TypeError, match="keys must be strings"):
validate_request_type({
"entrypoint": "abc",
"runtime_env": {
1: "hi"
}
}, JobSubmitRequest)
validate_request_type(
{"entrypoint": "abc", "runtime_env": {1: "hi"}}, JobSubmitRequest
)
def test_validate_metadata(self):
r = validate_request_type({"entrypoint": "abc"}, JobSubmitRequest)
assert r.entrypoint == "abc"
assert r.metadata is None
r = validate_request_type({
"entrypoint": "abc",
"metadata": {
"hi": "hi2"
}
}, JobSubmitRequest)
r = validate_request_type(
{"entrypoint": "abc", "metadata": {"hi": "hi2"}}, JobSubmitRequest
)
assert r.entrypoint == "abc"
assert r.metadata == {"hi": "hi2"}
with pytest.raises(TypeError, match="must be a dict"):
validate_request_type({
"entrypoint": "abc",
"metadata": 123
}, JobSubmitRequest)
validate_request_type(
{"entrypoint": "abc", "metadata": 123}, JobSubmitRequest
)
with pytest.raises(TypeError, match="keys must be strings"):
validate_request_type({
"entrypoint": "abc",
"metadata": {
1: "hi"
}
}, JobSubmitRequest)
validate_request_type(
{"entrypoint": "abc", "metadata": {1: "hi"}}, JobSubmitRequest
)
with pytest.raises(TypeError, match="values must be strings"):
validate_request_type({
"entrypoint": "abc",
"metadata": {
"hi": 1
}
}, JobSubmitRequest)
validate_request_type(
{"entrypoint": "abc", "metadata": {"hi": 1}}, JobSubmitRequest
)
def test_uri_to_http_and_back():
@ -127,4 +106,5 @@ def test_uri_to_http_and_back():
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -8,11 +8,17 @@ import pytest
import ray
from ray.dashboard.tests.conftest import * # noqa
from ray.tests.conftest import _ray_start
from ray._private.test_utils import (format_web_url, wait_for_condition,
wait_until_server_available)
from ray._private.test_utils import (
format_web_url,
wait_for_condition,
wait_until_server_available,
)
from ray.dashboard.modules.job.common import CURRENT_VERSION, JobStatus
from ray.dashboard.modules.job.sdk import (ClusterInfo, JobSubmissionClient,
parse_cluster_info)
from ray.dashboard.modules.job.sdk import (
ClusterInfo,
JobSubmissionClient,
parse_cluster_info,
)
from unittest.mock import patch
logger = logging.getLogger(__name__)
@ -50,8 +56,8 @@ def _check_job_stopped(client: JobSubmissionClient, job_id: str) -> bool:
@pytest.fixture(
scope="module",
params=["no_working_dir", "local_working_dir", "s3_working_dir"])
scope="module", params=["no_working_dir", "local_working_dir", "s3_working_dir"]
)
def working_dir_option(request):
if request.param == "no_working_dir":
yield {
@ -81,9 +87,7 @@ def working_dir_option(request):
f.write("from test_module.test import run_test\n")
yield {
"runtime_env": {
"working_dir": tmp_dir
},
"runtime_env": {"working_dir": tmp_dir},
"entrypoint": "python test.py",
"expected_logs": "Hello from test_module!\n",
}
@ -104,7 +108,8 @@ def test_submit_job(job_sdk_client, working_dir_option):
job_id = client.submit_job(
entrypoint=working_dir_option["entrypoint"],
runtime_env=working_dir_option["runtime_env"])
runtime_env=working_dir_option["runtime_env"],
)
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
@ -133,7 +138,8 @@ def test_http_bad_request(job_sdk_client):
def test_invalid_runtime_env(job_sdk_client):
client = job_sdk_client
job_id = client.submit_job(
entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"})
entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"}
)
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
status = client.get_job_status(job_id)
@ -143,8 +149,8 @@ def test_invalid_runtime_env(job_sdk_client):
def test_runtime_env_setup_failure(job_sdk_client):
client = job_sdk_client
job_id = client.submit_job(
entrypoint="echo hello",
runtime_env={"working_dir": "s3://does_not_exist.zip"})
entrypoint="echo hello", runtime_env={"working_dir": "s3://does_not_exist.zip"}
)
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
status = client.get_job_status(job_id)
@ -168,8 +174,8 @@ raise RuntimeError('Intentionally failed.')
file.write(driver_script)
job_id = client.submit_job(
entrypoint="python test_script.py",
runtime_env={"working_dir": tmp_dir})
entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
)
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
logs = client.get_job_logs(job_id)
@ -196,8 +202,8 @@ raise RuntimeError('Intentionally failed.')
file.write(driver_script)
job_id = client.submit_job(
entrypoint="python test_script.py",
runtime_env={"working_dir": tmp_dir})
entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
)
assert client.stop_job(job_id) is True
wait_for_condition(_check_job_stopped, client=client, job_id=job_id)
@ -206,28 +212,31 @@ def test_job_metadata(job_sdk_client):
client = job_sdk_client
print_metadata_cmd = (
"python -c\""
'python -c"'
"import ray;"
"ray.init();"
"job_config=ray.worker.global_worker.core_worker.get_job_config();"
"print(dict(sorted(job_config.metadata.items())))"
"\"")
'"'
)
job_id = client.submit_job(
entrypoint=print_metadata_cmd,
metadata={
"key1": "val1",
"key2": "val2"
})
entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
)
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
assert str({
"job_name": job_id,
"job_submission_id": job_id,
"key1": "val1",
"key2": "val2"
}) in client.get_job_logs(job_id)
assert (
str(
{
"job_name": job_id,
"job_submission_id": job_id,
"key1": "val1",
"key2": "val2",
}
)
in client.get_job_logs(job_id)
)
def test_pass_job_id(job_sdk_client):
@ -261,19 +270,19 @@ def test_submit_optional_args(job_sdk_client):
json_data={"entrypoint": "ls"},
)
wait_for_condition(
_check_job_succeeded, client=client, job_id=r.json()["job_id"])
wait_for_condition(_check_job_succeeded, client=client, job_id=r.json()["job_id"])
def test_missing_resources(job_sdk_client):
"""Check that 404s are raised for resources that don't exist."""
client = job_sdk_client
conditions = [("GET",
"/api/jobs/fake_job_id"), ("GET",
"/api/jobs/fake_job_id/logs"),
("POST", "/api/jobs/fake_job_id/stop"),
("GET", "/api/packages/fake_package_uri")]
conditions = [
("GET", "/api/jobs/fake_job_id"),
("GET", "/api/jobs/fake_job_id/logs"),
("POST", "/api/jobs/fake_job_id/stop"),
("GET", "/api/packages/fake_package_uri"),
]
for method, route in conditions:
assert client._do_request(method, route).status_code == 404
@ -287,7 +296,7 @@ def test_version_endpoint(job_sdk_client):
assert r.json() == {
"version": CURRENT_VERSION,
"ray_version": ray.__version__,
"ray_commit": ray.__commit__
"ray_commit": ray.__commit__,
}
@ -306,26 +315,31 @@ def test_request_headers(job_sdk_client):
cookies=None,
data=None,
json={"entrypoint": "ls"},
headers={
"Connection": "keep-alive",
"Authorization": "TOK:<MY_TOKEN>"
})
headers={"Connection": "keep-alive", "Authorization": "TOK:<MY_TOKEN>"},
)
@pytest.mark.parametrize("address", [
"http://127.0.0.1", "https://127.0.0.1", "ray://127.0.0.1",
"fake_module://127.0.0.1"
])
@pytest.mark.parametrize(
"address",
[
"http://127.0.0.1",
"https://127.0.0.1",
"ray://127.0.0.1",
"fake_module://127.0.0.1",
],
)
def test_parse_cluster_info(address: str):
if address.startswith("ray"):
assert parse_cluster_info(address, False) == ClusterInfo(
address="http" + address[address.index("://"):],
address="http" + address[address.index("://") :],
cookies=None,
metadata=None,
headers=None)
headers=None,
)
elif address.startswith("http") or address.startswith("https"):
assert parse_cluster_info(address, False) == ClusterInfo(
address=address, cookies=None, metadata=None, headers=None)
address=address, cookies=None, metadata=None, headers=None
)
else:
with pytest.raises(RuntimeError):
parse_cluster_info(address, False)
@ -347,8 +361,8 @@ for i in range(100):
f.write(driver_script)
job_id = client.submit_job(
entrypoint="python test_script.py",
runtime_env={"working_dir": tmp_dir})
entrypoint="python test_script.py", runtime_env={"working_dir": tmp_dir}
)
i = 0
async for lines in client.tail_job_logs(job_id):

View file

@ -8,8 +8,11 @@ import signal
import pytest
import ray
from ray.dashboard.modules.job.common import (JobStatus, JOB_ID_METADATA_KEY,
JOB_NAME_METADATA_KEY)
from ray.dashboard.modules.job.common import (
JobStatus,
JOB_ID_METADATA_KEY,
JOB_NAME_METADATA_KEY,
)
from ray.dashboard.modules.job.job_manager import generate_job_id, JobManager
from ray._private.test_utils import SignalActor, wait_for_condition
@ -28,7 +31,8 @@ def job_manager(shared_ray_instance):
def _driver_script_path(file_name: str) -> str:
return os.path.join(
os.path.dirname(__file__), "subprocess_driver_scripts", file_name)
os.path.dirname(__file__), "subprocess_driver_scripts", file_name
)
def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
@ -36,12 +40,15 @@ def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
pid_file = os.path.join(tmp_dir, "pid")
# Write subprocess pid to pid_file and block until tmp_file is present.
wait_for_file_cmd = (f"echo $$ > {pid_file} && "
f"until [ -f {tmp_file} ]; "
"do echo 'Waiting...' && sleep 1; "
"done")
wait_for_file_cmd = (
f"echo $$ > {pid_file} && "
f"until [ -f {tmp_file} ]; "
"do echo 'Waiting...' && sleep 1; "
"done"
)
job_id = job_manager.submit_job(
entrypoint=wait_for_file_cmd, _start_signal_actor=start_signal_actor)
entrypoint=wait_for_file_cmd, _start_signal_actor=start_signal_actor
)
status = job_manager.get_job_status(job_id)
if start_signal_actor:
@ -50,11 +57,9 @@ def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
logs = job_manager.get_job_logs(job_id)
assert logs == ""
else:
wait_for_condition(
check_job_running, job_manager=job_manager, job_id=job_id)
wait_for_condition(check_job_running, job_manager=job_manager, job_id=job_id)
wait_for_condition(
lambda: "Waiting..." in job_manager.get_job_logs(job_id))
wait_for_condition(lambda: "Waiting..." in job_manager.get_job_logs(job_id))
return pid_file, tmp_file, job_id
@ -63,25 +68,19 @@ def check_job_succeeded(job_manager, job_id):
status = job_manager.get_job_status(job_id)
if status.status == JobStatus.FAILED:
raise RuntimeError(f"Job failed! {status.message}")
assert status.status in {
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.SUCCEEDED
}
assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.SUCCEEDED}
return status.status == JobStatus.SUCCEEDED
def check_job_failed(job_manager, job_id):
status = job_manager.get_job_status(job_id)
assert status.status in {
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.FAILED
}
assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.FAILED}
return status.status == JobStatus.FAILED
def check_job_stopped(job_manager, job_id):
status = job_manager.get_job_status(job_id)
assert status.status in {
JobStatus.PENDING, JobStatus.RUNNING, JobStatus.STOPPED
}
assert status.status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.STOPPED}
return status.status == JobStatus.STOPPED
@ -111,12 +110,10 @@ def test_generate_job_id():
def test_pass_job_id(job_manager):
job_id = "my_custom_id"
returned_id = job_manager.submit_job(
entrypoint="echo hello", job_id=job_id)
returned_id = job_manager.submit_job(entrypoint="echo hello", job_id=job_id)
assert returned_id == job_id
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
# Check that the same job_id is rejected.
with pytest.raises(RuntimeError):
@ -127,23 +124,20 @@ class TestShellScriptExecution:
def test_submit_basic_echo(self, job_manager):
job_id = job_manager.submit_job(entrypoint="echo hello")
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert job_manager.get_job_logs(job_id) == "hello\n"
def test_submit_stderr(self, job_manager):
job_id = job_manager.submit_job(entrypoint="echo error 1>&2")
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert job_manager.get_job_logs(job_id) == "error\n"
def test_submit_ls_grep(self, job_manager):
grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py"
job_id = job_manager.submit_job(entrypoint=grep_cmd)
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n"
def test_subprocess_exception(self, job_manager):
@ -161,8 +155,7 @@ class TestShellScriptExecution:
status = job_manager.get_job_status(job_id)
if status.status != JobStatus.FAILED:
return False
if ("Exception: Script failed with exception !" not in
status.message):
if "Exception: Script failed with exception !" not in status.message:
return False
return job_manager._get_actor_for_job(job_id) is None
@ -172,14 +165,13 @@ class TestShellScriptExecution:
def test_submit_with_s3_runtime_env(self, job_manager):
job_id = job_manager.submit_job(
entrypoint="python script.py",
runtime_env={
"working_dir": "s3://runtime-env-test/script_runtime_env.zip"
})
runtime_env={"working_dir": "s3://runtime-env-test/script_runtime_env.zip"},
)
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert job_manager.get_job_logs(
job_id) == "Executing main() from script.py !!\n"
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert (
job_manager.get_job_logs(job_id) == "Executing main() from script.py !!\n"
)
class TestRuntimeEnv:
@ -193,14 +185,10 @@ class TestRuntimeEnv:
"""
job_id = job_manager.submit_job(
entrypoint="echo $TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR",
runtime_env={
"env_vars": {
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"
}
})
runtime_env={"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"}},
)
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert job_manager.get_job_logs(job_id) == "233\n"
def test_multiple_runtime_envs(self, job_manager):
@ -208,28 +196,32 @@ class TestRuntimeEnv:
job_id_1 = job_manager.submit_job(
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
runtime_env={
"env_vars": {
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"
}
})
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
},
)
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id_1)
check_job_succeeded, job_manager=job_manager, job_id=job_id_1
)
logs = job_manager.get_job_logs(job_id_1)
assert "{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_1_VAR'}}" in logs # noqa: E501
assert (
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_1_VAR'}}" in logs
) # noqa: E501
job_id_2 = job_manager.submit_job(
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
runtime_env={
"env_vars": {
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_2_VAR"
}
})
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_2_VAR"}
},
)
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id_2)
check_job_succeeded, job_manager=job_manager, job_id=job_id_2
)
logs = job_manager.get_job_logs(job_id_2)
assert "{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_2_VAR'}}" in logs # noqa: E501
assert (
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_2_VAR'}}" in logs
) # noqa: E501
def test_env_var_and_driver_job_config_warning(self, job_manager):
"""Ensure we got error message from worker.py and job logs
@ -238,17 +230,15 @@ class TestRuntimeEnv:
job_id = job_manager.submit_job(
entrypoint=f"python {_driver_script_path('override_env_var.py')}",
runtime_env={
"env_vars": {
"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"
}
})
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
},
)
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
logs = job_manager.get_job_logs(job_id)
assert logs.startswith(
"Both RAY_JOB_CONFIG_JSON_ENV_VAR and ray.init(runtime_env) "
"are provided")
"Both RAY_JOB_CONFIG_JSON_ENV_VAR and ray.init(runtime_env) " "are provided"
)
assert "JOB_1_VAR" in logs
def test_failed_runtime_env_validation(self, job_manager):
@ -257,7 +247,8 @@ class TestRuntimeEnv:
"""
run_cmd = f"python {_driver_script_path('override_env_var.py')}"
job_id = job_manager.submit_job(
entrypoint=run_cmd, runtime_env={"working_dir": "path_not_exist"})
entrypoint=run_cmd, runtime_env={"working_dir": "path_not_exist"}
)
status = job_manager.get_job_status(job_id)
assert status.status == JobStatus.FAILED
@ -269,11 +260,10 @@ class TestRuntimeEnv:
"""
run_cmd = f"python {_driver_script_path('override_env_var.py')}"
job_id = job_manager.submit_job(
entrypoint=run_cmd,
runtime_env={"working_dir": "s3://does_not_exist.zip"})
entrypoint=run_cmd, runtime_env={"working_dir": "s3://does_not_exist.zip"}
)
wait_for_condition(
check_job_failed, job_manager=job_manager, job_id=job_id)
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
status = job_manager.get_job_status(job_id)
assert "runtime_env setup failed" in status.message
@ -283,69 +273,67 @@ class TestRuntimeEnv:
return str(dict(sorted(d.items())))
print_metadata_cmd = (
"python -c\""
'python -c"'
"import ray;"
"ray.init();"
"job_config=ray.worker.global_worker.core_worker.get_job_config();"
"print(dict(sorted(job_config.metadata.items())))"
"\"")
'"'
)
# Check that we default to only the job ID and job name.
job_id = job_manager.submit_job(entrypoint=print_metadata_cmd)
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert dict_to_str({
JOB_NAME_METADATA_KEY: job_id,
JOB_ID_METADATA_KEY: job_id
}) in job_manager.get_job_logs(job_id)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert dict_to_str(
{JOB_NAME_METADATA_KEY: job_id, JOB_ID_METADATA_KEY: job_id}
) in job_manager.get_job_logs(job_id)
# Check that we can pass custom metadata.
job_id = job_manager.submit_job(
entrypoint=print_metadata_cmd,
metadata={
"key1": "val1",
"key2": "val2"
})
entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
)
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert dict_to_str({
JOB_NAME_METADATA_KEY: job_id,
JOB_ID_METADATA_KEY: job_id,
"key1": "val1",
"key2": "val2"
}) in job_manager.get_job_logs(job_id)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert (
dict_to_str(
{
JOB_NAME_METADATA_KEY: job_id,
JOB_ID_METADATA_KEY: job_id,
"key1": "val1",
"key2": "val2",
}
)
in job_manager.get_job_logs(job_id)
)
# Check that we can override job name.
job_id = job_manager.submit_job(
entrypoint=print_metadata_cmd,
metadata={JOB_NAME_METADATA_KEY: "custom_name"})
metadata={JOB_NAME_METADATA_KEY: "custom_name"},
)
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert dict_to_str({
JOB_NAME_METADATA_KEY: "custom_name",
JOB_ID_METADATA_KEY: job_id
}) in job_manager.get_job_logs(job_id)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
assert dict_to_str(
{JOB_NAME_METADATA_KEY: "custom_name", JOB_ID_METADATA_KEY: job_id}
) in job_manager.get_job_logs(job_id)
class TestAsyncAPI:
def test_status_and_logs_while_blocking(self, job_manager):
with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, tmp_file, job_id = _run_hanging_command(
job_manager, tmp_dir)
pid_file, tmp_file, job_id = _run_hanging_command(job_manager, tmp_dir)
with open(pid_file, "r") as file:
pid = int(file.read())
assert psutil.pid_exists(pid), (
"driver subprocess should be running")
assert psutil.pid_exists(pid), "driver subprocess should be running"
# Signal the job to exit by writing to the file.
with open(tmp_file, "w") as f:
print("hello", file=f)
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
# Ensure driver subprocess gets cleaned up after job reached
# termination state
wait_for_condition(check_subprocess_cleaned, pid=pid)
@ -356,7 +344,8 @@ class TestAsyncAPI:
assert job_manager.stop_job(job_id) is True
wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id)
check_job_stopped, job_manager=job_manager, job_id=job_id
)
# Assert re-stopping a stopped job also returns False
wait_for_condition(lambda: job_manager.stop_job(job_id) is False)
# Assert stopping non-existent job returns False
@ -375,13 +364,11 @@ class TestAsyncAPI:
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
with open(pid_file, "r") as file:
pid = int(file.read())
assert psutil.pid_exists(pid), (
"driver subprocess should be running")
assert psutil.pid_exists(pid), "driver subprocess should be running"
actor = job_manager._get_actor_for_job(job_id)
ray.kill(actor, no_restart=True)
wait_for_condition(
check_job_failed, job_manager=job_manager, job_id=job_id)
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
# Ensure driver subprocess gets cleaned up after job reached
# termination state
@ -398,16 +385,18 @@ class TestAsyncAPI:
with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, _, job_id = _run_hanging_command(
job_manager, tmp_dir, start_signal_actor=start_signal_actor)
job_manager, tmp_dir, start_signal_actor=start_signal_actor
)
assert not os.path.exists(pid_file), (
"driver subprocess should NOT be running while job is "
"still PENDING.")
"driver subprocess should NOT be running while job is " "still PENDING."
)
assert job_manager.stop_job(job_id) is True
# Send run signal to unblock run function
ray.get(start_signal_actor.send.remote())
wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id)
check_job_stopped, job_manager=job_manager, job_id=job_id
)
def test_kill_job_actor_in_pending(self, job_manager):
"""
@ -420,16 +409,16 @@ class TestAsyncAPI:
with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, _, job_id = _run_hanging_command(
job_manager, tmp_dir, start_signal_actor=start_signal_actor)
job_manager, tmp_dir, start_signal_actor=start_signal_actor
)
assert not os.path.exists(pid_file), (
"driver subprocess should NOT be running while job is "
"still PENDING.")
"driver subprocess should NOT be running while job is " "still PENDING."
)
actor = job_manager._get_actor_for_job(job_id)
ray.kill(actor, no_restart=True)
wait_for_condition(
check_job_failed, job_manager=job_manager, job_id=job_id)
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager):
"""
@ -442,12 +431,12 @@ class TestAsyncAPI:
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
with open(pid_file, "r") as file:
pid = int(file.read())
assert psutil.pid_exists(pid), (
"driver subprocess should be running")
assert psutil.pid_exists(pid), "driver subprocess should be running"
assert job_manager.stop_job(job_id) is True
wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id)
check_job_stopped, job_manager=job_manager, job_id=job_id
)
# Ensure driver subprocess gets cleaned up after job reached
# termination state
@ -455,11 +444,9 @@ class TestAsyncAPI:
class TestTailLogs:
async def _tail_and_assert_logs(self,
job_id,
job_manager,
expected_log="",
num_iteration=5):
async def _tail_and_assert_logs(
self, job_id, job_manager, expected_log="", num_iteration=5
):
i = 0
async for lines in job_manager.tail_job_logs(job_id):
assert all(s == expected_log for s in lines.strip().split("\n"))
@ -470,8 +457,7 @@ class TestTailLogs:
@pytest.mark.asyncio
async def test_unknown_job(self, job_manager):
with pytest.raises(
RuntimeError, match="Job 'unknown' does not exist."):
with pytest.raises(RuntimeError, match="Job 'unknown' does not exist."):
async for _ in job_manager.tail_job_logs("unknown"):
pass
@ -482,33 +468,31 @@ class TestTailLogs:
with tempfile.TemporaryDirectory() as tmp_dir:
_, tmp_file, job_id = _run_hanging_command(
job_manager, tmp_dir, start_signal_actor=start_signal_actor)
job_manager, tmp_dir, start_signal_actor=start_signal_actor
)
# TODO(edoakes): check we get no logs before actor starts (not sure
# how to timeout the iterator call).
assert job_manager.get_job_status(
job_id).status == JobStatus.PENDING
assert job_manager.get_job_status(job_id).status == JobStatus.PENDING
# Signal job to start.
ray.get(start_signal_actor.send.remote())
await self._tail_and_assert_logs(
job_id,
job_manager,
expected_log="Waiting...",
num_iteration=5)
job_id, job_manager, expected_log="Waiting...", num_iteration=5
)
# Signal the job to exit by writing to the file.
with open(tmp_file, "w") as f:
print("hello", file=f)
async for lines in job_manager.tail_job_logs(job_id):
assert all(
s == "Waiting..." for s in lines.strip().split("\n"))
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="")
wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id)
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
@pytest.mark.asyncio
async def test_failed_job(self, job_manager):
@ -517,22 +501,18 @@ class TestTailLogs:
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
await self._tail_and_assert_logs(
job_id,
job_manager,
expected_log="Waiting...",
num_iteration=5)
job_id, job_manager, expected_log="Waiting...", num_iteration=5
)
# Kill the job unexpectedly.
with open(pid_file, "r") as f:
os.kill(int(f.read()), signal.SIGKILL)
async for lines in job_manager.tail_job_logs(job_id):
assert all(
s == "Waiting..." for s in lines.strip().split("\n"))
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="")
wait_for_condition(
check_job_failed, job_manager=job_manager, job_id=job_id)
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
@pytest.mark.asyncio
async def test_stopped_job(self, job_manager):
@ -541,21 +521,19 @@ class TestTailLogs:
_, _, job_id = _run_hanging_command(job_manager, tmp_dir)
await self._tail_and_assert_logs(
job_id,
job_manager,
expected_log="Waiting...",
num_iteration=5)
job_id, job_manager, expected_log="Waiting...", num_iteration=5
)
# Stop the job via the API.
job_manager.stop_job(job_id)
async for lines in job_manager.tail_job_logs(job_id):
assert all(
s == "Waiting..." for s in lines.strip().split("\n"))
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="")
wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id)
check_job_stopped, job_manager=job_manager, job_id=job_id
)
def test_logs_streaming(job_manager):
@ -568,7 +546,7 @@ while True:
time.sleep(1)
"""
stream_logs_cmd = f"python -c \"{stream_logs_script}\""
stream_logs_cmd = f'python -c "{stream_logs_script}"'
job_id = job_manager.submit_job(entrypoint=stream_logs_cmd)
wait_for_condition(lambda: "STREAMED" in job_manager.get_job_logs(job_id))

View file

@ -12,7 +12,7 @@ def tmp():
yield f.name
class TestIterLine():
class TestIterLine:
def test_invalid_type(self):
with pytest.raises(TypeError, match="path must be a string"):
next(file_tail_iterator(1))

View file

@ -19,10 +19,8 @@ class LogHead(dashboard_utils.DashboardHeadModule):
self._proxy_session = aiohttp.ClientSession(auto_decompress=False)
log_utils.register_mimetypes()
routes.static("/logs", self._dashboard_head.log_dir, show_index=True)
GlobalSignals.node_info_fetched.append(
self.insert_log_url_to_node_info)
GlobalSignals.node_summary_fetched.append(
self.insert_log_url_to_node_info)
GlobalSignals.node_info_fetched.append(self.insert_log_url_to_node_info)
GlobalSignals.node_summary_fetched.append(self.insert_log_url_to_node_info)
async def insert_log_url_to_node_info(self, node_info):
node_id = node_info.get("raylet", {}).get("nodeId")
@ -33,7 +31,8 @@ class LogHead(dashboard_utils.DashboardHeadModule):
return
agent_http_port, _ = agent_port
log_url = self.LOG_URL_TEMPLATE.format(
ip=node_info.get("ip"), port=agent_http_port)
ip=node_info.get("ip"), port=agent_http_port
)
node_info["logUrl"] = log_url
@routes.get("/log_index")
@ -43,15 +42,16 @@ class LogHead(dashboard_utils.DashboardHeadModule):
for node_id, ports in DataSource.agents.items():
ip = DataSource.node_id_to_ip[node_id]
agent_ips.append(ip)
url_list.append(
self.LOG_URL_TEMPLATE.format(ip=ip, port=str(ports[0])))
url_list.append(self.LOG_URL_TEMPLATE.format(ip=ip, port=str(ports[0])))
if self._dashboard_head.ip not in agent_ips:
url_list.append(
self.LOG_URL_TEMPLATE.format(
ip=self._dashboard_head.ip,
port=self._dashboard_head.http_port))
ip=self._dashboard_head.ip, port=self._dashboard_head.http_port
)
)
return aiohttp.web.Response(
text=self._directory_as_html(url_list), content_type="text/html")
text=self._directory_as_html(url_list), content_type="text/html"
)
@routes.get("/log_proxy")
async def get_log_from_proxy(self, req) -> aiohttp.web.StreamResponse:
@ -60,9 +60,11 @@ class LogHead(dashboard_utils.DashboardHeadModule):
raise Exception("url is None.")
body = await req.read()
async with self._proxy_session.request(
req.method, url, data=body, headers=req.headers) as r:
req.method, url, data=body, headers=req.headers
) as r:
sr = aiohttp.web.StreamResponse(
status=r.status, reason=r.reason, headers=req.headers)
status=r.status, reason=r.reason, headers=req.headers
)
sr.content_length = r.content_length
sr.content_type = r.content_type
sr.charset = r.charset

View file

@ -40,8 +40,7 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
test_log_text = "test_log_text"
ray.get(write_log.remote(test_log_text))
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url)
node_id = ray_start_with_dashboard["node_id"]
@ -82,8 +81,8 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
# Test range request.
response = requests.get(
webui_url + "/logs/dashboard.log",
headers={"Range": "bytes=44-52"})
webui_url + "/logs/dashboard.log", headers={"Range": "bytes=44-52"}
)
response.raise_for_status()
assert response.text == "Dashboard"
@ -100,16 +99,19 @@ def test_log(disable_aiohttp_cache, ray_start_with_dashboard):
last_ex = ex
finally:
if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception(
type(last_ex), last_ex,
last_ex.__traceback__) if last_ex else []
ex_stack = (
traceback.format_exception(
type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}")
def test_log_proxy(ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url)
@ -122,21 +124,27 @@ def test_log_proxy(ray_start_with_dashboard):
# Test range request.
response = requests.get(
f"{webui_url}/log_proxy?url={webui_url}/logs/dashboard.log",
headers={"Range": "bytes=44-52"})
headers={"Range": "bytes=44-52"},
)
response.raise_for_status()
assert response.text == "Dashboard"
# Test 404.
response = requests.get(f"{webui_url}/log_proxy?"
f"url={webui_url}/logs/not_exist_file.log")
response = requests.get(
f"{webui_url}/log_proxy?" f"url={webui_url}/logs/not_exist_file.log"
)
assert response.status_code == 404
break
except Exception as ex:
last_ex = ex
finally:
if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception(
type(last_ex), last_ex,
last_ex.__traceback__) if last_ex else []
ex_stack = (
traceback.format_exception(
type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}")

View file

@ -9,8 +9,10 @@ import ray._private.utils
import ray._private.gcs_utils as gcs_utils
from ray import ray_constants
from ray.dashboard.modules.node import node_consts
from ray.dashboard.modules.node.node_consts import (MAX_LOGS_TO_CACHE,
LOG_PRUNE_THREASHOLD)
from ray.dashboard.modules.node.node_consts import (
MAX_LOGS_TO_CACHE,
LOG_PRUNE_THREASHOLD,
)
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.dashboard.consts as dashboard_consts
@ -28,13 +30,21 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
def gcs_node_info_to_dict(message):
return dashboard_utils.message_to_dict(
message, {"nodeId"}, including_default_value_fields=True)
message, {"nodeId"}, including_default_value_fields=True
)
def node_stats_to_dict(message):
decode_keys = {
"actorId", "jobId", "taskId", "parentTaskId", "sourceActorId",
"callerId", "rayletId", "workerId", "placementGroupId"
"actorId",
"jobId",
"taskId",
"parentTaskId",
"sourceActorId",
"callerId",
"rayletId",
"workerId",
"placementGroupId",
}
core_workers_stats = message.core_workers_stats
message.ClearField("core_workers_stats")
@ -42,7 +52,8 @@ def node_stats_to_dict(message):
result = dashboard_utils.message_to_dict(message, decode_keys)
result["coreWorkersStats"] = [
dashboard_utils.message_to_dict(
m, decode_keys, including_default_value_fields=True)
m, decode_keys, including_default_value_fields=True
)
for m in core_workers_stats
]
return result
@ -66,11 +77,13 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if change.new:
# TODO(fyrestone): Handle exceptions.
node_id, node_info = change.new
address = "{}:{}".format(node_info["nodeManagerAddress"],
int(node_info["nodeManagerPort"]))
options = (("grpc.enable_http_proxy", 0), )
address = "{}:{}".format(
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
)
options = (("grpc.enable_http_proxy", 0),)
channel = ray._private.utils.init_grpc_channel(
address, options, asynchronous=True)
address, options, asynchronous=True
)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
self._stubs[node_id] = stub
@ -81,8 +94,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
A dict of information about the nodes in the cluster.
"""
request = gcs_service_pb2.GetAllNodeInfoRequest()
reply = await self._gcs_node_info_stub.GetAllNodeInfo(
request, timeout=2)
reply = await self._gcs_node_info_stub.GetAllNodeInfo(request, timeout=2)
if reply.status.code == 0:
result = {}
for node_info in reply.node_info_list:
@ -116,11 +128,11 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
agents = dict(DataSource.agents)
for node_id in alive_node_ids:
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" \
f"{node_id}"
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}" f"{node_id}"
# TODO: Use async version if performance is an issue
agent_port = ray.experimental.internal_kv._internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
)
if agent_port:
agents[node_id] = json.loads(agent_port)
for node_id in agents.keys() - set(alive_node_ids):
@ -142,9 +154,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if view == "summary":
all_node_summary = await DataOrganizer.get_all_node_summary()
return dashboard_optional_utils.rest_response(
success=True,
message="Node summary fetched.",
summary=all_node_summary)
success=True, message="Node summary fetched.", summary=all_node_summary
)
elif view == "details":
all_node_details = await DataOrganizer.get_all_node_details()
return dashboard_optional_utils.rest_response(
@ -160,10 +171,12 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
return dashboard_optional_utils.rest_response(
success=True,
message="Node hostname list fetched.",
host_name_list=list(alive_hostnames))
host_name_list=list(alive_hostnames),
)
else:
return dashboard_optional_utils.rest_response(
success=False, message=f"Unknown view {view}")
success=False, message=f"Unknown view {view}"
)
@routes.get("/nodes/{node_id}")
@dashboard_optional_utils.aiohttp_cache
@ -171,7 +184,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
node_id = req.match_info.get("node_id")
node_info = await DataOrganizer.get_node_info(node_id)
return dashboard_optional_utils.rest_response(
success=True, message="Node details fetched.", detail=node_info)
success=True, message="Node details fetched.", detail=node_info
)
@routes.get("/memory/memory_table")
async def get_memory_table(self, req) -> aiohttp.web.Response:
@ -187,7 +201,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
return dashboard_optional_utils.rest_response(
success=True,
message="Fetched memory table",
memory_table=memory_table.as_dict())
memory_table=memory_table.as_dict(),
)
@routes.get("/memory/set_fetch")
async def set_fetch_memory_info(self, req) -> aiohttp.web.Response:
@ -198,11 +213,11 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
self._collect_memory_info = False
else:
return dashboard_optional_utils.rest_response(
success=False,
message=f"Unknown argument to set_fetch {should_fetch}")
success=False, message=f"Unknown argument to set_fetch {should_fetch}"
)
return dashboard_optional_utils.rest_response(
success=True,
message=f"Successfully set fetching to {should_fetch}")
success=True, message=f"Successfully set fetching to {should_fetch}"
)
@routes.get("/node_logs")
async def get_logs(self, req) -> aiohttp.web.Response:
@ -212,7 +227,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if pid:
node_logs = {str(pid): node_logs.get(pid, [])}
return dashboard_optional_utils.rest_response(
success=True, message="Fetched logs.", logs=node_logs)
success=True, message="Fetched logs.", logs=node_logs
)
@routes.get("/node_errors")
async def get_errors(self, req) -> aiohttp.web.Response:
@ -222,7 +238,8 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if pid:
node_errors = {str(pid): node_errors.get(pid, [])}
return dashboard_optional_utils.rest_response(
success=True, message="Fetched errors.", errors=node_errors)
success=True, message="Fetched errors.", errors=node_errors
)
@async_loop_forever(node_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS)
async def _update_node_stats(self):
@ -234,8 +251,10 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
try:
reply = await stub.GetNodeStats(
node_manager_pb2.GetNodeStatsRequest(
include_memory_info=self._collect_memory_info),
timeout=2)
include_memory_info=self._collect_memory_info
),
timeout=2,
)
reply_dict = node_stats_to_dict(reply)
DataSource.node_stats[node_id] = reply_dict
except Exception:
@ -263,8 +282,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if self._dashboard_head.gcs_log_subscriber:
while True:
try:
log_batch = await \
self._dashboard_head.gcs_log_subscriber.poll()
log_batch = await self._dashboard_head.gcs_log_subscriber.poll()
if log_batch is None:
continue
process_log_batch(log_batch)
@ -296,11 +314,13 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
ip = match.group(2)
errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
pid_errors = list(errs_for_ip.get(pid, []))
pid_errors.append({
"message": message,
"timestamp": error_data.timestamp,
"type": error_data.type
})
pid_errors.append(
{
"message": message,
"timestamp": error_data.timestamp,
"type": error_data.type,
}
)
errs_for_ip[pid] = pid_errors
DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
logger.info(f"Received error entry for {ip} {pid}")
@ -308,8 +328,10 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
if self._dashboard_head.gcs_error_subscriber:
while True:
try:
_, error_data = await \
self._dashboard_head.gcs_error_subscriber.poll()
(
_,
error_data,
) = await self._dashboard_head.gcs_error_subscriber.poll()
if error_data is None:
continue
process_error(error_data)
@ -328,20 +350,23 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
try:
_, data = msg
pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
error_data = gcs_utils.ErrorTableData.FromString(
pubsub_msg.data)
error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
process_error(error_data)
except Exception:
logger.exception("Error receiving error info from Redis.")
async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._gcs_node_info_stub = \
gcs_service_pb2_grpc.NodeInfoGcsServiceStub(gcs_channel)
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
gcs_channel
)
await asyncio.gather(self._update_nodes(), self._update_node_stats(),
self._update_log_info(),
self._update_error_info())
await asyncio.gather(
self._update_nodes(),
self._update_node_stats(),
self._update_log_info(),
self._update_error_info(),
)
@staticmethod
def is_minimal_module():

View file

@ -10,19 +10,23 @@ import ray
import threading
from datetime import datetime, timedelta
from ray.cluster_utils import Cluster
from ray.dashboard.modules.node.node_consts import (LOG_PRUNE_THREASHOLD,
MAX_LOGS_TO_CACHE)
from ray.dashboard.modules.node.node_consts import (
LOG_PRUNE_THREASHOLD,
MAX_LOGS_TO_CACHE,
)
from ray.dashboard.tests.conftest import * # noqa
from ray._private.test_utils import (
format_web_url, wait_until_server_available, wait_for_condition,
wait_until_succeeded_without_exception)
format_web_url,
wait_until_server_available,
wait_for_condition,
wait_until_succeeded_without_exception,
)
logger = logging.getLogger(__name__)
def test_nodes_update(enable_test_module, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url)
@ -44,8 +48,7 @@ def test_nodes_update(enable_test_module, ray_start_with_dashboard):
assert len(dump_data["agents"]) == 1
assert len(dump_data["nodeIdToIp"]) == 1
assert len(dump_data["nodeIdToHostname"]) == 1
assert dump_data["nodes"].keys() == dump_data[
"nodeIdToHostname"].keys()
assert dump_data["nodes"].keys() == dump_data["nodeIdToHostname"].keys()
response = requests.get(webui_url + "/test/notified_agents")
response.raise_for_status()
@ -77,8 +80,7 @@ def test_node_info(disable_aiohttp_cache, ray_start_with_dashboard):
actor_pids = [actor.getpid.remote() for actor in actors]
actor_pids = set(ray.get(actor_pids))
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url)
node_id = ray_start_with_dashboard["node_id"]
@ -134,15 +136,19 @@ def test_node_info(disable_aiohttp_cache, ray_start_with_dashboard):
last_ex = ex
finally:
if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception(
type(last_ex), last_ex,
last_ex.__traceback__) if last_ex else []
ex_stack = (
traceback.format_exception(
type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}")
def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
@ray.remote
class ActorWithObjs:
@ -156,8 +162,7 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
actors = [ActorWithObjs.remote() for _ in range(2)] # noqa
results = ray.get([actor.get_obj.remote() for actor in actors]) # noqa
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
resp = requests.get(
webui_url + "/memory/set_fetch", params={"shouldFetch": "true"})
resp = requests.get(webui_url + "/memory/set_fetch", params={"shouldFetch": "true"})
resp.raise_for_status()
def check_mem_table():
@ -172,11 +177,12 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
assert summary["totalLocalRefCount"] == 3
assert wait_until_succeeded_without_exception(
check_mem_table, (AssertionError, ), timeout_ms=10000)
check_mem_table, (AssertionError,), timeout_ms=10000
)
def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]))
assert wait_until_server_available(ray_start_with_dashboard["webui_url"])
webui_url = format_web_url(ray_start_with_dashboard["webui_url"])
@ -220,21 +226,25 @@ def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard):
last_ex = ex
finally:
if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception(
type(last_ex), last_ex,
last_ex.__traceback__) if last_ex else []
ex_stack = (
traceback.format_exception(
type(last_ex), last_ex, last_ex.__traceback__
)
if last_ex
else []
)
ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}")
@pytest.mark.parametrize(
"ray_start_cluster_head", [{
"include_dashboard": True
}], indirect=True)
def test_multi_nodes_info(enable_test_module, disable_aiohttp_cache,
ray_start_cluster_head):
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
)
def test_multi_nodes_info(
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
):
cluster: Cluster = ray_start_cluster_head
assert (wait_until_server_available(cluster.webui_url) is True)
assert wait_until_server_available(cluster.webui_url) is True
webui_url = cluster.webui_url
webui_url = format_web_url(webui_url)
cluster.add_node()
@ -269,13 +279,13 @@ def test_multi_nodes_info(enable_test_module, disable_aiohttp_cache,
@pytest.mark.parametrize(
"ray_start_cluster_head", [{
"include_dashboard": True
}], indirect=True)
def test_multi_node_churn(enable_test_module, disable_aiohttp_cache,
ray_start_cluster_head):
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
)
def test_multi_node_churn(
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
):
cluster: Cluster = ray_start_cluster_head
assert (wait_until_server_available(cluster.webui_url) is True)
assert wait_until_server_available(cluster.webui_url) is True
webui_url = format_web_url(cluster.webui_url)
def cluster_chaos_monkey():
@ -315,13 +325,11 @@ def test_multi_node_churn(enable_test_module, disable_aiohttp_cache,
@pytest.mark.parametrize(
"ray_start_cluster_head", [{
"include_dashboard": True
}], indirect=True)
def test_logs(enable_test_module, disable_aiohttp_cache,
ray_start_cluster_head):
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
)
def test_logs(enable_test_module, disable_aiohttp_cache, ray_start_cluster_head):
cluster = ray_start_cluster_head
assert (wait_until_server_available(cluster.webui_url) is True)
assert wait_until_server_available(cluster.webui_url) is True
webui_url = cluster.webui_url
webui_url = format_web_url(webui_url)
nodes = ray.nodes()
@ -348,21 +356,18 @@ def test_logs(enable_test_module, disable_aiohttp_cache,
def check_logs():
node_logs_response = requests.get(
f"{webui_url}/node_logs", params={"ip": node_ip})
f"{webui_url}/node_logs", params={"ip": node_ip}
)
node_logs_response.raise_for_status()
node_logs = node_logs_response.json()
assert node_logs["result"]
assert type(node_logs["data"]["logs"]) is dict
assert all(
pid in node_logs["data"]["logs"] for pid in (la_pid, la2_pid))
assert all(pid in node_logs["data"]["logs"] for pid in (la_pid, la2_pid))
assert len(node_logs["data"]["logs"][la2_pid]) == 1
actor_one_logs_response = requests.get(
f"{webui_url}/node_logs",
params={
"ip": node_ip,
"pid": str(la_pid)
})
f"{webui_url}/node_logs", params={"ip": node_ip, "pid": str(la_pid)}
)
actor_one_logs_response.raise_for_status()
actor_one_logs = actor_one_logs_response.json()
assert actor_one_logs["result"]
@ -370,19 +375,19 @@ def test_logs(enable_test_module, disable_aiohttp_cache,
assert len(actor_one_logs["data"]["logs"][la_pid]) == 4
assert wait_until_succeeded_without_exception(
check_logs, (AssertionError, ), timeout_ms=1000)
check_logs, (AssertionError,), timeout_ms=1000
)
@pytest.mark.parametrize(
"ray_start_cluster_head", [{
"include_dashboard": True
}], indirect=True)
def test_logs_clean_up(enable_test_module, disable_aiohttp_cache,
ray_start_cluster_head):
"""Check if logs from the dead pids are GC'ed.
"""
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
)
def test_logs_clean_up(
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
):
"""Check if logs from the dead pids are GC'ed."""
cluster = ray_start_cluster_head
assert (wait_until_server_available(cluster.webui_url) is True)
assert wait_until_server_available(cluster.webui_url) is True
webui_url = cluster.webui_url
webui_url = format_web_url(webui_url)
nodes = ray.nodes()
@ -406,38 +411,41 @@ def test_logs_clean_up(enable_test_module, disable_aiohttp_cache,
def check_logs():
node_logs_response = requests.get(
f"{webui_url}/node_logs", params={"ip": node_ip})
f"{webui_url}/node_logs", params={"ip": node_ip}
)
node_logs_response.raise_for_status()
node_logs = node_logs_response.json()
assert node_logs["result"]
assert la_pid in node_logs["data"]["logs"]
assert wait_until_succeeded_without_exception(
check_logs, (AssertionError, ), timeout_ms=1000)
check_logs, (AssertionError,), timeout_ms=1000
)
ray.kill(la)
def check_logs_not_exist():
node_logs_response = requests.get(
f"{webui_url}/node_logs", params={"ip": node_ip})
f"{webui_url}/node_logs", params={"ip": node_ip}
)
node_logs_response.raise_for_status()
node_logs = node_logs_response.json()
assert node_logs["result"]
assert la_pid not in node_logs["data"]["logs"]
assert wait_until_succeeded_without_exception(
check_logs_not_exist, (AssertionError, ), timeout_ms=10000)
check_logs_not_exist, (AssertionError,), timeout_ms=10000
)
@pytest.mark.parametrize(
"ray_start_cluster_head", [{
"include_dashboard": True
}], indirect=True)
def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
ray_start_cluster_head):
"""Test that each Ray worker cannot cache more than 1000 logs at a time.
"""
"ray_start_cluster_head", [{"include_dashboard": True}], indirect=True
)
def test_logs_max_count(
enable_test_module, disable_aiohttp_cache, ray_start_cluster_head
):
"""Test that each Ray worker cannot cache more than 1000 logs at a time."""
cluster = ray_start_cluster_head
assert (wait_until_server_available(cluster.webui_url) is True)
assert wait_until_server_available(cluster.webui_url) is True
webui_url = cluster.webui_url
webui_url = format_web_url(webui_url)
nodes = ray.nodes()
@ -461,7 +469,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
def check_logs():
node_logs_response = requests.get(
f"{webui_url}/node_logs", params={"ip": node_ip})
f"{webui_url}/node_logs", params={"ip": node_ip}
)
node_logs_response.raise_for_status()
node_logs = node_logs_response.json()
assert node_logs["result"]
@ -472,11 +481,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD
actor_one_logs_response = requests.get(
f"{webui_url}/node_logs",
params={
"ip": node_ip,
"pid": str(la_pid)
})
f"{webui_url}/node_logs", params={"ip": node_ip, "pid": str(la_pid)}
)
actor_one_logs_response.raise_for_status()
actor_one_logs = actor_one_logs_response.json()
assert actor_one_logs["result"]
@ -486,7 +492,8 @@ def test_logs_max_count(enable_test_module, disable_aiohttp_cache,
assert log_lengths <= MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD
assert wait_until_succeeded_without_exception(
check_logs, (AssertionError, ), timeout_ms=10000)
check_logs, (AssertionError,), timeout_ms=10000
)
if __name__ == "__main__":

View file

@ -39,10 +39,12 @@ try:
except (ModuleNotFoundError, ImportError):
gpustat = None
if log_once("gpustat_import_warning"):
warnings.warn("`gpustat` package is not installed. GPU monitoring is "
"not available. To have full functionality of the "
"dashboard please install `pip install ray["
"default]`.)")
warnings.warn(
"`gpustat` package is not installed. GPU monitoring is "
"not available. To have full functionality of the "
"dashboard please install `pip install ray["
"default]`.)"
)
def recursive_asdict(o):
@ -68,68 +70,81 @@ def jsonify_asdict(o) -> str:
# A list of gauges to record and export metrics.
METRICS_GAUGES = {
"node_cpu_utilization": Gauge("node_cpu_utilization",
"Total CPU usage on a ray node",
"percentage", ["ip"]),
"node_cpu_count": Gauge("node_cpu_count",
"Total CPUs available on a ray node", "cores",
["ip"]),
"node_mem_used": Gauge("node_mem_used", "Memory usage on a ray node",
"bytes", ["ip"]),
"node_mem_available": Gauge("node_mem_available",
"Memory available on a ray node", "bytes",
["ip"]),
"node_mem_total": Gauge("node_mem_total", "Total memory on a ray node",
"bytes", ["ip"]),
"node_gpus_available": Gauge("node_gpus_available",
"Total GPUs available on a ray node",
"percentage", ["ip"]),
"node_gpus_utilization": Gauge("node_gpus_utilization",
"Total GPUs usage on a ray node",
"percentage", ["ip"]),
"node_gram_used": Gauge("node_gram_used",
"Total GPU RAM usage on a ray node", "bytes",
["ip"]),
"node_gram_available": Gauge("node_gram_available",
"Total GPU RAM available on a ray node",
"bytes", ["ip"]),
"node_disk_usage": Gauge("node_disk_usage",
"Total disk usage (bytes) on a ray node", "bytes",
["ip"]),
"node_disk_free": Gauge("node_disk_free",
"Total disk free (bytes) on a ray node", "bytes",
["ip"]),
"node_cpu_utilization": Gauge(
"node_cpu_utilization", "Total CPU usage on a ray node", "percentage", ["ip"]
),
"node_cpu_count": Gauge(
"node_cpu_count", "Total CPUs available on a ray node", "cores", ["ip"]
),
"node_mem_used": Gauge(
"node_mem_used", "Memory usage on a ray node", "bytes", ["ip"]
),
"node_mem_available": Gauge(
"node_mem_available", "Memory available on a ray node", "bytes", ["ip"]
),
"node_mem_total": Gauge(
"node_mem_total", "Total memory on a ray node", "bytes", ["ip"]
),
"node_gpus_available": Gauge(
"node_gpus_available",
"Total GPUs available on a ray node",
"percentage",
["ip"],
),
"node_gpus_utilization": Gauge(
"node_gpus_utilization", "Total GPUs usage on a ray node", "percentage", ["ip"]
),
"node_gram_used": Gauge(
"node_gram_used", "Total GPU RAM usage on a ray node", "bytes", ["ip"]
),
"node_gram_available": Gauge(
"node_gram_available", "Total GPU RAM available on a ray node", "bytes", ["ip"]
),
"node_disk_usage": Gauge(
"node_disk_usage", "Total disk usage (bytes) on a ray node", "bytes", ["ip"]
),
"node_disk_free": Gauge(
"node_disk_free", "Total disk free (bytes) on a ray node", "bytes", ["ip"]
),
"node_disk_utilization_percentage": Gauge(
"node_disk_utilization_percentage",
"Total disk utilization (percentage) on a ray node", "percentage",
["ip"]),
"node_network_sent": Gauge("node_network_sent", "Total network sent",
"bytes", ["ip"]),
"node_network_received": Gauge("node_network_received",
"Total network received", "bytes", ["ip"]),
"Total disk utilization (percentage) on a ray node",
"percentage",
["ip"],
),
"node_network_sent": Gauge(
"node_network_sent", "Total network sent", "bytes", ["ip"]
),
"node_network_received": Gauge(
"node_network_received", "Total network received", "bytes", ["ip"]
),
"node_network_send_speed": Gauge(
"node_network_send_speed", "Network send speed", "bytes/sec", ["ip"]),
"node_network_receive_speed": Gauge("node_network_receive_speed",
"Network receive speed", "bytes/sec",
["ip"]),
"raylet_cpu": Gauge("raylet_cpu", "CPU usage of the raylet on a node.",
"percentage", ["ip", "pid"]),
"raylet_mem": Gauge("raylet_mem", "Memory usage of the raylet on a node",
"mb", ["ip", "pid"]),
"cluster_active_nodes": Gauge("cluster_active_nodes",
"Active nodes on the cluster", "count",
["node_type"]),
"cluster_failed_nodes": Gauge("cluster_failed_nodes",
"Failed nodes on the cluster", "count",
["node_type"]),
"cluster_pending_nodes": Gauge("cluster_pending_nodes",
"Pending nodes on the cluster", "count",
["node_type"]),
"node_network_send_speed", "Network send speed", "bytes/sec", ["ip"]
),
"node_network_receive_speed": Gauge(
"node_network_receive_speed", "Network receive speed", "bytes/sec", ["ip"]
),
"raylet_cpu": Gauge(
"raylet_cpu", "CPU usage of the raylet on a node.", "percentage", ["ip", "pid"]
),
"raylet_mem": Gauge(
"raylet_mem", "Memory usage of the raylet on a node", "mb", ["ip", "pid"]
),
"cluster_active_nodes": Gauge(
"cluster_active_nodes", "Active nodes on the cluster", "count", ["node_type"]
),
"cluster_failed_nodes": Gauge(
"cluster_failed_nodes", "Failed nodes on the cluster", "count", ["node_type"]
),
"cluster_pending_nodes": Gauge(
"cluster_pending_nodes", "Pending nodes on the cluster", "count", ["node_type"]
),
}
class ReporterAgent(dashboard_utils.DashboardAgentModule,
reporter_pb2_grpc.ReporterServiceServicer):
class ReporterAgent(
dashboard_utils.DashboardAgentModule, reporter_pb2_grpc.ReporterServiceServicer
):
"""A monitor process for monitoring Ray nodes.
Attributes:
@ -145,37 +160,39 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
cpu_count = ray._private.utils.get_num_cpus()
self._cpu_counts = (cpu_count, cpu_count)
else:
self._cpu_counts = (psutil.cpu_count(),
psutil.cpu_count(logical=False))
self._cpu_counts = (psutil.cpu_count(), psutil.cpu_count(logical=False))
self._ip = dashboard_agent.ip
if not use_gcs_for_bootstrap():
self._redis_address, _ = dashboard_agent.redis_address
self._is_head_node = (self._ip == self._redis_address)
self._is_head_node = self._ip == self._redis_address
else:
self._is_head_node = (
self._ip == dashboard_agent.gcs_address.split(":")[0])
self._is_head_node = self._ip == dashboard_agent.gcs_address.split(":")[0]
self._hostname = socket.gethostname()
self._workers = set()
self._network_stats_hist = [(0, (0.0, 0.0))] # time, (sent, recv)
self._metrics_agent = MetricsAgent(
"127.0.0.1" if self._ip == "127.0.0.1" else "",
dashboard_agent.metrics_export_port)
self._key = f"{reporter_consts.REPORTER_PREFIX}" \
f"{self._dashboard_agent.node_id}"
dashboard_agent.metrics_export_port,
)
self._key = (
f"{reporter_consts.REPORTER_PREFIX}" f"{self._dashboard_agent.node_id}"
)
async def GetProfilingStats(self, request, context):
pid = request.pid
duration = request.duration
profiling_file_path = os.path.join(
ray._private.utils.get_ray_temp_dir(), f"{pid}_profiling.txt")
ray._private.utils.get_ray_temp_dir(), f"{pid}_profiling.txt"
)
sudo = "sudo" if ray._private.utils.get_user() != "root" else ""
process = await asyncio.create_subprocess_shell(
f"{sudo} $(which py-spy) record "
f"-o {profiling_file_path} -p {pid} -d {duration} -f speedscope",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
shell=True,
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
profiling_stats = ""
@ -183,14 +200,14 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
with open(profiling_file_path, "r") as f:
profiling_stats = f.read()
return reporter_pb2.GetProfilingStatsReply(
profiling_stats=profiling_stats, std_out=stdout, std_err=stderr)
profiling_stats=profiling_stats, std_out=stdout, std_err=stderr
)
async def ReportOCMetrics(self, request, context):
# This function receives a GRPC containing OpenCensus (OC) metrics
# from a Ray process, then exposes those metrics to Prometheus.
try:
self._metrics_agent.record_metric_points_from_protobuf(
request.metrics)
self._metrics_agent.record_metric_points_from_protobuf(request.metrics)
except Exception:
logger.error(traceback.format_exc())
return reporter_pb2.ReportOCMetricsReply()
@ -227,10 +244,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
for gpu in gpus:
# Note the keys in this dict have periods which throws
# off javascript so we change .s to _s
gpu_data = {
"_".join(key.split(".")): val
for key, val in gpu.entry.items()
}
gpu_data = {"_".join(key.split(".")): val for key, val in gpu.entry.items()}
gpu_utilizations.append(gpu_data)
return gpu_utilizations
@ -245,8 +259,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
@staticmethod
def _get_network_stats():
ifaces = [
v for k, v in psutil.net_io_counters(pernic=True).items()
if k[0] == "e"
v for k, v in psutil.net_io_counters(pernic=True).items() if k[0] == "e"
]
sent = sum((iface.bytes_sent for iface in ifaces))
@ -266,8 +279,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
if IN_KUBERNETES_POD:
# If in a K8s pod, disable disk display by passing in dummy values.
return {
"/": psutil._common.sdiskusage(
total=1, used=0, free=1, percent=0.0)
"/": psutil._common.sdiskusage(total=1, used=0, free=1, percent=0.0)
}
root = os.environ["USERPROFILE"] if sys.platform == "win32" else os.sep
tmp = ray._private.utils.get_user_temp_dir()
@ -286,14 +298,18 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
self._workers.update(workers)
self._workers.discard(psutil.Process())
return [
w.as_dict(attrs=[
"pid",
"create_time",
"cpu_percent",
"cpu_times",
"cmdline",
"memory_info",
]) for w in self._workers if w.status() != psutil.STATUS_ZOMBIE
w.as_dict(
attrs=[
"pid",
"create_time",
"cpu_percent",
"cpu_times",
"cmdline",
"memory_info",
]
)
for w in self._workers
if w.status() != psutil.STATUS_ZOMBIE
]
@staticmethod
@ -318,14 +334,16 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
if raylet_proc is None:
return {}
else:
return raylet_proc.as_dict(attrs=[
"pid",
"create_time",
"cpu_percent",
"cpu_times",
"cmdline",
"memory_info",
])
return raylet_proc.as_dict(
attrs=[
"pid",
"create_time",
"cpu_percent",
"cpu_times",
"cmdline",
"memory_info",
]
)
def _get_load_avg(self):
if sys.platform == "win32":
@ -345,8 +363,10 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
then, prev_network_stats = self._network_stats_hist[0]
prev_send, prev_recv = prev_network_stats
now_send, now_recv = network_stats
network_speed_stats = ((now_send - prev_send) / (now - then),
(now_recv - prev_recv) / (now - then))
network_speed_stats = (
(now_send - prev_send) / (now - then),
(now_recv - prev_recv) / (now - then),
)
return {
"now": now,
"hostname": self._hostname,
@ -379,7 +399,9 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
Record(
gauge=METRICS_GAUGES["cluster_active_nodes"],
value=active_node_count,
tags={"node_type": node_type}))
tags={"node_type": node_type},
)
)
failed_nodes = cluster_stats["autoscaler_report"]["failed_nodes"]
failed_nodes_dict = {}
@ -394,7 +416,9 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
Record(
gauge=METRICS_GAUGES["cluster_failed_nodes"],
value=failed_node_count,
tags={"node_type": node_type}))
tags={"node_type": node_type},
)
)
pending_nodes = cluster_stats["autoscaler_report"]["pending_nodes"]
pending_nodes_dict = {}
@ -409,35 +433,36 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
Record(
gauge=METRICS_GAUGES["cluster_pending_nodes"],
value=pending_node_count,
tags={"node_type": node_type}))
tags={"node_type": node_type},
)
)
# -- CPU per node --
cpu_usage = float(stats["cpu"])
cpu_record = Record(
gauge=METRICS_GAUGES["node_cpu_utilization"],
value=cpu_usage,
tags={"ip": ip})
tags={"ip": ip},
)
cpu_count, _ = stats["cpus"]
cpu_count_record = Record(
gauge=METRICS_GAUGES["node_cpu_count"],
value=cpu_count,
tags={"ip": ip})
gauge=METRICS_GAUGES["node_cpu_count"], value=cpu_count, tags={"ip": ip}
)
# -- Mem per node --
mem_total, mem_available, _, mem_used = stats["mem"]
mem_used_record = Record(
gauge=METRICS_GAUGES["node_mem_used"],
value=mem_used,
tags={"ip": ip})
gauge=METRICS_GAUGES["node_mem_used"], value=mem_used, tags={"ip": ip}
)
mem_available_record = Record(
gauge=METRICS_GAUGES["node_mem_available"],
value=mem_available,
tags={"ip": ip})
tags={"ip": ip},
)
mem_total_record = Record(
gauge=METRICS_GAUGES["node_mem_total"],
value=mem_total,
tags={"ip": ip})
gauge=METRICS_GAUGES["node_mem_total"], value=mem_total, tags={"ip": ip}
)
# -- GPU per node --
gpus = stats["gpus"]
@ -455,23 +480,29 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
gpus_available_record = Record(
gauge=METRICS_GAUGES["node_gpus_available"],
value=gpus_available,
tags={"ip": ip})
tags={"ip": ip},
)
gpus_utilization_record = Record(
gauge=METRICS_GAUGES["node_gpus_utilization"],
value=gpus_utilization,
tags={"ip": ip})
tags={"ip": ip},
)
gram_used_record = Record(
gauge=METRICS_GAUGES["node_gram_used"],
value=gram_used,
tags={"ip": ip})
gauge=METRICS_GAUGES["node_gram_used"], value=gram_used, tags={"ip": ip}
)
gram_available_record = Record(
gauge=METRICS_GAUGES["node_gram_available"],
value=gram_available,
tags={"ip": ip})
records_reported.extend([
gpus_available_record, gpus_utilization_record,
gram_used_record, gram_available_record
])
tags={"ip": ip},
)
records_reported.extend(
[
gpus_available_record,
gpus_utilization_record,
gram_used_record,
gram_available_record,
]
)
# -- Disk per node --
used, free = 0, 0
@ -480,39 +511,42 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
free += entry.free
disk_utilization = float(used / (used + free)) * 100
disk_usage_record = Record(
gauge=METRICS_GAUGES["node_disk_usage"],
value=used,
tags={"ip": ip})
gauge=METRICS_GAUGES["node_disk_usage"], value=used, tags={"ip": ip}
)
disk_free_record = Record(
gauge=METRICS_GAUGES["node_disk_free"],
value=free,
tags={"ip": ip})
gauge=METRICS_GAUGES["node_disk_free"], value=free, tags={"ip": ip}
)
disk_utilization_percentage_record = Record(
gauge=METRICS_GAUGES["node_disk_utilization_percentage"],
value=disk_utilization,
tags={"ip": ip})
tags={"ip": ip},
)
# -- Network speed (send/receive) stats per node --
network_stats = stats["network"]
network_sent_record = Record(
gauge=METRICS_GAUGES["node_network_sent"],
value=network_stats[0],
tags={"ip": ip})
tags={"ip": ip},
)
network_received_record = Record(
gauge=METRICS_GAUGES["node_network_received"],
value=network_stats[1],
tags={"ip": ip})
tags={"ip": ip},
)
# -- Network speed (send/receive) per node --
network_speed_stats = stats["network_speed"]
network_send_speed_record = Record(
gauge=METRICS_GAUGES["node_network_send_speed"],
value=network_speed_stats[0],
tags={"ip": ip})
tags={"ip": ip},
)
network_receive_speed_record = Record(
gauge=METRICS_GAUGES["node_network_receive_speed"],
value=network_speed_stats[1],
tags={"ip": ip})
tags={"ip": ip},
)
raylet_stats = stats["raylet"]
if raylet_stats:
@ -522,29 +556,34 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
raylet_cpu_record = Record(
gauge=METRICS_GAUGES["raylet_cpu"],
value=raylet_cpu_usage,
tags={
"ip": ip,
"pid": raylet_pid
})
tags={"ip": ip, "pid": raylet_pid},
)
# -- raylet mem --
raylet_mem_usage = float(raylet_stats["memory_info"].rss) / 1e6
raylet_mem_record = Record(
gauge=METRICS_GAUGES["raylet_mem"],
value=raylet_mem_usage,
tags={
"ip": ip,
"pid": raylet_pid
})
tags={"ip": ip, "pid": raylet_pid},
)
records_reported.extend([raylet_cpu_record, raylet_mem_record])
records_reported.extend([
cpu_record, cpu_count_record, mem_used_record,
mem_available_record, mem_total_record, disk_usage_record,
disk_free_record, disk_utilization_percentage_record,
network_sent_record, network_received_record,
network_send_speed_record, network_receive_speed_record
])
records_reported.extend(
[
cpu_record,
cpu_count_record,
mem_used_record,
mem_available_record,
mem_total_record,
disk_usage_record,
disk_free_record,
disk_utilization_percentage_record,
network_sent_record,
network_received_record,
network_send_speed_record,
network_receive_speed_record,
]
)
return records_reported
async def _perform_iteration(self, publish):
@ -552,9 +591,13 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
while True:
try:
formatted_status_string = internal_kv._internal_kv_get(
DEBUG_AUTOSCALING_STATUS)
cluster_stats = json.loads(formatted_status_string.decode(
)) if formatted_status_string else {}
DEBUG_AUTOSCALING_STATUS
)
cluster_stats = (
json.loads(formatted_status_string.decode())
if formatted_status_string
else {}
)
stats = self._get_all_stats()
records_reported = self._record_stats(stats, cluster_stats)
@ -563,8 +606,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
except Exception:
logger.exception("Error publishing node physical stats.")
await asyncio.sleep(
reporter_consts.REPORTER_UPDATE_INTERVAL_MS / 1000)
await asyncio.sleep(reporter_consts.REPORTER_UPDATE_INTERVAL_MS / 1000)
async def run(self, server):
reporter_pb2_grpc.add_ReporterServiceServicer_to_server(self, server)
@ -573,17 +615,20 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
if gcs_addr is None:
aioredis_client = await aioredis.create_redis_pool(
address=self._dashboard_agent.redis_address,
password=self._dashboard_agent.redis_password)
password=self._dashboard_agent.redis_password,
)
gcs_addr = await aioredis_client.get("GcsServerAddress")
gcs_addr = gcs_addr.decode()
publisher = GcsAioPublisher(address=gcs_addr)
async def publish(key: str, data: str):
await publisher.publish_resource_usage(key, data)
else:
aioredis_client = await aioredis.create_redis_pool(
address=self._dashboard_agent.redis_address,
password=self._dashboard_agent.redis_password)
password=self._dashboard_agent.redis_password,
)
async def publish(key: str, data: str):
await aioredis_client.publish(key, data)

View file

@ -3,4 +3,5 @@ import ray.ray_constants as ray_constants
REPORTER_PREFIX = "RAY_REPORTER:"
# The reporter will report its statistics this often (milliseconds).
REPORTER_UPDATE_INTERVAL_MS = ray_constants.env_integer(
"REPORTER_UPDATE_INTERVAL_MS", 2500)
"REPORTER_UPDATE_INTERVAL_MS", 2500
)

View file

@ -9,13 +9,14 @@ import ray
import ray.dashboard.modules.reporter.reporter_consts as reporter_consts
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
from ray._private.gcs_pubsub import gcs_pubsub_enabled, \
GcsAioResourceUsageSubscriber
from ray._private.gcs_pubsub import gcs_pubsub_enabled, GcsAioResourceUsageSubscriber
import ray._private.services
import ray._private.utils
from ray.ray_constants import (DEBUG_AUTOSCALING_STATUS,
DEBUG_AUTOSCALING_STATUS_LEGACY,
DEBUG_AUTOSCALING_ERROR)
from ray.ray_constants import (
DEBUG_AUTOSCALING_STATUS,
DEBUG_AUTOSCALING_STATUS_LEGACY,
DEBUG_AUTOSCALING_ERROR,
)
from ray.core.generated import reporter_pb2
from ray.core.generated import reporter_pb2_grpc
import ray.experimental.internal_kv as internal_kv
@ -40,9 +41,10 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
if change.new:
node_id, ports = change.new
ip = DataSource.node_id_to_ip[node_id]
options = (("grpc.enable_http_proxy", 0), )
options = (("grpc.enable_http_proxy", 0),)
channel = ray._private.utils.init_grpc_channel(
f"{ip}:{ports[1]}", options=options, asynchronous=True)
f"{ip}:{ports[1]}", options=options, asynchronous=True
)
stub = reporter_pb2_grpc.ReporterServiceStub(channel)
self._stubs[ip] = stub
@ -53,13 +55,16 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
duration = int(req.query["duration"])
reporter_stub = self._stubs[ip]
reply = await reporter_stub.GetProfilingStats(
reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration))
profiling_info = (json.loads(reply.profiling_stats)
if reply.profiling_stats else reply.std_out)
reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration)
)
profiling_info = (
json.loads(reply.profiling_stats)
if reply.profiling_stats
else reply.std_out
)
return dashboard_optional_utils.rest_response(
success=True,
message="Profiling success.",
profiling_info=profiling_info)
success=True, message="Profiling success.", profiling_info=profiling_info
)
@routes.get("/api/ray_config")
async def get_ray_config(self, req) -> aiohttp.web.Response:
@ -75,12 +80,12 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
)
except FileNotFoundError:
return dashboard_optional_utils.rest_response(
success=False,
message="Invalid config, could not load YAML.")
success=False, message="Invalid config, could not load YAML."
)
payload = {
"min_workers": cfg.get("min_workers", "unspecified"),
"max_workers": cfg.get("max_workers", "unspecified")
"max_workers": cfg.get("max_workers", "unspecified"),
}
try:
@ -115,18 +120,18 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
"""
assert ray.experimental.internal_kv._internal_kv_initialized()
legacy_status = internal_kv._internal_kv_get(
DEBUG_AUTOSCALING_STATUS_LEGACY)
formatted_status_string = internal_kv._internal_kv_get(
DEBUG_AUTOSCALING_STATUS)
formatted_status = json.loads(formatted_status_string.decode()
) if formatted_status_string else {}
legacy_status = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS_LEGACY)
formatted_status_string = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_STATUS)
formatted_status = (
json.loads(formatted_status_string.decode())
if formatted_status_string
else {}
)
error = internal_kv._internal_kv_get(DEBUG_AUTOSCALING_ERROR)
return dashboard_optional_utils.rest_response(
success=True,
message="Got cluster status.",
autoscaling_status=legacy_status.decode()
if legacy_status else None,
autoscaling_status=legacy_status.decode() if legacy_status else None,
autoscaling_error=error.decode() if error else None,
cluster_status=formatted_status if formatted_status else None,
)
@ -148,8 +153,9 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = data
except Exception:
logger.exception("Error receiving node physical stats "
"from reporter agent.")
logger.exception(
"Error receiving node physical stats " "from reporter agent."
)
else:
receiver = Receiver()
aioredis_client = self._dashboard_head.aioredis_client
@ -165,8 +171,9 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = data
except Exception:
logger.exception("Error receiving node physical stats "
"from reporter agent.")
logger.exception(
"Error receiving node physical stats " "from reporter agent."
)
@staticmethod
def is_minimal_module():

View file

@ -10,9 +10,13 @@ from ray import ray_constants
from ray.dashboard.tests.conftest import * # noqa
from ray.dashboard.utils import Bunch
from ray.dashboard.modules.reporter.reporter_agent import ReporterAgent
from ray._private.test_utils import (format_web_url, RayTestTimeoutException,
wait_until_server_available,
wait_for_condition, fetch_prometheus)
from ray._private.test_utils import (
format_web_url,
RayTestTimeoutException,
wait_until_server_available,
wait_for_condition,
fetch_prometheus,
)
try:
import prometheus_client
@ -34,7 +38,7 @@ def test_profiling(shutdown_only):
actor_pid = ray.get(c.getpid.remote())
webui_url = addresses["webui_url"]
assert (wait_until_server_available(webui_url) is True)
assert wait_until_server_available(webui_url) is True
webui_url = format_web_url(webui_url)
start_time = time.time()
@ -44,14 +48,16 @@ def test_profiling(shutdown_only):
if time.time() - start_time > 15:
raise RayTestTimeoutException(
"Timed out while collecting profiling stats, "
f"launch_profiling: {launch_profiling}")
f"launch_profiling: {launch_profiling}"
)
launch_profiling = requests.get(
webui_url + "/api/launch_profiling",
params={
"ip": ray.nodes()[0]["NodeManagerAddress"],
"pid": actor_pid,
"duration": 5
}).json()
"duration": 5,
},
).json()
if launch_profiling["result"]:
profiling_info = launch_profiling["data"]["profilingInfo"]
break
@ -72,13 +78,12 @@ def test_node_physical_stats(enable_test_module, shutdown_only):
actor_pids = set(actor_pids)
webui_url = addresses["webui_url"]
assert (wait_until_server_available(webui_url) is True)
assert wait_until_server_available(webui_url) is True
webui_url = format_web_url(webui_url)
def _check_workers():
try:
resp = requests.get(webui_url +
"/test/dump?key=node_physical_stats")
resp = requests.get(webui_url + "/test/dump?key=node_physical_stats")
resp.raise_for_status()
result = resp.json()
assert result["result"] is True
@ -101,8 +106,7 @@ def test_node_physical_stats(enable_test_module, shutdown_only):
wait_for_condition(_check_workers, timeout=10)
@pytest.mark.skipif(
prometheus_client is None, reason="prometheus_client not installed")
@pytest.mark.skipif(prometheus_client is None, reason="prometheus_client not installed")
def test_prometheus_physical_stats_record(enable_test_module, shutdown_only):
addresses = ray.init(include_dashboard=True, num_cpus=1)
metrics_export_port = addresses["metrics_export_port"]
@ -110,29 +114,31 @@ def test_prometheus_physical_stats_record(enable_test_module, shutdown_only):
prom_addresses = [f"{addr}:{metrics_export_port}"]
def test_case_stats_exist():
components_dict, metric_names, metric_samples = fetch_prometheus(
prom_addresses)
return all([
"ray_node_cpu_utilization" in metric_names,
"ray_node_cpu_count" in metric_names,
"ray_node_mem_used" in metric_names,
"ray_node_mem_available" in metric_names,
"ray_node_mem_total" in metric_names,
"ray_raylet_cpu" in metric_names, "ray_raylet_mem" in metric_names,
"ray_node_disk_usage" in metric_names,
"ray_node_disk_free" in metric_names,
"ray_node_disk_utilization_percentage" in metric_names,
"ray_node_network_sent" in metric_names,
"ray_node_network_received" in metric_names,
"ray_node_network_send_speed" in metric_names,
"ray_node_network_receive_speed" in metric_names
])
components_dict, metric_names, metric_samples = fetch_prometheus(prom_addresses)
return all(
[
"ray_node_cpu_utilization" in metric_names,
"ray_node_cpu_count" in metric_names,
"ray_node_mem_used" in metric_names,
"ray_node_mem_available" in metric_names,
"ray_node_mem_total" in metric_names,
"ray_raylet_cpu" in metric_names,
"ray_raylet_mem" in metric_names,
"ray_node_disk_usage" in metric_names,
"ray_node_disk_free" in metric_names,
"ray_node_disk_utilization_percentage" in metric_names,
"ray_node_network_sent" in metric_names,
"ray_node_network_received" in metric_names,
"ray_node_network_send_speed" in metric_names,
"ray_node_network_receive_speed" in metric_names,
]
)
def test_case_ip_correct():
components_dict, metric_names, metric_samples = fetch_prometheus(
prom_addresses)
components_dict, metric_names, metric_samples = fetch_prometheus(prom_addresses)
raylet_proc = ray.worker._global_node.all_processes[
ray_constants.PROCESS_TYPE_RAYLET][0]
ray_constants.PROCESS_TYPE_RAYLET
][0]
raylet_pid = None
# Find the raylet pid recorded in the tag.
for sample in metric_samples:
@ -159,24 +165,25 @@ def test_report_stats():
"cpu": 57.4,
"cpus": (8, 4),
"mem": (17179869184, 5723353088, 66.7, 9234341888),
"workers": [{
"memory_info": Bunch(
rss=55934976, vms=7026937856, pfaults=15354, pageins=0),
"cpu_percent": 0.0,
"cmdline": [
"ray::IDLE", "", "", "", "", "", "", "", "", "", "", ""
],
"create_time": 1614826391.338613,
"pid": 7174,
"cpu_times": Bunch(
user=0.607899328,
system=0.274044032,
children_user=0.0,
children_system=0.0)
}],
"workers": [
{
"memory_info": Bunch(
rss=55934976, vms=7026937856, pfaults=15354, pageins=0
),
"cpu_percent": 0.0,
"cmdline": ["ray::IDLE", "", "", "", "", "", "", "", "", "", "", ""],
"create_time": 1614826391.338613,
"pid": 7174,
"cpu_times": Bunch(
user=0.607899328,
system=0.274044032,
children_user=0.0,
children_system=0.0,
),
}
],
"raylet": {
"memory_info": Bunch(
rss=18354176, vms=6921486336, pfaults=6206, pageins=3),
"memory_info": Bunch(rss=18354176, vms=6921486336, pfaults=6206, pageins=3),
"cpu_percent": 0.0,
"cmdline": ["fake raylet cmdline"],
"create_time": 1614826390.274854,
@ -185,22 +192,18 @@ def test_report_stats():
user=0.03683138,
system=0.035913716,
children_user=0.0,
children_system=0.0)
children_system=0.0,
),
},
"bootTime": 1612934656.0,
"loadAvg": ((4.4521484375, 3.61083984375, 3.5400390625), (0.56, 0.45,
0.44)),
"loadAvg": ((4.4521484375, 3.61083984375, 3.5400390625), (0.56, 0.45, 0.44)),
"disk": {
"/": Bunch(
total=250790436864,
used=11316781056,
free=22748921856,
percent=33.2),
total=250790436864, used=11316781056, free=22748921856, percent=33.2
),
"/tmp": Bunch(
total=250790436864,
used=209532035072,
free=22748921856,
percent=90.2)
total=250790436864, used=209532035072, free=22748921856, percent=90.2
),
},
"gpus": [],
"network": (13621160960, 11914936320),
@ -209,13 +212,10 @@ def test_report_stats():
cluster_stats = {
"autoscaler_report": {
"active_nodes": {
"head_node": 1,
"worker-node-0": 2
},
"active_nodes": {"head_node": 1, "worker-node-0": 2},
"failed_nodes": [],
"pending_launches": {},
"pending_nodes": []
"pending_nodes": [],
}
}
@ -226,11 +226,9 @@ def test_report_stats():
records = ReporterAgent._record_stats(obj, test_stats, cluster_stats)
assert len(records) == 14
# Test stats with gpus
test_stats["gpus"] = [{
"utilization_gpu": 1,
"memory_used": 100,
"memory_total": 1000
}]
test_stats["gpus"] = [
{"utilization_gpu": 1, "memory_used": 100, "memory_total": 1000}
]
records = ReporterAgent._record_stats(obj, test_stats, cluster_stats)
assert len(records) == 18
# Test stats without autoscaler report

View file

@ -12,10 +12,11 @@ from ray.core.generated import runtime_env_agent_pb2
from ray.core.generated import runtime_env_agent_pb2_grpc
from ray.core.generated import agent_manager_pb2
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.modules.runtime_env.runtime_env_consts \
as runtime_env_consts
from ray.experimental.internal_kv import _internal_kv_initialized, \
_initialize_internal_kv
import ray.dashboard.modules.runtime_env.runtime_env_consts as runtime_env_consts
from ray.experimental.internal_kv import (
_internal_kv_initialized,
_initialize_internal_kv,
)
from ray._private.ray_logging import setup_component_logger
from ray._private.runtime_env.pip import PipManager
from ray._private.runtime_env.conda import CondaManager
@ -42,8 +43,10 @@ class CreatedEnvResult:
result: str
class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
runtime_env_agent_pb2_grpc.RuntimeEnvServiceServicer):
class RuntimeEnvAgent(
dashboard_utils.DashboardAgentModule,
runtime_env_agent_pb2_grpc.RuntimeEnvServiceServicer,
):
"""An RPC server to create and delete runtime envs.
Attributes:
@ -86,32 +89,33 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
return self._per_job_logger_cache[job_id]
async def CreateRuntimeEnv(self, request, context):
async def _setup_runtime_env(serialized_runtime_env,
serialized_allocated_resource_instances):
async def _setup_runtime_env(
serialized_runtime_env, serialized_allocated_resource_instances
):
# This function will be ran inside a thread
def run_setup_with_logger():
runtime_env = RuntimeEnv(
serialized_runtime_env=serialized_runtime_env)
runtime_env = RuntimeEnv(serialized_runtime_env=serialized_runtime_env)
allocated_resource: dict = json.loads(
serialized_allocated_resource_instances or "{}")
serialized_allocated_resource_instances or "{}"
)
# Use a separate logger for each job.
per_job_logger = self.get_or_create_logger(request.job_id)
# TODO(chenk008): Add log about allocated_resource to
# avoid lint error. That will be moved to cgroup plugin.
per_job_logger.debug(f"Worker has resource :"
f"{allocated_resource}")
per_job_logger.debug(f"Worker has resource :" f"{allocated_resource}")
context = RuntimeEnvContext(env_vars=runtime_env.env_vars())
self._pip_manager.setup(
runtime_env, context, logger=per_job_logger)
self._conda_manager.setup(
runtime_env, context, logger=per_job_logger)
self._pip_manager.setup(runtime_env, context, logger=per_job_logger)
self._conda_manager.setup(runtime_env, context, logger=per_job_logger)
self._py_modules_manager.setup(
runtime_env, context, logger=per_job_logger)
runtime_env, context, logger=per_job_logger
)
self._working_dir_manager.setup(
runtime_env, context, logger=per_job_logger)
runtime_env, context, logger=per_job_logger
)
self._container_manager.setup(
runtime_env, context, logger=per_job_logger)
runtime_env, context, logger=per_job_logger
)
# Add the mapping of URIs -> the serialized environment to be
# used for cache invalidation.
@ -133,14 +137,15 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
# Run setup function from all the plugins
for plugin_class_path, config in runtime_env.plugins():
logger.debug(
f"Setting up runtime env plugin {plugin_class_path}")
logger.debug(f"Setting up runtime env plugin {plugin_class_path}")
plugin_class = import_attr(plugin_class_path)
# TODO(simon): implement uri support
plugin_class.create("uri not implemented",
json.loads(config), context)
plugin_class.modify_context("uri not implemented",
json.loads(config), context)
plugin_class.create(
"uri not implemented", json.loads(config), context
)
plugin_class.modify_context(
"uri not implemented", json.loads(config), context
)
return context
@ -159,18 +164,24 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
result = self._env_cache[serialized_env]
if result.success:
context = result.result
logger.info("Runtime env already created successfully. "
f"Env: {serialized_env}, context: {context}")
logger.info(
"Runtime env already created successfully. "
f"Env: {serialized_env}, context: {context}"
)
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
serialized_runtime_env_context=context)
serialized_runtime_env_context=context,
)
else:
error_message = result.result
logger.info("Runtime env already failed. "
f"Env: {serialized_env}, err: {error_message}")
logger.info(
"Runtime env already failed. "
f"Env: {serialized_env}, err: {error_message}"
)
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message=error_message)
error_message=error_message,
)
if SLEEP_FOR_TESTING_S:
logger.info(f"Sleeping for {SLEEP_FOR_TESTING_S}s.")
@ -182,8 +193,8 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES):
try:
runtime_env_context = await _setup_runtime_env(
serialized_env,
request.serialized_allocated_resource_instances)
serialized_env, request.serialized_allocated_resource_instances
)
break
except Exception as ex:
logger.exception("Runtime env creation failed.")
@ -195,22 +206,25 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
logger.error(
"Runtime env creation failed for %d times, "
"don't retry any more.",
runtime_env_consts.RUNTIME_ENV_RETRY_TIMES)
self._env_cache[serialized_env] = CreatedEnvResult(
False, error_message)
runtime_env_consts.RUNTIME_ENV_RETRY_TIMES,
)
self._env_cache[serialized_env] = CreatedEnvResult(False, error_message)
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message=error_message)
error_message=error_message,
)
serialized_context = runtime_env_context.serialize()
self._env_cache[serialized_env] = CreatedEnvResult(
True, serialized_context)
self._env_cache[serialized_env] = CreatedEnvResult(True, serialized_context)
logger.info(
"Successfully created runtime env: %s, the context: %s",
serialized_env, serialized_context)
serialized_env,
serialized_context,
)
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
serialized_runtime_env_context=serialized_context)
serialized_runtime_env_context=serialized_context,
)
async def DeleteURIs(self, request, context):
logger.info(f"Got request to delete URIs: {request.uris}.")
@ -239,20 +253,21 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
else:
raise ValueError(
"RuntimeEnvAgent received DeleteURI request "
f"for unsupported plugin {plugin}. URI: {uri}")
f"for unsupported plugin {plugin}. URI: {uri}"
)
if failed_uris:
return runtime_env_agent_pb2.DeleteURIsReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message="Local files for URI(s) "
f"{failed_uris} not found.")
error_message="Local files for URI(s) " f"{failed_uris} not found.",
)
else:
return runtime_env_agent_pb2.DeleteURIsReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_OK)
status=agent_manager_pb2.AGENT_RPC_STATUS_OK
)
async def run(self, server):
runtime_env_agent_pb2_grpc.add_RuntimeEnvServiceServicer_to_server(
self, server)
runtime_env_agent_pb2_grpc.add_RuntimeEnvServiceServicer_to_server(self, server)
@staticmethod
def is_minimal_module():

View file

@ -1,7 +1,7 @@
import ray.ray_constants as ray_constants
RUNTIME_ENV_RETRY_TIMES = ray_constants.env_integer("RUNTIME_ENV_RETRY_TIMES",
3)
RUNTIME_ENV_RETRY_TIMES = ray_constants.env_integer("RUNTIME_ENV_RETRY_TIMES", 3)
RUNTIME_ENV_RETRY_INTERVAL_MS = ray_constants.env_integer(
"RUNTIME_ENV_RETRY_INTERVAL_MS", 1000)
"RUNTIME_ENV_RETRY_INTERVAL_MS", 1000
)

View file

@ -8,13 +8,19 @@ from ray import ray_constants
from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_pb2
from ray.core.generated import gcs_service_pb2_grpc
from ray.experimental.internal_kv import (_internal_kv_initialized,
_internal_kv_get, _internal_kv_list)
from ray.experimental.internal_kv import (
_internal_kv_initialized,
_internal_kv_get,
_internal_kv_list,
)
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
from ray._private.runtime_env.validation import ParsedRuntimeEnv
from ray.dashboard.modules.job.common import (
JobStatusInfo, JobStatusStorageClient, JOB_ID_METADATA_KEY)
JobStatusInfo,
JobStatusStorageClient,
JOB_ID_METADATA_KEY,
)
import json
import aiohttp.web
@ -32,7 +38,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
self._job_status_client = JobStatusStorageClient()
# For offloading CPU intensive work.
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2, thread_name_prefix="api_head")
max_workers=2, thread_name_prefix="api_head"
)
@routes.get("/api/actors/kill")
async def kill_actor_gcs(self, req) -> aiohttp.web.Response:
@ -41,7 +48,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
no_restart = req.query.get("no_restart", False) in ("true", "True")
if not actor_id:
return dashboard_optional_utils.rest_response(
success=False, message="actor_id is required.")
success=False, message="actor_id is required."
)
request = gcs_service_pb2.KillActorViaGcsRequest()
request.actor_id = bytes.fromhex(actor_id)
@ -49,31 +57,36 @@ class APIHead(dashboard_utils.DashboardHeadModule):
request.no_restart = no_restart
await self._gcs_actor_info_stub.KillActorViaGcs(request, timeout=5)
message = (f"Force killed actor with id {actor_id}" if force_kill else
f"Requested actor with id {actor_id} to terminate. " +
"It will exit once running tasks complete")
message = (
f"Force killed actor with id {actor_id}"
if force_kill
else f"Requested actor with id {actor_id} to terminate. "
+ "It will exit once running tasks complete"
)
return dashboard_optional_utils.rest_response(
success=True, message=message)
return dashboard_optional_utils.rest_response(success=True, message=message)
@routes.get("/api/snapshot")
async def snapshot(self, req):
job_data, actor_data, serve_data, session_name = await asyncio.gather(
self.get_job_info(), self.get_actor_info(), self.get_serve_info(),
self.get_session_name())
self.get_job_info(),
self.get_actor_info(),
self.get_serve_info(),
self.get_session_name(),
)
snapshot = {
"jobs": job_data,
"actors": actor_data,
"deployments": serve_data,
"session_name": session_name,
"ray_version": ray.__version__,
"ray_commit": ray.__commit__
"ray_commit": ray.__commit__,
}
return dashboard_optional_utils.rest_response(
success=True, message="hello", snapshot=snapshot)
success=True, message="hello", snapshot=snapshot
)
def _get_job_status(self,
metadata: Dict[str, str]) -> Optional[JobStatusInfo]:
def _get_job_status(self, metadata: Dict[str, str]) -> Optional[JobStatusInfo]:
# If a job submission ID has been added to a job, the status is
# guaranteed to be returned.
job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
@ -91,8 +104,8 @@ class APIHead(dashboard_utils.DashboardHeadModule):
"namespace": job_table_entry.config.ray_namespace,
"metadata": metadata,
"runtime_env": ParsedRuntimeEnv.deserialize(
job_table_entry.config.runtime_env_info.
serialized_runtime_env),
job_table_entry.config.runtime_env_info.serialized_runtime_env
),
}
status = self._get_job_status(metadata)
entry = {
@ -111,8 +124,7 @@ class APIHead(dashboard_utils.DashboardHeadModule):
# TODO (Alex): GCS still needs to return actors from dead jobs.
request = gcs_service_pb2.GetAllActorInfoRequest()
request.show_dead_jobs = True
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
request, timeout=5)
reply = await self._gcs_actor_info_stub.GetAllActorInfo(request, timeout=5)
actors = {}
for actor_table_entry in reply.actor_table_data:
actor_id = actor_table_entry.actor_id.hex()
@ -120,37 +132,33 @@ class APIHead(dashboard_utils.DashboardHeadModule):
entry = {
"job_id": actor_table_entry.job_id.hex(),
"state": gcs_pb2.ActorTableData.ActorState.Name(
actor_table_entry.state),
actor_table_entry.state
),
"name": actor_table_entry.name,
"namespace": actor_table_entry.ray_namespace,
"runtime_env": runtime_env,
"start_time": actor_table_entry.start_time,
"end_time": actor_table_entry.end_time,
"is_detached": actor_table_entry.is_detached,
"resources": dict(
actor_table_entry.task_spec.required_resources),
"resources": dict(actor_table_entry.task_spec.required_resources),
"actor_class": actor_table_entry.class_name,
"current_worker_id": actor_table_entry.address.worker_id.hex(),
"current_raylet_id": actor_table_entry.address.raylet_id.hex(),
"ip_address": actor_table_entry.address.ip_address,
"port": actor_table_entry.address.port,
"metadata": dict()
"metadata": dict(),
}
actors[actor_id] = entry
deployments = await self.get_serve_info()
for _, deployment_info in deployments.items():
for replica_actor_id, actor_info in deployment_info[
"actors"].items():
for replica_actor_id, actor_info in deployment_info["actors"].items():
if replica_actor_id in actors:
serve_metadata = dict()
serve_metadata["replica_tag"] = actor_info[
"replica_tag"]
serve_metadata["deployment_name"] = deployment_info[
"name"]
serve_metadata["replica_tag"] = actor_info["replica_tag"]
serve_metadata["deployment_name"] = deployment_info["name"]
serve_metadata["version"] = actor_info["version"]
actors[replica_actor_id]["metadata"][
"serve"] = serve_metadata
actors[replica_actor_id]["metadata"]["serve"] = serve_metadata
return actors
async def get_serve_info(self) -> Dict[str, Any]:
@ -168,22 +176,21 @@ class APIHead(dashboard_utils.DashboardHeadModule):
# TODO: Convert to async GRPC, if CPU usage is not a concern.
def get_deployments():
serve_keys = _internal_kv_list(
SERVE_CONTROLLER_NAME,
namespace=ray_constants.KV_NAMESPACE_SERVE)
SERVE_CONTROLLER_NAME, namespace=ray_constants.KV_NAMESPACE_SERVE
)
serve_snapshot_keys = filter(
lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys)
lambda k: SERVE_SNAPSHOT_KEY in str(k), serve_keys
)
deployments_per_controller: List[Dict[str, Any]] = []
for key in serve_snapshot_keys:
val_bytes = _internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_SERVE
) or "{}".encode("utf-8")
deployments_per_controller.append(
json.loads(val_bytes.decode("utf-8")))
deployments_per_controller.append(json.loads(val_bytes.decode("utf-8")))
# Merge the deployments dicts of all controllers.
deployments: Dict[str, Any] = {
k: v
for d in deployments_per_controller for k, v in d.items()
k: v for d in deployments_per_controller for k, v in d.items()
}
# Replace the keys (deployment names) with their hashes to prevent
# collisions caused by the automatic conversion to camelcase by the
@ -194,24 +201,27 @@ class APIHead(dashboard_utils.DashboardHeadModule):
}
return await asyncio.get_event_loop().run_in_executor(
executor=self._thread_pool, func=get_deployments)
executor=self._thread_pool, func=get_deployments
)
async def get_session_name(self):
# TODO(yic): Convert to async GRPC.
def get_session():
return ray.experimental.internal_kv._internal_kv_get(
"session_name",
namespace=ray_constants.KV_NAMESPACE_SESSION).decode()
"session_name", namespace=ray_constants.KV_NAMESPACE_SESSION
).decode()
return await asyncio.get_event_loop().run_in_executor(
executor=self._thread_pool, func=get_session)
executor=self._thread_pool, func=get_session
)
async def run(self, server):
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel)
self._gcs_actor_info_stub = \
gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel)
self._dashboard_head.aiogrpc_gcs_channel
)
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel
)
@staticmethod
def is_minimal_module():

View file

@ -36,15 +36,14 @@ def _actor_killed_loop(worker_pid: str, timeout_secs=3) -> bool:
return dead
def _kill_actor_using_dashboard_gcs(webui_url: str,
actor_id: str,
force_kill=False):
def _kill_actor_using_dashboard_gcs(webui_url: str, actor_id: str, force_kill=False):
resp = requests.get(
webui_url + KILL_ACTOR_ENDPOINT,
params={
"actor_id": actor_id,
"force_kill": force_kill,
})
},
)
resp.raise_for_status()
resp_json = resp.json()
assert resp_json["result"] is True, "msg" in resp_json

View file

@ -8,8 +8,11 @@ import pprint
import pytest
import requests
from ray._private.test_utils import (format_web_url, wait_for_condition,
wait_until_server_available)
from ray._private.test_utils import (
format_web_url,
wait_for_condition,
wait_until_server_available,
)
from ray.dashboard import dashboard
from ray.dashboard.tests.conftest import * # noqa
from ray.dashboard.modules.job.sdk import JobSubmissionClient
@ -22,25 +25,23 @@ def _get_snapshot(address: str):
response.raise_for_status()
data = response.json()
schema_path = os.path.join(
os.path.dirname(dashboard.__file__),
"modules/snapshot/snapshot_schema.json")
os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
)
pprint.pprint(data)
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
return data
def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
enable_test_module):
def test_successful_job_status(
ray_start_with_dashboard, disable_aiohttp_cache, enable_test_module
):
address = ray_start_with_dashboard["webui_url"]
assert wait_until_server_available(address)
address = format_web_url(address)
entrypoint_cmd = ("python -c\""
"import ray;"
"ray.init();"
"import time;"
"time.sleep(5);"
"\"")
entrypoint_cmd = (
'python -c"' "import ray;" "ray.init();" "import time;" "time.sleep(5);" '"'
)
client = JobSubmissionClient(address)
job_id = client.submit_job(entrypoint=entrypoint_cmd)
@ -49,11 +50,8 @@ def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
data = _get_snapshot(address)
for job_entry in data["data"]["snapshot"]["jobs"].values():
if job_entry["status"] is not None:
assert job_entry["config"]["metadata"][
"jobSubmissionId"] == job_id
assert job_entry["status"] in {
"PENDING", "RUNNING", "SUCCEEDED"
}
assert job_entry["config"]["metadata"]["jobSubmissionId"] == job_id
assert job_entry["status"] in {"PENDING", "RUNNING", "SUCCEEDED"}
assert job_entry["statusMessage"] is not None
return job_entry["status"] == "SUCCEEDED"
@ -62,20 +60,23 @@ def test_successful_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
wait_for_condition(wait_for_job_to_succeed, timeout=30)
def test_failed_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
enable_test_module):
def test_failed_job_status(
ray_start_with_dashboard, disable_aiohttp_cache, enable_test_module
):
address = ray_start_with_dashboard["webui_url"]
assert wait_until_server_available(address)
address = format_web_url(address)
entrypoint_cmd = ("python -c\""
"import ray;"
"ray.init();"
"import time;"
"time.sleep(5);"
"import sys;"
"sys.exit(1);"
"\"")
entrypoint_cmd = (
'python -c"'
"import ray;"
"ray.init();"
"import time;"
"time.sleep(5);"
"import sys;"
"sys.exit(1);"
'"'
)
client = JobSubmissionClient(address)
job_id = client.submit_job(entrypoint=entrypoint_cmd)
@ -83,8 +84,7 @@ def test_failed_job_status(ray_start_with_dashboard, disable_aiohttp_cache,
data = _get_snapshot(address)
for job_entry in data["data"]["snapshot"]["jobs"].values():
if job_entry["status"] is not None:
assert job_entry["config"]["metadata"][
"jobSubmissionId"] == job_id
assert job_entry["config"]["metadata"]["jobSubmissionId"] == job_id
assert job_entry["status"] in {"PENDING", "RUNNING", "FAILED"}
assert job_entry["statusMessage"] is not None
return job_entry["status"] == "FAILED"

View file

@ -34,11 +34,14 @@ ray.get(a.ping.remote())
"""
address = ray_start_with_dashboard["address"]
detached_driver = driver_template.format(
address=address, lifetime="'detached'", name="'abc'")
address=address, lifetime="'detached'", name="'abc'"
)
named_driver = driver_template.format(
address=address, lifetime="None", name="'xyz'")
address=address, lifetime="None", name="'xyz'"
)
unnamed_driver = driver_template.format(
address=address, lifetime="None", name="None")
address=address, lifetime="None", name="None"
)
run_string_as_driver(detached_driver)
run_string_as_driver(named_driver)
@ -50,8 +53,8 @@ ray.get(a.ping.remote())
response.raise_for_status()
data = response.json()
schema_path = os.path.join(
os.path.dirname(dashboard.__file__),
"modules/snapshot/snapshot_schema.json")
os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
)
pprint.pprint(data)
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
@ -72,10 +75,7 @@ ray.get(a.ping.remote())
assert data["data"]["snapshot"]["rayVersion"] == ray.__version__
@pytest.mark.parametrize(
"ray_start_with_dashboard", [{
"num_cpus": 4
}], indirect=True)
@pytest.mark.parametrize("ray_start_with_dashboard", [{"num_cpus": 4}], indirect=True)
def test_serve_snapshot(ray_start_with_dashboard):
"""Test detached and nondetached Serve instances running concurrently."""
@ -115,8 +115,7 @@ my_func_deleted.delete()
my_func_nondetached.deploy()
assert requests.get(
"http://127.0.0.1:8123/my_func_nondetached").text == "hello"
assert requests.get("http://127.0.0.1:8123/my_func_nondetached").text == "hello"
webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url)
@ -124,15 +123,16 @@ my_func_deleted.delete()
response.raise_for_status()
data = response.json()
schema_path = os.path.join(
os.path.dirname(dashboard.__file__),
"modules/snapshot/snapshot_schema.json")
os.path.dirname(dashboard.__file__), "modules/snapshot/snapshot_schema.json"
)
pprint.pprint(data)
jsonschema.validate(instance=data, schema=json.load(open(schema_path)))
assert len(data["data"]["snapshot"]["deployments"]) == 3
entry = data["data"]["snapshot"]["deployments"][hashlib.sha1(
"my_func".encode()).hexdigest()]
entry = data["data"]["snapshot"]["deployments"][
hashlib.sha1("my_func".encode()).hexdigest()
]
assert entry["name"] == "my_func"
assert entry["version"] is None
assert entry["namespace"] == "serve"
@ -145,14 +145,14 @@ my_func_deleted.delete()
assert len(entry["actors"]) == 1
actor_id = next(iter(entry["actors"]))
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"][
"serve"]
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"]["serve"]
assert metadata["deploymentName"] == "my_func"
assert metadata["version"] is None
assert len(metadata["replicaTag"]) > 0
entry_deleted = data["data"]["snapshot"]["deployments"][hashlib.sha1(
"my_func_deleted".encode()).hexdigest()]
entry_deleted = data["data"]["snapshot"]["deployments"][
hashlib.sha1("my_func_deleted".encode()).hexdigest()
]
assert entry_deleted["name"] == "my_func_deleted"
assert entry_deleted["version"] == "v1"
assert entry_deleted["namespace"] == "serve"
@ -163,8 +163,9 @@ my_func_deleted.delete()
assert entry_deleted["startTime"] > 0
assert entry_deleted["endTime"] > entry_deleted["startTime"]
entry_nondetached = data["data"]["snapshot"]["deployments"][hashlib.sha1(
"my_func_nondetached".encode()).hexdigest()]
entry_nondetached = data["data"]["snapshot"]["deployments"][
hashlib.sha1("my_func_nondetached".encode()).hexdigest()
]
assert entry_nondetached["name"] == "my_func_nondetached"
assert entry_nondetached["version"] == "v1"
assert entry_nondetached["namespace"] == "default_test_namespace"
@ -177,8 +178,7 @@ my_func_deleted.delete()
assert len(entry_nondetached["actors"]) == 1
actor_id = next(iter(entry_nondetached["actors"]))
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"][
"serve"]
metadata = data["data"]["snapshot"]["actors"][actor_id]["metadata"]["serve"]
assert metadata["deploymentName"] == "my_func_nondetached"
assert metadata["version"] == "v1"
assert len(metadata["replicaTag"]) > 0

View file

@ -13,7 +13,8 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
@dashboard_utils.dashboard_module(
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False))
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False)
)
class TestAgent(dashboard_utils.DashboardAgentModule):
def __init__(self, dashboard_agent):
super().__init__(dashboard_agent)
@ -25,8 +26,7 @@ class TestAgent(dashboard_utils.DashboardAgentModule):
@routes.get("/test/http_get_from_agent")
async def get_url(self, req) -> aiohttp.web.Response:
url = req.query.get("url")
result = await test_utils.http_get(self._dashboard_agent.http_session,
url)
result = await test_utils.http_get(self._dashboard_agent.http_session, url)
return aiohttp.web.json_response(result)
@routes.head("/test/route_head")

View file

@ -15,7 +15,8 @@ routes = dashboard_optional_utils.ClassMethodRouteTable
@dashboard_utils.dashboard_module(
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False))
enable=env_bool(test_consts.TEST_MODULE_ENVIRONMENT_KEY, False)
)
class TestHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
@ -62,26 +63,28 @@ class TestHead(dashboard_utils.DashboardHeadModule):
return dashboard_optional_utils.rest_response(
success=True,
message="Fetch all data from datacenter success.",
**all_data)
**all_data,
)
else:
data = dict(DataSource.__dict__.get(key))
return dashboard_optional_utils.rest_response(
success=True,
message=f"Fetch {key} from datacenter success.",
**{key: data})
**{key: data},
)
@routes.get("/test/notified_agents")
async def get_notified_agents(self, req) -> aiohttp.web.Response:
return dashboard_optional_utils.rest_response(
success=True,
message="Fetch notified agents success.",
**self._notified_agents)
**self._notified_agents,
)
@routes.get("/test/http_get")
async def get_url(self, req) -> aiohttp.web.Response:
url = req.query.get("url")
result = await test_utils.http_get(self._dashboard_head.http_session,
url)
result = await test_utils.http_get(self._dashboard_head.http_session, url)
return aiohttp.web.json_response(result)
@routes.get("/test/aiohttp_cache/{sub_path}")
@ -89,14 +92,16 @@ class TestHead(dashboard_utils.DashboardHeadModule):
async def test_aiohttp_cache(self, req) -> aiohttp.web.Response:
value = req.query["value"]
return dashboard_optional_utils.rest_response(
success=True, message="OK", value=value, timestamp=time.time())
success=True, message="OK", value=value, timestamp=time.time()
)
@routes.get("/test/aiohttp_cache_lru/{sub_path}")
@dashboard_optional_utils.aiohttp_cache(ttl_seconds=60, maxsize=5)
async def test_aiohttp_cache_lru(self, req) -> aiohttp.web.Response:
value = req.query.get("value")
return dashboard_optional_utils.rest_response(
success=True, message="OK", value=value, timestamp=time.time())
success=True, message="OK", value=value, timestamp=time.time()
)
@routes.get("/test/file")
async def test_file(self, req) -> aiohttp.web.FileResponse:

View file

@ -4,8 +4,7 @@ import copy
import os
import aiohttp.web
import ray.dashboard.modules.tune.tune_consts \
as tune_consts
import ray.dashboard.modules.tune.tune_consts as tune_consts
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
from ray.dashboard.utils import async_loop_forever
@ -45,19 +44,17 @@ class TuneController(dashboard_utils.DashboardHeadModule):
@routes.get("/tune/info")
async def tune_info(self, req) -> aiohttp.web.Response:
stats = self.get_stats()
return rest_response(
success=True, message="Fetched tune info", result=stats)
return rest_response(success=True, message="Fetched tune info", result=stats)
@routes.get("/tune/availability")
async def get_availability(self, req) -> aiohttp.web.Response:
availability = {
"available": ExperimentAnalysis is not None,
"trials_available": self._trials_available
"trials_available": self._trials_available,
}
return rest_response(
success=True,
message="Fetched tune availability",
result=availability)
success=True, message="Fetched tune availability", result=availability
)
@routes.get("/tune/set_experiment")
async def set_tune_experiment(self, req) -> aiohttp.web.Response:
@ -66,25 +63,25 @@ class TuneController(dashboard_utils.DashboardHeadModule):
if err:
return rest_response(success=False, error=err)
return rest_response(
success=True, message="Successfully set experiment", **experiment)
success=True, message="Successfully set experiment", **experiment
)
@routes.get("/tune/enable_tensorboard")
async def enable_tensorboard(self, req) -> aiohttp.web.Response:
self._enable_tensorboard()
if not self._tensor_board_dir:
return rest_response(
success=False, message="Error enabling tensorboard")
return rest_response(success=False, message="Error enabling tensorboard")
return rest_response(success=True, message="Enabled tensorboard")
def get_stats(self):
tensor_board_info = {
"tensorboard_current": self._logdir == self._tensor_board_dir,
"tensorboard_enabled": self._tensor_board_dir != ""
"tensorboard_enabled": self._tensor_board_dir != "",
}
return {
"trial_records": copy.deepcopy(self._trial_records),
"errors": copy.deepcopy(self._errors),
"tensorboard": tensor_board_info
"tensorboard": tensor_board_info,
}
def set_experiment(self, experiment):
@ -104,7 +101,8 @@ class TuneController(dashboard_utils.DashboardHeadModule):
def collect_errors(self, df):
sub_dirs = os.listdir(self._logdir)
trial_names = filter(
lambda d: os.path.isdir(os.path.join(self._logdir, d)), sub_dirs)
lambda d: os.path.isdir(os.path.join(self._logdir, d)), sub_dirs
)
for trial in trial_names:
error_path = os.path.join(self._logdir, trial, "error.txt")
if os.path.isfile(error_path):
@ -114,7 +112,7 @@ class TuneController(dashboard_utils.DashboardHeadModule):
self._errors[str(trial)] = {
"text": text,
"job_id": os.path.basename(self._logdir),
"trial_id": "No Trial ID"
"trial_id": "No Trial ID",
}
other_data = df[df["logdir"].str.contains(trial)]
if len(other_data) > 0:
@ -175,12 +173,25 @@ class TuneController(dashboard_utils.DashboardHeadModule):
# list of static attributes for trial
default_names = {
"logdir", "time_this_iter_s", "done", "episodes_total",
"training_iteration", "timestamp", "timesteps_total",
"experiment_id", "date", "timestamp", "time_total_s", "pid",
"hostname", "node_ip", "time_since_restore",
"timesteps_since_restore", "iterations_since_restore",
"experiment_tag", "trial_id"
"logdir",
"time_this_iter_s",
"done",
"episodes_total",
"training_iteration",
"timestamp",
"timesteps_total",
"experiment_id",
"date",
"timestamp",
"time_total_s",
"pid",
"hostname",
"node_ip",
"time_since_restore",
"timesteps_since_restore",
"iterations_since_restore",
"experiment_tag",
"trial_id",
}
# filter attributes into floats, metrics, and config variables
@ -196,7 +207,8 @@ class TuneController(dashboard_utils.DashboardHeadModule):
for trial, details in trial_details.items():
ts = os.path.getctime(details["logdir"])
formatted_time = datetime.datetime.fromtimestamp(ts).strftime(
"%Y-%m-%d %H:%M:%S")
"%Y-%m-%d %H:%M:%S"
)
details["start_time"] = formatted_time
details["params"] = {}
details["metrics"] = {}

View file

@ -25,7 +25,7 @@ except AttributeError:
# All third-party dependencies that are not included in the minimal Ray
# installation must be included in this file. This allows us to determine if
# the agent has the necessary dependencies to be started.
from ray.dashboard.optional_deps import (aiohttp, hdrs, PathLike, RouteDef)
from ray.dashboard.optional_deps import aiohttp, hdrs, PathLike, RouteDef
from ray.dashboard.utils import to_google_style, CustomEncoder
logger = logging.getLogger(__name__)
@ -68,12 +68,15 @@ class ClassMethodRouteTable:
def _wrapper(handler):
if path in cls._bind_map[method]:
bind_info = cls._bind_map[method][path]
raise Exception(f"Duplicated route path: {path}, "
f"previous one registered at "
f"{bind_info.filename}:{bind_info.lineno}")
raise Exception(
f"Duplicated route path: {path}, "
f"previous one registered at "
f"{bind_info.filename}:{bind_info.lineno}"
)
bind_info = cls._BindInfo(handler.__code__.co_filename,
handler.__code__.co_firstlineno, None)
bind_info = cls._BindInfo(
handler.__code__.co_filename, handler.__code__.co_firstlineno, None
)
@functools.wraps(handler)
async def _handler_route(*args) -> aiohttp.web.Response:
@ -86,8 +89,7 @@ class ClassMethodRouteTable:
return await handler(bind_info.instance, req)
except Exception:
logger.exception("Handle %s %s failed.", method, path)
return rest_response(
success=False, message=traceback.format_exc())
return rest_response(success=False, message=traceback.format_exc())
cls._bind_map[method][path] = bind_info
_handler_route.__route_method__ = method
@ -132,18 +134,19 @@ class ClassMethodRouteTable:
def bind(cls, instance):
def predicate(o):
if inspect.ismethod(o):
return hasattr(o, "__route_method__") and hasattr(
o, "__route_path__")
return hasattr(o, "__route_method__") and hasattr(o, "__route_path__")
return False
handler_routes = inspect.getmembers(instance, predicate)
for _, h in handler_routes:
cls._bind_map[h.__func__.__route_method__][
h.__func__.__route_path__].instance = instance
h.__func__.__route_path__
].instance = instance
def rest_response(success, message, convert_google_style=True,
**kwargs) -> aiohttp.web.Response:
def rest_response(
success, message, convert_google_style=True, **kwargs
) -> aiohttp.web.Response:
# In the dev context we allow a dev server running on a
# different port to consume the API, meaning we need to allow
# cross-origin access
@ -155,24 +158,24 @@ def rest_response(success, message, convert_google_style=True,
{
"result": success,
"msg": message,
"data": to_google_style(kwargs) if convert_google_style else kwargs
"data": to_google_style(kwargs) if convert_google_style else kwargs,
},
dumps=functools.partial(json.dumps, cls=CustomEncoder),
headers=headers)
headers=headers,
)
# The cache value type used by aiohttp_cache.
_AiohttpCacheValue = namedtuple("AiohttpCacheValue",
["data", "expiration", "task"])
_AiohttpCacheValue = namedtuple("AiohttpCacheValue", ["data", "expiration", "task"])
# The methods with no request body used by aiohttp_cache.
_AIOHTTP_CACHE_NOBODY_METHODS = {hdrs.METH_GET, hdrs.METH_DELETE}
def aiohttp_cache(
ttl_seconds=dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS,
maxsize=dashboard_consts.AIOHTTP_CACHE_MAX_SIZE,
enable=not env_bool(
dashboard_consts.AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY, False)):
ttl_seconds=dashboard_consts.AIOHTTP_CACHE_TTL_SECONDS,
maxsize=dashboard_consts.AIOHTTP_CACHE_MAX_SIZE,
enable=not env_bool(dashboard_consts.AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY, False),
):
assert maxsize > 0
cache = collections.OrderedDict()
@ -195,8 +198,7 @@ def aiohttp_cache(
value = cache.get(key)
if value is not None:
cache.move_to_end(key)
if (not value.task.done()
or value.expiration >= time.time()):
if not value.task.done() or value.expiration >= time.time():
# Update task not done or the data is not expired.
return aiohttp.web.Response(**value.data)
@ -205,15 +207,16 @@ def aiohttp_cache(
response = task.result()
except Exception:
response = rest_response(
success=False, message=traceback.format_exc())
success=False, message=traceback.format_exc()
)
data = {
"status": response.status,
"headers": dict(response.headers),
"body": response.body,
}
cache[key] = _AiohttpCacheValue(data,
time.time() + ttl_seconds,
task)
cache[key] = _AiohttpCacheValue(
data, time.time() + ttl_seconds, task
)
cache.move_to_end(key)
if len(cache) > maxsize:
cache.popitem(last=False)

View file

@ -19,11 +19,14 @@ import requests
from ray import ray_constants
from ray._private.test_utils import (
format_web_url, wait_for_condition, wait_until_server_available,
run_string_as_driver, wait_until_succeeded_without_exception)
format_web_url,
wait_for_condition,
wait_until_server_available,
run_string_as_driver,
wait_until_succeeded_without_exception,
)
from ray._private.gcs_pubsub import gcs_pubsub_enabled
from ray.ray_constants import (DEBUG_AUTOSCALING_STATUS_LEGACY,
DEBUG_AUTOSCALING_ERROR)
from ray.ray_constants import DEBUG_AUTOSCALING_STATUS_LEGACY, DEBUG_AUTOSCALING_ERROR
from ray.dashboard import dashboard
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils
@ -43,7 +46,8 @@ def make_gcs_client(address_info):
client = redis.StrictRedis(
host=address[0],
port=int(address[1]),
password=ray_constants.REDIS_DEFAULT_PASSWORD)
password=ray_constants.REDIS_DEFAULT_PASSWORD,
)
gcs_client = ray._private.gcs_utils.GcsClient.create_from_redis(client)
else:
address = address_info["gcs_address"]
@ -73,17 +77,14 @@ cleanup_test_files()
@pytest.mark.parametrize(
"ray_start_with_dashboard", [{
"_system_config": {
"agent_register_timeout_ms": 5000
}
}],
indirect=True)
"ray_start_with_dashboard",
[{"_system_config": {"agent_register_timeout_ms": 5000}}],
indirect=True,
)
def test_basic(ray_start_with_dashboard):
"""Dashboard test that starts a Ray cluster with a dashboard server running,
then hits the dashboard API and asserts that it receives sensible data."""
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
address_info = ray_start_with_dashboard
node_id = address_info["node_id"]
gcs_client = make_gcs_client(address_info)
@ -92,11 +93,12 @@ def test_basic(ray_start_with_dashboard):
all_processes = ray.worker._global_node.all_processes
assert ray_constants.PROCESS_TYPE_DASHBOARD in all_processes
assert ray_constants.PROCESS_TYPE_REPORTER not in all_processes
dashboard_proc_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][
0]
dashboard_proc_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
dashboard_proc = psutil.Process(dashboard_proc_info.process.pid)
assert dashboard_proc.status() in [
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING, psutil.STATUS_DISK_SLEEP
psutil.STATUS_RUNNING,
psutil.STATUS_SLEEPING,
psutil.STATUS_DISK_SLEEP,
]
raylet_proc_info = all_processes[ray_constants.PROCESS_TYPE_RAYLET][0]
raylet_proc = psutil.Process(raylet_proc_info.process.pid)
@ -140,9 +142,7 @@ def test_basic(ray_start_with_dashboard):
logger.info("Test agent register is OK.")
wait_for_condition(lambda: _search_agent(raylet_proc.children()))
assert dashboard_proc.status() in [
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING
]
assert dashboard_proc.status() in [psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING]
agent_proc = _search_agent(raylet_proc.children())
agent_pid = agent_proc.pid
@ -161,40 +161,39 @@ def test_basic(ray_start_with_dashboard):
# Check kv keys are set.
logger.info("Check kv keys are set.")
dashboard_address = ray.experimental.internal_kv._internal_kv_get(
ray_constants.DASHBOARD_ADDRESS,
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
ray_constants.DASHBOARD_ADDRESS, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
)
assert dashboard_address is not None
dashboard_rpc_address = ray.experimental.internal_kv._internal_kv_get(
dashboard_consts.DASHBOARD_RPC_ADDRESS,
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
assert dashboard_rpc_address is not None
key = f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{node_id}"
agent_ports = ray.experimental.internal_kv._internal_kv_get(
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
key, namespace=ray_constants.KV_NAMESPACE_DASHBOARD
)
assert agent_ports is not None
@pytest.mark.parametrize(
"ray_start_with_dashboard", [{
"dashboard_host": "127.0.0.1"
}, {
"dashboard_host": "0.0.0.0"
}, {
"dashboard_host": "::"
}],
indirect=True)
"ray_start_with_dashboard",
[
{"dashboard_host": "127.0.0.1"},
{"dashboard_host": "0.0.0.0"},
{"dashboard_host": "::"},
],
indirect=True,
)
def test_dashboard_address(ray_start_with_dashboard):
webui_url = ray_start_with_dashboard["webui_url"]
webui_ip = webui_url.split(":")[0]
assert not ipaddress.ip_address(webui_ip).is_unspecified
assert webui_ip in [
"127.0.0.1", ray_start_with_dashboard["node_ip_address"]
]
assert webui_ip in ["127.0.0.1", ray_start_with_dashboard["node_ip_address"]]
def test_http_get(enable_test_module, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url)
@ -205,8 +204,7 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
while True:
time.sleep(3)
try:
response = requests.get(webui_url + "/test/http_get?url=" +
target_url)
response = requests.get(webui_url + "/test/http_get?url=" + target_url)
response.raise_for_status()
try:
dump_info = response.json()
@ -221,8 +219,8 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
http_port, grpc_port = ports
response = requests.get(
f"http://{ip}:{http_port}"
f"/test/http_get_from_agent?url={target_url}")
f"http://{ip}:{http_port}" f"/test/http_get_from_agent?url={target_url}"
)
response.raise_for_status()
try:
dump_info = response.json()
@ -239,10 +237,10 @@ def test_http_get(enable_test_module, ray_start_with_dashboard):
def test_class_method_route_table(enable_test_module):
head_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardHeadModule)
head_cls_list = dashboard_utils.get_all_modules(dashboard_utils.DashboardHeadModule)
agent_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardAgentModule)
dashboard_utils.DashboardAgentModule
)
test_head_cls = None
for cls in head_cls_list:
if cls.__name__ == "TestHead":
@ -274,28 +272,23 @@ def test_class_method_route_table(enable_test_module):
assert any(_has_route(r, "POST", "/test/route_post") for r in all_routes)
assert any(_has_route(r, "PUT", "/test/route_put") for r in all_routes)
assert any(_has_route(r, "PATCH", "/test/route_patch") for r in all_routes)
assert any(
_has_route(r, "DELETE", "/test/route_delete") for r in all_routes)
assert any(_has_route(r, "DELETE", "/test/route_delete") for r in all_routes)
assert any(_has_route(r, "*", "/test/route_view") for r in all_routes)
# Test bind()
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes(
)
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
assert len(bound_routes) == 0
dashboard_optional_utils.ClassMethodRouteTable.bind(
test_agent_cls.__new__(test_agent_cls))
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes(
test_agent_cls.__new__(test_agent_cls)
)
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
assert any(_has_route(r, "POST", "/test/route_post") for r in bound_routes)
assert all(
not _has_route(r, "PUT", "/test/route_put") for r in bound_routes)
assert all(not _has_route(r, "PUT", "/test/route_put") for r in bound_routes)
# Static def should be in bound routes.
routes.static("/test/route_static", "/path")
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes(
)
assert any(
_has_static(r, "/path", "/test/route_static") for r in bound_routes)
bound_routes = dashboard_optional_utils.ClassMethodRouteTable.bound_routes()
assert any(_has_static(r, "/path", "/test/route_static") for r in bound_routes)
# Test duplicated routes should raise exception.
try:
@ -358,10 +351,10 @@ def test_async_loop_forever():
def test_dashboard_module_decorator(enable_test_module):
head_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardHeadModule)
head_cls_list = dashboard_utils.get_all_modules(dashboard_utils.DashboardHeadModule)
agent_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardAgentModule)
dashboard_utils.DashboardAgentModule
)
assert any(cls.__name__ == "TestHead" for cls in head_cls_list)
assert any(cls.__name__ == "TestAgent" for cls in agent_cls_list)
@ -385,8 +378,7 @@ print("success")
def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
webui_url = ray_start_with_dashboard["webui_url"]
webui_url = format_web_url(webui_url)
@ -397,8 +389,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
time.sleep(1)
try:
for x in range(10):
response = requests.get(webui_url +
"/test/aiohttp_cache/t1?value=1")
response = requests.get(webui_url + "/test/aiohttp_cache/t1?value=1")
response.raise_for_status()
timestamp = response.json()["data"]["timestamp"]
value1_timestamps.append(timestamp)
@ -412,8 +403,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
sub_path_timestamps = []
for x in range(10):
response = requests.get(webui_url +
f"/test/aiohttp_cache/tt{x}?value=1")
response = requests.get(webui_url + f"/test/aiohttp_cache/tt{x}?value=1")
response.raise_for_status()
timestamp = response.json()["data"]["timestamp"]
sub_path_timestamps.append(timestamp)
@ -421,8 +411,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
volatile_value_timestamps = []
for x in range(10):
response = requests.get(webui_url +
f"/test/aiohttp_cache/tt?value={x}")
response = requests.get(webui_url + f"/test/aiohttp_cache/tt?value={x}")
response.raise_for_status()
timestamp = response.json()["data"]["timestamp"]
volatile_value_timestamps.append(timestamp)
@ -436,8 +425,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
volatile_value_timestamps = []
for x in range(10):
response = requests.get(webui_url +
f"/test/aiohttp_cache_lru/tt{x % 4}")
response = requests.get(webui_url + f"/test/aiohttp_cache_lru/tt{x % 4}")
response.raise_for_status()
timestamp = response.json()["data"]["timestamp"]
volatile_value_timestamps.append(timestamp)
@ -446,8 +434,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
volatile_value_timestamps = []
data = collections.defaultdict(set)
for x in [0, 1, 2, 3, 4, 5, 2, 1, 0, 3]:
response = requests.get(webui_url +
f"/test/aiohttp_cache_lru/t1?value={x}")
response = requests.get(webui_url + f"/test/aiohttp_cache_lru/t1?value={x}")
response.raise_for_status()
timestamp = response.json()["data"]["timestamp"]
data[x].add(timestamp)
@ -458,8 +445,7 @@ def test_aiohttp_cache(enable_test_module, ray_start_with_dashboard):
def test_get_cluster_status(ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
address_info = ray_start_with_dashboard
webui_url = address_info["webui_url"]
webui_url = format_web_url(webui_url)
@ -478,14 +464,15 @@ def test_get_cluster_status(ray_start_with_dashboard):
assert "loadMetricsReport" in response.json()["data"]["clusterStatus"]
assert wait_until_succeeded_without_exception(
get_cluster_status, (requests.RequestException, ))
get_cluster_status, (requests.RequestException,)
)
gcs_client = make_gcs_client(address_info)
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
ray.experimental.internal_kv._internal_kv_put(
DEBUG_AUTOSCALING_STATUS_LEGACY, "hello")
ray.experimental.internal_kv._internal_kv_put(DEBUG_AUTOSCALING_ERROR,
"world")
DEBUG_AUTOSCALING_STATUS_LEGACY, "hello"
)
ray.experimental.internal_kv._internal_kv_put(DEBUG_AUTOSCALING_ERROR, "world")
response = requests.get(f"{webui_url}/api/cluster_status")
response.raise_for_status()
@ -508,20 +495,19 @@ def test_immutable_types():
assert immutable_dict == dashboard_utils.ImmutableDict(d)
assert immutable_dict == d
assert dashboard_utils.ImmutableDict(immutable_dict) == immutable_dict
assert dashboard_utils.ImmutableList(
immutable_dict["list"]) == immutable_dict["list"]
assert (
dashboard_utils.ImmutableList(immutable_dict["list"]) == immutable_dict["list"]
)
assert "512" in d
assert "512" in d["list"][0]
assert "512" in d["dict"]
# Test type conversion
assert type(dict(immutable_dict)["list"]) == dashboard_utils.ImmutableList
assert type(list(
immutable_dict["list"])[0]) == dashboard_utils.ImmutableDict
assert type(list(immutable_dict["list"])[0]) == dashboard_utils.ImmutableDict
# Test json dumps / loads
json_str = json.dumps(
immutable_dict, cls=dashboard_optional_utils.CustomEncoder)
json_str = json.dumps(immutable_dict, cls=dashboard_optional_utils.CustomEncoder)
deserialized_immutable_dict = json.loads(json_str)
assert type(deserialized_immutable_dict) == dict
assert type(deserialized_immutable_dict["list"]) == list
@ -577,7 +563,7 @@ def test_immutable_types():
def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
address_info = ray.init(num_cpus=1, include_dashboard=True)
assert (wait_until_server_available(address_info["webui_url"]) is True)
assert wait_until_server_available(address_info["webui_url"]) is True
webui_url = address_info["webui_url"]
webui_url = format_web_url(webui_url)
@ -588,11 +574,8 @@ def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
time.sleep(1)
try:
response = requests.get(
webui_url + "/test/dump",
proxies={
"http": None,
"https": None
})
webui_url + "/test/dump", proxies={"http": None, "https": None}
)
response.raise_for_status()
try:
response.json()
@ -609,8 +592,7 @@ def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
def test_dashboard_port_conflict(ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
address_info = ray_start_with_dashboard
gcs_client = make_gcs_client(address_info)
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
@ -618,11 +600,15 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
temp_dir = "/tmp/ray"
log_dir = "/tmp/ray/session_latest/logs"
dashboard_cmd = [
sys.executable, dashboard.__file__, f"--host={host}", f"--port={port}",
f"--temp-dir={temp_dir}", f"--log-dir={log_dir}",
sys.executable,
dashboard.__file__,
f"--host={host}",
f"--port={port}",
f"--temp-dir={temp_dir}",
f"--log-dir={log_dir}",
f"--redis-address={address_info['redis_address']}",
f"--redis-password={ray_constants.REDIS_DEFAULT_PASSWORD}",
f"--gcs-address={address_info['gcs_address']}"
f"--gcs-address={address_info['gcs_address']}",
]
logger.info("The dashboard should be exit: %s", dashboard_cmd)
p = subprocess.Popen(dashboard_cmd)
@ -638,7 +624,8 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
try:
dashboard_url = ray.experimental.internal_kv._internal_kv_get(
ray_constants.DASHBOARD_ADDRESS,
namespace=ray_constants.KV_NAMESPACE_DASHBOARD)
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
if dashboard_url:
new_port = int(dashboard_url.split(b":")[-1])
assert new_port > int(port)
@ -651,8 +638,7 @@ def test_dashboard_port_conflict(ray_start_with_dashboard):
def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
assert wait_until_server_available(ray_start_with_dashboard["webui_url"]) is True
all_processes = ray.worker._global_node.all_processes
dashboard_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][0]
@ -661,7 +647,9 @@ def test_gcs_check_alive(fast_gcs_failure_detection, ray_start_with_dashboard):
gcs_server_proc = psutil.Process(gcs_server_info.process.pid)
assert dashboard_proc.status() in [
psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING, psutil.STATUS_DISK_SLEEP
psutil.STATUS_RUNNING,
psutil.STATUS_SLEEPING,
psutil.STATUS_DISK_SLEEP,
]
gcs_server_proc.kill()

View file

@ -1,7 +1,12 @@
import ray
from ray.dashboard.memory_utils import (
ReferenceType, decode_object_ref_if_needed, MemoryTableEntry, MemoryTable,
SortingType)
ReferenceType,
decode_object_ref_if_needed,
MemoryTableEntry,
MemoryTable,
SortingType,
)
"""Memory Table Unit Test"""
NODE_ADDRESS = "127.0.0.1"
@ -14,15 +19,17 @@ DECODED_ID = decode_object_ref_if_needed(OBJECT_ID)
OBJECT_SIZE = 100
def build_memory_entry(*,
local_ref_count,
pinned_in_memory,
submitted_task_reference_count,
contained_in_owned,
object_size,
pid,
object_id=OBJECT_ID,
node_address=NODE_ADDRESS):
def build_memory_entry(
*,
local_ref_count,
pinned_in_memory,
submitted_task_reference_count,
contained_in_owned,
object_size,
pid,
object_id=OBJECT_ID,
node_address=NODE_ADDRESS
):
object_ref = {
"objectId": object_id,
"callSite": "(task call) /Users:458",
@ -30,18 +37,16 @@ def build_memory_entry(*,
"localRefCount": local_ref_count,
"pinnedInMemory": pinned_in_memory,
"submittedTaskRefCount": submitted_task_reference_count,
"containedInOwned": contained_in_owned
"containedInOwned": contained_in_owned,
}
return MemoryTableEntry(
object_ref=object_ref,
node_address=node_address,
is_driver=IS_DRIVER,
pid=pid)
object_ref=object_ref, node_address=node_address, is_driver=IS_DRIVER, pid=pid
)
def build_local_reference_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS):
def build_local_reference_entry(
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
):
return build_memory_entry(
local_ref_count=1,
pinned_in_memory=False,
@ -49,12 +54,13 @@ def build_local_reference_entry(object_size=OBJECT_SIZE,
contained_in_owned=[],
object_size=object_size,
pid=pid,
node_address=node_address)
node_address=node_address,
)
def build_used_by_pending_task_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS):
def build_used_by_pending_task_entry(
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
):
return build_memory_entry(
local_ref_count=0,
pinned_in_memory=False,
@ -62,12 +68,13 @@ def build_used_by_pending_task_entry(object_size=OBJECT_SIZE,
contained_in_owned=[],
object_size=object_size,
pid=pid,
node_address=node_address)
node_address=node_address,
)
def build_captured_in_object_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS):
def build_captured_in_object_entry(
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
):
return build_memory_entry(
local_ref_count=0,
pinned_in_memory=False,
@ -75,12 +82,13 @@ def build_captured_in_object_entry(object_size=OBJECT_SIZE,
contained_in_owned=[OBJECT_ID],
object_size=object_size,
pid=pid,
node_address=node_address)
node_address=node_address,
)
def build_actor_handle_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS):
def build_actor_handle_entry(
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
):
return build_memory_entry(
local_ref_count=1,
pinned_in_memory=False,
@ -89,12 +97,13 @@ def build_actor_handle_entry(object_size=OBJECT_SIZE,
object_size=object_size,
pid=pid,
node_address=node_address,
object_id=ACTOR_ID)
object_id=ACTOR_ID,
)
def build_pinned_in_memory_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS):
def build_pinned_in_memory_entry(
object_size=OBJECT_SIZE, pid=PID, node_address=NODE_ADDRESS
):
return build_memory_entry(
local_ref_count=0,
pinned_in_memory=True,
@ -102,28 +111,36 @@ def build_pinned_in_memory_entry(object_size=OBJECT_SIZE,
contained_in_owned=[],
object_size=object_size,
pid=pid,
node_address=node_address)
node_address=node_address,
)
def build_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS,
reference_type=ReferenceType.PINNED_IN_MEMORY):
def build_entry(
object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS,
reference_type=ReferenceType.PINNED_IN_MEMORY,
):
if reference_type == ReferenceType.USED_BY_PENDING_TASK:
return build_used_by_pending_task_entry(
pid=pid, object_size=object_size, node_address=node_address)
pid=pid, object_size=object_size, node_address=node_address
)
elif reference_type == ReferenceType.LOCAL_REFERENCE:
return build_local_reference_entry(
pid=pid, object_size=object_size, node_address=node_address)
pid=pid, object_size=object_size, node_address=node_address
)
elif reference_type == ReferenceType.PINNED_IN_MEMORY:
return build_pinned_in_memory_entry(
pid=pid, object_size=object_size, node_address=node_address)
pid=pid, object_size=object_size, node_address=node_address
)
elif reference_type == ReferenceType.ACTOR_HANDLE:
return build_actor_handle_entry(
pid=pid, object_size=object_size, node_address=node_address)
pid=pid, object_size=object_size, node_address=node_address
)
elif reference_type == ReferenceType.CAPTURED_IN_OBJECT:
return build_captured_in_object_entry(
pid=pid, object_size=object_size, node_address=node_address)
pid=pid, object_size=object_size, node_address=node_address
)
def test_invalid_memory_entry():
@ -133,7 +150,8 @@ def test_invalid_memory_entry():
submitted_task_reference_count=0,
contained_in_owned=[],
object_size=OBJECT_SIZE,
pid=PID)
pid=PID,
)
assert memory_entry.is_valid() is False
memory_entry = build_memory_entry(
local_ref_count=0,
@ -141,7 +159,8 @@ def test_invalid_memory_entry():
submitted_task_reference_count=0,
contained_in_owned=[],
object_size=-1,
pid=PID)
pid=PID,
)
assert memory_entry.is_valid() is False
@ -149,7 +168,8 @@ def test_valid_reference_memory_entry():
memory_entry = build_local_reference_entry()
assert memory_entry.reference_type == ReferenceType.LOCAL_REFERENCE
assert memory_entry.object_ref == ray.ObjectRef(
decode_object_ref_if_needed(OBJECT_ID))
decode_object_ref_if_needed(OBJECT_ID)
)
assert memory_entry.is_valid() is True
@ -178,15 +198,14 @@ def test_memory_table_summary():
build_captured_in_object_entry(),
build_actor_handle_entry(),
build_local_reference_entry(),
build_local_reference_entry()
build_local_reference_entry(),
]
memory_table = MemoryTable(entries)
assert len(memory_table.group) == 1
assert memory_table.summary["total_actor_handles"] == 1
assert memory_table.summary["total_captured_in_objects"] == 1
assert memory_table.summary["total_local_ref_count"] == 2
assert memory_table.summary[
"total_object_size"] == len(entries) * OBJECT_SIZE
assert memory_table.summary["total_object_size"] == len(entries) * OBJECT_SIZE
assert memory_table.summary["total_pinned_in_memory"] == 1
assert memory_table.summary["total_used_by_pending_task"] == 1
@ -202,14 +221,13 @@ def test_memory_table_sort_by_pid():
def test_memory_table_sort_by_reference_type():
unsort = [
ReferenceType.USED_BY_PENDING_TASK, ReferenceType.LOCAL_REFERENCE,
ReferenceType.LOCAL_REFERENCE, ReferenceType.PINNED_IN_MEMORY
ReferenceType.USED_BY_PENDING_TASK,
ReferenceType.LOCAL_REFERENCE,
ReferenceType.LOCAL_REFERENCE,
ReferenceType.PINNED_IN_MEMORY,
]
entries = [
build_entry(reference_type=reference_type) for reference_type in unsort
]
memory_table = MemoryTable(
entries, sort_by_type=SortingType.REFERENCE_TYPE)
entries = [build_entry(reference_type=reference_type) for reference_type in unsort]
memory_table = MemoryTable(entries, sort_by_type=SortingType.REFERENCE_TYPE)
sort = sorted(unsort)
for reference_type, entry in zip(sort, memory_table.table):
assert reference_type == entry.reference_type
@ -231,7 +249,7 @@ def test_group_by():
build_entry(node_address=node_second, pid=2),
build_entry(node_address=node_second, pid=1),
build_entry(node_address=node_first, pid=2),
build_entry(node_address=node_first, pid=1)
build_entry(node_address=node_first, pid=1),
]
memory_table = MemoryTable(entries)
@ -250,4 +268,5 @@ def test_group_by():
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))

View file

@ -18,8 +18,7 @@ import aiosignal # noqa: F401
from google.protobuf.json_format import MessageToDict
from frozenlist import FrozenList # noqa: F401
from ray._private.utils import (binary_to_hex,
check_dashboard_dependencies_installed)
from ray._private.utils import binary_to_hex, check_dashboard_dependencies_installed
try:
create_task = asyncio.create_task
@ -97,23 +96,26 @@ def get_all_modules(module_type):
"""
logger.info(f"Get all modules by type: {module_type.__name__}")
import ray.dashboard.modules
should_only_load_minimal_modules = (
not check_dashboard_dependencies_installed())
should_only_load_minimal_modules = not check_dashboard_dependencies_installed()
for module_loader, name, ispkg in pkgutil.walk_packages(
ray.dashboard.modules.__path__,
ray.dashboard.modules.__name__ + "."):
ray.dashboard.modules.__path__, ray.dashboard.modules.__name__ + "."
):
try:
importlib.import_module(name)
except ModuleNotFoundError as e:
logger.info(f"Module {name} cannot be loaded because "
"we cannot import all dependencies. Download "
"`pip install ray[default]` for the full "
f"dashboard functionality. Error: {e}")
logger.info(
f"Module {name} cannot be loaded because "
"we cannot import all dependencies. Download "
"`pip install ray[default]` for the full "
f"dashboard functionality. Error: {e}"
)
if not should_only_load_minimal_modules:
logger.info(
"Although `pip install ray[default] is downloaded, "
"module couldn't be imported`")
"module couldn't be imported`"
)
raise e
imported_modules = []
@ -202,7 +204,8 @@ def message_to_dict(message, decode_keys=None, **kwargs):
if decode_keys:
return _decode_keys(
MessageToDict(message, use_integers_for_enums=False, **kwargs))
MessageToDict(message, use_integers_for_enums=False, **kwargs)
)
else:
return MessageToDict(message, use_integers_for_enums=False, **kwargs)
@ -251,8 +254,9 @@ class Change:
self.new = new
def __str__(self):
return f"Change(owner: {type(self.owner)}), " \
f"old: {self.old}, new: {self.new}"
return (
f"Change(owner: {type(self.owner)}), " f"old: {self.old}, new: {self.new}"
)
class NotifyQueue:
@ -289,10 +293,7 @@ https://docs.python.org/3/library/json.html?highlight=json#json.JSONEncoder
| None | null |
+-------------------+---------------+
"""
_json_compatible_types = {
dict, list, tuple, str, int, float, bool,
type(None), bytes
}
_json_compatible_types = {dict, list, tuple, str, int, float, bool, type(None), bytes}
def is_immutable(self):
@ -318,8 +319,7 @@ class Immutable(metaclass=ABCMeta):
class ImmutableList(Immutable, Sequence):
"""Makes a :class:`list` immutable.
"""
"""Makes a :class:`list` immutable."""
__slots__ = ("_list", "_proxy")
@ -332,7 +332,7 @@ class ImmutableList(Immutable, Sequence):
self._proxy = [None] * len(list_value)
def __reduce_ex__(self, protocol):
return type(self), (self._list, )
return type(self), (self._list,)
def mutable(self):
return self._list
@ -366,8 +366,7 @@ class ImmutableList(Immutable, Sequence):
class ImmutableDict(Immutable, Mapping):
"""Makes a :class:`dict` immutable.
"""
"""Makes a :class:`dict` immutable."""
__slots__ = ("_dict", "_proxy")
@ -380,7 +379,7 @@ class ImmutableDict(Immutable, Mapping):
self._proxy = {}
def __reduce_ex__(self, protocol):
return type(self), (self._dict, )
return type(self), (self._dict,)
def mutable(self):
return self._dict
@ -443,21 +442,23 @@ class Dict(ImmutableDict, MutableMapping):
if len(self.signal) and old != value:
if old is None:
co = self.signal.send(
Change(owner=self, new=Dict.ChangeItem(key, value)))
Change(owner=self, new=Dict.ChangeItem(key, value))
)
else:
co = self.signal.send(
Change(
owner=self,
old=Dict.ChangeItem(key, old),
new=Dict.ChangeItem(key, value)))
new=Dict.ChangeItem(key, value),
)
)
NotifyQueue.put(co)
def __delitem__(self, key):
old = self._dict.pop(key, None)
self._proxy.pop(key, None)
if len(self.signal) and old is not None:
co = self.signal.send(
Change(owner=self, old=Dict.ChangeItem(key, old)))
co = self.signal.send(Change(owner=self, old=Dict.ChangeItem(key, old)))
NotifyQueue.put(co)
def reset(self, d):
@ -482,12 +483,15 @@ def async_loop_forever(interval_seconds, cancellable=False):
await coro(*args, **kwargs)
except asyncio.CancelledError as ex:
if cancellable:
logger.info(f"An async loop forever coroutine "
f"is cancelled {coro}.")
logger.info(
f"An async loop forever coroutine " f"is cancelled {coro}."
)
raise ex
else:
logger.exception(f"Can not cancel the async loop "
f"forever coroutine {coro}.")
logger.exception(
f"Can not cancel the async loop "
f"forever coroutine {coro}."
)
except Exception:
logger.exception(f"Error looping coroutine {coro}.")
await asyncio.sleep(interval_seconds)
@ -497,15 +501,18 @@ def async_loop_forever(interval_seconds, cancellable=False):
return _wrapper
async def get_aioredis_client(redis_address, redis_password,
retry_interval_seconds, retry_times):
async def get_aioredis_client(
redis_address, redis_password, retry_interval_seconds, retry_times
):
for x in range(retry_times):
try:
return await aioredis.create_redis_pool(
address=redis_address, password=redis_password)
address=redis_address, password=redis_password
)
except (socket.gaierror, ConnectionError) as ex:
logger.error("Connect to Redis failed: %s, retry...", ex)
await asyncio.sleep(retry_interval_seconds)
# Raise exception from create_redis_pool
return await aioredis.create_redis_pool(
address=redis_address, password=redis_password)
address=redis_address, password=redis_password
)

View file

@ -2,6 +2,7 @@ from collections import Counter
import sys
import time
import ray
""" This script is meant to be run from a pod in the same Kubernetes namespace
as your Ray cluster.
"""
@ -11,8 +12,9 @@ as your Ray cluster.
def gethostname(x):
import platform
import time
time.sleep(0.01)
return x + (platform.node(), )
return x + (platform.node(),)
def wait_for_nodes(expected):
@ -22,8 +24,11 @@ def wait_for_nodes(expected):
node_keys = [key for key in resources if "node" in key]
num_nodes = sum(resources[node_key] for node_key in node_keys)
if num_nodes < expected:
print("{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes))
print(
"{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes
)
)
sys.stdout.flush()
time.sleep(1)
else:
@ -36,9 +41,7 @@ def main():
# Check that objects can be transferred from each node to each other node.
for i in range(10):
print("Iteration {}".format(i))
results = [
gethostname.remote(gethostname.remote(())) for _ in range(100)
]
results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
print(Counter(ray.get(results)))
sys.stdout.flush()

View file

@ -2,6 +2,7 @@ from collections import Counter
import sys
import time
import ray
""" Run this script locally to execute a Ray program on your Ray cluster on
Kubernetes.
@ -18,8 +19,9 @@ LOCAL_PORT = 10001
def gethostname(x):
import platform
import time
time.sleep(0.01)
return x + (platform.node(), )
return x + (platform.node(),)
def wait_for_nodes(expected):
@ -29,8 +31,11 @@ def wait_for_nodes(expected):
node_keys = [key for key in resources if "node" in key]
num_nodes = sum(resources[node_key] for node_key in node_keys)
if num_nodes < expected:
print("{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes))
print(
"{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes
)
)
sys.stdout.flush()
time.sleep(1)
else:
@ -43,9 +48,7 @@ def main():
# Check that objects can be transferred from each node to each other node.
for i in range(10):
print("Iteration {}".format(i))
results = [
gethostname.remote(gethostname.remote(())) for _ in range(100)
]
results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
print(Counter(ray.get(results)))
sys.stdout.flush()

View file

@ -10,8 +10,9 @@ import ray
def gethostname(x):
import platform
import time
time.sleep(0.01)
return x + (platform.node(), )
return x + (platform.node(),)
def wait_for_nodes(expected):
@ -21,8 +22,11 @@ def wait_for_nodes(expected):
node_keys = [key for key in resources if "node" in key]
num_nodes = sum(resources[node_key] for node_key in node_keys)
if num_nodes < expected:
print("{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes))
print(
"{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes
)
)
sys.stdout.flush()
time.sleep(1)
else:
@ -35,9 +39,7 @@ def main():
# Check that objects can be transferred from each node to each other node.
for i in range(10):
print("Iteration {}".format(i))
results = [
gethostname.remote(gethostname.remote(())) for _ in range(100)
]
results = [gethostname.remote(gethostname.remote(())) for _ in range(100)]
print(Counter(ray.get(results)))
sys.stdout.flush()

View file

@ -25,24 +25,24 @@ if __name__ == "__main__":
"--exp-name",
type=str,
required=True,
help="The job name and path to logging file (exp_name.log).")
help="The job name and path to logging file (exp_name.log).",
)
parser.add_argument(
"--num-nodes",
"-n",
type=int,
default=1,
help="Number of nodes to use.")
"--num-nodes", "-n", type=int, default=1, help="Number of nodes to use."
)
parser.add_argument(
"--node",
"-w",
type=str,
help="The specified nodes to use. Same format as the "
"return of 'sinfo'. Default: ''.")
"return of 'sinfo'. Default: ''.",
)
parser.add_argument(
"--num-gpus",
type=int,
default=0,
help="Number of GPUs to use in each node. (Default: 0)")
help="Number of GPUs to use in each node. (Default: 0)",
)
parser.add_argument(
"--partition",
"-p",
@ -51,14 +51,16 @@ if __name__ == "__main__":
parser.add_argument(
"--load-env",
type=str,
help="The script to load your environment ('module load cuda/10.1')")
help="The script to load your environment ('module load cuda/10.1')",
)
parser.add_argument(
"--command",
type=str,
required=True,
help="The command you wish to execute. For example: "
" --command 'python test.py'. "
"Note that the command must be a string.")
"Note that the command must be a string.",
)
args = parser.parse_args()
if args.node:
@ -67,11 +69,13 @@ if __name__ == "__main__":
else:
node_info = ""
job_name = "{}_{}".format(args.exp_name,
time.strftime("%m%d-%H%M", time.localtime()))
job_name = "{}_{}".format(
args.exp_name, time.strftime("%m%d-%H%M", time.localtime())
)
partition_option = "#SBATCH --partition={}".format(
args.partition) if args.partition else ""
partition_option = (
"#SBATCH --partition={}".format(args.partition) if args.partition else ""
)
# ===== Modified the template script =====
with open(template_file, "r") as f:
@ -84,10 +88,10 @@ if __name__ == "__main__":
text = text.replace(LOAD_ENV, str(args.load_env))
text = text.replace(GIVEN_NODE, node_info)
text = text.replace(
"# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO "
"PRODUCTION!",
"# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO " "PRODUCTION!",
"# THIS FILE IS MODIFIED AUTOMATICALLY FROM TEMPLATE AND SHOULD BE "
"RUNNABLE!")
"RUNNABLE!",
)
# ===== Save the script =====
script_file = "{}.sh".format(job_name)
@ -99,5 +103,7 @@ if __name__ == "__main__":
subprocess.Popen(["sbatch", script_file])
print(
"Job submitted! Script file is at: <{}>. Log file is at: <{}>".format(
script_file, "{}.log".format(job_name)))
script_file, "{}.log".format(job_name)
)
)
sys.exit(0)

View file

@ -89,7 +89,7 @@ myst_enable_extensions = [
]
external_toc_exclude_missing = False
external_toc_path = '_toc.yml'
external_toc_path = "_toc.yml"
# There's a flaky autodoc import for "TensorFlowVariables" that fails depending on the doc structure / order
# of imports.
@ -112,7 +112,8 @@ versionwarning_messages = {
"<b>Got questions?</b> Join "
f'<a href="{FORUM_LINK}">the Ray Community forum</a> '
"for Q&A on all things Ray, as well as to share and learn use cases "
"and best practices with the Ray community."),
"and best practices with the Ray community."
),
}
versionwarning_body_selector = "#main-content"
@ -189,11 +190,16 @@ exclude_patterns += sphinx_gallery_conf["examples_dirs"]
# If "DOC_LIB" is found, only build that top-level navigation item.
build_one_lib = os.getenv("DOC_LIB")
all_toc_libs = [
f.path for f in os.scandir(".") if f.is_dir() and "ray-" in f.path
]
all_toc_libs = [f.path for f in os.scandir(".") if f.is_dir() and "ray-" in f.path]
all_toc_libs += [
"cluster", "tune", "data", "raysgd", "train", "rllib", "serve", "workflows"
"cluster",
"tune",
"data",
"raysgd",
"train",
"rllib",
"serve",
"workflows",
]
if build_one_lib and build_one_lib in all_toc_libs:
all_toc_libs.remove(build_one_lib)
@ -405,7 +411,8 @@ def setup(app):
# Custom JS
app.add_js_file(
"https://cdn.jsdelivr.net/npm/docsearch.js@2/dist/cdn/docsearch.min.js",
defer="defer")
defer="defer",
)
app.add_js_file("js/docsearch.js", defer="defer")
# Custom Sphinx directives
app.add_directive("customgalleryitem", CustomGalleryItemDirective)

View file

@ -6,13 +6,17 @@ from docutils import nodes
import os
import sphinx_gallery
import urllib
# Note: the scipy import has to stay here, it's used implicitly down the line
import scipy.stats # noqa: F401
import scipy.linalg # noqa: F401
__all__ = [
"CustomGalleryItemDirective", "fix_xgb_lgbm_docs", "MOCK_MODULES",
"CHILD_MOCK_MODULES", "update_context"
"CustomGalleryItemDirective",
"fix_xgb_lgbm_docs",
"MOCK_MODULES",
"CHILD_MOCK_MODULES",
"update_context",
]
try:
@ -60,7 +64,7 @@ class CustomGalleryItemDirective(Directive):
option_spec = {
"tooltip": directives.unchanged,
"figure": directives.unchanged,
"description": directives.unchanged
"description": directives.unchanged,
}
has_content = False
@ -73,8 +77,9 @@ class CustomGalleryItemDirective(Directive):
if len(self.options["tooltip"]) > 195:
tooltip = tooltip[:195] + "..."
else:
raise ValueError("Need to provide :tooltip: under "
"`.. customgalleryitem::`.")
raise ValueError(
"Need to provide :tooltip: under " "`.. customgalleryitem::`."
)
# Generate `thumbnail` used in the gallery.
if "figure" in self.options:
@ -95,11 +100,13 @@ class CustomGalleryItemDirective(Directive):
if "description" in self.options:
description = self.options["description"]
else:
raise ValueError("Need to provide :description: under "
"`customgalleryitem::`.")
raise ValueError(
"Need to provide :description: under " "`customgalleryitem::`."
)
thumbnail_rst = GALLERY_TEMPLATE.format(
tooltip=tooltip, thumbnail=thumbnail, description=description)
tooltip=tooltip, thumbnail=thumbnail, description=description
)
thumbnail = StringList(thumbnail_rst.split("\n"))
thumb = nodes.paragraph()
self.state.nested_parse(thumbnail, self.content_offset, thumb)
@ -146,29 +153,30 @@ def fix_xgb_lgbm_docs(app, what, name, obj, options, lines):
# Taken from https://github.com/edx/edx-documentation
FEEDBACK_FORM_FMT = "https://github.com/ray-project/ray/issues/new?" \
"title={title}&labels=docs&body={body}"
FEEDBACK_FORM_FMT = (
"https://github.com/ray-project/ray/issues/new?"
"title={title}&labels=docs&body={body}"
)
def feedback_form_url(project, page):
"""Create a URL for feedback on a particular page in a project."""
return FEEDBACK_FORM_FMT.format(
title=urllib.parse.quote(
"[docs] Issue on `{page}.rst`".format(page=page)),
title=urllib.parse.quote("[docs] Issue on `{page}.rst`".format(page=page)),
body=urllib.parse.quote(
"# Documentation Problem/Question/Comment\n"
"<!-- Describe your issue/question/comment below. -->\n"
"<!-- If there are typos or errors in the docs, feel free "
"to create a pull-request. -->\n"
"\n\n\n\n"
"(Created directly from the docs)\n"),
"(Created directly from the docs)\n"
),
)
def update_context(app, pagename, templatename, context, doctree):
"""Update the page rendering context to include ``feedback_form_url``."""
context["feedback_form_url"] = feedback_form_url(app.config.project,
pagename)
context["feedback_form_url"] = feedback_form_url(app.config.project, pagename)
MOCK_MODULES = [
@ -187,8 +195,7 @@ MOCK_MODULES = [
"horovod.ray.runner",
"horovod.ray.utils",
"hyperopt",
"hyperopt.hp"
"kubernetes",
"hyperopt.hp" "kubernetes",
"mlflow",
"modin",
"mxnet",

View file

@ -59,13 +59,16 @@ from ray.data.datasource.datasource import RandomIntRowDatasource
# Lets see how we implement such pipeline using Ray Dataset:
def create_shuffle_pipeline(training_data_dir: str, num_epochs: int,
num_shards: int) -> List[DatasetPipeline]:
def create_shuffle_pipeline(
training_data_dir: str, num_epochs: int, num_shards: int
) -> List[DatasetPipeline]:
return ray.data.read_parquet(training_data_dir) \
.repeat(num_epochs) \
.random_shuffle_each_window() \
return (
ray.data.read_parquet(training_data_dir)
.repeat(num_epochs)
.random_shuffle_each_window()
.split(num_shards, equal=True)
)
############################################################################
@ -117,7 +120,8 @@ parser = argparse.ArgumentParser()
parser.add_argument(
"--large-scale-test",
action="store_true",
help="Run large scale test (500GiB of data).")
help="Run large scale test (500GiB of data).",
)
args, _ = parser.parse_known_args()
@ -142,13 +146,15 @@ if not args.large_scale_test:
ray.data.read_datasource(
RandomIntRowDatasource(),
n=size_bytes // 8 // NUM_COLUMNS,
num_columns=NUM_COLUMNS).write_parquet(tmpdir)
num_columns=NUM_COLUMNS,
).write_parquet(tmpdir)
return tmpdir
example_files_dir = generate_example_files(SIZE_100MiB)
splits = create_shuffle_pipeline(example_files_dir, NUM_EPOCHS,
NUM_TRAINING_WORKERS)
splits = create_shuffle_pipeline(
example_files_dir, NUM_EPOCHS, NUM_TRAINING_WORKERS
)
training_workers = [
TrainingWorker.remote(rank, shard) for rank, shard in enumerate(splits)
@ -198,18 +204,22 @@ if not args.large_scale_test:
# generated data.
def create_large_shuffle_pipeline(data_size_bytes: int, num_epochs: int,
num_columns: int,
num_shards: int) -> List[DatasetPipeline]:
def create_large_shuffle_pipeline(
data_size_bytes: int, num_epochs: int, num_columns: int, num_shards: int
) -> List[DatasetPipeline]:
# _spread_resource_prefix is used to ensure tasks are evenly spread to all
# CPU nodes.
return ray.data.read_datasource(
RandomIntRowDatasource(), n=data_size_bytes // 8 // num_columns,
return (
ray.data.read_datasource(
RandomIntRowDatasource(),
n=data_size_bytes // 8 // num_columns,
num_columns=num_columns,
_spread_resource_prefix="node:") \
.repeat(num_epochs) \
.random_shuffle_each_window(_spread_resource_prefix="node:") \
_spread_resource_prefix="node:",
)
.repeat(num_epochs)
.random_shuffle_each_window(_spread_resource_prefix="node:")
.split(num_shards, equal=True)
)
#################################################################################
@ -229,19 +239,18 @@ if args.large_scale_test:
# waiting for cluster nodes to come up.
while len(ray.nodes()) < TOTAL_NUM_NODES:
print(
f"waiting for nodes to start up: {len(ray.nodes())}/{TOTAL_NUM_NODES}"
)
print(f"waiting for nodes to start up: {len(ray.nodes())}/{TOTAL_NUM_NODES}")
time.sleep(5)
splits = create_large_shuffle_pipeline(SIZE_500GiB, NUM_EPOCHS,
NUM_COLUMNS, NUM_TRAINING_WORKERS)
splits = create_large_shuffle_pipeline(
SIZE_500GiB, NUM_EPOCHS, NUM_COLUMNS, NUM_TRAINING_WORKERS
)
# Note we set num_gpus=1 for workers so that
# the workers will only run on GPU nodes.
training_workers = [
TrainingWorker.options(num_gpus=1) \
.remote(rank, shard) for rank, shard in enumerate(splits)
TrainingWorker.options(num_gpus=1).remote(rank, shard)
for rank, shard in enumerate(splits)
]
start = time.time()

View file

@ -4,21 +4,30 @@ import pyximport
pyximport.install(setup_args={"include_dirs": numpy.get_include()})
from .cython_simple import simple_func, fib, fib_int, \
fib_cpdef, fib_cdef, simple_class
from .cython_simple import simple_func, fib, fib_int, fib_cpdef, fib_cdef, simple_class
from .masked_log import masked_log
from .cython_blas import \
compute_self_corr_for_voxel_sel, \
compute_kernel_matrix, \
compute_single_self_corr_syrk, \
compute_single_self_corr_gemm, \
compute_corr_vectors, \
compute_single_matrix_multiplication
from .cython_blas import (
compute_self_corr_for_voxel_sel,
compute_kernel_matrix,
compute_single_self_corr_syrk,
compute_single_self_corr_gemm,
compute_corr_vectors,
compute_single_matrix_multiplication,
)
__all__ = [
"simple_func", "fib", "fib_int", "fib_cpdef", "fib_cdef", "simple_class",
"masked_log", "compute_self_corr_for_voxel_sel", "compute_kernel_matrix",
"compute_single_self_corr_syrk", "compute_single_self_corr_gemm",
"compute_corr_vectors", "compute_single_matrix_multiplication"
"simple_func",
"fib",
"fib_int",
"fib_cpdef",
"fib_cdef",
"simple_class",
"masked_log",
"compute_self_corr_for_voxel_sel",
"compute_kernel_matrix",
"compute_single_self_corr_syrk",
"compute_single_self_corr_gemm",
"compute_corr_vectors",
"compute_single_matrix_multiplication",
]

View file

@ -94,11 +94,11 @@ def example8():
# See cython_blas.pyx for argument documentation
mat = np.array(
[[[2.0, 2.0], [2.0, 2.0]], [[2.0, 2.0], [2.0, 2.0]]], dtype=np.float32)
[[[2.0, 2.0], [2.0, 2.0]], [[2.0, 2.0], [2.0, 2.0]]], dtype=np.float32
)
result = np.zeros((2, 2), np.float32, order="C")
run_func(cyth.compute_kernel_matrix, "L", "T", 2, 2, 1.0, mat, 0, 2, 1.0,
result, 2)
run_func(cyth.compute_kernel_matrix, "L", "T", 2, 2, 1.0, mat, 0, 2, 1.0, result, 2)
if __name__ == "__main__":

View file

@ -13,6 +13,7 @@ include_dirs = [numpy.get_include()]
# dependencies
try:
import scipy # noqa
modules.append("cython_blas.pyx")
install_requires.append("scipy")
except ImportError as e: # noqa
@ -27,4 +28,5 @@ setup(
packages=[pkg_dir],
ext_modules=cythonize(modules),
install_requires=install_requires,
include_dirs=include_dirs)
include_dirs=include_dirs,
)

View file

@ -56,4 +56,5 @@ class CythonTest(unittest.TestCase):
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -65,31 +65,34 @@ from ray.util.dask import ray_dask_get
parser = argparse.ArgumentParser()
parser.add_argument(
"--address", type=str, default="auto", help="The address to use for Ray.")
"--address", type=str, default="auto", help="The address to use for Ray."
)
parser.add_argument(
"--smoke-test",
action="store_true",
help="Read a smaller dataset for quick testing purposes.")
help="Read a smaller dataset for quick testing purposes.",
)
parser.add_argument(
"--num-actors",
type=int,
default=4,
help="Sets number of actors for training.")
"--num-actors", type=int, default=4, help="Sets number of actors for training."
)
parser.add_argument(
"--cpus-per-actor",
type=int,
default=6,
help="The number of CPUs per actor for training.")
help="The number of CPUs per actor for training.",
)
parser.add_argument(
"--num-actors-inference",
type=int,
default=16,
help="Sets number of actors for inference.")
help="Sets number of actors for inference.",
)
parser.add_argument(
"--cpus-per-actor-inference",
type=int,
default=2,
help="The number of CPUs per actor for inference.")
help="The number of CPUs per actor for inference.",
)
# Ignore -f from ipykernel_launcher
args, _ = parser.parse_known_args()
@ -125,12 +128,13 @@ if not ray.is_initialized():
LABEL_COLUMN = "label"
if smoke_test:
# Test dataset with only 10,000 records.
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \
".csv"
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" ".csv"
else:
# Full dataset. This may take a couple of minutes to load.
FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \
"/00280/HIGGS.csv.gz"
FILE_URL = (
"https://archive.ics.uci.edu/ml/machine-learning-databases"
"/00280/HIGGS.csv.gz"
)
colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)]
dask.config.set(scheduler=ray_dask_get)
@ -192,7 +196,8 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
dtrain=train_set,
evals=[(test_set, "eval")],
evals_result=evals_result,
ray_params=ray_params)
ray_params=ray_params,
)
train_end_time = time.time()
train_duration = train_end_time - train_start_time
@ -200,8 +205,7 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
model_path = "model.xgb"
bst.save_model(model_path)
print("Final validation error: {:.4f}".format(
evals_result["eval"]["error"][-1]))
print("Final validation error: {:.4f}".format(evals_result["eval"]["error"][-1]))
return bst, evals_result
@ -221,8 +225,12 @@ config = {
}
bst, evals_result = train_xgboost(
config, train_df, eval_df, LABEL_COLUMN,
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors))
config,
train_df,
eval_df,
LABEL_COLUMN,
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors),
)
print(f"Results: {evals_result}")
###############################################################################
@ -260,13 +268,12 @@ def tune_xgboost(train_df, test_df, target_column):
"eval_metric": ["logloss", "error"],
"eta": tune.loguniform(1e-4, 1e-1),
"subsample": tune.uniform(0.5, 1.0),
"max_depth": tune.randint(1, 9)
"max_depth": tune.randint(1, 9),
}
ray_params = RayParams(
max_actor_restarts=1,
cpus_per_actor=cpus_per_actor,
num_actors=num_actors)
max_actor_restarts=1, cpus_per_actor=cpus_per_actor, num_actors=num_actors
)
tune_start_time = time.time()
@ -276,19 +283,21 @@ def tune_xgboost(train_df, test_df, target_column):
train_df=train_df,
test_df=test_df,
target_column=target_column,
ray_params=ray_params),
ray_params=ray_params,
),
# Use the `get_tune_resources` helper function to set the resources.
resources_per_trial=ray_params.get_tune_resources(),
config=config,
num_samples=10,
metric="eval-error",
mode="min")
mode="min",
)
tune_end_time = time.time()
tune_duration = tune_end_time - tune_start_time
print(f"Total time taken: {tune_duration} seconds.")
accuracy = 1. - analysis.best_result["eval-error"]
accuracy = 1.0 - analysis.best_result["eval-error"]
print(f"Best model parameters: {analysis.best_config}")
print(f"Best model total accuracy: {accuracy:.4f}")
@ -315,7 +324,8 @@ results = predict(
bst,
inference_df,
ray_params=RayParams(
cpus_per_actor=cpus_per_actor_inference,
num_actors=num_actors_inference))
cpus_per_actor=cpus_per_actor_inference, num_actors=num_actors_inference
),
)
print(results)

View file

@ -59,7 +59,8 @@ def make_and_upload_dataset(dir_path):
shift=0.0,
scale=1.0,
shuffle=False,
random_state=seed)
random_state=seed,
)
# turn into dataframe with column names
col_names = ["feature_%0d" % i for i in range(1, d + 1, 1)]
@ -91,10 +92,8 @@ def make_and_upload_dataset(dir_path):
path = os.path.join(data_path, f"data_{i:05d}.parquet.snappy")
if not os.path.exists(path):
tmp_df = create_data_chunk(
n=PARQUET_FILE_CHUNK_SIZE,
d=NUM_FEATURES,
seed=i,
include_label=True)
n=PARQUET_FILE_CHUNK_SIZE, d=NUM_FEATURES, seed=i, include_label=True
)
tmp_df.to_parquet(path, compression="snappy", index=False)
print(f"Wrote {path} to disk...")
# todo: at large enough scale we might want to upload the rest after
@ -108,10 +107,8 @@ def make_and_upload_dataset(dir_path):
path = os.path.join(inference_path, f"data_{i:05d}.parquet.snappy")
if not os.path.exists(path):
tmp_df = create_data_chunk(
n=PARQUET_FILE_CHUNK_SIZE,
d=NUM_FEATURES,
seed=i,
include_label=False)
n=PARQUET_FILE_CHUNK_SIZE, d=NUM_FEATURES, seed=i, include_label=False
)
tmp_df.to_parquet(path, compression="snappy", index=False)
print(f"Wrote {path} to disk...")
# todo: at large enough scale we might want to upload the rest after
@ -124,8 +121,9 @@ def make_and_upload_dataset(dir_path):
def read_dataset(path: str) -> ray.data.Dataset:
print(f"reading data from {path}")
return ray.data.read_parquet(path, _spread_resource_prefix="node:") \
.random_shuffle(_spread_resource_prefix="node:")
return ray.data.read_parquet(path, _spread_resource_prefix="node:").random_shuffle(
_spread_resource_prefix="node:"
)
class DataPreprocessor:
@ -141,20 +139,20 @@ class DataPreprocessor:
# columns.
self.standard_stats = None
def preprocess_train_data(self, ds: ray.data.Dataset
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
def preprocess_train_data(
self, ds: ray.data.Dataset
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
print("\n\nPreprocessing training dataset.\n")
return self._preprocess(ds, False)
def preprocess_inference_data(self,
df: ray.data.Dataset) -> ray.data.Dataset:
def preprocess_inference_data(self, df: ray.data.Dataset) -> ray.data.Dataset:
print("\n\nPreprocessing inference dataset.\n")
return self._preprocess(df, True)[0]
def _preprocess(self, ds: ray.data.Dataset, inferencing: bool
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
print(
"\nStep 1: Dropping nulls, creating new_col, updating feature_1\n")
def _preprocess(
self, ds: ray.data.Dataset, inferencing: bool
) -> Tuple[ray.data.Dataset, ray.data.Dataset]:
print("\nStep 1: Dropping nulls, creating new_col, updating feature_1\n")
def batch_transformer(df: pd.DataFrame):
# Disable chained assignment warning.
@ -165,25 +163,27 @@ class DataPreprocessor:
# Add new column.
df["new_col"] = (
df["feature_1"] - 2 * df["feature_2"] + df["feature_3"]) / 3.
df["feature_1"] - 2 * df["feature_2"] + df["feature_3"]
) / 3.0
# Transform column.
df["feature_1"] = 2. * df["feature_1"] + 0.1
df["feature_1"] = 2.0 * df["feature_1"] + 0.1
return df
ds = ds.map_batches(batch_transformer, batch_format="pandas")
print("\nStep 2: Precalculating fruit-grouped mean for new column and "
"for one-hot encoding (latter only uses fruit groups)\n")
print(
"\nStep 2: Precalculating fruit-grouped mean for new column and "
"for one-hot encoding (latter only uses fruit groups)\n"
)
agg_ds = ds.groupby("fruit").mean("feature_1")
fruit_means = {
r["fruit"]: r["mean(feature_1)"]
for r in agg_ds.take_all()
}
fruit_means = {r["fruit"]: r["mean(feature_1)"] for r in agg_ds.take_all()}
print("\nStep 3: create mean_by_fruit as mean of feature_1 groupby "
"fruit; one-hot encode fruit column\n")
print(
"\nStep 3: create mean_by_fruit as mean of feature_1 groupby "
"fruit; one-hot encode fruit column\n"
)
if inferencing:
assert self.fruits is not None
@ -192,8 +192,7 @@ class DataPreprocessor:
self.fruits = list(fruit_means.keys())
fruit_one_hots = {
fruit: collections.defaultdict(int, fruit=1)
for fruit in self.fruits
fruit: collections.defaultdict(int, fruit=1) for fruit in self.fruits
}
def batch_transformer(df: pd.DataFrame):
@ -224,12 +223,12 @@ class DataPreprocessor:
# Split into 90% training set, 10% test set.
train_ds, test_ds = ds.split_at_indices([split_index])
print("\nStep 4b: Precalculate training dataset stats for "
"standard scaling\n")
print(
"\nStep 4b: Precalculate training dataset stats for "
"standard scaling\n"
)
# Calculate stats needed for standard scaling feature columns.
feature_columns = [
col for col in train_ds.schema().names if col != "label"
]
feature_columns = [col for col in train_ds.schema().names if col != "label"]
standard_aggs = [
agg(on=col) for col in feature_columns for agg in (Mean, Std)
]
@ -252,30 +251,29 @@ class DataPreprocessor:
if inferencing:
# Apply standard scaling to inference dataset.
inference_ds = ds.map_batches(
batch_standard_scaler, batch_format="pandas")
inference_ds = ds.map_batches(batch_standard_scaler, batch_format="pandas")
return inference_ds, None
else:
# Apply standard scaling to both training dataset and test dataset.
train_ds = train_ds.map_batches(
batch_standard_scaler, batch_format="pandas")
test_ds = test_ds.map_batches(
batch_standard_scaler, batch_format="pandas")
batch_standard_scaler, batch_format="pandas"
)
test_ds = test_ds.map_batches(batch_standard_scaler, batch_format="pandas")
return train_ds, test_ds
def inference(dataset, model_cls: type, batch_size: int, result_path: str,
use_gpu: bool):
def inference(
dataset, model_cls: type, batch_size: int, result_path: str, use_gpu: bool
):
print("inferencing...")
num_gpus = 1 if use_gpu else 0
dataset \
.map_batches(
model_cls,
compute="actors",
batch_size=batch_size,
num_gpus=num_gpus,
num_cpus=0) \
.write_parquet(result_path)
dataset.map_batches(
model_cls,
compute="actors",
batch_size=batch_size,
num_gpus=num_gpus,
num_cpus=0,
).write_parquet(result_path)
"""
@ -295,8 +293,7 @@ P1:
class Net(nn.Module):
def __init__(self, n_layers, n_features, num_hidden, dropout_every,
drop_prob):
def __init__(self, n_layers, n_features, num_hidden, dropout_every, drop_prob):
super().__init__()
self.n_layers = n_layers
self.dropout_every = dropout_every
@ -406,8 +403,9 @@ def train_func(config):
print("Defining model, loss, and optimizer...")
# Setup device.
device = torch.device(f"cuda:{train.local_rank()}"
if use_gpu and torch.cuda.is_available() else "cpu")
device = torch.device(
f"cuda:{train.local_rank()}" if use_gpu and torch.cuda.is_available() else "cpu"
)
print(f"Device: {device}")
# Setup data.
@ -415,7 +413,8 @@ def train_func(config):
train_dataset_epoch_iterator = train_dataset_pipeline.iter_epochs()
test_dataset = train.get_dataset_shard("test_dataset")
test_torch_dataset = test_dataset.to_torch(
label_column="label", batch_size=batch_size)
label_column="label", batch_size=batch_size
)
net = Net(
n_layers=num_layers,
@ -436,30 +435,37 @@ def train_func(config):
train_dataset = next(train_dataset_epoch_iterator)
train_torch_dataset = train_dataset.to_torch(
label_column="label", batch_size=batch_size)
label_column="label", batch_size=batch_size
)
train_running_loss, train_num_correct, train_num_total = train_epoch(
train_torch_dataset, net, device, criterion, optimizer)
train_torch_dataset, net, device, criterion, optimizer
)
train_acc = train_num_correct / train_num_total
print(f"epoch [{epoch + 1}]: training accuracy: "
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}")
print(
f"epoch [{epoch + 1}]: training accuracy: "
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}"
)
test_running_loss, test_num_correct, test_num_total = test_epoch(
test_torch_dataset, net, device, criterion)
test_torch_dataset, net, device, criterion
)
test_acc = test_num_correct / test_num_total
print(f"epoch [{epoch + 1}]: testing accuracy: "
f"{test_num_correct} / {test_num_total} = {test_acc:.4f}")
print(
f"epoch [{epoch + 1}]: testing accuracy: "
f"{test_num_correct} / {test_num_total} = {test_acc:.4f}"
)
# Record and log stats.
train.report(
train_acc=train_acc,
train_loss=train_running_loss,
test_acc=test_acc,
test_loss=test_running_loss)
test_loss=test_running_loss,
)
# Checkpoint model.
module = (net.module
if isinstance(net, DistributedDataParallel) else net)
module = net.module if isinstance(net, DistributedDataParallel) else net
train.save_checkpoint(model_state_dict=module.state_dict())
if train.world_rank() == 0:
@ -469,46 +475,44 @@ def train_func(config):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dir-path",
default=".",
type=str,
help="Path to read and write data from")
"--dir-path", default=".", type=str, help="Path to read and write data from"
)
parser.add_argument(
"--use-s3",
action="store_true",
default=False,
help="Use data from s3 for testing.")
help="Use data from s3 for testing.",
)
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.")
help="Finish quickly for testing.",
)
parser.add_argument(
"--address",
required=False,
type=str,
help="The address to use for Ray. "
"`auto` if running through `ray submit.")
help="The address to use for Ray. " "`auto` if running through `ray submit.",
)
parser.add_argument(
"--num-workers",
default=1,
type=int,
help="The number of Ray workers to use for distributed training")
help="The number of Ray workers to use for distributed training",
)
parser.add_argument(
"--large-dataset",
action="store_true",
default=False,
help="Use 500GB dataset")
"--large-dataset", action="store_true", default=False, help="Use 500GB dataset"
)
parser.add_argument(
"--use-gpu",
action="store_true",
default=False,
help="Use GPU for training.")
"--use-gpu", action="store_true", default=False, help="Use GPU for training."
)
parser.add_argument(
"--mlflow-register-model",
action="store_true",
help="Whether to use mlflow model registry. If set, a local MLflow "
"tracking server is expected to have already been started.")
"tracking server is expected to have already been started.",
)
args = parser.parse_args()
smoke_test = args.smoke_test
@ -553,8 +557,11 @@ if __name__ == "__main__":
if len(list(count)) == 0:
print("please run `python make_and_upload_dataset.py` first")
sys.exit(1)
data_path = ("s3://cuj-big-data/big-data/"
if large_dataset else "s3://cuj-big-data/data/")
data_path = (
"s3://cuj-big-data/big-data/"
if large_dataset
else "s3://cuj-big-data/data/"
)
inference_path = "s3://cuj-big-data/inference/"
inference_output_path = "s3://cuj-big-data/output/"
else:
@ -562,20 +569,19 @@ if __name__ == "__main__":
inference_path = os.path.join(dir_path, "inference")
inference_output_path = "/tmp"
if len(os.listdir(data_path)) <= 1 or len(
os.listdir(inference_path)) <= 1:
if len(os.listdir(data_path)) <= 1 or len(os.listdir(inference_path)) <= 1:
print("please run `python make_and_upload_dataset.py` first")
sys.exit(1)
if smoke_test:
# Only read a single file.
data_path = os.path.join(data_path, "data_00000.parquet.snappy")
inference_path = os.path.join(inference_path,
"data_00000.parquet.snappy")
inference_path = os.path.join(inference_path, "data_00000.parquet.snappy")
preprocessor = DataPreprocessor()
train_dataset, test_dataset = preprocessor.preprocess_train_data(
read_dataset(data_path))
read_dataset(data_path)
)
num_columns = len(train_dataset.schema().names)
# remove label column and internal Arrow column.
@ -589,14 +595,12 @@ if __name__ == "__main__":
DROPOUT_PROB = 0.2
# Random global shuffle
train_dataset_pipeline = train_dataset.repeat() \
.random_shuffle_each_window(_spread_resource_prefix="node:")
train_dataset_pipeline = train_dataset.repeat().random_shuffle_each_window(
_spread_resource_prefix="node:"
)
del train_dataset
datasets = {
"train_dataset": train_dataset_pipeline,
"test_dataset": test_dataset
}
datasets = {"train_dataset": train_dataset_pipeline, "test_dataset": test_dataset}
config = {
"use_gpu": use_gpu,
@ -606,7 +610,7 @@ if __name__ == "__main__":
"num_layers": NUM_LAYERS,
"dropout_every": DROPOUT_EVERY,
"dropout_prob": DROPOUT_PROB,
"num_features": num_features
"num_features": num_features,
}
# Create 2 callbacks: one for Tensorboard Logging and one for MLflow
@ -619,7 +623,8 @@ if __name__ == "__main__":
callbacks = [
TBXLoggerCallback(logdir=tbx_logdir),
MLflowLoggerCallback(
experiment_name="cuj-big-data-training", save_artifact=True)
experiment_name="cuj-big-data-training", save_artifact=True
),
]
# Remove CPU resource so Datasets can be scheduled.
@ -629,19 +634,19 @@ if __name__ == "__main__":
backend="torch",
num_workers=num_workers,
use_gpu=use_gpu,
resources_per_worker=resources_per_worker)
resources_per_worker=resources_per_worker,
)
trainer.start()
results = trainer.run(
train_func=train_func,
config=config,
callbacks=callbacks,
dataset=datasets)
train_func=train_func, config=config, callbacks=callbacks, dataset=datasets
)
model = results[0]
trainer.shutdown()
if args.mlflow_register_model:
mlflow.pytorch.log_model(
model, artifact_path="models", registered_model_name="torch_model")
model, artifact_path="models", registered_model_name="torch_model"
)
# Get the latest model from mlflow model registry.
client = mlflow.tracking.MlflowClient()
@ -649,12 +654,14 @@ if __name__ == "__main__":
# Get the info for the latest model.
# By default, registered models are in stage "None".
latest_model_info = client.get_latest_versions(
registered_model_name, stages=["None"])[0]
registered_model_name, stages=["None"]
)[0]
latest_version = latest_model_info.version
def load_model_func():
model_uri = f"models:/torch_model/{latest_version}"
return mlflow.pytorch.load_model(model_uri)
else:
state_dict = model.state_dict()
@ -670,25 +677,30 @@ if __name__ == "__main__":
n_features=num_features,
num_hidden=num_hidden,
dropout_every=dropout_every,
drop_prob=dropout_prob)
drop_prob=dropout_prob,
)
model.load_state_dict(state_dict)
return model
class BatchInferModel:
def __init__(self, load_model_func):
self.device = torch.device("cuda:0"
if torch.cuda.is_available() else "cpu")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model = load_model_func().to(self.device)
def __call__(self, batch) -> "pd.DataFrame":
tensor = torch.FloatTensor(batch.to_pandas().values).to(
self.device)
tensor = torch.FloatTensor(batch.to_pandas().values).to(self.device)
return pd.DataFrame(self.model(tensor).cpu().detach().numpy())
inference_dataset = preprocessor.preprocess_inference_data(
read_dataset(inference_path))
inference(inference_dataset, BatchInferModel(load_model_func), 100,
inference_output_path, use_gpu)
read_dataset(inference_path)
)
inference(
inference_dataset,
BatchInferModel(load_model_func),
100,
inference_output_path,
use_gpu,
)
end_time = time.time()

View file

@ -14,20 +14,23 @@ class MyActor:
self.counter = Counter(
"num_requests",
description="Number of requests processed by the actor.",
tag_keys=("actor_name", ))
tag_keys=("actor_name",),
)
self.counter.set_default_tags({"actor_name": name})
self.gauge = Gauge(
"curr_count",
description="Current count held by the actor. Goes up and down.",
tag_keys=("actor_name", ))
tag_keys=("actor_name",),
)
self.gauge.set_default_tags({"actor_name": name})
self.histogram = Histogram(
"request_latency",
description="Latencies of requests in ms.",
boundaries=[0.1, 1],
tag_keys=("actor_name", ))
tag_keys=("actor_name",),
)
self.histogram.set_default_tags({"actor_name": name})
def process_request(self, num):

View file

@ -46,7 +46,8 @@ class LinearModel(object):
y_ = tf.placeholder(tf.float32, [None, shape[1]])
self.y_ = y_
cross_entropy = tf.reduce_mean(
-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])
)
self.cross_entropy = cross_entropy
self.cross_entropy_grads = tf.gradients(cross_entropy, [w, b])
self.sess = tf.Session()
@ -54,24 +55,20 @@ class LinearModel(object):
# Ray's TensorFlowVariables to automatically create methods to modify
# the weights.
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
cross_entropy, self.sess)
cross_entropy, self.sess
)
def loss(self, xs, ys):
"""Computes the loss of the network."""
return float(
self.sess.run(
self.cross_entropy, feed_dict={
self.x: xs,
self.y_: ys
}))
self.sess.run(self.cross_entropy, feed_dict={self.x: xs, self.y_: ys})
)
def grad(self, xs, ys):
"""Computes the gradients of the network."""
return self.sess.run(
self.cross_entropy_grads, feed_dict={
self.x: xs,
self.y_: ys
})
self.cross_entropy_grads, feed_dict={self.x: xs, self.y_: ys}
)
@ray.remote
@ -143,4 +140,5 @@ if __name__ == "__main__":
# Use L-BFGS to minimize the loss function.
print("Running L-BFGS.")
result = scipy.optimize.fmin_l_bfgs_b(
full_loss, theta_init, maxiter=10, fprime=full_grad, disp=True)
full_loss, theta_init, maxiter=10, fprime=full_grad, disp=True
)

View file

@ -26,8 +26,7 @@ class RayDistributedActor:
"""
# Set the init_method and rank of the process for distributed training.
print("Ray worker at {url} rank {rank}".format(
url=url, rank=world_rank))
print("Ray worker at {url} rank {rank}".format(url=url, rank=world_rank))
self.url = url
self.world_rank = world_rank
args.distributed_rank = world_rank
@ -55,8 +54,10 @@ class RayDistributedActor:
n_cpus = int(ray.cluster_resources()["CPU"])
if n_cpus > original_n_cpus:
raise Exception(
"New CPUs find (original %d CPUs, now %d CPUs)" %
(original_n_cpus, n_cpus))
"New CPUs find (original %d CPUs, now %d CPUs)"
% (original_n_cpus, n_cpus)
)
else:
original_n_gpus = args.distributed_world_size
@ -65,8 +66,9 @@ class RayDistributedActor:
n_gpus = int(ray.cluster_resources().get("GPU", 0))
if n_gpus > original_n_gpus:
raise Exception(
"New GPUs find (original %d GPUs, now %d GPUs)" %
(original_n_gpus, n_gpus))
"New GPUs find (original %d GPUs, now %d GPUs)"
% (original_n_gpus, n_gpus)
)
fairseq.checkpoint_utils.save_checkpoint = _new_save_checkpoint
@ -103,8 +105,7 @@ def run_fault_tolerant_loop():
set_batch_size(args)
# Set up Ray distributed actors.
Actor = ray.remote(
num_cpus=1, num_gpus=int(not args.cpu))(RayDistributedActor)
Actor = ray.remote(num_cpus=1, num_gpus=int(not args.cpu))(RayDistributedActor)
workers = [Actor.remote() for i in range(args.distributed_world_size)]
# Get the IP address and a free port of actor 0, which is used for
@ -116,8 +117,7 @@ def run_fault_tolerant_loop():
# Start the remote processes, and check whether their are any process
# fails. If so, restart all the processes.
unfinished = [
worker.run.remote(address, i, args)
for i, worker in enumerate(workers)
worker.run.remote(address, i, args) for i, worker in enumerate(workers)
]
try:
while len(unfinished) > 0:
@ -135,10 +135,8 @@ def add_ray_args(parser):
"""Add ray and fault-tolerance related parser arguments to the parser."""
group = parser.add_argument_group("Ray related arguments")
group.add_argument(
"--ray-address",
default="auto",
type=str,
help="address for ray initialization")
"--ray-address", default="auto", type=str, help="address for ray initialization"
)
group.add_argument(
"--fix-batch-size",
default=None,
@ -147,7 +145,8 @@ def add_ray_args(parser):
help="fix the actual batch size (max_sentences * update_freq "
"* n_GPUs) to be the fixed input values by adjusting update_freq "
"accroding to actual n_GPUs; the batch size is fixed to B_i for "
"epoch i; all epochs >N are fixed to B_N")
"epoch i; all epochs >N are fixed to B_N",
)
return group
@ -168,13 +167,13 @@ def set_batch_size(args):
"""Fixes the total batch_size to be agnostic to the GPU count."""
if args.fix_batch_size is not None:
args.update_freq = [
math.ceil(batch_size /
(args.max_sentences * args.distributed_world_size))
math.ceil(batch_size / (args.max_sentences * args.distributed_world_size))
for batch_size in args.fix_batch_size
]
print("Training on %d GPUs, max_sentences=%d, update_freq=%s" %
(args.distributed_world_size, args.max_sentences,
repr(args.update_freq)))
print(
"Training on %d GPUs, max_sentences=%d, update_freq=%s"
% (args.distributed_world_size, args.max_sentences, repr(args.update_freq))
)
if __name__ == "__main__":

View file

@ -62,31 +62,34 @@ import ray
parser = argparse.ArgumentParser()
parser.add_argument(
"--address", type=str, default="auto", help="The address to use for Ray.")
"--address", type=str, default="auto", help="The address to use for Ray."
)
parser.add_argument(
"--smoke-test",
action="store_true",
help="Read a smaller dataset for quick testing purposes.")
help="Read a smaller dataset for quick testing purposes.",
)
parser.add_argument(
"--num-actors",
type=int,
default=4,
help="Sets number of actors for training.")
"--num-actors", type=int, default=4, help="Sets number of actors for training."
)
parser.add_argument(
"--cpus-per-actor",
type=int,
default=8,
help="The number of CPUs per actor for training.")
help="The number of CPUs per actor for training.",
)
parser.add_argument(
"--num-actors-inference",
type=int,
default=16,
help="Sets number of actors for inference.")
help="Sets number of actors for inference.",
)
parser.add_argument(
"--cpus-per-actor-inference",
type=int,
default=2,
help="The number of CPUs per actor for inference.")
help="The number of CPUs per actor for inference.",
)
# Ignore -f from ipykernel_launcher
args, _ = parser.parse_known_args()
@ -119,12 +122,13 @@ if not ray.is_initialized():
LABEL_COLUMN = "label"
if smoke_test:
# Test dataset with only 10,000 records.
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" \
".csv"
FILE_URL = "https://ray-ci-higgs.s3.us-west-2.amazonaws.com/simpleHIGGS" ".csv"
else:
# Full dataset. This may take a couple of minutes to load.
FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" \
"/00280/HIGGS.csv.gz"
FILE_URL = (
"https://archive.ics.uci.edu/ml/machine-learning-databases"
"/00280/HIGGS.csv.gz"
)
colnames = [LABEL_COLUMN] + ["feature-%02d" % i for i in range(1, 29)]
@ -182,7 +186,8 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
evals_result=evals_result,
verbose_eval=False,
num_boost_round=100,
ray_params=ray_params)
ray_params=ray_params,
)
train_end_time = time.time()
train_duration = train_end_time - train_start_time
@ -190,8 +195,7 @@ def train_xgboost(config, train_df, test_df, target_column, ray_params):
model_path = "model.xgb"
bst.save_model(model_path)
print("Final validation error: {:.4f}".format(
evals_result["eval"]["error"][-1]))
print("Final validation error: {:.4f}".format(evals_result["eval"]["error"][-1]))
return bst, evals_result
@ -208,8 +212,12 @@ config = {
}
bst, evals_result = train_xgboost(
config, df_train, df_validation, LABEL_COLUMN,
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors))
config,
df_train,
df_validation,
LABEL_COLUMN,
RayParams(cpus_per_actor=cpus_per_actor, num_actors=num_actors),
)
print(f"Results: {evals_result}")
###############################################################################
@ -227,7 +235,8 @@ results = predict(
bst,
inference_df,
ray_params=RayParams(
cpus_per_actor=cpus_per_actor_inference,
num_actors=num_actors_inference))
cpus_per_actor=cpus_per_actor_inference, num_actors=num_actors_inference
),
)
print(results)

View file

@ -12,10 +12,12 @@ class NewsServer(object):
def __init__(self):
self.conn = sqlite3.connect("newsreader.db")
c = self.conn.cursor()
c.execute("""CREATE TABLE IF NOT EXISTS news
c.execute(
"""CREATE TABLE IF NOT EXISTS news
(title text, link text,
description text, published timestamp,
feed url, liked bool)""")
feed url, liked bool)"""
)
self.conn.commit()
def retrieve_feed(self, url):
@ -24,36 +26,41 @@ class NewsServer(object):
items = []
c = self.conn.cursor()
for item in feed.items:
items.append({
"title": item.title,
"link": item.link,
"description": item.description,
"description_text": item.description,
"pubDate": str(item.pub_date)
})
items.append(
{
"title": item.title,
"link": item.link,
"description": item.description,
"description_text": item.description,
"pubDate": str(item.pub_date),
}
)
c.execute(
"""INSERT INTO news (title, link, description,
published, feed, liked) values
(?, ?, ?, ?, ?, ?)""",
(item.title, item.link, item.description, item.pub_date,
feed.link, False))
(
item.title,
item.link,
item.description,
item.pub_date,
feed.link,
False,
),
)
self.conn.commit()
return {
"channel": {
"title": feed.title,
"link": feed.link,
"url": feed.link
},
"items": items
"channel": {"title": feed.title, "link": feed.link, "url": feed.link},
"items": items,
}
def like_item(self, url, is_faved):
c = self.conn.cursor()
if is_faved:
c.execute("UPDATE news SET liked = 1 WHERE link = ?", (url, ))
c.execute("UPDATE news SET liked = 1 WHERE link = ?", (url,))
else:
c.execute("UPDATE news SET liked = 0 WHERE link = ?", (url, ))
c.execute("UPDATE news SET liked = 0 WHERE link = ?", (url,))
self.conn.commit()
@ -77,9 +84,7 @@ def dispatcher():
result = ray.get(method.remote(*method_args))
return jsonify(result)
else:
return jsonify({
"error": "method_name '" + method_name + "' not found"
})
return jsonify({"error": "method_name '" + method_name + "' not found"})
if __name__ == "__main__":

View file

@ -44,16 +44,16 @@ num_evaluations = 10
# A function for generating random hyperparameters.
def generate_hyperparameters():
return {
"learning_rate": 10**np.random.uniform(-5, 1),
"learning_rate": 10 ** np.random.uniform(-5, 1),
"batch_size": np.random.randint(1, 100),
"momentum": np.random.uniform(0, 1)
"momentum": np.random.uniform(0, 1),
}
def get_data_loaders(batch_size):
mnist_transforms = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))])
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
# We add FileLock here because multiple workers will want to
# download data, and this may cause overwrites since
@ -61,16 +61,16 @@ def get_data_loaders(batch_size):
with FileLock(os.path.expanduser("~/data.lock")):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data",
train=True,
download=True,
transform=mnist_transforms),
"~/data", train=True, download=True, transform=mnist_transforms
),
batch_size=batch_size,
shuffle=True)
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
batch_size=batch_size,
shuffle=True)
shuffle=True,
)
return train_loader, test_loader
@ -152,9 +152,8 @@ def evaluate_hyperparameters(config):
model = ConvNet()
train_loader, test_loader = get_data_loaders(config["batch_size"])
optimizer = optim.SGD(
model.parameters(),
lr=config["learning_rate"],
momentum=config["momentum"])
model.parameters(), lr=config["learning_rate"], momentum=config["momentum"]
)
train(model, optimizer, train_loader)
return test(model, test_loader)
@ -202,22 +201,33 @@ while remaining_ids:
hyperparameters = hyperparameters_mapping[result_id]
accuracy = ray.get(result_id)
print("""We achieve accuracy {:.3}% with
print(
"""We achieve accuracy {:.3}% with
learning_rate: {:.2}
batch_size: {}
momentum: {:.2}
""".format(100 * accuracy, hyperparameters["learning_rate"],
hyperparameters["batch_size"], hyperparameters["momentum"]))
""".format(
100 * accuracy,
hyperparameters["learning_rate"],
hyperparameters["batch_size"],
hyperparameters["momentum"],
)
)
if accuracy > best_accuracy:
best_hyperparameters = hyperparameters
best_accuracy = accuracy
# Record the best performing set of hyperparameters.
print("""Best accuracy over {} trials was {:.3} with
print(
"""Best accuracy over {} trials was {:.3} with
learning_rate: {:.2}
batch_size: {}
momentum: {:.2}
""".format(num_evaluations, 100 * best_accuracy,
best_hyperparameters["learning_rate"],
best_hyperparameters["batch_size"],
best_hyperparameters["momentum"]))
""".format(
num_evaluations,
100 * best_accuracy,
best_hyperparameters["learning_rate"],
best_hyperparameters["batch_size"],
best_hyperparameters["momentum"],
)
)

View file

@ -39,8 +39,8 @@ import ray
def get_data_loader():
"""Safely downloads data. Returns training/validation set dataloader."""
mnist_transforms = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))])
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
# We add FileLock here because multiple workers will want to
# download data, and this may cause overwrites since
@ -48,16 +48,16 @@ def get_data_loader():
with FileLock(os.path.expanduser("~/data.lock")):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data",
train=True,
download=True,
transform=mnist_transforms),
"~/data", train=True, download=True, transform=mnist_transforms
),
batch_size=128,
shuffle=True)
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
batch_size=128,
shuffle=True)
shuffle=True,
)
return train_loader, test_loader
@ -75,7 +75,7 @@ def evaluate(model, test_loader):
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return 100. * correct / total
return 100.0 * correct / total
#######################################################################
@ -144,8 +144,7 @@ class ParameterServer(object):
def apply_gradients(self, *gradients):
summed_gradients = [
np.stack(gradient_zip).sum(axis=0)
for gradient_zip in zip(*gradients)
np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*gradients)
]
self.optimizer.zero_grad()
self.model.set_gradients(summed_gradients)
@ -215,9 +214,7 @@ test_loader = get_data_loader()[1]
print("Running synchronous parameter server training.")
current_weights = ps.get_weights.remote()
for i in range(iterations):
gradients = [
worker.compute_gradients.remote(current_weights) for worker in workers
]
gradients = [worker.compute_gradients.remote(current_weights) for worker in workers]
# Calculate update after all gradients are available.
current_weights = ps.apply_gradients.remote(*gradients)

View file

@ -197,7 +197,7 @@ class Model(object):
"""Applies the gradients to the model parameters with RMSProp."""
for k, v in self.weights.items():
g = grad_buffer[k]
rmsprop_cache[k] = (decay * rmsprop_cache[k] + (1 - decay) * g**2)
rmsprop_cache[k] = decay * rmsprop_cache[k] + (1 - decay) * g ** 2
self.weights[k] += lr * g / (np.sqrt(rmsprop_cache[k]) + 1e-5)
@ -278,20 +278,24 @@ for i in range(1, 1 + iterations):
gradient_ids = []
# Launch tasks to compute gradients from multiple rollouts in parallel.
start_time = time.time()
gradient_ids = [
actor.compute_gradient.remote(model_id) for actor in actors
]
gradient_ids = [actor.compute_gradient.remote(model_id) for actor in actors]
for batch in range(batch_size):
[grad_id], gradient_ids = ray.wait(gradient_ids)
grad, reward_sum = ray.get(grad_id)
# Accumulate the gradient over batch.
for k in model.weights:
grad_buffer[k] += grad[k]
running_reward = (reward_sum if running_reward is None else
running_reward * 0.99 + reward_sum * 0.01)
running_reward = (
reward_sum
if running_reward is None
else running_reward * 0.99 + reward_sum * 0.01
)
end_time = time.time()
print("Batch {} computed {} rollouts in {} seconds, "
"running mean is {}".format(i, batch_size, end_time - start_time,
running_reward))
print(
"Batch {} computed {} rollouts in {} seconds, "
"running mean is {}".format(
i, batch_size, end_time - start_time, running_reward
)
)
model.update(grad_buffer, rmsprop_cache, learning_rate, decay_rate)
zero_grads(grad_buffer)

View file

@ -23,6 +23,7 @@ from typing import Tuple
from time import sleep
import ray
# For typing purposes
from ray.actor import ActorHandle
from tqdm import tqdm

View file

@ -8,12 +8,11 @@ import wikipedia
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mappers", help="number of mapper actors used", default=3, type=int)
"--num-mappers", help="number of mapper actors used", default=3, type=int
)
parser.add_argument(
"--num-reducers",
help="number of reducer actors used",
default=4,
type=int)
"--num-reducers", help="number of reducer actors used", default=4, type=int
)
@ray.remote
@ -36,8 +35,11 @@ class Mapper(object):
while self.num_articles_processed < article_index + 1:
self.get_new_article()
# Return the word counts from within a given character range.
return [(k, v) for k, v in self.word_counts[article_index].items()
if len(k) >= 1 and k[0] >= keys[0] and k[0] <= keys[1]]
return [
(k, v)
for k, v in self.word_counts[article_index].items()
if len(k) >= 1 and k[0] >= keys[0] and k[0] <= keys[1]
]
@ray.remote
@ -51,8 +53,7 @@ class Reducer(object):
# Get the word counts for this Reducer's keys from all of the Mappers
# and aggregate the results.
count_ids = [
mapper.get_range.remote(article_index, self.keys)
for mapper in self.mappers
mapper.get_range.remote(article_index, self.keys) for mapper in self.mappers
]
while len(count_ids) > 0:
@ -87,9 +88,9 @@ if __name__ == "__main__":
streams.append(Stream([line.strip() for line in f.readlines()]))
# Partition the keys among the reducers.
chunks = np.array_split([chr(i)
for i in range(ord("a"),
ord("z") + 1)], args.num_reducers)
chunks = np.array_split(
[chr(i) for i in range(ord("a"), ord("z") + 1)], args.num_reducers
)
keys = [[chunk[0], chunk[-1]] for chunk in chunks]
# Create a number of mappers.
@ -103,14 +104,12 @@ if __name__ == "__main__":
while True:
print("article index = {}".format(article_index))
wordcounts = {}
counts = ray.get([
reducer.next_reduce_result.remote(article_index)
for reducer in reducers
])
counts = ray.get(
[reducer.next_reduce_result.remote(article_index) for reducer in reducers]
)
for count in counts:
wordcounts.update(count)
most_frequent_words = heapq.nlargest(
10, wordcounts, key=wordcounts.get)
most_frequent_words = heapq.nlargest(10, wordcounts, key=wordcounts.get)
for word in most_frequent_words:
print(" ", word, wordcounts[word])
article_index += 1

View file

@ -125,12 +125,12 @@ class MNISTDataInterface(object):
self.data_dir = data_dir
self.max_days = max_days
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
self.dataset = MNIST(
self.data_dir, train=True, download=True, transform=transform)
self.data_dir, train=True, download=True, transform=transform
)
def _get_day_slice(self, day=0):
if day < 0:
@ -154,8 +154,7 @@ class MNISTDataInterface(object):
end = self._get_day_slice(day)
available_data = Subset(self.dataset, list(range(start, end)))
train_n = int(
0.8 * (end - start)) # 80% train data, 20% validation data
train_n = int(0.8 * (end - start)) # 80% train data, 20% validation data
return random_split(available_data, [train_n, end - start - train_n])
@ -223,13 +222,15 @@ def test(model, data_loader, device=None):
# will take care of creating the model and optimizer and repeatedly
# call the ``train`` function to train the model. Also, this function
# will report the training progress back to Tune.
def train_mnist(config,
start_model=None,
checkpoint_dir=None,
num_epochs=10,
use_gpus=False,
data_fn=None,
day=0):
def train_mnist(
config,
start_model=None,
checkpoint_dir=None,
num_epochs=10,
use_gpus=False,
data_fn=None,
day=0,
):
# Create model
use_cuda = use_gpus and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
@ -237,7 +238,8 @@ def train_mnist(config,
# Create optimizer
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"])
model.parameters(), lr=config["lr"], momentum=config["momentum"]
)
# Load checkpoint, or load start model if no checkpoint has been
# passed and a start model is specified
@ -248,8 +250,7 @@ def train_mnist(config,
load_dir = start_model
if load_dir:
model_state, optimizer_state = torch.load(
os.path.join(load_dir, "checkpoint"))
model_state, optimizer_state = torch.load(os.path.join(load_dir, "checkpoint"))
model.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
@ -257,18 +258,22 @@ def train_mnist(config,
train_dataset, validation_dataset = data_fn(day=day)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=config["batch_size"], shuffle=True)
train_dataset, batch_size=config["batch_size"], shuffle=True
)
validation_loader = torch.utils.data.DataLoader(
validation_dataset, batch_size=config["batch_size"], shuffle=True)
validation_dataset, batch_size=config["batch_size"], shuffle=True
)
for i in range(num_epochs):
train(model, optimizer, train_loader, device)
acc = test(model, validation_loader, device)
if i == num_epochs - 1:
with tune.checkpoint_dir(step=i) as checkpoint_dir:
torch.save((model.state_dict(), optimizer.state_dict()),
os.path.join(checkpoint_dir, "checkpoint"))
torch.save(
(model.state_dict(), optimizer.state_dict()),
os.path.join(checkpoint_dir, "checkpoint"),
)
tune.report(mean_accuracy=acc, done=True)
else:
tune.report(mean_accuracy=acc)
@ -286,7 +291,7 @@ def train_mnist(config,
# until the given day. Our search space can thus also contain parameters
# that affect the model complexity (such as the layer size), since it
# does not have to be compatible to an existing model.
def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0):
data_interface = MNISTDataInterface("~/data", max_days=10)
num_examples = data_interface._get_day_slice(day)
@ -302,11 +307,13 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
mode="max",
max_t=num_epochs,
grace_period=1,
reduction_factor=2)
reduction_factor=2,
)
reporter = CLIReporter(
parameter_columns=["layer_size", "lr", "momentum", "batch_size"],
metric_columns=["mean_accuracy", "training_iteration"])
metric_columns=["mean_accuracy", "training_iteration"],
)
analysis = tune.run(
partial(
@ -315,17 +322,16 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
data_fn=data_interface.get_data,
num_epochs=num_epochs,
use_gpus=True if gpus_per_trial > 0 else False,
day=day),
resources_per_trial={
"cpu": 1,
"gpu": gpus_per_trial
},
day=day,
),
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
config=config,
num_samples=num_samples,
scheduler=scheduler,
progress_reporter=reporter,
verbose=0,
name="tune_serve_mnist_fromscratch")
name="tune_serve_mnist_fromscratch",
)
best_trial = analysis.get_best_trial("mean_accuracy", "max", "last")
best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"]
@ -344,33 +350,35 @@ def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0., day=0):
# layer size parameter. Since we continue to train an existing model,
# we cannot change the layer size mid training, so we just continue
# to use the existing one.
def tune_from_existing(start_model,
start_config,
num_samples=10,
num_epochs=10,
gpus_per_trial=0.,
day=0):
def tune_from_existing(
start_model, start_config, num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0
):
data_interface = MNISTDataInterface("/tmp/mnist_data", max_days=10)
num_examples = data_interface._get_day_slice(day) - \
data_interface._get_day_slice(day - 1)
num_examples = data_interface._get_day_slice(day) - data_interface._get_day_slice(
day - 1
)
config = start_config.copy()
config.update({
"batch_size": tune.choice([16, 32, 64]),
"lr": tune.loguniform(1e-4, 1e-1),
"momentum": tune.uniform(0.1, 0.9),
})
config.update(
{
"batch_size": tune.choice([16, 32, 64]),
"lr": tune.loguniform(1e-4, 1e-1),
"momentum": tune.uniform(0.1, 0.9),
}
)
scheduler = ASHAScheduler(
metric="mean_accuracy",
mode="max",
max_t=num_epochs,
grace_period=1,
reduction_factor=2)
reduction_factor=2,
)
reporter = CLIReporter(
parameter_columns=["lr", "momentum", "batch_size"],
metric_columns=["mean_accuracy", "training_iteration"])
metric_columns=["mean_accuracy", "training_iteration"],
)
analysis = tune.run(
partial(
@ -379,17 +387,16 @@ def tune_from_existing(start_model,
data_fn=data_interface.get_incremental_data,
num_epochs=num_epochs,
use_gpus=True if gpus_per_trial > 0 else False,
day=day),
resources_per_trial={
"cpu": 1,
"gpu": gpus_per_trial
},
day=day,
),
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
config=config,
num_samples=num_samples,
scheduler=scheduler,
progress_reporter=reporter,
verbose=0,
name="tune_serve_mnist_fromsexisting")
name="tune_serve_mnist_fromsexisting",
)
best_trial = analysis.get_best_trial("mean_accuracy", "max", "last")
best_accuracy = best_trial.metric_analysis["mean_accuracy"]["last"]
@ -423,8 +430,8 @@ class MNISTDeployment:
model = ConvNet(layer_size=self.config["layer_size"]).to(self.device)
model_state, optimizer_state = torch.load(
os.path.join(self.checkpoint_dir, "checkpoint"),
map_location=self.device)
os.path.join(self.checkpoint_dir, "checkpoint"), map_location=self.device
)
model.load_state_dict(model_state)
self.model = model
@ -442,12 +449,12 @@ class MNISTDeployment:
# active model. We call this directory ``model_dir``. Every time we
# would like to update our model, we copy the checkpoint of the new
# model to this directory. We then update the deployment to the new version.
def serve_new_model(model_dir, checkpoint, config, metrics, day,
use_gpu=False):
def serve_new_model(model_dir, checkpoint, config, metrics, day, use_gpu=False):
print("Serving checkpoint: {}".format(checkpoint))
checkpoint_path = _move_checkpoint_to_model_dir(model_dir, checkpoint,
config, metrics)
checkpoint_path = _move_checkpoint_to_model_dir(
model_dir, checkpoint, config, metrics
)
serve.start(detached=True)
MNISTDeployment.deploy(checkpoint_path, config, metrics, use_gpu)
@ -482,8 +489,7 @@ def get_current_model(model_dir):
checkpoint_path = os.path.join(model_dir, "checkpoint")
meta_path = os.path.join(model_dir, "meta.json")
if not os.path.exists(checkpoint_path) or \
not os.path.exists(meta_path):
if not os.path.exists(checkpoint_path) or not os.path.exists(meta_path):
return None, None, None
with open(meta_path, "rt") as fp:
@ -559,28 +565,33 @@ if __name__ == "__main__":
"--from_scratch",
action="store_true",
help="Train and select best model from scratch",
default=False)
default=False,
)
parser.add_argument(
"--from_existing",
action="store_true",
help="Train and select best model from existing model",
default=False)
default=False,
)
parser.add_argument(
"--day",
help="Indicate the day to simulate the amount of data available to us",
type=int,
default=0)
default=0,
)
parser.add_argument(
"--query", help="Query endpoint with example", type=int, default=-1)
"--query", help="Query endpoint with example", type=int, default=-1
)
parser.add_argument(
"--smoke-test",
action="store_true",
help="Finish quickly for testing",
default=False)
default=False,
)
args = parser.parse_args()
@ -600,20 +611,23 @@ if __name__ == "__main__":
# Query our model
response = requests.post(
"http://localhost:8000/mnist",
json={"images": [data[0].numpy().tolist()]})
"http://localhost:8000/mnist", json={"images": [data[0].numpy().tolist()]}
)
try:
pred = response.json()["result"][0]
except: # noqa: E722
pred = -1
print("Querying model with example #{}. "
"Label = {}, Response = {}, Correct = {}".format(
args.query, label, pred, label == pred))
print(
"Querying model with example #{}. "
"Label = {}, Response = {}, Correct = {}".format(
args.query, label, pred, label == pred
)
)
sys.exit(0)
gpus_per_trial = 0.5 if not args.smoke_test else 0.
gpus_per_trial = 0.5 if not args.smoke_test else 0.0
serve_gpu = True if gpus_per_trial > 0 else False
num_samples = 8 if not args.smoke_test else 1
num_epochs = 10 if not args.smoke_test else 1
@ -621,23 +635,22 @@ if __name__ == "__main__":
if args.from_scratch: # train everyday from scratch
print("Start training job from scratch on day {}.".format(args.day))
acc, config, best_checkpoint, num_examples = tune_from_scratch(
num_samples, num_epochs, gpus_per_trial, day=args.day)
print("Trained day {} from scratch on {} samples. "
"Best accuracy: {:.4f}. Best config: {}".format(
args.day, num_examples, acc, config))
num_samples, num_epochs, gpus_per_trial, day=args.day
)
print(
"Trained day {} from scratch on {} samples. "
"Best accuracy: {:.4f}. Best config: {}".format(
args.day, num_examples, acc, config
)
)
serve_new_model(
model_dir,
best_checkpoint,
config,
acc,
args.day,
use_gpu=serve_gpu)
model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu
)
if args.from_existing:
old_checkpoint, old_config, old_acc = get_current_model(model_dir)
if not old_checkpoint or not old_config or not old_acc:
print("No existing model found. Train one with --from_scratch "
"first.")
print("No existing model found. Train one with --from_scratch " "first.")
sys.exit(1)
acc, config, best_checkpoint, num_examples = tune_from_existing(
old_checkpoint,
@ -645,17 +658,17 @@ if __name__ == "__main__":
num_samples,
num_epochs,
gpus_per_trial,
day=args.day)
print("Trained day {} from existing on {} samples. "
"Best accuracy: {:.4f}. Best config: {}".format(
args.day, num_examples, acc, config))
day=args.day,
)
print(
"Trained day {} from existing on {} samples. "
"Best accuracy: {:.4f}. Best config: {}".format(
args.day, num_examples, acc, config
)
)
serve_new_model(
model_dir,
best_checkpoint,
config,
acc,
args.day,
use_gpu=serve_gpu)
model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu
)
#######################################################################
# That's it! We now have an end-to-end workflow to train and update a

View file

@ -38,6 +38,7 @@ To start out, change the import statement to get tune-scikit-learns grid sear
"""
# Keep this here for https://github.com/ray-project/ray/issues/11547
from sklearn.model_selection import GridSearchCV
# Replace above line with:
from ray.tune.sklearn import TuneGridSearchCV
@ -60,7 +61,8 @@ X, y = make_classification(
n_informative=50,
n_redundant=0,
n_classes=10,
class_sep=2.5)
class_sep=2.5,
)
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=1000)
# Example parameters to tune from SGDClassifier
@ -70,9 +72,11 @@ parameter_grid = {"alpha": [1e-4, 1e-1, 1], "epsilon": [0.01, 0.1]}
# As you can see, the setup here is exactly how you would do it for Scikit-Learn. Now, let's try fitting a model.
tune_search = TuneGridSearchCV(
SGDClassifier(), parameter_grid, early_stopping=True, max_iters=10)
SGDClassifier(), parameter_grid, early_stopping=True, max_iters=10
)
import time # Just to compare fit times
start = time.time()
tune_search.fit(x_train, y_train)
end = time.time()
@ -93,6 +97,7 @@ print("Tune GridSearch Fit Time:", end - start)
# Try running this compared to the GridSearchCV equivalent, and see the speedup for yourself!
from sklearn.model_selection import GridSearchCV
# n_jobs=-1 enables use of all cores like Tune does
sklearn_search = GridSearchCV(SGDClassifier(), parameter_grid, n_jobs=-1)
@ -120,7 +125,7 @@ import numpy as np
digits = datasets.load_digits()
x = digits.data
y = digits.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.2)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
clf = SGDClassifier()
parameter_grid = {"alpha": (1e-4, 1), "epsilon": (0.01, 0.1)}

View file

@ -8,8 +8,9 @@ import ray
def get_host_name(x):
import platform
import time
time.sleep(0.01)
return x + (platform.node(), )
return x + (platform.node(),)
def wait_for_nodes(expected):
@ -17,8 +18,11 @@ def wait_for_nodes(expected):
while True:
num_nodes = len(ray.nodes())
if num_nodes < expected:
print("{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes))
print(
"{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes
)
)
sys.stdout.flush()
time.sleep(1)
else:
@ -31,9 +35,7 @@ def main():
# Check that objects can be transferred from each node to each other node.
for i in range(10):
print("Iteration {}".format(i))
results = [
get_host_name.remote(get_host_name.remote(())) for _ in range(100)
]
results = [get_host_name.remote(get_host_name.remote(())) for _ in range(100)]
print(Counter(ray.get(results)))
sys.stdout.flush()

View file

@ -20,8 +20,9 @@ def setup_logging() -> None:
setup_component_logger(
logging_level=ray_constants.LOGGER_LEVEL, # info
logging_format=ray_constants.LOGGER_FORMAT,
log_dir=os.path.join(ray._private.utils.get_ray_temp_dir(),
ray.node.SESSION_LATEST, "logs"),
log_dir=os.path.join(
ray._private.utils.get_ray_temp_dir(), ray.node.SESSION_LATEST, "logs"
),
filename=ray_constants.MONITOR_LOG_FILE_NAME, # monitor.log
max_bytes=ray_constants.LOGGING_ROTATE_BYTES,
backup_count=ray_constants.LOGGING_ROTATE_BACKUP_COUNT,
@ -47,11 +48,11 @@ if __name__ == "__main__":
required=False,
type=str,
default=None,
help="The password to use for Redis")
help="The password to use for Redis",
)
args = parser.parse_args()
cluster_name = yaml.safe_load(
open(AUTOSCALING_CONFIG_PATH).read())["cluster_name"]
cluster_name = yaml.safe_load(open(AUTOSCALING_CONFIG_PATH).read())["cluster_name"]
head_ip = get_node_ip_address()
Monitor(
address=f"{head_ip}:6379",

Some files were not shown because too many files have changed in this diff Show more