From 007634fd1b7f52d6418043b52e01141f72d5d6b4 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Wed, 4 Nov 2020 06:04:40 +0100 Subject: [PATCH] [tune] logger refactor part 2: Add SyncerCallback (#11748) Co-authored-by: Richard Liaw --- python/ray/tune/callback.py | 38 +++++---- python/ray/tune/config_parser.py | 1 - python/ray/tune/logger.py | 88 ++------------------ python/ray/tune/syncer.py | 95 +++++++++++++++++++++- python/ray/tune/tests/test_cluster.py | 85 +++++++++++-------- python/ray/tune/tests/test_sync.py | 9 +- python/ray/tune/tests/test_tune_restore.py | 6 +- python/ray/tune/trial.py | 51 +----------- python/ray/tune/tune.py | 7 +- python/ray/tune/utils/callback.py | 24 ++++++ 10 files changed, 211 insertions(+), 193 deletions(-) create mode 100644 python/ray/tune/utils/callback.py diff --git a/python/ray/tune/callback.py b/python/ray/tune/callback.py index c20070cfa..b26a48b84 100644 --- a/python/ray/tune/callback.py +++ b/python/ray/tune/callback.py @@ -1,7 +1,9 @@ -from typing import Dict, List +from typing import TYPE_CHECKING, Dict, List from ray.tune.checkpoint_manager import Checkpoint -from ray.tune.trial import Trial + +if TYPE_CHECKING: + from ray.tune.trial import Trial class Callback: @@ -39,7 +41,7 @@ class Callback: """ - def on_step_begin(self, iteration: int, trials: List[Trial], **info): + def on_step_begin(self, iteration: int, trials: List["Trial"], **info): """Called at the start of each tuning loop step. Arguments: @@ -49,7 +51,7 @@ class Callback: """ pass - def on_step_end(self, iteration: int, trials: List[Trial], **info): + def on_step_end(self, iteration: int, trials: List["Trial"], **info): """Called at the end of each tuning loop step. The iteration counter is increased before this hook is called. @@ -61,8 +63,8 @@ class Callback: """ pass - def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, - **info): + def on_trial_start(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): """Called after starting a trial instance. Arguments: @@ -74,8 +76,8 @@ class Callback: """ pass - def on_trial_restore(self, iteration: int, trials: List[Trial], - trial: Trial, **info): + def on_trial_restore(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): """Called after restoring a trial instance. Arguments: @@ -86,8 +88,8 @@ class Callback: """ pass - def on_trial_save(self, iteration: int, trials: List[Trial], trial: Trial, - **info): + def on_trial_save(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): """Called after receiving a checkpoint from a trial. Arguments: @@ -98,8 +100,8 @@ class Callback: """ pass - def on_trial_result(self, iteration: int, trials: List[Trial], - trial: Trial, result: Dict, **info): + def on_trial_result(self, iteration: int, trials: List["Trial"], + trial: "Trial", result: Dict, **info): """Called after receiving a result from a trial. The search algorithm and scheduler are notified before this @@ -114,8 +116,8 @@ class Callback: """ pass - def on_trial_complete(self, iteration: int, trials: List[Trial], - trial: Trial, **info): + def on_trial_complete(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): """Called after a trial instance completed. The search algorithm and scheduler are notified before this @@ -129,8 +131,8 @@ class Callback: """ pass - def on_trial_error(self, iteration: int, trials: List[Trial], trial: Trial, - **info): + def on_trial_error(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): """Called after a trial instance failed (errored). The search algorithm and scheduler are notified before this @@ -144,8 +146,8 @@ class Callback: """ pass - def on_checkpoint(self, iteration: int, trials: List[Trial], trial: Trial, - checkpoint: Checkpoint, **info): + def on_checkpoint(self, iteration: int, trials: List["Trial"], + trial: "Trial", checkpoint: Checkpoint, **info): """Called after a trial saved a checkpoint with Tune. Arguments: diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 4da1860b3..c3c501974 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -193,6 +193,5 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): loggers=spec.get("loggers"), log_to_file=spec.get("log_to_file"), # str(None) doesn't create None - sync_to_driver_fn=spec.get("sync_to_driver"), max_failures=args.max_failures, **trial_kwargs) diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 5df2a0d31..867fae618 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -1,21 +1,19 @@ import csv import json import logging -import numbers import numpy as np import os import yaml -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Type import ray.cloudpickle as cloudpickle + from ray.tune.utils.util import SafeFallbackEncoder from ray.util.debug import log_once -from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, - TIMESTEPS_TOTAL, EXPR_PARAM_FILE, - EXPR_PARAM_PICKLE_FILE, EXPR_PROGRESS_FILE, - EXPR_RESULT_FILE) -from ray.tune.syncer import get_node_syncer +from ray.tune.result import (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL, + EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE, + EXPR_PROGRESS_FILE, EXPR_RESULT_FILE) from ray.tune.utils import flatten_dict if TYPE_CHECKING: @@ -317,16 +315,13 @@ class UnifiedLogger(Logger): logdir: Directory for all logger creators to log to. loggers (list): List of logger creators. Defaults to CSV, Tensorboard, and JSON loggers. - sync_function (func|str): Optional function for syncer to run. - See ray/python/ray/tune/syncer.py """ def __init__(self, config: Dict, logdir: str, trial: Optional["Trial"] = None, - loggers: Optional[List[Type[Logger]]] = None, - sync_function: Union[None, Callable, str] = None): + loggers: Optional[List[Type[Logger]]] = None): if loggers is None: self._logger_cls_list = DEFAULT_LOGGERS else: @@ -336,8 +331,6 @@ class UnifiedLogger(Logger): logger.warning( "JsonLogger not provided. The ExperimentAnalysis tool is " "disabled.") - self._sync_function = sync_function - self._log_syncer = None super(UnifiedLogger, self).__init__(config, logdir, trial) @@ -350,16 +343,10 @@ class UnifiedLogger(Logger): if log_once(f"instantiate:{cls.__name__}"): logger.warning("Could not instantiate %s: %s.", cls.__name__, str(exc)) - self._log_syncer = get_node_syncer( - self.logdir, - remote_dir=self.logdir, - sync_function=self._sync_function) def on_result(self, result): for _logger in self._loggers: _logger.on_result(result) - self._log_syncer.set_worker_ip(result.get(NODE_IP)) - self._log_syncer.sync_down_if_needed() def update_config(self, config): for _logger in self._loggers: @@ -369,68 +356,9 @@ class UnifiedLogger(Logger): for _logger in self._loggers: _logger.close() - def flush(self, sync_down=True): + def flush(self): for _logger in self._loggers: _logger.flush() - if sync_down: - if not self._log_syncer.sync_down(): - logger.warning("Trial %s: Post-flush sync skipped.", - self.trial) - - def sync_up(self): - return self._log_syncer.sync_up() - - def sync_down(self): - return self._log_syncer.sync_down() - - def wait(self): - self._log_syncer.wait() - - def sync_results_to_new_location(self, worker_ip): - """Sends the current log directory to the remote node. - - Syncing will not occur if the cluster is not started - with the Ray autoscaler. - """ - if worker_ip != self._log_syncer.worker_ip: - logger.info("Trial %s: Syncing (blocking) results to %s", - self.trial, worker_ip) - self._log_syncer.reset() - self._log_syncer.set_worker_ip(worker_ip) - if not self._log_syncer.sync_up(): - logger.error( - "Trial %s: Sync up to new location skipped. " - "This should not occur.", self.trial) - self._log_syncer.wait() - else: - logger.error( - "Trial %s: Sync attempted to same IP %s. This " - "should not occur.", self.trial, worker_ip) - - -class _SafeFallbackEncoder(json.JSONEncoder): - def __init__(self, nan_str="null", **kwargs): - super(_SafeFallbackEncoder, self).__init__(**kwargs) - self.nan_str = nan_str - - def default(self, value): - try: - if np.isnan(value): - return self.nan_str - - if (type(value).__module__ == np.__name__ - and isinstance(value, np.ndarray)): - return value.tolist() - - if issubclass(type(value), numbers.Integral): - return int(value) - if issubclass(type(value), numbers.Number): - return float(value) - - return super(_SafeFallbackEncoder, self).default(value) - - except Exception: - return str(value) # give up, just stringify it (ok for logs) def pretty_print(result): @@ -442,5 +370,5 @@ def pretty_print(result): if v is not None: out[k] = v - cleaned = json.dumps(out, cls=_SafeFallbackEncoder) + cleaned = json.dumps(out, cls=SafeFallbackEncoder) return yaml.safe_dump(json.loads(cleaned), default_flow_style=False) diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index ea74c41de..e6180baa7 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Callable, Dict, List, TYPE_CHECKING, Union import distutils import logging @@ -9,13 +9,21 @@ from dataclasses import dataclass from inspect import isclass from shlex import quote +import ray from ray import services +from ray.tune import TuneError +from ray.tune.callback import Callback +from ray.tune.checkpoint_manager import Checkpoint +from ray.tune.result import NODE_IP from ray.util.debug import log_once from ray.tune.utils.util import env_integer from ray.tune.cluster_info import get_ssh_key, get_ssh_user from ray.tune.sync_client import (CommandBasedClient, get_sync_client, get_cloud_sync_client, NOOP) +if TYPE_CHECKING: + from ray.tune.trial import Trial + logger = logging.getLogger(__name__) # Syncing period for syncing local checkpoints to cloud. @@ -355,3 +363,88 @@ def get_node_syncer(local_dir, remote_dir=None, sync_function=None): _syncers[key] = NodeSyncer(local_dir, remote_dir, sync_client) return _syncers[key] + + +class SyncerCallback(Callback): + def __init__(self, sync_function: Union[None, bool, Callable]): + self._sync_function = sync_function + self._syncers: Dict["Trial", NodeSyncer] = {} + + def _get_trial_syncer(self, trial: "Trial"): + if trial not in self._syncers: + self._syncers[trial] = self._create_trial_syncer(trial) + return self._syncers[trial] + + def _create_trial_syncer(self, trial: "Trial"): + return get_node_syncer( + trial.logdir, + remote_dir=trial.logdir, + sync_function=self._sync_function) + + def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: Checkpoint): + if checkpoint.storage == Checkpoint.MEMORY: + return + + # Local import to avoid circular dependencies between syncer and + # trainable + from ray.tune.durable_trainable import DurableTrainable + + trial_syncer = self._get_trial_syncer(trial) + if trial.sync_on_checkpoint: + try: + # Wait for any other syncs to finish. We need to sync again + # after this to handle checkpoints taken mid-sync. + trial_syncer.wait() + except TuneError as e: + # Errors occurring during this wait are not fatal for this + # checkpoint, so it should just be logged. + logger.error( + "Trial %s: An error occurred during the " + "checkpoint pre-sync wait - %s", trial, str(e)) + # Force sync down and wait before tracking the new checkpoint. + try: + if trial_syncer.sync_down(): + trial_syncer.wait() + else: + logger.error( + "Trial %s: Checkpoint sync skipped. " + "This should not happen.", trial) + except TuneError as e: + if issubclass(trial.get_trainable_cls(), DurableTrainable): + # Even though rsync failed the trainable can restore + # from remote durable storage. + logger.error("Trial %s: Sync error - %s", trial, str(e)) + else: + # If the trainable didn't have remote storage to upload + # to then this checkpoint may have been lost, so we + # shouldn't track it with the checkpoint_manager. + raise e + if not issubclass(trial.get_trainable_cls(), DurableTrainable): + if not os.path.exists(checkpoint.value): + raise TuneError("Trial {}: Checkpoint path {} not " + "found after successful sync down.".format( + trial, checkpoint.value)) + + def on_trial_start(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): + self._get_trial_syncer(trial) + + def on_trial_result(self, iteration: int, trials: List["Trial"], + trial: "Trial", result: Dict, **info): + trial_syncer = self._get_trial_syncer(trial) + trial_syncer.set_worker_ip(result.get(NODE_IP)) + trial_syncer.sync_down_if_needed() + + def on_trial_complete(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): + trial_syncer = self._get_trial_syncer(trial) + if NODE_IP in trial.last_result: + trainable_ip = trial.last_result[NODE_IP] + else: + trainable_ip = ray.get(trial.runner.get_current_ip.remote()) + trial_syncer.set_worker_ip(trainable_ip) + trial_syncer.sync_down_if_needed() + + def on_checkpoint(self, iteration: int, trials: List["Trial"], + trial: "Trial", checkpoint: Checkpoint, **info): + self._sync_trial_checkpoint(trial, checkpoint) diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index c7ea4faf4..29408d2fe 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -1,12 +1,15 @@ import inspect import time import os + import pytest import shutil import subprocess import sys from unittest.mock import MagicMock, patch +from typing import Callable, Union + import ray from ray import tune from ray.rllib import _register_all @@ -18,7 +21,7 @@ from ray.tune.error import TuneError from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.resources import Resources from ray.tune.suggest import BasicVariantGenerator -from ray.tune.syncer import CloudSyncer +from ray.tune.syncer import CloudSyncer, SyncerCallback, get_node_syncer from ray.tune.utils.trainable import TrainableUtil from ray.tune.trial import Trial from ray.tune.trial_runner import TrialRunner @@ -55,6 +58,19 @@ def _start_new_cluster(): return cluster +class _PerTrialSyncerCallback(SyncerCallback): + def __init__( + self, + get_sync_fn: Callable[["Trial"], Union[None, bool, Callable]]): + self._get_sync_fn = get_sync_fn + super(_PerTrialSyncerCallback, self).__init__(None) + + def _create_trial_syncer(self, trial: "Trial"): + sync_fn = self._get_sync_fn(trial) + return get_node_syncer( + trial.logdir, remote_dir=trial.logdir, sync_function=sync_fn) + + @pytest.fixture def start_connected_cluster(): # Start the Ray processes. @@ -255,7 +271,9 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id): node = cluster.add_node(num_cpus=1) cluster.wait_for_nodes() - runner = TrialRunner(BasicVariantGenerator()) + syncer_callback = _PerTrialSyncerCallback( + lambda trial: trial.trainable_name == "__fake") + runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback]) kwargs = { "stopping_criterion": { "training_iteration": 4 @@ -263,7 +281,6 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id): "checkpoint_freq": 2, "max_failures": 2, "remote_checkpoint_dir": MOCK_REMOTE_DIR, - "sync_to_driver_fn": trainable_id == "__fake", } # Test recovery of trial that hasn't been checkpointed @@ -316,7 +333,6 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id): "training_iteration": 3 }, "remote_checkpoint_dir": MOCK_REMOTE_DIR, - "sync_to_driver_fn": trainable_id == "__fake", } t3 = Trial(trainable_id, **kwargs) runner.add_trial(t3) @@ -341,7 +357,9 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id): node = cluster.add_node(num_cpus=1) cluster.wait_for_nodes() - runner = TrialRunner(BasicVariantGenerator()) + syncer_callback = _PerTrialSyncerCallback( + lambda trial: trial.trainable_name == "__fake") + runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback]) kwargs = { "stopping_criterion": { "training_iteration": 5 @@ -349,7 +367,6 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id): "checkpoint_freq": 1, "max_failures": 1, "remote_checkpoint_dir": MOCK_REMOTE_DIR, - "sync_to_driver_fn": trainable_id == "__fake", } trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)] @@ -382,7 +399,13 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, node = cluster.add_node(num_cpus=1) cluster.wait_for_nodes() - runner = TrialRunner(BasicVariantGenerator()) + class _SyncerCallback(SyncerCallback): + def _create_trial_syncer(self, trial: "Trial"): + client = mock_storage_client() + return MockNodeSyncer(trial.logdir, trial.logdir, client) + + syncer_callback = _SyncerCallback(None) + runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback]) kwargs = { "stopping_criterion": { "training_iteration": 4 @@ -390,7 +413,6 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, "checkpoint_freq": 2, "max_failures": 2, "remote_checkpoint_dir": MOCK_REMOTE_DIR, - "sync_to_driver_fn": trainable_id == "__fake_remote", } # The following patches only affect __fake_remote. @@ -415,33 +437,26 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster, # TrainableUtil will not check this path unless we mock it. mock_find.side_effect = hide_remote_path(find_func) mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func) - with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer: - def mock_get_syncer_fn(local_dir, remote_dir, sync_function): - client = mock_storage_client() - return MockNodeSyncer(local_dir, remote_dir, client) + # Test recovery of trial that has been checkpointed + t1 = Trial(trainable_id, **kwargs) + runner.add_trial(t1) - mock_get_node_syncer.side_effect = mock_get_syncer_fn + # Start trial, process result (x2), process save + for _ in range(4): + runner.step() + assert t1.has_checkpoint() - # Test recovery of trial that has been checkpointed - t1 = Trial(trainable_id, **kwargs) - runner.add_trial(t1) - - # Start trial, process result (x2), process save - for _ in range(4): + cluster.add_node(num_cpus=1) + cluster.remove_node(node) + cluster.wait_for_nodes() + shutil.rmtree(os.path.dirname(t1.checkpoint.value)) + runner.step() # Collect result 3, kick off + fail result 4 + runner.step() # Dispatch restore + runner.step() # Process restore + step 4 + for _ in range(3): + if t1.status != Trial.TERMINATED: runner.step() - assert t1.has_checkpoint() - - cluster.add_node(num_cpus=1) - cluster.remove_node(node) - cluster.wait_for_nodes() - shutil.rmtree(os.path.dirname(t1.checkpoint.value)) - runner.step() # Collect result 3, kick off + fail result 4 - runner.step() # Dispatch restore - runner.step() # Process restore + step 4 - for _ in range(3): - if t1.status != Trial.TERMINATED: - runner.step() assert t1.status == Trial.TERMINATED, runner.debug_string() @@ -454,7 +469,12 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id): cluster.wait_for_nodes() dirpath = str(tmpdir) - runner = TrialRunner(local_checkpoint_dir=dirpath, checkpoint_period=0) + syncer_callback = _PerTrialSyncerCallback( + lambda trial: trial.trainable_name == "__fake") + runner = TrialRunner( + local_checkpoint_dir=dirpath, + checkpoint_period=0, + callbacks=[syncer_callback]) kwargs = { "stopping_criterion": { "training_iteration": 2 @@ -462,7 +482,6 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id): "checkpoint_freq": 1, "max_failures": 1, "remote_checkpoint_dir": MOCK_REMOTE_DIR, - "sync_to_driver_fn": trainable_id == "__fake", } trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)] for t in trials: diff --git a/python/ray/tune/tests/test_sync.py b/python/ray/tune/tests/test_sync.py index 4571104b9..a6159b890 100644 --- a/python/ray/tune/tests/test_sync.py +++ b/python/ray/tune/tests/test_sync.py @@ -11,7 +11,6 @@ import ray from ray.rllib import _register_all from ray import tune -from ray.tune import TuneError from ray.tune.syncer import CommandBasedClient @@ -87,8 +86,8 @@ class TestSyncFunctionality(unittest.TestCase): def testClusterProperString(self): """Tests that invalid commands throw..""" - with self.assertRaises(TuneError): - # This raises TuneError because logger is init in safe zone. + with self.assertRaises(ValueError): + # This raises ValueError because logger is init in safe zone. sync_config = tune.SyncConfig(sync_to_driver="ls {target}") [trial] = tune.run( "__fake", @@ -100,8 +99,8 @@ class TestSyncFunctionality(unittest.TestCase): sync_config=sync_config, ).trials - with self.assertRaises(TuneError): - # This raises TuneError because logger is init in safe zone. + with self.assertRaises(ValueError): + # This raises ValueError because logger is init in safe zone. sync_config = tune.SyncConfig(sync_to_driver="ls {source}") [trial] = tune.run( "__fake", diff --git a/python/ray/tune/tests/test_tune_restore.py b/python/ray/tune/tests/test_tune_restore.py index d94e1a7ae..556c28e8f 100644 --- a/python/ray/tune/tests/test_tune_restore.py +++ b/python/ray/tune/tests/test_tune_restore.py @@ -124,11 +124,7 @@ class TuneExampleTest(unittest.TestCase): class AutoInitTest(unittest.TestCase): def testTuneRestore(self): self.assertFalse(ray.is_initialized()) - tune.run( - "__fake", - name="TestAutoInit", - stop={"training_iteration": 1}, - ray_auto_init=True) + tune.run("__fake", name="TestAutoInit", stop={"training_iteration": 1}) self.assertTrue(ray.is_initialized()) def tearDown(self): diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index c69e3b362..686526d6e 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -12,7 +12,6 @@ import os from numbers import Number from ray.tune import TuneError from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager -from ray.tune.durable_trainable import DurableTrainable from ray.tune.logger import pretty_print, UnifiedLogger # NOTE(rkn): We import ray.tune.registry here instead of importing the names we # need because there are cyclic imports that may cause specific names to not @@ -192,7 +191,6 @@ class Trial: trial_dirname_creator=None, loggers=None, log_to_file=None, - sync_to_driver_fn=None, max_failures=0): """Initialize a new trial. @@ -232,7 +230,6 @@ class Trial: or not len(self.log_to_file) == 2: self.log_to_file = (None, None) - self.sync_to_driver_fn = sync_to_driver_fn self.verbose = True self.max_failures = max_failures @@ -289,7 +286,6 @@ class Trial: self._nonjson_fields = [ "loggers", - "sync_to_driver_fn", "results", "best_result", "param_config", @@ -356,7 +352,6 @@ class Trial: trial_name_creator=self.trial_name_creator, loggers=self.loggers, log_to_file=self.log_to_file, - sync_to_driver_fn=self.sync_to_driver_fn, max_failures=self.max_failures, ) @@ -370,11 +365,7 @@ class Trial: os.makedirs(self.logdir, exist_ok=True) self.result_logger = UnifiedLogger( - self.config, - self.logdir, - trial=self, - loggers=self.loggers, - sync_function=self.sync_to_driver_fn) + self.config, self.logdir, trial=self, loggers=self.loggers) def update_resources(self, cpu, gpu, **kwargs): """EXPERIMENTAL: Updates the resource requirements. @@ -459,43 +450,6 @@ class Trial: Args: checkpoint (Checkpoint): Checkpoint taken. """ - if checkpoint.storage == Checkpoint.MEMORY: - self.checkpoint_manager.on_checkpoint(checkpoint) - return - if self.sync_on_checkpoint: - try: - # Wait for any other syncs to finish. We need to sync again - # after this to handle checkpoints taken mid-sync. - self.result_logger.wait() - except TuneError as e: - # Errors occurring during this wait are not fatal for this - # checkpoint, so it should just be logged. - logger.error( - "Trial %s: An error occurred during the " - "checkpoint pre-sync wait - %s", self, str(e)) - # Force sync down and wait before tracking the new checkpoint. - try: - if self.result_logger.sync_down(): - self.result_logger.wait() - else: - logger.error( - "Trial %s: Checkpoint sync skipped. " - "This should not happen.", self) - except TuneError as e: - if issubclass(self.get_trainable_cls(), DurableTrainable): - # Even though rsync failed the trainable can restore - # from remote durable storage. - logger.error("Trial %s: Sync error - %s", self, str(e)) - else: - # If the trainable didn't have remote storage to upload - # to then this checkpoint may have been lost, so we - # shouldn't track it with the checkpoint_manager. - raise e - if not issubclass(self.get_trainable_cls(), DurableTrainable): - if not os.path.exists(checkpoint.value): - raise TuneError("Trial {}: Checkpoint path {} not " - "found after successful sync down.".format( - self, checkpoint.value)) self.checkpoint_manager.on_checkpoint(checkpoint) def on_restore(self): @@ -515,7 +469,6 @@ class Trial: return self.num_failures < self.max_failures or self.max_failures < 0 def update_last_result(self, result, terminate=False): - result.update(trial_id=self.trial_id, done=terminate) if self.experiment_tag: result.update(experiment_tag=self.experiment_tag) if self.verbose and (terminate or time.time() - self.last_debug > @@ -634,7 +587,7 @@ class Trial: state["resuming_from"] = None state["saving_to"] = None if self.result_logger: - self.result_logger.flush(sync_down=False) + self.result_logger.flush() state["__logger_started__"] = True else: state["__logger_started__"] = False diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 4a0486e7e..9c3aefd14 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -11,8 +11,10 @@ from ray.tune.suggest.variant_generator import has_unresolved_values from ray.tune.trial import Trial from ray.tune.trainable import Trainable from ray.tune.ray_trial_executor import RayTrialExecutor +from ray.tune.utils.callback import create_default_callbacks from ray.tune.registry import get_trainable_cls -from ray.tune.syncer import wait_for_sync, set_sync_periods, SyncConfig +from ray.tune.syncer import wait_for_sync, set_sync_periods, \ + SyncConfig from ray.tune.trial_runner import TrialRunner from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter from ray.tune.schedulers import FIFOScheduler @@ -353,6 +355,9 @@ def run( "own `metric` and `mode` parameters. Either remove the arguments " "from your scheduler or from your call to `tune.run()`") + # Create syncer callbacks + callbacks = create_default_callbacks(callbacks, sync_config) + runner = TrialRunner( search_alg=search_alg, scheduler=scheduler, diff --git a/python/ray/tune/utils/callback.py b/python/ray/tune/utils/callback.py new file mode 100644 index 000000000..e54d9d842 --- /dev/null +++ b/python/ray/tune/utils/callback.py @@ -0,0 +1,24 @@ +import os +from typing import List, Optional + +from ray.tune.callback import Callback +from ray.tune.syncer import SyncConfig +from ray.tune.syncer import SyncerCallback + + +def create_default_callbacks(callbacks: Optional[List[Callback]], + sync_config: SyncConfig): + + callbacks = callbacks or [] + + # Check if there is a SyncerCallback + has_syncer_callback = any(isinstance(c, SyncerCallback) for c in callbacks) + + # If no SyncerCallback was found, add + if not has_syncer_callback and os.environ.get( + "TUNE_DISABLE_AUTO_CALLBACK_SYNCER", "0") != "1": + syncer_callback = SyncerCallback( + sync_function=sync_config.sync_to_driver) + callbacks.append(syncer_callback) + + return callbacks