diff --git a/python/ray/serve/tests/test_runtime_env.py b/python/ray/serve/tests/test_runtime_env.py index fc85e02f0..01c2feea4 100644 --- a/python/ray/serve/tests/test_runtime_env.py +++ b/python/ray/serve/tests/test_runtime_env.py @@ -2,6 +2,7 @@ import pytest import sys import ray +from ray import serve from ray._private.test_utils import run_string_as_driver @@ -42,45 +43,32 @@ except FileNotFoundError: run_string_as_driver(driver) -def connect_with_working_dir(use_ray_client: bool, ray_client_addr: str): - job_config = ray.job_config.JobConfig(runtime_env={"working_dir": "."}) - if use_ray_client: - ray.util.connect(ray_client_addr, namespace="serve", job_config=job_config) - else: - ray.init(address="auto", namespace="serve", job_config=job_config) - - @pytest.mark.parametrize("use_ray_client", [False, True]) @pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.") -def test_working_dir_basic(ray_start, tmp_dir, use_ray_client): +def test_working_dir_basic(ray_start, tmp_dir, use_ray_client, ray_shutdown): with open("hello", "w") as f: f.write("world") + print("Wrote file") + if use_ray_client: + ray.init( + f"ray://{ray_start}", namespace="serve", runtime_env={"working_dir": "."} + ) + else: + ray.init(address="auto", namespace="serve", runtime_env={"working_dir": "."}) + print("Initialized Ray") + serve.start() + print("Started Serve") - driver = """ -import ray -from ray import serve + @serve.deployment + class Test: + def __call__(self, *args): + return open("hello").read() -job_config = ray.job_config.JobConfig(runtime_env={{"working_dir": "."}}) -if {use_ray_client}: - ray.util.connect("{client_addr}", job_config=job_config) -else: - ray.init(address="auto", job_config=job_config) - -serve.start() - -@serve.deployment -class Test: - def __call__(self, *args): - return open("hello").read() - -Test.deploy() -handle = Test.get_handle() -assert ray.get(handle.remote()) == "world" -""".format( - use_ray_client=use_ray_client, client_addr=ray_start - ) - - run_string_as_driver(driver) + Test.deploy() + print("Deployed") + handle = Test.get_handle() + print("Got handle") + assert ray.get(handle.remote()) == "world" @pytest.mark.parametrize("use_ray_client", [False, True])