[projects] Wrap ProjectDefinition in a class (#5654)

This commit is contained in:
Stephanie Wang 2019-09-07 18:30:17 -07:00 committed by Philipp Moritz
parent d0125d4212
commit cb7102f31e
20 changed files with 161 additions and 144 deletions

View file

@ -2,10 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.projects.projects import (check_project_definition, find_root,
load_project, validate_project_schema)
from ray.projects.projects import ProjectDefinition
__all__ = [
"check_project_definition", "find_root", "load_project",
"validate_project_schema"
"ProjectDefinition",
]

View file

@ -2,12 +2,98 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import jsonschema
import os
import yaml
class ProjectDefinition:
def __init__(self, current_dir):
"""Finds .rayproject folder for current project, parse and validates it.
Args:
current_dir (str): Path from which to search for .rayproject.
Raises:
jsonschema.exceptions.ValidationError: This exception is raised
if the project file is not valid.
ValueError: This exception is raised if there are other errors in
the project definition (e.g. files not existing).
"""
root = find_root(current_dir)
if root is None:
raise ValueError("No project root found")
# Add an empty pathname to the end so that rsync will copy the project
# directory to the correct target.
self.root = os.path.join(root, "")
# Parse the project YAML.
project_file = os.path.join(self.root, ".rayproject", "project.yaml")
if not os.path.exists(project_file):
raise ValueError("Project file {} not found".format(project_file))
with open(project_file) as f:
self.config = yaml.safe_load(f)
check_project_config(self.root, self.config)
def cluster_yaml(self):
"""Return the project's cluster configuration filename."""
return self.config["cluster"]
def working_directory(self):
"""Return the project's working directory on a cluster session."""
# Add an empty pathname to the end so that rsync will copy the project
# directory to the correct target.
directory = os.path.join("~", self.config["name"], "")
return directory
def get_command_to_run(self, command=None, args=tuple()):
"""Get and format a command to run.
Args:
command (str): Name of the command to run. The command definition
should be available in project.yaml.
args (tuple): Tuple containing arguments to format the command
with.
Returns:
The raw shell command to run, formatted with the given arguments.
Raises:
ValueError: This exception is raised if the given command is not
found in project.yaml.
"""
command_to_run = None
params = None
if command is None:
command = "default"
for command_definition in self.config["commands"]:
if command_definition["name"] == command:
command_to_run = command_definition["command"]
params = command_definition.get("params", [])
if not command_to_run:
raise ValueError(
"Cannot find the command '{}' in commmands section of the "
"project file.".format(command))
# Build argument parser dynamically to parse parameter arguments.
parser = argparse.ArgumentParser(prog=command)
for param in params:
parser.add_argument(
"--" + param["name"],
required=True,
help=param.get("help"),
choices=param.get("choices"))
result = parser.parse_args(list(args))
for key, val in result.__dict__.items():
command_to_run = command_to_run.replace("{{" + key + "}}", val)
return command_to_run
def find_root(directory):
"""Find root directory of the ray project.
@ -27,11 +113,11 @@ def find_root(directory):
return None
def validate_project_schema(project_definition):
"""Validate a project file against the official ray project schema.
def validate_project_schema(project_config):
"""Validate a project config against the official ray project schema.
Args:
project_definition (dict): Parsed project yaml.
project_config (dict): Parsed project yaml.
Raises:
jsonschema.exceptions.ValidationError: This exception is raised
@ -41,15 +127,15 @@ def validate_project_schema(project_definition):
with open(os.path.join(dir, "schema.json")) as f:
schema = json.load(f)
jsonschema.validate(instance=project_definition, schema=schema)
jsonschema.validate(instance=project_config, schema=schema)
def check_project_definition(project_root, project_definition):
def check_project_config(project_root, project_config):
"""Checks if the project definition is valid.
Args:
project_root (str): Path containing the .rayproject
project_definition (dict): Project definition
project_config (dict): Project config definition
Raises:
jsonschema.exceptions.ValidationError: This exception is raised
@ -57,19 +143,17 @@ def check_project_definition(project_root, project_definition):
ValueError: This exception is raised if there are other errors in
the project definition (e.g. files not existing).
"""
validate_project_schema(project_definition)
validate_project_schema(project_config)
# Make sure the cluster yaml file exists
if "cluster" in project_definition:
cluster_file = os.path.join(project_root,
project_definition["cluster"])
if "cluster" in project_config:
cluster_file = os.path.join(project_root, project_config["cluster"])
if not os.path.exists(cluster_file):
raise ValueError("'cluster' file does not exist "
"in {}".format(project_root))
if "environment" in project_definition:
env = project_definition["environment"]
if "environment" in project_config:
env = project_config["environment"]
if sum(["dockerfile" in env, "dockerimage" in env]) > 1:
raise ValueError("Cannot specify both 'dockerfile' and "
@ -86,36 +170,3 @@ def check_project_definition(project_root, project_definition):
if not os.path.exists(docker_file):
raise ValueError("'dockerfile' file in 'environment' does "
"not exist in {}".format(project_root))
def load_project(current_dir):
"""Finds .rayproject folder for current project, parse and validates it.
Args:
current_dir (str): Path from which to search for .rayproject.
Returns:
Dictionary containing the project definition.
Raises:
jsonschema.exceptions.ValidationError: This exception is raised
if the project file is not valid.
ValueError: This exception is raised if there are other errors in
the project definition (e.g. files not existing).
"""
project_root = find_root(current_dir)
if not project_root:
raise ValueError("No project root found")
project_file = os.path.join(project_root, ".rayproject", "project.yaml")
if not os.path.exists(project_file):
raise ValueError("Project file {} not found".format(project_file))
with open(project_file) as f:
project_definition = yaml.safe_load(f)
check_project_definition(project_root, project_definition)
return project_definition

View file

@ -2,7 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import click
import jsonschema
import logging
@ -59,10 +58,10 @@ def project_cli():
"--verbose", help="If set, print the validated file", is_flag=True)
def validate(verbose):
try:
project = ray.projects.load_project(os.getcwd())
project = ray.projects.ProjectDefinition(os.getcwd())
print("Project files validated!", file=sys.stderr)
if verbose:
print(project)
print(project.config)
except (jsonschema.exceptions.ValidationError, ValueError) as e:
print("Validation failed for the following reason", file=sys.stderr)
raise click.ClickException(e)
@ -139,7 +138,7 @@ def session_cli():
def load_project_or_throw():
# Validate the project file
try:
return ray.projects.load_project(os.getcwd())
return ray.projects.ProjectDefinition(os.getcwd())
except (jsonschema.exceptions.ValidationError, ValueError):
raise click.ClickException(
"Project file validation failed. Please run "
@ -150,7 +149,7 @@ def load_project_or_throw():
def attach():
project_definition = load_project_or_throw()
attach_cluster(
project_definition["cluster"],
project_definition.cluster_yaml(),
start=False,
use_tmux=False,
override_cluster_name=None,
@ -162,7 +161,7 @@ def attach():
def stop():
project_definition = load_project_or_throw()
teardown_cluster(
project_definition["cluster"],
project_definition.cluster_yaml(),
yes=True,
workers_only=False,
override_cluster_name=None)
@ -184,14 +183,15 @@ def start(command, args, shell):
if shell:
command_to_run = command
elif command:
command_to_run = _get_command_to_run(command, project_definition, args)
else:
command_to_run = _get_command_to_run("default", project_definition,
args)
try:
command_to_run = project_definition.get_command_to_run(
command=command, args=args)
except ValueError as e:
raise click.ClickException(e)
# Check for features we don't support right now
project_environment = project_definition["environment"]
project_environment = project_definition.config["environment"]
need_docker = ("dockerfile" in project_environment
or "dockerimage" in project_environment)
if need_docker:
@ -200,12 +200,9 @@ def start(command, args, shell):
"Please file an feature request at"
"https://github.com/ray-project/ray/issues")
cluster_yaml = project_definition["cluster"]
working_directory = project_definition["name"]
logger.info("[1/4] Creating cluster")
create_or_update_cluster(
config_file=cluster_yaml,
config_file=project_definition.cluster_yaml(),
override_min_workers=None,
override_max_workers=None,
no_restart=False,
@ -215,26 +212,26 @@ def start(command, args, shell):
)
logger.info("[2/4] Syncing the project")
project_root = ray.projects.find_root(os.getcwd())
# This is so that rsync syncs directly to the target directory, instead of
# nesting inside the target directory.
if not project_root.endswith("/"):
project_root += "/"
rsync(
cluster_yaml,
source=project_root,
target="~/{}/".format(working_directory),
project_definition.cluster_yaml(),
source=project_definition.root,
target=project_definition.working_directory(),
override_cluster_name=None,
down=False,
)
logger.info("[3/4] Setting up environment")
_setup_environment(
cluster_yaml, project_definition["environment"], cwd=working_directory)
project_definition.cluster_yaml(),
project_environment,
cwd=project_definition.working_directory())
logger.info("[4/4] Running command")
logger.debug("Running {}".format(command))
session_exec_cluster(cluster_yaml, command_to_run, cwd=working_directory)
session_exec_cluster(
project_definition.cluster_yaml(),
command_to_run,
cwd=project_definition.working_directory())
def session_exec_cluster(cluster_yaml, cmd, cwd=None):
@ -277,32 +274,3 @@ def _setup_environment(cluster_yaml, project_environment, cwd):
if "shell" in project_environment:
for cmd in project_environment["shell"]:
session_exec_cluster(cluster_yaml, cmd, cwd=cwd)
def _get_command_to_run(command, project_definition, args):
command_to_run = None
params = None
for command_definition in project_definition["commands"]:
if command_definition["name"] == command:
command_to_run = command_definition["command"]
params = command_definition.get("params", [])
if not command_to_run:
raise click.ClickException(
"Cannot find the command '" + command +
"' in commmands section of the project file.")
# Build argument parser dynamically to parse parameter arguments.
parser = argparse.ArgumentParser(prog=command)
for param in params:
parser.add_argument(
"--" + param["name"],
required=True,
help=param.get("help"),
choices=param.get("choices"))
result = parser.parse_args(list(args))
for key, val in result.__dict__.items():
command_to_run = command_to_run.replace("{{" + key + "}}", val)
return command_to_run

View file

@ -1,4 +1,4 @@
name: testproject2
name: testmissingcluster
environment:
shell: "one command"

View file

@ -0,0 +1,3 @@
name: testmissingyaml
cluster: "cluster.yaml"

View file

@ -20,61 +20,57 @@ if sys.version_info >= (3, 3):
else:
from mock import patch, DEFAULT
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
TEST_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "project_files")
def load_project_description(project_file):
path = os.path.join(TEST_DIR, "project_files", project_file)
path = os.path.join(TEST_DIR, project_file)
with open(path) as f:
return yaml.safe_load(f)
def test_validation_success():
project_files = [
"docker_project.yaml", "requirements_project.yaml",
"shell_project.yaml"
]
for project_file in project_files:
project_definition = load_project_description(project_file)
ray.projects.validate_project_schema(project_definition)
def test_validation():
project_dirs = ["docker_project", "requirements_project", "shell_project"]
for project_dir in project_dirs:
project_dir = os.path.join(TEST_DIR, project_dir)
ray.projects.ProjectDefinition(project_dir)
def test_validation_failure():
project_files = ["no_project1.yaml", "no_project2.yaml"]
for project_file in project_files:
project_definition = load_project_description(project_file)
bad_schema_dirs = ["no_project1"]
for project_dir in bad_schema_dirs:
project_dir = os.path.join(TEST_DIR, project_dir)
with pytest.raises(jsonschema.exceptions.ValidationError):
ray.projects.validate_project_schema(project_definition)
ray.projects.ProjectDefinition(project_dir)
def test_check_failure():
project_files = ["no_project3.yaml"]
for project_file in project_files:
project_definition = load_project_description(project_file)
bad_project_dirs = ["no_project2", "noproject3"]
for project_dir in bad_project_dirs:
project_dir = os.path.join(TEST_DIR, project_dir)
with pytest.raises(ValueError):
ray.projects.check_project_definition("", project_definition)
ray.projects.ProjectDefinition(project_dir)
def test_project_root():
path = os.path.join(TEST_DIR, "project_files", "project1")
assert ray.projects.find_root(path) == path
path = os.path.join(TEST_DIR, "project1")
project_definition = ray.projects.ProjectDefinition(path)
assert os.path.normpath(project_definition.root) == os.path.normpath(path)
path2 = os.path.join(TEST_DIR, "project_files", "project1", "subdir")
assert ray.projects.find_root(path2) == path
path2 = os.path.join(TEST_DIR, "project1", "subdir")
project_definition = ray.projects.ProjectDefinition(path2)
assert os.path.normpath(project_definition.root) == os.path.normpath(path)
path3 = "/tmp/"
assert ray.projects.find_root(path3) is None
with pytest.raises(ValueError):
project_definition = ray.projects.ProjectDefinition(path3)
def test_project_validation():
path = os.path.join(TEST_DIR, "project_files", "project1")
path = os.path.join(TEST_DIR, "project1")
subprocess.check_call(["ray", "project", "validate"], cwd=path)
def test_project_no_validation():
path = os.path.join(TEST_DIR, "project_files")
with pytest.raises(subprocess.CalledProcessError):
subprocess.check_call(["ray", "project", "validate"], cwd=path)
subprocess.check_call(["ray", "project", "validate"], cwd=TEST_DIR)
@contextmanager
@ -89,7 +85,7 @@ def _chdir_and_back(d):
def run_test_project(project_dir, command, args):
# Run the CLI commands with patching
test_dir = os.path.join(TEST_DIR, "project_files", project_dir)
test_dir = os.path.join(TEST_DIR, project_dir)
with _chdir_and_back(test_dir):
runner = CliRunner()
with patch.multiple(
@ -107,14 +103,14 @@ def test_session_start_default_project():
result, mock_calls, test_dir = run_test_project(
"session-tests/project-pass", start, [])
loaded_project = ray.projects.load_project(test_dir)
loaded_project = ray.projects.ProjectDefinition(test_dir)
assert result.exit_code == 0
# Part 1/3: Cluster Launching Call
create_or_update_cluster_call = mock_calls["create_or_update_cluster"]
assert create_or_update_cluster_call.call_count == 1
_, kwargs = create_or_update_cluster_call.call_args
assert kwargs["config_file"] == loaded_project["cluster"]
assert kwargs["config_file"] == loaded_project.cluster_yaml()
# Part 2/3: Rsync Calls
rsync_call = mock_calls["rsync"]
@ -122,21 +118,22 @@ def test_session_start_default_project():
# requirements.txt.
assert rsync_call.call_count == 2
_, kwargs = rsync_call.call_args
assert kwargs["source"] == loaded_project["environment"]["requirements"]
assert kwargs["source"] == loaded_project.config["environment"][
"requirements"]
# Part 3/3: Exec Calls
exec_cluster_call = mock_calls["exec_cluster"]
commands_executed = []
for _, kwargs in exec_cluster_call.call_args_list:
commands_executed.append(kwargs["cmd"].replace(
"cd {}; ".format(loaded_project["name"]), ""))
"cd {}; ".format(loaded_project.working_directory()), ""))
expected_commands = loaded_project["environment"]["shell"]
expected_commands = loaded_project.config["environment"]["shell"]
expected_commands += [
command["command"] for command in loaded_project["commands"]
command["command"] for command in loaded_project.config["commands"]
]
if "requirements" in loaded_project["environment"]:
if "requirements" in loaded_project.config["environment"]:
assert any("pip install -r" for cmd in commands_executed)
# pop the `pip install` off commands executed
commands_executed = [
@ -171,7 +168,7 @@ def test_session_create_command():
["first", "--a", "1", "--b", "2"])
# Verify the project can be loaded.
ray.projects.load_project(test_dir)
ray.projects.ProjectDefinition(test_dir)
assert result.exit_code == 0
exec_cluster_call = mock_calls["exec_cluster"]