[Tune] Place remote tune.run on node running the client server (#16034)

* force placement on persistent node

* address comments

* doc
This commit is contained in:
Amog Kamsetty 2021-05-28 18:32:57 -07:00 committed by GitHub
parent cfa2997b86
commit 38b657cb65
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 79 additions and 13 deletions

View file

@ -348,6 +348,8 @@ Utilities
.. autofunction:: ray.tune.utils.validate_save_restore
.. autofunction:: ray.tune.utils.force_on_current_node
.. _tune-ddp-doc:

View file

@ -224,7 +224,11 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
if ray.util.client.ray.is_connected():
# If using Ray Client, we want to make sure checkpoint access
# happens on the server. So we wrap `test_best_model` in a Ray task.
ray.get(ray.remote(test_best_model).remote(best_trial))
# We have to make sure it gets executed on the same node that
# ``tune.run`` is called on.
from ray.tune.utils.util import force_on_current_node
remote_fn = force_on_current_node(ray.remote(test_best_model))
ray.get(remote_fn.remote(best_trial))
else:
test_best_model(best_trial)

View file

@ -139,6 +139,11 @@ if __name__ == "__main__":
if args.server_address:
# If using Ray Client, we want to make sure checkpoint access
# happens on the server. So we wrap `test_best_model` in a Ray task.
ray.get(ray.remote(test_best_model).remote(analysis))
# We have to make sure it gets executed on the same node that
# ``tune.run`` is called on.
from ray.tune.utils.util import force_on_current_node
remote_fn = force_on_current_node(ray.remote(test_best_model))
ray.get(remote_fn.remote(analysis))
else:
test_best_model(analysis)

View file

@ -88,8 +88,12 @@ if __name__ == "__main__":
if args.server_address:
# If connecting to a remote server with Ray Client, checkpoint loading
# should be wrapped in a task so it will execute on the server.
best_bst = ray.get(
ray.remote(get_best_model_checkpoint.remote(analysis)))
# We have to make sure it gets executed on the same node that
# ``tune.run`` is called on.
from ray.tune.utils import force_on_current_node
remote_fn = force_on_current_node(
ray.remote(get_best_model_checkpoint))
best_bst = ray.get(remote_fn.remote(analysis))
else:
best_bst = get_best_model_checkpoint(analysis)

View file

@ -29,6 +29,7 @@ from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.utils.callback import create_default_callbacks
from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
from ray.tune.utils import force_on_current_node
# Must come last to avoid circular imports
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
@ -297,8 +298,13 @@ def run(
_ray_auto_init()
if _remote:
remote_run = ray.remote(num_cpus=0)(run)
# Make sure tune.run is called on the sever node.
remote_run = force_on_current_node(remote_run)
return ray.get(
ray.remote(num_cpus=0)(run).remote(
remote_run.remote(
run_or_experiment,
name,
metric,
@ -601,8 +607,13 @@ def run_experiments(
_ray_auto_init()
if _remote:
remote_run = ray.remote(num_cpus=0)(run_experiments)
# Make sure tune.run_experiments is run on the server node.
remote_run = force_on_current_node(remote_run)
return ray.get(
ray.remote(num_cpus=0)(run_experiments).remote(
remote_run.remote(
experiments,
scheduler,
server_port,

View file

@ -1,14 +1,15 @@
from ray.tune.utils.util import (
deep_update, date_str, find_free_port, flatten_dict, get_pinned_object,
merge_dicts, pin_in_object_store, unflattened_lookup, UtilMonitor,
validate_save_restore, warn_if_slow, diagnose_serialization,
deep_update, date_str, find_free_port, flatten_dict, force_on_current_node,
get_pinned_object, merge_dicts, pin_in_object_store, unflattened_lookup,
UtilMonitor, validate_save_restore, warn_if_slow, diagnose_serialization,
detect_checkpoint_function, detect_reporter, detect_config_single,
wait_for_gpu)
__all__ = [
"deep_update", "date_str", "find_free_port", "flatten_dict",
"get_pinned_object", "merge_dicts", "pin_in_object_store",
"unflattened_lookup", "UtilMonitor", "validate_save_restore",
"warn_if_slow", "diagnose_serialization", "detect_checkpoint_function",
"detect_reporter", "detect_config_single", "wait_for_gpu"
"force_on_current_node", "get_pinned_object", "merge_dicts",
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
"validate_save_restore", "warn_if_slow", "diagnose_serialization",
"detect_checkpoint_function", "detect_reporter", "detect_config_single",
"wait_for_gpu"
]

View file

@ -759,6 +759,45 @@ def validate_warmstart(parameter_names: List[str],
" do not match.")
def get_current_node_resource_key() -> str:
"""Get the Ray resource key for current node.
It can be used for actor placement.
If using Ray Client, this will return the resource key for the node that
is running the client server.
Returns:
(str) A string of the format node:<CURRENT-NODE-IP-ADDRESS>
"""
current_node_id = ray.get_runtime_context().node_id.hex()
for node in ray.nodes():
if node["NodeID"] == current_node_id:
# Found the node.
for key in node["Resources"].keys():
if key.startswith("node:"):
return key
else:
raise ValueError("Cannot found the node dictionary for current node.")
def force_on_current_node(task_or_actor):
"""Given a task or actor, place it on the current node.
If using Ray Client, the current node is the client server node.
Args:
task_or_actor: A Ray remote function or class to place on the
current node.
Returns:
The provided task or actor, but with options modified to force
placement on the current node.
"""
node_resource_key = get_current_node_resource_key()
options = {"resources": {node_resource_key: 0.01}}
return task_or_actor.options(**options)
class SafeFallbackEncoder(json.JSONEncoder):
def __init__(self, nan_str="null", **kwargs):
super(SafeFallbackEncoder, self).__init__(**kwargs)