diff --git a/python/ray/__init__.py b/python/ray/__init__.py index df1a496ea..b49384944 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -150,5 +150,23 @@ __all__ += [ "PlacementGroupID", ] + # Remove modules from top-level ray +def _ray_user_setup_function(): + import os + user_setup_fn = os.environ.get("RAY_USER_SETUP_FUNCTION") + if user_setup_fn is not None: + try: + module_name, fn_name = user_setup_fn.rsplit(".", 1) + m = __import__(module_name, globals(), locals(), [fn_name]) + getattr(m, fn_name)() + except Exception as e: + logger.exception( + f"Failed to run user setup function: {user_setup_fn}. " + f"Error message {e}") + + +_ray_user_setup_function() + del logging +del _ray_user_setup_function diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index c3761038b..f097aaeab 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -537,3 +537,8 @@ def load_test_config(config_file_name): config_file_name) config = yaml.safe_load(open(config_path).read()) return config + + +def set_setup_func(): + import ray._private.runtime_env as runtime_env + runtime_env.VAR = "hello world" diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 3da0a3083..75eb9593a 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -9,7 +9,8 @@ import numpy as np import pytest import ray.cluster_utils -from ray.test_utils import (client_test_enabled, get_error_message) +from ray.test_utils import (client_test_enabled, get_error_message, + run_string_as_driver) import ray @@ -165,6 +166,28 @@ def test_invalid_arguments(shutdown_only): x = 1 +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows") +def test_user_setup_function(): + script = """ +import ray +ray.init() +@ray.remote +def get_pkg_dir(): + return ray._private.runtime_env.VAR + +print("remote", ray.get(get_pkg_dir.remote())) +print("local", ray._private.runtime_env.VAR) + + +""" + + out = run_string_as_driver( + script, {"RAY_USER_SETUP_FUNCTION": "ray.test_utils.set_setup_func"}) + (remote_out, local_out) = out.strip().split("\n")[-2:] + assert remote_out == "remote hello world" + assert local_out == "local hello world" + + def test_put_get(shutdown_only): ray.init(num_cpus=0)