Don't include script directory in sys.path if it's started via python -m (#28043)

According to https://peps.python.org/pep-0338/
> The -m switch provides a benefit here, as it inserts the current directory into sys.path, instead of the directory contain the main module.

We should follow this and don't add the driver script directory to worker's sys.path. I couldn't find a way to detect that the driver is run via `python -m` but instead we don't add the script directory to worker's sys.path if it doesn't exist in driver's sys.path.
This commit is contained in:
Jiajun Yao 2022-08-26 13:27:08 -07:00 committed by GitHub
parent ce99cf1b71
commit b41ee37c3a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 4 deletions

View file

@ -2067,10 +2067,15 @@ def connect(
# are the same.
# When using an interactive shell, there is no script directory.
if not interactive_mode:
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
worker.run_function_on_all_workers(
lambda worker_info: sys.path.insert(1, script_directory)
)
script_directory = os.path.realpath(os.path.dirname(sys.argv[0]))
# If driver's sys.path doesn't include the script directory
# (e.g driver is started via `python -m`,
# see https://peps.python.org/pep-0338/),
# then we shouldn't add it to the workers.
if script_directory in sys.path:
worker.run_function_on_all_workers(
lambda worker_info: sys.path.insert(1, script_directory)
)
# In client mode, if we use runtime envs with "working_dir", then
# it'll be handled automatically. Otherwise, add the current dir.
if not job_config.client_job and not job_config.runtime_env_has_working_dir():

View file

@ -4,6 +4,7 @@ import logging
import os
import sys
import time
import subprocess
import pytest
@ -180,6 +181,48 @@ ray.get(ready.remote())
run_string_as_driver(driver_script)
def test_worker_sys_path_contains_driver_script_directory(tmp_path, monkeypatch):
package_folder = tmp_path / "package"
package_folder.mkdir()
init_file = tmp_path / "package" / "__init__.py"
init_file.write_text("")
module1_file = tmp_path / "package" / "module1.py"
module1_file.write_text(
f"""
import sys
import ray
ray.init()
@ray.remote
def sys_path():
return sys.path
assert r'{str(tmp_path / "package")}' in ray.get(sys_path.remote())
"""
)
subprocess.check_call(["python", str(module1_file)])
# If the driver script is run via `python -m`,
# the script directory is not included in sys.path.
module2_file = tmp_path / "package" / "module2.py"
module2_file.write_text(
f"""
import sys
import ray
ray.init()
@ray.remote
def sys_path():
return sys.path
assert r'{str(tmp_path / "package")}' not in ray.get(sys_path.remote())
"""
)
monkeypatch.chdir(str(tmp_path))
subprocess.check_call(["python", "-m", "package.module2"])
if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))