mirror of
https://github.com/vale981/ray
synced 2025-03-10 05:16:49 -04:00
102 lines
3.3 KiB
Python
102 lines
3.3 KiB
Python
from contextlib import contextmanager
|
|
import logging
|
|
import os
|
|
from unittest import mock
|
|
|
|
from click.testing import CliRunner
|
|
|
|
import pytest
|
|
|
|
from ray.dashboard.modules.job.cli import job_cli_group
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_sdk_client():
|
|
if "RAY_ADDRESS" in os.environ:
|
|
del os.environ["RAY_ADDRESS"]
|
|
with mock.patch("ray.dashboard.modules.job.cli.JobSubmissionClient"
|
|
) as mock_client:
|
|
yield mock_client
|
|
|
|
|
|
@contextmanager
|
|
def set_env_var(key: str, val: str):
|
|
old_val = os.environ.get(key, None)
|
|
os.environ[key] = val
|
|
yield
|
|
del os.environ[key]
|
|
if old_val is not None:
|
|
os.environ[key] = old_val
|
|
|
|
|
|
class TestSubmit:
|
|
def test_address(self, mock_sdk_client):
|
|
runner = CliRunner()
|
|
|
|
# Test passing address via command line.
|
|
result = runner.invoke(
|
|
job_cli_group,
|
|
["submit", "--address=arg_addr", "--", "echo hello"])
|
|
assert mock_sdk_client.called_with("arg_addr")
|
|
assert result.exit_code == 0
|
|
|
|
# Test passing address via env var.
|
|
with set_env_var("RAY_ADDRESS", "env_addr"):
|
|
result = runner.invoke(job_cli_group,
|
|
["submit", "--", "echo hello"])
|
|
assert result.exit_code == 0
|
|
assert mock_sdk_client.called_with("env_addr")
|
|
|
|
# Test passing no address.
|
|
result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
|
|
assert result.exit_code == 1
|
|
assert "Address must be specified" in str(result.exception)
|
|
|
|
def test_runtime_env(self, mock_sdk_client):
|
|
mock_client_instance = mock_sdk_client.return_value
|
|
|
|
runner = CliRunner()
|
|
|
|
with set_env_var("RAY_ADDRESS", "env_addr"):
|
|
result = runner.invoke(job_cli_group,
|
|
["submit", "--", "echo hello"])
|
|
assert result.exit_code == 0
|
|
assert mock_client_instance.called_with(runtime_env={})
|
|
|
|
result = runner.invoke(
|
|
job_cli_group,
|
|
["submit", "--", "--working-dir", "blah", "--", "echo hello"])
|
|
assert result.exit_code == 0
|
|
assert mock_client_instance.called_with(
|
|
runtime_env={"working_dir": "blah"})
|
|
|
|
result = runner.invoke(
|
|
job_cli_group,
|
|
["submit", "--", "--working-dir='.'", "--", "echo hello"])
|
|
assert result.exit_code == 0
|
|
assert mock_client_instance.called_with(
|
|
runtime_env={"working_dir": "."})
|
|
|
|
def test_job_id(self, mock_sdk_client):
|
|
mock_client_instance = mock_sdk_client.return_value
|
|
|
|
runner = CliRunner()
|
|
|
|
with set_env_var("RAY_ADDRESS", "env_addr"):
|
|
result = runner.invoke(job_cli_group,
|
|
["submit", "--", "echo hello"])
|
|
assert result.exit_code == 0
|
|
assert mock_client_instance.called_with(job_id=None)
|
|
|
|
result = runner.invoke(
|
|
job_cli_group,
|
|
["submit", "--", "--job-id=my_job_id", "echo hello"])
|
|
assert result.exit_code == 0
|
|
assert mock_client_instance.called_with(job_id="my_job_id")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
sys.exit(pytest.main(["-v", __file__]))
|