mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] only sync up and sync down checkpoint folder for cloud checkpoint. (#21658)
By default, ~/ray_results/exp_name/trial_name/checkpoint_name. Instead of the whole trial checkpoint (~/ray_results/exp_name/trial_name/) directory. Stuff like progress.csv, result.json, params.pkl, params.json, events.out etc are coming from driver process. This could also enable us to de-couple sync up and delete - they don't have to wait for each other to finish.
This commit is contained in:
parent
e8ce01c525
commit
0abcd5eea5
8 changed files with 122 additions and 39 deletions
|
@ -411,7 +411,7 @@ class FunctionRunner(Trainable):
|
|||
def execute(self, fn):
|
||||
return fn(self)
|
||||
|
||||
def save(self, checkpoint_path=None):
|
||||
def save(self, checkpoint_path=None) -> str:
|
||||
if checkpoint_path:
|
||||
raise ValueError(
|
||||
"Checkpoint path should not be used with function API.")
|
||||
|
@ -449,7 +449,7 @@ class FunctionRunner(Trainable):
|
|||
checkpoint_path = TrainableUtil.process_checkpoint(
|
||||
checkpoint, parent_dir, state)
|
||||
|
||||
self._maybe_save_to_cloud()
|
||||
self._maybe_save_to_cloud(parent_dir)
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import distutils
|
|||
import distutils.spawn
|
||||
import inspect
|
||||
import logging
|
||||
import pathlib
|
||||
import subprocess
|
||||
import tempfile
|
||||
import types
|
||||
|
@ -262,6 +263,9 @@ class CommandBasedClient(SyncClient):
|
|||
return self._execute(self.sync_up_template, source, target, exclude)
|
||||
|
||||
def sync_down(self, source, target, exclude: Optional[List] = None):
|
||||
# Just in case some command line sync client expects that local
|
||||
# directory exists.
|
||||
pathlib.Path(target).mkdir(parents=True, exist_ok=True)
|
||||
return self._execute(self.sync_down_template, source, target, exclude)
|
||||
|
||||
def delete(self, target):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import pickle
|
||||
from collections import Counter
|
||||
import copy
|
||||
from functools import partial
|
||||
import gym
|
||||
import numpy as np
|
||||
import os
|
||||
|
@ -12,32 +13,32 @@ import unittest
|
|||
from unittest.mock import patch
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import (Trainable, TuneError, Stopper, run)
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune import (register_env, register_trainable, run, run_experiments,
|
||||
Trainable, TuneError, Stopper)
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
|
||||
AsyncHyperBandScheduler)
|
||||
from ray.tune.stopper import (MaximumIterationStopper, TrialPlateauStopper,
|
||||
ExperimentPlateauStopper)
|
||||
from ray.tune.suggest.suggestion import ConcurrencyLimiter
|
||||
from ray.tune.sync_client import CommandBasedClient
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.ray_trial_executor import noop_logger_creator
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
|
||||
EPISODES_TOTAL, TRAINING_ITERATION,
|
||||
TIMESTEPS_THIS_ITER, TIME_THIS_ITER_S,
|
||||
TIME_TOTAL_S, TRIAL_ID, EXPERIMENT_TAG)
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
|
||||
AsyncHyperBandScheduler)
|
||||
from ray.tune.stopper import (MaximumIterationStopper, TrialPlateauStopper,
|
||||
ExperimentPlateauStopper)
|
||||
from ray.tune.suggest import BasicVariantGenerator, grid_search
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.ax import AxSearch
|
||||
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
|
||||
from ray.tune.suggest.suggestion import ConcurrencyLimiter
|
||||
from ray.tune.sync_client import CommandBasedClient
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.utils import (flatten_dict, get_pinned_object,
|
||||
pin_in_object_store)
|
||||
from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR
|
||||
|
@ -971,7 +972,12 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
mock_get_client = "ray.tune.trainable.get_cloud_sync_client"
|
||||
with patch(mock_get_client) as mock_get_cloud_sync_client:
|
||||
mock_get_cloud_sync_client.return_value = sync_client
|
||||
test_trainable = trainable(remote_checkpoint_dir=MOCK_REMOTE_DIR)
|
||||
log_creator = partial(
|
||||
noop_logger_creator, logdir="~/tmp/ray_results/exp/trial")
|
||||
remote_checkpoint_dir = os.path.join(MOCK_REMOTE_DIR, "exp/trial")
|
||||
test_trainable = trainable(
|
||||
logger_creator=log_creator,
|
||||
remote_checkpoint_dir=remote_checkpoint_dir)
|
||||
result = test_trainable.train()
|
||||
self.assertEqual(result["metric"], 1)
|
||||
checkpoint_path = test_trainable.save()
|
||||
|
@ -982,6 +988,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
result = test_trainable.train()
|
||||
self.assertEqual(result["metric"], 4)
|
||||
|
||||
shutil.rmtree("~/tmp/ray_results/exp/")
|
||||
if not function:
|
||||
test_trainable.state["hi"] = 2
|
||||
test_trainable.restore(checkpoint_path)
|
||||
|
@ -990,7 +997,8 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
# Cannot re-use function trainable, create new
|
||||
tune.session.shutdown()
|
||||
test_trainable = trainable(
|
||||
remote_checkpoint_dir=MOCK_REMOTE_DIR)
|
||||
logger_creator=log_creator,
|
||||
remote_checkpoint_dir=remote_checkpoint_dir)
|
||||
test_trainable.restore(checkpoint_path)
|
||||
|
||||
result = test_trainable.train()
|
||||
|
|
|
@ -366,13 +366,16 @@ def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
|
|||
node = cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
# Only added for fake_remote case.
|
||||
# For durable case, we don't do sync to head.
|
||||
class _SyncerCallback(SyncerCallback):
|
||||
def _create_trial_syncer(self, trial: "Trial"):
|
||||
client = mock_storage_client()
|
||||
return MockNodeSyncer(trial.logdir, trial.logdir, client)
|
||||
|
||||
syncer_callback = _SyncerCallback(None)
|
||||
runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
|
||||
syncer_callback = [_SyncerCallback(None)
|
||||
] if trainable_id == "__fake_remote" else None
|
||||
runner = TrialRunner(BasicVariantGenerator(), callbacks=syncer_callback)
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 4
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import copy
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
import shutil
|
||||
import unittest
|
||||
|
@ -15,6 +16,17 @@ from ray.tune.utils.util import (flatten_dict, unflatten_dict,
|
|||
from ray.tune.utils.trainable import TrainableUtil
|
||||
|
||||
|
||||
@pytest.mark.parametrize("checkpoint_path", [
|
||||
"~/tmp/exp/trial/checkpoint0", "~/tmp/exp/trial/checkpoint0/",
|
||||
"~/tmp/exp/trial/checkpoint0/checkpoint",
|
||||
"~/tmp/exp/trial/checkpoint0/foo/bar/baz"
|
||||
])
|
||||
@pytest.mark.parametrize("logdir", ["~/tmp/exp/trial", "~/tmp/exp/trial/"])
|
||||
def test_find_rel_checkpoint_dir(checkpoint_path, logdir):
|
||||
assert TrainableUtil.find_rel_checkpoint_dir(
|
||||
logdir, checkpoint_path) == "checkpoint0/"
|
||||
|
||||
|
||||
class TrainableUtilTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.checkpoint_dir = os.path.join(
|
||||
|
|
|
@ -83,6 +83,8 @@ class Trainable:
|
|||
logger_creator (func): Function that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
remote_checkpoint_dir (str): Upload directory (S3 or GS path).
|
||||
This is **per trial** directory,
|
||||
which is different from **per checkpoint** directory.
|
||||
sync_function_tpl (str): Sync function template to use. Defaults
|
||||
to `cls._sync_function` (which defaults to `None`).
|
||||
"""
|
||||
|
@ -147,6 +149,8 @@ class Trainable:
|
|||
self.remote_checkpoint_dir)
|
||||
|
||||
def _storage_path(self, local_path):
|
||||
"""Converts a `local_path` to be based off of
|
||||
`self.remote_checkpoint_dir`."""
|
||||
rel_local_path = os.path.relpath(local_path, self.logdir)
|
||||
return os.path.join(self.remote_checkpoint_dir, rel_local_path)
|
||||
|
||||
|
@ -381,7 +385,7 @@ class Trainable:
|
|||
"ray_version": ray.__version__,
|
||||
}
|
||||
|
||||
def save(self, checkpoint_dir=None):
|
||||
def save(self, checkpoint_dir=None) -> str:
|
||||
"""Saves the current model state to a checkpoint.
|
||||
|
||||
Subclasses should override ``save_checkpoint()`` instead to save state.
|
||||
|
@ -394,7 +398,10 @@ class Trainable:
|
|||
checkpoint_dir (str): Optional dir to place the checkpoint.
|
||||
|
||||
Returns:
|
||||
str: Checkpoint path or prefix that may be passed to restore().
|
||||
str: path that points to xxx.pkl file.
|
||||
|
||||
Note the return path should match up with what is expected of
|
||||
`restore()`.
|
||||
"""
|
||||
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
checkpoint_dir or self.logdir, index=self.iteration)
|
||||
|
@ -406,15 +413,15 @@ class Trainable:
|
|||
trainable_state=trainable_state)
|
||||
|
||||
# Maybe sync to cloud
|
||||
self._maybe_save_to_cloud()
|
||||
self._maybe_save_to_cloud(checkpoint_dir)
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
def _maybe_save_to_cloud(self):
|
||||
def _maybe_save_to_cloud(self, checkpoint_dir):
|
||||
# Derived classes like the FunctionRunner might call this
|
||||
if self.uses_cloud_checkpointing:
|
||||
self.storage_client.sync_up(self.logdir,
|
||||
self.remote_checkpoint_dir)
|
||||
self.storage_client.sync_up(checkpoint_dir,
|
||||
self._storage_path(checkpoint_dir))
|
||||
self.storage_client.wait()
|
||||
|
||||
def save_to_object(self):
|
||||
|
@ -437,13 +444,29 @@ class Trainable:
|
|||
|
||||
These checkpoints are returned from calls to save().
|
||||
|
||||
Subclasses should override ``_restore()`` instead to restore state.
|
||||
Subclasses should override ``load_checkpoint()`` instead to
|
||||
restore state.
|
||||
This method restores additional metadata saved with the checkpoint.
|
||||
|
||||
`checkpoint_path` should match with the return from ``save()``.
|
||||
|
||||
`checkpoint_path` can be
|
||||
`~/ray_results/exp/MyTrainable_abc/
|
||||
checkpoint_00000/checkpoint`. Or,
|
||||
`~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.
|
||||
|
||||
`self.logdir` should generally be corresponding to `checkpoint_path`,
|
||||
for example, `~/ray_results/exp/MyTrainable_abc`.
|
||||
|
||||
`self.remote_checkpoint_dir` in this case, is something like,
|
||||
`REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`
|
||||
"""
|
||||
# Maybe sync from cloud
|
||||
if self.uses_cloud_checkpointing:
|
||||
self.storage_client.sync_down(self.remote_checkpoint_dir,
|
||||
self.logdir)
|
||||
rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
|
||||
self.logdir, checkpoint_path)
|
||||
self.storage_client.sync_down(
|
||||
os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir),
|
||||
os.path.join(self.logdir, rel_checkpoint_dir))
|
||||
self.storage_client.wait()
|
||||
|
||||
# Ensure TrialCheckpoints are converted
|
||||
|
@ -624,6 +647,8 @@ class Trainable:
|
|||
"""Create logger from logger creator.
|
||||
|
||||
Sets _logdir and _result_logger.
|
||||
|
||||
`_logdir` is the **per trial** directory for the Trainable.
|
||||
"""
|
||||
if logger_creator:
|
||||
self._result_logger = logger_creator(config)
|
||||
|
|
|
@ -444,6 +444,10 @@ class Trial:
|
|||
|
||||
@property
|
||||
def remote_checkpoint_dir(self):
|
||||
"""This is the **per trial** remote checkpoint dir.
|
||||
|
||||
This is different from **per experiment** remote checkpoint dir.
|
||||
"""
|
||||
assert self.logdir, "Trial {}: logdir not initialized.".format(self)
|
||||
if not self.remote_checkpoint_dir_prefix:
|
||||
return None
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
from typing import Dict, Any
|
||||
|
||||
import glob
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
import pandas as pd
|
||||
import ray.cloudpickle as pickle
|
||||
import os
|
||||
import pandas as pd
|
||||
import shutil
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.tune.registry import _ParameterRegistry
|
||||
from ray.tune.utils import detect_checkpoint_function
|
||||
from ray.util import placement_group
|
||||
|
@ -21,7 +19,22 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class TrainableUtil:
|
||||
@staticmethod
|
||||
def process_checkpoint(checkpoint, parent_dir, trainable_state):
|
||||
def process_checkpoint(checkpoint: Union[Dict, str], parent_dir: str,
|
||||
trainable_state: Dict) -> str:
|
||||
"""Creates checkpoint file structure and writes metadata
|
||||
under `parent_dir`.
|
||||
|
||||
The file structure could either look like:
|
||||
- checkpoint_00000 (returned path)
|
||||
-- .is_checkpoint
|
||||
-- .tune_metadata
|
||||
-- xxx.pkl (or whatever user specifies in their Trainable)
|
||||
Or,
|
||||
- checkpoint_00000
|
||||
-- .is_checkpoint
|
||||
-- checkpoint (returned path)
|
||||
-- checkpoint.tune_metadata
|
||||
"""
|
||||
saved_as_dict = False
|
||||
if isinstance(checkpoint, string_types):
|
||||
if not checkpoint.startswith(parent_dir):
|
||||
|
@ -100,6 +113,20 @@ class TrainableUtil:
|
|||
checkpoint_path))
|
||||
return os.path.normpath(checkpoint_dir)
|
||||
|
||||
@staticmethod
|
||||
def find_rel_checkpoint_dir(logdir, checkpoint_path):
|
||||
"""Returns the (relative) directory name of the checkpoint.
|
||||
|
||||
Note, the assumption here is `logdir` should be the prefix of
|
||||
`checkpoint_path`.
|
||||
For example, returns `checkpoint00000/`.
|
||||
"""
|
||||
assert checkpoint_path.startswith(
|
||||
logdir), "expecting `logdir` to be a prefix of `checkpoint_path`"
|
||||
rel_path = os.path.relpath(checkpoint_path, logdir)
|
||||
tokens = rel_path.split(os.sep)
|
||||
return os.path.join(tokens[0], "")
|
||||
|
||||
@staticmethod
|
||||
def make_checkpoint_dir(checkpoint_dir, index, override=False):
|
||||
"""Creates a checkpoint directory within the provided path.
|
||||
|
|
Loading…
Add table
Reference in a new issue