[tune] Fix file descriptor leak by syncer (#12590)

This commit is contained in:
Richard Liaw 2020-12-03 13:39:04 -08:00 committed by GitHub
parent 36e46ed923
commit 1ce5e0e99f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 3 deletions

View file

@ -135,6 +135,10 @@ class SyncClient:
"""Resets state.""" """Resets state."""
pass pass
def close(self):
"""Clean up hook."""
pass
class FunctionBasedClient(SyncClient): class FunctionBasedClient(SyncClient):
def __init__(self, sync_up_func, sync_down_func, delete_func=None): def __init__(self, sync_up_func, sync_down_func, delete_func=None):
@ -179,6 +183,7 @@ class CommandBasedClient(SyncClient):
self.sync_down_template = sync_down_template self.sync_down_template = sync_down_template
self.delete_template = delete_template self.delete_template = delete_template
self.logfile = None self.logfile = None
self._closed = False
self.cmd_process = None self.cmd_process = None
def set_logdir(self, logdir): def set_logdir(self, logdir):
@ -189,6 +194,16 @@ class CommandBasedClient(SyncClient):
""" """
self.logfile = tempfile.NamedTemporaryFile( self.logfile = tempfile.NamedTemporaryFile(
prefix="log_sync_out", dir=logdir, suffix=".log", delete=False) prefix="log_sync_out", dir=logdir, suffix=".log", delete=False)
self._closed = False
def _get_logfile(self):
if self._closed:
raise RuntimeError(
"[internalerror] The client has been closed. "
"Please report this stacktrace + your cluster configuration "
"on Github!")
else:
return self.logfile
def sync_up(self, source, target): def sync_up(self, source, target):
return self._execute(self.sync_up_template, source, target) return self._execute(self.sync_up_template, source, target)
@ -203,7 +218,10 @@ class CommandBasedClient(SyncClient):
final_cmd = self.delete_template.format(target=quote(target)) final_cmd = self.delete_template.format(target=quote(target))
logger.debug("Running delete: {}".format(final_cmd)) logger.debug("Running delete: {}".format(final_cmd))
self.cmd_process = subprocess.Popen( self.cmd_process = subprocess.Popen(
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self.logfile) final_cmd,
shell=True,
stderr=subprocess.PIPE,
stdout=self._get_logfile())
return True return True
def wait(self): def wait(self):
@ -223,6 +241,13 @@ class CommandBasedClient(SyncClient):
logger.warning("Sync process still running but resetting anyways.") logger.warning("Sync process still running but resetting anyways.")
self.cmd_process = None self.cmd_process = None
def close(self):
if self.logfile:
logger.debug(f"Closing the logfile: {str(self.logfile)}")
self.logfile.close()
self.logfile = None
self._closed = True
@property @property
def is_running(self): def is_running(self):
"""Returns whether a sync or delete process is running.""" """Returns whether a sync or delete process is running."""
@ -240,7 +265,10 @@ class CommandBasedClient(SyncClient):
source=quote(source), target=quote(target)) source=quote(source), target=quote(target))
logger.debug("Running sync: {}".format(final_cmd)) logger.debug("Running sync: {}".format(final_cmd))
self.cmd_process = subprocess.Popen( self.cmd_process = subprocess.Popen(
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self.logfile) final_cmd,
shell=True,
stderr=subprocess.PIPE,
stdout=self._get_logfile())
return True return True
@staticmethod @staticmethod

View file

@ -206,6 +206,9 @@ class Syncer:
self.last_sync_down_time = float("-inf") self.last_sync_down_time = float("-inf")
self.sync_client.reset() self.sync_client.reset()
def close(self):
self.sync_client.close()
@property @property
def _remote_path(self): def _remote_path(self):
return self._remote_dir return self._remote_dir
@ -445,6 +448,7 @@ class SyncerCallback(Callback):
trainable_ip = ray.get(trial.runner.get_current_ip.remote()) trainable_ip = ray.get(trial.runner.get_current_ip.remote())
trial_syncer.set_worker_ip(trainable_ip) trial_syncer.set_worker_ip(trainable_ip)
trial_syncer.sync_down_if_needed() trial_syncer.sync_down_if_needed()
trial_syncer.close()
def on_checkpoint(self, iteration: int, trials: List["Trial"], def on_checkpoint(self, iteration: int, trials: List["Trial"],
trial: "Trial", checkpoint: Checkpoint, **info): trial: "Trial", checkpoint: Checkpoint, **info):

View file

@ -2,6 +2,7 @@ import os
import numpy as np import numpy as np
import json import json
import random import random
import uuid
import ray.utils import ray.utils
@ -20,7 +21,12 @@ LOCAL_DELETE_TEMPLATE = "rm -rf {target}"
def mock_storage_client(): def mock_storage_client():
"""Mocks storage client that treats a local dir as durable storage.""" """Mocks storage client that treats a local dir as durable storage."""
return get_sync_client(LOCAL_SYNC_TEMPLATE, LOCAL_DELETE_TEMPLATE) client = get_sync_client(LOCAL_SYNC_TEMPLATE, LOCAL_DELETE_TEMPLATE)
path = os.path.join(ray.utils.get_user_temp_dir(),
f"mock-client-{uuid.uuid4().hex[:4]}")
os.makedirs(path, exist_ok=True)
client.set_logdir(path)
return client
class MockNodeSyncer(NodeSyncer): class MockNodeSyncer(NodeSyncer):