mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Local Mode support (#4138)
This commit is contained in:
parent
e2e6ef198b
commit
3483282254
2 changed files with 36 additions and 1 deletions
|
@ -17,6 +17,15 @@ from ray.tune.trial_executor import TrialExecutor
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _LocalWrapper(object):
|
||||
def __init__(self, result):
|
||||
self._result = result
|
||||
|
||||
def unwrap(self):
|
||||
"""Returns the wrapped result."""
|
||||
return self._result
|
||||
|
||||
|
||||
class RayTrialExecutor(TrialExecutor):
|
||||
"""An implemention of TrialExecutor based on Ray."""
|
||||
|
||||
|
@ -61,6 +70,11 @@ class RayTrialExecutor(TrialExecutor):
|
|||
|
||||
assert trial.status == Trial.RUNNING, trial.status
|
||||
remote = trial.runner.train.remote()
|
||||
|
||||
# Local Mode
|
||||
if isinstance(remote, dict):
|
||||
remote = _LocalWrapper(remote)
|
||||
|
||||
self._running[remote] = trial
|
||||
|
||||
def _start_trial(self, trial, checkpoint=None):
|
||||
|
@ -229,6 +243,10 @@ class RayTrialExecutor(TrialExecutor):
|
|||
raise ValueError("Trial was not running.")
|
||||
self._running.pop(trial_future[0])
|
||||
result = ray.get(trial_future[0])
|
||||
|
||||
# For local mode
|
||||
if isinstance(result, _LocalWrapper):
|
||||
result = result.unwrap()
|
||||
return result
|
||||
|
||||
def _commit_resources(self, resources):
|
||||
|
@ -266,7 +284,14 @@ class RayTrialExecutor(TrialExecutor):
|
|||
|
||||
def _update_avail_resources(self, num_retries=5):
|
||||
for i in range(num_retries):
|
||||
resources = ray.global_state.cluster_resources()
|
||||
try:
|
||||
resources = ray.global_state.cluster_resources()
|
||||
except Exception:
|
||||
# TODO(rliaw): Remove this when local mode is fixed.
|
||||
# https://github.com/ray-project/ray/issues/4147
|
||||
logger.debug("Using resources for local machine.")
|
||||
resources = ray.services.check_and_update_resources(
|
||||
None, None, None)
|
||||
if not resources:
|
||||
logger.warning("Cluster resources not detected. Retrying...")
|
||||
time.sleep(0.5)
|
||||
|
|
|
@ -110,5 +110,15 @@ class RayTrialExecutorTest(unittest.TestCase):
|
|||
return suggester.next_trials()
|
||||
|
||||
|
||||
class LocalModeExecutorTest(RayTrialExecutorTest):
|
||||
def setUp(self):
|
||||
self.trial_executor = RayTrialExecutor(queue_trials=False)
|
||||
ray.init(local_mode=True)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
|
Loading…
Add table
Reference in a new issue