mirror of
https://github.com/vale981/ray
synced 2025-03-08 11:31:40 -05:00
Improve code related to node (#4383)
* Make full use of node implement local node fix bugs mentioned in comments * Add more tests * Use more specific exception handling * fix, lint * fix for py2.x
This commit is contained in:
parent
c5bcec54f3
commit
dab99d26af
7 changed files with 243 additions and 245 deletions
|
@ -4,7 +4,6 @@ from __future__ import print_function
|
|||
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import redis
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
@ -13,6 +12,7 @@ from ray.function_manager import FunctionDescriptor
|
|||
import ray.gcs_utils
|
||||
|
||||
from ray.ray_constants import ID_SIZE
|
||||
from ray import services
|
||||
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
|
||||
hex_to_binary)
|
||||
|
||||
|
@ -126,8 +126,7 @@ class GlobalState(object):
|
|||
self.redis_clients = None
|
||||
|
||||
def _initialize_global_state(self,
|
||||
redis_ip_address,
|
||||
redis_port,
|
||||
redis_address,
|
||||
redis_password=None,
|
||||
timeout=20):
|
||||
"""Initialize the GlobalState object by connecting to Redis.
|
||||
|
@ -137,18 +136,15 @@ class GlobalState(object):
|
|||
been populated or we exceed a timeout.
|
||||
|
||||
Args:
|
||||
redis_ip_address: The IP address of the node that the Redis server
|
||||
lives on.
|
||||
redis_port: The port that the Redis server is listening on.
|
||||
redis_address: The Redis address to connect.
|
||||
redis_password: The password of the redis server.
|
||||
"""
|
||||
self.redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=redis_port, password=redis_password)
|
||||
|
||||
self.redis_client = services.create_redis_client(
|
||||
redis_address, redis_password)
|
||||
start_time = time.time()
|
||||
|
||||
num_redis_shards = None
|
||||
ip_address_ports = []
|
||||
redis_shard_addresses = []
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# Attempt to get the number of Redis shards.
|
||||
|
@ -163,9 +159,9 @@ class GlobalState(object):
|
|||
"{}.".format(num_redis_shards))
|
||||
|
||||
# Attempt to get all of the Redis shards.
|
||||
ip_address_ports = self.redis_client.lrange(
|
||||
redis_shard_addresses = self.redis_client.lrange(
|
||||
"RedisShards", start=0, end=-1)
|
||||
if len(ip_address_ports) != num_redis_shards:
|
||||
if len(redis_shard_addresses) != num_redis_shards:
|
||||
print("Waiting longer for RedisShards to be populated.")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
@ -177,18 +173,15 @@ class GlobalState(object):
|
|||
if time.time() - start_time >= timeout:
|
||||
raise Exception("Timed out while attempting to initialize the "
|
||||
"global state. num_redis_shards = {}, "
|
||||
"ip_address_ports = {}".format(
|
||||
num_redis_shards, ip_address_ports))
|
||||
"redis_shard_addresses = {}".format(
|
||||
num_redis_shards, redis_shard_addresses))
|
||||
|
||||
# Get the rest of the information.
|
||||
self.redis_clients = []
|
||||
for ip_address_port in ip_address_ports:
|
||||
shard_address, shard_port = ip_address_port.split(b":")
|
||||
for shard_address in redis_shard_addresses:
|
||||
self.redis_clients.append(
|
||||
redis.StrictRedis(
|
||||
host=shard_address,
|
||||
port=shard_port,
|
||||
password=redis_password))
|
||||
services.create_redis_client(shard_address.decode(),
|
||||
redis_password))
|
||||
|
||||
def _execute_command(self, key, *args):
|
||||
"""Execute a Redis command on the appropriate Redis shard based on key.
|
||||
|
|
|
@ -16,7 +16,6 @@ import ray.cloudpickle as pickle
|
|||
import ray.gcs_utils
|
||||
import ray.utils
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray.services import get_ip_address, get_port
|
||||
from ray.utils import (binary_to_hex, binary_to_object_id, hex_to_binary,
|
||||
setup_logger)
|
||||
|
||||
|
@ -32,17 +31,15 @@ class Monitor(object):
|
|||
|
||||
Attributes:
|
||||
redis: A connection to the Redis server.
|
||||
subscribe_client: A pubsub client for the Redis server. This is used to
|
||||
receive notifications about failed components.
|
||||
primary_subscribe_client: A pubsub client for the Redis server.
|
||||
This is used to receive notifications about failed components.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_address, autoscaling_config, redis_password=None):
|
||||
# Initialize the Redis clients.
|
||||
self.state = ray.experimental.state.GlobalState()
|
||||
redis_ip_address = get_ip_address(args.redis_address)
|
||||
redis_port = get_port(args.redis_address)
|
||||
self.state._initialize_global_state(
|
||||
redis_ip_address, redis_port, redis_password=redis_password)
|
||||
args.redis_address, redis_password=redis_password)
|
||||
self.redis = ray.services.create_redis_client(
|
||||
redis_address, password=redis_password)
|
||||
# Setup subscriptions to the primary Redis server and the Redis shards.
|
||||
|
|
|
@ -31,7 +31,8 @@ PY3 = sys.version_info.major >= 3
|
|||
class Node(object):
|
||||
"""An encapsulation of the Ray processes on a single node.
|
||||
|
||||
This class is responsible for starting Ray processes and killing them.
|
||||
This class is responsible for starting Ray processes and killing them,
|
||||
and it also controls the temp file policy.
|
||||
|
||||
Attributes:
|
||||
all_processes (dict): A mapping from process type (str) to a list of
|
||||
|
@ -63,8 +64,17 @@ class Node(object):
|
|||
"be both true.")
|
||||
self.all_processes = {}
|
||||
|
||||
# Try to get node IP address with the parameters.
|
||||
if ray_params.node_ip_address:
|
||||
node_ip_address = ray_params.node_ip_address
|
||||
elif ray_params.redis_address:
|
||||
node_ip_address = ray.services.get_node_ip_address(
|
||||
ray_params.redis_address)
|
||||
else:
|
||||
node_ip_address = ray.services.get_node_ip_address()
|
||||
self._node_ip_address = node_ip_address
|
||||
|
||||
ray_params.update_if_absent(
|
||||
node_ip_address=ray.services.get_node_ip_address(),
|
||||
include_log_monitor=True,
|
||||
resources={},
|
||||
include_webui=False,
|
||||
|
@ -73,31 +83,51 @@ class Node(object):
|
|||
"workers/default_worker.py"))
|
||||
|
||||
self._ray_params = ray_params
|
||||
|
||||
self._node_ip_address = ray_params.node_ip_address
|
||||
self._redis_address = ray_params.redis_address
|
||||
self._config = (json.loads(ray_params._internal_config)
|
||||
if ray_params._internal_config else None)
|
||||
|
||||
if head:
|
||||
ray_params.update_if_absent(num_redis_shards=1, include_webui=True)
|
||||
self._plasma_store_socket_name = None
|
||||
self._raylet_socket_name = None
|
||||
self._webui_url = None
|
||||
else:
|
||||
self._init_temp()
|
||||
|
||||
if connect_only:
|
||||
# Get socket names from the configuration.
|
||||
self._plasma_store_socket_name = (
|
||||
ray_params.plasma_store_socket_name)
|
||||
self._raylet_socket_name = ray_params.raylet_socket_name
|
||||
|
||||
# If user does not provide the socket name, get it from Redis.
|
||||
if (self._plasma_store_socket_name is None
|
||||
or self._raylet_socket_name is None):
|
||||
# Get the address info of the processes to connect to
|
||||
# from Redis.
|
||||
address_info = ray.services.get_address_info_from_redis(
|
||||
self.redis_address,
|
||||
self._node_ip_address,
|
||||
redis_password=self.redis_password)
|
||||
self._plasma_store_socket_name = address_info[
|
||||
"object_store_address"]
|
||||
self._raylet_socket_name = address_info["raylet_socket_name"]
|
||||
else:
|
||||
# If the user specified a socket name, use it.
|
||||
self._plasma_store_socket_name = self._prepare_socket_file(
|
||||
self._ray_params.plasma_store_socket_name,
|
||||
default_prefix="plasma_store")
|
||||
self._raylet_socket_name = self._prepare_socket_file(
|
||||
self._ray_params.raylet_socket_name, default_prefix="raylet")
|
||||
|
||||
if head:
|
||||
ray_params.update_if_absent(num_redis_shards=1, include_webui=True)
|
||||
self._webui_url = None
|
||||
else:
|
||||
redis_client = self.create_redis_client()
|
||||
# TODO(suquark): Replace _webui_url_helper in worker.py in
|
||||
# another PR.
|
||||
_webui_url = redis_client.hmget("webui", "url")[0]
|
||||
self._webui_url = (ray.utils.decode(_webui_url)
|
||||
if _webui_url is not None else None)
|
||||
self._webui_url = (
|
||||
ray.services.get_webui_url_from_redis(redis_client))
|
||||
ray_params.include_java = (
|
||||
ray.services.include_java_from_redis(redis_client))
|
||||
|
||||
self._init_temp()
|
||||
# Start processes.
|
||||
if head:
|
||||
self.start_head_processes()
|
||||
|
||||
if not connect_only:
|
||||
self.start_ray_processes()
|
||||
|
@ -136,6 +166,20 @@ class Node(object):
|
|||
"""Get the cluster Redis address."""
|
||||
return self._redis_address
|
||||
|
||||
@property
|
||||
def redis_password(self):
|
||||
"""Get the cluster Redis password"""
|
||||
return self._ray_params.redis_password
|
||||
|
||||
@property
|
||||
def load_code_from_local(self):
|
||||
return self._ray_params.load_code_from_local
|
||||
|
||||
@property
|
||||
def object_id_seed(self):
|
||||
"""Get the seed for deterministic generation of object IDs"""
|
||||
return self._ray_params.object_id_seed
|
||||
|
||||
@property
|
||||
def plasma_store_socket_name(self):
|
||||
"""Get the node's plasma store socket name."""
|
||||
|
@ -151,6 +195,17 @@ class Node(object):
|
|||
"""Get the node's raylet socket name."""
|
||||
return self._raylet_socket_name
|
||||
|
||||
@property
|
||||
def address_info(self):
|
||||
"""Get a dictionary of addresses."""
|
||||
return {
|
||||
"node_ip_address": self._node_ip_address,
|
||||
"redis_address": self._redis_address,
|
||||
"object_store_address": self._plasma_store_socket_name,
|
||||
"raylet_socket_name": self._raylet_socket_name,
|
||||
"webui_url": self._webui_url,
|
||||
}
|
||||
|
||||
def create_redis_client(self):
|
||||
"""Create a redis client."""
|
||||
return ray.services.create_redis_client(
|
||||
|
@ -321,11 +376,6 @@ class Node(object):
|
|||
|
||||
def start_plasma_store(self):
|
||||
"""Start the plasma store."""
|
||||
assert self._plasma_store_socket_name is None
|
||||
# If the user specified a socket name, use it.
|
||||
self._plasma_store_socket_name = self._prepare_socket_file(
|
||||
self._ray_params.plasma_store_socket_name,
|
||||
default_prefix="plasma_store")
|
||||
stdout_file, stderr_file = self.new_log_files("plasma_store")
|
||||
process_info = ray.services.start_plasma_store(
|
||||
stdout_file=stdout_file,
|
||||
|
@ -349,10 +399,6 @@ class Node(object):
|
|||
use_profiler (bool): True if we should start the process in the
|
||||
valgrind profiler.
|
||||
"""
|
||||
assert self._raylet_socket_name is None
|
||||
# If the user specified a socket name, use it.
|
||||
self._raylet_socket_name = self._prepare_socket_file(
|
||||
self._ray_params.raylet_socket_name, default_prefix="raylet")
|
||||
stdout_file, stderr_file = self.new_log_files("raylet")
|
||||
process_info = ray.services.start_raylet(
|
||||
self._redis_address,
|
||||
|
@ -416,20 +462,26 @@ class Node(object):
|
|||
process_info
|
||||
]
|
||||
|
||||
def start_head_processes(self):
|
||||
"""Start head processes on the node."""
|
||||
logger.info(
|
||||
"Process STDOUT and STDERR is being redirected to {}.".format(
|
||||
self._logs_dir))
|
||||
assert self._redis_address is None
|
||||
# If this is the head node, start the relevant head node processes.
|
||||
self.start_redis()
|
||||
self.start_monitor()
|
||||
self.start_raylet_monitor()
|
||||
# The dashboard is Python3.x only.
|
||||
if PY3 and self._ray_params.include_webui:
|
||||
self.start_dashboard()
|
||||
|
||||
def start_ray_processes(self):
|
||||
"""Start all of the processes on the node."""
|
||||
logger.info(
|
||||
"Process STDOUT and STDERR is being redirected to {}.".format(
|
||||
self._logs_dir))
|
||||
|
||||
# If this is the head node, start the relevant head node processes.
|
||||
if self._redis_address is None:
|
||||
self.start_redis()
|
||||
self.start_monitor()
|
||||
self.start_raylet_monitor()
|
||||
if PY3 and self._ray_params.include_webui:
|
||||
self.start_dashboard()
|
||||
|
||||
self.start_plasma_store()
|
||||
self.start_raylet()
|
||||
if PY3:
|
||||
|
@ -685,3 +737,16 @@ class Node(object):
|
|||
True if any process that wasn't explicitly killed is still alive.
|
||||
"""
|
||||
return not any(self.dead_processes())
|
||||
|
||||
|
||||
class LocalNode(object):
|
||||
"""Imitate the node that manages the processes in local mode."""
|
||||
|
||||
def kill_all_processes(self, *args, **kwargs):
|
||||
"""Kill all of the processes."""
|
||||
pass # Keep this function empty because it will be used in worker.py
|
||||
|
||||
@property
|
||||
def address_info(self):
|
||||
"""Get a dictionary of addresses."""
|
||||
return {} # Return a null dict.
|
||||
|
|
|
@ -77,20 +77,6 @@ def address(ip_address, port):
|
|||
return ip_address + ":" + str(port)
|
||||
|
||||
|
||||
def get_ip_address(address):
|
||||
assert type(address) == str, "Address must be a string"
|
||||
ip_address = address.split(":")[0]
|
||||
return ip_address
|
||||
|
||||
|
||||
def get_port(address):
|
||||
try:
|
||||
port = int(address.split(":")[1])
|
||||
except Exception:
|
||||
raise Exception("Unable to parse port from address {}".format(address))
|
||||
return port
|
||||
|
||||
|
||||
def new_port():
|
||||
return random.randint(10000, 65535)
|
||||
|
||||
|
@ -107,6 +93,64 @@ def include_java_from_redis(redis_client):
|
|||
return redis_client.get("INCLUDE_JAVA") == b"1"
|
||||
|
||||
|
||||
def get_address_info_from_redis_helper(redis_address,
|
||||
node_ip_address,
|
||||
redis_password=None):
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
# For this command to work, some other client (on the same machine as
|
||||
# Redis) must have run "CONFIG SET protected-mode no".
|
||||
redis_client = create_redis_client(redis_address, password=redis_password)
|
||||
|
||||
client_table = ray.experimental.state.parse_client_table(redis_client)
|
||||
if len(client_table) == 0:
|
||||
raise Exception(
|
||||
"Redis has started but no raylets have registered yet.")
|
||||
|
||||
relevant_client = None
|
||||
for client_info in client_table:
|
||||
client_node_ip_address = client_info["NodeManagerAddress"]
|
||||
if (client_node_ip_address == node_ip_address
|
||||
or (client_node_ip_address == "127.0.0.1"
|
||||
and redis_ip_address == get_node_ip_address())):
|
||||
relevant_client = client_info
|
||||
break
|
||||
if relevant_client is None:
|
||||
raise Exception(
|
||||
"Redis has started but no raylets have registered yet.")
|
||||
|
||||
return {
|
||||
"object_store_address": relevant_client["ObjectStoreSocketName"],
|
||||
"raylet_socket_name": relevant_client["RayletSocketName"],
|
||||
}
|
||||
|
||||
|
||||
def get_address_info_from_redis(redis_address,
|
||||
node_ip_address,
|
||||
num_retries=5,
|
||||
redis_password=None):
|
||||
counter = 0
|
||||
while True:
|
||||
try:
|
||||
return get_address_info_from_redis_helper(
|
||||
redis_address, node_ip_address, redis_password=redis_password)
|
||||
except Exception:
|
||||
if counter == num_retries:
|
||||
raise
|
||||
# Some of the information may not be in Redis yet, so wait a little
|
||||
# bit.
|
||||
logger.warning(
|
||||
"Some processes that the driver needs to connect to have "
|
||||
"not registered with Redis, so retrying. Have you run "
|
||||
"'ray start' on this node?")
|
||||
time.sleep(1)
|
||||
counter += 1
|
||||
|
||||
|
||||
def get_webui_url_from_redis(redis_client):
|
||||
webui_url = redis_client.hmget("webui", "url")[0]
|
||||
return ray.utils.decode(webui_url) if webui_url is not None else None
|
||||
|
||||
|
||||
def remaining_processes_alive():
|
||||
"""See if the remaining processes are alive or not.
|
||||
|
||||
|
|
|
@ -7,6 +7,13 @@ import shutil
|
|||
import time
|
||||
import pytest
|
||||
import ray
|
||||
from ray.tests.cluster_utils import Cluster
|
||||
|
||||
# Py2 compatibility
|
||||
try:
|
||||
FileNotFoundError
|
||||
except NameError:
|
||||
FileNotFoundError = OSError
|
||||
|
||||
|
||||
def test_conn_cluster():
|
||||
|
@ -52,8 +59,17 @@ def test_raylet_socket_name():
|
|||
ray.shutdown()
|
||||
try:
|
||||
os.remove("/tmp/i_am_a_temp_socket")
|
||||
except Exception:
|
||||
pass
|
||||
except FileNotFoundError:
|
||||
pass # It could have been removed by Ray.
|
||||
cluster = Cluster(True)
|
||||
cluster.add_node(raylet_socket_name="/tmp/i_am_a_temp_socket_2")
|
||||
assert os.path.exists(
|
||||
"/tmp/i_am_a_temp_socket_2"), "Specified socket path not found."
|
||||
cluster.shutdown()
|
||||
try:
|
||||
os.remove("/tmp/i_am_a_temp_socket_2")
|
||||
except FileNotFoundError:
|
||||
pass # It could have been removed by Ray.
|
||||
|
||||
|
||||
def test_temp_plasma_store_socket():
|
||||
|
@ -63,8 +79,17 @@ def test_temp_plasma_store_socket():
|
|||
ray.shutdown()
|
||||
try:
|
||||
os.remove("/tmp/i_am_a_temp_socket")
|
||||
except Exception:
|
||||
pass
|
||||
except FileNotFoundError:
|
||||
pass # It could have been removed by Ray.
|
||||
cluster = Cluster(True)
|
||||
cluster.add_node(plasma_store_socket_name="/tmp/i_am_a_temp_socket_2")
|
||||
assert os.path.exists(
|
||||
"/tmp/i_am_a_temp_socket_2"), "Specified socket path not found."
|
||||
cluster.shutdown()
|
||||
try:
|
||||
os.remove("/tmp/i_am_a_temp_socket_2")
|
||||
except FileNotFoundError:
|
||||
pass # It could have been removed by Ray.
|
||||
|
||||
|
||||
def test_raylet_tempfiles():
|
||||
|
|
|
@ -12,7 +12,6 @@ import json
|
|||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import redis
|
||||
import signal
|
||||
from six.moves import queue
|
||||
import sys
|
||||
|
@ -115,6 +114,7 @@ class Worker(object):
|
|||
|
||||
Attributes:
|
||||
connected (bool): True if Ray has been started and False otherwise.
|
||||
node (ray.node.Node): The node this worker is attached to.
|
||||
mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and
|
||||
WORKER_MODE.
|
||||
cached_functions_to_run (List): A list of functions to run on all of
|
||||
|
@ -124,7 +124,7 @@ class Worker(object):
|
|||
|
||||
def __init__(self):
|
||||
"""Initialize a Worker object."""
|
||||
self.connected = False
|
||||
self.node = None
|
||||
self.mode = None
|
||||
self.cached_functions_to_run = []
|
||||
self.actor_init_error = None
|
||||
|
@ -144,7 +144,6 @@ class Worker(object):
|
|||
# A dictionary that maps from driver id to SerializationContext
|
||||
# TODO: clean up the SerializationContext once the job finished.
|
||||
self.serialization_context_map = {}
|
||||
self.load_code_from_local = False
|
||||
self.function_actor_manager = FunctionActorManager(self)
|
||||
# Identity of the driver that this worker is processing.
|
||||
# It is a DriverID.
|
||||
|
@ -158,6 +157,20 @@ class Worker(object):
|
|||
self._session_index = 0
|
||||
self._current_task = None
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
return self.node is not None
|
||||
|
||||
@property
|
||||
def node_ip_address(self):
|
||||
self.check_connected()
|
||||
return self.node.node_ip_address
|
||||
|
||||
@property
|
||||
def load_code_from_local(self):
|
||||
self.check_connected()
|
||||
return self.node.load_code_from_local
|
||||
|
||||
@property
|
||||
def task_context(self):
|
||||
"""A thread-local that contains the following attributes.
|
||||
|
@ -1072,19 +1085,6 @@ def get_resource_ids():
|
|||
return global_worker.raylet_client.resource_ids()
|
||||
|
||||
|
||||
def _webui_url_helper(client):
|
||||
"""Parsing for getting the url of the web UI.
|
||||
|
||||
Args:
|
||||
client: A redis client to use to query the primary Redis shard.
|
||||
|
||||
Returns:
|
||||
The URL of the web UI as a string.
|
||||
"""
|
||||
result = client.hmget("webui", "url")[0]
|
||||
return ray.utils.decode(result) if result is not None else result
|
||||
|
||||
|
||||
def get_webui_url():
|
||||
"""Get the URL to access the web UI.
|
||||
|
||||
|
@ -1093,10 +1093,9 @@ def get_webui_url():
|
|||
Returns:
|
||||
The URL of the web UI as a string.
|
||||
"""
|
||||
if _mode() == LOCAL_MODE:
|
||||
raise Exception("ray.get_webui_url() currently does not work in "
|
||||
"PYTHON MODE.")
|
||||
return _webui_url_helper(global_worker.redis_client)
|
||||
if _global_node is None:
|
||||
raise Exception("Ray has not been initialized/connected.")
|
||||
return _global_node.get_webui_url
|
||||
|
||||
|
||||
global_worker = Worker()
|
||||
|
@ -1211,64 +1210,6 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
|||
class_id="ray.signature.FunctionSignature")
|
||||
|
||||
|
||||
def get_address_info_from_redis_helper(redis_address,
|
||||
node_ip_address,
|
||||
redis_password=None):
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
# For this command to work, some other client (on the same machine as
|
||||
# Redis) must have run "CONFIG SET protected-mode no".
|
||||
redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=int(redis_port), password=redis_password)
|
||||
|
||||
client_table = ray.experimental.state.parse_client_table(redis_client)
|
||||
if len(client_table) == 0:
|
||||
raise Exception(
|
||||
"Redis has started but no raylets have registered yet.")
|
||||
|
||||
relevant_client = None
|
||||
for client_info in client_table:
|
||||
client_node_ip_address = client_info["NodeManagerAddress"]
|
||||
if (client_node_ip_address == node_ip_address or
|
||||
(client_node_ip_address == "127.0.0.1"
|
||||
and redis_ip_address == ray.services.get_node_ip_address())):
|
||||
relevant_client = client_info
|
||||
break
|
||||
if relevant_client is None:
|
||||
raise Exception(
|
||||
"Redis has started but no raylets have registered yet.")
|
||||
|
||||
return {
|
||||
"node_ip_address": node_ip_address,
|
||||
"redis_address": redis_address,
|
||||
"object_store_address": relevant_client["ObjectStoreSocketName"],
|
||||
"raylet_socket_name": relevant_client["RayletSocketName"],
|
||||
# Web UI should be running.
|
||||
"webui_url": _webui_url_helper(redis_client)
|
||||
}
|
||||
|
||||
|
||||
def get_address_info_from_redis(redis_address,
|
||||
node_ip_address,
|
||||
num_retries=5,
|
||||
redis_password=None):
|
||||
counter = 0
|
||||
while True:
|
||||
try:
|
||||
return get_address_info_from_redis_helper(
|
||||
redis_address, node_ip_address, redis_password=redis_password)
|
||||
except Exception:
|
||||
if counter == num_retries:
|
||||
raise
|
||||
# Some of the information may not be in Redis yet, so wait a little
|
||||
# bit.
|
||||
logger.warning(
|
||||
"Some processes that the driver needs to connect to have "
|
||||
"not registered with Redis, so retrying. Have you run "
|
||||
"'ray start' on this node?")
|
||||
time.sleep(1)
|
||||
counter += 1
|
||||
|
||||
|
||||
def init(redis_address=None,
|
||||
num_cpus=None,
|
||||
num_gpus=None,
|
||||
|
@ -1414,22 +1355,11 @@ def init(redis_address=None,
|
|||
if redis_address is not None:
|
||||
redis_address = services.address_to_ip(redis_address)
|
||||
|
||||
address_info = {
|
||||
"node_ip_address": node_ip_address,
|
||||
"redis_address": redis_address
|
||||
}
|
||||
|
||||
global _global_node
|
||||
if driver_mode == LOCAL_MODE:
|
||||
# If starting Ray in LOCAL_MODE, don't start any other processes.
|
||||
pass
|
||||
_global_node = ray.node.LocalNode()
|
||||
elif redis_address is None:
|
||||
# TODO(suquark): We should remove the code below because they
|
||||
# have been set when initializing the node.
|
||||
if node_ip_address is None:
|
||||
node_ip_address = ray.services.get_node_ip_address()
|
||||
if num_redis_shards is None:
|
||||
num_redis_shards = 1
|
||||
# In this case, we need to start a new cluster.
|
||||
ray_params = ray.parameter.RayParams(
|
||||
redis_address=redis_address,
|
||||
|
@ -1461,11 +1391,6 @@ def init(redis_address=None,
|
|||
# handler.
|
||||
_global_node = ray.node.Node(
|
||||
head=True, shutdown_at_exit=False, ray_params=ray_params)
|
||||
address_info["redis_address"] = _global_node.redis_address
|
||||
address_info[
|
||||
"object_store_address"] = _global_node.plasma_store_socket_name
|
||||
address_info["webui_url"] = _global_node.webui_url
|
||||
address_info["raylet_socket_name"] = _global_node.raylet_socket_name
|
||||
else:
|
||||
# In this case, we are connecting to an existing cluster.
|
||||
if num_cpus is not None or num_gpus is not None:
|
||||
|
@ -1505,51 +1430,28 @@ def init(redis_address=None,
|
|||
raise Exception("When connecting to an existing cluster, "
|
||||
"_internal_config must not be provided.")
|
||||
|
||||
# Get the node IP address if one is not provided.
|
||||
|
||||
if node_ip_address is None:
|
||||
node_ip_address = services.get_node_ip_address(redis_address)
|
||||
# Get the address info of the processes to connect to from Redis.
|
||||
address_info = get_address_info_from_redis(
|
||||
redis_address, node_ip_address, redis_password=redis_password)
|
||||
# TODO(suquark): Use "node" as the input of "connect()".
|
||||
# In this case, we only need to connect the node.
|
||||
ray_params = ray.parameter.RayParams(
|
||||
node_ip_address=node_ip_address,
|
||||
redis_address=redis_address,
|
||||
redis_password=redis_password,
|
||||
plasma_store_socket_name=address_info["object_store_address"],
|
||||
raylet_socket_name=address_info["raylet_socket_name"],
|
||||
object_id_seed=object_id_seed,
|
||||
temp_dir=temp_dir)
|
||||
temp_dir=temp_dir,
|
||||
load_code_from_local=load_code_from_local)
|
||||
_global_node = ray.node.Node(
|
||||
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
|
||||
|
||||
if driver_mode == LOCAL_MODE:
|
||||
driver_address_info = {}
|
||||
else:
|
||||
driver_address_info = {
|
||||
"node_ip_address": node_ip_address,
|
||||
"redis_address": address_info["redis_address"],
|
||||
"store_socket_name": address_info["object_store_address"],
|
||||
"webui_url": address_info["webui_url"],
|
||||
"raylet_socket_name": address_info["raylet_socket_name"],
|
||||
}
|
||||
|
||||
connect(
|
||||
driver_address_info,
|
||||
redis_password=redis_password,
|
||||
object_id_seed=object_id_seed,
|
||||
_global_node,
|
||||
mode=driver_mode,
|
||||
log_to_driver=log_to_driver,
|
||||
worker=global_worker,
|
||||
driver_id=driver_id,
|
||||
load_code_from_local=load_code_from_local)
|
||||
driver_id=driver_id)
|
||||
|
||||
for hook in _post_init_hooks:
|
||||
hook()
|
||||
|
||||
return address_info
|
||||
return _global_node.address_info
|
||||
|
||||
|
||||
# Functions to run as callback after a successful ray init
|
||||
|
@ -1782,9 +1684,7 @@ def is_initialized():
|
|||
return ray.worker.global_worker.connected
|
||||
|
||||
|
||||
def connect(info,
|
||||
redis_password=None,
|
||||
object_id_seed=None,
|
||||
def connect(node,
|
||||
mode=WORKER_MODE,
|
||||
log_to_driver=False,
|
||||
worker=global_worker,
|
||||
|
@ -1793,14 +1693,7 @@ def connect(info,
|
|||
"""Connect this worker to the raylet, to Plasma, and to Redis.
|
||||
|
||||
Args:
|
||||
info (dict): A dictionary with address of the Redis server and the
|
||||
sockets of the plasma store and raylet.
|
||||
redis_password (str): Prevents external clients without the password
|
||||
from connecting to Redis if provided.
|
||||
object_id_seed (int): Used to seed the deterministic generation of
|
||||
object IDs. The same value can be used across multiple runs of the
|
||||
same job in order to generate the object IDs in a consistent
|
||||
manner. However, the same ID should not be used for different jobs.
|
||||
node (ray.node.Node): The node to connect.
|
||||
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and
|
||||
LOCAL_MODE.
|
||||
log_to_driver (bool): If true, then output from all of the worker
|
||||
|
@ -1844,25 +1737,19 @@ def connect(info,
|
|||
# All workers start out as non-actors. A worker can be turned into an actor
|
||||
# after it is created.
|
||||
worker.actor_id = ActorID.nil()
|
||||
worker.connected = True
|
||||
worker.node = node
|
||||
worker.set_mode(mode)
|
||||
worker.load_code_from_local = load_code_from_local
|
||||
|
||||
# If running Ray in LOCAL_MODE, there is no need to create call
|
||||
# create_worker or to start the worker service.
|
||||
if mode == LOCAL_MODE:
|
||||
return
|
||||
# Set the node IP address.
|
||||
worker.node_ip_address = info["node_ip_address"]
|
||||
worker.redis_address = info["redis_address"]
|
||||
|
||||
# Create a Redis client.
|
||||
redis_ip_address, redis_port = info["redis_address"].split(":")
|
||||
# The Redis client can safely be shared between threads. However, that is
|
||||
# not true of Redis pubsub clients. See the documentation at
|
||||
# https://github.com/andymccurdy/redis-py#thread-safety.
|
||||
worker.redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=int(redis_port), password=redis_password)
|
||||
worker.redis_client = node.create_redis_client()
|
||||
|
||||
# For driver's check that the version information matches the version
|
||||
# information that the Ray cluster was started with.
|
||||
|
@ -1883,7 +1770,7 @@ def connect(info,
|
|||
|
||||
# Create an object for interfacing with the global state.
|
||||
global_state._initialize_global_state(
|
||||
redis_ip_address, int(redis_port), redis_password=redis_password)
|
||||
node.redis_address, redis_password=node.redis_password)
|
||||
|
||||
# Register the worker with Redis.
|
||||
if mode == SCRIPT_MODE:
|
||||
|
@ -1891,11 +1778,11 @@ def connect(info,
|
|||
# Register the driver/job with Redis here.
|
||||
import __main__ as main
|
||||
driver_info = {
|
||||
"node_ip_address": worker.node_ip_address,
|
||||
"node_ip_address": node.node_ip_address,
|
||||
"driver_id": worker.worker_id,
|
||||
"start_time": time.time(),
|
||||
"plasma_store_socket": info["store_socket_name"],
|
||||
"raylet_socket": info.get("raylet_socket_name"),
|
||||
"plasma_store_socket": node.plasma_store_socket_name,
|
||||
"raylet_socket": node.raylet_socket_name,
|
||||
"name": (main.__file__
|
||||
if hasattr(main, "__file__") else "INTERACTIVE MODE")
|
||||
}
|
||||
|
@ -1903,8 +1790,8 @@ def connect(info,
|
|||
elif mode == WORKER_MODE:
|
||||
# Register the worker with Redis.
|
||||
worker_dict = {
|
||||
"node_ip_address": worker.node_ip_address,
|
||||
"plasma_store_socket": info["store_socket_name"],
|
||||
"node_ip_address": node.node_ip_address,
|
||||
"plasma_store_socket": node.plasma_store_socket_name,
|
||||
}
|
||||
# Check the RedirectOutput key in Redis and based on its value redirect
|
||||
# worker output and error to their own files.
|
||||
|
@ -1913,7 +1800,7 @@ def connect(info,
|
|||
if (redirect_worker_output_val is not None
|
||||
and int(redirect_worker_output_val) == 1):
|
||||
log_stdout_file, log_stderr_file = (
|
||||
_global_node.new_worker_redirected_log_file(worker.worker_id))
|
||||
node.new_worker_redirected_log_file(worker.worker_id))
|
||||
# Redirect stdout/stderr at the file descriptor level. If we simply
|
||||
# set sys.stdout and sys.stderr, then logging from C++ can fail to
|
||||
# be redirected.
|
||||
|
@ -1941,7 +1828,7 @@ def connect(info,
|
|||
|
||||
# Create an object store client.
|
||||
worker.plasma_client = thread_safe_client(
|
||||
plasma.connect(info["store_socket_name"], None, 0, 300))
|
||||
plasma.connect(node.plasma_store_socket_name, None, 0, 300))
|
||||
|
||||
# If this is a driver, set the current task ID, the task driver ID, and set
|
||||
# the task index to 0.
|
||||
|
@ -1951,8 +1838,8 @@ def connect(info,
|
|||
# the user's random number generator). Otherwise, set the current task
|
||||
# ID randomly to avoid object ID collisions.
|
||||
numpy_state = np.random.get_state()
|
||||
if object_id_seed is not None:
|
||||
np.random.seed(object_id_seed)
|
||||
if node.object_id_seed is not None:
|
||||
np.random.seed(node.object_id_seed)
|
||||
else:
|
||||
# Try to use true randomness.
|
||||
np.random.seed(None)
|
||||
|
@ -1999,7 +1886,7 @@ def connect(info,
|
|||
worker.task_context.current_task_id = driver_task.task_id()
|
||||
|
||||
worker.raylet_client = ray._raylet.RayletClient(
|
||||
info["raylet_socket_name"],
|
||||
node.raylet_socket_name,
|
||||
ClientID(worker.worker_id),
|
||||
(mode == WORKER_MODE),
|
||||
DriverID(worker.current_task_id.binary()),
|
||||
|
@ -2096,7 +1983,7 @@ def disconnect():
|
|||
worker.threads_stopped.clear()
|
||||
worker._session_index += 1
|
||||
|
||||
worker.connected = False
|
||||
worker.node = None # Disconnect the worker from the node.
|
||||
worker.cached_functions_to_run = []
|
||||
worker.function_actor_manager.reset_cache()
|
||||
worker.serialization_context_map.clear()
|
||||
|
|
|
@ -66,14 +66,6 @@ parser.add_argument(
|
|||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
info = {
|
||||
"node_ip_address": args.node_ip_address,
|
||||
"redis_address": args.redis_address,
|
||||
"redis_password": args.redis_password,
|
||||
"store_socket_name": args.object_store_name,
|
||||
"raylet_socket_name": args.raylet_name,
|
||||
}
|
||||
|
||||
ray.utils.setup_logger(args.logging_level, args.logging_format)
|
||||
|
||||
ray_params = RayParams(
|
||||
|
@ -89,12 +81,7 @@ if __name__ == "__main__":
|
|||
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
|
||||
ray.worker._global_node = node
|
||||
|
||||
# TODO(suquark): Use "node" as the input of "connect".
|
||||
ray.worker.connect(
|
||||
info,
|
||||
redis_password=args.redis_password,
|
||||
mode=ray.WORKER_MODE,
|
||||
load_code_from_local=args.load_code_from_local)
|
||||
ray.worker.connect(node, mode=ray.WORKER_MODE)
|
||||
|
||||
error_explanation = """
|
||||
This error is unexpected and should not have happened. Somehow a worker
|
||||
|
|
Loading…
Add table
Reference in a new issue