mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Pass custom sync_to_cloud
templates to durable trainables (#16739)
This commit is contained in:
parent
e250abf689
commit
4178655ba7
10 changed files with 97 additions and 10 deletions
|
@ -202,6 +202,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||||
# json.load leads to str -> unicode in py2.7
|
# json.load leads to str -> unicode in py2.7
|
||||||
stopping_criterion=spec.get("stop", {}),
|
stopping_criterion=spec.get("stop", {}),
|
||||||
remote_checkpoint_dir=spec.get("remote_checkpoint_dir"),
|
remote_checkpoint_dir=spec.get("remote_checkpoint_dir"),
|
||||||
|
sync_to_cloud=spec.get("sync_to_cloud"),
|
||||||
checkpoint_freq=args.checkpoint_freq,
|
checkpoint_freq=args.checkpoint_freq,
|
||||||
checkpoint_at_end=args.checkpoint_at_end,
|
checkpoint_at_end=args.checkpoint_at_end,
|
||||||
sync_on_checkpoint=args.sync_on_checkpoint,
|
sync_on_checkpoint=args.sync_on_checkpoint,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Callable, Type, Union
|
from typing import Callable, Optional, Type, Union
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
@ -6,6 +6,7 @@ import os
|
||||||
|
|
||||||
from ray.tune.function_runner import wrap_function
|
from ray.tune.function_runner import wrap_function
|
||||||
from ray.tune.registry import get_trainable_cls
|
from ray.tune.registry import get_trainable_cls
|
||||||
|
from ray.tune.sync_client import get_sync_client
|
||||||
from ray.tune.trainable import Trainable, TrainableUtil
|
from ray.tune.trainable import Trainable, TrainableUtil
|
||||||
from ray.tune.syncer import get_cloud_sync_client
|
from ray.tune.syncer import get_cloud_sync_client
|
||||||
|
|
||||||
|
@ -31,15 +32,23 @@ class DurableTrainable(Trainable):
|
||||||
|
|
||||||
>>> tune.run(MyDurableTrainable, sync_to_driver=False)
|
>>> tune.run(MyDurableTrainable, sync_to_driver=False)
|
||||||
"""
|
"""
|
||||||
|
_sync_function_tpl = None
|
||||||
|
|
||||||
def __init__(self, remote_checkpoint_dir, *args, **kwargs):
|
def __init__(self,
|
||||||
|
remote_checkpoint_dir: str,
|
||||||
|
sync_function_tpl: Optional[str] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
"""Initializes a DurableTrainable.
|
"""Initializes a DurableTrainable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
remote_checkpoint_dir (str): Upload directory (S3 or GS path).
|
remote_checkpoint_dir (str): Upload directory (S3 or GS path).
|
||||||
|
sync_function_tpl (str): Sync function template to use. Defaults
|
||||||
|
to `cls._sync_function` (which defaults to `None`).
|
||||||
"""
|
"""
|
||||||
super(DurableTrainable, self).__init__(*args, **kwargs)
|
super(DurableTrainable, self).__init__(*args, **kwargs)
|
||||||
self.remote_checkpoint_dir = remote_checkpoint_dir
|
self.remote_checkpoint_dir = remote_checkpoint_dir
|
||||||
|
self.sync_function_tpl = sync_function_tpl or self._sync_function_tpl
|
||||||
self.storage_client = self._create_storage_client()
|
self.storage_client = self._create_storage_client()
|
||||||
|
|
||||||
def save(self, checkpoint_dir=None):
|
def save(self, checkpoint_dir=None):
|
||||||
|
@ -96,7 +105,9 @@ class DurableTrainable(Trainable):
|
||||||
|
|
||||||
def _create_storage_client(self):
|
def _create_storage_client(self):
|
||||||
"""Returns a storage client."""
|
"""Returns a storage client."""
|
||||||
return get_cloud_sync_client(self.remote_checkpoint_dir)
|
return get_sync_client(
|
||||||
|
self.sync_function_tpl) or get_cloud_sync_client(
|
||||||
|
self.remote_checkpoint_dir)
|
||||||
|
|
||||||
def _storage_path(self, local_path):
|
def _storage_path(self, local_path):
|
||||||
rel_local_path = os.path.relpath(local_path, self.logdir)
|
rel_local_path = os.path.relpath(local_path, self.logdir)
|
||||||
|
|
|
@ -115,6 +115,7 @@ class Experiment:
|
||||||
loggers=None,
|
loggers=None,
|
||||||
log_to_file=False,
|
log_to_file=False,
|
||||||
sync_to_driver=None,
|
sync_to_driver=None,
|
||||||
|
sync_to_cloud=None,
|
||||||
checkpoint_freq=0,
|
checkpoint_freq=0,
|
||||||
checkpoint_at_end=False,
|
checkpoint_at_end=False,
|
||||||
sync_on_checkpoint=True,
|
sync_on_checkpoint=True,
|
||||||
|
@ -217,6 +218,7 @@ class Experiment:
|
||||||
"loggers": loggers,
|
"loggers": loggers,
|
||||||
"log_to_file": (stdout_file, stderr_file),
|
"log_to_file": (stdout_file, stderr_file),
|
||||||
"sync_to_driver": sync_to_driver,
|
"sync_to_driver": sync_to_driver,
|
||||||
|
"sync_to_cloud": sync_to_cloud,
|
||||||
"checkpoint_freq": checkpoint_freq,
|
"checkpoint_freq": checkpoint_freq,
|
||||||
"checkpoint_at_end": checkpoint_at_end,
|
"checkpoint_at_end": checkpoint_at_end,
|
||||||
"sync_on_checkpoint": sync_on_checkpoint,
|
"sync_on_checkpoint": sync_on_checkpoint,
|
||||||
|
|
|
@ -361,6 +361,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
}
|
}
|
||||||
if issubclass(trial.get_trainable_cls(), DurableTrainable):
|
if issubclass(trial.get_trainable_cls(), DurableTrainable):
|
||||||
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
|
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
|
||||||
|
kwargs["sync_function_tpl"] = trial.sync_to_cloud
|
||||||
|
|
||||||
with self._change_working_directory(trial):
|
with self._change_working_directory(trial):
|
||||||
return full_actor_class.remote(**kwargs)
|
return full_actor_class.remote(**kwargs)
|
||||||
|
|
|
@ -21,6 +21,7 @@ from ray.tune.durable_trainable import durable
|
||||||
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
|
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
|
||||||
AsyncHyperBandScheduler)
|
AsyncHyperBandScheduler)
|
||||||
from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper
|
from ray.tune.stopper import MaximumIterationStopper, TrialPlateauStopper
|
||||||
|
from ray.tune.sync_client import CommandBasedClient
|
||||||
from ray.tune.trial import Trial
|
from ray.tune.trial import Trial
|
||||||
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
|
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
|
||||||
EPISODES_TOTAL, TRAINING_ITERATION,
|
EPISODES_TOTAL, TRAINING_ITERATION,
|
||||||
|
@ -29,7 +30,7 @@ from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
|
||||||
from ray.tune.logger import Logger
|
from ray.tune.logger import Logger
|
||||||
from ray.tune.experiment import Experiment
|
from ray.tune.experiment import Experiment
|
||||||
from ray.tune.resources import Resources
|
from ray.tune.resources import Resources
|
||||||
from ray.tune.suggest import grid_search
|
from ray.tune.suggest import BasicVariantGenerator, grid_search
|
||||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||||
from ray.tune.suggest.ax import AxSearch
|
from ray.tune.suggest.ax import AxSearch
|
||||||
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
|
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
|
||||||
|
@ -942,6 +943,58 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||||
|
|
||||||
self._testDurableTrainable(durable(test_train), function=True)
|
self._testDurableTrainable(durable(test_train), function=True)
|
||||||
|
|
||||||
|
def testDurableTrainableSyncFunction(self):
|
||||||
|
"""Check custom sync functions in durable trainables"""
|
||||||
|
|
||||||
|
class TestDurable(DurableTrainable):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# Mock distutils.spawn.find_executable
|
||||||
|
# so `aws` command is found
|
||||||
|
import distutils.spawn
|
||||||
|
distutils.spawn.find_executable = lambda *_, **__: True
|
||||||
|
super(TestDurable, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
return bool(self.sync_function_tpl) and isinstance(
|
||||||
|
self.storage_client, CommandBasedClient
|
||||||
|
) and "aws" not in self.storage_client.sync_up_template
|
||||||
|
|
||||||
|
class TestTplDurable(TestDurable):
|
||||||
|
_sync_function_tpl = "echo static sync {source} {target}"
|
||||||
|
|
||||||
|
upload_dir = "s3://test-bucket/path"
|
||||||
|
|
||||||
|
def _create_remote_actor(trainable_cls, sync_to_cloud):
|
||||||
|
"""Create a remote trainable actor from an experiment"""
|
||||||
|
exp = Experiment(
|
||||||
|
name="test_durable_sync",
|
||||||
|
run=trainable_cls,
|
||||||
|
sync_to_cloud=sync_to_cloud,
|
||||||
|
sync_to_driver=False,
|
||||||
|
upload_dir=upload_dir)
|
||||||
|
|
||||||
|
searchers = BasicVariantGenerator()
|
||||||
|
searchers.add_configurations([exp])
|
||||||
|
trial = searchers.next_trial()
|
||||||
|
cls = trial.get_trainable_cls()
|
||||||
|
actor = ray.remote(cls).remote(
|
||||||
|
remote_checkpoint_dir=upload_dir,
|
||||||
|
sync_function_tpl=trial.sync_to_cloud)
|
||||||
|
return actor
|
||||||
|
|
||||||
|
# This actor should create a default aws syncer, so check should fail
|
||||||
|
actor1 = _create_remote_actor(TestDurable, None)
|
||||||
|
self.assertFalse(ray.get(actor1.check.remote()))
|
||||||
|
|
||||||
|
# This actor should create a custom syncer, so check should pass
|
||||||
|
actor2 = _create_remote_actor(TestDurable,
|
||||||
|
"echo test sync {source} {target}")
|
||||||
|
self.assertTrue(ray.get(actor2.check.remote()))
|
||||||
|
|
||||||
|
# This actor should create a custom syncer, so check should pass
|
||||||
|
actor3 = _create_remote_actor(TestTplDurable, None)
|
||||||
|
self.assertTrue(ray.get(actor3.check.remote()))
|
||||||
|
|
||||||
def testCheckpointDict(self):
|
def testCheckpointDict(self):
|
||||||
class TestTrain(Trainable):
|
class TestTrain(Trainable):
|
||||||
def setup(self, config):
|
def setup(self, config):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
@ -169,10 +170,16 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
tune.report(score=i)
|
tune.report(score=i)
|
||||||
|
|
||||||
mock = unittest.mock.Mock()
|
|
||||||
|
|
||||||
def counter(local, remote):
|
def counter(local, remote):
|
||||||
mock()
|
count_file = os.path.join(tmpdir, "count.txt")
|
||||||
|
if not os.path.exists(count_file):
|
||||||
|
count = 0
|
||||||
|
else:
|
||||||
|
with open(count_file, "rb") as fp:
|
||||||
|
count = pickle.load(fp)
|
||||||
|
count += 1
|
||||||
|
with open(count_file, "wb") as fp:
|
||||||
|
pickle.dump(count, fp)
|
||||||
|
|
||||||
sync_config = tune.SyncConfig(
|
sync_config = tune.SyncConfig(
|
||||||
upload_dir="test", sync_to_cloud=counter, cloud_sync_period=1)
|
upload_dir="test", sync_to_cloud=counter, cloud_sync_period=1)
|
||||||
|
@ -191,7 +198,11 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
sync_config=sync_config,
|
sync_config=sync_config,
|
||||||
).trials
|
).trials
|
||||||
|
|
||||||
self.assertEqual(mock.call_count, 12)
|
count_file = os.path.join(tmpdir, "count.txt")
|
||||||
|
with open(count_file, "rb") as fp:
|
||||||
|
count = pickle.load(fp)
|
||||||
|
|
||||||
|
self.assertEqual(count, 12)
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
def testClusterSyncFunction(self):
|
def testClusterSyncFunction(self):
|
||||||
|
|
|
@ -187,6 +187,7 @@ class Trial:
|
||||||
placement_group_factory=None,
|
placement_group_factory=None,
|
||||||
stopping_criterion=None,
|
stopping_criterion=None,
|
||||||
remote_checkpoint_dir=None,
|
remote_checkpoint_dir=None,
|
||||||
|
sync_to_cloud=None,
|
||||||
checkpoint_freq=0,
|
checkpoint_freq=0,
|
||||||
checkpoint_at_end=False,
|
checkpoint_at_end=False,
|
||||||
sync_on_checkpoint=True,
|
sync_on_checkpoint=True,
|
||||||
|
@ -283,6 +284,7 @@ class Trial:
|
||||||
self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
|
self.remote_checkpoint_dir_prefix = remote_checkpoint_dir
|
||||||
else:
|
else:
|
||||||
self.remote_checkpoint_dir_prefix = None
|
self.remote_checkpoint_dir_prefix = None
|
||||||
|
self.sync_to_cloud = sync_to_cloud
|
||||||
self.checkpoint_freq = checkpoint_freq
|
self.checkpoint_freq = checkpoint_freq
|
||||||
self.checkpoint_at_end = checkpoint_at_end
|
self.checkpoint_at_end = checkpoint_at_end
|
||||||
self.keep_checkpoints_num = keep_checkpoints_num
|
self.keep_checkpoints_num = keep_checkpoints_num
|
||||||
|
|
|
@ -397,6 +397,7 @@ def run(
|
||||||
local_dir=local_dir,
|
local_dir=local_dir,
|
||||||
upload_dir=sync_config.upload_dir,
|
upload_dir=sync_config.upload_dir,
|
||||||
sync_to_driver=sync_config.sync_to_driver,
|
sync_to_driver=sync_config.sync_to_driver,
|
||||||
|
sync_to_cloud=sync_config.sync_to_cloud,
|
||||||
trial_name_creator=trial_name_creator,
|
trial_name_creator=trial_name_creator,
|
||||||
trial_dirname_creator=trial_dirname_creator,
|
trial_dirname_creator=trial_dirname_creator,
|
||||||
log_to_file=log_to_file,
|
log_to_file=log_to_file,
|
||||||
|
|
|
@ -65,7 +65,8 @@ class MockDurableTrainer(DurableTrainable, _MockTrainer):
|
||||||
# TODO(ujvl): This class uses multiple inheritance; it should be cleaned
|
# TODO(ujvl): This class uses multiple inheritance; it should be cleaned
|
||||||
# up once the durable training API converges.
|
# up once the durable training API converges.
|
||||||
|
|
||||||
def __init__(self, remote_checkpoint_dir, *args, **kwargs):
|
def __init__(self, remote_checkpoint_dir, sync_function_tpl, *args,
|
||||||
|
**kwargs):
|
||||||
_MockTrainer.__init__(self, *args, **kwargs)
|
_MockTrainer.__init__(self, *args, **kwargs)
|
||||||
DurableTrainable.__init__(self, remote_checkpoint_dir, *args, **kwargs)
|
DurableTrainable.__init__(self, remote_checkpoint_dir, *args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,11 @@ class ProgressCallback(tune.callback.Callback):
|
||||||
|
|
||||||
|
|
||||||
class TestDurableTrainable(DurableTrainable):
|
class TestDurableTrainable(DurableTrainable):
|
||||||
def __init__(self, remote_checkpoint_dir, config, logger_creator=None):
|
def __init__(self,
|
||||||
|
remote_checkpoint_dir,
|
||||||
|
config,
|
||||||
|
logger_creator=None,
|
||||||
|
**kwargs):
|
||||||
self.setup_env()
|
self.setup_env()
|
||||||
|
|
||||||
super(TestDurableTrainable, self).__init__(
|
super(TestDurableTrainable, self).__init__(
|
||||||
|
|
Loading…
Add table
Reference in a new issue