[core] API for pre-run customized functions (#15749)

* run customer setup fn

* fix

* lint

* skip on w32

* fix comment

* up

* up
This commit is contained in:
Yi Cheng 2021-05-17 22:52:36 -07:00 committed by GitHub
parent 69f228d22d
commit 863532af0a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 1 deletions

View file

@ -150,5 +150,23 @@ __all__ += [
"PlacementGroupID",
]
# Remove modules from top-level ray
def _ray_user_setup_function():
import os
user_setup_fn = os.environ.get("RAY_USER_SETUP_FUNCTION")
if user_setup_fn is not None:
try:
module_name, fn_name = user_setup_fn.rsplit(".", 1)
m = __import__(module_name, globals(), locals(), [fn_name])
getattr(m, fn_name)()
except Exception as e:
logger.exception(
f"Failed to run user setup function: {user_setup_fn}. "
f"Error message {e}")
_ray_user_setup_function()
del logging
del _ray_user_setup_function

View file

@ -537,3 +537,8 @@ def load_test_config(config_file_name):
config_file_name)
config = yaml.safe_load(open(config_path).read())
return config
def set_setup_func():
import ray._private.runtime_env as runtime_env
runtime_env.VAR = "hello world"

View file

@ -9,7 +9,8 @@ import numpy as np
import pytest
import ray.cluster_utils
from ray.test_utils import (client_test_enabled, get_error_message)
from ray.test_utils import (client_test_enabled, get_error_message,
run_string_as_driver)
import ray
@ -165,6 +166,28 @@ def test_invalid_arguments(shutdown_only):
x = 1
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows")
def test_user_setup_function():
script = """
import ray
ray.init()
@ray.remote
def get_pkg_dir():
return ray._private.runtime_env.VAR
print("remote", ray.get(get_pkg_dir.remote()))
print("local", ray._private.runtime_env.VAR)
"""
out = run_string_as_driver(
script, {"RAY_USER_SETUP_FUNCTION": "ray.test_utils.set_setup_func"})
(remote_out, local_out) = out.strip().split("\n")[-2:]
assert remote_out == "remote hello world"
assert local_out == "local hello world"
def test_put_get(shutdown_only):
ray.init(num_cpus=0)