[tune] Pass custom sync_to_cloud templates to durable trainables (#16739)

This commit is contained in:
Kai Fricke 2021-07-06 10:50:59 +02:00 committed by GitHub
parent e250abf689
commit 4178655ba7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 97 additions and 10 deletions

View file

@ -202,6 +202,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
# json.load leads to str -> unicode in py2.7 # json.load leads to str -> unicode in py2.7
stopping_criterion=spec.get("stop", {}), stopping_criterion=spec.get("stop", {}),
remote_checkpoint_dir=spec.get("remote_checkpoint_dir"), remote_checkpoint_dir=spec.get("remote_checkpoint_dir"),
sync_to_cloud=spec.get("sync_to_cloud"),
checkpoint_freq=args.checkpoint_freq, checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end, checkpoint_at_end=args.checkpoint_at_end,
sync_on_checkpoint=args.sync_on_checkpoint, sync_on_checkpoint=args.sync_on_checkpoint,

View file

@ -1,4 +1,4 @@
from typing import Callable, Type, Union from typing import Callable, Optional, Type, Union
import inspect import inspect
import logging import logging
@ -6,6 +6,7 @@ import os
from ray.tune.function_runner import wrap_function from ray.tune.function_runner import wrap_function
from ray.tune.registry import get_trainable_cls from ray.tune.registry import get_trainable_cls
from ray.tune.sync_client import get_sync_client
from ray.tune.trainable import Trainable, TrainableUtil from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.syncer import get_cloud_sync_client from ray.tune.syncer import get_cloud_sync_client
@ -31,15 +32,23 @@ class DurableTrainable(Trainable):
>>> tune.run(MyDurableTrainable, sync_to_driver=False) >>> tune.run(MyDurableTrainable, sync_to_driver=False)
""" """
_sync_function_tpl = None
def __init__(self, remote_checkpoint_dir, *args, **kwargs): def __init__(self,
remote_checkpoint_dir: str,
sync_function_tpl: Optional[str] = None,
*args,
**kwargs):
"""Initializes a DurableTrainable. """Initializes a DurableTrainable.
Args: Args:
remote_checkpoint_dir (str): Upload directory (S3 or GS path). remote_checkpoint_dir (str): Upload directory (S3 or GS path).
sync_function_tpl (str): Sync function template to use. Defaults
to `cls._sync_function` (which defaults to `None`).
""" """
super(DurableTrainable, self).__init__(*args, **kwargs) super(DurableTrainable, self).__init__(*args, **kwargs)
self.remote_checkpoint_dir = remote_checkpoint_dir self.remote_checkpoint_dir = remote_checkpoint_dir
self.sync_function_tpl = sync_function_tpl or self._sync_function_tpl
self.storage_client = self._create_storage_client() self.storage_client = self._create_storage_client()
def save(self, checkpoint_dir=None): def save(self, checkpoint_dir=None):
@ -96,7 +105,9 @@ class DurableTrainable(Trainable):
def _create_storage_client(self): def _create_storage_client(self):
"""Returns a storage client.""" """Returns a storage client."""
return get_cloud_sync_client(self.remote_checkpoint_dir) return get_sync_client(
self.sync_function_tpl) or get_cloud_sync_client(
self.remote_checkpoint_dir)
def _storage_path(self, local_path): def _storage_path(self, local_path):
rel_local_path = os.path.relpath(local_path, self.logdir) rel_local_path = os.path.relpath(local_path, self.logdir)

View file

@ -115,6 +115,7 @@ class Experiment:
loggers=None, loggers=None,
log_to_file=False, log_to_file=False,
sync_to_driver=None, sync_to_driver=None,
sync_to_cloud=None,
checkpoint_freq=0, checkpoint_freq=0,
checkpoint_at_end=False, checkpoint_at_end=False,
sync_on_checkpoint=True, sync_on_checkpoint=True,
@ -217,6 +218,7 @@ class Experiment:
"loggers": loggers, "loggers": loggers,
"log_to_file": (stdout_file, stderr_file), "log_to_file": (stdout_file, stderr_file),
"sync_to_driver": sync_to_driver, "sync_to_driver": sync_to_driver,
"sync_to_cloud": sync_to_cloud,
"checkpoint_freq": checkpoint_freq, "checkpoint_freq": checkpoint_freq,
"checkpoint_at_end": checkpoint_at_end, "checkpoint_at_end": checkpoint_at_end,
"sync_on_checkpoint": sync_on_checkpoint, "sync_on_checkpoint": sync_on_checkpoint,

View file

@ -361,6 +361,7 @@ class RayTrialExecutor(TrialExecutor):
} }
if issubclass(trial.get_trainable_cls(), DurableTrainable): if issubclass(trial.get_trainable_cls(), DurableTrainable):
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
kwargs["sync_function_tpl"] = trial.sync_to_cloud
with self._change_working_directory(trial): with self._change_working_directory(trial):
return full_actor_class.remote(**kwargs) return full_actor_class.remote(**kwargs)

View file

@ -21,6 +21,7 @@ from ray.tune.durable_trainable import durable
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler, from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
AsyncHyperBandScheduler) AsyncHyperBandScheduler)
from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper
from ray.tune.sync_client import CommandBasedClient
from ray.tune.trial import Trial from ray.tune.trial import Trial
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID, from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
EPISODES_TOTAL, TRAINING_ITERATION, EPISODES_TOTAL, TRAINING_ITERATION,
@ -29,7 +30,7 @@ from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
from ray.tune.logger import Logger from ray.tune.logger import Logger
from ray.tune.experiment import Experiment from ray.tune.experiment import Experiment
from ray.tune.resources import Resources from ray.tune.resources import Resources
from ray.tune.suggest import grid_search from ray.tune.suggest import BasicVariantGenerator, grid_search
from ray.tune.suggest.hyperopt import HyperOptSearch from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.ax import AxSearch from ray.tune.suggest.ax import AxSearch
from ray.tune.suggest._mock import _MockSuggestionAlgorithm from ray.tune.suggest._mock import _MockSuggestionAlgorithm
@ -942,6 +943,58 @@ class TrainableFunctionApiTest(unittest.TestCase):
self._testDurableTrainable(durable(test_train), function=True) self._testDurableTrainable(durable(test_train), function=True)
def testDurableTrainableSyncFunction(self):
"""Check custom sync functions in durable trainables"""
class TestDurable(DurableTrainable):
def __init__(self, *args, **kwargs):
# Mock distutils.spawn.find_executable
# so `aws` command is found
import distutils.spawn
distutils.spawn.find_executable = lambda *_, **__: True
super(TestDurable, self).__init__(*args, **kwargs)
def check(self):
return bool(self.sync_function_tpl) and isinstance(
self.storage_client, CommandBasedClient
) and "aws" not in self.storage_client.sync_up_template
class TestTplDurable(TestDurable):
_sync_function_tpl = "echo static sync {source} {target}"
upload_dir = "s3://test-bucket/path"
def _create_remote_actor(trainable_cls, sync_to_cloud):
"""Create a remote trainable actor from an experiment"""
exp = Experiment(
name="test_durable_sync",
run=trainable_cls,
sync_to_cloud=sync_to_cloud,
sync_to_driver=False,
upload_dir=upload_dir)
searchers = BasicVariantGenerator()
searchers.add_configurations([exp])
trial = searchers.next_trial()
cls = trial.get_trainable_cls()
actor = ray.remote(cls).remote(
remote_checkpoint_dir=upload_dir,
sync_function_tpl=trial.sync_to_cloud)
return actor
# This actor should create a default aws syncer, so check should fail
actor1 = _create_remote_actor(TestDurable, None)
self.assertFalse(ray.get(actor1.check.remote()))
# This actor should create a custom syncer, so check should pass
actor2 = _create_remote_actor(TestDurable,
"echo test sync {source} {target}")
self.assertTrue(ray.get(actor2.check.remote()))
# This actor should create a custom syncer, so check should pass
actor3 = _create_remote_actor(TestTplDurable, None)
self.assertTrue(ray.get(actor3.check.remote()))
def testCheckpointDict(self): def testCheckpointDict(self):
class TestTrain(Trainable): class TestTrain(Trainable):
def setup(self, config): def setup(self, config):

View file

@ -1,5 +1,6 @@
import glob import glob
import os import os
import pickle
import shutil import shutil
import sys import sys
import tempfile import tempfile
@ -169,10 +170,16 @@ class TestSyncFunctionality(unittest.TestCase):
time.sleep(1) time.sleep(1)
tune.report(score=i) tune.report(score=i)
mock = unittest.mock.Mock()
def counter(local, remote): def counter(local, remote):
mock() count_file = os.path.join(tmpdir, "count.txt")
if not os.path.exists(count_file):
count = 0
else:
with open(count_file, "rb") as fp:
count = pickle.load(fp)
count += 1
with open(count_file, "wb") as fp:
pickle.dump(count, fp)
sync_config = tune.SyncConfig( sync_config = tune.SyncConfig(
upload_dir="test", sync_to_cloud=counter, cloud_sync_period=1) upload_dir="test", sync_to_cloud=counter, cloud_sync_period=1)
@ -191,7 +198,11 @@ class TestSyncFunctionality(unittest.TestCase):
sync_config=sync_config, sync_config=sync_config,
).trials ).trials
self.assertEqual(mock.call_count, 12) count_file = os.path.join(tmpdir, "count.txt")
with open(count_file, "rb") as fp:
count = pickle.load(fp)
self.assertEqual(count, 12)
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
def testClusterSyncFunction(self): def testClusterSyncFunction(self):

View file

@ -187,6 +187,7 @@ class Trial:
placement_group_factory=None, placement_group_factory=None,
stopping_criterion=None, stopping_criterion=None,
remote_checkpoint_dir=None, remote_checkpoint_dir=None,
sync_to_cloud=None,
checkpoint_freq=0, checkpoint_freq=0,
checkpoint_at_end=False, checkpoint_at_end=False,
sync_on_checkpoint=True, sync_on_checkpoint=True,
@ -283,6 +284,7 @@ class Trial:
self.remote_checkpoint_dir_prefix = remote_checkpoint_dir self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
else: else:
self.remote_checkpoint_dir_prefix = None self.remote_checkpoint_dir_prefix = None
self.sync_to_cloud = sync_to_cloud
self.checkpoint_freq = checkpoint_freq self.checkpoint_freq = checkpoint_freq
self.checkpoint_at_end = checkpoint_at_end self.checkpoint_at_end = checkpoint_at_end
self.keep_checkpoints_num = keep_checkpoints_num self.keep_checkpoints_num = keep_checkpoints_num

View file

@ -397,6 +397,7 @@ def run(
local_dir=local_dir, local_dir=local_dir,
upload_dir=sync_config.upload_dir, upload_dir=sync_config.upload_dir,
sync_to_driver=sync_config.sync_to_driver, sync_to_driver=sync_config.sync_to_driver,
sync_to_cloud=sync_config.sync_to_cloud,
trial_name_creator=trial_name_creator, trial_name_creator=trial_name_creator,
trial_dirname_creator=trial_dirname_creator, trial_dirname_creator=trial_dirname_creator,
log_to_file=log_to_file, log_to_file=log_to_file,

View file

@ -65,7 +65,8 @@ class MockDurableTrainer(DurableTrainable, _MockTrainer):
# TODO(ujvl): This class uses multiple inheritance; it should be cleaned # TODO(ujvl): This class uses multiple inheritance; it should be cleaned
# up once the durable training API converges. # up once the durable training API converges.
def __init__(self, remote_checkpoint_dir, *args, **kwargs): def __init__(self, remote_checkpoint_dir, sync_function_tpl, *args,
**kwargs):
_MockTrainer.__init__(self, *args, **kwargs) _MockTrainer.__init__(self, *args, **kwargs)
DurableTrainable.__init__(self, remote_checkpoint_dir, *args, **kwargs) DurableTrainable.__init__(self, remote_checkpoint_dir, *args, **kwargs)

View file

@ -34,7 +34,11 @@ class ProgressCallback(tune.callback.Callback):
class TestDurableTrainable(DurableTrainable): class TestDurableTrainable(DurableTrainable):
def __init__(self, remote_checkpoint_dir, config, logger_creator=None): def __init__(self,
remote_checkpoint_dir,
config,
logger_creator=None,
**kwargs):
self.setup_env() self.setup_env()
super(TestDurableTrainable, self).__init__( super(TestDurableTrainable, self).__init__(