[tune] Local Mode support (#4138)

This commit is contained in:
Richard Liaw 2019-03-03 14:05:59 -08:00 committed by GitHub
parent e2e6ef198b
commit 3483282254
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 1 deletions

View file

@ -17,6 +17,15 @@ from ray.tune.trial_executor import TrialExecutor
logger = logging.getLogger(__name__)
class _LocalWrapper(object):
def __init__(self, result):
self._result = result
def unwrap(self):
"""Returns the wrapped result."""
return self._result
class RayTrialExecutor(TrialExecutor):
"""An implemention of TrialExecutor based on Ray."""
@ -61,6 +70,11 @@ class RayTrialExecutor(TrialExecutor):
assert trial.status == Trial.RUNNING, trial.status
remote = trial.runner.train.remote()
# Local Mode
if isinstance(remote, dict):
remote = _LocalWrapper(remote)
self._running[remote] = trial
def _start_trial(self, trial, checkpoint=None):
@ -229,6 +243,10 @@ class RayTrialExecutor(TrialExecutor):
raise ValueError("Trial was not running.")
self._running.pop(trial_future[0])
result = ray.get(trial_future[0])
# For local mode
if isinstance(result, _LocalWrapper):
result = result.unwrap()
return result
def _commit_resources(self, resources):
@ -266,7 +284,14 @@ class RayTrialExecutor(TrialExecutor):
def _update_avail_resources(self, num_retries=5):
for i in range(num_retries):
resources = ray.global_state.cluster_resources()
try:
resources = ray.global_state.cluster_resources()
except Exception:
# TODO(rliaw): Remove this when local mode is fixed.
# https://github.com/ray-project/ray/issues/4147
logger.debug("Using resources for local machine.")
resources = ray.services.check_and_update_resources(
None, None, None)
if not resources:
logger.warning("Cluster resources not detected. Retrying...")
time.sleep(0.5)

View file

@ -110,5 +110,15 @@ class RayTrialExecutorTest(unittest.TestCase):
return suggester.next_trials()
class LocalModeExecutorTest(RayTrialExecutorTest):
def setUp(self):
self.trial_executor = RayTrialExecutor(queue_trials=False)
ray.init(local_mode=True)
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
if __name__ == "__main__":
unittest.main(verbosity=2)