mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00
[tune] Ensure arguments passed to tune remote_run
match (#18733)
This commit is contained in:
parent
fc6a739e4b
commit
ca3fabc4cb
2 changed files with 55 additions and 40 deletions
|
@ -1,9 +1,12 @@
|
||||||
|
import inspect
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.tune import register_trainable, run_experiments, run
|
from ray.tune import register_trainable, run_experiments, run, choice
|
||||||
from ray.tune.result import TIMESTEPS_TOTAL
|
from ray.tune.result import TIMESTEPS_TOTAL
|
||||||
from ray.tune.experiment import Experiment
|
from ray.tune.experiment import Experiment
|
||||||
|
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||||
from ray.tune.trial import Trial
|
from ray.tune.trial import Trial
|
||||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||||
|
|
||||||
|
@ -36,6 +39,47 @@ class RemoteTest(unittest.TestCase):
|
||||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||||
|
|
||||||
|
def testRemoteRunArguments(self):
|
||||||
|
def train(config, reporter):
|
||||||
|
for i in range(100):
|
||||||
|
reporter(timesteps_total=i)
|
||||||
|
|
||||||
|
def mocked_run(*args, **kwargs):
|
||||||
|
capture_args_kwargs = (args, kwargs)
|
||||||
|
return run(*args, **kwargs), capture_args_kwargs
|
||||||
|
|
||||||
|
with patch("ray.tune.tune.run", mocked_run):
|
||||||
|
analysis, capture_args_kwargs = run(train, _remote=True)
|
||||||
|
args, kwargs = capture_args_kwargs
|
||||||
|
self.assertFalse(args)
|
||||||
|
kwargs.pop("run_or_experiment")
|
||||||
|
kwargs.pop("_remote")
|
||||||
|
|
||||||
|
default_kwargs = {
|
||||||
|
k: v.default
|
||||||
|
for k, v in inspect.signature(run).parameters.items()
|
||||||
|
}
|
||||||
|
default_kwargs.pop("run_or_experiment")
|
||||||
|
default_kwargs.pop("_remote")
|
||||||
|
|
||||||
|
self.assertDictEqual(kwargs, default_kwargs)
|
||||||
|
|
||||||
|
def testRemoteRunWithSearcher(self):
|
||||||
|
def train(config, reporter):
|
||||||
|
for i in range(100):
|
||||||
|
reporter(timesteps_total=i)
|
||||||
|
|
||||||
|
analysis = run(
|
||||||
|
train,
|
||||||
|
search_alg=HyperOptSearch(),
|
||||||
|
config={"a": choice(["a", "b"])},
|
||||||
|
metric="timesteps_total",
|
||||||
|
mode="max",
|
||||||
|
_remote=True)
|
||||||
|
[trial] = analysis.trials
|
||||||
|
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||||
|
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||||
|
|
||||||
def testRemoteRunExperimentsInClient(self):
|
def testRemoteRunExperimentsInClient(self):
|
||||||
ray.init()
|
ray.init()
|
||||||
assert not ray.util.client.ray.is_connected()
|
assert not ray.util.client.ray.is_connected()
|
||||||
|
|
|
@ -292,6 +292,13 @@ def run(
|
||||||
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT
|
||||||
|
# remote_run_kwargs must be defined before any other
|
||||||
|
# code is ran to ensure that at this point,
|
||||||
|
# `locals()` is equal to args and kwargs
|
||||||
|
remote_run_kwargs = locals().copy()
|
||||||
|
remote_run_kwargs.pop("_remote")
|
||||||
|
|
||||||
if _remote is None:
|
if _remote is None:
|
||||||
_remote = ray.util.client.ray.is_connected()
|
_remote = ray.util.client.ray.is_connected()
|
||||||
|
|
||||||
|
@ -307,45 +314,9 @@ def run(
|
||||||
# Make sure tune.run is called on the sever node.
|
# Make sure tune.run is called on the sever node.
|
||||||
remote_run = force_on_current_node(remote_run)
|
remote_run = force_on_current_node(remote_run)
|
||||||
|
|
||||||
return ray.get(
|
return ray.get(remote_run.remote(_remote=False, **remote_run_kwargs))
|
||||||
remote_run.remote(
|
|
||||||
run_or_experiment,
|
del remote_run_kwargs
|
||||||
name,
|
|
||||||
metric,
|
|
||||||
mode,
|
|
||||||
stop,
|
|
||||||
time_budget_s,
|
|
||||||
config,
|
|
||||||
resources_per_trial,
|
|
||||||
num_samples,
|
|
||||||
local_dir,
|
|
||||||
search_alg,
|
|
||||||
scheduler,
|
|
||||||
keep_checkpoints_num,
|
|
||||||
checkpoint_score_attr,
|
|
||||||
checkpoint_freq,
|
|
||||||
checkpoint_at_end,
|
|
||||||
verbose,
|
|
||||||
progress_reporter,
|
|
||||||
log_to_file,
|
|
||||||
trial_name_creator,
|
|
||||||
trial_dirname_creator,
|
|
||||||
sync_config,
|
|
||||||
export_formats,
|
|
||||||
max_failures,
|
|
||||||
fail_fast,
|
|
||||||
restore,
|
|
||||||
server_port,
|
|
||||||
resume,
|
|
||||||
queue_trials,
|
|
||||||
reuse_actors,
|
|
||||||
trial_executor,
|
|
||||||
raise_on_failed_trial,
|
|
||||||
callbacks,
|
|
||||||
max_concurrent_trials,
|
|
||||||
# Deprecated args
|
|
||||||
loggers,
|
|
||||||
_remote=False))
|
|
||||||
|
|
||||||
all_start = time.time()
|
all_start = time.time()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue