From dab99d26af061bc3447ff15479f1fb9a7b7740c9 Mon Sep 17 00:00:00 2001 From: Si-Yuan Date: Tue, 9 Apr 2019 17:27:54 +0800 Subject: [PATCH] Improve code related to node (#4383) * Make full use of node implement local node fix bugs mentioned in comments * Add more tests * Use more specific exception handling * fix, lint * fix for py2.x --- python/ray/experimental/state.py | 33 ++--- python/ray/monitor.py | 9 +- python/ray/node.py | 131 +++++++++++++----- python/ray/services.py | 72 ++++++++-- python/ray/tests/test_tempfile.py | 33 ++++- python/ray/worker.py | 195 ++++++--------------------- python/ray/workers/default_worker.py | 15 +-- 7 files changed, 243 insertions(+), 245 deletions(-) diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 268f32fb8..b884d0794 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -4,7 +4,6 @@ from __future__ import print_function from collections import defaultdict import json -import redis import sys import time @@ -13,6 +12,7 @@ from ray.function_manager import FunctionDescriptor import ray.gcs_utils from ray.ray_constants import ID_SIZE +from ray import services from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) @@ -126,8 +126,7 @@ class GlobalState(object): self.redis_clients = None def _initialize_global_state(self, - redis_ip_address, - redis_port, + redis_address, redis_password=None, timeout=20): """Initialize the GlobalState object by connecting to Redis. @@ -137,18 +136,15 @@ class GlobalState(object): been populated or we exceed a timeout. Args: - redis_ip_address: The IP address of the node that the Redis server - lives on. - redis_port: The port that the Redis server is listening on. + redis_address: The Redis address to connect. redis_password: The password of the redis server. """ - self.redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port, password=redis_password) - + self.redis_client = services.create_redis_client( + redis_address, redis_password) start_time = time.time() num_redis_shards = None - ip_address_ports = [] + redis_shard_addresses = [] while time.time() - start_time < timeout: # Attempt to get the number of Redis shards. @@ -163,9 +159,9 @@ class GlobalState(object): "{}.".format(num_redis_shards)) # Attempt to get all of the Redis shards. - ip_address_ports = self.redis_client.lrange( + redis_shard_addresses = self.redis_client.lrange( "RedisShards", start=0, end=-1) - if len(ip_address_ports) != num_redis_shards: + if len(redis_shard_addresses) != num_redis_shards: print("Waiting longer for RedisShards to be populated.") time.sleep(1) continue @@ -177,18 +173,15 @@ class GlobalState(object): if time.time() - start_time >= timeout: raise Exception("Timed out while attempting to initialize the " "global state. num_redis_shards = {}, " - "ip_address_ports = {}".format( - num_redis_shards, ip_address_ports)) + "redis_shard_addresses = {}".format( + num_redis_shards, redis_shard_addresses)) # Get the rest of the information. self.redis_clients = [] - for ip_address_port in ip_address_ports: - shard_address, shard_port = ip_address_port.split(b":") + for shard_address in redis_shard_addresses: self.redis_clients.append( - redis.StrictRedis( - host=shard_address, - port=shard_port, - password=redis_password)) + services.create_redis_client(shard_address.decode(), + redis_password)) def _execute_command(self, key, *args): """Execute a Redis command on the appropriate Redis shard based on key. diff --git a/python/ray/monitor.py b/python/ray/monitor.py index df1e179dc..49b337c2d 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -16,7 +16,6 @@ import ray.cloudpickle as pickle import ray.gcs_utils import ray.utils import ray.ray_constants as ray_constants -from ray.services import get_ip_address, get_port from ray.utils import (binary_to_hex, binary_to_object_id, hex_to_binary, setup_logger) @@ -32,17 +31,15 @@ class Monitor(object): Attributes: redis: A connection to the Redis server. - subscribe_client: A pubsub client for the Redis server. This is used to - receive notifications about failed components. + primary_subscribe_client: A pubsub client for the Redis server. + This is used to receive notifications about failed components. """ def __init__(self, redis_address, autoscaling_config, redis_password=None): # Initialize the Redis clients. self.state = ray.experimental.state.GlobalState() - redis_ip_address = get_ip_address(args.redis_address) - redis_port = get_port(args.redis_address) self.state._initialize_global_state( - redis_ip_address, redis_port, redis_password=redis_password) + args.redis_address, redis_password=redis_password) self.redis = ray.services.create_redis_client( redis_address, password=redis_password) # Setup subscriptions to the primary Redis server and the Redis shards. diff --git a/python/ray/node.py b/python/ray/node.py index 4da4d6989..991131d8b 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -31,7 +31,8 @@ PY3 = sys.version_info.major >= 3 class Node(object): """An encapsulation of the Ray processes on a single node. - This class is responsible for starting Ray processes and killing them. + This class is responsible for starting Ray processes and killing them, + and it also controls the temp file policy. Attributes: all_processes (dict): A mapping from process type (str) to a list of @@ -63,8 +64,17 @@ class Node(object): "be both true.") self.all_processes = {} + # Try to get node IP address with the parameters. + if ray_params.node_ip_address: + node_ip_address = ray_params.node_ip_address + elif ray_params.redis_address: + node_ip_address = ray.services.get_node_ip_address( + ray_params.redis_address) + else: + node_ip_address = ray.services.get_node_ip_address() + self._node_ip_address = node_ip_address + ray_params.update_if_absent( - node_ip_address=ray.services.get_node_ip_address(), include_log_monitor=True, resources={}, include_webui=False, @@ -73,31 +83,51 @@ class Node(object): "workers/default_worker.py")) self._ray_params = ray_params - - self._node_ip_address = ray_params.node_ip_address self._redis_address = ray_params.redis_address self._config = (json.loads(ray_params._internal_config) if ray_params._internal_config else None) - if head: - ray_params.update_if_absent(num_redis_shards=1, include_webui=True) - self._plasma_store_socket_name = None - self._raylet_socket_name = None - self._webui_url = None - else: + self._init_temp() + + if connect_only: + # Get socket names from the configuration. self._plasma_store_socket_name = ( ray_params.plasma_store_socket_name) self._raylet_socket_name = ray_params.raylet_socket_name + + # If user does not provide the socket name, get it from Redis. + if (self._plasma_store_socket_name is None + or self._raylet_socket_name is None): + # Get the address info of the processes to connect to + # from Redis. + address_info = ray.services.get_address_info_from_redis( + self.redis_address, + self._node_ip_address, + redis_password=self.redis_password) + self._plasma_store_socket_name = address_info[ + "object_store_address"] + self._raylet_socket_name = address_info["raylet_socket_name"] + else: + # If the user specified a socket name, use it. + self._plasma_store_socket_name = self._prepare_socket_file( + self._ray_params.plasma_store_socket_name, + default_prefix="plasma_store") + self._raylet_socket_name = self._prepare_socket_file( + self._ray_params.raylet_socket_name, default_prefix="raylet") + + if head: + ray_params.update_if_absent(num_redis_shards=1, include_webui=True) + self._webui_url = None + else: redis_client = self.create_redis_client() - # TODO(suquark): Replace _webui_url_helper in worker.py in - # another PR. - _webui_url = redis_client.hmget("webui", "url")[0] - self._webui_url = (ray.utils.decode(_webui_url) - if _webui_url is not None else None) + self._webui_url = ( + ray.services.get_webui_url_from_redis(redis_client)) ray_params.include_java = ( ray.services.include_java_from_redis(redis_client)) - self._init_temp() + # Start processes. + if head: + self.start_head_processes() if not connect_only: self.start_ray_processes() @@ -136,6 +166,20 @@ class Node(object): """Get the cluster Redis address.""" return self._redis_address + @property + def redis_password(self): + """Get the cluster Redis password""" + return self._ray_params.redis_password + + @property + def load_code_from_local(self): + return self._ray_params.load_code_from_local + + @property + def object_id_seed(self): + """Get the seed for deterministic generation of object IDs""" + return self._ray_params.object_id_seed + @property def plasma_store_socket_name(self): """Get the node's plasma store socket name.""" @@ -151,6 +195,17 @@ class Node(object): """Get the node's raylet socket name.""" return self._raylet_socket_name + @property + def address_info(self): + """Get a dictionary of addresses.""" + return { + "node_ip_address": self._node_ip_address, + "redis_address": self._redis_address, + "object_store_address": self._plasma_store_socket_name, + "raylet_socket_name": self._raylet_socket_name, + "webui_url": self._webui_url, + } + def create_redis_client(self): """Create a redis client.""" return ray.services.create_redis_client( @@ -321,11 +376,6 @@ class Node(object): def start_plasma_store(self): """Start the plasma store.""" - assert self._plasma_store_socket_name is None - # If the user specified a socket name, use it. - self._plasma_store_socket_name = self._prepare_socket_file( - self._ray_params.plasma_store_socket_name, - default_prefix="plasma_store") stdout_file, stderr_file = self.new_log_files("plasma_store") process_info = ray.services.start_plasma_store( stdout_file=stdout_file, @@ -349,10 +399,6 @@ class Node(object): use_profiler (bool): True if we should start the process in the valgrind profiler. """ - assert self._raylet_socket_name is None - # If the user specified a socket name, use it. - self._raylet_socket_name = self._prepare_socket_file( - self._ray_params.raylet_socket_name, default_prefix="raylet") stdout_file, stderr_file = self.new_log_files("raylet") process_info = ray.services.start_raylet( self._redis_address, @@ -416,20 +462,26 @@ class Node(object): process_info ] + def start_head_processes(self): + """Start head processes on the node.""" + logger.info( + "Process STDOUT and STDERR is being redirected to {}.".format( + self._logs_dir)) + assert self._redis_address is None + # If this is the head node, start the relevant head node processes. + self.start_redis() + self.start_monitor() + self.start_raylet_monitor() + # The dashboard is Python3.x only. + if PY3 and self._ray_params.include_webui: + self.start_dashboard() + def start_ray_processes(self): """Start all of the processes on the node.""" logger.info( "Process STDOUT and STDERR is being redirected to {}.".format( self._logs_dir)) - # If this is the head node, start the relevant head node processes. - if self._redis_address is None: - self.start_redis() - self.start_monitor() - self.start_raylet_monitor() - if PY3 and self._ray_params.include_webui: - self.start_dashboard() - self.start_plasma_store() self.start_raylet() if PY3: @@ -685,3 +737,16 @@ class Node(object): True if any process that wasn't explicitly killed is still alive. """ return not any(self.dead_processes()) + + +class LocalNode(object): + """Imitate the node that manages the processes in local mode.""" + + def kill_all_processes(self, *args, **kwargs): + """Kill all of the processes.""" + pass # Keep this function empty because it will be used in worker.py + + @property + def address_info(self): + """Get a dictionary of addresses.""" + return {} # Return a null dict. diff --git a/python/ray/services.py b/python/ray/services.py index c440df2ea..0fc7f4e73 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -77,20 +77,6 @@ def address(ip_address, port): return ip_address + ":" + str(port) -def get_ip_address(address): - assert type(address) == str, "Address must be a string" - ip_address = address.split(":")[0] - return ip_address - - -def get_port(address): - try: - port = int(address.split(":")[1]) - except Exception: - raise Exception("Unable to parse port from address {}".format(address)) - return port - - def new_port(): return random.randint(10000, 65535) @@ -107,6 +93,64 @@ def include_java_from_redis(redis_client): return redis_client.get("INCLUDE_JAVA") == b"1" +def get_address_info_from_redis_helper(redis_address, + node_ip_address, + redis_password=None): + redis_ip_address, redis_port = redis_address.split(":") + # For this command to work, some other client (on the same machine as + # Redis) must have run "CONFIG SET protected-mode no". + redis_client = create_redis_client(redis_address, password=redis_password) + + client_table = ray.experimental.state.parse_client_table(redis_client) + if len(client_table) == 0: + raise Exception( + "Redis has started but no raylets have registered yet.") + + relevant_client = None + for client_info in client_table: + client_node_ip_address = client_info["NodeManagerAddress"] + if (client_node_ip_address == node_ip_address + or (client_node_ip_address == "127.0.0.1" + and redis_ip_address == get_node_ip_address())): + relevant_client = client_info + break + if relevant_client is None: + raise Exception( + "Redis has started but no raylets have registered yet.") + + return { + "object_store_address": relevant_client["ObjectStoreSocketName"], + "raylet_socket_name": relevant_client["RayletSocketName"], + } + + +def get_address_info_from_redis(redis_address, + node_ip_address, + num_retries=5, + redis_password=None): + counter = 0 + while True: + try: + return get_address_info_from_redis_helper( + redis_address, node_ip_address, redis_password=redis_password) + except Exception: + if counter == num_retries: + raise + # Some of the information may not be in Redis yet, so wait a little + # bit. + logger.warning( + "Some processes that the driver needs to connect to have " + "not registered with Redis, so retrying. Have you run " + "'ray start' on this node?") + time.sleep(1) + counter += 1 + + +def get_webui_url_from_redis(redis_client): + webui_url = redis_client.hmget("webui", "url")[0] + return ray.utils.decode(webui_url) if webui_url is not None else None + + def remaining_processes_alive(): """See if the remaining processes are alive or not. diff --git a/python/ray/tests/test_tempfile.py b/python/ray/tests/test_tempfile.py index ead1d31d9..191c9287e 100644 --- a/python/ray/tests/test_tempfile.py +++ b/python/ray/tests/test_tempfile.py @@ -7,6 +7,13 @@ import shutil import time import pytest import ray +from ray.tests.cluster_utils import Cluster + +# Py2 compatibility +try: + FileNotFoundError +except NameError: + FileNotFoundError = OSError def test_conn_cluster(): @@ -52,8 +59,17 @@ def test_raylet_socket_name(): ray.shutdown() try: os.remove("/tmp/i_am_a_temp_socket") - except Exception: - pass + except FileNotFoundError: + pass # It could have been removed by Ray. + cluster = Cluster(True) + cluster.add_node(raylet_socket_name="/tmp/i_am_a_temp_socket_2") + assert os.path.exists( + "/tmp/i_am_a_temp_socket_2"), "Specified socket path not found." + cluster.shutdown() + try: + os.remove("/tmp/i_am_a_temp_socket_2") + except FileNotFoundError: + pass # It could have been removed by Ray. def test_temp_plasma_store_socket(): @@ -63,8 +79,17 @@ def test_temp_plasma_store_socket(): ray.shutdown() try: os.remove("/tmp/i_am_a_temp_socket") - except Exception: - pass + except FileNotFoundError: + pass # It could have been removed by Ray. + cluster = Cluster(True) + cluster.add_node(plasma_store_socket_name="/tmp/i_am_a_temp_socket_2") + assert os.path.exists( + "/tmp/i_am_a_temp_socket_2"), "Specified socket path not found." + cluster.shutdown() + try: + os.remove("/tmp/i_am_a_temp_socket_2") + except FileNotFoundError: + pass # It could have been removed by Ray. def test_raylet_tempfiles(): diff --git a/python/ray/worker.py b/python/ray/worker.py index b6d84c434..64c8fbda7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -12,7 +12,6 @@ import json import logging import numpy as np import os -import redis import signal from six.moves import queue import sys @@ -115,6 +114,7 @@ class Worker(object): Attributes: connected (bool): True if Ray has been started and False otherwise. + node (ray.node.Node): The node this worker is attached to. mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and WORKER_MODE. cached_functions_to_run (List): A list of functions to run on all of @@ -124,7 +124,7 @@ class Worker(object): def __init__(self): """Initialize a Worker object.""" - self.connected = False + self.node = None self.mode = None self.cached_functions_to_run = [] self.actor_init_error = None @@ -144,7 +144,6 @@ class Worker(object): # A dictionary that maps from driver id to SerializationContext # TODO: clean up the SerializationContext once the job finished. self.serialization_context_map = {} - self.load_code_from_local = False self.function_actor_manager = FunctionActorManager(self) # Identity of the driver that this worker is processing. # It is a DriverID. @@ -158,6 +157,20 @@ class Worker(object): self._session_index = 0 self._current_task = None + @property + def connected(self): + return self.node is not None + + @property + def node_ip_address(self): + self.check_connected() + return self.node.node_ip_address + + @property + def load_code_from_local(self): + self.check_connected() + return self.node.load_code_from_local + @property def task_context(self): """A thread-local that contains the following attributes. @@ -1072,19 +1085,6 @@ def get_resource_ids(): return global_worker.raylet_client.resource_ids() -def _webui_url_helper(client): - """Parsing for getting the url of the web UI. - - Args: - client: A redis client to use to query the primary Redis shard. - - Returns: - The URL of the web UI as a string. - """ - result = client.hmget("webui", "url")[0] - return ray.utils.decode(result) if result is not None else result - - def get_webui_url(): """Get the URL to access the web UI. @@ -1093,10 +1093,9 @@ def get_webui_url(): Returns: The URL of the web UI as a string. """ - if _mode() == LOCAL_MODE: - raise Exception("ray.get_webui_url() currently does not work in " - "PYTHON MODE.") - return _webui_url_helper(global_worker.redis_client) + if _global_node is None: + raise Exception("Ray has not been initialized/connected.") + return _global_node.get_webui_url global_worker = Worker() @@ -1211,64 +1210,6 @@ def _initialize_serialization(driver_id, worker=global_worker): class_id="ray.signature.FunctionSignature") -def get_address_info_from_redis_helper(redis_address, - node_ip_address, - redis_password=None): - redis_ip_address, redis_port = redis_address.split(":") - # For this command to work, some other client (on the same machine as - # Redis) must have run "CONFIG SET protected-mode no". - redis_client = redis.StrictRedis( - host=redis_ip_address, port=int(redis_port), password=redis_password) - - client_table = ray.experimental.state.parse_client_table(redis_client) - if len(client_table) == 0: - raise Exception( - "Redis has started but no raylets have registered yet.") - - relevant_client = None - for client_info in client_table: - client_node_ip_address = client_info["NodeManagerAddress"] - if (client_node_ip_address == node_ip_address or - (client_node_ip_address == "127.0.0.1" - and redis_ip_address == ray.services.get_node_ip_address())): - relevant_client = client_info - break - if relevant_client is None: - raise Exception( - "Redis has started but no raylets have registered yet.") - - return { - "node_ip_address": node_ip_address, - "redis_address": redis_address, - "object_store_address": relevant_client["ObjectStoreSocketName"], - "raylet_socket_name": relevant_client["RayletSocketName"], - # Web UI should be running. - "webui_url": _webui_url_helper(redis_client) - } - - -def get_address_info_from_redis(redis_address, - node_ip_address, - num_retries=5, - redis_password=None): - counter = 0 - while True: - try: - return get_address_info_from_redis_helper( - redis_address, node_ip_address, redis_password=redis_password) - except Exception: - if counter == num_retries: - raise - # Some of the information may not be in Redis yet, so wait a little - # bit. - logger.warning( - "Some processes that the driver needs to connect to have " - "not registered with Redis, so retrying. Have you run " - "'ray start' on this node?") - time.sleep(1) - counter += 1 - - def init(redis_address=None, num_cpus=None, num_gpus=None, @@ -1414,22 +1355,11 @@ def init(redis_address=None, if redis_address is not None: redis_address = services.address_to_ip(redis_address) - address_info = { - "node_ip_address": node_ip_address, - "redis_address": redis_address - } - global _global_node if driver_mode == LOCAL_MODE: # If starting Ray in LOCAL_MODE, don't start any other processes. - pass + _global_node = ray.node.LocalNode() elif redis_address is None: - # TODO(suquark): We should remove the code below because they - # have been set when initializing the node. - if node_ip_address is None: - node_ip_address = ray.services.get_node_ip_address() - if num_redis_shards is None: - num_redis_shards = 1 # In this case, we need to start a new cluster. ray_params = ray.parameter.RayParams( redis_address=redis_address, @@ -1461,11 +1391,6 @@ def init(redis_address=None, # handler. _global_node = ray.node.Node( head=True, shutdown_at_exit=False, ray_params=ray_params) - address_info["redis_address"] = _global_node.redis_address - address_info[ - "object_store_address"] = _global_node.plasma_store_socket_name - address_info["webui_url"] = _global_node.webui_url - address_info["raylet_socket_name"] = _global_node.raylet_socket_name else: # In this case, we are connecting to an existing cluster. if num_cpus is not None or num_gpus is not None: @@ -1505,51 +1430,28 @@ def init(redis_address=None, raise Exception("When connecting to an existing cluster, " "_internal_config must not be provided.") - # Get the node IP address if one is not provided. - - if node_ip_address is None: - node_ip_address = services.get_node_ip_address(redis_address) - # Get the address info of the processes to connect to from Redis. - address_info = get_address_info_from_redis( - redis_address, node_ip_address, redis_password=redis_password) - # TODO(suquark): Use "node" as the input of "connect()". # In this case, we only need to connect the node. ray_params = ray.parameter.RayParams( node_ip_address=node_ip_address, redis_address=redis_address, redis_password=redis_password, - plasma_store_socket_name=address_info["object_store_address"], - raylet_socket_name=address_info["raylet_socket_name"], object_id_seed=object_id_seed, - temp_dir=temp_dir) + temp_dir=temp_dir, + load_code_from_local=load_code_from_local) _global_node = ray.node.Node( ray_params, head=False, shutdown_at_exit=False, connect_only=True) - if driver_mode == LOCAL_MODE: - driver_address_info = {} - else: - driver_address_info = { - "node_ip_address": node_ip_address, - "redis_address": address_info["redis_address"], - "store_socket_name": address_info["object_store_address"], - "webui_url": address_info["webui_url"], - "raylet_socket_name": address_info["raylet_socket_name"], - } - connect( - driver_address_info, - redis_password=redis_password, - object_id_seed=object_id_seed, + _global_node, mode=driver_mode, log_to_driver=log_to_driver, worker=global_worker, - driver_id=driver_id, - load_code_from_local=load_code_from_local) + driver_id=driver_id) for hook in _post_init_hooks: hook() - return address_info + return _global_node.address_info # Functions to run as callback after a successful ray init @@ -1782,9 +1684,7 @@ def is_initialized(): return ray.worker.global_worker.connected -def connect(info, - redis_password=None, - object_id_seed=None, +def connect(node, mode=WORKER_MODE, log_to_driver=False, worker=global_worker, @@ -1793,14 +1693,7 @@ def connect(info, """Connect this worker to the raylet, to Plasma, and to Redis. Args: - info (dict): A dictionary with address of the Redis server and the - sockets of the plasma store and raylet. - redis_password (str): Prevents external clients without the password - from connecting to Redis if provided. - object_id_seed (int): Used to seed the deterministic generation of - object IDs. The same value can be used across multiple runs of the - same job in order to generate the object IDs in a consistent - manner. However, the same ID should not be used for different jobs. + node (ray.node.Node): The node to connect. mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE. log_to_driver (bool): If true, then output from all of the worker @@ -1844,25 +1737,19 @@ def connect(info, # All workers start out as non-actors. A worker can be turned into an actor # after it is created. worker.actor_id = ActorID.nil() - worker.connected = True + worker.node = node worker.set_mode(mode) - worker.load_code_from_local = load_code_from_local # If running Ray in LOCAL_MODE, there is no need to create call # create_worker or to start the worker service. if mode == LOCAL_MODE: return - # Set the node IP address. - worker.node_ip_address = info["node_ip_address"] - worker.redis_address = info["redis_address"] # Create a Redis client. - redis_ip_address, redis_port = info["redis_address"].split(":") # The Redis client can safely be shared between threads. However, that is # not true of Redis pubsub clients. See the documentation at # https://github.com/andymccurdy/redis-py#thread-safety. - worker.redis_client = redis.StrictRedis( - host=redis_ip_address, port=int(redis_port), password=redis_password) + worker.redis_client = node.create_redis_client() # For driver's check that the version information matches the version # information that the Ray cluster was started with. @@ -1883,7 +1770,7 @@ def connect(info, # Create an object for interfacing with the global state. global_state._initialize_global_state( - redis_ip_address, int(redis_port), redis_password=redis_password) + node.redis_address, redis_password=node.redis_password) # Register the worker with Redis. if mode == SCRIPT_MODE: @@ -1891,11 +1778,11 @@ def connect(info, # Register the driver/job with Redis here. import __main__ as main driver_info = { - "node_ip_address": worker.node_ip_address, + "node_ip_address": node.node_ip_address, "driver_id": worker.worker_id, "start_time": time.time(), - "plasma_store_socket": info["store_socket_name"], - "raylet_socket": info.get("raylet_socket_name"), + "plasma_store_socket": node.plasma_store_socket_name, + "raylet_socket": node.raylet_socket_name, "name": (main.__file__ if hasattr(main, "__file__") else "INTERACTIVE MODE") } @@ -1903,8 +1790,8 @@ def connect(info, elif mode == WORKER_MODE: # Register the worker with Redis. worker_dict = { - "node_ip_address": worker.node_ip_address, - "plasma_store_socket": info["store_socket_name"], + "node_ip_address": node.node_ip_address, + "plasma_store_socket": node.plasma_store_socket_name, } # Check the RedirectOutput key in Redis and based on its value redirect # worker output and error to their own files. @@ -1913,7 +1800,7 @@ def connect(info, if (redirect_worker_output_val is not None and int(redirect_worker_output_val) == 1): log_stdout_file, log_stderr_file = ( - _global_node.new_worker_redirected_log_file(worker.worker_id)) + node.new_worker_redirected_log_file(worker.worker_id)) # Redirect stdout/stderr at the file descriptor level. If we simply # set sys.stdout and sys.stderr, then logging from C++ can fail to # be redirected. @@ -1941,7 +1828,7 @@ def connect(info, # Create an object store client. worker.plasma_client = thread_safe_client( - plasma.connect(info["store_socket_name"], None, 0, 300)) + plasma.connect(node.plasma_store_socket_name, None, 0, 300)) # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. @@ -1951,8 +1838,8 @@ def connect(info, # the user's random number generator). Otherwise, set the current task # ID randomly to avoid object ID collisions. numpy_state = np.random.get_state() - if object_id_seed is not None: - np.random.seed(object_id_seed) + if node.object_id_seed is not None: + np.random.seed(node.object_id_seed) else: # Try to use true randomness. np.random.seed(None) @@ -1999,7 +1886,7 @@ def connect(info, worker.task_context.current_task_id = driver_task.task_id() worker.raylet_client = ray._raylet.RayletClient( - info["raylet_socket_name"], + node.raylet_socket_name, ClientID(worker.worker_id), (mode == WORKER_MODE), DriverID(worker.current_task_id.binary()), @@ -2096,7 +1983,7 @@ def disconnect(): worker.threads_stopped.clear() worker._session_index += 1 - worker.connected = False + worker.node = None # Disconnect the worker from the node. worker.cached_functions_to_run = [] worker.function_actor_manager.reset_cache() worker.serialization_context_map.clear() diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 71b3a5f26..a8b7040d1 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -66,14 +66,6 @@ parser.add_argument( if __name__ == "__main__": args = parser.parse_args() - info = { - "node_ip_address": args.node_ip_address, - "redis_address": args.redis_address, - "redis_password": args.redis_password, - "store_socket_name": args.object_store_name, - "raylet_socket_name": args.raylet_name, - } - ray.utils.setup_logger(args.logging_level, args.logging_format) ray_params = RayParams( @@ -89,12 +81,7 @@ if __name__ == "__main__": ray_params, head=False, shutdown_at_exit=False, connect_only=True) ray.worker._global_node = node - # TODO(suquark): Use "node" as the input of "connect". - ray.worker.connect( - info, - redis_password=args.redis_password, - mode=ray.WORKER_MODE, - load_code_from_local=args.load_code_from_local) + ray.worker.connect(node, mode=ray.WORKER_MODE) error_explanation = """ This error is unexpected and should not have happened. Somehow a worker