mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
[tune] logger refactor part 2: Add SyncerCallback (#11748)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
05c4e3fb2a
commit
007634fd1b
10 changed files with 211 additions and 193 deletions
|
@ -1,7 +1,9 @@
|
||||||
from typing import Dict, List
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
from ray.tune.checkpoint_manager import Checkpoint
|
from ray.tune.checkpoint_manager import Checkpoint
|
||||||
from ray.tune.trial import Trial
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.tune.trial import Trial
|
||||||
|
|
||||||
|
|
||||||
class Callback:
|
class Callback:
|
||||||
|
@ -39,7 +41,7 @@ class Callback:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def on_step_begin(self, iteration: int, trials: List[Trial], **info):
|
def on_step_begin(self, iteration: int, trials: List["Trial"], **info):
|
||||||
"""Called at the start of each tuning loop step.
|
"""Called at the start of each tuning loop step.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -49,7 +51,7 @@ class Callback:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_step_end(self, iteration: int, trials: List[Trial], **info):
|
def on_step_end(self, iteration: int, trials: List["Trial"], **info):
|
||||||
"""Called at the end of each tuning loop step.
|
"""Called at the end of each tuning loop step.
|
||||||
|
|
||||||
The iteration counter is increased before this hook is called.
|
The iteration counter is increased before this hook is called.
|
||||||
|
@ -61,8 +63,8 @@ class Callback:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial,
|
def on_trial_start(self, iteration: int, trials: List["Trial"],
|
||||||
**info):
|
trial: "Trial", **info):
|
||||||
"""Called after starting a trial instance.
|
"""Called after starting a trial instance.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -74,8 +76,8 @@ class Callback:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_trial_restore(self, iteration: int, trials: List[Trial],
|
def on_trial_restore(self, iteration: int, trials: List["Trial"],
|
||||||
trial: Trial, **info):
|
trial: "Trial", **info):
|
||||||
"""Called after restoring a trial instance.
|
"""Called after restoring a trial instance.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -86,8 +88,8 @@ class Callback:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_trial_save(self, iteration: int, trials: List[Trial], trial: Trial,
|
def on_trial_save(self, iteration: int, trials: List["Trial"],
|
||||||
**info):
|
trial: "Trial", **info):
|
||||||
"""Called after receiving a checkpoint from a trial.
|
"""Called after receiving a checkpoint from a trial.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -98,8 +100,8 @@ class Callback:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_trial_result(self, iteration: int, trials: List[Trial],
|
def on_trial_result(self, iteration: int, trials: List["Trial"],
|
||||||
trial: Trial, result: Dict, **info):
|
trial: "Trial", result: Dict, **info):
|
||||||
"""Called after receiving a result from a trial.
|
"""Called after receiving a result from a trial.
|
||||||
|
|
||||||
The search algorithm and scheduler are notified before this
|
The search algorithm and scheduler are notified before this
|
||||||
|
@ -114,8 +116,8 @@ class Callback:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_trial_complete(self, iteration: int, trials: List[Trial],
|
def on_trial_complete(self, iteration: int, trials: List["Trial"],
|
||||||
trial: Trial, **info):
|
trial: "Trial", **info):
|
||||||
"""Called after a trial instance completed.
|
"""Called after a trial instance completed.
|
||||||
|
|
||||||
The search algorithm and scheduler are notified before this
|
The search algorithm and scheduler are notified before this
|
||||||
|
@ -129,8 +131,8 @@ class Callback:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_trial_error(self, iteration: int, trials: List[Trial], trial: Trial,
|
def on_trial_error(self, iteration: int, trials: List["Trial"],
|
||||||
**info):
|
trial: "Trial", **info):
|
||||||
"""Called after a trial instance failed (errored).
|
"""Called after a trial instance failed (errored).
|
||||||
|
|
||||||
The search algorithm and scheduler are notified before this
|
The search algorithm and scheduler are notified before this
|
||||||
|
@ -144,8 +146,8 @@ class Callback:
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_checkpoint(self, iteration: int, trials: List[Trial], trial: Trial,
|
def on_checkpoint(self, iteration: int, trials: List["Trial"],
|
||||||
checkpoint: Checkpoint, **info):
|
trial: "Trial", checkpoint: Checkpoint, **info):
|
||||||
"""Called after a trial saved a checkpoint with Tune.
|
"""Called after a trial saved a checkpoint with Tune.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
|
@ -193,6 +193,5 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||||
loggers=spec.get("loggers"),
|
loggers=spec.get("loggers"),
|
||||||
log_to_file=spec.get("log_to_file"),
|
log_to_file=spec.get("log_to_file"),
|
||||||
# str(None) doesn't create None
|
# str(None) doesn't create None
|
||||||
sync_to_driver_fn=spec.get("sync_to_driver"),
|
|
||||||
max_failures=args.max_failures,
|
max_failures=args.max_failures,
|
||||||
**trial_kwargs)
|
**trial_kwargs)
|
||||||
|
|
|
@ -1,21 +1,19 @@
|
||||||
import csv
|
import csv
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import numbers
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Type
|
||||||
|
|
||||||
import ray.cloudpickle as cloudpickle
|
import ray.cloudpickle as cloudpickle
|
||||||
|
|
||||||
from ray.tune.utils.util import SafeFallbackEncoder
|
from ray.tune.utils.util import SafeFallbackEncoder
|
||||||
from ray.util.debug import log_once
|
from ray.util.debug import log_once
|
||||||
from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S,
|
from ray.tune.result import (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
|
||||||
TIMESTEPS_TOTAL, EXPR_PARAM_FILE,
|
EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE,
|
||||||
EXPR_PARAM_PICKLE_FILE, EXPR_PROGRESS_FILE,
|
EXPR_PROGRESS_FILE, EXPR_RESULT_FILE)
|
||||||
EXPR_RESULT_FILE)
|
|
||||||
from ray.tune.syncer import get_node_syncer
|
|
||||||
from ray.tune.utils import flatten_dict
|
from ray.tune.utils import flatten_dict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -317,16 +315,13 @@ class UnifiedLogger(Logger):
|
||||||
logdir: Directory for all logger creators to log to.
|
logdir: Directory for all logger creators to log to.
|
||||||
loggers (list): List of logger creators. Defaults to CSV, Tensorboard,
|
loggers (list): List of logger creators. Defaults to CSV, Tensorboard,
|
||||||
and JSON loggers.
|
and JSON loggers.
|
||||||
sync_function (func|str): Optional function for syncer to run.
|
|
||||||
See ray/python/ray/tune/syncer.py
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Dict,
|
config: Dict,
|
||||||
logdir: str,
|
logdir: str,
|
||||||
trial: Optional["Trial"] = None,
|
trial: Optional["Trial"] = None,
|
||||||
loggers: Optional[List[Type[Logger]]] = None,
|
loggers: Optional[List[Type[Logger]]] = None):
|
||||||
sync_function: Union[None, Callable, str] = None):
|
|
||||||
if loggers is None:
|
if loggers is None:
|
||||||
self._logger_cls_list = DEFAULT_LOGGERS
|
self._logger_cls_list = DEFAULT_LOGGERS
|
||||||
else:
|
else:
|
||||||
|
@ -336,8 +331,6 @@ class UnifiedLogger(Logger):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"JsonLogger not provided. The ExperimentAnalysis tool is "
|
"JsonLogger not provided. The ExperimentAnalysis tool is "
|
||||||
"disabled.")
|
"disabled.")
|
||||||
self._sync_function = sync_function
|
|
||||||
self._log_syncer = None
|
|
||||||
|
|
||||||
super(UnifiedLogger, self).__init__(config, logdir, trial)
|
super(UnifiedLogger, self).__init__(config, logdir, trial)
|
||||||
|
|
||||||
|
@ -350,16 +343,10 @@ class UnifiedLogger(Logger):
|
||||||
if log_once(f"instantiate:{cls.__name__}"):
|
if log_once(f"instantiate:{cls.__name__}"):
|
||||||
logger.warning("Could not instantiate %s: %s.",
|
logger.warning("Could not instantiate %s: %s.",
|
||||||
cls.__name__, str(exc))
|
cls.__name__, str(exc))
|
||||||
self._log_syncer = get_node_syncer(
|
|
||||||
self.logdir,
|
|
||||||
remote_dir=self.logdir,
|
|
||||||
sync_function=self._sync_function)
|
|
||||||
|
|
||||||
def on_result(self, result):
|
def on_result(self, result):
|
||||||
for _logger in self._loggers:
|
for _logger in self._loggers:
|
||||||
_logger.on_result(result)
|
_logger.on_result(result)
|
||||||
self._log_syncer.set_worker_ip(result.get(NODE_IP))
|
|
||||||
self._log_syncer.sync_down_if_needed()
|
|
||||||
|
|
||||||
def update_config(self, config):
|
def update_config(self, config):
|
||||||
for _logger in self._loggers:
|
for _logger in self._loggers:
|
||||||
|
@ -369,68 +356,9 @@ class UnifiedLogger(Logger):
|
||||||
for _logger in self._loggers:
|
for _logger in self._loggers:
|
||||||
_logger.close()
|
_logger.close()
|
||||||
|
|
||||||
def flush(self, sync_down=True):
|
def flush(self):
|
||||||
for _logger in self._loggers:
|
for _logger in self._loggers:
|
||||||
_logger.flush()
|
_logger.flush()
|
||||||
if sync_down:
|
|
||||||
if not self._log_syncer.sync_down():
|
|
||||||
logger.warning("Trial %s: Post-flush sync skipped.",
|
|
||||||
self.trial)
|
|
||||||
|
|
||||||
def sync_up(self):
|
|
||||||
return self._log_syncer.sync_up()
|
|
||||||
|
|
||||||
def sync_down(self):
|
|
||||||
return self._log_syncer.sync_down()
|
|
||||||
|
|
||||||
def wait(self):
|
|
||||||
self._log_syncer.wait()
|
|
||||||
|
|
||||||
def sync_results_to_new_location(self, worker_ip):
|
|
||||||
"""Sends the current log directory to the remote node.
|
|
||||||
|
|
||||||
Syncing will not occur if the cluster is not started
|
|
||||||
with the Ray autoscaler.
|
|
||||||
"""
|
|
||||||
if worker_ip != self._log_syncer.worker_ip:
|
|
||||||
logger.info("Trial %s: Syncing (blocking) results to %s",
|
|
||||||
self.trial, worker_ip)
|
|
||||||
self._log_syncer.reset()
|
|
||||||
self._log_syncer.set_worker_ip(worker_ip)
|
|
||||||
if not self._log_syncer.sync_up():
|
|
||||||
logger.error(
|
|
||||||
"Trial %s: Sync up to new location skipped. "
|
|
||||||
"This should not occur.", self.trial)
|
|
||||||
self._log_syncer.wait()
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
"Trial %s: Sync attempted to same IP %s. This "
|
|
||||||
"should not occur.", self.trial, worker_ip)
|
|
||||||
|
|
||||||
|
|
||||||
class _SafeFallbackEncoder(json.JSONEncoder):
|
|
||||||
def __init__(self, nan_str="null", **kwargs):
|
|
||||||
super(_SafeFallbackEncoder, self).__init__(**kwargs)
|
|
||||||
self.nan_str = nan_str
|
|
||||||
|
|
||||||
def default(self, value):
|
|
||||||
try:
|
|
||||||
if np.isnan(value):
|
|
||||||
return self.nan_str
|
|
||||||
|
|
||||||
if (type(value).__module__ == np.__name__
|
|
||||||
and isinstance(value, np.ndarray)):
|
|
||||||
return value.tolist()
|
|
||||||
|
|
||||||
if issubclass(type(value), numbers.Integral):
|
|
||||||
return int(value)
|
|
||||||
if issubclass(type(value), numbers.Number):
|
|
||||||
return float(value)
|
|
||||||
|
|
||||||
return super(_SafeFallbackEncoder, self).default(value)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return str(value) # give up, just stringify it (ok for logs)
|
|
||||||
|
|
||||||
|
|
||||||
def pretty_print(result):
|
def pretty_print(result):
|
||||||
|
@ -442,5 +370,5 @@ def pretty_print(result):
|
||||||
if v is not None:
|
if v is not None:
|
||||||
out[k] = v
|
out[k] = v
|
||||||
|
|
||||||
cleaned = json.dumps(out, cls=_SafeFallbackEncoder)
|
cleaned = json.dumps(out, cls=SafeFallbackEncoder)
|
||||||
return yaml.safe_dump(json.loads(cleaned), default_flow_style=False)
|
return yaml.safe_dump(json.loads(cleaned), default_flow_style=False)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any
|
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import distutils
|
import distutils
|
||||||
import logging
|
import logging
|
||||||
|
@ -9,13 +9,21 @@ from dataclasses import dataclass
|
||||||
from inspect import isclass
|
from inspect import isclass
|
||||||
from shlex import quote
|
from shlex import quote
|
||||||
|
|
||||||
|
import ray
|
||||||
from ray import services
|
from ray import services
|
||||||
|
from ray.tune import TuneError
|
||||||
|
from ray.tune.callback import Callback
|
||||||
|
from ray.tune.checkpoint_manager import Checkpoint
|
||||||
|
from ray.tune.result import NODE_IP
|
||||||
from ray.util.debug import log_once
|
from ray.util.debug import log_once
|
||||||
from ray.tune.utils.util import env_integer
|
from ray.tune.utils.util import env_integer
|
||||||
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
||||||
from ray.tune.sync_client import (CommandBasedClient, get_sync_client,
|
from ray.tune.sync_client import (CommandBasedClient, get_sync_client,
|
||||||
get_cloud_sync_client, NOOP)
|
get_cloud_sync_client, NOOP)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.tune.trial import Trial
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Syncing period for syncing local checkpoints to cloud.
|
# Syncing period for syncing local checkpoints to cloud.
|
||||||
|
@ -355,3 +363,88 @@ def get_node_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||||
|
|
||||||
_syncers[key] = NodeSyncer(local_dir, remote_dir, sync_client)
|
_syncers[key] = NodeSyncer(local_dir, remote_dir, sync_client)
|
||||||
return _syncers[key]
|
return _syncers[key]
|
||||||
|
|
||||||
|
|
||||||
|
class SyncerCallback(Callback):
|
||||||
|
def __init__(self, sync_function: Union[None, bool, Callable]):
|
||||||
|
self._sync_function = sync_function
|
||||||
|
self._syncers: Dict["Trial", NodeSyncer] = {}
|
||||||
|
|
||||||
|
def _get_trial_syncer(self, trial: "Trial"):
|
||||||
|
if trial not in self._syncers:
|
||||||
|
self._syncers[trial] = self._create_trial_syncer(trial)
|
||||||
|
return self._syncers[trial]
|
||||||
|
|
||||||
|
def _create_trial_syncer(self, trial: "Trial"):
|
||||||
|
return get_node_syncer(
|
||||||
|
trial.logdir,
|
||||||
|
remote_dir=trial.logdir,
|
||||||
|
sync_function=self._sync_function)
|
||||||
|
|
||||||
|
def _sync_trial_checkpoint(self, trial: "Trial", checkpoint: Checkpoint):
|
||||||
|
if checkpoint.storage == Checkpoint.MEMORY:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Local import to avoid circular dependencies between syncer and
|
||||||
|
# trainable
|
||||||
|
from ray.tune.durable_trainable import DurableTrainable
|
||||||
|
|
||||||
|
trial_syncer = self._get_trial_syncer(trial)
|
||||||
|
if trial.sync_on_checkpoint:
|
||||||
|
try:
|
||||||
|
# Wait for any other syncs to finish. We need to sync again
|
||||||
|
# after this to handle checkpoints taken mid-sync.
|
||||||
|
trial_syncer.wait()
|
||||||
|
except TuneError as e:
|
||||||
|
# Errors occurring during this wait are not fatal for this
|
||||||
|
# checkpoint, so it should just be logged.
|
||||||
|
logger.error(
|
||||||
|
"Trial %s: An error occurred during the "
|
||||||
|
"checkpoint pre-sync wait - %s", trial, str(e))
|
||||||
|
# Force sync down and wait before tracking the new checkpoint.
|
||||||
|
try:
|
||||||
|
if trial_syncer.sync_down():
|
||||||
|
trial_syncer.wait()
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Trial %s: Checkpoint sync skipped. "
|
||||||
|
"This should not happen.", trial)
|
||||||
|
except TuneError as e:
|
||||||
|
if issubclass(trial.get_trainable_cls(), DurableTrainable):
|
||||||
|
# Even though rsync failed the trainable can restore
|
||||||
|
# from remote durable storage.
|
||||||
|
logger.error("Trial %s: Sync error - %s", trial, str(e))
|
||||||
|
else:
|
||||||
|
# If the trainable didn't have remote storage to upload
|
||||||
|
# to then this checkpoint may have been lost, so we
|
||||||
|
# shouldn't track it with the checkpoint_manager.
|
||||||
|
raise e
|
||||||
|
if not issubclass(trial.get_trainable_cls(), DurableTrainable):
|
||||||
|
if not os.path.exists(checkpoint.value):
|
||||||
|
raise TuneError("Trial {}: Checkpoint path {} not "
|
||||||
|
"found after successful sync down.".format(
|
||||||
|
trial, checkpoint.value))
|
||||||
|
|
||||||
|
def on_trial_start(self, iteration: int, trials: List["Trial"],
|
||||||
|
trial: "Trial", **info):
|
||||||
|
self._get_trial_syncer(trial)
|
||||||
|
|
||||||
|
def on_trial_result(self, iteration: int, trials: List["Trial"],
|
||||||
|
trial: "Trial", result: Dict, **info):
|
||||||
|
trial_syncer = self._get_trial_syncer(trial)
|
||||||
|
trial_syncer.set_worker_ip(result.get(NODE_IP))
|
||||||
|
trial_syncer.sync_down_if_needed()
|
||||||
|
|
||||||
|
def on_trial_complete(self, iteration: int, trials: List["Trial"],
|
||||||
|
trial: "Trial", **info):
|
||||||
|
trial_syncer = self._get_trial_syncer(trial)
|
||||||
|
if NODE_IP in trial.last_result:
|
||||||
|
trainable_ip = trial.last_result[NODE_IP]
|
||||||
|
else:
|
||||||
|
trainable_ip = ray.get(trial.runner.get_current_ip.remote())
|
||||||
|
trial_syncer.set_worker_ip(trainable_ip)
|
||||||
|
trial_syncer.sync_down_if_needed()
|
||||||
|
|
||||||
|
def on_checkpoint(self, iteration: int, trials: List["Trial"],
|
||||||
|
trial: "Trial", checkpoint: Checkpoint, **info):
|
||||||
|
self._sync_trial_checkpoint(trial, checkpoint)
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
import inspect
|
import inspect
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from typing import Callable, Union
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.rllib import _register_all
|
from ray.rllib import _register_all
|
||||||
|
@ -18,7 +21,7 @@ from ray.tune.error import TuneError
|
||||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||||
from ray.tune.resources import Resources
|
from ray.tune.resources import Resources
|
||||||
from ray.tune.suggest import BasicVariantGenerator
|
from ray.tune.suggest import BasicVariantGenerator
|
||||||
from ray.tune.syncer import CloudSyncer
|
from ray.tune.syncer import CloudSyncer, SyncerCallback, get_node_syncer
|
||||||
from ray.tune.utils.trainable import TrainableUtil
|
from ray.tune.utils.trainable import TrainableUtil
|
||||||
from ray.tune.trial import Trial
|
from ray.tune.trial import Trial
|
||||||
from ray.tune.trial_runner import TrialRunner
|
from ray.tune.trial_runner import TrialRunner
|
||||||
|
@ -55,6 +58,19 @@ def _start_new_cluster():
|
||||||
return cluster
|
return cluster
|
||||||
|
|
||||||
|
|
||||||
|
class _PerTrialSyncerCallback(SyncerCallback):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
get_sync_fn: Callable[["Trial"], Union[None, bool, Callable]]):
|
||||||
|
self._get_sync_fn = get_sync_fn
|
||||||
|
super(_PerTrialSyncerCallback, self).__init__(None)
|
||||||
|
|
||||||
|
def _create_trial_syncer(self, trial: "Trial"):
|
||||||
|
sync_fn = self._get_sync_fn(trial)
|
||||||
|
return get_node_syncer(
|
||||||
|
trial.logdir, remote_dir=trial.logdir, sync_function=sync_fn)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def start_connected_cluster():
|
def start_connected_cluster():
|
||||||
# Start the Ray processes.
|
# Start the Ray processes.
|
||||||
|
@ -255,7 +271,9 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||||
node = cluster.add_node(num_cpus=1)
|
node = cluster.add_node(num_cpus=1)
|
||||||
cluster.wait_for_nodes()
|
cluster.wait_for_nodes()
|
||||||
|
|
||||||
runner = TrialRunner(BasicVariantGenerator())
|
syncer_callback = _PerTrialSyncerCallback(
|
||||||
|
lambda trial: trial.trainable_name == "__fake")
|
||||||
|
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"stopping_criterion": {
|
"stopping_criterion": {
|
||||||
"training_iteration": 4
|
"training_iteration": 4
|
||||||
|
@ -263,7 +281,6 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||||
"checkpoint_freq": 2,
|
"checkpoint_freq": 2,
|
||||||
"max_failures": 2,
|
"max_failures": 2,
|
||||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||||
"sync_to_driver_fn": trainable_id == "__fake",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Test recovery of trial that hasn't been checkpointed
|
# Test recovery of trial that hasn't been checkpointed
|
||||||
|
@ -316,7 +333,6 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||||
"training_iteration": 3
|
"training_iteration": 3
|
||||||
},
|
},
|
||||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||||
"sync_to_driver_fn": trainable_id == "__fake",
|
|
||||||
}
|
}
|
||||||
t3 = Trial(trainable_id, **kwargs)
|
t3 = Trial(trainable_id, **kwargs)
|
||||||
runner.add_trial(t3)
|
runner.add_trial(t3)
|
||||||
|
@ -341,7 +357,9 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
||||||
node = cluster.add_node(num_cpus=1)
|
node = cluster.add_node(num_cpus=1)
|
||||||
cluster.wait_for_nodes()
|
cluster.wait_for_nodes()
|
||||||
|
|
||||||
runner = TrialRunner(BasicVariantGenerator())
|
syncer_callback = _PerTrialSyncerCallback(
|
||||||
|
lambda trial: trial.trainable_name == "__fake")
|
||||||
|
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"stopping_criterion": {
|
"stopping_criterion": {
|
||||||
"training_iteration": 5
|
"training_iteration": 5
|
||||||
|
@ -349,7 +367,6 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
||||||
"checkpoint_freq": 1,
|
"checkpoint_freq": 1,
|
||||||
"max_failures": 1,
|
"max_failures": 1,
|
||||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||||
"sync_to_driver_fn": trainable_id == "__fake",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
||||||
|
@ -382,7 +399,13 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
||||||
node = cluster.add_node(num_cpus=1)
|
node = cluster.add_node(num_cpus=1)
|
||||||
cluster.wait_for_nodes()
|
cluster.wait_for_nodes()
|
||||||
|
|
||||||
runner = TrialRunner(BasicVariantGenerator())
|
class _SyncerCallback(SyncerCallback):
|
||||||
|
def _create_trial_syncer(self, trial: "Trial"):
|
||||||
|
client = mock_storage_client()
|
||||||
|
return MockNodeSyncer(trial.logdir, trial.logdir, client)
|
||||||
|
|
||||||
|
syncer_callback = _SyncerCallback(None)
|
||||||
|
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"stopping_criterion": {
|
"stopping_criterion": {
|
||||||
"training_iteration": 4
|
"training_iteration": 4
|
||||||
|
@ -390,7 +413,6 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
||||||
"checkpoint_freq": 2,
|
"checkpoint_freq": 2,
|
||||||
"max_failures": 2,
|
"max_failures": 2,
|
||||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||||
"sync_to_driver_fn": trainable_id == "__fake_remote",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# The following patches only affect __fake_remote.
|
# The following patches only affect __fake_remote.
|
||||||
|
@ -415,33 +437,26 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
||||||
# TrainableUtil will not check this path unless we mock it.
|
# TrainableUtil will not check this path unless we mock it.
|
||||||
mock_find.side_effect = hide_remote_path(find_func)
|
mock_find.side_effect = hide_remote_path(find_func)
|
||||||
mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func)
|
mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func)
|
||||||
with patch("ray.tune.logger.get_node_syncer") as mock_get_node_syncer:
|
|
||||||
|
|
||||||
def mock_get_syncer_fn(local_dir, remote_dir, sync_function):
|
# Test recovery of trial that has been checkpointed
|
||||||
client = mock_storage_client()
|
t1 = Trial(trainable_id, **kwargs)
|
||||||
return MockNodeSyncer(local_dir, remote_dir, client)
|
runner.add_trial(t1)
|
||||||
|
|
||||||
mock_get_node_syncer.side_effect = mock_get_syncer_fn
|
# Start trial, process result (x2), process save
|
||||||
|
for _ in range(4):
|
||||||
|
runner.step()
|
||||||
|
assert t1.has_checkpoint()
|
||||||
|
|
||||||
# Test recovery of trial that has been checkpointed
|
cluster.add_node(num_cpus=1)
|
||||||
t1 = Trial(trainable_id, **kwargs)
|
cluster.remove_node(node)
|
||||||
runner.add_trial(t1)
|
cluster.wait_for_nodes()
|
||||||
|
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
|
||||||
# Start trial, process result (x2), process save
|
runner.step() # Collect result 3, kick off + fail result 4
|
||||||
for _ in range(4):
|
runner.step() # Dispatch restore
|
||||||
|
runner.step() # Process restore + step 4
|
||||||
|
for _ in range(3):
|
||||||
|
if t1.status != Trial.TERMINATED:
|
||||||
runner.step()
|
runner.step()
|
||||||
assert t1.has_checkpoint()
|
|
||||||
|
|
||||||
cluster.add_node(num_cpus=1)
|
|
||||||
cluster.remove_node(node)
|
|
||||||
cluster.wait_for_nodes()
|
|
||||||
shutil.rmtree(os.path.dirname(t1.checkpoint.value))
|
|
||||||
runner.step() # Collect result 3, kick off + fail result 4
|
|
||||||
runner.step() # Dispatch restore
|
|
||||||
runner.step() # Process restore + step 4
|
|
||||||
for _ in range(3):
|
|
||||||
if t1.status != Trial.TERMINATED:
|
|
||||||
runner.step()
|
|
||||||
assert t1.status == Trial.TERMINATED, runner.debug_string()
|
assert t1.status == Trial.TERMINATED, runner.debug_string()
|
||||||
|
|
||||||
|
|
||||||
|
@ -454,7 +469,12 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
|
||||||
cluster.wait_for_nodes()
|
cluster.wait_for_nodes()
|
||||||
|
|
||||||
dirpath = str(tmpdir)
|
dirpath = str(tmpdir)
|
||||||
runner = TrialRunner(local_checkpoint_dir=dirpath, checkpoint_period=0)
|
syncer_callback = _PerTrialSyncerCallback(
|
||||||
|
lambda trial: trial.trainable_name == "__fake")
|
||||||
|
runner = TrialRunner(
|
||||||
|
local_checkpoint_dir=dirpath,
|
||||||
|
checkpoint_period=0,
|
||||||
|
callbacks=[syncer_callback])
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"stopping_criterion": {
|
"stopping_criterion": {
|
||||||
"training_iteration": 2
|
"training_iteration": 2
|
||||||
|
@ -462,7 +482,6 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir, trainable_id):
|
||||||
"checkpoint_freq": 1,
|
"checkpoint_freq": 1,
|
||||||
"max_failures": 1,
|
"max_failures": 1,
|
||||||
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
"remote_checkpoint_dir": MOCK_REMOTE_DIR,
|
||||||
"sync_to_driver_fn": trainable_id == "__fake",
|
|
||||||
}
|
}
|
||||||
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
trials = [Trial(trainable_id, **kwargs), Trial(trainable_id, **kwargs)]
|
||||||
for t in trials:
|
for t in trials:
|
||||||
|
|
|
@ -11,7 +11,6 @@ import ray
|
||||||
from ray.rllib import _register_all
|
from ray.rllib import _register_all
|
||||||
|
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.tune import TuneError
|
|
||||||
from ray.tune.syncer import CommandBasedClient
|
from ray.tune.syncer import CommandBasedClient
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,8 +86,8 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
|
|
||||||
def testClusterProperString(self):
|
def testClusterProperString(self):
|
||||||
"""Tests that invalid commands throw.."""
|
"""Tests that invalid commands throw.."""
|
||||||
with self.assertRaises(TuneError):
|
with self.assertRaises(ValueError):
|
||||||
# This raises TuneError because logger is init in safe zone.
|
# This raises ValueError because logger is init in safe zone.
|
||||||
sync_config = tune.SyncConfig(sync_to_driver="ls {target}")
|
sync_config = tune.SyncConfig(sync_to_driver="ls {target}")
|
||||||
[trial] = tune.run(
|
[trial] = tune.run(
|
||||||
"__fake",
|
"__fake",
|
||||||
|
@ -100,8 +99,8 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
sync_config=sync_config,
|
sync_config=sync_config,
|
||||||
).trials
|
).trials
|
||||||
|
|
||||||
with self.assertRaises(TuneError):
|
with self.assertRaises(ValueError):
|
||||||
# This raises TuneError because logger is init in safe zone.
|
# This raises ValueError because logger is init in safe zone.
|
||||||
sync_config = tune.SyncConfig(sync_to_driver="ls {source}")
|
sync_config = tune.SyncConfig(sync_to_driver="ls {source}")
|
||||||
[trial] = tune.run(
|
[trial] = tune.run(
|
||||||
"__fake",
|
"__fake",
|
||||||
|
|
|
@ -124,11 +124,7 @@ class TuneExampleTest(unittest.TestCase):
|
||||||
class AutoInitTest(unittest.TestCase):
|
class AutoInitTest(unittest.TestCase):
|
||||||
def testTuneRestore(self):
|
def testTuneRestore(self):
|
||||||
self.assertFalse(ray.is_initialized())
|
self.assertFalse(ray.is_initialized())
|
||||||
tune.run(
|
tune.run("__fake", name="TestAutoInit", stop={"training_iteration": 1})
|
||||||
"__fake",
|
|
||||||
name="TestAutoInit",
|
|
||||||
stop={"training_iteration": 1},
|
|
||||||
ray_auto_init=True)
|
|
||||||
self.assertTrue(ray.is_initialized())
|
self.assertTrue(ray.is_initialized())
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
|
|
@ -12,7 +12,6 @@ import os
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from ray.tune import TuneError
|
from ray.tune import TuneError
|
||||||
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
|
from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
|
||||||
from ray.tune.durable_trainable import DurableTrainable
|
|
||||||
from ray.tune.logger import pretty_print, UnifiedLogger
|
from ray.tune.logger import pretty_print, UnifiedLogger
|
||||||
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
# NOTE(rkn): We import ray.tune.registry here instead of importing the names we
|
||||||
# need because there are cyclic imports that may cause specific names to not
|
# need because there are cyclic imports that may cause specific names to not
|
||||||
|
@ -192,7 +191,6 @@ class Trial:
|
||||||
trial_dirname_creator=None,
|
trial_dirname_creator=None,
|
||||||
loggers=None,
|
loggers=None,
|
||||||
log_to_file=None,
|
log_to_file=None,
|
||||||
sync_to_driver_fn=None,
|
|
||||||
max_failures=0):
|
max_failures=0):
|
||||||
"""Initialize a new trial.
|
"""Initialize a new trial.
|
||||||
|
|
||||||
|
@ -232,7 +230,6 @@ class Trial:
|
||||||
or not len(self.log_to_file) == 2:
|
or not len(self.log_to_file) == 2:
|
||||||
self.log_to_file = (None, None)
|
self.log_to_file = (None, None)
|
||||||
|
|
||||||
self.sync_to_driver_fn = sync_to_driver_fn
|
|
||||||
self.verbose = True
|
self.verbose = True
|
||||||
self.max_failures = max_failures
|
self.max_failures = max_failures
|
||||||
|
|
||||||
|
@ -289,7 +286,6 @@ class Trial:
|
||||||
|
|
||||||
self._nonjson_fields = [
|
self._nonjson_fields = [
|
||||||
"loggers",
|
"loggers",
|
||||||
"sync_to_driver_fn",
|
|
||||||
"results",
|
"results",
|
||||||
"best_result",
|
"best_result",
|
||||||
"param_config",
|
"param_config",
|
||||||
|
@ -356,7 +352,6 @@ class Trial:
|
||||||
trial_name_creator=self.trial_name_creator,
|
trial_name_creator=self.trial_name_creator,
|
||||||
loggers=self.loggers,
|
loggers=self.loggers,
|
||||||
log_to_file=self.log_to_file,
|
log_to_file=self.log_to_file,
|
||||||
sync_to_driver_fn=self.sync_to_driver_fn,
|
|
||||||
max_failures=self.max_failures,
|
max_failures=self.max_failures,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -370,11 +365,7 @@ class Trial:
|
||||||
os.makedirs(self.logdir, exist_ok=True)
|
os.makedirs(self.logdir, exist_ok=True)
|
||||||
|
|
||||||
self.result_logger = UnifiedLogger(
|
self.result_logger = UnifiedLogger(
|
||||||
self.config,
|
self.config, self.logdir, trial=self, loggers=self.loggers)
|
||||||
self.logdir,
|
|
||||||
trial=self,
|
|
||||||
loggers=self.loggers,
|
|
||||||
sync_function=self.sync_to_driver_fn)
|
|
||||||
|
|
||||||
def update_resources(self, cpu, gpu, **kwargs):
|
def update_resources(self, cpu, gpu, **kwargs):
|
||||||
"""EXPERIMENTAL: Updates the resource requirements.
|
"""EXPERIMENTAL: Updates the resource requirements.
|
||||||
|
@ -459,43 +450,6 @@ class Trial:
|
||||||
Args:
|
Args:
|
||||||
checkpoint (Checkpoint): Checkpoint taken.
|
checkpoint (Checkpoint): Checkpoint taken.
|
||||||
"""
|
"""
|
||||||
if checkpoint.storage == Checkpoint.MEMORY:
|
|
||||||
self.checkpoint_manager.on_checkpoint(checkpoint)
|
|
||||||
return
|
|
||||||
if self.sync_on_checkpoint:
|
|
||||||
try:
|
|
||||||
# Wait for any other syncs to finish. We need to sync again
|
|
||||||
# after this to handle checkpoints taken mid-sync.
|
|
||||||
self.result_logger.wait()
|
|
||||||
except TuneError as e:
|
|
||||||
# Errors occurring during this wait are not fatal for this
|
|
||||||
# checkpoint, so it should just be logged.
|
|
||||||
logger.error(
|
|
||||||
"Trial %s: An error occurred during the "
|
|
||||||
"checkpoint pre-sync wait - %s", self, str(e))
|
|
||||||
# Force sync down and wait before tracking the new checkpoint.
|
|
||||||
try:
|
|
||||||
if self.result_logger.sync_down():
|
|
||||||
self.result_logger.wait()
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
"Trial %s: Checkpoint sync skipped. "
|
|
||||||
"This should not happen.", self)
|
|
||||||
except TuneError as e:
|
|
||||||
if issubclass(self.get_trainable_cls(), DurableTrainable):
|
|
||||||
# Even though rsync failed the trainable can restore
|
|
||||||
# from remote durable storage.
|
|
||||||
logger.error("Trial %s: Sync error - %s", self, str(e))
|
|
||||||
else:
|
|
||||||
# If the trainable didn't have remote storage to upload
|
|
||||||
# to then this checkpoint may have been lost, so we
|
|
||||||
# shouldn't track it with the checkpoint_manager.
|
|
||||||
raise e
|
|
||||||
if not issubclass(self.get_trainable_cls(), DurableTrainable):
|
|
||||||
if not os.path.exists(checkpoint.value):
|
|
||||||
raise TuneError("Trial {}: Checkpoint path {} not "
|
|
||||||
"found after successful sync down.".format(
|
|
||||||
self, checkpoint.value))
|
|
||||||
self.checkpoint_manager.on_checkpoint(checkpoint)
|
self.checkpoint_manager.on_checkpoint(checkpoint)
|
||||||
|
|
||||||
def on_restore(self):
|
def on_restore(self):
|
||||||
|
@ -515,7 +469,6 @@ class Trial:
|
||||||
return self.num_failures < self.max_failures or self.max_failures < 0
|
return self.num_failures < self.max_failures or self.max_failures < 0
|
||||||
|
|
||||||
def update_last_result(self, result, terminate=False):
|
def update_last_result(self, result, terminate=False):
|
||||||
result.update(trial_id=self.trial_id, done=terminate)
|
|
||||||
if self.experiment_tag:
|
if self.experiment_tag:
|
||||||
result.update(experiment_tag=self.experiment_tag)
|
result.update(experiment_tag=self.experiment_tag)
|
||||||
if self.verbose and (terminate or time.time() - self.last_debug >
|
if self.verbose and (terminate or time.time() - self.last_debug >
|
||||||
|
@ -634,7 +587,7 @@ class Trial:
|
||||||
state["resuming_from"] = None
|
state["resuming_from"] = None
|
||||||
state["saving_to"] = None
|
state["saving_to"] = None
|
||||||
if self.result_logger:
|
if self.result_logger:
|
||||||
self.result_logger.flush(sync_down=False)
|
self.result_logger.flush()
|
||||||
state["__logger_started__"] = True
|
state["__logger_started__"] = True
|
||||||
else:
|
else:
|
||||||
state["__logger_started__"] = False
|
state["__logger_started__"] = False
|
||||||
|
|
|
@ -11,8 +11,10 @@ from ray.tune.suggest.variant_generator import has_unresolved_values
|
||||||
from ray.tune.trial import Trial
|
from ray.tune.trial import Trial
|
||||||
from ray.tune.trainable import Trainable
|
from ray.tune.trainable import Trainable
|
||||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||||
|
from ray.tune.utils.callback import create_default_callbacks
|
||||||
from ray.tune.registry import get_trainable_cls
|
from ray.tune.registry import get_trainable_cls
|
||||||
from ray.tune.syncer import wait_for_sync, set_sync_periods, SyncConfig
|
from ray.tune.syncer import wait_for_sync, set_sync_periods, \
|
||||||
|
SyncConfig
|
||||||
from ray.tune.trial_runner import TrialRunner
|
from ray.tune.trial_runner import TrialRunner
|
||||||
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
|
from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
|
||||||
from ray.tune.schedulers import FIFOScheduler
|
from ray.tune.schedulers import FIFOScheduler
|
||||||
|
@ -353,6 +355,9 @@ def run(
|
||||||
"own `metric` and `mode` parameters. Either remove the arguments "
|
"own `metric` and `mode` parameters. Either remove the arguments "
|
||||||
"from your scheduler or from your call to `tune.run()`")
|
"from your scheduler or from your call to `tune.run()`")
|
||||||
|
|
||||||
|
# Create syncer callbacks
|
||||||
|
callbacks = create_default_callbacks(callbacks, sync_config)
|
||||||
|
|
||||||
runner = TrialRunner(
|
runner = TrialRunner(
|
||||||
search_alg=search_alg,
|
search_alg=search_alg,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
|
|
24
python/ray/tune/utils/callback.py
Normal file
24
python/ray/tune/utils/callback.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from ray.tune.callback import Callback
|
||||||
|
from ray.tune.syncer import SyncConfig
|
||||||
|
from ray.tune.syncer import SyncerCallback
|
||||||
|
|
||||||
|
|
||||||
|
def create_default_callbacks(callbacks: Optional[List[Callback]],
|
||||||
|
sync_config: SyncConfig):
|
||||||
|
|
||||||
|
callbacks = callbacks or []
|
||||||
|
|
||||||
|
# Check if there is a SyncerCallback
|
||||||
|
has_syncer_callback = any(isinstance(c, SyncerCallback) for c in callbacks)
|
||||||
|
|
||||||
|
# If no SyncerCallback was found, add
|
||||||
|
if not has_syncer_callback and os.environ.get(
|
||||||
|
"TUNE_DISABLE_AUTO_CALLBACK_SYNCER", "0") != "1":
|
||||||
|
syncer_callback = SyncerCallback(
|
||||||
|
sync_function=sync_config.sync_to_driver)
|
||||||
|
callbacks.append(syncer_callback)
|
||||||
|
|
||||||
|
return callbacks
|
Loading…
Add table
Reference in a new issue