diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 67fd5a665..06d6d7a97 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -202,6 +202,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): # json.load leads to str -> unicode in py2.7 stopping_criterion=spec.get("stop", {}), remote_checkpoint_dir=spec.get("remote_checkpoint_dir"), + sync_to_cloud=spec.get("sync_to_cloud"), checkpoint_freq=args.checkpoint_freq, checkpoint_at_end=args.checkpoint_at_end, sync_on_checkpoint=args.sync_on_checkpoint, diff --git a/python/ray/tune/durable_trainable.py b/python/ray/tune/durable_trainable.py index aba1d9839..0ca3c692d 100644 --- a/python/ray/tune/durable_trainable.py +++ b/python/ray/tune/durable_trainable.py @@ -1,4 +1,4 @@ -from typing import Callable, Type, Union +from typing import Callable, Optional, Type, Union import inspect import logging @@ -6,6 +6,7 @@ import os from ray.tune.function_runner import wrap_function from ray.tune.registry import get_trainable_cls +from ray.tune.sync_client import get_sync_client from ray.tune.trainable import Trainable, TrainableUtil from ray.tune.syncer import get_cloud_sync_client @@ -31,15 +32,23 @@ class DurableTrainable(Trainable): >>> tune.run(MyDurableTrainable, sync_to_driver=False) """ + _sync_function_tpl = None - def __init__(self, remote_checkpoint_dir, *args, **kwargs): + def __init__(self, + remote_checkpoint_dir: str, + sync_function_tpl: Optional[str] = None, + *args, + **kwargs): """Initializes a DurableTrainable. Args: remote_checkpoint_dir (str): Upload directory (S3 or GS path). + sync_function_tpl (str): Sync function template to use. Defaults + to `cls._sync_function` (which defaults to `None`). """ super(DurableTrainable, self).__init__(*args, **kwargs) self.remote_checkpoint_dir = remote_checkpoint_dir + self.sync_function_tpl = sync_function_tpl or self._sync_function_tpl self.storage_client = self._create_storage_client() def save(self, checkpoint_dir=None): @@ -96,7 +105,9 @@ class DurableTrainable(Trainable): def _create_storage_client(self): """Returns a storage client.""" - return get_cloud_sync_client(self.remote_checkpoint_dir) + return get_sync_client( + self.sync_function_tpl) or get_cloud_sync_client( + self.remote_checkpoint_dir) def _storage_path(self, local_path): rel_local_path = os.path.relpath(local_path, self.logdir) diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 5c18f48e4..51db759f4 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -115,6 +115,7 @@ class Experiment: loggers=None, log_to_file=False, sync_to_driver=None, + sync_to_cloud=None, checkpoint_freq=0, checkpoint_at_end=False, sync_on_checkpoint=True, @@ -217,6 +218,7 @@ class Experiment: "loggers": loggers, "log_to_file": (stdout_file, stderr_file), "sync_to_driver": sync_to_driver, + "sync_to_cloud": sync_to_cloud, "checkpoint_freq": checkpoint_freq, "checkpoint_at_end": checkpoint_at_end, "sync_on_checkpoint": sync_on_checkpoint, diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 5221a133f..ee7e7e461 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -361,6 +361,7 @@ class RayTrialExecutor(TrialExecutor): } if issubclass(trial.get_trainable_cls(), DurableTrainable): kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir + kwargs["sync_function_tpl"] = trial.sync_to_cloud with self._change_working_directory(trial): return full_actor_class.remote(**kwargs) diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 0638ffa04..cf7454a4f 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -21,6 +21,7 @@ from ray.tune.durable_trainable import durable from ray.tune.schedulers import (TrialScheduler, FIFOScheduler, AsyncHyperBandScheduler) from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper +from ray.tune.sync_client import CommandBasedClient from ray.tune.trial import Trial from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID, EPISODES_TOTAL, TRAINING_ITERATION, @@ -29,7 +30,7 @@ from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID, from ray.tune.logger import Logger from ray.tune.experiment import Experiment from ray.tune.resources import Resources -from ray.tune.suggest import grid_search +from ray.tune.suggest import BasicVariantGenerator, grid_search from ray.tune.suggest.hyperopt import HyperOptSearch from ray.tune.suggest.ax import AxSearch from ray.tune.suggest._mock import _MockSuggestionAlgorithm @@ -942,6 +943,58 @@ class TrainableFunctionApiTest(unittest.TestCase): self._testDurableTrainable(durable(test_train), function=True) + def testDurableTrainableSyncFunction(self): + """Check custom sync functions in durable trainables""" + + class TestDurable(DurableTrainable): + def __init__(self, *args, **kwargs): + # Mock distutils.spawn.find_executable + # so `aws` command is found + import distutils.spawn + distutils.spawn.find_executable = lambda *_, **__: True + super(TestDurable, self).__init__(*args, **kwargs) + + def check(self): + return bool(self.sync_function_tpl) and isinstance( + self.storage_client, CommandBasedClient + ) and "aws" not in self.storage_client.sync_up_template + + class TestTplDurable(TestDurable): + _sync_function_tpl = "echo static sync {source} {target}" + + upload_dir = "s3://test-bucket/path" + + def _create_remote_actor(trainable_cls, sync_to_cloud): + """Create a remote trainable actor from an experiment""" + exp = Experiment( + name="test_durable_sync", + run=trainable_cls, + sync_to_cloud=sync_to_cloud, + sync_to_driver=False, + upload_dir=upload_dir) + + searchers = BasicVariantGenerator() + searchers.add_configurations([exp]) + trial = searchers.next_trial() + cls = trial.get_trainable_cls() + actor = ray.remote(cls).remote( + remote_checkpoint_dir=upload_dir, + sync_function_tpl=trial.sync_to_cloud) + return actor + + # This actor should create a default aws syncer, so check should fail + actor1 = _create_remote_actor(TestDurable, None) + self.assertFalse(ray.get(actor1.check.remote())) + + # This actor should create a custom syncer, so check should pass + actor2 = _create_remote_actor(TestDurable, + "echo test sync {source} {target}") + self.assertTrue(ray.get(actor2.check.remote())) + + # This actor should create a custom syncer, so check should pass + actor3 = _create_remote_actor(TestTplDurable, None) + self.assertTrue(ray.get(actor3.check.remote())) + def testCheckpointDict(self): class TestTrain(Trainable): def setup(self, config): diff --git a/python/ray/tune/tests/test_sync.py b/python/ray/tune/tests/test_sync.py index 3675da055..bced13edb 100644 --- a/python/ray/tune/tests/test_sync.py +++ b/python/ray/tune/tests/test_sync.py @@ -1,5 +1,6 @@ import glob import os +import pickle import shutil import sys import tempfile @@ -169,10 +170,16 @@ class TestSyncFunctionality(unittest.TestCase): time.sleep(1) tune.report(score=i) - mock = unittest.mock.Mock() - def counter(local, remote): - mock() + count_file = os.path.join(tmpdir, "count.txt") + if not os.path.exists(count_file): + count = 0 + else: + with open(count_file, "rb") as fp: + count = pickle.load(fp) + count += 1 + with open(count_file, "wb") as fp: + pickle.dump(count, fp) sync_config = tune.SyncConfig( upload_dir="test", sync_to_cloud=counter, cloud_sync_period=1) @@ -191,7 +198,11 @@ class TestSyncFunctionality(unittest.TestCase): sync_config=sync_config, ).trials - self.assertEqual(mock.call_count, 12) + count_file = os.path.join(tmpdir, "count.txt") + with open(count_file, "rb") as fp: + count = pickle.load(fp) + + self.assertEqual(count, 12) shutil.rmtree(tmpdir) def testClusterSyncFunction(self): diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index c27ffa1d8..60fb1ecab 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -187,6 +187,7 @@ class Trial: placement_group_factory=None, stopping_criterion=None, remote_checkpoint_dir=None, + sync_to_cloud=None, checkpoint_freq=0, checkpoint_at_end=False, sync_on_checkpoint=True, @@ -283,6 +284,7 @@ class Trial: self.remote_checkpoint_dir_prefix = remote_checkpoint_dir else: self.remote_checkpoint_dir_prefix = None + self.sync_to_cloud = sync_to_cloud self.checkpoint_freq = checkpoint_freq self.checkpoint_at_end = checkpoint_at_end self.keep_checkpoints_num = keep_checkpoints_num diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index f820c9394..a61bf6f00 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -397,6 +397,7 @@ def run( local_dir=local_dir, upload_dir=sync_config.upload_dir, sync_to_driver=sync_config.sync_to_driver, + sync_to_cloud=sync_config.sync_to_cloud, trial_name_creator=trial_name_creator, trial_dirname_creator=trial_dirname_creator, log_to_file=log_to_file, diff --git a/python/ray/tune/utils/mock.py b/python/ray/tune/utils/mock.py index 50de8868d..ca82a6244 100644 --- a/python/ray/tune/utils/mock.py +++ b/python/ray/tune/utils/mock.py @@ -65,7 +65,8 @@ class MockDurableTrainer(DurableTrainable, _MockTrainer): # TODO(ujvl): This class uses multiple inheritance; it should be cleaned # up once the durable training API converges. - def __init__(self, remote_checkpoint_dir, *args, **kwargs): + def __init__(self, remote_checkpoint_dir, sync_function_tpl, *args, + **kwargs): _MockTrainer.__init__(self, *args, **kwargs) DurableTrainable.__init__(self, remote_checkpoint_dir, *args, **kwargs) diff --git a/python/ray/tune/utils/release_test_util.py b/python/ray/tune/utils/release_test_util.py index 7c4074cb7..177b134ff 100644 --- a/python/ray/tune/utils/release_test_util.py +++ b/python/ray/tune/utils/release_test_util.py @@ -34,7 +34,11 @@ class ProgressCallback(tune.callback.Callback): class TestDurableTrainable(DurableTrainable): - def __init__(self, remote_checkpoint_dir, config, logger_creator=None): + def __init__(self, + remote_checkpoint_dir, + config, + logger_creator=None, + **kwargs): self.setup_env() super(TestDurableTrainable, self).__init__(