[core] Add type hints to public APIs in worker.py (#18049)

This commit is contained in:
Edward Oakes 2021-08-26 09:51:44 -05:00 committed by GitHub
parent 07c05e16fa
commit 0c5f7a698d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -12,7 +12,7 @@ import sys
import threading
import time
import traceback
from typing import Any, Dict, List, Optional, Iterator
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
# Ray modules
from ray.autoscaler._private.constants import AUTOSCALER_EVENTS
@ -577,37 +577,37 @@ _global_node = None
@PublicAPI
@client_mode_hook
def init(
address=None,
address: Optional[str] = None,
*,
num_cpus=None,
num_gpus=None,
resources=None,
object_store_memory=None,
local_mode=False,
ignore_reinit_error=False,
include_dashboard=None,
dashboard_host=ray_constants.DEFAULT_DASHBOARD_IP,
dashboard_port=None,
job_config=None,
configure_logging=True,
logging_level=logging.INFO,
logging_format=ray_constants.LOGGER_FORMAT,
log_to_driver=True,
namespace=None,
runtime_env=None,
num_cpus: Optional[int] = None,
num_gpus: Optional[int] = None,
resources: Optional[Dict[str, float]] = None,
object_store_memory: Optional[int] = None,
local_mode: bool = False,
ignore_reinit_error: bool = False,
include_dashboard: Optional[bool] = None,
dashboard_host: str = ray_constants.DEFAULT_DASHBOARD_IP,
dashboard_port: Optional[int] = None,
job_config: "ray.job_config.JobConfig" = None,
configure_logging: bool = True,
logging_level: int = logging.INFO,
logging_format: str = ray_constants.LOGGER_FORMAT,
log_to_driver: bool = True,
namespace: Optional[str] = None,
runtime_env: Dict[str, Any] = None,
# The following are unstable parameters and their use is discouraged.
_enable_object_reconstruction=False,
_redis_max_memory=None,
_plasma_directory=None,
_node_ip_address=ray_constants.NODE_DEFAULT_IP,
_driver_object_store_memory=None,
_memory=None,
_redis_password=ray_constants.REDIS_DEFAULT_PASSWORD,
_temp_dir=None,
_lru_evict=False,
_metrics_export_port=None,
_system_config=None,
_tracing_startup_hook=None,
_enable_object_reconstruction: bool = False,
_redis_max_memory: Optional[int] = None,
_plasma_directory: Optional[str] = None,
_node_ip_address: str = ray_constants.NODE_DEFAULT_IP,
_driver_object_store_memory: Optional[int] = None,
_memory: Optional[int] = None,
_redis_password: str = ray_constants.REDIS_DEFAULT_PASSWORD,
_temp_dir: Optional[str] = None,
_lru_evict: bool = False,
_metrics_export_port: Optional[int] = None,
_system_config: Optional[Dict[str, str]] = None,
_tracing_startup_hook: Optional[Callable] = None,
**kwargs):
"""
Connect to an existing Ray cluster or start one and connect to it.
@ -974,7 +974,7 @@ _post_init_hooks = []
@PublicAPI
@client_mode_hook
def shutdown(_exiting_interpreter=False):
def shutdown(_exiting_interpreter: bool = False):
"""Disconnect the worker, and terminate processes started by ray.init().
This will automatically run at the end when a Python process that uses Ray
@ -1221,7 +1221,7 @@ def listen_error_messages_raylet(worker, threads_stopped):
@PublicAPI
@client_mode_hook
def is_initialized():
def is_initialized() -> bool:
"""Check if ray.init has been called yet.
Returns:
@ -1522,7 +1522,7 @@ def _changeproctitle(title, next_title):
@DeveloperAPI
def show_in_dashboard(message, key="", dtype="text"):
def show_in_dashboard(message: str, key: str = "", dtype: str = "text"):
"""Display message in dashboard.
Display message for the current task or actor in the dashboard.
@ -1556,7 +1556,9 @@ blocking_get_inside_async_warned = False
@PublicAPI
@client_mode_hook
def get(object_refs, *, timeout=None):
def get(object_refs: Union[ray.ObjectRef, List[ray.ObjectRef]],
*,
timeout: Optional[float] = None) -> Union[Any, List[Any]]:
"""Get a remote object or a list of remote objects from the object store.
This method blocks until the object corresponding to the object ref is
@ -1643,7 +1645,8 @@ def get(object_refs, *, timeout=None):
@PublicAPI
@client_mode_hook
def put(value, *, _owner=None):
def put(value: Any, *,
_owner: Optional["ray.actor.ActorHandle"] = None) -> ray.ObjectRef:
"""Store an object in the object store.
The object may not be evicted while a reference to the returned ID exists.
@ -1696,7 +1699,12 @@ blocking_wait_inside_async_warned = False
@PublicAPI
@client_mode_hook
def wait(object_refs, *, num_returns=1, timeout=None, fetch_local=True):
def wait(object_refs: List[ray.ObjectRef],
*,
num_returns: int = 1,
timeout: Optional[float] = None,
fetch_local: bool = True
) -> Tuple[List[ray.ObjectRef], List[ray.ObjectRef]]:
"""Return a list of IDs that are ready and a list of IDs that are not.
If timeout is set, the function returns either when the requested number of
@ -1798,7 +1806,8 @@ def wait(object_refs, *, num_returns=1, timeout=None, fetch_local=True):
@PublicAPI
@client_mode_hook
def get_actor(name: str, namespace: Optional[str] = None):
def get_actor(name: str,
namespace: Optional[str] = None) -> "ray.actor.ActorHandle":
"""Get a handle to a named actor.
Gets a handle to an actor with the given name. The actor must
@ -1829,7 +1838,7 @@ def get_actor(name: str, namespace: Optional[str] = None):
@PublicAPI
@client_mode_hook
def kill(actor, *, no_restart=True):
def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True):
"""Kill an actor forcefully.
This will interrupt any running tasks on the actor, causing them to fail
@ -1858,7 +1867,10 @@ def kill(actor, *, no_restart=True):
@PublicAPI
@client_mode_hook
def cancel(object_ref, *, force=False, recursive=True):
def cancel(object_ref: ray.ObjectRef,
*,
force: bool = False,
recursive: bool = True):
"""Cancels a task according to the following conditions.
If the specified task is pending execution, it will not be executed. If