[tune] Ensure arguments passed to tune remote_run match (#18733)

This commit is contained in:
Antoni Baum 2021-09-21 16:29:29 +02:00 committed by GitHub
parent fc6a739e4b
commit ca3fabc4cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 40 deletions

View file

@ -1,9 +1,12 @@
import inspect
import unittest
from unittest.mock import patch
import ray
from ray.tune import register_trainable, run_experiments, run
from ray.tune import register_trainable, run_experiments, run, choice
from ray.tune.result import TIMESTEPS_TOTAL
from ray.tune.experiment import Experiment
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.trial import Trial
from ray.util.client.ray_client_helpers import ray_start_client_server
@ -36,6 +39,47 @@ class RemoteTest(unittest.TestCase):
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
def testRemoteRunArguments(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
def mocked_run(*args, **kwargs):
capture_args_kwargs = (args, kwargs)
return run(*args, **kwargs), capture_args_kwargs
with patch("ray.tune.tune.run", mocked_run):
analysis, capture_args_kwargs = run(train, _remote=True)
args, kwargs = capture_args_kwargs
self.assertFalse(args)
kwargs.pop("run_or_experiment")
kwargs.pop("_remote")
default_kwargs = {
k: v.default
for k, v in inspect.signature(run).parameters.items()
}
default_kwargs.pop("run_or_experiment")
default_kwargs.pop("_remote")
self.assertDictEqual(kwargs, default_kwargs)
def testRemoteRunWithSearcher(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
analysis = run(
train,
search_alg=HyperOptSearch(),
config={"a": choice(["a", "b"])},
metric="timesteps_total",
mode="max",
_remote=True)
[trial] = analysis.trials
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
def testRemoteRunExperimentsInClient(self):
ray.init()
assert not ray.util.client.ray.is_connected()

View file

@ -292,6 +292,13 @@ def run(
TuneError: Any trials failed and `raise_on_failed_trial` is True.
"""
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT
# remote_run_kwargs must be defined before any other
# code is ran to ensure that at this point,
# `locals()` is equal to args and kwargs
remote_run_kwargs = locals().copy()
remote_run_kwargs.pop("_remote")
if _remote is None:
_remote = ray.util.client.ray.is_connected()
@ -307,45 +314,9 @@ def run(
# Make sure tune.run is called on the sever node.
remote_run = force_on_current_node(remote_run)
return ray.get(
remote_run.remote(
run_or_experiment,
name,
metric,
mode,
stop,
time_budget_s,
config,
resources_per_trial,
num_samples,
local_dir,
search_alg,
scheduler,
keep_checkpoints_num,
checkpoint_score_attr,
checkpoint_freq,
checkpoint_at_end,
verbose,
progress_reporter,
log_to_file,
trial_name_creator,
trial_dirname_creator,
sync_config,
export_formats,
max_failures,
fail_fast,
restore,
server_port,
resume,
queue_trials,
reuse_actors,
trial_executor,
raise_on_failed_trial,
callbacks,
max_concurrent_trials,
# Deprecated args
loggers,
_remote=False))
return ray.get(remote_run.remote(_remote=False, **remote_run_kwargs))
del remote_run_kwargs
all_start = time.time()