Improve code related to node (#4383)

* Make full use of node

implement local node

fix bugs mentioned in comments

* Add more tests

* Use more specific exception handling

* fix, lint

* fix for py2.x
This commit is contained in:
Si-Yuan 2019-04-09 17:27:54 +08:00 committed by GitHub
parent c5bcec54f3
commit dab99d26af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 243 additions and 245 deletions

View file

@ -4,7 +4,6 @@ from __future__ import print_function
from collections import defaultdict from collections import defaultdict
import json import json
import redis
import sys import sys
import time import time
@ -13,6 +12,7 @@ from ray.function_manager import FunctionDescriptor
import ray.gcs_utils import ray.gcs_utils
from ray.ray_constants import ID_SIZE from ray.ray_constants import ID_SIZE
from ray import services
from ray.utils import (decode, binary_to_object_id, binary_to_hex, from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary) hex_to_binary)
@ -126,8 +126,7 @@ class GlobalState(object):
self.redis_clients = None self.redis_clients = None
def _initialize_global_state(self, def _initialize_global_state(self,
redis_ip_address, redis_address,
redis_port,
redis_password=None, redis_password=None,
timeout=20): timeout=20):
"""Initialize the GlobalState object by connecting to Redis. """Initialize the GlobalState object by connecting to Redis.
@ -137,18 +136,15 @@ class GlobalState(object):
been populated or we exceed a timeout. been populated or we exceed a timeout.
Args: Args:
redis_ip_address: The IP address of the node that the Redis server redis_address: The Redis address to connect.
lives on.
redis_port: The port that the Redis server is listening on.
redis_password: The password of the redis server. redis_password: The password of the redis server.
""" """
self.redis_client = redis.StrictRedis( self.redis_client = services.create_redis_client(
host=redis_ip_address, port=redis_port, password=redis_password) redis_address, redis_password)
start_time = time.time() start_time = time.time()
num_redis_shards = None num_redis_shards = None
ip_address_ports = [] redis_shard_addresses = []
while time.time() - start_time < timeout: while time.time() - start_time < timeout:
# Attempt to get the number of Redis shards. # Attempt to get the number of Redis shards.
@ -163,9 +159,9 @@ class GlobalState(object):
"{}.".format(num_redis_shards)) "{}.".format(num_redis_shards))
# Attempt to get all of the Redis shards. # Attempt to get all of the Redis shards.
ip_address_ports = self.redis_client.lrange( redis_shard_addresses = self.redis_client.lrange(
"RedisShards", start=0, end=-1) "RedisShards", start=0, end=-1)
if len(ip_address_ports) != num_redis_shards: if len(redis_shard_addresses) != num_redis_shards:
print("Waiting longer for RedisShards to be populated.") print("Waiting longer for RedisShards to be populated.")
time.sleep(1) time.sleep(1)
continue continue
@ -177,18 +173,15 @@ class GlobalState(object):
if time.time() - start_time >= timeout: if time.time() - start_time >= timeout:
raise Exception("Timed out while attempting to initialize the " raise Exception("Timed out while attempting to initialize the "
"global state. num_redis_shards = {}, " "global state. num_redis_shards = {}, "
"ip_address_ports = {}".format( "redis_shard_addresses = {}".format(
num_redis_shards, ip_address_ports)) num_redis_shards, redis_shard_addresses))
# Get the rest of the information. # Get the rest of the information.
self.redis_clients = [] self.redis_clients = []
for ip_address_port in ip_address_ports: for shard_address in redis_shard_addresses:
shard_address, shard_port = ip_address_port.split(b":")
self.redis_clients.append( self.redis_clients.append(
redis.StrictRedis( services.create_redis_client(shard_address.decode(),
host=shard_address, redis_password))
port=shard_port,
password=redis_password))
def _execute_command(self, key, *args): def _execute_command(self, key, *args):
"""Execute a Redis command on the appropriate Redis shard based on key. """Execute a Redis command on the appropriate Redis shard based on key.

View file

@ -16,7 +16,6 @@ import ray.cloudpickle as pickle
import ray.gcs_utils import ray.gcs_utils
import ray.utils import ray.utils
import ray.ray_constants as ray_constants import ray.ray_constants as ray_constants
from ray.services import get_ip_address, get_port
from ray.utils import (binary_to_hex, binary_to_object_id, hex_to_binary, from ray.utils import (binary_to_hex, binary_to_object_id, hex_to_binary,
setup_logger) setup_logger)
@ -32,17 +31,15 @@ class Monitor(object):
Attributes: Attributes:
redis: A connection to the Redis server. redis: A connection to the Redis server.
subscribe_client: A pubsub client for the Redis server. This is used to primary_subscribe_client: A pubsub client for the Redis server.
receive notifications about failed components. This is used to receive notifications about failed components.
""" """
def __init__(self, redis_address, autoscaling_config, redis_password=None): def __init__(self, redis_address, autoscaling_config, redis_password=None):
# Initialize the Redis clients. # Initialize the Redis clients.
self.state = ray.experimental.state.GlobalState() self.state = ray.experimental.state.GlobalState()
redis_ip_address = get_ip_address(args.redis_address)
redis_port = get_port(args.redis_address)
self.state._initialize_global_state( self.state._initialize_global_state(
redis_ip_address, redis_port, redis_password=redis_password) args.redis_address, redis_password=redis_password)
self.redis = ray.services.create_redis_client( self.redis = ray.services.create_redis_client(
redis_address, password=redis_password) redis_address, password=redis_password)
# Setup subscriptions to the primary Redis server and the Redis shards. # Setup subscriptions to the primary Redis server and the Redis shards.

View file

@ -31,7 +31,8 @@ PY3 = sys.version_info.major >= 3
class Node(object): class Node(object):
"""An encapsulation of the Ray processes on a single node. """An encapsulation of the Ray processes on a single node.
This class is responsible for starting Ray processes and killing them. This class is responsible for starting Ray processes and killing them,
and it also controls the temp file policy.
Attributes: Attributes:
all_processes (dict): A mapping from process type (str) to a list of all_processes (dict): A mapping from process type (str) to a list of
@ -63,8 +64,17 @@ class Node(object):
"be both true.") "be both true.")
self.all_processes = {} self.all_processes = {}
# Try to get node IP address with the parameters.
if ray_params.node_ip_address:
node_ip_address = ray_params.node_ip_address
elif ray_params.redis_address:
node_ip_address = ray.services.get_node_ip_address(
ray_params.redis_address)
else:
node_ip_address = ray.services.get_node_ip_address()
self._node_ip_address = node_ip_address
ray_params.update_if_absent( ray_params.update_if_absent(
node_ip_address=ray.services.get_node_ip_address(),
include_log_monitor=True, include_log_monitor=True,
resources={}, resources={},
include_webui=False, include_webui=False,
@ -73,31 +83,51 @@ class Node(object):
"workers/default_worker.py")) "workers/default_worker.py"))
self._ray_params = ray_params self._ray_params = ray_params
self._node_ip_address = ray_params.node_ip_address
self._redis_address = ray_params.redis_address self._redis_address = ray_params.redis_address
self._config = (json.loads(ray_params._internal_config) self._config = (json.loads(ray_params._internal_config)
if ray_params._internal_config else None) if ray_params._internal_config else None)
if head: self._init_temp()
ray_params.update_if_absent(num_redis_shards=1, include_webui=True)
self._plasma_store_socket_name = None if connect_only:
self._raylet_socket_name = None # Get socket names from the configuration.
self._webui_url = None
else:
self._plasma_store_socket_name = ( self._plasma_store_socket_name = (
ray_params.plasma_store_socket_name) ray_params.plasma_store_socket_name)
self._raylet_socket_name = ray_params.raylet_socket_name self._raylet_socket_name = ray_params.raylet_socket_name
# If user does not provide the socket name, get it from Redis.
if (self._plasma_store_socket_name is None
or self._raylet_socket_name is None):
# Get the address info of the processes to connect to
# from Redis.
address_info = ray.services.get_address_info_from_redis(
self.redis_address,
self._node_ip_address,
redis_password=self.redis_password)
self._plasma_store_socket_name = address_info[
"object_store_address"]
self._raylet_socket_name = address_info["raylet_socket_name"]
else:
# If the user specified a socket name, use it.
self._plasma_store_socket_name = self._prepare_socket_file(
self._ray_params.plasma_store_socket_name,
default_prefix="plasma_store")
self._raylet_socket_name = self._prepare_socket_file(
self._ray_params.raylet_socket_name, default_prefix="raylet")
if head:
ray_params.update_if_absent(num_redis_shards=1, include_webui=True)
self._webui_url = None
else:
redis_client = self.create_redis_client() redis_client = self.create_redis_client()
# TODO(suquark): Replace _webui_url_helper in worker.py in self._webui_url = (
# another PR. ray.services.get_webui_url_from_redis(redis_client))
_webui_url = redis_client.hmget("webui", "url")[0]
self._webui_url = (ray.utils.decode(_webui_url)
if _webui_url is not None else None)
ray_params.include_java = ( ray_params.include_java = (
ray.services.include_java_from_redis(redis_client)) ray.services.include_java_from_redis(redis_client))
self._init_temp() # Start processes.
if head:
self.start_head_processes()
if not connect_only: if not connect_only:
self.start_ray_processes() self.start_ray_processes()
@ -136,6 +166,20 @@ class Node(object):
"""Get the cluster Redis address.""" """Get the cluster Redis address."""
return self._redis_address return self._redis_address
@property
def redis_password(self):
"""Get the cluster Redis password"""
return self._ray_params.redis_password
@property
def load_code_from_local(self):
return self._ray_params.load_code_from_local
@property
def object_id_seed(self):
"""Get the seed for deterministic generation of object IDs"""
return self._ray_params.object_id_seed
@property @property
def plasma_store_socket_name(self): def plasma_store_socket_name(self):
"""Get the node's plasma store socket name.""" """Get the node's plasma store socket name."""
@ -151,6 +195,17 @@ class Node(object):
"""Get the node's raylet socket name.""" """Get the node's raylet socket name."""
return self._raylet_socket_name return self._raylet_socket_name
@property
def address_info(self):
"""Get a dictionary of addresses."""
return {
"node_ip_address": self._node_ip_address,
"redis_address": self._redis_address,
"object_store_address": self._plasma_store_socket_name,
"raylet_socket_name": self._raylet_socket_name,
"webui_url": self._webui_url,
}
def create_redis_client(self): def create_redis_client(self):
"""Create a redis client.""" """Create a redis client."""
return ray.services.create_redis_client( return ray.services.create_redis_client(
@ -321,11 +376,6 @@ class Node(object):
def start_plasma_store(self): def start_plasma_store(self):
"""Start the plasma store.""" """Start the plasma store."""
assert self._plasma_store_socket_name is None
# If the user specified a socket name, use it.
self._plasma_store_socket_name = self._prepare_socket_file(
self._ray_params.plasma_store_socket_name,
default_prefix="plasma_store")
stdout_file, stderr_file = self.new_log_files("plasma_store") stdout_file, stderr_file = self.new_log_files("plasma_store")
process_info = ray.services.start_plasma_store( process_info = ray.services.start_plasma_store(
stdout_file=stdout_file, stdout_file=stdout_file,
@ -349,10 +399,6 @@ class Node(object):
use_profiler (bool): True if we should start the process in the use_profiler (bool): True if we should start the process in the
valgrind profiler. valgrind profiler.
""" """
assert self._raylet_socket_name is None
# If the user specified a socket name, use it.
self._raylet_socket_name = self._prepare_socket_file(
self._ray_params.raylet_socket_name, default_prefix="raylet")
stdout_file, stderr_file = self.new_log_files("raylet") stdout_file, stderr_file = self.new_log_files("raylet")
process_info = ray.services.start_raylet( process_info = ray.services.start_raylet(
self._redis_address, self._redis_address,
@ -416,20 +462,26 @@ class Node(object):
process_info process_info
] ]
def start_head_processes(self):
"""Start head processes on the node."""
logger.info(
"Process STDOUT and STDERR is being redirected to {}.".format(
self._logs_dir))
assert self._redis_address is None
# If this is the head node, start the relevant head node processes.
self.start_redis()
self.start_monitor()
self.start_raylet_monitor()
# The dashboard is Python3.x only.
if PY3 and self._ray_params.include_webui:
self.start_dashboard()
def start_ray_processes(self): def start_ray_processes(self):
"""Start all of the processes on the node.""" """Start all of the processes on the node."""
logger.info( logger.info(
"Process STDOUT and STDERR is being redirected to {}.".format( "Process STDOUT and STDERR is being redirected to {}.".format(
self._logs_dir)) self._logs_dir))
# If this is the head node, start the relevant head node processes.
if self._redis_address is None:
self.start_redis()
self.start_monitor()
self.start_raylet_monitor()
if PY3 and self._ray_params.include_webui:
self.start_dashboard()
self.start_plasma_store() self.start_plasma_store()
self.start_raylet() self.start_raylet()
if PY3: if PY3:
@ -685,3 +737,16 @@ class Node(object):
True if any process that wasn't explicitly killed is still alive. True if any process that wasn't explicitly killed is still alive.
""" """
return not any(self.dead_processes()) return not any(self.dead_processes())
class LocalNode(object):
"""Imitate the node that manages the processes in local mode."""
def kill_all_processes(self, *args, **kwargs):
"""Kill all of the processes."""
pass # Keep this function empty because it will be used in worker.py
@property
def address_info(self):
"""Get a dictionary of addresses."""
return {} # Return a null dict.

View file

@ -77,20 +77,6 @@ def address(ip_address, port):
return ip_address + ":" + str(port) return ip_address + ":" + str(port)
def get_ip_address(address):
assert type(address) == str, "Address must be a string"
ip_address = address.split(":")[0]
return ip_address
def get_port(address):
try:
port = int(address.split(":")[1])
except Exception:
raise Exception("Unable to parse port from address {}".format(address))
return port
def new_port(): def new_port():
return random.randint(10000, 65535) return random.randint(10000, 65535)
@ -107,6 +93,64 @@ def include_java_from_redis(redis_client):
return redis_client.get("INCLUDE_JAVA") == b"1" return redis_client.get("INCLUDE_JAVA") == b"1"
def get_address_info_from_redis_helper(redis_address,
node_ip_address,
redis_password=None):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine as
# Redis) must have run "CONFIG SET protected-mode no".
redis_client = create_redis_client(redis_address, password=redis_password)
client_table = ray.experimental.state.parse_client_table(redis_client)
if len(client_table) == 0:
raise Exception(
"Redis has started but no raylets have registered yet.")
relevant_client = None
for client_info in client_table:
client_node_ip_address = client_info["NodeManagerAddress"]
if (client_node_ip_address == node_ip_address
or (client_node_ip_address == "127.0.0.1"
and redis_ip_address == get_node_ip_address())):
relevant_client = client_info
break
if relevant_client is None:
raise Exception(
"Redis has started but no raylets have registered yet.")
return {
"object_store_address": relevant_client["ObjectStoreSocketName"],
"raylet_socket_name": relevant_client["RayletSocketName"],
}
def get_address_info_from_redis(redis_address,
node_ip_address,
num_retries=5,
redis_password=None):
counter = 0
while True:
try:
return get_address_info_from_redis_helper(
redis_address, node_ip_address, redis_password=redis_password)
except Exception:
if counter == num_retries:
raise
# Some of the information may not be in Redis yet, so wait a little
# bit.
logger.warning(
"Some processes that the driver needs to connect to have "
"not registered with Redis, so retrying. Have you run "
"'ray start' on this node?")
time.sleep(1)
counter += 1
def get_webui_url_from_redis(redis_client):
webui_url = redis_client.hmget("webui", "url")[0]
return ray.utils.decode(webui_url) if webui_url is not None else None
def remaining_processes_alive(): def remaining_processes_alive():
"""See if the remaining processes are alive or not. """See if the remaining processes are alive or not.

View file

@ -7,6 +7,13 @@ import shutil
import time import time
import pytest import pytest
import ray import ray
from ray.tests.cluster_utils import Cluster
# Py2 compatibility
try:
FileNotFoundError
except NameError:
FileNotFoundError = OSError
def test_conn_cluster(): def test_conn_cluster():
@ -52,8 +59,17 @@ def test_raylet_socket_name():
ray.shutdown() ray.shutdown()
try: try:
os.remove("/tmp/i_am_a_temp_socket") os.remove("/tmp/i_am_a_temp_socket")
except Exception: except FileNotFoundError:
pass pass # It could have been removed by Ray.
cluster = Cluster(True)
cluster.add_node(raylet_socket_name="/tmp/i_am_a_temp_socket_2")
assert os.path.exists(
"/tmp/i_am_a_temp_socket_2"), "Specified socket path not found."
cluster.shutdown()
try:
os.remove("/tmp/i_am_a_temp_socket_2")
except FileNotFoundError:
pass # It could have been removed by Ray.
def test_temp_plasma_store_socket(): def test_temp_plasma_store_socket():
@ -63,8 +79,17 @@ def test_temp_plasma_store_socket():
ray.shutdown() ray.shutdown()
try: try:
os.remove("/tmp/i_am_a_temp_socket") os.remove("/tmp/i_am_a_temp_socket")
except Exception: except FileNotFoundError:
pass pass # It could have been removed by Ray.
cluster = Cluster(True)
cluster.add_node(plasma_store_socket_name="/tmp/i_am_a_temp_socket_2")
assert os.path.exists(
"/tmp/i_am_a_temp_socket_2"), "Specified socket path not found."
cluster.shutdown()
try:
os.remove("/tmp/i_am_a_temp_socket_2")
except FileNotFoundError:
pass # It could have been removed by Ray.
def test_raylet_tempfiles(): def test_raylet_tempfiles():

View file

@ -12,7 +12,6 @@ import json
import logging import logging
import numpy as np import numpy as np
import os import os
import redis
import signal import signal
from six.moves import queue from six.moves import queue
import sys import sys
@ -115,6 +114,7 @@ class Worker(object):
Attributes: Attributes:
connected (bool): True if Ray has been started and False otherwise. connected (bool): True if Ray has been started and False otherwise.
node (ray.node.Node): The node this worker is attached to.
mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and
WORKER_MODE. WORKER_MODE.
cached_functions_to_run (List): A list of functions to run on all of cached_functions_to_run (List): A list of functions to run on all of
@ -124,7 +124,7 @@ class Worker(object):
def __init__(self): def __init__(self):
"""Initialize a Worker object.""" """Initialize a Worker object."""
self.connected = False self.node = None
self.mode = None self.mode = None
self.cached_functions_to_run = [] self.cached_functions_to_run = []
self.actor_init_error = None self.actor_init_error = None
@ -144,7 +144,6 @@ class Worker(object):
# A dictionary that maps from driver id to SerializationContext # A dictionary that maps from driver id to SerializationContext
# TODO: clean up the SerializationContext once the job finished. # TODO: clean up the SerializationContext once the job finished.
self.serialization_context_map = {} self.serialization_context_map = {}
self.load_code_from_local = False
self.function_actor_manager = FunctionActorManager(self) self.function_actor_manager = FunctionActorManager(self)
# Identity of the driver that this worker is processing. # Identity of the driver that this worker is processing.
# It is a DriverID. # It is a DriverID.
@ -158,6 +157,20 @@ class Worker(object):
self._session_index = 0 self._session_index = 0
self._current_task = None self._current_task = None
@property
def connected(self):
return self.node is not None
@property
def node_ip_address(self):
self.check_connected()
return self.node.node_ip_address
@property
def load_code_from_local(self):
self.check_connected()
return self.node.load_code_from_local
@property @property
def task_context(self): def task_context(self):
"""A thread-local that contains the following attributes. """A thread-local that contains the following attributes.
@ -1072,19 +1085,6 @@ def get_resource_ids():
return global_worker.raylet_client.resource_ids() return global_worker.raylet_client.resource_ids()
def _webui_url_helper(client):
"""Parsing for getting the url of the web UI.
Args:
client: A redis client to use to query the primary Redis shard.
Returns:
The URL of the web UI as a string.
"""
result = client.hmget("webui", "url")[0]
return ray.utils.decode(result) if result is not None else result
def get_webui_url(): def get_webui_url():
"""Get the URL to access the web UI. """Get the URL to access the web UI.
@ -1093,10 +1093,9 @@ def get_webui_url():
Returns: Returns:
The URL of the web UI as a string. The URL of the web UI as a string.
""" """
if _mode() == LOCAL_MODE: if _global_node is None:
raise Exception("ray.get_webui_url() currently does not work in " raise Exception("Ray has not been initialized/connected.")
"PYTHON MODE.") return _global_node.get_webui_url
return _webui_url_helper(global_worker.redis_client)
global_worker = Worker() global_worker = Worker()
@ -1211,64 +1210,6 @@ def _initialize_serialization(driver_id, worker=global_worker):
class_id="ray.signature.FunctionSignature") class_id="ray.signature.FunctionSignature")
def get_address_info_from_redis_helper(redis_address,
node_ip_address,
redis_password=None):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine as
# Redis) must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(
host=redis_ip_address, port=int(redis_port), password=redis_password)
client_table = ray.experimental.state.parse_client_table(redis_client)
if len(client_table) == 0:
raise Exception(
"Redis has started but no raylets have registered yet.")
relevant_client = None
for client_info in client_table:
client_node_ip_address = client_info["NodeManagerAddress"]
if (client_node_ip_address == node_ip_address or
(client_node_ip_address == "127.0.0.1"
and redis_ip_address == ray.services.get_node_ip_address())):
relevant_client = client_info
break
if relevant_client is None:
raise Exception(
"Redis has started but no raylets have registered yet.")
return {
"node_ip_address": node_ip_address,
"redis_address": redis_address,
"object_store_address": relevant_client["ObjectStoreSocketName"],
"raylet_socket_name": relevant_client["RayletSocketName"],
# Web UI should be running.
"webui_url": _webui_url_helper(redis_client)
}
def get_address_info_from_redis(redis_address,
node_ip_address,
num_retries=5,
redis_password=None):
counter = 0
while True:
try:
return get_address_info_from_redis_helper(
redis_address, node_ip_address, redis_password=redis_password)
except Exception:
if counter == num_retries:
raise
# Some of the information may not be in Redis yet, so wait a little
# bit.
logger.warning(
"Some processes that the driver needs to connect to have "
"not registered with Redis, so retrying. Have you run "
"'ray start' on this node?")
time.sleep(1)
counter += 1
def init(redis_address=None, def init(redis_address=None,
num_cpus=None, num_cpus=None,
num_gpus=None, num_gpus=None,
@ -1414,22 +1355,11 @@ def init(redis_address=None,
if redis_address is not None: if redis_address is not None:
redis_address = services.address_to_ip(redis_address) redis_address = services.address_to_ip(redis_address)
address_info = {
"node_ip_address": node_ip_address,
"redis_address": redis_address
}
global _global_node global _global_node
if driver_mode == LOCAL_MODE: if driver_mode == LOCAL_MODE:
# If starting Ray in LOCAL_MODE, don't start any other processes. # If starting Ray in LOCAL_MODE, don't start any other processes.
pass _global_node = ray.node.LocalNode()
elif redis_address is None: elif redis_address is None:
# TODO(suquark): We should remove the code below because they
# have been set when initializing the node.
if node_ip_address is None:
node_ip_address = ray.services.get_node_ip_address()
if num_redis_shards is None:
num_redis_shards = 1
# In this case, we need to start a new cluster. # In this case, we need to start a new cluster.
ray_params = ray.parameter.RayParams( ray_params = ray.parameter.RayParams(
redis_address=redis_address, redis_address=redis_address,
@ -1461,11 +1391,6 @@ def init(redis_address=None,
# handler. # handler.
_global_node = ray.node.Node( _global_node = ray.node.Node(
head=True, shutdown_at_exit=False, ray_params=ray_params) head=True, shutdown_at_exit=False, ray_params=ray_params)
address_info["redis_address"] = _global_node.redis_address
address_info[
"object_store_address"] = _global_node.plasma_store_socket_name
address_info["webui_url"] = _global_node.webui_url
address_info["raylet_socket_name"] = _global_node.raylet_socket_name
else: else:
# In this case, we are connecting to an existing cluster. # In this case, we are connecting to an existing cluster.
if num_cpus is not None or num_gpus is not None: if num_cpus is not None or num_gpus is not None:
@ -1505,51 +1430,28 @@ def init(redis_address=None,
raise Exception("When connecting to an existing cluster, " raise Exception("When connecting to an existing cluster, "
"_internal_config must not be provided.") "_internal_config must not be provided.")
# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address(redis_address)
# Get the address info of the processes to connect to from Redis.
address_info = get_address_info_from_redis(
redis_address, node_ip_address, redis_password=redis_password)
# TODO(suquark): Use "node" as the input of "connect()".
# In this case, we only need to connect the node. # In this case, we only need to connect the node.
ray_params = ray.parameter.RayParams( ray_params = ray.parameter.RayParams(
node_ip_address=node_ip_address, node_ip_address=node_ip_address,
redis_address=redis_address, redis_address=redis_address,
redis_password=redis_password, redis_password=redis_password,
plasma_store_socket_name=address_info["object_store_address"],
raylet_socket_name=address_info["raylet_socket_name"],
object_id_seed=object_id_seed, object_id_seed=object_id_seed,
temp_dir=temp_dir) temp_dir=temp_dir,
load_code_from_local=load_code_from_local)
_global_node = ray.node.Node( _global_node = ray.node.Node(
ray_params, head=False, shutdown_at_exit=False, connect_only=True) ray_params, head=False, shutdown_at_exit=False, connect_only=True)
if driver_mode == LOCAL_MODE:
driver_address_info = {}
else:
driver_address_info = {
"node_ip_address": node_ip_address,
"redis_address": address_info["redis_address"],
"store_socket_name": address_info["object_store_address"],
"webui_url": address_info["webui_url"],
"raylet_socket_name": address_info["raylet_socket_name"],
}
connect( connect(
driver_address_info, _global_node,
redis_password=redis_password,
object_id_seed=object_id_seed,
mode=driver_mode, mode=driver_mode,
log_to_driver=log_to_driver, log_to_driver=log_to_driver,
worker=global_worker, worker=global_worker,
driver_id=driver_id, driver_id=driver_id)
load_code_from_local=load_code_from_local)
for hook in _post_init_hooks: for hook in _post_init_hooks:
hook() hook()
return address_info return _global_node.address_info
# Functions to run as callback after a successful ray init # Functions to run as callback after a successful ray init
@ -1782,9 +1684,7 @@ def is_initialized():
return ray.worker.global_worker.connected return ray.worker.global_worker.connected
def connect(info, def connect(node,
redis_password=None,
object_id_seed=None,
mode=WORKER_MODE, mode=WORKER_MODE,
log_to_driver=False, log_to_driver=False,
worker=global_worker, worker=global_worker,
@ -1793,14 +1693,7 @@ def connect(info,
"""Connect this worker to the raylet, to Plasma, and to Redis. """Connect this worker to the raylet, to Plasma, and to Redis.
Args: Args:
info (dict): A dictionary with address of the Redis server and the node (ray.node.Node): The node to connect.
sockets of the plasma store and raylet.
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
object_id_seed (int): Used to seed the deterministic generation of
object IDs. The same value can be used across multiple runs of the
same job in order to generate the object IDs in a consistent
manner. However, the same ID should not be used for different jobs.
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and
LOCAL_MODE. LOCAL_MODE.
log_to_driver (bool): If true, then output from all of the worker log_to_driver (bool): If true, then output from all of the worker
@ -1844,25 +1737,19 @@ def connect(info,
# All workers start out as non-actors. A worker can be turned into an actor # All workers start out as non-actors. A worker can be turned into an actor
# after it is created. # after it is created.
worker.actor_id = ActorID.nil() worker.actor_id = ActorID.nil()
worker.connected = True worker.node = node
worker.set_mode(mode) worker.set_mode(mode)
worker.load_code_from_local = load_code_from_local
# If running Ray in LOCAL_MODE, there is no need to create call # If running Ray in LOCAL_MODE, there is no need to create call
# create_worker or to start the worker service. # create_worker or to start the worker service.
if mode == LOCAL_MODE: if mode == LOCAL_MODE:
return return
# Set the node IP address.
worker.node_ip_address = info["node_ip_address"]
worker.redis_address = info["redis_address"]
# Create a Redis client. # Create a Redis client.
redis_ip_address, redis_port = info["redis_address"].split(":")
# The Redis client can safely be shared between threads. However, that is # The Redis client can safely be shared between threads. However, that is
# not true of Redis pubsub clients. See the documentation at # not true of Redis pubsub clients. See the documentation at
# https://github.com/andymccurdy/redis-py#thread-safety. # https://github.com/andymccurdy/redis-py#thread-safety.
worker.redis_client = redis.StrictRedis( worker.redis_client = node.create_redis_client()
host=redis_ip_address, port=int(redis_port), password=redis_password)
# For driver's check that the version information matches the version # For driver's check that the version information matches the version
# information that the Ray cluster was started with. # information that the Ray cluster was started with.
@ -1883,7 +1770,7 @@ def connect(info,
# Create an object for interfacing with the global state. # Create an object for interfacing with the global state.
global_state._initialize_global_state( global_state._initialize_global_state(
redis_ip_address, int(redis_port), redis_password=redis_password) node.redis_address, redis_password=node.redis_password)
# Register the worker with Redis. # Register the worker with Redis.
if mode == SCRIPT_MODE: if mode == SCRIPT_MODE:
@ -1891,11 +1778,11 @@ def connect(info,
# Register the driver/job with Redis here. # Register the driver/job with Redis here.
import __main__ as main import __main__ as main
driver_info = { driver_info = {
"node_ip_address": worker.node_ip_address, "node_ip_address": node.node_ip_address,
"driver_id": worker.worker_id, "driver_id": worker.worker_id,
"start_time": time.time(), "start_time": time.time(),
"plasma_store_socket": info["store_socket_name"], "plasma_store_socket": node.plasma_store_socket_name,
"raylet_socket": info.get("raylet_socket_name"), "raylet_socket": node.raylet_socket_name,
"name": (main.__file__ "name": (main.__file__
if hasattr(main, "__file__") else "INTERACTIVE MODE") if hasattr(main, "__file__") else "INTERACTIVE MODE")
} }
@ -1903,8 +1790,8 @@ def connect(info,
elif mode == WORKER_MODE: elif mode == WORKER_MODE:
# Register the worker with Redis. # Register the worker with Redis.
worker_dict = { worker_dict = {
"node_ip_address": worker.node_ip_address, "node_ip_address": node.node_ip_address,
"plasma_store_socket": info["store_socket_name"], "plasma_store_socket": node.plasma_store_socket_name,
} }
# Check the RedirectOutput key in Redis and based on its value redirect # Check the RedirectOutput key in Redis and based on its value redirect
# worker output and error to their own files. # worker output and error to their own files.
@ -1913,7 +1800,7 @@ def connect(info,
if (redirect_worker_output_val is not None if (redirect_worker_output_val is not None
and int(redirect_worker_output_val) == 1): and int(redirect_worker_output_val) == 1):
log_stdout_file, log_stderr_file = ( log_stdout_file, log_stderr_file = (
_global_node.new_worker_redirected_log_file(worker.worker_id)) node.new_worker_redirected_log_file(worker.worker_id))
# Redirect stdout/stderr at the file descriptor level. If we simply # Redirect stdout/stderr at the file descriptor level. If we simply
# set sys.stdout and sys.stderr, then logging from C++ can fail to # set sys.stdout and sys.stderr, then logging from C++ can fail to
# be redirected. # be redirected.
@ -1941,7 +1828,7 @@ def connect(info,
# Create an object store client. # Create an object store client.
worker.plasma_client = thread_safe_client( worker.plasma_client = thread_safe_client(
plasma.connect(info["store_socket_name"], None, 0, 300)) plasma.connect(node.plasma_store_socket_name, None, 0, 300))
# If this is a driver, set the current task ID, the task driver ID, and set # If this is a driver, set the current task ID, the task driver ID, and set
# the task index to 0. # the task index to 0.
@ -1951,8 +1838,8 @@ def connect(info,
# the user's random number generator). Otherwise, set the current task # the user's random number generator). Otherwise, set the current task
# ID randomly to avoid object ID collisions. # ID randomly to avoid object ID collisions.
numpy_state = np.random.get_state() numpy_state = np.random.get_state()
if object_id_seed is not None: if node.object_id_seed is not None:
np.random.seed(object_id_seed) np.random.seed(node.object_id_seed)
else: else:
# Try to use true randomness. # Try to use true randomness.
np.random.seed(None) np.random.seed(None)
@ -1999,7 +1886,7 @@ def connect(info,
worker.task_context.current_task_id = driver_task.task_id() worker.task_context.current_task_id = driver_task.task_id()
worker.raylet_client = ray._raylet.RayletClient( worker.raylet_client = ray._raylet.RayletClient(
info["raylet_socket_name"], node.raylet_socket_name,
ClientID(worker.worker_id), ClientID(worker.worker_id),
(mode == WORKER_MODE), (mode == WORKER_MODE),
DriverID(worker.current_task_id.binary()), DriverID(worker.current_task_id.binary()),
@ -2096,7 +1983,7 @@ def disconnect():
worker.threads_stopped.clear() worker.threads_stopped.clear()
worker._session_index += 1 worker._session_index += 1
worker.connected = False worker.node = None # Disconnect the worker from the node.
worker.cached_functions_to_run = [] worker.cached_functions_to_run = []
worker.function_actor_manager.reset_cache() worker.function_actor_manager.reset_cache()
worker.serialization_context_map.clear() worker.serialization_context_map.clear()

View file

@ -66,14 +66,6 @@ parser.add_argument(
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
info = {
"node_ip_address": args.node_ip_address,
"redis_address": args.redis_address,
"redis_password": args.redis_password,
"store_socket_name": args.object_store_name,
"raylet_socket_name": args.raylet_name,
}
ray.utils.setup_logger(args.logging_level, args.logging_format) ray.utils.setup_logger(args.logging_level, args.logging_format)
ray_params = RayParams( ray_params = RayParams(
@ -89,12 +81,7 @@ if __name__ == "__main__":
ray_params, head=False, shutdown_at_exit=False, connect_only=True) ray_params, head=False, shutdown_at_exit=False, connect_only=True)
ray.worker._global_node = node ray.worker._global_node = node
# TODO(suquark): Use "node" as the input of "connect". ray.worker.connect(node, mode=ray.WORKER_MODE)
ray.worker.connect(
info,
redis_password=args.redis_password,
mode=ray.WORKER_MODE,
load_code_from_local=args.load_code_from_local)
error_explanation = """ error_explanation = """
This error is unexpected and should not have happened. Somehow a worker This error is unexpected and should not have happened. Somehow a worker