[tune] Fix syncer=None not disabling trial-to-driver syncing (#20418)

This commit is contained in:
Kai Fricke 2021-11-16 14:36:23 +00:00 committed by GitHub
parent f82880eda1
commit 8a6c936aa8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 3 deletions

View file

@ -556,8 +556,8 @@ def detect_cluster_syncer(
sync_config = sync_config or SyncConfig()
if bool(sync_config.upload_dir):
# No sync to driver for cloud checkpointing
if bool(sync_config.upload_dir) or sync_config.syncer is None:
# No sync to driver for cloud checkpointing or if manually disabled
return False
_syncer = sync_config.syncer

View file

@ -17,8 +17,10 @@ from ray.rllib import _register_all
from ray import tune
from ray.tune.integration.docker import DockerSyncer
from ray.tune.integration.kubernetes import KubernetesSyncer
from ray.tune.sync_client import NOOP
from ray.tune.syncer import (CommandBasedClient, detect_cluster_syncer,
get_cloud_sync_client)
get_cloud_sync_client, SyncerCallback)
from ray.tune.utils.callback import create_default_callbacks
class TestSyncFunctionality(unittest.TestCase):
@ -404,6 +406,34 @@ class TestSyncFunctionality(unittest.TestCase):
tune.SyncConfig(syncer=DockerSyncer), kubernetes_file)
self.assertTrue(issubclass(syncer, DockerSyncer))
@patch("ray.tune.syncer.log_sync_template",
lambda: "rsync {source} {target}")
def testNoSyncToDriver(self):
"""Test that sync to driver is disabled"""
class _Trial:
def __init__(self, id, logdir):
self.id = id,
self.logdir = logdir
trial = _Trial("0", "some_dir")
sync_config = tune.SyncConfig(syncer=None)
# Create syncer callbacks
callbacks = create_default_callbacks([], sync_config, loggers=None)
syncer_callback = callbacks[-1]
# Sanity check that we got the syncer callback
self.assertTrue(isinstance(syncer_callback, SyncerCallback))
# Sync function should be false (no sync to driver)
self.assertEquals(syncer_callback._sync_function, False)
# Sync to driver is disabled, so this should be no-op
trial_syncer = syncer_callback._get_trial_syncer(trial)
self.assertEquals(trial_syncer.sync_client, NOOP)
if __name__ == "__main__":
import pytest