ray/dashboard/modules/job/tests/test_cli.py
2021-11-11 15:59:13 -08:00

102 lines
3.3 KiB
Python

from contextlib import contextmanager
import logging
import os
from unittest import mock
from click.testing import CliRunner
import pytest
from ray.dashboard.modules.job.cli import job_cli_group
logger = logging.getLogger(__name__)
@pytest.fixture
def mock_sdk_client():
if "RAY_ADDRESS" in os.environ:
del os.environ["RAY_ADDRESS"]
with mock.patch("ray.dashboard.modules.job.cli.JobSubmissionClient"
) as mock_client:
yield mock_client
@contextmanager
def set_env_var(key: str, val: str):
old_val = os.environ.get(key, None)
os.environ[key] = val
yield
del os.environ[key]
if old_val is not None:
os.environ[key] = old_val
class TestSubmit:
def test_address(self, mock_sdk_client):
runner = CliRunner()
# Test passing address via command line.
result = runner.invoke(
job_cli_group,
["submit", "--address=arg_addr", "--", "echo hello"])
assert mock_sdk_client.called_with("arg_addr")
assert result.exit_code == 0
# Test passing address via env var.
with set_env_var("RAY_ADDRESS", "env_addr"):
result = runner.invoke(job_cli_group,
["submit", "--", "echo hello"])
assert result.exit_code == 0
assert mock_sdk_client.called_with("env_addr")
# Test passing no address.
result = runner.invoke(job_cli_group, ["submit", "--", "echo hello"])
assert result.exit_code == 1
assert "Address must be specified" in str(result.exception)
def test_runtime_env(self, mock_sdk_client):
mock_client_instance = mock_sdk_client.return_value
runner = CliRunner()
with set_env_var("RAY_ADDRESS", "env_addr"):
result = runner.invoke(job_cli_group,
["submit", "--", "echo hello"])
assert result.exit_code == 0
assert mock_client_instance.called_with(runtime_env={})
result = runner.invoke(
job_cli_group,
["submit", "--", "--working-dir", "blah", "--", "echo hello"])
assert result.exit_code == 0
assert mock_client_instance.called_with(
runtime_env={"working_dir": "blah"})
result = runner.invoke(
job_cli_group,
["submit", "--", "--working-dir='.'", "--", "echo hello"])
assert result.exit_code == 0
assert mock_client_instance.called_with(
runtime_env={"working_dir": "."})
def test_job_id(self, mock_sdk_client):
mock_client_instance = mock_sdk_client.return_value
runner = CliRunner()
with set_env_var("RAY_ADDRESS", "env_addr"):
result = runner.invoke(job_cli_group,
["submit", "--", "echo hello"])
assert result.exit_code == 0
assert mock_client_instance.called_with(job_id=None)
result = runner.invoke(
job_cli_group,
["submit", "--", "--job-id=my_job_id", "echo hello"])
assert result.exit_code == 0
assert mock_client_instance.called_with(job_id="my_job_id")
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))