[tune] only sync up and sync down checkpoint folder for cloud checkpoint. (#21658)

By default, ~/ray_results/exp_name/trial_name/checkpoint_name.
Instead of the whole trial checkpoint (~/ray_results/exp_name/trial_name/) directory.
Stuff like progress.csv, result.json, params.pkl, params.json, events.out etc are coming from driver process.
This could also enable us to de-couple sync up and delete - they don't have to wait for each other to finish.
This commit is contained in:
xwjiang2010 2022-01-21 17:56:05 -08:00 committed by GitHub
parent e8ce01c525
commit 0abcd5eea5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 122 additions and 39 deletions

View file

@ -411,7 +411,7 @@ class FunctionRunner(Trainable):
def execute(self, fn):
return fn(self)
def save(self, checkpoint_path=None):
def save(self, checkpoint_path=None) -> str:
if checkpoint_path:
raise ValueError(
"Checkpoint path should not be used with function API.")
@ -449,7 +449,7 @@ class FunctionRunner(Trainable):
checkpoint_path = TrainableUtil.process_checkpoint(
checkpoint, parent_dir, state)
self._maybe_save_to_cloud()
self._maybe_save_to_cloud(parent_dir)
return checkpoint_path

View file

@ -2,6 +2,7 @@ import distutils
import distutils.spawn
import inspect
import logging
import pathlib
import subprocess
import tempfile
import types
@ -262,6 +263,9 @@ class CommandBasedClient(SyncClient):
return self._execute(self.sync_up_template, source, target, exclude)
def sync_down(self, source, target, exclude: Optional[List] = None):
# Just in case some command line sync client expects that local
# directory exists.
pathlib.Path(target).mkdir(parents=True, exist_ok=True)
return self._execute(self.sync_down_template, source, target, exclude)
def delete(self, target):

View file

@ -1,6 +1,7 @@
import pickle
from collections import Counter
import copy
from functools import partial
import gym
import numpy as np
import os
@ -12,32 +13,32 @@ import unittest
from unittest.mock import patch
import ray
from ray.rllib import _register_all
from ray import tune
from ray.tune import (Trainable, TuneError, Stopper, run)
from ray.tune.function_runner import wrap_function
from ray.tune import register_env, register_trainable, run_experiments
from ray.rllib import _register_all
from ray.tune import (register_env, register_trainable, run, run_experiments,
Trainable, TuneError, Stopper)
from ray.tune.callback import Callback
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
AsyncHyperBandScheduler)
from ray.tune.stopper import (MaximumIterationStopper, TrialPlateauStopper,
ExperimentPlateauStopper)
from ray.tune.suggest.suggestion import ConcurrencyLimiter
from ray.tune.sync_client import CommandBasedClient
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.experiment import Experiment
from ray.tune.function_runner import wrap_function
from ray.tune.logger import Logger
from ray.tune.ray_trial_executor import noop_logger_creator
from ray.tune.resources import Resources
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
EPISODES_TOTAL, TRAINING_ITERATION,
TIMESTEPS_THIS_ITER, TIME_THIS_ITER_S,
TIME_TOTAL_S, TRIAL_ID, EXPERIMENT_TAG)
from ray.tune.logger import Logger
from ray.tune.experiment import Experiment
from ray.tune.resources import Resources
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
AsyncHyperBandScheduler)
from ray.tune.stopper import (MaximumIterationStopper, TrialPlateauStopper,
ExperimentPlateauStopper)
from ray.tune.suggest import BasicVariantGenerator, grid_search
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.ax import AxSearch
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
from ray.tune.suggest.suggestion import ConcurrencyLimiter
from ray.tune.sync_client import CommandBasedClient
from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.utils import (flatten_dict, get_pinned_object,
pin_in_object_store)
from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR
@ -971,7 +972,12 @@ class TrainableFunctionApiTest(unittest.TestCase):
mock_get_client = "ray.tune.trainable.get_cloud_sync_client"
with patch(mock_get_client) as mock_get_cloud_sync_client:
mock_get_cloud_sync_client.return_value = sync_client
test_trainable = trainable(remote_checkpoint_dir=MOCK_REMOTE_DIR)
log_creator = partial(
noop_logger_creator, logdir="~/tmp/ray_results/exp/trial")
remote_checkpoint_dir = os.path.join(MOCK_REMOTE_DIR, "exp/trial")
test_trainable = trainable(
logger_creator=log_creator,
remote_checkpoint_dir=remote_checkpoint_dir)
result = test_trainable.train()
self.assertEqual(result["metric"], 1)
checkpoint_path = test_trainable.save()
@ -982,6 +988,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
result = test_trainable.train()
self.assertEqual(result["metric"], 4)
shutil.rmtree("~/tmp/ray_results/exp/")
if not function:
test_trainable.state["hi"] = 2
test_trainable.restore(checkpoint_path)
@ -990,7 +997,8 @@ class TrainableFunctionApiTest(unittest.TestCase):
# Cannot re-use function trainable, create new
tune.session.shutdown()
test_trainable = trainable(
remote_checkpoint_dir=MOCK_REMOTE_DIR)
logger_creator=log_creator,
remote_checkpoint_dir=remote_checkpoint_dir)
test_trainable.restore(checkpoint_path)
result = test_trainable.train()

View file

@ -366,13 +366,16 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
node = cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
# Only added for fake_remote case.
# For durable case, we don't do sync to head.
class _SyncerCallback(SyncerCallback):
def _create_trial_syncer(self, trial: "Trial"):
client = mock_storage_client()
return MockNodeSyncer(trial.logdir, trial.logdir, client)
syncer_callback = _SyncerCallback(None)
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
syncer_callback = [_SyncerCallback(None)
] if trainable_id == "__fake_remote" else None
runner = TrialRunner(BasicVariantGenerator(), callbacks=syncer_callback)
kwargs = {
"stopping_criterion": {
"training_iteration": 4

View file

@ -1,6 +1,7 @@
import copy
from collections import OrderedDict
import os
import pytest
import sys
import shutil
import unittest
@ -15,6 +16,17 @@ from ray.tune.utils.util import (flatten_dict, unflatten_dict,
from ray.tune.utils.trainable import TrainableUtil
@pytest.mark.parametrize("checkpoint_path", [
"~/tmp/exp/trial/checkpoint0", "~/tmp/exp/trial/checkpoint0/",
"~/tmp/exp/trial/checkpoint0/checkpoint",
"~/tmp/exp/trial/checkpoint0/foo/bar/baz"
])
@pytest.mark.parametrize("logdir", ["~/tmp/exp/trial", "~/tmp/exp/trial/"])
def test_find_rel_checkpoint_dir(checkpoint_path, logdir):
assert TrainableUtil.find_rel_checkpoint_dir(
logdir, checkpoint_path) == "checkpoint0/"
class TrainableUtilTest(unittest.TestCase):
def setUp(self):
self.checkpoint_dir = os.path.join(

View file

@ -83,6 +83,8 @@ class Trainable:
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
remote_checkpoint_dir (str): Upload directory (S3 or GS path).
This is **per trial** directory,
which is different from **per checkpoint** directory.
sync_function_tpl (str): Sync function template to use. Defaults
to `cls._sync_function` (which defaults to `None`).
"""
@ -147,6 +149,8 @@ class Trainable:
self.remote_checkpoint_dir)
def _storage_path(self, local_path):
"""Converts a `local_path` to be based off of
`self.remote_checkpoint_dir`."""
rel_local_path = os.path.relpath(local_path, self.logdir)
return os.path.join(self.remote_checkpoint_dir, rel_local_path)
@ -381,7 +385,7 @@ class Trainable:
"ray_version": ray.__version__,
}
def save(self, checkpoint_dir=None):
def save(self, checkpoint_dir=None) -> str:
"""Saves the current model state to a checkpoint.
Subclasses should override ``save_checkpoint()`` instead to save state.
@ -394,7 +398,10 @@ class Trainable:
checkpoint_dir (str): Optional dir to place the checkpoint.
Returns:
str: Checkpoint path or prefix that may be passed to restore().
str: path that points to xxx.pkl file.
Note the return path should match up with what is expected of
`restore()`.
"""
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
checkpoint_dir or self.logdir, index=self.iteration)
@ -406,15 +413,15 @@ class Trainable:
trainable_state=trainable_state)
# Maybe sync to cloud
self._maybe_save_to_cloud()
self._maybe_save_to_cloud(checkpoint_dir)
return checkpoint_path
def _maybe_save_to_cloud(self):
def _maybe_save_to_cloud(self, checkpoint_dir):
# Derived classes like the FunctionRunner might call this
if self.uses_cloud_checkpointing:
self.storage_client.sync_up(self.logdir,
self.remote_checkpoint_dir)
self.storage_client.sync_up(checkpoint_dir,
self._storage_path(checkpoint_dir))
self.storage_client.wait()
def save_to_object(self):
@ -437,13 +444,29 @@ class Trainable:
These checkpoints are returned from calls to save().
Subclasses should override ``_restore()`` instead to restore state.
Subclasses should override ``load_checkpoint()`` instead to
restore state.
This method restores additional metadata saved with the checkpoint.
`checkpoint_path` should match with the return from ``save()``.
`checkpoint_path` can be
`~/ray_results/exp/MyTrainable_abc/
checkpoint_00000/checkpoint`. Or,
`~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.
`self.logdir` should generally be corresponding to `checkpoint_path`,
for example, `~/ray_results/exp/MyTrainable_abc`.
`self.remote_checkpoint_dir` in this case, is something like,
`REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`
"""
# Maybe sync from cloud
if self.uses_cloud_checkpointing:
self.storage_client.sync_down(self.remote_checkpoint_dir,
self.logdir)
rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
self.logdir, checkpoint_path)
self.storage_client.sync_down(
os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir),
os.path.join(self.logdir, rel_checkpoint_dir))
self.storage_client.wait()
# Ensure TrialCheckpoints are converted
@ -624,6 +647,8 @@ class Trainable:
"""Create logger from logger creator.
Sets _logdir and _result_logger.
`_logdir` is the **per trial** directory for the Trainable.
"""
if logger_creator:
self._result_logger = logger_creator(config)

View file

@ -444,6 +444,10 @@ class Trial:
@property
def remote_checkpoint_dir(self):
"""This is the **per trial** remote checkpoint dir.
This is different from **per experiment** remote checkpoint dir.
"""
assert self.logdir, "Trial {}: logdir not initialized.".format(self)
if not self.remote_checkpoint_dir_prefix:
return None

View file

@ -1,16 +1,14 @@
from typing import Dict, Any
import glob
import inspect
import io
import logging
import shutil
import pandas as pd
import ray.cloudpickle as pickle
import os
import pandas as pd
import shutil
from typing import Any, Dict, Union
import ray
import ray.cloudpickle as pickle
from ray.tune.registry import _ParameterRegistry
from ray.tune.utils import detect_checkpoint_function
from ray.util import placement_group
@ -21,7 +19,22 @@ logger = logging.getLogger(__name__)
class TrainableUtil:
@staticmethod
def process_checkpoint(checkpoint, parent_dir, trainable_state):
def process_checkpoint(checkpoint: Union[Dict, str], parent_dir: str,
trainable_state: Dict) -> str:
"""Creates checkpoint file structure and writes metadata
under `parent_dir`.
The file structure could either look like:
- checkpoint_00000 (returned path)
-- .is_checkpoint
-- .tune_metadata
-- xxx.pkl (or whatever user specifies in their Trainable)
Or,
- checkpoint_00000
-- .is_checkpoint
-- checkpoint (returned path)
-- checkpoint.tune_metadata
"""
saved_as_dict = False
if isinstance(checkpoint, string_types):
if not checkpoint.startswith(parent_dir):
@ -100,6 +113,20 @@ class TrainableUtil:
checkpoint_path))
return os.path.normpath(checkpoint_dir)
@staticmethod
def find_rel_checkpoint_dir(logdir, checkpoint_path):
"""Returns the (relative) directory name of the checkpoint.
Note, the assumption here is `logdir` should be the prefix of
`checkpoint_path`.
For example, returns `checkpoint00000/`.
"""
assert checkpoint_path.startswith(
logdir), "expecting `logdir` to be a prefix of `checkpoint_path`"
rel_path = os.path.relpath(checkpoint_path, logdir)
tokens = rel_path.split(os.sep)
return os.path.join(tokens[0], "")
@staticmethod
def make_checkpoint_dir(checkpoint_dir, index, override=False):
"""Creates a checkpoint directory within the provided path.