mirror of
https://github.com/vale981/ray
synced 2025-03-08 11:31:40 -05:00
[tune] Ensure arguments passed to tune remote_run
match (#18733)
This commit is contained in:
parent
fc6a739e4b
commit
ca3fabc4cb
2 changed files with 55 additions and 40 deletions
|
@ -1,9 +1,12 @@
|
|||
import inspect
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import ray
|
||||
from ray.tune import register_trainable, run_experiments, run
|
||||
from ray.tune import register_trainable, run_experiments, run, choice
|
||||
from ray.tune.result import TIMESTEPS_TOTAL
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.trial import Trial
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
|
@ -36,6 +39,47 @@ class RemoteTest(unittest.TestCase):
|
|||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testRemoteRunArguments(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
def mocked_run(*args, **kwargs):
|
||||
capture_args_kwargs = (args, kwargs)
|
||||
return run(*args, **kwargs), capture_args_kwargs
|
||||
|
||||
with patch("ray.tune.tune.run", mocked_run):
|
||||
analysis, capture_args_kwargs = run(train, _remote=True)
|
||||
args, kwargs = capture_args_kwargs
|
||||
self.assertFalse(args)
|
||||
kwargs.pop("run_or_experiment")
|
||||
kwargs.pop("_remote")
|
||||
|
||||
default_kwargs = {
|
||||
k: v.default
|
||||
for k, v in inspect.signature(run).parameters.items()
|
||||
}
|
||||
default_kwargs.pop("run_or_experiment")
|
||||
default_kwargs.pop("_remote")
|
||||
|
||||
self.assertDictEqual(kwargs, default_kwargs)
|
||||
|
||||
def testRemoteRunWithSearcher(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
analysis = run(
|
||||
train,
|
||||
search_alg=HyperOptSearch(),
|
||||
config={"a": choice(["a", "b"])},
|
||||
metric="timesteps_total",
|
||||
mode="max",
|
||||
_remote=True)
|
||||
[trial] = analysis.trials
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testRemoteRunExperimentsInClient(self):
|
||||
ray.init()
|
||||
assert not ray.util.client.ray.is_connected()
|
||||
|
|
|
@ -292,6 +292,13 @@ def run(
|
|||
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
||||
"""
|
||||
|
||||
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT
|
||||
# remote_run_kwargs must be defined before any other
|
||||
# code is ran to ensure that at this point,
|
||||
# `locals()` is equal to args and kwargs
|
||||
remote_run_kwargs = locals().copy()
|
||||
remote_run_kwargs.pop("_remote")
|
||||
|
||||
if _remote is None:
|
||||
_remote = ray.util.client.ray.is_connected()
|
||||
|
||||
|
@ -307,45 +314,9 @@ def run(
|
|||
# Make sure tune.run is called on the sever node.
|
||||
remote_run = force_on_current_node(remote_run)
|
||||
|
||||
return ray.get(
|
||||
remote_run.remote(
|
||||
run_or_experiment,
|
||||
name,
|
||||
metric,
|
||||
mode,
|
||||
stop,
|
||||
time_budget_s,
|
||||
config,
|
||||
resources_per_trial,
|
||||
num_samples,
|
||||
local_dir,
|
||||
search_alg,
|
||||
scheduler,
|
||||
keep_checkpoints_num,
|
||||
checkpoint_score_attr,
|
||||
checkpoint_freq,
|
||||
checkpoint_at_end,
|
||||
verbose,
|
||||
progress_reporter,
|
||||
log_to_file,
|
||||
trial_name_creator,
|
||||
trial_dirname_creator,
|
||||
sync_config,
|
||||
export_formats,
|
||||
max_failures,
|
||||
fail_fast,
|
||||
restore,
|
||||
server_port,
|
||||
resume,
|
||||
queue_trials,
|
||||
reuse_actors,
|
||||
trial_executor,
|
||||
raise_on_failed_trial,
|
||||
callbacks,
|
||||
max_concurrent_trials,
|
||||
# Deprecated args
|
||||
loggers,
|
||||
_remote=False))
|
||||
return ray.get(remote_run.remote(_remote=False, **remote_run_kwargs))
|
||||
|
||||
del remote_run_kwargs
|
||||
|
||||
all_start = time.time()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue