mirror of
https://github.com/vale981/ray
synced 2025-03-12 06:06:39 -04:00
[Autoscaler] Introduce callback system (#11674)
Co-authored-by: Nikita Vemuri <nikitavemuri@Nikitas-MacBook-Pro.local> Co-authored-by: Xiayue Charles Lin <xcl@anyscale.com> Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
ee2da0cf45
commit
aba9288615
5 changed files with 158 additions and 1 deletions
|
@ -17,6 +17,8 @@ from ray.autoscaler._private.providers import _PROVIDER_PRETTY_NAMES
|
||||||
from ray.autoscaler._private.aws.utils import LazyDefaultDict, \
|
from ray.autoscaler._private.aws.utils import LazyDefaultDict, \
|
||||||
handle_boto_error
|
handle_boto_error
|
||||||
from ray.autoscaler._private.cli_logger import cli_logger, cf
|
from ray.autoscaler._private.cli_logger import cli_logger, cf
|
||||||
|
from ray.autoscaler._private.event_system import (CreateClusterEvent,
|
||||||
|
global_event_system)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -191,6 +193,9 @@ def bootstrap_aws(config):
|
||||||
|
|
||||||
# Configure SSH access, using an existing key pair if possible.
|
# Configure SSH access, using an existing key pair if possible.
|
||||||
config = _configure_key_pair(config)
|
config = _configure_key_pair(config)
|
||||||
|
global_event_system.execute_callback(
|
||||||
|
CreateClusterEvent.ssh_keypair_downloaded,
|
||||||
|
{"ssh_key_path": config["auth"]["ssh_private_key"]})
|
||||||
|
|
||||||
# Pick a reasonable subnet if not specified by the user.
|
# Pick a reasonable subnet if not specified by the user.
|
||||||
config = _configure_subnet(config)
|
config = _configure_subnet(config)
|
||||||
|
|
|
@ -36,6 +36,8 @@ from ray.autoscaler._private.cli_logger import cli_logger, cf
|
||||||
from ray.autoscaler._private.updater import NodeUpdaterThread
|
from ray.autoscaler._private.updater import NodeUpdaterThread
|
||||||
from ray.autoscaler._private.command_runner import set_using_login_shells, \
|
from ray.autoscaler._private.command_runner import set_using_login_shells, \
|
||||||
set_rsync_silent
|
set_rsync_silent
|
||||||
|
from ray.autoscaler._private.event_system import (CreateClusterEvent,
|
||||||
|
global_event_system)
|
||||||
from ray.autoscaler._private.log_timer import LogTimer
|
from ray.autoscaler._private.log_timer import LogTimer
|
||||||
from ray.worker import global_worker # type: ignore
|
from ray.worker import global_worker # type: ignore
|
||||||
from ray.util.debug import log_once
|
from ray.util.debug import log_once
|
||||||
|
@ -167,6 +169,8 @@ def create_or_update_cluster(config_file: str,
|
||||||
except yaml.scanner.ScannerError as e:
|
except yaml.scanner.ScannerError as e:
|
||||||
handle_yaml_error(e)
|
handle_yaml_error(e)
|
||||||
raise
|
raise
|
||||||
|
global_event_system.execute_callback(CreateClusterEvent.up_started,
|
||||||
|
{"cluster_config": config})
|
||||||
|
|
||||||
# todo: validate file_mounts, ssh keys, etc.
|
# todo: validate file_mounts, ssh keys, etc.
|
||||||
|
|
||||||
|
@ -476,6 +480,8 @@ def get_or_create_head_node(config: Dict[str, Any],
|
||||||
_provider: Optional[NodeProvider] = None,
|
_provider: Optional[NodeProvider] = None,
|
||||||
_runner: ModuleType = subprocess) -> None:
|
_runner: ModuleType = subprocess) -> None:
|
||||||
"""Create the cluster head node, which in turn creates the workers."""
|
"""Create the cluster head node, which in turn creates the workers."""
|
||||||
|
global_event_system.execute_callback(
|
||||||
|
CreateClusterEvent.cluster_booting_started)
|
||||||
provider = (_provider or _get_node_provider(config["provider"],
|
provider = (_provider or _get_node_provider(config["provider"],
|
||||||
config["cluster_name"]))
|
config["cluster_name"]))
|
||||||
|
|
||||||
|
@ -536,6 +542,8 @@ def get_or_create_head_node(config: Dict[str, Any],
|
||||||
if head_node is None or provider.node_tags(head_node).get(
|
if head_node is None or provider.node_tags(head_node).get(
|
||||||
TAG_RAY_LAUNCH_CONFIG) != launch_hash:
|
TAG_RAY_LAUNCH_CONFIG) != launch_hash:
|
||||||
with cli_logger.group("Acquiring an up-to-date head node"):
|
with cli_logger.group("Acquiring an up-to-date head node"):
|
||||||
|
global_event_system.execute_callback(
|
||||||
|
CreateClusterEvent.acquiring_new_head_node)
|
||||||
if head_node is not None:
|
if head_node is not None:
|
||||||
cli_logger.print(
|
cli_logger.print(
|
||||||
"Currently running head node is out-of-date with "
|
"Currently running head node is out-of-date with "
|
||||||
|
@ -571,6 +579,8 @@ def get_or_create_head_node(config: Dict[str, Any],
|
||||||
time.sleep(POLL_INTERVAL)
|
time.sleep(POLL_INTERVAL)
|
||||||
cli_logger.newline()
|
cli_logger.newline()
|
||||||
|
|
||||||
|
global_event_system.execute_callback(CreateClusterEvent.head_node_acquired)
|
||||||
|
|
||||||
with cli_logger.group(
|
with cli_logger.group(
|
||||||
"Setting up head node",
|
"Setting up head node",
|
||||||
_numbered=("<>", 1, 1),
|
_numbered=("<>", 1, 1),
|
||||||
|
@ -664,6 +674,11 @@ def get_or_create_head_node(config: Dict[str, Any],
|
||||||
cli_logger.abort("Failed to setup head node.")
|
cli_logger.abort("Failed to setup head node.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
global_event_system.execute_callback(
|
||||||
|
CreateClusterEvent.cluster_booting_completed, {
|
||||||
|
"head_node_id": head_node,
|
||||||
|
})
|
||||||
|
|
||||||
monitor_str = "tail -n 100 -f /tmp/ray/session_latest/logs/monitor*"
|
monitor_str = "tail -n 100 -f /tmp/ray/session_latest/logs/monitor*"
|
||||||
if override_cluster_name:
|
if override_cluster_name:
|
||||||
modifiers = " --cluster-name={}".format(quote(override_cluster_name))
|
modifiers = " --cluster-name={}".format(quote(override_cluster_name))
|
||||||
|
|
100
python/ray/autoscaler/_private/event_system.py
Normal file
100
python/ray/autoscaler/_private/event_system.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Any, Callable, Dict, List, Union
|
||||||
|
|
||||||
|
from ray.autoscaler._private.cli_logger import cli_logger
|
||||||
|
|
||||||
|
|
||||||
|
class CreateClusterEvent(Enum):
|
||||||
|
"""Events to track in ray.autoscaler.sdk.create_or_update_cluster.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
up_started : Invoked at the beginning of create_or_update_cluster.
|
||||||
|
ssh_keypair_downloaded : Invoked when the ssh keypair is downloaded.
|
||||||
|
cluster_booting_started : Invoked when when the cluster booting starts.
|
||||||
|
acquiring_new_head_node : Invoked before the head node is acquired.
|
||||||
|
head_node_acquired : Invoked after the head node is acquired.
|
||||||
|
ssh_control_acquired : Invoked when the node is being updated.
|
||||||
|
run_initialization_cmd : Invoked before all initialization
|
||||||
|
commands are called and again before each initialization command.
|
||||||
|
run_setup_cmd : Invoked before all setup commands are
|
||||||
|
called and again before each setup command.
|
||||||
|
start_ray_runtime : Invoked before ray start commands are run.
|
||||||
|
start_ray_runtime_completed : Invoked after ray start commands
|
||||||
|
are run.
|
||||||
|
cluster_booting_completed : Invoked after cluster booting
|
||||||
|
is completed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
up_started = auto()
|
||||||
|
ssh_keypair_downloaded = auto()
|
||||||
|
cluster_booting_started = auto()
|
||||||
|
acquiring_new_head_node = auto()
|
||||||
|
head_node_acquired = auto()
|
||||||
|
ssh_control_acquired = auto()
|
||||||
|
run_initialization_cmd = auto()
|
||||||
|
run_setup_cmd = auto()
|
||||||
|
start_ray_runtime = auto()
|
||||||
|
start_ray_runtime_completed = auto()
|
||||||
|
cluster_booting_completed = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class _EventSystem:
|
||||||
|
"""Event system that handles storing and calling callbacks for events.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
callback_map (Dict[str, List[Callable]]) : Stores list of callbacks
|
||||||
|
for events when registered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.callback_map = {}
|
||||||
|
|
||||||
|
def add_callback_handler(
|
||||||
|
self,
|
||||||
|
event: str,
|
||||||
|
callback: Union[Callable[[Dict], None], List[Callable[[Dict],
|
||||||
|
None]]],
|
||||||
|
):
|
||||||
|
"""Stores callback handler for event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (str): Event that callback should be called on. See
|
||||||
|
CreateClusterEvent for details on the events available to be
|
||||||
|
registered against.
|
||||||
|
callback (Callable[[Dict], None]): Callable object that is invoked
|
||||||
|
when specified event occurs.
|
||||||
|
"""
|
||||||
|
if event not in CreateClusterEvent.__members__.values():
|
||||||
|
cli_logger.warning(f"{event} is not currently tracked, and this"
|
||||||
|
" callback will not be invoked.")
|
||||||
|
|
||||||
|
self.callback_map.setdefault(
|
||||||
|
event,
|
||||||
|
[]).extend([callback] if type(callback) is not list else callback)
|
||||||
|
|
||||||
|
def execute_callback(self, event: str, event_data: Dict[str, Any] = {}):
|
||||||
|
"""Executes all callbacks for event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (str): Event that is invoked. See CreateClusterEvent
|
||||||
|
for details on the available events.
|
||||||
|
event_data (Dict[str, Any]): Argument that is passed to each
|
||||||
|
callable object stored for this particular event.
|
||||||
|
"""
|
||||||
|
event_data["event_name"] = event
|
||||||
|
if event in self.callback_map:
|
||||||
|
for callback in self.callback_map[event]:
|
||||||
|
callback(event_data)
|
||||||
|
|
||||||
|
def clear_callbacks_for_event(self, event: str):
|
||||||
|
"""Clears stored callable objects for event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (str): Event that has callable objects stored in map.
|
||||||
|
See CreateClusterEvent for details on the available events.
|
||||||
|
"""
|
||||||
|
if event in self.callback_map:
|
||||||
|
del self.callback_map[event]
|
||||||
|
|
||||||
|
|
||||||
|
global_event_system = _EventSystem()
|
|
@ -17,6 +17,8 @@ from ray.autoscaler._private.cli_logger import cli_logger, cf
|
||||||
import ray.autoscaler._private.subprocess_output_util as cmd_output_util
|
import ray.autoscaler._private.subprocess_output_util as cmd_output_util
|
||||||
from ray.autoscaler._private.constants import \
|
from ray.autoscaler._private.constants import \
|
||||||
RESOURCES_ENVIRONMENT_VARIABLE
|
RESOURCES_ENVIRONMENT_VARIABLE
|
||||||
|
from ray.autoscaler._private.event_system import (CreateClusterEvent,
|
||||||
|
global_event_system)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -271,6 +273,8 @@ class NodeUpdater:
|
||||||
|
|
||||||
deadline = time.time() + NODE_START_WAIT_S
|
deadline = time.time() + NODE_START_WAIT_S
|
||||||
self.wait_ready(deadline)
|
self.wait_ready(deadline)
|
||||||
|
global_event_system.execute_callback(
|
||||||
|
CreateClusterEvent.ssh_control_acquired)
|
||||||
|
|
||||||
node_tags = self.provider.node_tags(self.node_id)
|
node_tags = self.provider.node_tags(self.node_id)
|
||||||
logger.debug("Node tags: {}".format(str(node_tags)))
|
logger.debug("Node tags: {}".format(str(node_tags)))
|
||||||
|
@ -317,10 +321,15 @@ class NodeUpdater:
|
||||||
with cli_logger.group(
|
with cli_logger.group(
|
||||||
"Running initialization commands",
|
"Running initialization commands",
|
||||||
_numbered=("[]", 3, 5)):
|
_numbered=("[]", 3, 5)):
|
||||||
|
global_event_system.execute_callback(
|
||||||
|
CreateClusterEvent.run_initialization_cmd)
|
||||||
with LogTimer(
|
with LogTimer(
|
||||||
self.log_prefix + "Initialization commands",
|
self.log_prefix + "Initialization commands",
|
||||||
show_status=True):
|
show_status=True):
|
||||||
for cmd in self.initialization_commands:
|
for cmd in self.initialization_commands:
|
||||||
|
global_event_system.execute_callback(
|
||||||
|
CreateClusterEvent.run_initialization_cmd,
|
||||||
|
{"command": cmd})
|
||||||
try:
|
try:
|
||||||
# Overriding the existing SSHOptions class
|
# Overriding the existing SSHOptions class
|
||||||
# with a new SSHOptions class that uses
|
# with a new SSHOptions class that uses
|
||||||
|
@ -352,12 +361,17 @@ class NodeUpdater:
|
||||||
"Running setup commands",
|
"Running setup commands",
|
||||||
# todo: fix command numbering
|
# todo: fix command numbering
|
||||||
_numbered=("[]", 4, 6)):
|
_numbered=("[]", 4, 6)):
|
||||||
|
global_event_system.execute_callback(
|
||||||
|
CreateClusterEvent.run_setup_cmd)
|
||||||
with LogTimer(
|
with LogTimer(
|
||||||
self.log_prefix + "Setup commands",
|
self.log_prefix + "Setup commands",
|
||||||
show_status=True):
|
show_status=True):
|
||||||
|
|
||||||
total = len(self.setup_commands)
|
total = len(self.setup_commands)
|
||||||
for i, cmd in enumerate(self.setup_commands):
|
for i, cmd in enumerate(self.setup_commands):
|
||||||
|
global_event_system.execute_callback(
|
||||||
|
CreateClusterEvent.run_setup_cmd,
|
||||||
|
{"command": cmd})
|
||||||
if cli_logger.verbosity == 0 and len(cmd) > 30:
|
if cli_logger.verbosity == 0 and len(cmd) > 30:
|
||||||
cmd_to_print = cf.bold(cmd[:30]) + "..."
|
cmd_to_print = cf.bold(cmd[:30]) + "..."
|
||||||
else:
|
else:
|
||||||
|
@ -385,6 +399,8 @@ class NodeUpdater:
|
||||||
|
|
||||||
with cli_logger.group(
|
with cli_logger.group(
|
||||||
"Starting the Ray runtime", _numbered=("[]", 6, 6)):
|
"Starting the Ray runtime", _numbered=("[]", 6, 6)):
|
||||||
|
global_event_system.execute_callback(
|
||||||
|
CreateClusterEvent.start_ray_runtime)
|
||||||
with LogTimer(
|
with LogTimer(
|
||||||
self.log_prefix + "Ray start commands", show_status=True):
|
self.log_prefix + "Ray start commands", show_status=True):
|
||||||
for cmd in self.ray_start_commands:
|
for cmd in self.ray_start_commands:
|
||||||
|
@ -409,6 +425,8 @@ class NodeUpdater:
|
||||||
cli_logger.error("See above for stderr.")
|
cli_logger.error("See above for stderr.")
|
||||||
|
|
||||||
raise click.ClickException("Start command failed.")
|
raise click.ClickException("Start command failed.")
|
||||||
|
global_event_system.execute_callback(
|
||||||
|
CreateClusterEvent.start_ray_runtime_completed)
|
||||||
|
|
||||||
def rsync_up(self, source, target, docker_mount_if_possible=False):
|
def rsync_up(self, source, target, docker_mount_if_possible=False):
|
||||||
options = {}
|
options = {}
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
"""IMPORTANT: this is an experimental interface and not currently stable."""
|
"""IMPORTANT: this is an experimental interface and not currently stable."""
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from ray.autoscaler._private import commands
|
from ray.autoscaler._private import commands
|
||||||
|
from ray.autoscaler._private.event_system import ( # noqa: F401
|
||||||
|
CreateClusterEvent, # noqa: F401
|
||||||
|
global_event_system)
|
||||||
|
|
||||||
|
|
||||||
def create_or_update_cluster(cluster_config: Union[dict, str],
|
def create_or_update_cluster(cluster_config: Union[dict, str],
|
||||||
|
@ -224,6 +227,22 @@ def fillout_defaults(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
return fillout_defaults(config)
|
return fillout_defaults(config)
|
||||||
|
|
||||||
|
|
||||||
|
def register_callback_handler(
|
||||||
|
event_name: str,
|
||||||
|
callback: Union[Callable[[Dict], None], List[Callable[[Dict], None]]],
|
||||||
|
) -> None:
|
||||||
|
"""Registers a callback handler for autoscaler events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_name (str): Event that callback should be called on. See
|
||||||
|
CreateClusterEvent for details on the events available to be
|
||||||
|
registered against.
|
||||||
|
callback (Callable): Callable object that is invoked
|
||||||
|
when specified event occurs.
|
||||||
|
"""
|
||||||
|
global_event_system.add_callback_handler(event_name, callback)
|
||||||
|
|
||||||
|
|
||||||
def get_docker_host_mount_location(cluster_name: str) -> str:
|
def get_docker_host_mount_location(cluster_name: str) -> str:
|
||||||
"""Return host path that Docker mounts attach to."""
|
"""Return host path that Docker mounts attach to."""
|
||||||
docker_mount_prefix = "/tmp/ray_tmp_mount/{cluster_name}"
|
docker_mount_prefix = "/tmp/ray_tmp_mount/{cluster_name}"
|
||||||
|
|
Loading…
Add table
Reference in a new issue