From ca3fabc4cb6d85771db7802f57cca4f56565211d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 21 Sep 2021 16:29:29 +0200 Subject: [PATCH] [tune] Ensure arguments passed to tune `remote_run` match (#18733) --- python/ray/tune/tests/test_remote.py | 46 +++++++++++++++++++++++++- python/ray/tune/tune.py | 49 ++++++---------------------- 2 files changed, 55 insertions(+), 40 deletions(-) diff --git a/python/ray/tune/tests/test_remote.py b/python/ray/tune/tests/test_remote.py index 1e521c54b..adac70dcd 100644 --- a/python/ray/tune/tests/test_remote.py +++ b/python/ray/tune/tests/test_remote.py @@ -1,9 +1,12 @@ +import inspect import unittest +from unittest.mock import patch import ray -from ray.tune import register_trainable, run_experiments, run +from ray.tune import register_trainable, run_experiments, run, choice from ray.tune.result import TIMESTEPS_TOTAL from ray.tune.experiment import Experiment +from ray.tune.suggest.hyperopt import HyperOptSearch from ray.tune.trial import Trial from ray.util.client.ray_client_helpers import ray_start_client_server @@ -36,6 +39,47 @@ class RemoteTest(unittest.TestCase): self.assertEqual(trial.status, Trial.TERMINATED) self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99) + def testRemoteRunArguments(self): + def train(config, reporter): + for i in range(100): + reporter(timesteps_total=i) + + def mocked_run(*args, **kwargs): + capture_args_kwargs = (args, kwargs) + return run(*args, **kwargs), capture_args_kwargs + + with patch("ray.tune.tune.run", mocked_run): + analysis, capture_args_kwargs = run(train, _remote=True) + args, kwargs = capture_args_kwargs + self.assertFalse(args) + kwargs.pop("run_or_experiment") + kwargs.pop("_remote") + + default_kwargs = { + k: v.default + for k, v in inspect.signature(run).parameters.items() + } + default_kwargs.pop("run_or_experiment") + default_kwargs.pop("_remote") + + self.assertDictEqual(kwargs, default_kwargs) + + def testRemoteRunWithSearcher(self): + def train(config, reporter): + for i in range(100): + reporter(timesteps_total=i) + + analysis = run( + train, + search_alg=HyperOptSearch(), + config={"a": choice(["a", "b"])}, + metric="timesteps_total", + mode="max", + _remote=True) + [trial] = analysis.trials + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99) + def testRemoteRunExperimentsInClient(self): ray.init() assert not ray.util.client.ray.is_connected() diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 394dad1aa..8077f7c6e 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -292,6 +292,13 @@ def run( TuneError: Any trials failed and `raise_on_failed_trial` is True. """ + # NO CODE IS TO BE ADDED ABOVE THIS COMMENT + # remote_run_kwargs must be defined before any other + # code is ran to ensure that at this point, + # `locals()` is equal to args and kwargs + remote_run_kwargs = locals().copy() + remote_run_kwargs.pop("_remote") + if _remote is None: _remote = ray.util.client.ray.is_connected() @@ -307,45 +314,9 @@ def run( # Make sure tune.run is called on the sever node. remote_run = force_on_current_node(remote_run) - return ray.get( - remote_run.remote( - run_or_experiment, - name, - metric, - mode, - stop, - time_budget_s, - config, - resources_per_trial, - num_samples, - local_dir, - search_alg, - scheduler, - keep_checkpoints_num, - checkpoint_score_attr, - checkpoint_freq, - checkpoint_at_end, - verbose, - progress_reporter, - log_to_file, - trial_name_creator, - trial_dirname_creator, - sync_config, - export_formats, - max_failures, - fail_fast, - restore, - server_port, - resume, - queue_trials, - reuse_actors, - trial_executor, - raise_on_failed_trial, - callbacks, - max_concurrent_trials, - # Deprecated args - loggers, - _remote=False)) + return ray.get(remote_run.remote(_remote=False, **remote_run_kwargs)) + + del remote_run_kwargs all_start = time.time()