mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[projects] Wrap ProjectDefinition in a class (#5654)
This commit is contained in:
parent
d0125d4212
commit
cb7102f31e
20 changed files with 161 additions and 144 deletions
|
@ -2,10 +2,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.projects.projects import (check_project_definition, find_root,
|
||||
load_project, validate_project_schema)
|
||||
from ray.projects.projects import ProjectDefinition
|
||||
|
||||
__all__ = [
|
||||
"check_project_definition", "find_root", "load_project",
|
||||
"validate_project_schema"
|
||||
"ProjectDefinition",
|
||||
]
|
||||
|
|
|
@ -2,12 +2,98 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import jsonschema
|
||||
import os
|
||||
import yaml
|
||||
|
||||
|
||||
class ProjectDefinition:
|
||||
def __init__(self, current_dir):
|
||||
"""Finds .rayproject folder for current project, parse and validates it.
|
||||
|
||||
Args:
|
||||
current_dir (str): Path from which to search for .rayproject.
|
||||
|
||||
Raises:
|
||||
jsonschema.exceptions.ValidationError: This exception is raised
|
||||
if the project file is not valid.
|
||||
ValueError: This exception is raised if there are other errors in
|
||||
the project definition (e.g. files not existing).
|
||||
"""
|
||||
root = find_root(current_dir)
|
||||
if root is None:
|
||||
raise ValueError("No project root found")
|
||||
# Add an empty pathname to the end so that rsync will copy the project
|
||||
# directory to the correct target.
|
||||
self.root = os.path.join(root, "")
|
||||
|
||||
# Parse the project YAML.
|
||||
project_file = os.path.join(self.root, ".rayproject", "project.yaml")
|
||||
if not os.path.exists(project_file):
|
||||
raise ValueError("Project file {} not found".format(project_file))
|
||||
with open(project_file) as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
|
||||
check_project_config(self.root, self.config)
|
||||
|
||||
def cluster_yaml(self):
|
||||
"""Return the project's cluster configuration filename."""
|
||||
return self.config["cluster"]
|
||||
|
||||
def working_directory(self):
|
||||
"""Return the project's working directory on a cluster session."""
|
||||
# Add an empty pathname to the end so that rsync will copy the project
|
||||
# directory to the correct target.
|
||||
directory = os.path.join("~", self.config["name"], "")
|
||||
return directory
|
||||
|
||||
def get_command_to_run(self, command=None, args=tuple()):
|
||||
"""Get and format a command to run.
|
||||
|
||||
Args:
|
||||
command (str): Name of the command to run. The command definition
|
||||
should be available in project.yaml.
|
||||
args (tuple): Tuple containing arguments to format the command
|
||||
with.
|
||||
Returns:
|
||||
The raw shell command to run, formatted with the given arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: This exception is raised if the given command is not
|
||||
found in project.yaml.
|
||||
"""
|
||||
command_to_run = None
|
||||
params = None
|
||||
|
||||
if command is None:
|
||||
command = "default"
|
||||
for command_definition in self.config["commands"]:
|
||||
if command_definition["name"] == command:
|
||||
command_to_run = command_definition["command"]
|
||||
params = command_definition.get("params", [])
|
||||
if not command_to_run:
|
||||
raise ValueError(
|
||||
"Cannot find the command '{}' in commmands section of the "
|
||||
"project file.".format(command))
|
||||
|
||||
# Build argument parser dynamically to parse parameter arguments.
|
||||
parser = argparse.ArgumentParser(prog=command)
|
||||
for param in params:
|
||||
parser.add_argument(
|
||||
"--" + param["name"],
|
||||
required=True,
|
||||
help=param.get("help"),
|
||||
choices=param.get("choices"))
|
||||
|
||||
result = parser.parse_args(list(args))
|
||||
for key, val in result.__dict__.items():
|
||||
command_to_run = command_to_run.replace("{{" + key + "}}", val)
|
||||
|
||||
return command_to_run
|
||||
|
||||
|
||||
def find_root(directory):
|
||||
"""Find root directory of the ray project.
|
||||
|
||||
|
@ -27,11 +113,11 @@ def find_root(directory):
|
|||
return None
|
||||
|
||||
|
||||
def validate_project_schema(project_definition):
|
||||
"""Validate a project file against the official ray project schema.
|
||||
def validate_project_schema(project_config):
|
||||
"""Validate a project config against the official ray project schema.
|
||||
|
||||
Args:
|
||||
project_definition (dict): Parsed project yaml.
|
||||
project_config (dict): Parsed project yaml.
|
||||
|
||||
Raises:
|
||||
jsonschema.exceptions.ValidationError: This exception is raised
|
||||
|
@ -41,15 +127,15 @@ def validate_project_schema(project_definition):
|
|||
with open(os.path.join(dir, "schema.json")) as f:
|
||||
schema = json.load(f)
|
||||
|
||||
jsonschema.validate(instance=project_definition, schema=schema)
|
||||
jsonschema.validate(instance=project_config, schema=schema)
|
||||
|
||||
|
||||
def check_project_definition(project_root, project_definition):
|
||||
def check_project_config(project_root, project_config):
|
||||
"""Checks if the project definition is valid.
|
||||
|
||||
Args:
|
||||
project_root (str): Path containing the .rayproject
|
||||
project_definition (dict): Project definition
|
||||
project_config (dict): Project config definition
|
||||
|
||||
Raises:
|
||||
jsonschema.exceptions.ValidationError: This exception is raised
|
||||
|
@ -57,19 +143,17 @@ def check_project_definition(project_root, project_definition):
|
|||
ValueError: This exception is raised if there are other errors in
|
||||
the project definition (e.g. files not existing).
|
||||
"""
|
||||
|
||||
validate_project_schema(project_definition)
|
||||
validate_project_schema(project_config)
|
||||
|
||||
# Make sure the cluster yaml file exists
|
||||
if "cluster" in project_definition:
|
||||
cluster_file = os.path.join(project_root,
|
||||
project_definition["cluster"])
|
||||
if "cluster" in project_config:
|
||||
cluster_file = os.path.join(project_root, project_config["cluster"])
|
||||
if not os.path.exists(cluster_file):
|
||||
raise ValueError("'cluster' file does not exist "
|
||||
"in {}".format(project_root))
|
||||
|
||||
if "environment" in project_definition:
|
||||
env = project_definition["environment"]
|
||||
if "environment" in project_config:
|
||||
env = project_config["environment"]
|
||||
|
||||
if sum(["dockerfile" in env, "dockerimage" in env]) > 1:
|
||||
raise ValueError("Cannot specify both 'dockerfile' and "
|
||||
|
@ -86,36 +170,3 @@ def check_project_definition(project_root, project_definition):
|
|||
if not os.path.exists(docker_file):
|
||||
raise ValueError("'dockerfile' file in 'environment' does "
|
||||
"not exist in {}".format(project_root))
|
||||
|
||||
|
||||
def load_project(current_dir):
|
||||
"""Finds .rayproject folder for current project, parse and validates it.
|
||||
|
||||
Args:
|
||||
current_dir (str): Path from which to search for .rayproject.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the project definition.
|
||||
|
||||
Raises:
|
||||
jsonschema.exceptions.ValidationError: This exception is raised
|
||||
if the project file is not valid.
|
||||
ValueError: This exception is raised if there are other errors in
|
||||
the project definition (e.g. files not existing).
|
||||
"""
|
||||
project_root = find_root(current_dir)
|
||||
|
||||
if not project_root:
|
||||
raise ValueError("No project root found")
|
||||
|
||||
project_file = os.path.join(project_root, ".rayproject", "project.yaml")
|
||||
|
||||
if not os.path.exists(project_file):
|
||||
raise ValueError("Project file {} not found".format(project_file))
|
||||
|
||||
with open(project_file) as f:
|
||||
project_definition = yaml.safe_load(f)
|
||||
|
||||
check_project_definition(project_root, project_definition)
|
||||
|
||||
return project_definition
|
||||
|
|
|
@ -2,7 +2,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import click
|
||||
import jsonschema
|
||||
import logging
|
||||
|
@ -59,10 +58,10 @@ def project_cli():
|
|||
"--verbose", help="If set, print the validated file", is_flag=True)
|
||||
def validate(verbose):
|
||||
try:
|
||||
project = ray.projects.load_project(os.getcwd())
|
||||
project = ray.projects.ProjectDefinition(os.getcwd())
|
||||
print("Project files validated!", file=sys.stderr)
|
||||
if verbose:
|
||||
print(project)
|
||||
print(project.config)
|
||||
except (jsonschema.exceptions.ValidationError, ValueError) as e:
|
||||
print("Validation failed for the following reason", file=sys.stderr)
|
||||
raise click.ClickException(e)
|
||||
|
@ -139,7 +138,7 @@ def session_cli():
|
|||
def load_project_or_throw():
|
||||
# Validate the project file
|
||||
try:
|
||||
return ray.projects.load_project(os.getcwd())
|
||||
return ray.projects.ProjectDefinition(os.getcwd())
|
||||
except (jsonschema.exceptions.ValidationError, ValueError):
|
||||
raise click.ClickException(
|
||||
"Project file validation failed. Please run "
|
||||
|
@ -150,7 +149,7 @@ def load_project_or_throw():
|
|||
def attach():
|
||||
project_definition = load_project_or_throw()
|
||||
attach_cluster(
|
||||
project_definition["cluster"],
|
||||
project_definition.cluster_yaml(),
|
||||
start=False,
|
||||
use_tmux=False,
|
||||
override_cluster_name=None,
|
||||
|
@ -162,7 +161,7 @@ def attach():
|
|||
def stop():
|
||||
project_definition = load_project_or_throw()
|
||||
teardown_cluster(
|
||||
project_definition["cluster"],
|
||||
project_definition.cluster_yaml(),
|
||||
yes=True,
|
||||
workers_only=False,
|
||||
override_cluster_name=None)
|
||||
|
@ -184,14 +183,15 @@ def start(command, args, shell):
|
|||
|
||||
if shell:
|
||||
command_to_run = command
|
||||
elif command:
|
||||
command_to_run = _get_command_to_run(command, project_definition, args)
|
||||
else:
|
||||
command_to_run = _get_command_to_run("default", project_definition,
|
||||
args)
|
||||
try:
|
||||
command_to_run = project_definition.get_command_to_run(
|
||||
command=command, args=args)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(e)
|
||||
|
||||
# Check for features we don't support right now
|
||||
project_environment = project_definition["environment"]
|
||||
project_environment = project_definition.config["environment"]
|
||||
need_docker = ("dockerfile" in project_environment
|
||||
or "dockerimage" in project_environment)
|
||||
if need_docker:
|
||||
|
@ -200,12 +200,9 @@ def start(command, args, shell):
|
|||
"Please file an feature request at"
|
||||
"https://github.com/ray-project/ray/issues")
|
||||
|
||||
cluster_yaml = project_definition["cluster"]
|
||||
working_directory = project_definition["name"]
|
||||
|
||||
logger.info("[1/4] Creating cluster")
|
||||
create_or_update_cluster(
|
||||
config_file=cluster_yaml,
|
||||
config_file=project_definition.cluster_yaml(),
|
||||
override_min_workers=None,
|
||||
override_max_workers=None,
|
||||
no_restart=False,
|
||||
|
@ -215,26 +212,26 @@ def start(command, args, shell):
|
|||
)
|
||||
|
||||
logger.info("[2/4] Syncing the project")
|
||||
project_root = ray.projects.find_root(os.getcwd())
|
||||
# This is so that rsync syncs directly to the target directory, instead of
|
||||
# nesting inside the target directory.
|
||||
if not project_root.endswith("/"):
|
||||
project_root += "/"
|
||||
rsync(
|
||||
cluster_yaml,
|
||||
source=project_root,
|
||||
target="~/{}/".format(working_directory),
|
||||
project_definition.cluster_yaml(),
|
||||
source=project_definition.root,
|
||||
target=project_definition.working_directory(),
|
||||
override_cluster_name=None,
|
||||
down=False,
|
||||
)
|
||||
|
||||
logger.info("[3/4] Setting up environment")
|
||||
_setup_environment(
|
||||
cluster_yaml, project_definition["environment"], cwd=working_directory)
|
||||
project_definition.cluster_yaml(),
|
||||
project_environment,
|
||||
cwd=project_definition.working_directory())
|
||||
|
||||
logger.info("[4/4] Running command")
|
||||
logger.debug("Running {}".format(command))
|
||||
session_exec_cluster(cluster_yaml, command_to_run, cwd=working_directory)
|
||||
session_exec_cluster(
|
||||
project_definition.cluster_yaml(),
|
||||
command_to_run,
|
||||
cwd=project_definition.working_directory())
|
||||
|
||||
|
||||
def session_exec_cluster(cluster_yaml, cmd, cwd=None):
|
||||
|
@ -277,32 +274,3 @@ def _setup_environment(cluster_yaml, project_environment, cwd):
|
|||
if "shell" in project_environment:
|
||||
for cmd in project_environment["shell"]:
|
||||
session_exec_cluster(cluster_yaml, cmd, cwd=cwd)
|
||||
|
||||
|
||||
def _get_command_to_run(command, project_definition, args):
|
||||
command_to_run = None
|
||||
params = None
|
||||
|
||||
for command_definition in project_definition["commands"]:
|
||||
if command_definition["name"] == command:
|
||||
command_to_run = command_definition["command"]
|
||||
params = command_definition.get("params", [])
|
||||
if not command_to_run:
|
||||
raise click.ClickException(
|
||||
"Cannot find the command '" + command +
|
||||
"' in commmands section of the project file.")
|
||||
|
||||
# Build argument parser dynamically to parse parameter arguments.
|
||||
parser = argparse.ArgumentParser(prog=command)
|
||||
for param in params:
|
||||
parser.add_argument(
|
||||
"--" + param["name"],
|
||||
required=True,
|
||||
help=param.get("help"),
|
||||
choices=param.get("choices"))
|
||||
|
||||
result = parser.parse_args(list(args))
|
||||
for key, val in result.__dict__.items():
|
||||
command_to_run = command_to_run.replace("{{" + key + "}}", val)
|
||||
|
||||
return command_to_run
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
name: testproject2
|
||||
name: testmissingcluster
|
||||
|
||||
environment:
|
||||
shell: "one command"
|
|
@ -0,0 +1,3 @@
|
|||
name: testmissingyaml
|
||||
|
||||
cluster: "cluster.yaml"
|
|
@ -20,61 +20,57 @@ if sys.version_info >= (3, 3):
|
|||
else:
|
||||
from mock import patch, DEFAULT
|
||||
|
||||
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
TEST_DIR = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "project_files")
|
||||
|
||||
|
||||
def load_project_description(project_file):
|
||||
path = os.path.join(TEST_DIR, "project_files", project_file)
|
||||
path = os.path.join(TEST_DIR, project_file)
|
||||
with open(path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def test_validation_success():
|
||||
project_files = [
|
||||
"docker_project.yaml", "requirements_project.yaml",
|
||||
"shell_project.yaml"
|
||||
]
|
||||
for project_file in project_files:
|
||||
project_definition = load_project_description(project_file)
|
||||
ray.projects.validate_project_schema(project_definition)
|
||||
def test_validation():
|
||||
project_dirs = ["docker_project", "requirements_project", "shell_project"]
|
||||
for project_dir in project_dirs:
|
||||
project_dir = os.path.join(TEST_DIR, project_dir)
|
||||
ray.projects.ProjectDefinition(project_dir)
|
||||
|
||||
|
||||
def test_validation_failure():
|
||||
project_files = ["no_project1.yaml", "no_project2.yaml"]
|
||||
for project_file in project_files:
|
||||
project_definition = load_project_description(project_file)
|
||||
bad_schema_dirs = ["no_project1"]
|
||||
for project_dir in bad_schema_dirs:
|
||||
project_dir = os.path.join(TEST_DIR, project_dir)
|
||||
with pytest.raises(jsonschema.exceptions.ValidationError):
|
||||
ray.projects.validate_project_schema(project_definition)
|
||||
ray.projects.ProjectDefinition(project_dir)
|
||||
|
||||
|
||||
def test_check_failure():
|
||||
project_files = ["no_project3.yaml"]
|
||||
for project_file in project_files:
|
||||
project_definition = load_project_description(project_file)
|
||||
bad_project_dirs = ["no_project2", "noproject3"]
|
||||
for project_dir in bad_project_dirs:
|
||||
project_dir = os.path.join(TEST_DIR, project_dir)
|
||||
with pytest.raises(ValueError):
|
||||
ray.projects.check_project_definition("", project_definition)
|
||||
ray.projects.ProjectDefinition(project_dir)
|
||||
|
||||
|
||||
def test_project_root():
|
||||
path = os.path.join(TEST_DIR, "project_files", "project1")
|
||||
assert ray.projects.find_root(path) == path
|
||||
path = os.path.join(TEST_DIR, "project1")
|
||||
project_definition = ray.projects.ProjectDefinition(path)
|
||||
assert os.path.normpath(project_definition.root) == os.path.normpath(path)
|
||||
|
||||
path2 = os.path.join(TEST_DIR, "project_files", "project1", "subdir")
|
||||
assert ray.projects.find_root(path2) == path
|
||||
path2 = os.path.join(TEST_DIR, "project1", "subdir")
|
||||
project_definition = ray.projects.ProjectDefinition(path2)
|
||||
assert os.path.normpath(project_definition.root) == os.path.normpath(path)
|
||||
|
||||
path3 = "/tmp/"
|
||||
assert ray.projects.find_root(path3) is None
|
||||
with pytest.raises(ValueError):
|
||||
project_definition = ray.projects.ProjectDefinition(path3)
|
||||
|
||||
|
||||
def test_project_validation():
|
||||
path = os.path.join(TEST_DIR, "project_files", "project1")
|
||||
path = os.path.join(TEST_DIR, "project1")
|
||||
subprocess.check_call(["ray", "project", "validate"], cwd=path)
|
||||
|
||||
|
||||
def test_project_no_validation():
|
||||
path = os.path.join(TEST_DIR, "project_files")
|
||||
with pytest.raises(subprocess.CalledProcessError):
|
||||
subprocess.check_call(["ray", "project", "validate"], cwd=path)
|
||||
subprocess.check_call(["ray", "project", "validate"], cwd=TEST_DIR)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -89,7 +85,7 @@ def _chdir_and_back(d):
|
|||
|
||||
def run_test_project(project_dir, command, args):
|
||||
# Run the CLI commands with patching
|
||||
test_dir = os.path.join(TEST_DIR, "project_files", project_dir)
|
||||
test_dir = os.path.join(TEST_DIR, project_dir)
|
||||
with _chdir_and_back(test_dir):
|
||||
runner = CliRunner()
|
||||
with patch.multiple(
|
||||
|
@ -107,14 +103,14 @@ def test_session_start_default_project():
|
|||
result, mock_calls, test_dir = run_test_project(
|
||||
"session-tests/project-pass", start, [])
|
||||
|
||||
loaded_project = ray.projects.load_project(test_dir)
|
||||
loaded_project = ray.projects.ProjectDefinition(test_dir)
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Part 1/3: Cluster Launching Call
|
||||
create_or_update_cluster_call = mock_calls["create_or_update_cluster"]
|
||||
assert create_or_update_cluster_call.call_count == 1
|
||||
_, kwargs = create_or_update_cluster_call.call_args
|
||||
assert kwargs["config_file"] == loaded_project["cluster"]
|
||||
assert kwargs["config_file"] == loaded_project.cluster_yaml()
|
||||
|
||||
# Part 2/3: Rsync Calls
|
||||
rsync_call = mock_calls["rsync"]
|
||||
|
@ -122,21 +118,22 @@ def test_session_start_default_project():
|
|||
# requirements.txt.
|
||||
assert rsync_call.call_count == 2
|
||||
_, kwargs = rsync_call.call_args
|
||||
assert kwargs["source"] == loaded_project["environment"]["requirements"]
|
||||
assert kwargs["source"] == loaded_project.config["environment"][
|
||||
"requirements"]
|
||||
|
||||
# Part 3/3: Exec Calls
|
||||
exec_cluster_call = mock_calls["exec_cluster"]
|
||||
commands_executed = []
|
||||
for _, kwargs in exec_cluster_call.call_args_list:
|
||||
commands_executed.append(kwargs["cmd"].replace(
|
||||
"cd {}; ".format(loaded_project["name"]), ""))
|
||||
"cd {}; ".format(loaded_project.working_directory()), ""))
|
||||
|
||||
expected_commands = loaded_project["environment"]["shell"]
|
||||
expected_commands = loaded_project.config["environment"]["shell"]
|
||||
expected_commands += [
|
||||
command["command"] for command in loaded_project["commands"]
|
||||
command["command"] for command in loaded_project.config["commands"]
|
||||
]
|
||||
|
||||
if "requirements" in loaded_project["environment"]:
|
||||
if "requirements" in loaded_project.config["environment"]:
|
||||
assert any("pip install -r" for cmd in commands_executed)
|
||||
# pop the `pip install` off commands executed
|
||||
commands_executed = [
|
||||
|
@ -171,7 +168,7 @@ def test_session_create_command():
|
|||
["first", "--a", "1", "--b", "2"])
|
||||
|
||||
# Verify the project can be loaded.
|
||||
ray.projects.load_project(test_dir)
|
||||
ray.projects.ProjectDefinition(test_dir)
|
||||
assert result.exit_code == 0
|
||||
|
||||
exec_cluster_call = mock_calls["exec_cluster"]
|
||||
|
|
Loading…
Add table
Reference in a new issue