From d4a1bc7bc7e81060cef66fd1638de768975d0838 Mon Sep 17 00:00:00 2001 From: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Date: Fri, 25 Feb 2022 06:42:30 -0800 Subject: [PATCH] Revert "[runtime env] runtime env inheritance refactor (#22244)" (#22626) Breaks train_torch_linear_test.py. --- doc/source/ray-core/handling-dependencies.rst | 52 +++---- python/ray/__init__.py | 2 - python/ray/runtime_env.py | 26 ---- python/ray/tests/test_client.py | 2 +- python/ray/tests/test_runtime_env.py | 32 ----- .../ray/tests/test_runtime_env_complicated.py | 30 +--- python/ray/tests/test_runtime_env_env_vars.py | 12 +- python/ray/util/client/api.py | 8 -- src/ray/core_worker/context.cc | 2 +- src/ray/core_worker/context.h | 3 +- src/ray/core_worker/core_worker.cc | 111 +++++++++++---- src/ray/core_worker/test/core_worker_test.cc | 134 ++++++++++++++++++ 12 files changed, 244 insertions(+), 170 deletions(-) delete mode 100644 python/ray/runtime_env.py diff --git a/doc/source/ray-core/handling-dependencies.rst b/doc/source/ray-core/handling-dependencies.rst index 50499f9ee..32648c88b 100644 --- a/doc/source/ray-core/handling-dependencies.rst +++ b/doc/source/ray-core/handling-dependencies.rst @@ -363,48 +363,30 @@ To disable all deletion behavior (for example, for debugging purposes) you may s Inheritance """"""""""" -The runtime environment is inheritable, so it will apply to all tasks/actors within a job and all child tasks/actors of a task or actor once set, unless it is overridden by explicitly specifying a runtime environment for the child task/actor. +The runtime environment is inheritable, so it will apply to all tasks/actors within a job and all child tasks/actors of a task or actor once set, unless it is overridden. -1. By default, all actors and tasks inherit the current runtime env. +If an actor or task specifies a new ``runtime_env``, it will override the parent’s ``runtime_env`` (i.e., the parent actor/task's ``runtime_env``, or the job's ``runtime_env`` if there is no parent actor or task) as follows: + +* The ``runtime_env["env_vars"]`` field will be merged with the ``runtime_env["env_vars"]`` field of the parent. + This allows for environment variables set in the parent's runtime environment to be automatically propagated to the child, even if new environment variables are set in the child's runtime environment. +* Every other field in the ``runtime_env`` will be *overridden* by the child, not merged. For example, if ``runtime_env["py_modules"]`` is specified, it will replace the ``runtime_env["py_modules"]`` field of the parent. + +Example: .. code-block:: python - # Current `runtime_env` - ray.init(runtime_env={"pip": ["requests", "chess"]}) - - # Create child actor - ChildActor.remote() - - # ChildActor's actual `runtime_env` (inherit from current runtime env) - {"pip": ["requests", "chess"]} - -2. However, if you specify runtime_env for task/actor, it will override current runtime env. - -.. code-block:: python - - # Current `runtime_env` - ray.init(runtime_env={"pip": ["requests", "chess"]}) - - # Create child actor - ChildActor.options(runtime_env={"env_vars": {"A": "a", "B": "b"}}).remote() - - # ChildActor's actual `runtime_env` (specify runtime_env overrides) - {"env_vars": {"A": "a", "B": "b"}} - -3. If you'd like to still use current runtime env, you can use the :ref:`ray.get_runtime_context() ` API to get the current runtime env and modify it by yourself. - -.. code-block:: python - - # Current `runtime_env` - ray.init(runtime_env={"pip": ["requests", "chess"]}) - - # Child updates `runtime_env` - Actor.options(runtime_env=ray.get_current_runtime_env().update({"env_vars": {"A": "a", "B": "b"}})) - - # Child's actual `runtime_env` (merged with current runtime env) + # Parent's `runtime_env` {"pip": ["requests", "chess"], "env_vars": {"A": "a", "B": "b"}} + # Child's specified `runtime_env` + {"pip": ["torch", "ray[serve]"], + "env_vars": {"B": "new", "C", "c"}} + + # Child's actual `runtime_env` (merged with parent's) + {"pip": ["torch", "ray[serve]"], + "env_vars": {"A": "a", "B": "new", "C", "c"}} + .. _remote-uris: diff --git a/python/ray/__init__.py b/python/ray/__init__.py index fb9031ec4..c82d5532c 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -162,7 +162,6 @@ from ray import data # noqa: E402,F401 from ray import util # noqa: E402 from ray import _private # noqa: E402,F401 from ray import workflow # noqa: E402,F401 -from ray import runtime_env # noqa: E402,F401 # We import ClientBuilder so that modules can inherit from `ray.ClientBuilder`. from ray.client_builder import client, ClientBuilder # noqa: E402 @@ -201,7 +200,6 @@ __all__ = [ "LOCAL_MODE", "SCRIPT_MODE", "WORKER_MODE", - "runtime_env", ] # ID types diff --git a/python/ray/runtime_env.py b/python/ray/runtime_env.py deleted file mode 100644 index 3fe85b342..000000000 --- a/python/ray/runtime_env.py +++ /dev/null @@ -1,26 +0,0 @@ -import ray - -from ray._private.client_mode_hook import client_mode_hook - - -@client_mode_hook(auto_init=False) -def get_current_runtime_env(): - """Get the runtime env of the current job/worker. - - If this API is called in driver or ray client, returns the job level runtime env. - If this API is called in workers/actors, returns the worker level runtime env. - - Returns: - A dict of the current runtime env - - To merge from the parent runtime env in some specific cases, you can get the parent - runtime env by this API and modify it by yourself. - - Example: - - >>> # Inherit parent runtime env, except `env_vars` - >>> Actor.options(runtime_env=ray.get_current_runtime_env().update( - {"env_vars": {"A": "a", "B": "b"}})) - """ - - return dict(ray.get_runtime_context().runtime_env) diff --git a/python/ray/tests/test_client.py b/python/ray/tests/test_client.py index add2e7bef..6ec289016 100644 --- a/python/ray/tests/test_client.py +++ b/python/ray/tests/test_client.py @@ -746,7 +746,7 @@ def test_wrapped_actor_creation(call_ray_start): def test_init_requires_no_resources(call_ray_start, use_client): import ray - if not use_client: + if use_client: address = call_ray_start ray.init(address) else: diff --git a/python/ray/tests/test_runtime_env.py b/python/ray/tests/test_runtime_env.py index c81f724c6..baf068455 100644 --- a/python/ray/tests/test_runtime_env.py +++ b/python/ray/tests/test_runtime_env.py @@ -536,38 +536,6 @@ async def test_check_output_cmd(): assert "cmd[5]" in str(e.value) -@pytest.mark.parametrize( - "call_ray_start", - ["ray start --head --ray-client-server-port 25553"], - indirect=True, -) -@pytest.mark.parametrize("use_client", [False, True]) -def test_get_current_runtime_env(call_ray_start, use_client): - job_runtime_env = {"env_vars": {"a": "b"}} - - if not use_client: - address = call_ray_start - ray.init(address, runtime_env=job_runtime_env) - else: - ray.init("ray://localhost:25553", runtime_env=job_runtime_env) - - current_runtime_env = ray.runtime_env.get_current_runtime_env() - assert type(current_runtime_env) is dict - assert current_runtime_env == job_runtime_env - - @ray.remote - def get_runtime_env(): - return ray.runtime_env.get_current_runtime_env() - - assert ray.get(get_runtime_env.remote()) == job_runtime_env - - task_runtime_env = {"env_vars": {"a": "c"}} - assert ( - ray.get(get_runtime_env.options(runtime_env=task_runtime_env).remote()) - == task_runtime_env - ) - - MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env.MyPlugin" success_retry_number = 3 runtime_env_retry_times = 0 diff --git a/python/ray/tests/test_runtime_env_complicated.py b/python/ray/tests/test_runtime_env_complicated.py index 2daae6061..4cd3cc29b 100644 --- a/python/ray/tests/test_runtime_env_complicated.py +++ b/python/ray/tests/test_runtime_env_complicated.py @@ -841,16 +841,16 @@ def test_e2e_complex(call_ray_start, tmp_path): return Path("./test").read_text() - a = TestActor.remote() + a = TestActor.options(runtime_env={"pip": str(requirement_path)}).remote() assert ray.get(a.test.remote()) == "Hello" # Check that per-task pip specification works and that the job's - # working_dir is not inherited. + # working_dir is still inherited. @ray.remote def test_pip(): import pip_install_test # noqa - return "Hello" + return Path("./test").read_text() assert ( ray.get( @@ -859,44 +859,22 @@ def test_e2e_complex(call_ray_start, tmp_path): == "Hello" ) - @ray.remote - def test_working_dir(): - import pip_install_test # noqa - - return Path("./test").read_text() - - with pytest.raises(ray.exceptions.RayTaskError) as excinfo: - ray.get( - test_working_dir.options( - runtime_env={"pip": ["pip-install-test"]} - ).remote() - ) - assert "FileNotFoundError" in str(excinfo.value) - # Check that pip_install_test is not in the job's pip requirements. with pytest.raises(ray.exceptions.RayTaskError) as excinfo: ray.get(test_pip.remote()) assert "ModuleNotFoundError" in str(excinfo.value) # Check that per-actor pip specification works and that the job's - # working_dir is not inherited. + # working_dir is still inherited. @ray.remote class TestActor: def test(self): import pip_install_test # noqa - return "Hello" - - def test_working_dir(self): - import pip_install_test # noqa - return Path("./test").read_text() a = TestActor.options(runtime_env={"pip": ["pip-install-test"]}).remote() assert ray.get(a.test.remote()) == "Hello" - with pytest.raises(ray.exceptions.RayTaskError) as excinfo: - ray.get(a.test_working_dir.remote()) - assert "FileNotFoundError" in str(excinfo.value) @pytest.mark.skipif( diff --git a/python/ray/tests/test_runtime_env_env_vars.py b/python/ray/tests/test_runtime_env_env_vars.py index bdb2e5da9..4991f11d8 100644 --- a/python/ray/tests/test_runtime_env_env_vars.py +++ b/python/ray/tests/test_runtime_env_env_vars.py @@ -106,7 +106,7 @@ def test_environment_variables_multitenancy(shutdown_only): } ).remote("foo2") ) - is None + == "bar2" ) @@ -163,7 +163,7 @@ def test_environment_variables_complex(shutdown_only): assert ray.get(a.get.remote("a")) == "b" assert ray.get(a.get_task.remote("a")) == "b" - assert ray.get(a.nested_get.remote("a")) is None + assert ray.get(a.nested_get.remote("a")) == "b" assert ray.get(a.nested_get.remote("c")) == "e" assert ray.get(a.nested_get.remote("d")) == "dd" assert ( @@ -179,9 +179,9 @@ def test_environment_variables_complex(shutdown_only): == "b" ) - assert ray.get(a.get.remote("z")) is None - assert ray.get(a.get_task.remote("z")) is None - assert ray.get(a.nested_get.remote("z")) is None + assert ray.get(a.get.remote("z")) == "job_z" + assert ray.get(a.get_task.remote("z")) == "job_z" + assert ray.get(a.nested_get.remote("z")) == "job_z" assert ( ray.get( get_env.options( @@ -192,7 +192,7 @@ def test_environment_variables_complex(shutdown_only): } ).remote("z") ) - is None + == "job_z" ) diff --git a/python/ray/util/client/api.py b/python/ray/util/client/api.py index 09073e62e..e1dbe9f7f 100644 --- a/python/ray/util/client/api.py +++ b/python/ray/util/client/api.py @@ -277,14 +277,6 @@ class ClientAPI: """ return ClientWorkerPropertyAPI(self.worker).build_runtime_context() - def get_current_runtime_env(self): - """Get the runtime env of the current client/driver. - - Returns: - A dict of current runtime env. - """ - return dict(self.get_runtime_context().runtime_env) - # Client process isn't assigned any GPUs. def get_gpu_ids(self) -> list: return [] diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 08593f0cc..d71f312b0 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -210,7 +210,7 @@ const std::string &WorkerContext::GetCurrentSerializedRuntimeEnv() const { return runtime_env_info_.serialized_runtime_env(); } -std::shared_ptr WorkerContext::GetCurrentRuntimeEnv() const { +std::shared_ptr WorkerContext::GetCurrentRuntimeEnv() const { absl::ReaderMutexLock lock(&mutex_); return runtime_env_; } diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 973defe35..5fc96ccff 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -44,8 +44,7 @@ class WorkerContext { const std::string &GetCurrentSerializedRuntimeEnv() const LOCKS_EXCLUDED(mutex_); - std::shared_ptr GetCurrentRuntimeEnv() const - LOCKS_EXCLUDED(mutex_); + std::shared_ptr GetCurrentRuntimeEnv() const LOCKS_EXCLUDED(mutex_); // TODO(edoakes): remove this once Python core worker uses the task interfaces. void SetCurrentTaskId(const TaskID &task_id, uint64_t attempt_number); diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 531951e4a..96ef91abf 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1366,6 +1366,43 @@ std::unordered_map AddPlacementGroupConstraint( return resources; } +rpc::RuntimeEnv CoreWorker::OverrideRuntimeEnv( + const rpc::RuntimeEnv &child, const std::shared_ptr parent) { + // By default, the child runtime env inherits non-specified options from the + // parent. There is one exception to this: + // - The env_vars dictionaries are merged, so environment variables + // not specified by the child are still inherited from the parent. + + // Override environment variables. + google::protobuf::Map result_env_vars(parent->env_vars()); + result_env_vars.insert(child.env_vars().begin(), child.env_vars().end()); + // Inherit all other non-specified options from the parent. + rpc::RuntimeEnv result_runtime_env(*parent); + // TODO(SongGuyang): avoid dupliacated fields. + result_runtime_env.MergeFrom(child); + if (child.python_runtime_env().py_modules().size() > 0 && + parent->python_runtime_env().py_modules().size() > 0) { + result_runtime_env.mutable_python_runtime_env()->clear_py_modules(); + for (auto &module : child.python_runtime_env().py_modules()) { + result_runtime_env.mutable_python_runtime_env()->add_py_modules(module); + } + result_runtime_env.mutable_uris()->clear_py_modules_uris(); + result_runtime_env.mutable_uris()->mutable_py_modules_uris()->CopyFrom( + child.uris().py_modules_uris()); + } + if (child.python_runtime_env().has_pip_runtime_env() && + parent->python_runtime_env().has_pip_runtime_env()) { + result_runtime_env.mutable_python_runtime_env()->clear_pip_runtime_env(); + result_runtime_env.mutable_python_runtime_env()->mutable_pip_runtime_env()->CopyFrom( + child.python_runtime_env().pip_runtime_env()); + } + if (!result_env_vars.empty()) { + result_runtime_env.mutable_env_vars()->insert(result_env_vars.begin(), + result_env_vars.end()); + } + return result_runtime_env; +} + // TODO(SongGuyang): This function exists in both C++ and Python. We should make this // logic clearly. static std::string encode_plugin_uri(std::string plugin, std::string uri) { @@ -1399,45 +1436,57 @@ static std::vector GetUrisFromRuntimeEnv( return result; } -std::string CoreWorker::OverrideTaskOrActorRuntimeEnv( - const std::string &serialized_runtime_env, - std::vector *runtime_env_uris) { - std::shared_ptr parent_runtime_env; - std::string parent_serialized_runtime_env; - if (options_.worker_type == WorkerType::DRIVER) { - parent_runtime_env = job_runtime_env_; - parent_serialized_runtime_env = - job_config_->runtime_env_info().serialized_runtime_env(); - } else { - parent_runtime_env = worker_context_.GetCurrentRuntimeEnv(); - parent_serialized_runtime_env = worker_context_.GetCurrentSerializedRuntimeEnv(); - } - if (IsRuntimeEnvEmpty(serialized_runtime_env)) { - // Try to inherit runtime env from job or worker. - *runtime_env_uris = GetUrisFromRuntimeEnv(parent_runtime_env.get()); - return parent_serialized_runtime_env; - } - - if (!IsRuntimeEnvEmpty(parent_serialized_runtime_env)) { - // TODO(SongGuyang): We add this warning log because of the change of API behavior. - // Refer to https://github.com/ray-project/ray/issues/21818. - // Modify this log level to `INFO` or `DEBUG` after a few release versions. - RAY_LOG(WARNING) << "Runtime env already exists and the parent runtime env is " - << parent_serialized_runtime_env << ". It will be overridden by " - << serialized_runtime_env << "."; - } - +static std::vector GetUrisFromSerializedRuntimeEnv( + const std::string &serialized_runtime_env) { rpc::RuntimeEnv runtime_env; if (!google::protobuf::util::JsonStringToMessage(serialized_runtime_env, &runtime_env) .ok()) { RAY_LOG(WARNING) << "Parse runtime env failed for " << serialized_runtime_env; // TODO(SongGuyang): We pass the raw string here and the task will fail after an // exception raised in runtime env agent. Actually, we can fail the task here. + return {}; + } + return GetUrisFromRuntimeEnv(&runtime_env); +} + +std::string CoreWorker::OverrideTaskOrActorRuntimeEnv( + const std::string &serialized_runtime_env, + std::vector *runtime_env_uris) { + std::shared_ptr parent = nullptr; + if (options_.worker_type == WorkerType::DRIVER) { + if (IsRuntimeEnvEmpty(serialized_runtime_env)) { + *runtime_env_uris = GetUrisFromRuntimeEnv(job_runtime_env_.get()); + return job_config_->runtime_env_info().serialized_runtime_env(); + } + parent = job_runtime_env_; + } else { + if (IsRuntimeEnvEmpty(serialized_runtime_env)) { + *runtime_env_uris = + GetUrisFromRuntimeEnv(worker_context_.GetCurrentRuntimeEnv().get()); + return worker_context_.GetCurrentSerializedRuntimeEnv(); + } + parent = worker_context_.GetCurrentRuntimeEnv(); + } + if (parent) { + rpc::RuntimeEnv child_runtime_env; + if (!google::protobuf::util::JsonStringToMessage(serialized_runtime_env, + &child_runtime_env) + .ok()) { + RAY_LOG(WARNING) << "Parse runtime env failed for " << serialized_runtime_env; + // TODO(SongGuyang): We pass the raw string here and the task will fail after an + // exception raised in runtime env agent. Actually, we can fail the task here. + return serialized_runtime_env; + } + auto override_runtime_env = OverrideRuntimeEnv(child_runtime_env, parent); + std::string result; + RAY_CHECK( + google::protobuf::util::MessageToJsonString(override_runtime_env, &result).ok()); + *runtime_env_uris = GetUrisFromRuntimeEnv(&override_runtime_env); + return result; + } else { + *runtime_env_uris = GetUrisFromSerializedRuntimeEnv(serialized_runtime_env); return serialized_runtime_env; } - - *runtime_env_uris = GetUrisFromRuntimeEnv(&runtime_env); - return serialized_runtime_env; } void CoreWorker::BuildCommonTaskSpec( diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 6a5b9e80c..4db832172 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -935,6 +935,140 @@ TEST_F(TwoNodeTest, TestActorTaskCrossNodesFailure) { TestActorFailure(resources); } +TEST(TestOverrideRuntimeEnv, TestOverrideEnvVars) { + rpc::RuntimeEnv child; + auto parent = std::make_shared(); + // child {"a": "b"}, parent {}, expected {"a": "b"} + (*child.mutable_env_vars())["a"] = "b"; + auto result = CoreWorker::OverrideRuntimeEnv(child, parent); + ASSERT_EQ(result.env_vars().size(), 1); + ASSERT_EQ(result.env_vars().count("a"), 1); + ASSERT_EQ(result.env_vars().at("a"), "b"); + child.clear_env_vars(); + parent->clear_env_vars(); + // child {}, parent {"a": "b"}, expected {"a": "b"} + (*(parent->mutable_env_vars()))["a"] = "b"; + result = CoreWorker::OverrideRuntimeEnv(child, parent); + ASSERT_EQ(result.env_vars().size(), 1); + ASSERT_EQ(result.env_vars().count("a"), 1); + ASSERT_EQ(result.env_vars().at("a"), "b"); + child.clear_env_vars(); + parent->clear_env_vars(); + // child {"a": "b"}, parent {"a": "d"}, expected {"a": "b"} + (*child.mutable_env_vars())["a"] = "b"; + (*(parent->mutable_env_vars()))["a"] = "d"; + result = CoreWorker::OverrideRuntimeEnv(child, parent); + ASSERT_EQ(result.env_vars().size(), 1); + ASSERT_EQ(result.env_vars().count("a"), 1); + ASSERT_EQ(result.env_vars().at("a"), "b"); + child.clear_env_vars(); + parent->clear_env_vars(); + // child {"a": "b"}, parent {"c": "d"}, expected {"a": "b", "c": "d"} + (*child.mutable_env_vars())["a"] = "b"; + (*(parent->mutable_env_vars()))["c"] = "d"; + result = CoreWorker::OverrideRuntimeEnv(child, parent); + ASSERT_EQ(result.env_vars().size(), 2); + ASSERT_EQ(result.env_vars().count("a"), 1); + ASSERT_EQ(result.env_vars().at("a"), "b"); + ASSERT_EQ(result.env_vars().count("c"), 1); + ASSERT_EQ(result.env_vars().at("c"), "d"); + child.clear_env_vars(); + parent->clear_env_vars(); + // child {"a": "b"}, parent {"a": "e", "c": "d"}, expected {"a": "b", "c": "d"} + (*child.mutable_env_vars())["a"] = "b"; + (*(parent->mutable_env_vars()))["a"] = "e"; + (*(parent->mutable_env_vars()))["c"] = "d"; + result = CoreWorker::OverrideRuntimeEnv(child, parent); + ASSERT_EQ(result.env_vars().size(), 2); + ASSERT_EQ(result.env_vars().count("a"), 1); + ASSERT_EQ(result.env_vars().at("a"), "b"); + ASSERT_EQ(result.env_vars().count("c"), 1); + ASSERT_EQ(result.env_vars().at("c"), "d"); + child.clear_env_vars(); + parent->clear_env_vars(); +} + +TEST(TestOverrideRuntimeEnv, TestPyModulesInherit) { + rpc::RuntimeEnv child; + auto parent = std::make_shared(); + parent->mutable_python_runtime_env()->mutable_dependent_modules()->Add("s3://456"); + parent->mutable_uris()->mutable_py_modules_uris()->Add("s3://456"); + auto result = CoreWorker::OverrideRuntimeEnv(child, parent); + ASSERT_EQ(result.python_runtime_env().dependent_modules().size(), 1); + ASSERT_EQ(result.python_runtime_env().dependent_modules()[0], "s3://456"); + ASSERT_EQ(result.uris().py_modules_uris().size(), 1); + ASSERT_EQ(result.uris().py_modules_uris()[0], "s3://456"); +} + +TEST(TestOverrideRuntimeEnv, TestOverridePyModules) { + rpc::RuntimeEnv child; + auto parent = std::make_shared(); + child.mutable_python_runtime_env()->mutable_dependent_modules()->Add("s3://123"); + child.mutable_uris()->mutable_py_modules_uris()->Add("s3://123"); + parent->mutable_python_runtime_env()->mutable_dependent_modules()->Add("s3://456"); + parent->mutable_python_runtime_env()->mutable_dependent_modules()->Add("s3://789"); + parent->mutable_uris()->mutable_py_modules_uris()->Add("s3://456"); + parent->mutable_uris()->mutable_py_modules_uris()->Add("s3://789"); + auto result = CoreWorker::OverrideRuntimeEnv(child, parent); + ASSERT_EQ(result.python_runtime_env().dependent_modules().size(), 1); + ASSERT_EQ(result.python_runtime_env().dependent_modules()[0], "s3://123"); + ASSERT_EQ(result.uris().py_modules_uris().size(), 1); + ASSERT_EQ(result.uris().py_modules_uris()[0], "s3://123"); +} + +TEST(TestOverrideRuntimeEnv, TestWorkingDirInherit) { + rpc::RuntimeEnv child; + auto parent = std::make_shared(); + parent->set_working_dir("uri://abc"); + auto result = CoreWorker::OverrideRuntimeEnv(child, parent); + ASSERT_EQ(result.working_dir(), "uri://abc"); +} + +TEST(TestOverrideRuntimeEnv, TestWorkingDirOverride) { + rpc::RuntimeEnv child; + auto parent = std::make_shared(); + child.set_working_dir("uri://abc"); + parent->set_working_dir("uri://def"); + auto result = CoreWorker::OverrideRuntimeEnv(child, parent); + ASSERT_EQ(result.working_dir(), "uri://abc"); +} + +TEST(TestOverrideRuntimeEnv, TestCondaInherit) { + rpc::RuntimeEnv child; + auto parent = std::make_shared(); + child.mutable_uris()->set_working_dir_uri("gcs://abc"); + parent->mutable_uris()->set_working_dir_uri("gcs://def"); + parent->mutable_uris()->set_conda_uri("conda://456"); + parent->mutable_python_runtime_env()->mutable_conda_runtime_env()->set_conda_env_name( + "my-env-name"); + auto result = CoreWorker::OverrideRuntimeEnv(child, parent); + ASSERT_EQ(result.uris().working_dir_uri(), "gcs://abc"); + ASSERT_EQ(result.uris().conda_uri(), "conda://456"); + ASSERT_TRUE(result.python_runtime_env().has_conda_runtime_env()); + ASSERT_TRUE(result.python_runtime_env().conda_runtime_env().has_conda_env_name()); + ASSERT_EQ(result.python_runtime_env().conda_runtime_env().conda_env_name(), + "my-env-name"); +} + +TEST(TestOverrideRuntimeEnv, TestCondaOverride) { + rpc::RuntimeEnv child; + auto parent = std::make_shared(); + child.mutable_uris()->set_conda_uri("conda://123"); + child.mutable_python_runtime_env()->mutable_conda_runtime_env()->set_conda_env_name( + "my-env-name-123"); + parent->mutable_uris()->set_conda_uri("conda://456"); + parent->mutable_python_runtime_env()->mutable_conda_runtime_env()->set_conda_env_name( + "my-env-name-456"); + parent->mutable_uris()->set_working_dir_uri("gcs://def"); + auto result = CoreWorker::OverrideRuntimeEnv(child, parent); + ASSERT_EQ(result.uris().conda_uri(), "conda://123"); + ASSERT_TRUE(result.python_runtime_env().has_conda_runtime_env()); + ASSERT_TRUE(result.python_runtime_env().conda_runtime_env().has_conda_env_name()); + ASSERT_EQ(result.python_runtime_env().conda_runtime_env().conda_env_name(), + "my-env-name-123"); + ASSERT_EQ(result.uris().working_dir_uri(), "gcs://def"); +} + } // namespace core } // namespace ray