[runtime_env] Fix the some bugs related with runtime_env (#15286)

This commit is contained in:
Yi Cheng 2021-04-21 11:31:21 -07:00 committed by GitHub
parent c7f6ffb70c
commit b63e493c04
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 120 additions and 6 deletions

View file

@ -128,7 +128,8 @@ def _zip_module(root: Path, relative_path: Path, excludes: Set[Path],
def handler(path: Path):
# Pack this path if it's an empty directory or it's a file.
if path.is_dir() and next(path.iterdir()) is None or path.is_file():
if path.is_dir() and next(path.iterdir(),
None) is None or path.is_file():
file_size = path.stat().st_size
if file_size >= FILE_SIZE_WARNING:
logger.warning(
@ -213,13 +214,21 @@ def get_project_package_name(working_dir: str, py_modules: List[str],
hash_val = None
excludes = {Path(p).absolute() for p in excludes}
if working_dir:
assert isinstance(working_dir, str)
assert Path(working_dir).exists()
if not isinstance(working_dir, str):
raise TypeError("`working_dir` must be a string.")
working_dir = Path(working_dir).absolute()
if not working_dir.exists() or not working_dir.is_dir():
raise ValueError(f"working_dir {working_dir} must be an existing"
" directory")
hash_val = _xor_bytes(
hash_val, _hash_modules(working_dir, working_dir, excludes))
for py_module in py_modules or []:
if not isinstance(py_module, str):
raise TypeError("`py_module` must be a string.")
module_dir = Path(py_module).absolute()
if not module_dir.exists() or not module_dir.is_dir():
raise ValueError(f"py_module {py_module} must be an existing"
" directory")
hash_val = _xor_bytes(
hash_val, _hash_modules(module_dir, module_dir.parent, excludes))
return RAY_PKG_PREFIX + hash_val.hex() + ".zip" if hash_val else None

View file

@ -14,11 +14,15 @@ from time import sleep
import sys
import logging
sys.path.insert(0, "{working_dir}")
import test_module
import ray
import ray.util
import os
try:
import test_module
except:
pass
job_config = ray.job_config.JobConfig(
runtime_env={runtime_env}
)
@ -33,10 +37,20 @@ try:
ray.init(address="{address}",
job_config=job_config,
logging_level=logging.DEBUG)
except ValueError:
print("ValueError")
sys.exit(0)
except TypeError:
print("TypeError")
sys.exit(0)
except:
print("ERROR")
sys.exit(0)
if os.environ.get("EXIT_AFTER_INIT"):
sys.exit(0)
@ray.remote
def run_test():
return test_module.one()
@ -91,7 +105,7 @@ from test_module.test import one
def start_client_server(cluster, client_mode):
from ray._private.runtime_env import PKG_DIR
if not client_mode:
return (cluster.address, None, PKG_DIR)
return (cluster.address, {}, PKG_DIR)
ray.worker._global_node._ray_params.ray_client_server_port = "10003"
ray.worker._global_node.start_ray_client_server()
return ("localhost:10003", {"USE_RAY_CLIENT": "1"}, PKG_DIR)
@ -106,6 +120,60 @@ The following test cases are related with runtime env. It following these steps
"""
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
@pytest.mark.parametrize("client_mode", [True, False])
def test_empty_working_dir(ray_start_cluster_head, client_mode):
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
env["EXIT_AFTER_INIT"] = "1"
with tempfile.TemporaryDirectory() as working_dir:
runtime_env = f"""{{
"working_dir": r"{working_dir}",
"py_modules": [r"{working_dir}"]
}}"""
# Execute the following cmd in driver with runtime_env
execute_statement = "sys.exit(0)"
script = driver_script.format(**locals())
out = run_string_as_driver(script, env)
assert out != "ERROR"
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
@pytest.mark.parametrize("client_mode", [True, False])
def test_invalid_working_dir(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
env["EXIT_AFTER_INIT"] = "1"
runtime_env = "{ 'working_dir': 10 }"
# Execute the following cmd in driver with runtime_env
execute_statement = ""
script = driver_script.format(**locals())
out = run_string_as_driver(script, env).strip().split()[-1]
assert out == "TypeError"
runtime_env = "{ 'py_modules': [10] }"
# Execute the following cmd in driver with runtime_env
execute_statement = ""
script = driver_script.format(**locals())
out = run_string_as_driver(script, env).strip().split()[-1]
assert out == "TypeError"
runtime_env = f"{{ 'working_dir': os.path.join(r'{working_dir}', 'na') }}"
# Execute the following cmd in driver with runtime_env
execute_statement = ""
script = driver_script.format(**locals())
out = run_string_as_driver(script, env).strip().split()[-1]
assert out == "ValueError"
runtime_env = f"{{ 'py_modules': [os.path.join(r'{working_dir}', 'na')] }}"
# Execute the following cmd in driver with runtime_env
execute_statement = ""
script = driver_script.format(**locals())
out = run_string_as_driver(script, env).strip().split()[-1]
assert out == "ValueError"
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
@pytest.mark.parametrize("client_mode", [True, False])
def test_single_node(ray_start_cluster_head, working_dir, client_mode):
@ -377,6 +445,42 @@ sleep(600)
assert out.strip().split()[-1] == "ERROR"
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
def test_util_without_job_config(shutdown_only):
from ray.cluster_utils import Cluster
with tempfile.TemporaryDirectory() as tmp_dir:
with (Path(tmp_dir) / "lib.py").open("w") as f:
f.write("""
def one():
return 1
""")
old_dir = os.getcwd()
os.chdir(tmp_dir)
cluster = Cluster()
cluster.add_node(num_cpus=1)
ray.init(address=cluster.address)
(address, env, PKG_DIR) = start_client_server(cluster, True)
script = f"""
import ray
import ray.util
import os
ray.util.connect("{address}", job_config=None)
@ray.remote
def run():
from lib import one
return one()
print(ray.get([run.remote()])[0])
"""
out = run_string_as_driver(script, env)
print(out)
os.chdir(old_dir)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-sv", __file__]))

View file

@ -1291,7 +1291,8 @@ def connect(node,
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
worker.run_function_on_all_workers(
lambda worker_info: sys.path.insert(1, script_directory))
if not job_config.client_job and job_config.get_runtime_env_uris():
if not job_config.client_job and len(
job_config.get_runtime_env_uris()) == 0:
current_directory = os.path.abspath(os.path.curdir)
worker.run_function_on_all_workers(
lambda worker_info: sys.path.insert(1, current_directory))