[Core] [runtime env] Support specifying runtime env in @ray.remote decorator (#16660)

This commit is contained in:
architkulkarni 2021-06-25 07:37:40 -07:00 committed by GitHub
parent e74d9d3ded
commit b15ab2d60b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 100 additions and 30 deletions

View file

@ -451,7 +451,7 @@ You can specify a runtime environment for your whole job using ``ray.init()`` or
ray.client("localhost:10001").env(runtime_env).connect()
...or specify per-actor or per-task using ``.options()``:
...or specify per-actor or per-task in the ``@ray.remote()`` decorator or by using ``.options()``:
.. literalinclude:: ../examples/doc_code/runtime_env_example.py
:language: python

View file

@ -267,6 +267,9 @@ class ActorClassMetadata:
memory: The heap memory quota for this actor.
object_store_memory: The object store memory quota for this actor.
resources: The default resources required by the actor creation task.
accelerator_type: The specified type of accelerator required for the
node on which this actor runs.
runtime_env: The runtime environment for this actor.
last_export_session_and_job: A pair of the last exported session
and job to help us to know whether this function was exported.
This is an imperfect mechanism used to determine if we need to
@ -279,7 +282,8 @@ class ActorClassMetadata:
def __init__(self, language, modified_class,
actor_creation_function_descriptor, class_id, max_restarts,
max_task_retries, num_cpus, num_gpus, memory,
object_store_memory, resources, accelerator_type):
object_store_memory, resources, accelerator_type,
runtime_env):
self.language = language
self.modified_class = modified_class
self.actor_creation_function_descriptor = \
@ -295,6 +299,7 @@ class ActorClassMetadata:
self.object_store_memory = object_store_memory
self.resources = resources
self.accelerator_type = accelerator_type
self.runtime_env = runtime_env
self.last_export_session_and_job = None
self.method_meta = ActorClassMethodMetadata.create(
modified_class, actor_creation_function_descriptor)
@ -353,7 +358,7 @@ class ActorClass:
def _ray_from_modified_class(cls, modified_class, class_id, max_restarts,
max_task_retries, num_cpus, num_gpus, memory,
object_store_memory, resources,
accelerator_type):
accelerator_type, runtime_env):
for attribute in [
"remote",
"_remote",
@ -385,7 +390,7 @@ class ActorClass:
Language.PYTHON, modified_class,
actor_creation_function_descriptor, class_id, max_restarts,
max_task_retries, num_cpus, num_gpus, memory, object_store_memory,
resources, accelerator_type)
resources, accelerator_type, runtime_env)
return self
@ -393,13 +398,13 @@ class ActorClass:
def _ray_from_function_descriptor(
cls, language, actor_creation_function_descriptor, max_restarts,
max_task_retries, num_cpus, num_gpus, memory, object_store_memory,
resources, accelerator_type):
resources, accelerator_type, runtime_env):
self = ActorClass.__new__(ActorClass)
self.__ray_metadata__ = ActorClassMetadata(
language, None, actor_creation_function_descriptor, None,
max_restarts, max_task_retries, num_cpus, num_gpus, memory,
object_store_memory, resources, accelerator_type)
object_store_memory, resources, accelerator_type, runtime_env)
return self
@ -704,6 +709,8 @@ class ActorClass:
function_signature = meta.method_meta.signatures["__init__"]
creation_args = signature.flatten_args(function_signature, args,
kwargs)
if runtime_env is None:
runtime_env = meta.runtime_env
if runtime_env:
runtime_env_dict = runtime_support.RuntimeEnvDict(
runtime_env).get_parsed_dict()
@ -1035,7 +1042,7 @@ def modify_class(cls):
def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
accelerator_type, max_restarts, max_task_retries):
accelerator_type, max_restarts, max_task_retries, runtime_env):
Class = modify_class(cls)
_inject_tracing_into_class(Class)
@ -1061,7 +1068,7 @@ def make_actor(cls, num_cpus, num_gpus, memory, object_store_memory, resources,
return ActorClass._ray_from_modified_class(
Class, ActorClassID.from_random(), max_restarts, max_task_retries,
num_cpus, num_gpus, memory, object_store_memory, resources,
accelerator_type)
accelerator_type, runtime_env)
def exit_actor():

View file

@ -75,7 +75,8 @@ def java_function(class_name, function_name):
None, # accelerator_type,
None, # num_returns,
None, # max_calls,
None) # max_retries
None, # max_retries
None) # runtime_env
def java_actor_class(class_name):
@ -96,4 +97,4 @@ def java_actor_class(class_name):
object_store_memory=None,
resources=None,
accelerator_type=None,
)
runtime_env=None)

View file

@ -53,6 +53,7 @@ class RemoteFunction:
of this remote function.
_max_calls: The number of times a worker can execute this function
before exiting.
_runtime_env: The runtime environment for this task.
_decorator: An optional decorator that should be applied to the remote
function invocation (as opposed to the function execution) before
invoking the function. The decorator must return a function that
@ -71,7 +72,8 @@ class RemoteFunction:
def __init__(self, language, function, function_descriptor, num_cpus,
num_gpus, memory, object_store_memory, resources,
accelerator_type, num_returns, max_calls, max_retries):
accelerator_type, num_returns, max_calls, max_retries,
runtime_env):
if inspect.iscoroutinefunction(function):
raise ValueError("'async def' should not be used for remote "
"tasks. You can wrap the async function with "
@ -98,6 +100,7 @@ class RemoteFunction:
if max_calls is None else max_calls)
self._max_retries = (DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES
if max_retries is None else max_retries)
self._runtime_env = runtime_env
self._decorator = getattr(function, "__ray_invocation_decorator__",
None)
self._function_signature = ray._private.signature.extract_signature(
@ -271,6 +274,8 @@ class RemoteFunction:
num_cpus, num_gpus, memory, object_store_memory, resources,
accelerator_type)
if runtime_env is None:
runtime_env = self._runtime_env
if runtime_env:
runtime_env_dict = runtime_support.RuntimeEnvDict(
runtime_env).get_parsed_dict()

View file

@ -693,6 +693,70 @@ def test_get_release_wheel_url():
assert requests.head(url).status_code == 200, url
@pytest.mark.skipif(
sys.platform == "win32", reason="runtime_env unsupported on Windows.")
def test_decorator_task(ray_start_cluster_head):
@ray.remote(runtime_env={"env_vars": {"foo": "bar"}})
def f():
return os.environ.get("foo")
assert ray.get(f.remote()) == "bar"
@pytest.mark.skipif(
sys.platform == "win32", reason="runtime_env unsupported on Windows.")
def test_decorator_actor(ray_start_cluster_head):
@ray.remote(runtime_env={"env_vars": {"foo": "bar"}})
class A:
def g(self):
return os.environ.get("foo")
a = A.remote()
assert ray.get(a.g.remote()) == "bar"
@pytest.mark.skipif(
sys.platform == "win32", reason="runtime_env unsupported on Windows.")
def test_decorator_complex(shutdown_only):
ray.init(
job_config=ray.job_config.JobConfig(
runtime_env={"env_vars": {
"foo": "job"
}}))
@ray.remote
def env_from_job():
return os.environ.get("foo")
assert ray.get(env_from_job.remote()) == "job"
@ray.remote(runtime_env={"env_vars": {"foo": "task"}})
def f():
return os.environ.get("foo")
assert ray.get(f.remote()) == "task"
@ray.remote(runtime_env={"env_vars": {"foo": "actor"}})
class A:
def g(self):
return os.environ.get("foo")
a = A.remote()
assert ray.get(a.g.remote()) == "actor"
# Test that runtime_env can be overridden by specifying .options().
assert ray.get(
f.options(runtime_env={
"env_vars": {
"foo": "new"
}
}).remote()) == "new"
a = A.options(runtime_env={"env_vars": {"foo": "new2"}}).remote()
assert ray.get(a.g.remote()) == "new2"
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-sv", __file__]))

View file

@ -1753,6 +1753,7 @@ def make_decorator(num_returns=None,
max_retries=None,
max_restarts=None,
max_task_retries=None,
runtime_env=None,
worker=None):
def decorator(function_or_class):
if (inspect.isfunction(function_or_class)
@ -1782,7 +1783,7 @@ def make_decorator(num_returns=None,
return ray.remote_function.RemoteFunction(
Language.PYTHON, function_or_class, None, num_cpus, num_gpus,
memory, object_store_memory, resources, accelerator_type,
num_returns, max_calls, max_retries)
num_returns, max_calls, max_retries, runtime_env)
if inspect.isclass(function_or_class):
if num_returns is not None:
@ -1804,7 +1805,7 @@ def make_decorator(num_returns=None,
return ray.actor.make_actor(function_or_class, num_cpus, num_gpus,
memory, object_store_memory, resources,
accelerator_type, max_restarts,
max_task_retries)
max_task_retries, runtime_env)
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
@ -1909,7 +1910,6 @@ def remote(*args, **kwargs):
runtime_env (Dict[str, Any]): Specifies the runtime environment for
this actor or task and its children. See
:ref:`runtime-environments` for detailed documentation.
Note: can only be set via `.options()`.
override_environment_variables (Dict[str, str]): (Deprecated in Ray
1.4.0, will be removed in Ray 1.5--please use the ``env_vars``
field of :ref:`runtime-environments` instead.) This specifies
@ -1927,29 +1927,20 @@ def remote(*args, **kwargs):
return make_decorator(worker=worker)(args[0])
# Parse the keyword arguments from the decorator.
valid_kwargs = [
"num_returns", "num_cpus", "num_gpus", "memory", "object_store_memory",
"resources", "accelerator_type", "max_calls", "max_restarts",
"max_task_retries", "max_retries", "runtime_env"
]
error_string = ("The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
"'@ray.remote', or it must be applied using some of "
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
"'memory', 'object_store_memory', 'resources', "
"'max_calls', or 'max_restarts', like "
f"the arguments in the list {valid_kwargs}, for example "
"'@ray.remote(num_returns=2, "
"resources={\"CustomResource\": 1})'.")
assert len(args) == 0 and len(kwargs) > 0, error_string
for key in kwargs:
assert key in [
"num_returns",
"num_cpus",
"num_gpus",
"memory",
"object_store_memory",
"resources",
"accelerator_type",
"max_calls",
"max_restarts",
"max_task_retries",
"max_retries",
], error_string
assert key in valid_kwargs, error_string
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else None
@ -1971,6 +1962,7 @@ def remote(*args, **kwargs):
memory = kwargs.get("memory")
object_store_memory = kwargs.get("object_store_memory")
max_retries = kwargs.get("max_retries")
runtime_env = kwargs.get("runtime_env")
return make_decorator(
num_returns=num_returns,
@ -1984,4 +1976,5 @@ def remote(*args, **kwargs):
max_restarts=max_restarts,
max_task_retries=max_task_retries,
max_retries=max_retries,
runtime_env=runtime_env,
worker=worker)