mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[projects] Add named commands to sessions (#5525)
This commit is contained in:
parent
97ccd75952
commit
f1dcce5a47
10 changed files with 159 additions and 69 deletions
|
@ -18,7 +18,7 @@ Quick start (CLI)
|
|||
|
||||
# Create a new session from the given project.
|
||||
# Launch a cluster and run the appropriate command.
|
||||
$ ray session start
|
||||
$ ray session start <command> [arguments]
|
||||
|
||||
# Open a console for the given session.
|
||||
$ ray session attach
|
||||
|
@ -63,8 +63,19 @@ Here is an example for a minimal project format:
|
|||
# and the environment is set up.
|
||||
# A command can also specify a cluster that overwrites the default cluster.
|
||||
commands:
|
||||
- name: default
|
||||
command: python default.py
|
||||
help: "The command that will be executed if no command name is specified"
|
||||
- name: test
|
||||
command: python test.py
|
||||
command: python test.py --param1={{param1}} --param2={{param2}}
|
||||
help: "A test command"
|
||||
params:
|
||||
- name: "param1"
|
||||
help: "The first parameter"
|
||||
# The following line indicates possible values this parameter can take.
|
||||
choices: ["1", "2"]
|
||||
- name: "param2"
|
||||
help: "The second parameter"
|
||||
|
||||
Project files have to adhere to the following schema:
|
||||
|
||||
|
|
|
@ -10,21 +10,25 @@ environment:
|
|||
requirements: requirements.txt
|
||||
|
||||
commands:
|
||||
- name: train_sst_2
|
||||
- name: train
|
||||
command: |
|
||||
wget https://raw.githubusercontent.com/nyu-mll/GLUE-baselines/master/download_glue_data.py && \
|
||||
python download_glue_data.py -d /tmp -t SST && \
|
||||
wget https://raw.githubusercontent.com/ray-project/project-data/master/download_glue_data.py && \
|
||||
python download_glue_data.py -d /tmp -t {{dataset}} && \
|
||||
python ./examples/run_glue.py \
|
||||
--model_type bert \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--task_name SST-2 \
|
||||
--task_name {{dataset}} \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir /tmp/SST-2 \
|
||||
--data_dir /tmp/{{dataset}} \
|
||||
--max_seq_length 128 \
|
||||
--per_gpu_eval_batch_size=8 \
|
||||
--per_gpu_train_batch_size=8 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/output/
|
||||
params:
|
||||
- name: "dataset"
|
||||
help: "The GLUE dataset to fine-tune on"
|
||||
choices: ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI", "QNLI", "RTE", "WNLI"]
|
||||
|
|
|
@ -55,6 +55,27 @@
|
|||
"command": {
|
||||
"description": "Shell command to run on the cluster",
|
||||
"type": "string"
|
||||
},
|
||||
"params" : {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"description": "Possible parameters in the command",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"description": "Name of the parameter",
|
||||
"type": "string"
|
||||
},
|
||||
"help": {
|
||||
"description": "Help string for the parameter",
|
||||
"type": "string"
|
||||
},
|
||||
"choices": {
|
||||
"description": "Possible values the parameter can take",
|
||||
"type": "array"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
@ -153,10 +154,20 @@ def stop():
|
|||
override_cluster_name=None)
|
||||
|
||||
|
||||
@session_cli.command(help="Start a session based on current project config")
|
||||
def start():
|
||||
@session_cli.command(
|
||||
context_settings=dict(ignore_unknown_options=True, ),
|
||||
help="Start a session based on current project config")
|
||||
@click.argument("command", required=False)
|
||||
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
|
||||
def start(command, args):
|
||||
project_definition = load_project_or_throw()
|
||||
|
||||
if command:
|
||||
command_to_run = _get_command_to_run(command, project_definition, args)
|
||||
else:
|
||||
command_to_run = _get_command_to_run("default", project_definition,
|
||||
args)
|
||||
|
||||
# Check for features we don't support right now
|
||||
project_environment = project_definition["environment"]
|
||||
need_docker = ("dockerfile" in project_environment
|
||||
|
@ -202,9 +213,9 @@ def start():
|
|||
_setup_environment(
|
||||
cluster_yaml, project_definition["environment"], cwd=working_directory)
|
||||
|
||||
logger.info("[4/4] Running commands")
|
||||
_run_commands(
|
||||
cluster_yaml, project_definition["commands"], cwd=working_directory)
|
||||
logger.info("[4/4] Running command")
|
||||
logger.debug("Running {}".format(command))
|
||||
session_exec_cluster(cluster_yaml, command_to_run, cwd=working_directory)
|
||||
|
||||
|
||||
def session_exec_cluster(cluster_yaml, cmd, cwd=None):
|
||||
|
@ -249,7 +260,30 @@ def _setup_environment(cluster_yaml, project_environment, cwd):
|
|||
session_exec_cluster(cluster_yaml, cmd, cwd=cwd)
|
||||
|
||||
|
||||
def _run_commands(cluster_yaml, commands, cwd):
|
||||
for cmd in commands:
|
||||
logger.debug("Running {}".format(cmd["name"]))
|
||||
session_exec_cluster(cluster_yaml, cmd["command"], cwd=cwd)
|
||||
def _get_command_to_run(command, project_definition, args):
|
||||
command_to_run = None
|
||||
params = None
|
||||
|
||||
for command_definition in project_definition["commands"]:
|
||||
if command_definition["name"] == command:
|
||||
command_to_run = command_definition["command"]
|
||||
params = command_definition.get("params", [])
|
||||
if not command_to_run:
|
||||
raise click.ClickException(
|
||||
"Cannot find the command '" + command +
|
||||
"' in commmands section of the project file.")
|
||||
|
||||
# Build argument parser dynamically to parse parameter arguments.
|
||||
parser = argparse.ArgumentParser(prog=command)
|
||||
for param in params:
|
||||
parser.add_argument(
|
||||
"--" + param["name"],
|
||||
required=True,
|
||||
help=param.get("help"),
|
||||
choices=param.get("choices"))
|
||||
|
||||
result = parser.parse_args(list(args))
|
||||
for key, val in result.__dict__.items():
|
||||
command_to_run = command_to_run.replace("{{" + key + "}}", val)
|
||||
|
||||
return command_to_run
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# This file is generated by `ray project create`.
|
||||
|
||||
name: commands-test
|
||||
|
||||
# description: A short description of the project.
|
||||
repo: https://github.com/ray-project/not-exist
|
||||
|
||||
cluster: .rayproject/cluster.yaml
|
||||
|
||||
environment:
|
||||
shell:
|
||||
- echo "Setting up"
|
||||
|
||||
commands:
|
||||
- name: first
|
||||
command: echo "Starting ray job with {{a}} and {{b}}"
|
||||
params:
|
||||
- name: a
|
||||
help: "This is the first parameter"
|
||||
choices: ["1", "2"]
|
||||
- name: b
|
||||
help: "This is the second parameter"
|
||||
choices: ["1", "2"]
|
||||
|
||||
- name: second
|
||||
command: echo "Some command"
|
|
@ -16,5 +16,5 @@ environment:
|
|||
- echo "Setting up the environment"
|
||||
|
||||
commands:
|
||||
- name: first-command
|
||||
- name: default
|
||||
command: echo "Starting ray job"
|
||||
|
|
|
@ -16,5 +16,5 @@ environment:
|
|||
- echo "Setting up the environment"
|
||||
|
||||
commands:
|
||||
- name: first-command
|
||||
- name: default
|
||||
command: echo "Starting ray job"
|
||||
|
|
|
@ -15,5 +15,5 @@ environment:
|
|||
- echo "Setting up the environment"
|
||||
|
||||
commands:
|
||||
- name: first-command
|
||||
- name: default
|
||||
command: echo "Starting ray job"
|
||||
|
|
|
@ -87,10 +87,9 @@ def _chdir_and_back(d):
|
|||
os.chdir(old_dir)
|
||||
|
||||
|
||||
def test_session_start_default_project():
|
||||
def run_test_project(project_dir, command, args):
|
||||
# Run the CLI commands with patching
|
||||
test_dir = os.path.join(TEST_DIR,
|
||||
"project_files/session-tests/project-pass")
|
||||
test_dir = os.path.join(TEST_DIR, "project_files", project_dir)
|
||||
with _chdir_and_back(test_dir):
|
||||
runner = CliRunner()
|
||||
with patch.multiple(
|
||||
|
@ -99,11 +98,17 @@ def test_session_start_default_project():
|
|||
rsync=DEFAULT,
|
||||
exec_cluster=DEFAULT,
|
||||
) as mock_calls:
|
||||
result = runner.invoke(start, [])
|
||||
assert result.exit_code == 0
|
||||
result = runner.invoke(command, args)
|
||||
|
||||
return result, mock_calls, test_dir
|
||||
|
||||
|
||||
def test_session_start_default_project():
|
||||
result, mock_calls, test_dir = run_test_project(
|
||||
"session-tests/project-pass", start, [])
|
||||
|
||||
# Check we are calling autoscaler correctly
|
||||
loaded_project = ray.projects.load_project(test_dir)
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Part 1/3: Cluster Launching Call
|
||||
create_or_update_cluster_call = mock_calls["create_or_update_cluster"]
|
||||
|
@ -150,39 +155,20 @@ def test_session_start_default_project():
|
|||
|
||||
|
||||
def test_session_start_docker_fail():
|
||||
# Run the CLI commands with patching
|
||||
test_dir = os.path.join(TEST_DIR,
|
||||
"project_files/session-tests/with-docker-fail")
|
||||
with _chdir_and_back(test_dir):
|
||||
runner = CliRunner()
|
||||
with patch.multiple(
|
||||
"ray.projects.scripts",
|
||||
create_or_update_cluster=DEFAULT,
|
||||
rsync=DEFAULT,
|
||||
exec_cluster=DEFAULT,
|
||||
) as _:
|
||||
result = runner.invoke(start, [])
|
||||
result, _, _ = run_test_project("session-tests/with-docker-fail", start,
|
||||
[])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert ("Docker support in session is currently "
|
||||
"not implemented") in result.output
|
||||
|
||||
|
||||
def test_session_git_repo_cloned():
|
||||
# Run the CLI commands with patching
|
||||
test_dir = os.path.join(TEST_DIR,
|
||||
"project_files/session-tests/git-repo-pass")
|
||||
with _chdir_and_back(test_dir):
|
||||
runner = CliRunner()
|
||||
with patch.multiple(
|
||||
"ray.projects.scripts",
|
||||
create_or_update_cluster=DEFAULT,
|
||||
rsync=DEFAULT,
|
||||
exec_cluster=DEFAULT,
|
||||
) as mock_calls:
|
||||
result = runner.invoke(start, [])
|
||||
assert result.exit_code == 0
|
||||
result, mock_calls, test_dir = run_test_project(
|
||||
"session-tests/git-repo-pass", start, [])
|
||||
|
||||
loaded_project = ray.projects.load_project(test_dir)
|
||||
assert result.exit_code == 0
|
||||
|
||||
exec_cluster_call = mock_calls["exec_cluster"]
|
||||
commands_executed = []
|
||||
|
@ -197,19 +183,27 @@ def test_session_git_repo_cloned():
|
|||
|
||||
|
||||
def test_session_invalid_config_errored():
|
||||
# Run the CLI commands with patching
|
||||
test_dir = os.path.join(TEST_DIR,
|
||||
"project_files/session-tests/invalid-config-fail")
|
||||
with _chdir_and_back(test_dir):
|
||||
runner = CliRunner()
|
||||
with patch.multiple(
|
||||
"ray.projects.scripts",
|
||||
create_or_update_cluster=DEFAULT,
|
||||
rsync=DEFAULT,
|
||||
exec_cluster=DEFAULT,
|
||||
) as _:
|
||||
result = runner.invoke(start, [])
|
||||
result, _, _ = run_test_project("session-tests/invalid-config-fail", start,
|
||||
[])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "validation failed" in result.output
|
||||
# check that we are displaying actional error message
|
||||
assert "ray project validate" in result.output
|
||||
|
||||
|
||||
def test_session_create_command():
|
||||
result, mock_calls, test_dir = run_test_project(
|
||||
"session-tests/commands-test", start,
|
||||
["first", "--a", "1", "--b", "2"])
|
||||
|
||||
# Verify the project can be loaded.
|
||||
ray.projects.load_project(test_dir)
|
||||
assert result.exit_code == 0
|
||||
|
||||
exec_cluster_call = mock_calls["exec_cluster"]
|
||||
found_command = False
|
||||
for _, kwargs in exec_cluster_call.call_args_list:
|
||||
if "Starting ray job with 1 and 2" in kwargs["cmd"]:
|
||||
found_command = True
|
||||
assert found_command
|
||||
|
|
Loading…
Add table
Reference in a new issue