[runtime env] Introduce async Manager.create (#22311)

This commit is contained in:
Jialing He 2022-02-15 06:26:47 +08:00 committed by GitHub
parent 845861fdc1
commit 192f9de421
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 91 additions and 99 deletions

View file

@ -122,100 +122,91 @@ class RuntimeEnvAgent(
async def _setup_runtime_env(
serialized_runtime_env, serialized_allocated_resource_instances
):
# This function will be ran inside a thread
def run_setup_with_logger():
runtime_env = RuntimeEnv(serialized_runtime_env=serialized_runtime_env)
allocated_resource: dict = json.loads(
serialized_allocated_resource_instances or "{}"
runtime_env = RuntimeEnv(serialized_runtime_env=serialized_runtime_env)
allocated_resource: dict = json.loads(
serialized_allocated_resource_instances or "{}"
)
# Use a separate logger for each job.
per_job_logger = self.get_or_create_logger(request.job_id)
# TODO(chenk008): Add log about allocated_resource to
# avoid lint error. That will be moved to cgroup plugin.
per_job_logger.debug(f"Worker has resource :" f"{allocated_resource}")
context = RuntimeEnvContext(env_vars=runtime_env.env_vars())
await self._container_manager.setup(
runtime_env, context, logger=per_job_logger
)
for (manager, uri_cache) in [
(self._working_dir_manager, self._working_dir_uri_cache),
(self._conda_manager, self._conda_uri_cache),
(self._pip_manager, self._pip_uri_cache),
]:
uri = manager.get_uri(runtime_env)
if uri is not None:
if uri not in uri_cache:
per_job_logger.debug(f"Cache miss for URI {uri}.")
size_bytes = await manager.create(
uri, runtime_env, context, logger=per_job_logger
)
uri_cache.add(uri, size_bytes, logger=per_job_logger)
else:
per_job_logger.debug(f"Cache hit for URI {uri}.")
uri_cache.mark_used(uri, logger=per_job_logger)
manager.modify_context(uri, runtime_env, context)
# Set up py_modules. For now, py_modules uses multiple URIs so
# the logic is slightly different from working_dir, conda, and
# pip above.
py_modules_uris = self._py_modules_manager.get_uris(runtime_env)
if py_modules_uris is not None:
for uri in py_modules_uris:
if uri not in self._py_modules_uri_cache:
per_job_logger.debug(f"Cache miss for URI {uri}.")
size_bytes = await self._py_modules_manager.create(
uri, runtime_env, context, logger=per_job_logger
)
self._py_modules_uri_cache.add(
uri, size_bytes, logger=per_job_logger
)
else:
per_job_logger.debug(f"Cache hit for URI {uri}.")
self._py_modules_uri_cache.mark_used(uri, logger=per_job_logger)
self._py_modules_manager.modify_context(
py_modules_uris, runtime_env, context
)
# Add the mapping of URIs -> the serialized environment to be
# used for cache invalidation.
if runtime_env.working_dir_uri():
uri = runtime_env.working_dir_uri()
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.py_modules_uris():
for uri in runtime_env.py_modules_uris():
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.conda_uri():
uri = runtime_env.conda_uri()
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.pip_uri():
uri = runtime_env.pip_uri()
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.plugin_uris():
for uri in runtime_env.plugin_uris():
self._uris_to_envs[uri].add(serialized_runtime_env)
# Run setup function from all the plugins
for plugin_class_path, config in runtime_env.plugins():
per_job_logger.debug(
f"Setting up runtime env plugin {plugin_class_path}"
)
plugin_class = import_attr(plugin_class_path)
# TODO(simon): implement uri support
plugin_class.create("uri not implemented", json.loads(config), context)
plugin_class.modify_context(
"uri not implemented", json.loads(config), context
)
# Use a separate logger for each job.
per_job_logger = self.get_or_create_logger(request.job_id)
# TODO(chenk008): Add log about allocated_resource to
# avoid lint error. That will be moved to cgroup plugin.
per_job_logger.debug(f"Worker has resource :" f"{allocated_resource}")
context = RuntimeEnvContext(env_vars=runtime_env.env_vars())
self._container_manager.setup(
runtime_env, context, logger=per_job_logger
)
for (manager, uri_cache) in [
(self._working_dir_manager, self._working_dir_uri_cache),
(self._conda_manager, self._conda_uri_cache),
(self._pip_manager, self._pip_uri_cache),
]:
uri = manager.get_uri(runtime_env)
if uri is not None:
if uri not in uri_cache:
per_job_logger.debug(f"Cache miss for URI {uri}.")
size_bytes = manager.create(
uri, runtime_env, context, logger=per_job_logger
)
uri_cache.add(uri, size_bytes, logger=per_job_logger)
else:
per_job_logger.debug(f"Cache hit for URI {uri}.")
uri_cache.mark_used(uri, logger=per_job_logger)
manager.modify_context(uri, runtime_env, context)
# Set up py_modules. For now, py_modules uses multiple URIs so
# the logic is slightly different from working_dir, conda, and
# pip above.
py_modules_uris = self._py_modules_manager.get_uris(runtime_env)
if py_modules_uris is not None:
for uri in py_modules_uris:
if uri not in self._py_modules_uri_cache:
per_job_logger.debug(f"Cache miss for URI {uri}.")
size_bytes = self._py_modules_manager.create(
uri, runtime_env, context, logger=per_job_logger
)
self._py_modules_uri_cache.add(
uri, size_bytes, logger=per_job_logger
)
else:
per_job_logger.debug(f"Cache hit for URI {uri}.")
self._py_modules_uri_cache.mark_used(
uri, logger=per_job_logger
)
self._py_modules_manager.modify_context(
py_modules_uris, runtime_env, context
)
# Add the mapping of URIs -> the serialized environment to be
# used for cache invalidation.
if runtime_env.working_dir_uri():
uri = runtime_env.working_dir_uri()
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.py_modules_uris():
for uri in runtime_env.py_modules_uris():
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.conda_uri():
uri = runtime_env.conda_uri()
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.pip_uri():
uri = runtime_env.pip_uri()
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.plugin_uris():
for uri in runtime_env.plugin_uris():
self._uris_to_envs[uri].add(serialized_runtime_env)
# Run setup function from all the plugins
for plugin_class_path, config in runtime_env.plugins():
per_job_logger.debug(
f"Setting up runtime env plugin {plugin_class_path}"
)
plugin_class = import_attr(plugin_class_path)
# TODO(simon): implement uri support
plugin_class.create(
"uri not implemented", json.loads(config), context
)
plugin_class.modify_context(
"uri not implemented", json.loads(config), context
)
return context
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, run_setup_with_logger)
return context
serialized_env = request.serialized_runtime_env

View file

@ -299,7 +299,7 @@ class CondaManager:
return local_dir_size
def create(
async def create(
self,
uri: Optional[str],
runtime_env: RuntimeEnv,

View file

@ -50,7 +50,7 @@ class ContainerManager:
# can connect to raylet.
self._ray_tmp_dir = tmp_dir
def setup(
async def setup(
self,
runtime_env: RuntimeEnv,
context: RuntimeEnvContext,

View file

@ -286,7 +286,7 @@ class PipManager:
return local_dir_size
def create(
async def create(
self,
uri: str,
runtime_env: RuntimeEnv,

View file

@ -123,7 +123,7 @@ class PyModulesManager:
def get_uris(self, runtime_env: dict) -> Optional[List[str]]:
return runtime_env.py_modules()
def create(
async def create(
self,
uri: str,
runtime_env: RuntimeEnv,

View file

@ -109,7 +109,7 @@ class WorkingDirManager:
return working_dir_uri
return None
def create(
async def create(
self,
uri: str,
runtime_env: dict,

View file

@ -31,7 +31,8 @@ GS_PACKAGE_URI = "gs://public-runtime-env-test/test_module.zip"
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
def test_create_delete_size_equal(tmpdir, ray_start_regular):
@pytest.mark.asyncio
async def test_create_delete_size_equal(tmpdir, ray_start_regular):
"""Tests that `create` and `delete_uri` return the same size for a URI."""
# Create an arbitrary nonempty directory to upload.
@ -50,7 +51,7 @@ def test_create_delete_size_equal(tmpdir, ray_start_regular):
manager = WorkingDirManager(tmpdir)
created_size_bytes = manager.create(uri, {}, RuntimeEnvContext())
created_size_bytes = await manager.create(uri, {}, RuntimeEnvContext())
deleted_size_bytes = manager.delete_uri(uri)
assert created_size_bytes == deleted_size_bytes