[tune] logger refactor part 2: Add SyncerCallback (#11748)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke 2020-11-04 06:04:40 +01:00 committed by GitHub
parent 05c4e3fb2a
commit 007634fd1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 211 additions and 193 deletions

View file

@ -1,7 +1,9 @@
from typing import Dict, List from typing import TYPE_CHECKING, Dict, List
from ray.tune.checkpoint_manager import Checkpoint from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.trial import Trial
if TYPE_CHECKING:
from ray.tune.trial import Trial
class Callback: class Callback:
@ -39,7 +41,7 @@ class Callback:
""" """
def on_step_begin(self, iteration: int, trials: List[Trial], **info): def on_step_begin(self, iteration: int, trials: List["Trial"], **info):
"""Called at the start of each tuning loop step. """Called at the start of each tuning loop step.
Arguments: Arguments:
@ -49,7 +51,7 @@ class Callback:
""" """
pass pass
def on_step_end(self, iteration: int, trials: List[Trial], **info): def on_step_end(self, iteration: int, trials: List["Trial"], **info):
"""Called at the end of each tuning loop step. """Called at the end of each tuning loop step.
The iteration counter is increased before this hook is called. The iteration counter is increased before this hook is called.
@ -61,8 +63,8 @@ class Callback:
""" """
pass pass
def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, def on_trial_start(self, iteration: int, trials: List["Trial"],
**info): trial: "Trial", **info):
"""Called after starting a trial instance. """Called after starting a trial instance.
Arguments: Arguments:
@ -74,8 +76,8 @@ class Callback:
""" """
pass pass
def on_trial_restore(self, iteration: int, trials: List[Trial], def on_trial_restore(self, iteration: int, trials: List["Trial"],
trial: Trial, **info): trial: "Trial", **info):
"""Called after restoring a trial instance. """Called after restoring a trial instance.
Arguments: Arguments:
@ -86,8 +88,8 @@ class Callback:
""" """
pass pass
def on_trial_save(self, iteration: int, trials: List[Trial], trial: Trial, def on_trial_save(self, iteration: int, trials: List["Trial"],
**info): trial: "Trial", **info):
"""Called after receiving a checkpoint from a trial. """Called after receiving a checkpoint from a trial.
Arguments: Arguments:
@ -98,8 +100,8 @@ class Callback:
""" """
pass pass
def on_trial_result(self, iteration: int, trials: List[Trial], def on_trial_result(self, iteration: int, trials: List["Trial"],
trial: Trial, result: Dict, **info): trial: "Trial", result: Dict, **info):
"""Called after receiving a result from a trial. """Called after receiving a result from a trial.
The search algorithm and scheduler are notified before this The search algorithm and scheduler are notified before this
@ -114,8 +116,8 @@ class Callback:
""" """
pass pass
def on_trial_complete(self, iteration: int, trials: List[Trial], def on_trial_complete(self, iteration: int, trials: List["Trial"],
trial: Trial, **info): trial: "Trial", **info):
"""Called after a trial instance completed. """Called after a trial instance completed.
The search algorithm and scheduler are notified before this The search algorithm and scheduler are notified before this
@ -129,8 +131,8 @@ class Callback:
""" """
pass pass
def on_trial_error(self, iteration: int, trials: List[Trial], trial: Trial, def on_trial_error(self, iteration: int, trials: List["Trial"],
**info): trial: "Trial", **info):
"""Called after a trial instance failed (errored). """Called after a trial instance failed (errored).
The search algorithm and scheduler are notified before this The search algorithm and scheduler are notified before this
@ -144,8 +146,8 @@ class Callback:
""" """
pass pass
def on_checkpoint(self, iteration: int, trials: List[Trial], trial: Trial, def on_checkpoint(self, iteration: int, trials: List["Trial"],
checkpoint: Checkpoint, **info): trial: "Trial", checkpoint: Checkpoint, **info):
"""Called after a trial saved a checkpoint with Tune. """Called after a trial saved a checkpoint with Tune.
Arguments: Arguments:

View file

@ -193,6 +193,5 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
loggers=spec.get("loggers"), loggers=spec.get("loggers"),
log_to_file=spec.get("log_to_file"), log_to_file=spec.get("log_to_file"),
# str(None) doesn't create None # str(None) doesn't create None
sync_to_driver_fn=spec.get("sync_to_driver"),
max_failures=args.max_failures, max_failures=args.max_failures,
**trial_kwargs) **trial_kwargs)

View file

@ -1,21 +1,19 @@
import csv import csv
import json import json
import logging import logging
import numbers
import numpy as np import numpy as np
import os import os
import yaml import yaml
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, Union from typing import TYPE_CHECKING, Dict, List, Optional, Type
import ray.cloudpickle as cloudpickle import ray.cloudpickle as cloudpickle
from ray.tune.utils.util import SafeFallbackEncoder from ray.tune.utils.util import SafeFallbackEncoder
from ray.util.debug import log_once from ray.util.debug import log_once
from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, from ray.tune.result import (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
TIMESTEPS_TOTAL, EXPR_PARAM_FILE, EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE,
EXPR_PARAM_PICKLE_FILE, EXPR_PROGRESS_FILE, EXPR_PROGRESS_FILE, EXPR_RESULT_FILE)
EXPR_RESULT_FILE)
from ray.tune.syncer import get_node_syncer
from ray.tune.utils import flatten_dict from ray.tune.utils import flatten_dict
if TYPE_CHECKING: if TYPE_CHECKING:
@ -317,16 +315,13 @@ class UnifiedLogger(Logger):
logdir: Directory for all logger creators to log to. logdir: Directory for all logger creators to log to.
loggers (list): List of logger creators. Defaults to CSV, Tensorboard, loggers (list): List of logger creators. Defaults to CSV, Tensorboard,
and JSON loggers. and JSON loggers.
sync_function (func|str): Optional function for syncer to run.
See ray/python/ray/tune/syncer.py
""" """
def __init__(self, def __init__(self,
config: Dict, config: Dict,
logdir: str, logdir: str,
trial: Optional["Trial"] = None, trial: Optional["Trial"] = None,
loggers: Optional[List[Type[Logger]]] = None, loggers: Optional[List[Type[Logger]]] = None):
sync_function: Union[None, Callable, str] = None):
if loggers is None: if loggers is None:
self._logger_cls_list = DEFAULT_LOGGERS self._logger_cls_list = DEFAULT_LOGGERS
else: else:
@ -336,8 +331,6 @@ class UnifiedLogger(Logger):
logger.warning( logger.warning(
"JsonLogger not provided. The ExperimentAnalysis tool is " "JsonLogger not provided. The ExperimentAnalysis tool is "
"disabled.") "disabled.")
self._sync_function = sync_function
self._log_syncer = None
super(UnifiedLogger, self).__init__(config, logdir, trial) super(UnifiedLogger, self).__init__(config, logdir, trial)
@ -350,16 +343,10 @@ class UnifiedLogger(Logger):
if log_once(f"instantiate:{cls.__name__}"): if log_once(f"instantiate:{cls.__name__}"):
logger.warning("Could not instantiate %s: %s.", logger.warning("Could not instantiate %s: %s.",
cls.__name__, str(exc)) cls.__name__, str(exc))
self._log_syncer = get_node_syncer(
self.logdir,
remote_dir=self.logdir,
sync_function=self._sync_function)
def on_result(self, result): def on_result(self, result):
for _logger in self._loggers: for _logger in self._loggers:
_logger.on_result(result) _logger.on_result(result)
self._log_syncer.set_worker_ip(result.get(NODE_IP))
self._log_syncer.sync_down_if_needed()
def update_config(self, config): def update_config(self, config):
for _logger in self._loggers: for _logger in self._loggers:
@ -369,68 +356,9 @@ class UnifiedLogger(Logger):
for _logger in self._loggers: for _logger in self._loggers:
_logger.close() _logger.close()
def flush(self, sync_down=True): def flush(self):
for _logger in self._loggers: for _logger in self._loggers:
_logger.flush() _logger.flush()
if sync_down:
if not self._log_syncer.sync_down():
logger.warning("Trial %s: Post-flush sync skipped.",
self.trial)
def sync_up(self):
return self._log_syncer.sync_up()
def sync_down(self):
return self._log_syncer.sync_down()
def wait(self):
self._log_syncer.wait()
def sync_results_to_new_location(self, worker_ip):
"""Sends the current log directory to the remote node.
Syncing will not occur if the cluster is not started
with the Ray autoscaler.
"""
if worker_ip != self._log_syncer.worker_ip:
logger.info("Trial %s: Syncing (blocking) results to %s",
self.trial, worker_ip)
self._log_syncer.reset()
self._log_syncer.set_worker_ip(worker_ip)
if not self._log_syncer.sync_up():
logger.error(
"Trial %s: Sync up to new location skipped. "
"This should not occur.", self.trial)
self._log_syncer.wait()
else:
logger.error(
"Trial %s: Sync attempted to same IP %s. This "
"should not occur.", self.trial, worker_ip)
class _SafeFallbackEncoder(json.JSONEncoder):
def __init__(self, nan_str="null", **kwargs):
super(_SafeFallbackEncoder, self).__init__(**kwargs)
self.nan_str = nan_str
def default(self, value):
try:
if np.isnan(value):
return self.nan_str
if (type(value).__module__ == np.__name__
and isinstance(value, np.ndarray)):
return value.tolist()
if issubclass(type(value), numbers.Integral):
return int(value)
if issubclass(type(value), numbers.Number):
return float(value)
return super(_SafeFallbackEncoder, self).default(value)
except Exception:
return str(value) # give up, just stringify it (ok for logs)
def pretty_print(result): def pretty_print(result):
@ -442,5 +370,5 @@ def pretty_print(result):
if v is not None: if v is not None:
out[k] = v out[k] = v
cleaned = json.dumps(out, cls=_SafeFallbackEncoder) cleaned = json.dumps(out, cls=SafeFallbackEncoder)
return yaml.safe_dump(json.loads(cleaned), default_flow_style=False) return yaml.safe_dump(json.loads(cleaned), default_flow_style=False)

View file

@ -1,4 +1,4 @@
from typing import Any from typing import Any, Callable, Dict, List, TYPE_CHECKING, Union
import distutils import distutils
import logging import logging
@ -9,13 +9,21 @@ from dataclasses import dataclass
from inspect import isclass from inspect import isclass
from shlex import quote from shlex import quote
import ray
from ray import services from ray import services
from ray.tune import TuneError
from ray.tune.callback import Callback
from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.result import NODE_IP
from ray.util.debug import log_once from ray.util.debug import log_once
from ray.tune.utils.util import env_integer from ray.tune.utils.util import env_integer
from ray.tune.cluster_info import get_ssh_key, get_ssh_user from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.sync_client import (CommandBasedClient, get_sync_client, from ray.tune.sync_client import (CommandBasedClient, get_sync_client,
get_cloud_sync_client, NOOP) get_cloud_sync_client, NOOP)
if TYPE_CHECKING:
from ray.tune.trial import Trial
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Syncing period for syncing local checkpoints to cloud. # Syncing period for syncing local checkpoints to cloud.
@ -355,3 +363,88 @@ def get_node_syncer(local_dir, remote_dir=None, sync_function=None):
_syncers[key] = NodeSyncer(local_dir, remote_dir, sync_client) _syncers[key] = NodeSyncer(local_dir, remote_dir, sync_client)
return _syncers[key] return _syncers[key]
class SyncerCallback(Callback):
def __init__(self, sync_function: Union[None, bool, Callable]):
self._sync_function = sync_function
self._syncers: Dict["Trial", NodeSyncer] = {}
def _get_trial_syncer(self, trial: "Trial"):
if trial not in self._syncers:
self._syncers[trial] = self._create_trial_syncer(trial)
return self._syncers[trial]
def _create_trial_syncer(self, trial: "Trial"):
return get_node_syncer(
trial.logdir,
remote_dir=trial.logdir,
sync_function=self._sync_function)
def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: Checkpoint):
if checkpoint.storage == Checkpoint.MEMORY:
return
# Local import to avoid circular dependencies between syncer and
# trainable
from ray.tune.durable_trainable import DurableTrainable
trial_syncer = self._get_trial_syncer(trial)
if trial.sync_on_checkpoint:
try:
# Wait for any other syncs to finish. We need to sync again
# after this to handle checkpoints taken mid-sync.
trial_syncer.wait()
except TuneError as e:
# Errors occurring during this wait are not fatal for this
# checkpoint, so it should just be logged.
logger.error(
"Trial %s: An error occurred during the "
"checkpoint pre-sync wait - %s", trial, str(e))
# Force sync down and wait before tracking the new checkpoint.
try:
if trial_syncer.sync_down():
trial_syncer.wait()
else:
logger.error(
"Trial %s: Checkpoint sync skipped. "
"This should not happen.", trial)
except TuneError as e:
if issubclass(trial.get_trainable_cls(), DurableTrainable):
# Even though rsync failed the trainable can restore
# from remote durable storage.
logger.error("Trial %s: Sync error - %s", trial, str(e))
else:
# If the trainable didn't have remote storage to upload
# to then this checkpoint may have been lost, so we
# shouldn't track it with the checkpoint_manager.
raise e
if not issubclass(trial.get_trainable_cls(), DurableTrainable):
if not os.path.exists(checkpoint.value):
raise TuneError("Trial {}: Checkpoint path {} not "
"found after successful sync down.".format(
trial, checkpoint.value))
def on_trial_start(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
self._get_trial_syncer(trial)
def on_trial_result(self, iteration: int, trials: List["Trial"],
trial: "Trial", result: Dict, **info):
trial_syncer = self._get_trial_syncer(trial)
trial_syncer.set_worker_ip(result.get(NODE_IP))
trial_syncer.sync_down_if_needed()
def on_trial_complete(self, iteration: int, trials: List["Trial"],
trial: "Trial", **info):
trial_syncer = self._get_trial_syncer(trial)
if NODE_IP in trial.last_result:
trainable_ip = trial.last_result[NODE_IP]
else:
trainable_ip = ray.get(trial.runner.get_current_ip.remote())
trial_syncer.set_worker_ip(trainable_ip)
trial_syncer.sync_down_if_needed()
def on_checkpoint(self, iteration: int, trials: List["Trial"],
trial: "Trial", checkpoint: Checkpoint, **info):
self._sync_trial_checkpoint(trial, checkpoint)

View file

@ -1,12 +1,15 @@
import inspect import inspect
import time import time
import os import os
import pytest import pytest
import shutil import shutil
import subprocess import subprocess
import sys import sys
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from typing import Callable, Union
import ray import ray
from ray import tune from ray import tune
from ray.rllib import _register_all from ray.rllib import _register_all
@ -18,7 +21,7 @@ from ray.tune.error import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.resources import Resources from ray.tune.resources import Resources
from ray.tune.suggest import BasicVariantGenerator from ray.tune.suggest import BasicVariantGenerator
from ray.tune.syncer import CloudSyncer from ray.tune.syncer import CloudSyncer, SyncerCallback, get_node_syncer
from ray.tune.utils.trainable import TrainableUtil from ray.tune.utils.trainable import TrainableUtil
from ray.tune.trial import Trial from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner from ray.tune.trial_runner import TrialRunner
@ -55,6 +58,19 @@ def _start_new_cluster():
return cluster return cluster
class _PerTrialSyncerCallback(SyncerCallback):
def __init__(
self,
get_sync_fn: Callable[["Trial"], Union[None, bool, Callable]]):
self._get_sync_fn = get_sync_fn
super(_PerTrialSyncerCallback, self).__init__(None)
def _create_trial_syncer(self, trial: "Trial"):
sync_fn = self._get_sync_fn(trial)
return get_node_syncer(
trial.logdir, remote_dir=trial.logdir, sync_function=sync_fn)
@pytest.fixture @pytest.fixture
def start_connected_cluster(): def start_connected_cluster():
# Start the Ray processes. # Start the Ray processes.
@ -255,7 +271,9 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
node = cluster.add_node(num_cpus=1) node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes() cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator()) syncer_callback = _PerTrialSyncerCallback(
lambda trial: trial.trainable_name == "__fake")
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
kwargs = { kwargs = {
"stopping_criterion": { "stopping_criterion": {
"training_iteration": 4 "training_iteration": 4
@ -263,7 +281,6 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
"checkpoint_freq": 2, "checkpoint_freq": 2,
"max_failures": 2, "max_failures": 2,
"remote_checkpoint_dir": MOCK_REMOTE_DIR, "remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
} }
# Test recovery of trial that hasn't been checkpointed # Test recovery of trial that hasn't been checkpointed
@ -316,7 +333,6 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
"training_iteration": 3 "training_iteration": 3
}, },
"remote_checkpoint_dir": MOCK_REMOTE_DIR, "remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
} }
t3 = Trial(trainable_id, **kwargs) t3 = Trial(trainable_id, **kwargs)
runner.add_trial(t3) runner.add_trial(t3)
@ -341,7 +357,9 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
node = cluster.add_node(num_cpus=1) node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes() cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator()) syncer_callback = _PerTrialSyncerCallback(
lambda trial: trial.trainable_name == "__fake")
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
kwargs = { kwargs = {
"stopping_criterion": { "stopping_criterion": {
"training_iteration": 5 "training_iteration": 5
@ -349,7 +367,6 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
"checkpoint_freq": 1, "checkpoint_freq": 1,
"max_failures": 1, "max_failures": 1,
"remote_checkpoint_dir": MOCK_REMOTE_DIR, "remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
} }
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)] trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
@ -382,7 +399,13 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
node = cluster.add_node(num_cpus=1) node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes() cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator()) class _SyncerCallback(SyncerCallback):
def _create_trial_syncer(self, trial: "Trial"):
client = mock_storage_client()
return MockNodeSyncer(trial.logdir, trial.logdir, client)
syncer_callback = _SyncerCallback(None)
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
kwargs = { kwargs = {
"stopping_criterion": { "stopping_criterion": {
"training_iteration": 4 "training_iteration": 4
@ -390,7 +413,6 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
"checkpoint_freq": 2, "checkpoint_freq": 2,
"max_failures": 2, "max_failures": 2,
"remote_checkpoint_dir": MOCK_REMOTE_DIR, "remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake_remote",
} }
# The following patches only affect __fake_remote. # The following patches only affect __fake_remote.
@ -415,33 +437,26 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
# TrainableUtil will not check this path unless we mock it. # TrainableUtil will not check this path unless we mock it.
mock_find.side_effect = hide_remote_path(find_func) mock_find.side_effect = hide_remote_path(find_func)
mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func) mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func)
with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer:
def mock_get_syncer_fn(local_dir, remote_dir, sync_function): # Test recovery of trial that has been checkpointed
client = mock_storage_client() t1 = Trial(trainable_id, **kwargs)
return MockNodeSyncer(local_dir, remote_dir, client) runner.add_trial(t1)
mock_get_node_syncer.side_effect = mock_get_syncer_fn # Start trial, process result (x2), process save
for _ in range(4):
runner.step()
assert t1.has_checkpoint()
# Test recovery of trial that has been checkpointed cluster.add_node(num_cpus=1)
t1 = Trial(trainable_id, **kwargs) cluster.remove_node(node)
runner.add_trial(t1) cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
# Start trial, process result (x2), process save runner.step() # Collect result 3, kick off + fail result 4
for _ in range(4): runner.step() # Dispatch restore
runner.step() # Process restore + step 4
for _ in range(3):
if t1.status != Trial.TERMINATED:
runner.step() runner.step()
assert t1.has_checkpoint()
cluster.add_node(num_cpus=1)
cluster.remove_node(node)
cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
runner.step() # Collect result 3, kick off + fail result 4
runner.step() # Dispatch restore
runner.step() # Process restore + step 4
for _ in range(3):
if t1.status != Trial.TERMINATED:
runner.step()
assert t1.status == Trial.TERMINATED, runner.debug_string() assert t1.status == Trial.TERMINATED, runner.debug_string()
@ -454,7 +469,12 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
cluster.wait_for_nodes() cluster.wait_for_nodes()
dirpath = str(tmpdir) dirpath = str(tmpdir)
runner = TrialRunner(local_checkpoint_dir=dirpath, checkpoint_period=0) syncer_callback = _PerTrialSyncerCallback(
lambda trial: trial.trainable_name == "__fake")
runner = TrialRunner(
local_checkpoint_dir=dirpath,
checkpoint_period=0,
callbacks=[syncer_callback])
kwargs = { kwargs = {
"stopping_criterion": { "stopping_criterion": {
"training_iteration": 2 "training_iteration": 2
@ -462,7 +482,6 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
"checkpoint_freq": 1, "checkpoint_freq": 1,
"max_failures": 1, "max_failures": 1,
"remote_checkpoint_dir": MOCK_REMOTE_DIR, "remote_checkpoint_dir": MOCK_REMOTE_DIR,
"sync_to_driver_fn": trainable_id == "__fake",
} }
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)] trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
for t in trials: for t in trials:

View file

@ -11,7 +11,6 @@ import ray
from ray.rllib import _register_all from ray.rllib import _register_all
from ray import tune from ray import tune
from ray.tune import TuneError
from ray.tune.syncer import CommandBasedClient from ray.tune.syncer import CommandBasedClient
@ -87,8 +86,8 @@ class TestSyncFunctionality(unittest.TestCase):
def testClusterProperString(self): def testClusterProperString(self):
"""Tests that invalid commands throw..""" """Tests that invalid commands throw.."""
with self.assertRaises(TuneError): with self.assertRaises(ValueError):
# This raises TuneError because logger is init in safe zone. # This raises ValueError because logger is init in safe zone.
sync_config = tune.SyncConfig(sync_to_driver="ls {target}") sync_config = tune.SyncConfig(sync_to_driver="ls {target}")
[trial] = tune.run( [trial] = tune.run(
"__fake", "__fake",
@ -100,8 +99,8 @@ class TestSyncFunctionality(unittest.TestCase):
sync_config=sync_config, sync_config=sync_config,
).trials ).trials
with self.assertRaises(TuneError): with self.assertRaises(ValueError):
# This raises TuneError because logger is init in safe zone. # This raises ValueError because logger is init in safe zone.
sync_config = tune.SyncConfig(sync_to_driver="ls {source}") sync_config = tune.SyncConfig(sync_to_driver="ls {source}")
[trial] = tune.run( [trial] = tune.run(
"__fake", "__fake",

View file

@ -124,11 +124,7 @@ class TuneExampleTest(unittest.TestCase):
class AutoInitTest(unittest.TestCase): class AutoInitTest(unittest.TestCase):
def testTuneRestore(self): def testTuneRestore(self):
self.assertFalse(ray.is_initialized()) self.assertFalse(ray.is_initialized())
tune.run( tune.run("__fake", name="TestAutoInit", stop={"training_iteration": 1})
"__fake",
name="TestAutoInit",
stop={"training_iteration": 1},
ray_auto_init=True)
self.assertTrue(ray.is_initialized()) self.assertTrue(ray.is_initialized())
def tearDown(self): def tearDown(self):

View file

@ -12,7 +12,6 @@ import os
from numbers import Number from numbers import Number
from ray.tune import TuneError from ray.tune import TuneError
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.logger import pretty_print, UnifiedLogger from ray.tune.logger import pretty_print, UnifiedLogger
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we # NOTE(rkn): We import ray.tune.registry here instead of importing the names we
# need because there are cyclic imports that may cause specific names to not # need because there are cyclic imports that may cause specific names to not
@ -192,7 +191,6 @@ class Trial:
trial_dirname_creator=None, trial_dirname_creator=None,
loggers=None, loggers=None,
log_to_file=None, log_to_file=None,
sync_to_driver_fn=None,
max_failures=0): max_failures=0):
"""Initialize a new trial. """Initialize a new trial.
@ -232,7 +230,6 @@ class Trial:
or not len(self.log_to_file) == 2: or not len(self.log_to_file) == 2:
self.log_to_file = (None, None) self.log_to_file = (None, None)
self.sync_to_driver_fn = sync_to_driver_fn
self.verbose = True self.verbose = True
self.max_failures = max_failures self.max_failures = max_failures
@ -289,7 +286,6 @@ class Trial:
self._nonjson_fields = [ self._nonjson_fields = [
"loggers", "loggers",
"sync_to_driver_fn",
"results", "results",
"best_result", "best_result",
"param_config", "param_config",
@ -356,7 +352,6 @@ class Trial:
trial_name_creator=self.trial_name_creator, trial_name_creator=self.trial_name_creator,
loggers=self.loggers, loggers=self.loggers,
log_to_file=self.log_to_file, log_to_file=self.log_to_file,
sync_to_driver_fn=self.sync_to_driver_fn,
max_failures=self.max_failures, max_failures=self.max_failures,
) )
@ -370,11 +365,7 @@ class Trial:
os.makedirs(self.logdir, exist_ok=True) os.makedirs(self.logdir, exist_ok=True)
self.result_logger = UnifiedLogger( self.result_logger = UnifiedLogger(
self.config, self.config, self.logdir, trial=self, loggers=self.loggers)
self.logdir,
trial=self,
loggers=self.loggers,
sync_function=self.sync_to_driver_fn)
def update_resources(self, cpu, gpu, **kwargs): def update_resources(self, cpu, gpu, **kwargs):
"""EXPERIMENTAL: Updates the resource requirements. """EXPERIMENTAL: Updates the resource requirements.
@ -459,43 +450,6 @@ class Trial:
Args: Args:
checkpoint (Checkpoint): Checkpoint taken. checkpoint (Checkpoint): Checkpoint taken.
""" """
if checkpoint.storage == Checkpoint.MEMORY:
self.checkpoint_manager.on_checkpoint(checkpoint)
return
if self.sync_on_checkpoint:
try:
# Wait for any other syncs to finish. We need to sync again
# after this to handle checkpoints taken mid-sync.
self.result_logger.wait()
except TuneError as e:
# Errors occurring during this wait are not fatal for this
# checkpoint, so it should just be logged.
logger.error(
"Trial %s: An error occurred during the "
"checkpoint pre-sync wait - %s", self, str(e))
# Force sync down and wait before tracking the new checkpoint.
try:
if self.result_logger.sync_down():
self.result_logger.wait()
else:
logger.error(
"Trial %s: Checkpoint sync skipped. "
"This should not happen.", self)
except TuneError as e:
if issubclass(self.get_trainable_cls(), DurableTrainable):
# Even though rsync failed the trainable can restore
# from remote durable storage.
logger.error("Trial %s: Sync error - %s", self, str(e))
else:
# If the trainable didn't have remote storage to upload
# to then this checkpoint may have been lost, so we
# shouldn't track it with the checkpoint_manager.
raise e
if not issubclass(self.get_trainable_cls(), DurableTrainable):
if not os.path.exists(checkpoint.value):
raise TuneError("Trial {}: Checkpoint path {} not "
"found after successful sync down.".format(
self, checkpoint.value))
self.checkpoint_manager.on_checkpoint(checkpoint) self.checkpoint_manager.on_checkpoint(checkpoint)
def on_restore(self): def on_restore(self):
@ -515,7 +469,6 @@ class Trial:
return self.num_failures < self.max_failures or self.max_failures < 0 return self.num_failures < self.max_failures or self.max_failures < 0
def update_last_result(self, result, terminate=False): def update_last_result(self, result, terminate=False):
result.update(trial_id=self.trial_id, done=terminate)
if self.experiment_tag: if self.experiment_tag:
result.update(experiment_tag=self.experiment_tag) result.update(experiment_tag=self.experiment_tag)
if self.verbose and (terminate or time.time() - self.last_debug > if self.verbose and (terminate or time.time() - self.last_debug >
@ -634,7 +587,7 @@ class Trial:
state["resuming_from"] = None state["resuming_from"] = None
state["saving_to"] = None state["saving_to"] = None
if self.result_logger: if self.result_logger:
self.result_logger.flush(sync_down=False) self.result_logger.flush()
state["__logger_started__"] = True state["__logger_started__"] = True
else: else:
state["__logger_started__"] = False state["__logger_started__"] = False

View file

@ -11,8 +11,10 @@ from ray.tune.suggest.variant_generator import has_unresolved_values
from ray.tune.trial import Trial from ray.tune.trial import Trial
from ray.tune.trainable import Trainable from ray.tune.trainable import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.utils.callback import create_default_callbacks
from ray.tune.registry import get_trainable_cls from ray.tune.registry import get_trainable_cls
from ray.tune.syncer import wait_for_sync, set_sync_periods, SyncConfig from ray.tune.syncer import wait_for_sync, set_sync_periods, \
SyncConfig
from ray.tune.trial_runner import TrialRunner from ray.tune.trial_runner import TrialRunner
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import FIFOScheduler from ray.tune.schedulers import FIFOScheduler
@ -353,6 +355,9 @@ def run(
"own `metric` and `mode` parameters. Either remove the arguments " "own `metric` and `mode` parameters. Either remove the arguments "
"from your scheduler or from your call to `tune.run()`") "from your scheduler or from your call to `tune.run()`")
# Create syncer callbacks
callbacks = create_default_callbacks(callbacks, sync_config)
runner = TrialRunner( runner = TrialRunner(
search_alg=search_alg, search_alg=search_alg,
scheduler=scheduler, scheduler=scheduler,

View file

@ -0,0 +1,24 @@
import os
from typing import List, Optional
from ray.tune.callback import Callback
from ray.tune.syncer import SyncConfig
from ray.tune.syncer import SyncerCallback
def create_default_callbacks(callbacks: Optional[List[Callback]],
sync_config: SyncConfig):
callbacks = callbacks or []
# Check if there is a SyncerCallback
has_syncer_callback = any(isinstance(c, SyncerCallback) for c in callbacks)
# If no SyncerCallback was found, add
if not has_syncer_callback and os.environ.get(
"TUNE_DISABLE_AUTO_CALLBACK_SYNCER", "0") != "1":
syncer_callback = SyncerCallback(
sync_function=sync_config.sync_to_driver)
callbacks.append(syncer_callback)
return callbacks