mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Tune] Place remote tune.run on node running the client server (#16034)
* force placement on persistent node * address comments * doc
This commit is contained in:
parent
cfa2997b86
commit
38b657cb65
7 changed files with 79 additions and 13 deletions
|
@ -348,6 +348,8 @@ Utilities
|
|||
|
||||
.. autofunction:: ray.tune.utils.validate_save_restore
|
||||
|
||||
.. autofunction:: ray.tune.utils.force_on_current_node
|
||||
|
||||
|
||||
.. _tune-ddp-doc:
|
||||
|
||||
|
|
|
@ -224,7 +224,11 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
|
|||
if ray.util.client.ray.is_connected():
|
||||
# If using Ray Client, we want to make sure checkpoint access
|
||||
# happens on the server. So we wrap `test_best_model` in a Ray task.
|
||||
ray.get(ray.remote(test_best_model).remote(best_trial))
|
||||
# We have to make sure it gets executed on the same node that
|
||||
# ``tune.run`` is called on.
|
||||
from ray.tune.utils.util import force_on_current_node
|
||||
remote_fn = force_on_current_node(ray.remote(test_best_model))
|
||||
ray.get(remote_fn.remote(best_trial))
|
||||
else:
|
||||
test_best_model(best_trial)
|
||||
|
||||
|
|
|
@ -139,6 +139,11 @@ if __name__ == "__main__":
|
|||
if args.server_address:
|
||||
# If using Ray Client, we want to make sure checkpoint access
|
||||
# happens on the server. So we wrap `test_best_model` in a Ray task.
|
||||
ray.get(ray.remote(test_best_model).remote(analysis))
|
||||
# We have to make sure it gets executed on the same node that
|
||||
# ``tune.run`` is called on.
|
||||
from ray.tune.utils.util import force_on_current_node
|
||||
|
||||
remote_fn = force_on_current_node(ray.remote(test_best_model))
|
||||
ray.get(remote_fn.remote(analysis))
|
||||
else:
|
||||
test_best_model(analysis)
|
||||
|
|
|
@ -88,8 +88,12 @@ if __name__ == "__main__":
|
|||
if args.server_address:
|
||||
# If connecting to a remote server with Ray Client, checkpoint loading
|
||||
# should be wrapped in a task so it will execute on the server.
|
||||
best_bst = ray.get(
|
||||
ray.remote(get_best_model_checkpoint.remote(analysis)))
|
||||
# We have to make sure it gets executed on the same node that
|
||||
# ``tune.run`` is called on.
|
||||
from ray.tune.utils import force_on_current_node
|
||||
remote_fn = force_on_current_node(
|
||||
ray.remote(get_best_model_checkpoint))
|
||||
best_bst = ray.get(remote_fn.remote(analysis))
|
||||
else:
|
||||
best_bst = get_best_model_checkpoint(analysis)
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from ray.tune.trial import Trial
|
|||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.utils.callback import create_default_callbacks
|
||||
from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
|
||||
from ray.tune.utils import force_on_current_node
|
||||
|
||||
# Must come last to avoid circular imports
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
|
@ -297,8 +298,13 @@ def run(
|
|||
_ray_auto_init()
|
||||
|
||||
if _remote:
|
||||
remote_run = ray.remote(num_cpus=0)(run)
|
||||
|
||||
# Make sure tune.run is called on the sever node.
|
||||
remote_run = force_on_current_node(remote_run)
|
||||
|
||||
return ray.get(
|
||||
ray.remote(num_cpus=0)(run).remote(
|
||||
remote_run.remote(
|
||||
run_or_experiment,
|
||||
name,
|
||||
metric,
|
||||
|
@ -601,8 +607,13 @@ def run_experiments(
|
|||
_ray_auto_init()
|
||||
|
||||
if _remote:
|
||||
remote_run = ray.remote(num_cpus=0)(run_experiments)
|
||||
|
||||
# Make sure tune.run_experiments is run on the server node.
|
||||
remote_run = force_on_current_node(remote_run)
|
||||
|
||||
return ray.get(
|
||||
ray.remote(num_cpus=0)(run_experiments).remote(
|
||||
remote_run.remote(
|
||||
experiments,
|
||||
scheduler,
|
||||
server_port,
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from ray.tune.utils.util import (
|
||||
deep_update, date_str, find_free_port, flatten_dict, get_pinned_object,
|
||||
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor,
|
||||
validate_save_restore, warn_if_slow, diagnose_serialization,
|
||||
deep_update, date_str, find_free_port, flatten_dict, force_on_current_node,
|
||||
get_pinned_object, merge_dicts, pin_in_object_store, unflattened_lookup,
|
||||
UtilMonitor, validate_save_restore, warn_if_slow, diagnose_serialization,
|
||||
detect_checkpoint_function, detect_reporter, detect_config_single,
|
||||
wait_for_gpu)
|
||||
|
||||
__all__ = [
|
||||
"deep_update", "date_str", "find_free_port", "flatten_dict",
|
||||
"get_pinned_object", "merge_dicts", "pin_in_object_store",
|
||||
"unflattened_lookup", "UtilMonitor", "validate_save_restore",
|
||||
"warn_if_slow", "diagnose_serialization", "detect_checkpoint_function",
|
||||
"detect_reporter", "detect_config_single", "wait_for_gpu"
|
||||
"force_on_current_node", "get_pinned_object", "merge_dicts",
|
||||
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
|
||||
"validate_save_restore", "warn_if_slow", "diagnose_serialization",
|
||||
"detect_checkpoint_function", "detect_reporter", "detect_config_single",
|
||||
"wait_for_gpu"
|
||||
]
|
||||
|
|
|
@ -759,6 +759,45 @@ def validate_warmstart(parameter_names: List[str],
|
|||
" do not match.")
|
||||
|
||||
|
||||
def get_current_node_resource_key() -> str:
|
||||
"""Get the Ray resource key for current node.
|
||||
It can be used for actor placement.
|
||||
|
||||
If using Ray Client, this will return the resource key for the node that
|
||||
is running the client server.
|
||||
|
||||
Returns:
|
||||
(str) A string of the format node:<CURRENT-NODE-IP-ADDRESS>
|
||||
"""
|
||||
current_node_id = ray.get_runtime_context().node_id.hex()
|
||||
for node in ray.nodes():
|
||||
if node["NodeID"] == current_node_id:
|
||||
# Found the node.
|
||||
for key in node["Resources"].keys():
|
||||
if key.startswith("node:"):
|
||||
return key
|
||||
else:
|
||||
raise ValueError("Cannot found the node dictionary for current node.")
|
||||
|
||||
|
||||
def force_on_current_node(task_or_actor):
|
||||
"""Given a task or actor, place it on the current node.
|
||||
|
||||
If using Ray Client, the current node is the client server node.
|
||||
|
||||
Args:
|
||||
task_or_actor: A Ray remote function or class to place on the
|
||||
current node.
|
||||
|
||||
Returns:
|
||||
The provided task or actor, but with options modified to force
|
||||
placement on the current node.
|
||||
"""
|
||||
node_resource_key = get_current_node_resource_key()
|
||||
options = {"resources": {node_resource_key: 0.01}}
|
||||
return task_or_actor.options(**options)
|
||||
|
||||
|
||||
class SafeFallbackEncoder(json.JSONEncoder):
|
||||
def __init__(self, nan_str="null", **kwargs):
|
||||
super(SafeFallbackEncoder, self).__init__(**kwargs)
|
||||
|
|
Loading…
Add table
Reference in a new issue