mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Pass custom sync_to_cloud
templates to durable trainables (#16739)
This commit is contained in:
parent
e250abf689
commit
4178655ba7
10 changed files with 97 additions and 10 deletions
|
@ -202,6 +202,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
|||
# json.load leads to str -> unicode in py2.7
|
||||
stopping_criterion=spec.get("stop", {}),
|
||||
remote_checkpoint_dir=spec.get("remote_checkpoint_dir"),
|
||||
sync_to_cloud=spec.get("sync_to_cloud"),
|
||||
checkpoint_freq=args.checkpoint_freq,
|
||||
checkpoint_at_end=args.checkpoint_at_end,
|
||||
sync_on_checkpoint=args.sync_on_checkpoint,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Type, Union
|
||||
from typing import Callable, Optional, Type, Union
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
@ -6,6 +6,7 @@ import os
|
|||
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.sync_client import get_sync_client
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.syncer import get_cloud_sync_client
|
||||
|
||||
|
@ -31,15 +32,23 @@ class DurableTrainable(Trainable):
|
|||
|
||||
>>> tune.run(MyDurableTrainable, sync_to_driver=False)
|
||||
"""
|
||||
_sync_function_tpl = None
|
||||
|
||||
def __init__(self, remote_checkpoint_dir, *args, **kwargs):
|
||||
def __init__(self,
|
||||
remote_checkpoint_dir: str,
|
||||
sync_function_tpl: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""Initializes a DurableTrainable.
|
||||
|
||||
Args:
|
||||
remote_checkpoint_dir (str): Upload directory (S3 or GS path).
|
||||
sync_function_tpl (str): Sync function template to use. Defaults
|
||||
to `cls._sync_function` (which defaults to `None`).
|
||||
"""
|
||||
super(DurableTrainable, self).__init__(*args, **kwargs)
|
||||
self.remote_checkpoint_dir = remote_checkpoint_dir
|
||||
self.sync_function_tpl = sync_function_tpl or self._sync_function_tpl
|
||||
self.storage_client = self._create_storage_client()
|
||||
|
||||
def save(self, checkpoint_dir=None):
|
||||
|
@ -96,7 +105,9 @@ class DurableTrainable(Trainable):
|
|||
|
||||
def _create_storage_client(self):
|
||||
"""Returns a storage client."""
|
||||
return get_cloud_sync_client(self.remote_checkpoint_dir)
|
||||
return get_sync_client(
|
||||
self.sync_function_tpl) or get_cloud_sync_client(
|
||||
self.remote_checkpoint_dir)
|
||||
|
||||
def _storage_path(self, local_path):
|
||||
rel_local_path = os.path.relpath(local_path, self.logdir)
|
||||
|
|
|
@ -115,6 +115,7 @@ class Experiment:
|
|||
loggers=None,
|
||||
log_to_file=False,
|
||||
sync_to_driver=None,
|
||||
sync_to_cloud=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
sync_on_checkpoint=True,
|
||||
|
@ -217,6 +218,7 @@ class Experiment:
|
|||
"loggers": loggers,
|
||||
"log_to_file": (stdout_file, stderr_file),
|
||||
"sync_to_driver": sync_to_driver,
|
||||
"sync_to_cloud": sync_to_cloud,
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"checkpoint_at_end": checkpoint_at_end,
|
||||
"sync_on_checkpoint": sync_on_checkpoint,
|
||||
|
|
|
@ -361,6 +361,7 @@ class RayTrialExecutor(TrialExecutor):
|
|||
}
|
||||
if issubclass(trial.get_trainable_cls(), DurableTrainable):
|
||||
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
|
||||
kwargs["sync_function_tpl"] = trial.sync_to_cloud
|
||||
|
||||
with self._change_working_directory(trial):
|
||||
return full_actor_class.remote(**kwargs)
|
||||
|
|
|
@ -21,6 +21,7 @@ from ray.tune.durable_trainable import durable
|
|||
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
|
||||
AsyncHyperBandScheduler)
|
||||
from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper
|
||||
from ray.tune.sync_client import CommandBasedClient
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
|
||||
EPISODES_TOTAL, TRAINING_ITERATION,
|
||||
|
@ -29,7 +30,7 @@ from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
|
|||
from ray.tune.logger import Logger
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.suggest import BasicVariantGenerator, grid_search
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.ax import AxSearch
|
||||
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
|
||||
|
@ -942,6 +943,58 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
|
||||
self._testDurableTrainable(durable(test_train), function=True)
|
||||
|
||||
def testDurableTrainableSyncFunction(self):
|
||||
"""Check custom sync functions in durable trainables"""
|
||||
|
||||
class TestDurable(DurableTrainable):
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Mock distutils.spawn.find_executable
|
||||
# so `aws` command is found
|
||||
import distutils.spawn
|
||||
distutils.spawn.find_executable = lambda *_, **__: True
|
||||
super(TestDurable, self).__init__(*args, **kwargs)
|
||||
|
||||
def check(self):
|
||||
return bool(self.sync_function_tpl) and isinstance(
|
||||
self.storage_client, CommandBasedClient
|
||||
) and "aws" not in self.storage_client.sync_up_template
|
||||
|
||||
class TestTplDurable(TestDurable):
|
||||
_sync_function_tpl = "echo static sync {source} {target}"
|
||||
|
||||
upload_dir = "s3://test-bucket/path"
|
||||
|
||||
def _create_remote_actor(trainable_cls, sync_to_cloud):
|
||||
"""Create a remote trainable actor from an experiment"""
|
||||
exp = Experiment(
|
||||
name="test_durable_sync",
|
||||
run=trainable_cls,
|
||||
sync_to_cloud=sync_to_cloud,
|
||||
sync_to_driver=False,
|
||||
upload_dir=upload_dir)
|
||||
|
||||
searchers = BasicVariantGenerator()
|
||||
searchers.add_configurations([exp])
|
||||
trial = searchers.next_trial()
|
||||
cls = trial.get_trainable_cls()
|
||||
actor = ray.remote(cls).remote(
|
||||
remote_checkpoint_dir=upload_dir,
|
||||
sync_function_tpl=trial.sync_to_cloud)
|
||||
return actor
|
||||
|
||||
# This actor should create a default aws syncer, so check should fail
|
||||
actor1 = _create_remote_actor(TestDurable, None)
|
||||
self.assertFalse(ray.get(actor1.check.remote()))
|
||||
|
||||
# This actor should create a custom syncer, so check should pass
|
||||
actor2 = _create_remote_actor(TestDurable,
|
||||
"echo test sync {source} {target}")
|
||||
self.assertTrue(ray.get(actor2.check.remote()))
|
||||
|
||||
# This actor should create a custom syncer, so check should pass
|
||||
actor3 = _create_remote_actor(TestTplDurable, None)
|
||||
self.assertTrue(ray.get(actor3.check.remote()))
|
||||
|
||||
def testCheckpointDict(self):
|
||||
class TestTrain(Trainable):
|
||||
def setup(self, config):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import glob
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
@ -169,10 +170,16 @@ class TestSyncFunctionality(unittest.TestCase):
|
|||
time.sleep(1)
|
||||
tune.report(score=i)
|
||||
|
||||
mock = unittest.mock.Mock()
|
||||
|
||||
def counter(local, remote):
|
||||
mock()
|
||||
count_file = os.path.join(tmpdir, "count.txt")
|
||||
if not os.path.exists(count_file):
|
||||
count = 0
|
||||
else:
|
||||
with open(count_file, "rb") as fp:
|
||||
count = pickle.load(fp)
|
||||
count += 1
|
||||
with open(count_file, "wb") as fp:
|
||||
pickle.dump(count, fp)
|
||||
|
||||
sync_config = tune.SyncConfig(
|
||||
upload_dir="test", sync_to_cloud=counter, cloud_sync_period=1)
|
||||
|
@ -191,7 +198,11 @@ class TestSyncFunctionality(unittest.TestCase):
|
|||
sync_config=sync_config,
|
||||
).trials
|
||||
|
||||
self.assertEqual(mock.call_count, 12)
|
||||
count_file = os.path.join(tmpdir, "count.txt")
|
||||
with open(count_file, "rb") as fp:
|
||||
count = pickle.load(fp)
|
||||
|
||||
self.assertEqual(count, 12)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testClusterSyncFunction(self):
|
||||
|
|
|
@ -187,6 +187,7 @@ class Trial:
|
|||
placement_group_factory=None,
|
||||
stopping_criterion=None,
|
||||
remote_checkpoint_dir=None,
|
||||
sync_to_cloud=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
sync_on_checkpoint=True,
|
||||
|
@ -283,6 +284,7 @@ class Trial:
|
|||
self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
|
||||
else:
|
||||
self.remote_checkpoint_dir_prefix = None
|
||||
self.sync_to_cloud = sync_to_cloud
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
self.keep_checkpoints_num = keep_checkpoints_num
|
||||
|
|
|
@ -397,6 +397,7 @@ def run(
|
|||
local_dir=local_dir,
|
||||
upload_dir=sync_config.upload_dir,
|
||||
sync_to_driver=sync_config.sync_to_driver,
|
||||
sync_to_cloud=sync_config.sync_to_cloud,
|
||||
trial_name_creator=trial_name_creator,
|
||||
trial_dirname_creator=trial_dirname_creator,
|
||||
log_to_file=log_to_file,
|
||||
|
|
|
@ -65,7 +65,8 @@ class MockDurableTrainer(DurableTrainable, _MockTrainer):
|
|||
# TODO(ujvl): This class uses multiple inheritance; it should be cleaned
|
||||
# up once the durable training API converges.
|
||||
|
||||
def __init__(self, remote_checkpoint_dir, *args, **kwargs):
|
||||
def __init__(self, remote_checkpoint_dir, sync_function_tpl, *args,
|
||||
**kwargs):
|
||||
_MockTrainer.__init__(self, *args, **kwargs)
|
||||
DurableTrainable.__init__(self, remote_checkpoint_dir, *args, **kwargs)
|
||||
|
||||
|
|
|
@ -34,7 +34,11 @@ class ProgressCallback(tune.callback.Callback):
|
|||
|
||||
|
||||
class TestDurableTrainable(DurableTrainable):
|
||||
def __init__(self, remote_checkpoint_dir, config, logger_creator=None):
|
||||
def __init__(self,
|
||||
remote_checkpoint_dir,
|
||||
config,
|
||||
logger_creator=None,
|
||||
**kwargs):
|
||||
self.setup_env()
|
||||
|
||||
super(TestDurableTrainable, self).__init__(
|
||||
|
|
Loading…
Add table
Reference in a new issue