[tune] logger refactor part 2: Add SyncerCallback (#11748)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke 2020-11-04 06:04:40 +01:00 committed by GitHub
parent 05c4e3fb2a
commit 007634fd1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 211 additions and 193 deletions

View file

@ -1,7 +1,9 @@
from typing import Dict, List
from typing import TYPE_CHECKING, Dict, List
from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.trial import Trial
if TYPE_CHECKING:
from ray.tune.trial import Trial
class Callback:
@ -39,7 +41,7 @@ class Callback:
"""
def on_step_begin(self, iteration: int, trials: List[Trial], **info):
def on_step_begin(self, iteration: int, trials: List["Trial"], **info):
"""Called at the start of each tuning loop step.
Arguments:
@ -49,7 +51,7 @@ class Callback:
"""
pass
def on_step_end(self, iteration: int, trials: List[Trial], **info):
def on_step_end(self, iteration: int, trials: List["Trial"], **info):
"""Called at the end of each tuning loop step.
The iteration counter is increased before this hook is called.
@ -61,8 +63,8 @@ class Callback:
"""
pass
def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial,
**info):
def on_trial_start(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
"""Called after starting a trial instance.
Arguments:
@ -74,8 +76,8 @@ class Callback:
"""
pass
def on_trial_restore(self, iteration: int, trials: List[Trial],
trial: Trial, **info):
def on_trial_restore(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
"""Called after restoring a trial instance.
Arguments:
@ -86,8 +88,8 @@ class Callback:
"""
pass
def on_trial_save(self, iteration: int, trials: List[Trial], trial: Trial,
**info):
def on_trial_save(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
"""Called after receiving a checkpoint from a trial.
Arguments:
@ -98,8 +100,8 @@ class Callback:
"""
pass
def on_trial_result(self, iteration: int, trials: List[Trial],
trial: Trial, result: Dict, **info):
def on_trial_result(self, iteration: int, trials: List["Trial"],
trial: "Trial", result: Dict, **info):
"""Called after receiving a result from a trial.
The search algorithm and scheduler are notified before this
@ -114,8 +116,8 @@ class Callback:
"""
pass
def on_trial_complete(self, iteration: int, trials: List[Trial],
trial: Trial, **info):
def on_trial_complete(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
"""Called after a trial instance completed.
The search algorithm and scheduler are notified before this
@ -129,8 +131,8 @@ class Callback:
"""
pass
def on_trial_error(self, iteration: int, trials: List[Trial], trial: Trial,
**info):
def on_trial_error(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
"""Called after a trial instance failed (errored).
The search algorithm and scheduler are notified before this
@ -144,8 +146,8 @@ class Callback:
"""
pass
def on_checkpoint(self, iteration: int, trials: List[Trial], trial: Trial,
checkpoint: Checkpoint, **info):
def on_checkpoint(self, iteration: int, trials: List["Trial"],
trial: "Trial", checkpoint: Checkpoint, **info):
"""Called after a trial saved a checkpoint with Tune.
Arguments:

View file

@ -193,6 +193,5 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
loggers=spec.get("loggers"),
log_to_file=spec.get("log_to_file"),
# str(None) doesn't create None
sync_to_driver_fn=spec.get("sync_to_driver"),
max_failures=args.max_failures,
**trial_kwargs)

View file

@ -1,21 +1,19 @@
import csv
import json
import logging
import numbers
import numpy as np
import os
import yaml
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Type
import ray.cloudpickle as cloudpickle
from ray.tune.utils.util import SafeFallbackEncoder
from ray.util.debug import log_once
from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S,
TIMESTEPS_TOTAL, EXPR_PARAM_FILE,
EXPR_PARAM_PICKLE_FILE, EXPR_PROGRESS_FILE,
EXPR_RESULT_FILE)
from ray.tune.syncer import get_node_syncer
from ray.tune.result import (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE,
EXPR_PROGRESS_FILE, EXPR_RESULT_FILE)
from ray.tune.utils import flatten_dict
if TYPE_CHECKING:
@ -317,16 +315,13 @@ class UnifiedLogger(Logger):
logdir: Directory for all logger creators to log to.
loggers (list): List of logger creators. Defaults to CSV, Tensorboard,
and JSON loggers.
sync_function (func|str): Optional function for syncer to run.
See ray/python/ray/tune/syncer.py
"""
def __init__(self,
config: Dict,
logdir: str,
trial: Optional["Trial"] = None,
loggers: Optional[List[Type[Logger]]] = None,
sync_function: Union[None, Callable, str] = None):
loggers: Optional[List[Type[Logger]]] = None):
if loggers is None:
self._logger_cls_list = DEFAULT_LOGGERS
else:
@ -336,8 +331,6 @@ class UnifiedLogger(Logger):
logger.warning(
"JsonLogger not provided. The ExperimentAnalysis tool is "
"disabled.")
self._sync_function = sync_function
self._log_syncer = None
super(UnifiedLogger, self).__init__(config, logdir, trial)
@ -350,16 +343,10 @@ class UnifiedLogger(Logger):
if log_once(f"instantiate:{cls.__name__}"):
logger.warning("Could not instantiate %s: %s.",
cls.__name__, str(exc))
self._log_syncer = get_node_syncer(
self.logdir,
remote_dir=self.logdir,
sync_function=self._sync_function)
def on_result(self, result):
for _logger in self._loggers:
_logger.on_result(result)
self._log_syncer.set_worker_ip(result.get(NODE_IP))
self._log_syncer.sync_down_if_needed()
def update_config(self, config):
for _logger in self._loggers:
@ -369,68 +356,9 @@ class UnifiedLogger(Logger):
for _logger in self._loggers:
_logger.close()
def flush(self, sync_down=True):
def flush(self):
for _logger in self._loggers:
_logger.flush()
if sync_down:
if not self._log_syncer.sync_down():
logger.warning("Trial %s: Post-flush sync skipped.",
self.trial)
def sync_up(self):
return self._log_syncer.sync_up()
def sync_down(self):
return self._log_syncer.sync_down()
def wait(self):
self._log_syncer.wait()
def sync_results_to_new_location(self, worker_ip):
"""Sends the current log directory to the remote node.
Syncing will not occur if the cluster is not started
with the Ray autoscaler.
"""
if worker_ip != self._log_syncer.worker_ip:
logger.info("Trial %s: Syncing (blocking) results to %s",
self.trial, worker_ip)
self._log_syncer.reset()
self._log_syncer.set_worker_ip(worker_ip)
if not self._log_syncer.sync_up():
logger.error(
"Trial %s: Sync up to new location skipped. "
"This should not occur.", self.trial)
self._log_syncer.wait()
else:
logger.error(
"Trial %s: Sync attempted to same IP %s. This "
"should not occur.", self.trial, worker_ip)
class _SafeFallbackEncoder(json.JSONEncoder):
def __init__(self, nan_str="null", **kwargs):
super(_SafeFallbackEncoder, self).__init__(**kwargs)
self.nan_str = nan_str
def default(self, value):
try:
if np.isnan(value):
return self.nan_str
if (type(value).__module__ == np.__name__
and isinstance(value, np.ndarray)):
return value.tolist()
if issubclass(type(value), numbers.Integral):
return int(value)
if issubclass(type(value), numbers.Number):
return float(value)
return super(_SafeFallbackEncoder, self).default(value)
except Exception:
return str(value) # give up, just stringify it (ok for logs)
def pretty_print(result):
@ -442,5 +370,5 @@ def pretty_print(result):
if v is not None:
out[k] = v
cleaned = json.dumps(out, cls=_SafeFallbackEncoder)
cleaned = json.dumps(out, cls=SafeFallbackEncoder)
return yaml.safe_dump(json.loads(cleaned), default_flow_style=False)

View file

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Union
import distutils
import logging
@ -9,13 +9,21 @@ from dataclasses import dataclass
from inspect import isclass
from shlex import quote
import ray
from ray import services
from ray.tune import TuneError
from ray.tune.callback import Callback
from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.result import NODE_IP
from ray.util.debug import log_once
from ray.tune.utils.util import env_integer
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.sync_client import (CommandBasedClient, get_sync_client,
get_cloud_sync_client, NOOP)
if TYPE_CHECKING:
from ray.tune.trial import Trial
logger = logging.getLogger(__name__)
# Syncing period for syncing local checkpoints to cloud.
@ -355,3 +363,88 @@ def get_node_syncer(local_dir, remote_dir=None, sync_function=None):
_syncers[key] = NodeSyncer(local_dir, remote_dir, sync_client)
return _syncers[key]
class SyncerCallback(Callback):
def __init__(self, sync_function: Union[None, bool, Callable]):
self._sync_function = sync_function
self._syncers: Dict["Trial", NodeSyncer] = {}
def _get_trial_syncer(self, trial: "Trial"):
if trial not in self._syncers:
self._syncers[trial] = self._create_trial_syncer(trial)
return self._syncers[trial]
def _create_trial_syncer(self, trial: "Trial"):
return get_node_syncer(
trial.logdir,
remote_dir=trial.logdir,
sync_function=self._sync_function)
def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: Checkpoint):
if checkpoint.storage == Checkpoint.MEMORY:
return
# Local import to avoid circular dependencies between syncer and
# trainable
from ray.tune.durable_trainable import DurableTrainable
trial_syncer = self._get_trial_syncer(trial)
if trial.sync_on_checkpoint:
try:
# Wait for any other syncs to finish. We need to sync again
# after this to handle checkpoints taken mid-sync.
trial_syncer.wait()
except TuneError as e:
# Errors occurring during this wait are not fatal for this
# checkpoint, so it should just be logged.
logger.error(
"Trial %s: An error occurred during the "
"checkpoint pre-sync wait - %s", trial, str(e))
# Force sync down and wait before tracking the new checkpoint.
try:
if trial_syncer.sync_down():
trial_syncer.wait()
else:
logger.error(
"Trial %s: Checkpoint sync skipped. "
"This should not happen.", trial)
except TuneError as e:
if issubclass(trial.get_trainable_cls(), DurableTrainable):
# Even though rsync failed the trainable can restore
# from remote durable storage.
logger.error("Trial %s: Sync error - %s", trial, str(e))
else:
# If the trainable didn't have remote storage to upload
# to then this checkpoint may have been lost, so we
# shouldn't track it with the checkpoint_manager.
raise e
if not issubclass(trial.get_trainable_cls(), DurableTrainable):
if not os.path.exists(checkpoint.value):
raise TuneError("Trial {}: Checkpoint path {} not "
"found after successful sync down.".format(
trial, checkpoint.value))
def on_trial_start(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
self._get_trial_syncer(trial)
def on_trial_result(self, iteration: int, trials: List["Trial"],
trial: "Trial", result: Dict, **info):
trial_syncer = self._get_trial_syncer(trial)
trial_syncer.set_worker_ip(result.get(NODE_IP))
trial_syncer.sync_down_if_needed()
def on_trial_complete(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
trial_syncer = self._get_trial_syncer(trial)
if NODE_IP in trial.last_result:
trainable_ip = trial.last_result[NODE_IP]
else:
trainable_ip = ray.get(trial.runner.get_current_ip.remote())
trial_syncer.set_worker_ip(trainable_ip)
trial_syncer.sync_down_if_needed()
def on_checkpoint(self, iteration: int, trials: List["Trial"],
trial: "Trial", checkpoint: Checkpoint, **info):
self._sync_trial_checkpoint(trial, checkpoint)

View file

@ -1,12 +1,15 @@
import inspect
import time
import os
import pytest
import shutil
import subprocess
import sys
from unittest.mock import MagicMock, patch
from typing import Callable, Union
import ray
from ray import tune
from ray.rllib import _register_all
@ -18,7 +21,7 @@ from ray.tune.error import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.resources import Resources
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.syncer import CloudSyncer
from ray.tune.syncer import CloudSyncer, SyncerCallback, get_node_syncer
from ray.tune.utils.trainable import TrainableUtil
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
@ -55,6 +58,19 @@ def _start_new_cluster():
return cluster
class _PerTrialSyncerCallback(SyncerCallback):
def __init__(
self,
get_sync_fn: Callable[["Trial"], Union[None, bool, Callable]]):
self._get_sync_fn = get_sync_fn
super(_PerTrialSyncerCallback, self).__init__(None)
def _create_trial_syncer(self, trial: "Trial"):
sync_fn = self._get_sync_fn(trial)
return get_node_syncer(
trial.logdir, remote_dir=trial.logdir, sync_function=sync_fn)
@pytest.fixture
def start_connected_cluster():
# Start the Ray processes.
@ -255,7 +271,9 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator())
syncer_callback = _PerTrialSyncerCallback(
lambda trial: trial.trainable_name == "__fake")
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
kwargs = {
"stopping_criterion": {
"training_iteration": 4
@ -263,7 +281,6 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
"checkpoint_freq": 2,
"max_failures": 2,
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
}
# Test recovery of trial that hasn't been checkpointed
@ -316,7 +333,6 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
"training_iteration": 3
},
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
}
t3 = Trial(trainable_id, **kwargs)
runner.add_trial(t3)
@ -341,7 +357,9 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator())
syncer_callback = _PerTrialSyncerCallback(
lambda trial: trial.trainable_name == "__fake")
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
kwargs = {
"stopping_criterion": {
"training_iteration": 5
@ -349,7 +367,6 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
"checkpoint_freq": 1,
"max_failures": 1,
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
}
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
@ -382,7 +399,13 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator())
class _SyncerCallback(SyncerCallback):
def _create_trial_syncer(self, trial: "Trial"):
client = mock_storage_client()
return MockNodeSyncer(trial.logdir, trial.logdir, client)
syncer_callback = _SyncerCallback(None)
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
kwargs = {
"stopping_criterion": {
"training_iteration": 4
@ -390,7 +413,6 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
"checkpoint_freq": 2,
"max_failures": 2,
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake_remote",
}
# The following patches only affect __fake_remote.
@ -415,33 +437,26 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
# TrainableUtil will not check this path unless we mock it.
mock_find.side_effect = hide_remote_path(find_func)
mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func)
with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer:
def mock_get_syncer_fn(local_dir, remote_dir, sync_function):
client = mock_storage_client()
return MockNodeSyncer(local_dir, remote_dir, client)
# Test recovery of trial that has been checkpointed
t1 = Trial(trainable_id, **kwargs)
runner.add_trial(t1)
mock_get_node_syncer.side_effect = mock_get_syncer_fn
# Start trial, process result (x2), process save
for _ in range(4):
runner.step()
assert t1.has_checkpoint()
# Test recovery of trial that has been checkpointed
t1 = Trial(trainable_id, **kwargs)
runner.add_trial(t1)
# Start trial, process result (x2), process save
for _ in range(4):
cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
runner.step() # Collect result 3, kick off + fail result 4
runner.step() # Dispatch restore
runner.step() # Process restore + step 4
for _ in range(3):
if t1.status != Trial.TERMINATED:
runner.step()
assert t1.has_checkpoint()
cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
runner.step() # Collect result 3, kick off + fail result 4
runner.step() # Dispatch restore
runner.step() # Process restore + step 4
for _ in range(3):
if t1.status != Trial.TERMINATED:
runner.step()
assert t1.status == Trial.TERMINATED, runner.debug_string()
@ -454,7 +469,12 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
cluster.wait_for_nodes()
dirpath = str(tmpdir)
runner = TrialRunner(local_checkpoint_dir=dirpath, checkpoint_period=0)
syncer_callback = _PerTrialSyncerCallback(
lambda trial: trial.trainable_name == "__fake")
runner = TrialRunner(
local_checkpoint_dir=dirpath,
checkpoint_period=0,
callbacks=[syncer_callback])
kwargs = {
"stopping_criterion": {
"training_iteration": 2
@ -462,7 +482,6 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
"checkpoint_freq": 1,
"max_failures": 1,
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
}
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
for t in trials:

View file

@ -11,7 +11,6 @@ import ray
from ray.rllib import _register_all
from ray import tune
from ray.tune import TuneError
from ray.tune.syncer import CommandBasedClient
@ -87,8 +86,8 @@ class TestSyncFunctionality(unittest.TestCase):
def testClusterProperString(self):
"""Tests that invalid commands throw.."""
with self.assertRaises(TuneError):
# This raises TuneError because logger is init in safe zone.
with self.assertRaises(ValueError):
# This raises ValueError because logger is init in safe zone.
sync_config = tune.SyncConfig(sync_to_driver="ls {target}")
[trial] = tune.run(
"__fake",
@ -100,8 +99,8 @@ class TestSyncFunctionality(unittest.TestCase):
sync_config=sync_config,
).trials
with self.assertRaises(TuneError):
# This raises TuneError because logger is init in safe zone.
with self.assertRaises(ValueError):
# This raises ValueError because logger is init in safe zone.
sync_config = tune.SyncConfig(sync_to_driver="ls {source}")
[trial] = tune.run(
"__fake",

View file

@ -124,11 +124,7 @@ class TuneExampleTest(unittest.TestCase):
class AutoInitTest(unittest.TestCase):
def testTuneRestore(self):
self.assertFalse(ray.is_initialized())
tune.run(
"__fake",
name="TestAutoInit",
stop={"training_iteration": 1},
ray_auto_init=True)
tune.run("__fake", name="TestAutoInit", stop={"training_iteration": 1})
self.assertTrue(ray.is_initialized())
def tearDown(self):

View file

@ -12,7 +12,6 @@ import os
from numbers import Number
from ray.tune import TuneError
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.logger import pretty_print, UnifiedLogger
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
# need because there are cyclic imports that may cause specific names to not
@ -192,7 +191,6 @@ class Trial:
trial_dirname_creator=None,
loggers=None,
log_to_file=None,
sync_to_driver_fn=None,
max_failures=0):
"""Initialize a new trial.
@ -232,7 +230,6 @@ class Trial:
or not len(self.log_to_file) == 2:
self.log_to_file = (None, None)
self.sync_to_driver_fn = sync_to_driver_fn
self.verbose = True
self.max_failures = max_failures
@ -289,7 +286,6 @@ class Trial:
self._nonjson_fields = [
"loggers",
"sync_to_driver_fn",
"results",
"best_result",
"param_config",
@ -356,7 +352,6 @@ class Trial:
trial_name_creator=self.trial_name_creator,
loggers=self.loggers,
log_to_file=self.log_to_file,
sync_to_driver_fn=self.sync_to_driver_fn,
max_failures=self.max_failures,
)
@ -370,11 +365,7 @@ class Trial:
os.makedirs(self.logdir, exist_ok=True)
self.result_logger = UnifiedLogger(
self.config,
self.logdir,
trial=self,
loggers=self.loggers,
sync_function=self.sync_to_driver_fn)
self.config, self.logdir, trial=self, loggers=self.loggers)
def update_resources(self, cpu, gpu, **kwargs):
"""EXPERIMENTAL: Updates the resource requirements.
@ -459,43 +450,6 @@ class Trial:
Args:
checkpoint (Checkpoint): Checkpoint taken.
"""
if checkpoint.storage == Checkpoint.MEMORY:
self.checkpoint_manager.on_checkpoint(checkpoint)
return
if self.sync_on_checkpoint:
try:
# Wait for any other syncs to finish. We need to sync again
# after this to handle checkpoints taken mid-sync.
self.result_logger.wait()
except TuneError as e:
# Errors occurring during this wait are not fatal for this
# checkpoint, so it should just be logged.
logger.error(
"Trial %s: An error occurred during the "
"checkpoint pre-sync wait - %s", self, str(e))
# Force sync down and wait before tracking the new checkpoint.
try:
if self.result_logger.sync_down():
self.result_logger.wait()
else:
logger.error(
"Trial %s: Checkpoint sync skipped. "
"This should not happen.", self)
except TuneError as e:
if issubclass(self.get_trainable_cls(), DurableTrainable):
# Even though rsync failed the trainable can restore
# from remote durable storage.
logger.error("Trial %s: Sync error - %s", self, str(e))
else:
# If the trainable didn't have remote storage to upload
# to then this checkpoint may have been lost, so we
# shouldn't track it with the checkpoint_manager.
raise e
if not issubclass(self.get_trainable_cls(), DurableTrainable):
if not os.path.exists(checkpoint.value):
raise TuneError("Trial {}: Checkpoint path {} not "
"found after successful sync down.".format(
self, checkpoint.value))
self.checkpoint_manager.on_checkpoint(checkpoint)
def on_restore(self):
@ -515,7 +469,6 @@ class Trial:
return self.num_failures < self.max_failures or self.max_failures < 0
def update_last_result(self, result, terminate=False):
result.update(trial_id=self.trial_id, done=terminate)
if self.experiment_tag:
result.update(experiment_tag=self.experiment_tag)
if self.verbose and (terminate or time.time() - self.last_debug >
@ -634,7 +587,7 @@ class Trial:
state["resuming_from"] = None
state["saving_to"] = None
if self.result_logger:
self.result_logger.flush(sync_down=False)
self.result_logger.flush()
state["__logger_started__"] = True
else:
state["__logger_started__"] = False

View file

@ -11,8 +11,10 @@ from ray.tune.suggest.variant_generator import has_unresolved_values
from ray.tune.trial import Trial
from ray.tune.trainable import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.utils.callback import create_default_callbacks
from ray.tune.registry import get_trainable_cls
from ray.tune.syncer import wait_for_sync, set_sync_periods, SyncConfig
from ray.tune.syncer import wait_for_sync, set_sync_periods, \
SyncConfig
from ray.tune.trial_runner import TrialRunner
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import FIFOScheduler
@ -353,6 +355,9 @@ def run(
"own `metric` and `mode` parameters. Either remove the arguments "
"from your scheduler or from your call to `tune.run()`")
# Create syncer callbacks
callbacks = create_default_callbacks(callbacks, sync_config)
runner = TrialRunner(
search_alg=search_alg,
scheduler=scheduler,

View file

@ -0,0 +1,24 @@
import os
from typing import List, Optional
from ray.tune.callback import Callback
from ray.tune.syncer import SyncConfig
from ray.tune.syncer import SyncerCallback
def create_default_callbacks(callbacks: Optional[List[Callback]],
sync_config: SyncConfig):
callbacks = callbacks or []
# Check if there is a SyncerCallback
has_syncer_callback = any(isinstance(c, SyncerCallback) for c in callbacks)
# If no SyncerCallback was found, add
if not has_syncer_callback and os.environ.get(
"TUNE_DISABLE_AUTO_CALLBACK_SYNCER", "0") != "1":
syncer_callback = SyncerCallback(
sync_function=sync_config.sync_to_driver)
callbacks.append(syncer_callback)
return callbacks