mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[core] API for pre-run customized functions (#15749)
* run customer setup fn * fix * lint * skip on w32 * fix comment * up * up
This commit is contained in:
parent
69f228d22d
commit
863532af0a
3 changed files with 47 additions and 1 deletions
|
@ -150,5 +150,23 @@ __all__ += [
|
|||
"PlacementGroupID",
|
||||
]
|
||||
|
||||
|
||||
# Remove modules from top-level ray
|
||||
def _ray_user_setup_function():
|
||||
import os
|
||||
user_setup_fn = os.environ.get("RAY_USER_SETUP_FUNCTION")
|
||||
if user_setup_fn is not None:
|
||||
try:
|
||||
module_name, fn_name = user_setup_fn.rsplit(".", 1)
|
||||
m = __import__(module_name, globals(), locals(), [fn_name])
|
||||
getattr(m, fn_name)()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to run user setup function: {user_setup_fn}. "
|
||||
f"Error message {e}")
|
||||
|
||||
|
||||
_ray_user_setup_function()
|
||||
|
||||
del logging
|
||||
del _ray_user_setup_function
|
||||
|
|
|
@ -537,3 +537,8 @@ def load_test_config(config_file_name):
|
|||
config_file_name)
|
||||
config = yaml.safe_load(open(config_path).read())
|
||||
return config
|
||||
|
||||
|
||||
def set_setup_func():
|
||||
import ray._private.runtime_env as runtime_env
|
||||
runtime_env.VAR = "hello world"
|
||||
|
|
|
@ -9,7 +9,8 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import ray.cluster_utils
|
||||
from ray.test_utils import (client_test_enabled, get_error_message)
|
||||
from ray.test_utils import (client_test_enabled, get_error_message,
|
||||
run_string_as_driver)
|
||||
|
||||
import ray
|
||||
|
||||
|
@ -165,6 +166,28 @@ def test_invalid_arguments(shutdown_only):
|
|||
x = 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows")
|
||||
def test_user_setup_function():
|
||||
script = """
|
||||
import ray
|
||||
ray.init()
|
||||
@ray.remote
|
||||
def get_pkg_dir():
|
||||
return ray._private.runtime_env.VAR
|
||||
|
||||
print("remote", ray.get(get_pkg_dir.remote()))
|
||||
print("local", ray._private.runtime_env.VAR)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
out = run_string_as_driver(
|
||||
script, {"RAY_USER_SETUP_FUNCTION": "ray.test_utils.set_setup_func"})
|
||||
(remote_out, local_out) = out.strip().split("\n")[-2:]
|
||||
assert remote_out == "remote hello world"
|
||||
assert local_out == "local hello world"
|
||||
|
||||
|
||||
def test_put_get(shutdown_only):
|
||||
ray.init(num_cpus=0)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue