mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[tune] logger refactor part 2: Add SyncerCallback (#11748)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
05c4e3fb2a
commit
007634fd1b
10 changed files with 211 additions and 193 deletions
|
@ -1,7 +1,9 @@
|
|||
from typing import Dict, List
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
from ray.tune.checkpoint_manager import Checkpoint
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
|
||||
class Callback:
|
||||
|
@ -39,7 +41,7 @@ class Callback:
|
|||
|
||||
"""
|
||||
|
||||
def on_step_begin(self, iteration: int, trials: List[Trial], **info):
|
||||
def on_step_begin(self, iteration: int, trials: List["Trial"], **info):
|
||||
"""Called at the start of each tuning loop step.
|
||||
|
||||
Arguments:
|
||||
|
@ -49,7 +51,7 @@ class Callback:
|
|||
"""
|
||||
pass
|
||||
|
||||
def on_step_end(self, iteration: int, trials: List[Trial], **info):
|
||||
def on_step_end(self, iteration: int, trials: List["Trial"], **info):
|
||||
"""Called at the end of each tuning loop step.
|
||||
|
||||
The iteration counter is increased before this hook is called.
|
||||
|
@ -61,8 +63,8 @@ class Callback:
|
|||
"""
|
||||
pass
|
||||
|
||||
def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial,
|
||||
**info):
|
||||
def on_trial_start(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
"""Called after starting a trial instance.
|
||||
|
||||
Arguments:
|
||||
|
@ -74,8 +76,8 @@ class Callback:
|
|||
"""
|
||||
pass
|
||||
|
||||
def on_trial_restore(self, iteration: int, trials: List[Trial],
|
||||
trial: Trial, **info):
|
||||
def on_trial_restore(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
"""Called after restoring a trial instance.
|
||||
|
||||
Arguments:
|
||||
|
@ -86,8 +88,8 @@ class Callback:
|
|||
"""
|
||||
pass
|
||||
|
||||
def on_trial_save(self, iteration: int, trials: List[Trial], trial: Trial,
|
||||
**info):
|
||||
def on_trial_save(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
"""Called after receiving a checkpoint from a trial.
|
||||
|
||||
Arguments:
|
||||
|
@ -98,8 +100,8 @@ class Callback:
|
|||
"""
|
||||
pass
|
||||
|
||||
def on_trial_result(self, iteration: int, trials: List[Trial],
|
||||
trial: Trial, result: Dict, **info):
|
||||
def on_trial_result(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", result: Dict, **info):
|
||||
"""Called after receiving a result from a trial.
|
||||
|
||||
The search algorithm and scheduler are notified before this
|
||||
|
@ -114,8 +116,8 @@ class Callback:
|
|||
"""
|
||||
pass
|
||||
|
||||
def on_trial_complete(self, iteration: int, trials: List[Trial],
|
||||
trial: Trial, **info):
|
||||
def on_trial_complete(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
"""Called after a trial instance completed.
|
||||
|
||||
The search algorithm and scheduler are notified before this
|
||||
|
@ -129,8 +131,8 @@ class Callback:
|
|||
"""
|
||||
pass
|
||||
|
||||
def on_trial_error(self, iteration: int, trials: List[Trial], trial: Trial,
|
||||
**info):
|
||||
def on_trial_error(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
"""Called after a trial instance failed (errored).
|
||||
|
||||
The search algorithm and scheduler are notified before this
|
||||
|
@ -144,8 +146,8 @@ class Callback:
|
|||
"""
|
||||
pass
|
||||
|
||||
def on_checkpoint(self, iteration: int, trials: List[Trial], trial: Trial,
|
||||
checkpoint: Checkpoint, **info):
|
||||
def on_checkpoint(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", checkpoint: Checkpoint, **info):
|
||||
"""Called after a trial saved a checkpoint with Tune.
|
||||
|
||||
Arguments:
|
||||
|
|
|
@ -193,6 +193,5 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
|||
loggers=spec.get("loggers"),
|
||||
log_to_file=spec.get("log_to_file"),
|
||||
# str(None) doesn't create None
|
||||
sync_to_driver_fn=spec.get("sync_to_driver"),
|
||||
max_failures=args.max_failures,
|
||||
**trial_kwargs)
|
||||
|
|
|
@ -1,21 +1,19 @@
|
|||
import csv
|
||||
import json
|
||||
import logging
|
||||
import numbers
|
||||
import numpy as np
|
||||
import os
|
||||
import yaml
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
|
||||
from ray.tune.utils.util import SafeFallbackEncoder
|
||||
from ray.util.debug import log_once
|
||||
from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S,
|
||||
TIMESTEPS_TOTAL, EXPR_PARAM_FILE,
|
||||
EXPR_PARAM_PICKLE_FILE, EXPR_PROGRESS_FILE,
|
||||
EXPR_RESULT_FILE)
|
||||
from ray.tune.syncer import get_node_syncer
|
||||
from ray.tune.result import (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
|
||||
EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE,
|
||||
EXPR_PROGRESS_FILE, EXPR_RESULT_FILE)
|
||||
from ray.tune.utils import flatten_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -317,16 +315,13 @@ class UnifiedLogger(Logger):
|
|||
logdir: Directory for all logger creators to log to.
|
||||
loggers (list): List of logger creators. Defaults to CSV, Tensorboard,
|
||||
and JSON loggers.
|
||||
sync_function (func|str): Optional function for syncer to run.
|
||||
See ray/python/ray/tune/syncer.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: Dict,
|
||||
logdir: str,
|
||||
trial: Optional["Trial"] = None,
|
||||
loggers: Optional[List[Type[Logger]]] = None,
|
||||
sync_function: Union[None, Callable, str] = None):
|
||||
loggers: Optional[List[Type[Logger]]] = None):
|
||||
if loggers is None:
|
||||
self._logger_cls_list = DEFAULT_LOGGERS
|
||||
else:
|
||||
|
@ -336,8 +331,6 @@ class UnifiedLogger(Logger):
|
|||
logger.warning(
|
||||
"JsonLogger not provided. The ExperimentAnalysis tool is "
|
||||
"disabled.")
|
||||
self._sync_function = sync_function
|
||||
self._log_syncer = None
|
||||
|
||||
super(UnifiedLogger, self).__init__(config, logdir, trial)
|
||||
|
||||
|
@ -350,16 +343,10 @@ class UnifiedLogger(Logger):
|
|||
if log_once(f"instantiate:{cls.__name__}"):
|
||||
logger.warning("Could not instantiate %s: %s.",
|
||||
cls.__name__, str(exc))
|
||||
self._log_syncer = get_node_syncer(
|
||||
self.logdir,
|
||||
remote_dir=self.logdir,
|
||||
sync_function=self._sync_function)
|
||||
|
||||
def on_result(self, result):
|
||||
for _logger in self._loggers:
|
||||
_logger.on_result(result)
|
||||
self._log_syncer.set_worker_ip(result.get(NODE_IP))
|
||||
self._log_syncer.sync_down_if_needed()
|
||||
|
||||
def update_config(self, config):
|
||||
for _logger in self._loggers:
|
||||
|
@ -369,68 +356,9 @@ class UnifiedLogger(Logger):
|
|||
for _logger in self._loggers:
|
||||
_logger.close()
|
||||
|
||||
def flush(self, sync_down=True):
|
||||
def flush(self):
|
||||
for _logger in self._loggers:
|
||||
_logger.flush()
|
||||
if sync_down:
|
||||
if not self._log_syncer.sync_down():
|
||||
logger.warning("Trial %s: Post-flush sync skipped.",
|
||||
self.trial)
|
||||
|
||||
def sync_up(self):
|
||||
return self._log_syncer.sync_up()
|
||||
|
||||
def sync_down(self):
|
||||
return self._log_syncer.sync_down()
|
||||
|
||||
def wait(self):
|
||||
self._log_syncer.wait()
|
||||
|
||||
def sync_results_to_new_location(self, worker_ip):
|
||||
"""Sends the current log directory to the remote node.
|
||||
|
||||
Syncing will not occur if the cluster is not started
|
||||
with the Ray autoscaler.
|
||||
"""
|
||||
if worker_ip != self._log_syncer.worker_ip:
|
||||
logger.info("Trial %s: Syncing (blocking) results to %s",
|
||||
self.trial, worker_ip)
|
||||
self._log_syncer.reset()
|
||||
self._log_syncer.set_worker_ip(worker_ip)
|
||||
if not self._log_syncer.sync_up():
|
||||
logger.error(
|
||||
"Trial %s: Sync up to new location skipped. "
|
||||
"This should not occur.", self.trial)
|
||||
self._log_syncer.wait()
|
||||
else:
|
||||
logger.error(
|
||||
"Trial %s: Sync attempted to same IP %s. This "
|
||||
"should not occur.", self.trial, worker_ip)
|
||||
|
||||
|
||||
class _SafeFallbackEncoder(json.JSONEncoder):
|
||||
def __init__(self, nan_str="null", **kwargs):
|
||||
super(_SafeFallbackEncoder, self).__init__(**kwargs)
|
||||
self.nan_str = nan_str
|
||||
|
||||
def default(self, value):
|
||||
try:
|
||||
if np.isnan(value):
|
||||
return self.nan_str
|
||||
|
||||
if (type(value).__module__ == np.__name__
|
||||
and isinstance(value, np.ndarray)):
|
||||
return value.tolist()
|
||||
|
||||
if issubclass(type(value), numbers.Integral):
|
||||
return int(value)
|
||||
if issubclass(type(value), numbers.Number):
|
||||
return float(value)
|
||||
|
||||
return super(_SafeFallbackEncoder, self).default(value)
|
||||
|
||||
except Exception:
|
||||
return str(value) # give up, just stringify it (ok for logs)
|
||||
|
||||
|
||||
def pretty_print(result):
|
||||
|
@ -442,5 +370,5 @@ def pretty_print(result):
|
|||
if v is not None:
|
||||
out[k] = v
|
||||
|
||||
cleaned = json.dumps(out, cls=_SafeFallbackEncoder)
|
||||
cleaned = json.dumps(out, cls=SafeFallbackEncoder)
|
||||
return yaml.safe_dump(json.loads(cleaned), default_flow_style=False)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any
|
||||
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Union
|
||||
|
||||
import distutils
|
||||
import logging
|
||||
|
@ -9,13 +9,21 @@ from dataclasses import dataclass
|
|||
from inspect import isclass
|
||||
from shlex import quote
|
||||
|
||||
import ray
|
||||
from ray import services
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.checkpoint_manager import Checkpoint
|
||||
from ray.tune.result import NODE_IP
|
||||
from ray.util.debug import log_once
|
||||
from ray.tune.utils.util import env_integer
|
||||
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
||||
from ray.tune.sync_client import (CommandBasedClient, get_sync_client,
|
||||
get_cloud_sync_client, NOOP)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Syncing period for syncing local checkpoints to cloud.
|
||||
|
@ -355,3 +363,88 @@ def get_node_syncer(local_dir, remote_dir=None, sync_function=None):
|
|||
|
||||
_syncers[key] = NodeSyncer(local_dir, remote_dir, sync_client)
|
||||
return _syncers[key]
|
||||
|
||||
|
||||
class SyncerCallback(Callback):
|
||||
def __init__(self, sync_function: Union[None, bool, Callable]):
|
||||
self._sync_function = sync_function
|
||||
self._syncers: Dict["Trial", NodeSyncer] = {}
|
||||
|
||||
def _get_trial_syncer(self, trial: "Trial"):
|
||||
if trial not in self._syncers:
|
||||
self._syncers[trial] = self._create_trial_syncer(trial)
|
||||
return self._syncers[trial]
|
||||
|
||||
def _create_trial_syncer(self, trial: "Trial"):
|
||||
return get_node_syncer(
|
||||
trial.logdir,
|
||||
remote_dir=trial.logdir,
|
||||
sync_function=self._sync_function)
|
||||
|
||||
def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: Checkpoint):
|
||||
if checkpoint.storage == Checkpoint.MEMORY:
|
||||
return
|
||||
|
||||
# Local import to avoid circular dependencies between syncer and
|
||||
# trainable
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
|
||||
trial_syncer = self._get_trial_syncer(trial)
|
||||
if trial.sync_on_checkpoint:
|
||||
try:
|
||||
# Wait for any other syncs to finish. We need to sync again
|
||||
# after this to handle checkpoints taken mid-sync.
|
||||
trial_syncer.wait()
|
||||
except TuneError as e:
|
||||
# Errors occurring during this wait are not fatal for this
|
||||
# checkpoint, so it should just be logged.
|
||||
logger.error(
|
||||
"Trial %s: An error occurred during the "
|
||||
"checkpoint pre-sync wait - %s", trial, str(e))
|
||||
# Force sync down and wait before tracking the new checkpoint.
|
||||
try:
|
||||
if trial_syncer.sync_down():
|
||||
trial_syncer.wait()
|
||||
else:
|
||||
logger.error(
|
||||
"Trial %s: Checkpoint sync skipped. "
|
||||
"This should not happen.", trial)
|
||||
except TuneError as e:
|
||||
if issubclass(trial.get_trainable_cls(), DurableTrainable):
|
||||
# Even though rsync failed the trainable can restore
|
||||
# from remote durable storage.
|
||||
logger.error("Trial %s: Sync error - %s", trial, str(e))
|
||||
else:
|
||||
# If the trainable didn't have remote storage to upload
|
||||
# to then this checkpoint may have been lost, so we
|
||||
# shouldn't track it with the checkpoint_manager.
|
||||
raise e
|
||||
if not issubclass(trial.get_trainable_cls(), DurableTrainable):
|
||||
if not os.path.exists(checkpoint.value):
|
||||
raise TuneError("Trial {}: Checkpoint path {} not "
|
||||
"found after successful sync down.".format(
|
||||
trial, checkpoint.value))
|
||||
|
||||
def on_trial_start(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
self._get_trial_syncer(trial)
|
||||
|
||||
def on_trial_result(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", result: Dict, **info):
|
||||
trial_syncer = self._get_trial_syncer(trial)
|
||||
trial_syncer.set_worker_ip(result.get(NODE_IP))
|
||||
trial_syncer.sync_down_if_needed()
|
||||
|
||||
def on_trial_complete(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", **info):
|
||||
trial_syncer = self._get_trial_syncer(trial)
|
||||
if NODE_IP in trial.last_result:
|
||||
trainable_ip = trial.last_result[NODE_IP]
|
||||
else:
|
||||
trainable_ip = ray.get(trial.runner.get_current_ip.remote())
|
||||
trial_syncer.set_worker_ip(trainable_ip)
|
||||
trial_syncer.sync_down_if_needed()
|
||||
|
||||
def on_checkpoint(self, iteration: int, trials: List["Trial"],
|
||||
trial: "Trial", checkpoint: Checkpoint, **info):
|
||||
self._sync_trial_checkpoint(trial, checkpoint)
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import inspect
|
||||
import time
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from typing import Callable, Union
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib import _register_all
|
||||
|
@ -18,7 +21,7 @@ from ray.tune.error import TuneError
|
|||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.syncer import CloudSyncer
|
||||
from ray.tune.syncer import CloudSyncer, SyncerCallback, get_node_syncer
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
|
@ -55,6 +58,19 @@ def _start_new_cluster():
|
|||
return cluster
|
||||
|
||||
|
||||
class _PerTrialSyncerCallback(SyncerCallback):
|
||||
def __init__(
|
||||
self,
|
||||
get_sync_fn: Callable[["Trial"], Union[None, bool, Callable]]):
|
||||
self._get_sync_fn = get_sync_fn
|
||||
super(_PerTrialSyncerCallback, self).__init__(None)
|
||||
|
||||
def _create_trial_syncer(self, trial: "Trial"):
|
||||
sync_fn = self._get_sync_fn(trial)
|
||||
return get_node_syncer(
|
||||
trial.logdir, remote_dir=trial.logdir, sync_function=sync_fn)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_connected_cluster():
|
||||
# Start the Ray processes.
|
||||
|
@ -255,7 +271,9 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
|||
node = cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
syncer_callback = _PerTrialSyncerCallback(
|
||||
lambda trial: trial.trainable_name == "__fake")
|
||||
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 4
|
||||
|
@ -263,7 +281,6 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
|||
"checkpoint_freq": 2,
|
||||
"max_failures": 2,
|
||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||
"sync_to_driver_fn": trainable_id == "__fake",
|
||||
}
|
||||
|
||||
# Test recovery of trial that hasn't been checkpointed
|
||||
|
@ -316,7 +333,6 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
|||
"training_iteration": 3
|
||||
},
|
||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||
"sync_to_driver_fn": trainable_id == "__fake",
|
||||
}
|
||||
t3 = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t3)
|
||||
|
@ -341,7 +357,9 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
|||
node = cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
syncer_callback = _PerTrialSyncerCallback(
|
||||
lambda trial: trial.trainable_name == "__fake")
|
||||
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 5
|
||||
|
@ -349,7 +367,6 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
|||
"checkpoint_freq": 1,
|
||||
"max_failures": 1,
|
||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||
"sync_to_driver_fn": trainable_id == "__fake",
|
||||
}
|
||||
|
||||
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
||||
|
@ -382,7 +399,13 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
|||
node = cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
class _SyncerCallback(SyncerCallback):
|
||||
def _create_trial_syncer(self, trial: "Trial"):
|
||||
client = mock_storage_client()
|
||||
return MockNodeSyncer(trial.logdir, trial.logdir, client)
|
||||
|
||||
syncer_callback = _SyncerCallback(None)
|
||||
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 4
|
||||
|
@ -390,7 +413,6 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
|||
"checkpoint_freq": 2,
|
||||
"max_failures": 2,
|
||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||
"sync_to_driver_fn": trainable_id == "__fake_remote",
|
||||
}
|
||||
|
||||
# The following patches only affect __fake_remote.
|
||||
|
@ -415,33 +437,26 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
|||
# TrainableUtil will not check this path unless we mock it.
|
||||
mock_find.side_effect = hide_remote_path(find_func)
|
||||
mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func)
|
||||
with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer:
|
||||
|
||||
def mock_get_syncer_fn(local_dir, remote_dir, sync_function):
|
||||
client = mock_storage_client()
|
||||
return MockNodeSyncer(local_dir, remote_dir, client)
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t1 = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t1)
|
||||
|
||||
mock_get_node_syncer.side_effect = mock_get_syncer_fn
|
||||
# Start trial, process result (x2), process save
|
||||
for _ in range(4):
|
||||
runner.step()
|
||||
assert t1.has_checkpoint()
|
||||
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t1 = Trial(trainable_id, **kwargs)
|
||||
runner.add_trial(t1)
|
||||
|
||||
# Start trial, process result (x2), process save
|
||||
for _ in range(4):
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
|
||||
runner.step() # Collect result 3, kick off + fail result 4
|
||||
runner.step() # Dispatch restore
|
||||
runner.step() # Process restore + step 4
|
||||
for _ in range(3):
|
||||
if t1.status != Trial.TERMINATED:
|
||||
runner.step()
|
||||
assert t1.has_checkpoint()
|
||||
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
|
||||
runner.step() # Collect result 3, kick off + fail result 4
|
||||
runner.step() # Dispatch restore
|
||||
runner.step() # Process restore + step 4
|
||||
for _ in range(3):
|
||||
if t1.status != Trial.TERMINATED:
|
||||
runner.step()
|
||||
assert t1.status == Trial.TERMINATED, runner.debug_string()
|
||||
|
||||
|
||||
|
@ -454,7 +469,12 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
|
|||
cluster.wait_for_nodes()
|
||||
|
||||
dirpath = str(tmpdir)
|
||||
runner = TrialRunner(local_checkpoint_dir=dirpath, checkpoint_period=0)
|
||||
syncer_callback = _PerTrialSyncerCallback(
|
||||
lambda trial: trial.trainable_name == "__fake")
|
||||
runner = TrialRunner(
|
||||
local_checkpoint_dir=dirpath,
|
||||
checkpoint_period=0,
|
||||
callbacks=[syncer_callback])
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
|
@ -462,7 +482,6 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
|
|||
"checkpoint_freq": 1,
|
||||
"max_failures": 1,
|
||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||
"sync_to_driver_fn": trainable_id == "__fake",
|
||||
}
|
||||
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
||||
for t in trials:
|
||||
|
|
|
@ -11,7 +11,6 @@ import ray
|
|||
from ray.rllib import _register_all
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.syncer import CommandBasedClient
|
||||
|
||||
|
||||
|
@ -87,8 +86,8 @@ class TestSyncFunctionality(unittest.TestCase):
|
|||
|
||||
def testClusterProperString(self):
|
||||
"""Tests that invalid commands throw.."""
|
||||
with self.assertRaises(TuneError):
|
||||
# This raises TuneError because logger is init in safe zone.
|
||||
with self.assertRaises(ValueError):
|
||||
# This raises ValueError because logger is init in safe zone.
|
||||
sync_config = tune.SyncConfig(sync_to_driver="ls {target}")
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
|
@ -100,8 +99,8 @@ class TestSyncFunctionality(unittest.TestCase):
|
|||
sync_config=sync_config,
|
||||
).trials
|
||||
|
||||
with self.assertRaises(TuneError):
|
||||
# This raises TuneError because logger is init in safe zone.
|
||||
with self.assertRaises(ValueError):
|
||||
# This raises ValueError because logger is init in safe zone.
|
||||
sync_config = tune.SyncConfig(sync_to_driver="ls {source}")
|
||||
[trial] = tune.run(
|
||||
"__fake",
|
||||
|
|
|
@ -124,11 +124,7 @@ class TuneExampleTest(unittest.TestCase):
|
|||
class AutoInitTest(unittest.TestCase):
|
||||
def testTuneRestore(self):
|
||||
self.assertFalse(ray.is_initialized())
|
||||
tune.run(
|
||||
"__fake",
|
||||
name="TestAutoInit",
|
||||
stop={"training_iteration": 1},
|
||||
ray_auto_init=True)
|
||||
tune.run("__fake", name="TestAutoInit", stop={"training_iteration": 1})
|
||||
self.assertTrue(ray.is_initialized())
|
||||
|
||||
def tearDown(self):
|
||||
|
|
|
@ -12,7 +12,6 @@ import os
|
|||
from numbers import Number
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
||||
# need because there are cyclic imports that may cause specific names to not
|
||||
|
@ -192,7 +191,6 @@ class Trial:
|
|||
trial_dirname_creator=None,
|
||||
loggers=None,
|
||||
log_to_file=None,
|
||||
sync_to_driver_fn=None,
|
||||
max_failures=0):
|
||||
"""Initialize a new trial.
|
||||
|
||||
|
@ -232,7 +230,6 @@ class Trial:
|
|||
or not len(self.log_to_file) == 2:
|
||||
self.log_to_file = (None, None)
|
||||
|
||||
self.sync_to_driver_fn = sync_to_driver_fn
|
||||
self.verbose = True
|
||||
self.max_failures = max_failures
|
||||
|
||||
|
@ -289,7 +286,6 @@ class Trial:
|
|||
|
||||
self._nonjson_fields = [
|
||||
"loggers",
|
||||
"sync_to_driver_fn",
|
||||
"results",
|
||||
"best_result",
|
||||
"param_config",
|
||||
|
@ -356,7 +352,6 @@ class Trial:
|
|||
trial_name_creator=self.trial_name_creator,
|
||||
loggers=self.loggers,
|
||||
log_to_file=self.log_to_file,
|
||||
sync_to_driver_fn=self.sync_to_driver_fn,
|
||||
max_failures=self.max_failures,
|
||||
)
|
||||
|
||||
|
@ -370,11 +365,7 @@ class Trial:
|
|||
os.makedirs(self.logdir, exist_ok=True)
|
||||
|
||||
self.result_logger = UnifiedLogger(
|
||||
self.config,
|
||||
self.logdir,
|
||||
trial=self,
|
||||
loggers=self.loggers,
|
||||
sync_function=self.sync_to_driver_fn)
|
||||
self.config, self.logdir, trial=self, loggers=self.loggers)
|
||||
|
||||
def update_resources(self, cpu, gpu, **kwargs):
|
||||
"""EXPERIMENTAL: Updates the resource requirements.
|
||||
|
@ -459,43 +450,6 @@ class Trial:
|
|||
Args:
|
||||
checkpoint (Checkpoint): Checkpoint taken.
|
||||
"""
|
||||
if checkpoint.storage == Checkpoint.MEMORY:
|
||||
self.checkpoint_manager.on_checkpoint(checkpoint)
|
||||
return
|
||||
if self.sync_on_checkpoint:
|
||||
try:
|
||||
# Wait for any other syncs to finish. We need to sync again
|
||||
# after this to handle checkpoints taken mid-sync.
|
||||
self.result_logger.wait()
|
||||
except TuneError as e:
|
||||
# Errors occurring during this wait are not fatal for this
|
||||
# checkpoint, so it should just be logged.
|
||||
logger.error(
|
||||
"Trial %s: An error occurred during the "
|
||||
"checkpoint pre-sync wait - %s", self, str(e))
|
||||
# Force sync down and wait before tracking the new checkpoint.
|
||||
try:
|
||||
if self.result_logger.sync_down():
|
||||
self.result_logger.wait()
|
||||
else:
|
||||
logger.error(
|
||||
"Trial %s: Checkpoint sync skipped. "
|
||||
"This should not happen.", self)
|
||||
except TuneError as e:
|
||||
if issubclass(self.get_trainable_cls(), DurableTrainable):
|
||||
# Even though rsync failed the trainable can restore
|
||||
# from remote durable storage.
|
||||
logger.error("Trial %s: Sync error - %s", self, str(e))
|
||||
else:
|
||||
# If the trainable didn't have remote storage to upload
|
||||
# to then this checkpoint may have been lost, so we
|
||||
# shouldn't track it with the checkpoint_manager.
|
||||
raise e
|
||||
if not issubclass(self.get_trainable_cls(), DurableTrainable):
|
||||
if not os.path.exists(checkpoint.value):
|
||||
raise TuneError("Trial {}: Checkpoint path {} not "
|
||||
"found after successful sync down.".format(
|
||||
self, checkpoint.value))
|
||||
self.checkpoint_manager.on_checkpoint(checkpoint)
|
||||
|
||||
def on_restore(self):
|
||||
|
@ -515,7 +469,6 @@ class Trial:
|
|||
return self.num_failures < self.max_failures or self.max_failures < 0
|
||||
|
||||
def update_last_result(self, result, terminate=False):
|
||||
result.update(trial_id=self.trial_id, done=terminate)
|
||||
if self.experiment_tag:
|
||||
result.update(experiment_tag=self.experiment_tag)
|
||||
if self.verbose and (terminate or time.time() - self.last_debug >
|
||||
|
@ -634,7 +587,7 @@ class Trial:
|
|||
state["resuming_from"] = None
|
||||
state["saving_to"] = None
|
||||
if self.result_logger:
|
||||
self.result_logger.flush(sync_down=False)
|
||||
self.result_logger.flush()
|
||||
state["__logger_started__"] = True
|
||||
else:
|
||||
state["__logger_started__"] = False
|
||||
|
|
|
@ -11,8 +11,10 @@ from ray.tune.suggest.variant_generator import has_unresolved_values
|
|||
from ray.tune.trial import Trial
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.utils.callback import create_default_callbacks
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.syncer import wait_for_sync, set_sync_periods, SyncConfig
|
||||
from ray.tune.syncer import wait_for_sync, set_sync_periods, \
|
||||
SyncConfig
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
|
||||
from ray.tune.schedulers import FIFOScheduler
|
||||
|
@ -353,6 +355,9 @@ def run(
|
|||
"own `metric` and `mode` parameters. Either remove the arguments "
|
||||
"from your scheduler or from your call to `tune.run()`")
|
||||
|
||||
# Create syncer callbacks
|
||||
callbacks = create_default_callbacks(callbacks, sync_config)
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler,
|
||||
|
|
24
python/ray/tune/utils/callback.py
Normal file
24
python/ray/tune/utils/callback.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.syncer import SyncConfig
|
||||
from ray.tune.syncer import SyncerCallback
|
||||
|
||||
|
||||
def create_default_callbacks(callbacks: Optional[List[Callback]],
|
||||
sync_config: SyncConfig):
|
||||
|
||||
callbacks = callbacks or []
|
||||
|
||||
# Check if there is a SyncerCallback
|
||||
has_syncer_callback = any(isinstance(c, SyncerCallback) for c in callbacks)
|
||||
|
||||
# If no SyncerCallback was found, add
|
||||
if not has_syncer_callback and os.environ.get(
|
||||
"TUNE_DISABLE_AUTO_CALLBACK_SYNCER", "0") != "1":
|
||||
syncer_callback = SyncerCallback(
|
||||
sync_function=sync_config.sync_to_driver)
|
||||
callbacks.append(syncer_callback)
|
||||
|
||||
return callbacks
|
Loading…
Add table
Reference in a new issue