Improve code related to node (#4383)

* Make full use of node

implement local node

fix bugs mentioned in comments

* Add more tests

* Use more specific exception handling

* fix, lint

* fix for py2.x
This commit is contained in:
Si-Yuan 2019-04-09 17:27:54 +08:00 committed by GitHub
parent c5bcec54f3
commit dab99d26af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 243 additions and 245 deletions

View file

@ -4,7 +4,6 @@ from __future__ import print_function
from collections import defaultdict
import json
import redis
import sys
import time
@ -13,6 +12,7 @@ from ray.function_manager import FunctionDescriptor
import ray.gcs_utils
from ray.ray_constants import ID_SIZE
from ray import services
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
@ -126,8 +126,7 @@ class GlobalState(object):
self.redis_clients = None
def _initialize_global_state(self,
redis_ip_address,
redis_port,
redis_address,
redis_password=None,
timeout=20):
"""Initialize the GlobalState object by connecting to Redis.
@ -137,18 +136,15 @@ class GlobalState(object):
been populated or we exceed a timeout.
Args:
redis_ip_address: The IP address of the node that the Redis server
lives on.
redis_port: The port that the Redis server is listening on.
redis_address: The Redis address to connect.
redis_password: The password of the redis server.
"""
self.redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port, password=redis_password)
self.redis_client = services.create_redis_client(
redis_address, redis_password)
start_time = time.time()
num_redis_shards = None
ip_address_ports = []
redis_shard_addresses = []
while time.time() - start_time < timeout:
# Attempt to get the number of Redis shards.
@ -163,9 +159,9 @@ class GlobalState(object):
"{}.".format(num_redis_shards))
# Attempt to get all of the Redis shards.
ip_address_ports = self.redis_client.lrange(
redis_shard_addresses = self.redis_client.lrange(
"RedisShards", start=0, end=-1)
if len(ip_address_ports) != num_redis_shards:
if len(redis_shard_addresses) != num_redis_shards:
print("Waiting longer for RedisShards to be populated.")
time.sleep(1)
continue
@ -177,18 +173,15 @@ class GlobalState(object):
if time.time() - start_time >= timeout:
raise Exception("Timed out while attempting to initialize the "
"global state. num_redis_shards = {}, "
"ip_address_ports = {}".format(
num_redis_shards, ip_address_ports))
"redis_shard_addresses = {}".format(
num_redis_shards, redis_shard_addresses))
# Get the rest of the information.
self.redis_clients = []
for ip_address_port in ip_address_ports:
shard_address, shard_port = ip_address_port.split(b":")
for shard_address in redis_shard_addresses:
self.redis_clients.append(
redis.StrictRedis(
host=shard_address,
port=shard_port,
password=redis_password))
services.create_redis_client(shard_address.decode(),
redis_password))
def _execute_command(self, key, *args):
"""Execute a Redis command on the appropriate Redis shard based on key.

View file

@ -16,7 +16,6 @@ import ray.cloudpickle as pickle
import ray.gcs_utils
import ray.utils
import ray.ray_constants as ray_constants
from ray.services import get_ip_address, get_port
from ray.utils import (binary_to_hex, binary_to_object_id, hex_to_binary,
setup_logger)
@ -32,17 +31,15 @@ class Monitor(object):
Attributes:
redis: A connection to the Redis server.
subscribe_client: A pubsub client for the Redis server. This is used to
receive notifications about failed components.
primary_subscribe_client: A pubsub client for the Redis server.
This is used to receive notifications about failed components.
"""
def __init__(self, redis_address, autoscaling_config, redis_password=None):
# Initialize the Redis clients.
self.state = ray.experimental.state.GlobalState()
redis_ip_address = get_ip_address(args.redis_address)
redis_port = get_port(args.redis_address)
self.state._initialize_global_state(
redis_ip_address, redis_port, redis_password=redis_password)
args.redis_address, redis_password=redis_password)
self.redis = ray.services.create_redis_client(
redis_address, password=redis_password)
# Setup subscriptions to the primary Redis server and the Redis shards.

View file

@ -31,7 +31,8 @@ PY3 = sys.version_info.major >= 3
class Node(object):
"""An encapsulation of the Ray processes on a single node.
This class is responsible for starting Ray processes and killing them.
This class is responsible for starting Ray processes and killing them,
and it also controls the temp file policy.
Attributes:
all_processes (dict): A mapping from process type (str) to a list of
@ -63,8 +64,17 @@ class Node(object):
"be both true.")
self.all_processes = {}
# Try to get node IP address with the parameters.
if ray_params.node_ip_address:
node_ip_address = ray_params.node_ip_address
elif ray_params.redis_address:
node_ip_address = ray.services.get_node_ip_address(
ray_params.redis_address)
else:
node_ip_address = ray.services.get_node_ip_address()
self._node_ip_address = node_ip_address
ray_params.update_if_absent(
node_ip_address=ray.services.get_node_ip_address(),
include_log_monitor=True,
resources={},
include_webui=False,
@ -73,31 +83,51 @@ class Node(object):
"workers/default_worker.py"))
self._ray_params = ray_params
self._node_ip_address = ray_params.node_ip_address
self._redis_address = ray_params.redis_address
self._config = (json.loads(ray_params._internal_config)
if ray_params._internal_config else None)
if head:
ray_params.update_if_absent(num_redis_shards=1, include_webui=True)
self._plasma_store_socket_name = None
self._raylet_socket_name = None
self._webui_url = None
else:
self._init_temp()
if connect_only:
# Get socket names from the configuration.
self._plasma_store_socket_name = (
ray_params.plasma_store_socket_name)
self._raylet_socket_name = ray_params.raylet_socket_name
# If user does not provide the socket name, get it from Redis.
if (self._plasma_store_socket_name is None
or self._raylet_socket_name is None):
# Get the address info of the processes to connect to
# from Redis.
address_info = ray.services.get_address_info_from_redis(
self.redis_address,
self._node_ip_address,
redis_password=self.redis_password)
self._plasma_store_socket_name = address_info[
"object_store_address"]
self._raylet_socket_name = address_info["raylet_socket_name"]
else:
# If the user specified a socket name, use it.
self._plasma_store_socket_name = self._prepare_socket_file(
self._ray_params.plasma_store_socket_name,
default_prefix="plasma_store")
self._raylet_socket_name = self._prepare_socket_file(
self._ray_params.raylet_socket_name, default_prefix="raylet")
if head:
ray_params.update_if_absent(num_redis_shards=1, include_webui=True)
self._webui_url = None
else:
redis_client = self.create_redis_client()
# TODO(suquark): Replace _webui_url_helper in worker.py in
# another PR.
_webui_url = redis_client.hmget("webui", "url")[0]
self._webui_url = (ray.utils.decode(_webui_url)
if _webui_url is not None else None)
self._webui_url = (
ray.services.get_webui_url_from_redis(redis_client))
ray_params.include_java = (
ray.services.include_java_from_redis(redis_client))
self._init_temp()
# Start processes.
if head:
self.start_head_processes()
if not connect_only:
self.start_ray_processes()
@ -136,6 +166,20 @@ class Node(object):
"""Get the cluster Redis address."""
return self._redis_address
@property
def redis_password(self):
"""Get the cluster Redis password"""
return self._ray_params.redis_password
@property
def load_code_from_local(self):
return self._ray_params.load_code_from_local
@property
def object_id_seed(self):
"""Get the seed for deterministic generation of object IDs"""
return self._ray_params.object_id_seed
@property
def plasma_store_socket_name(self):
"""Get the node's plasma store socket name."""
@ -151,6 +195,17 @@ class Node(object):
"""Get the node's raylet socket name."""
return self._raylet_socket_name
@property
def address_info(self):
"""Get a dictionary of addresses."""
return {
"node_ip_address": self._node_ip_address,
"redis_address": self._redis_address,
"object_store_address": self._plasma_store_socket_name,
"raylet_socket_name": self._raylet_socket_name,
"webui_url": self._webui_url,
}
def create_redis_client(self):
"""Create a redis client."""
return ray.services.create_redis_client(
@ -321,11 +376,6 @@ class Node(object):
def start_plasma_store(self):
"""Start the plasma store."""
assert self._plasma_store_socket_name is None
# If the user specified a socket name, use it.
self._plasma_store_socket_name = self._prepare_socket_file(
self._ray_params.plasma_store_socket_name,
default_prefix="plasma_store")
stdout_file, stderr_file = self.new_log_files("plasma_store")
process_info = ray.services.start_plasma_store(
stdout_file=stdout_file,
@ -349,10 +399,6 @@ class Node(object):
use_profiler (bool): True if we should start the process in the
valgrind profiler.
"""
assert self._raylet_socket_name is None
# If the user specified a socket name, use it.
self._raylet_socket_name = self._prepare_socket_file(
self._ray_params.raylet_socket_name, default_prefix="raylet")
stdout_file, stderr_file = self.new_log_files("raylet")
process_info = ray.services.start_raylet(
self._redis_address,
@ -416,20 +462,26 @@ class Node(object):
process_info
]
def start_head_processes(self):
"""Start head processes on the node."""
logger.info(
"Process STDOUT and STDERR is being redirected to {}.".format(
self._logs_dir))
assert self._redis_address is None
# If this is the head node, start the relevant head node processes.
self.start_redis()
self.start_monitor()
self.start_raylet_monitor()
# The dashboard is Python3.x only.
if PY3 and self._ray_params.include_webui:
self.start_dashboard()
def start_ray_processes(self):
"""Start all of the processes on the node."""
logger.info(
"Process STDOUT and STDERR is being redirected to {}.".format(
self._logs_dir))
# If this is the head node, start the relevant head node processes.
if self._redis_address is None:
self.start_redis()
self.start_monitor()
self.start_raylet_monitor()
if PY3 and self._ray_params.include_webui:
self.start_dashboard()
self.start_plasma_store()
self.start_raylet()
if PY3:
@ -685,3 +737,16 @@ class Node(object):
True if any process that wasn't explicitly killed is still alive.
"""
return not any(self.dead_processes())
class LocalNode(object):
"""Imitate the node that manages the processes in local mode."""
def kill_all_processes(self, *args, **kwargs):
"""Kill all of the processes."""
pass # Keep this function empty because it will be used in worker.py
@property
def address_info(self):
"""Get a dictionary of addresses."""
return {} # Return a null dict.

View file

@ -77,20 +77,6 @@ def address(ip_address, port):
return ip_address + ":" + str(port)
def get_ip_address(address):
assert type(address) == str, "Address must be a string"
ip_address = address.split(":")[0]
return ip_address
def get_port(address):
try:
port = int(address.split(":")[1])
except Exception:
raise Exception("Unable to parse port from address {}".format(address))
return port
def new_port():
return random.randint(10000, 65535)
@ -107,6 +93,64 @@ def include_java_from_redis(redis_client):
return redis_client.get("INCLUDE_JAVA") == b"1"
def get_address_info_from_redis_helper(redis_address,
node_ip_address,
redis_password=None):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine as
# Redis) must have run "CONFIG SET protected-mode no".
redis_client = create_redis_client(redis_address, password=redis_password)
client_table = ray.experimental.state.parse_client_table(redis_client)
if len(client_table) == 0:
raise Exception(
"Redis has started but no raylets have registered yet.")
relevant_client = None
for client_info in client_table:
client_node_ip_address = client_info["NodeManagerAddress"]
if (client_node_ip_address == node_ip_address
or (client_node_ip_address == "127.0.0.1"
and redis_ip_address == get_node_ip_address())):
relevant_client = client_info
break
if relevant_client is None:
raise Exception(
"Redis has started but no raylets have registered yet.")
return {
"object_store_address": relevant_client["ObjectStoreSocketName"],
"raylet_socket_name": relevant_client["RayletSocketName"],
}
def get_address_info_from_redis(redis_address,
node_ip_address,
num_retries=5,
redis_password=None):
counter = 0
while True:
try:
return get_address_info_from_redis_helper(
redis_address, node_ip_address, redis_password=redis_password)
except Exception:
if counter == num_retries:
raise
# Some of the information may not be in Redis yet, so wait a little
# bit.
logger.warning(
"Some processes that the driver needs to connect to have "
"not registered with Redis, so retrying. Have you run "
"'ray start' on this node?")
time.sleep(1)
counter += 1
def get_webui_url_from_redis(redis_client):
webui_url = redis_client.hmget("webui", "url")[0]
return ray.utils.decode(webui_url) if webui_url is not None else None
def remaining_processes_alive():
"""See if the remaining processes are alive or not.

View file

@ -7,6 +7,13 @@ import shutil
import time
import pytest
import ray
from ray.tests.cluster_utils import Cluster
# Py2 compatibility
try:
FileNotFoundError
except NameError:
FileNotFoundError = OSError
def test_conn_cluster():
@ -52,8 +59,17 @@ def test_raylet_socket_name():
ray.shutdown()
try:
os.remove("/tmp/i_am_a_temp_socket")
except Exception:
pass
except FileNotFoundError:
pass # It could have been removed by Ray.
cluster = Cluster(True)
cluster.add_node(raylet_socket_name="/tmp/i_am_a_temp_socket_2")
assert os.path.exists(
"/tmp/i_am_a_temp_socket_2"), "Specified socket path not found."
cluster.shutdown()
try:
os.remove("/tmp/i_am_a_temp_socket_2")
except FileNotFoundError:
pass # It could have been removed by Ray.
def test_temp_plasma_store_socket():
@ -63,8 +79,17 @@ def test_temp_plasma_store_socket():
ray.shutdown()
try:
os.remove("/tmp/i_am_a_temp_socket")
except Exception:
pass
except FileNotFoundError:
pass # It could have been removed by Ray.
cluster = Cluster(True)
cluster.add_node(plasma_store_socket_name="/tmp/i_am_a_temp_socket_2")
assert os.path.exists(
"/tmp/i_am_a_temp_socket_2"), "Specified socket path not found."
cluster.shutdown()
try:
os.remove("/tmp/i_am_a_temp_socket_2")
except FileNotFoundError:
pass # It could have been removed by Ray.
def test_raylet_tempfiles():

View file

@ -12,7 +12,6 @@ import json
import logging
import numpy as np
import os
import redis
import signal
from six.moves import queue
import sys
@ -115,6 +114,7 @@ class Worker(object):
Attributes:
connected (bool): True if Ray has been started and False otherwise.
node (ray.node.Node): The node this worker is attached to.
mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and
WORKER_MODE.
cached_functions_to_run (List): A list of functions to run on all of
@ -124,7 +124,7 @@ class Worker(object):
def __init__(self):
"""Initialize a Worker object."""
self.connected = False
self.node = None
self.mode = None
self.cached_functions_to_run = []
self.actor_init_error = None
@ -144,7 +144,6 @@ class Worker(object):
# A dictionary that maps from driver id to SerializationContext
# TODO: clean up the SerializationContext once the job finished.
self.serialization_context_map = {}
self.load_code_from_local = False
self.function_actor_manager = FunctionActorManager(self)
# Identity of the driver that this worker is processing.
# It is a DriverID.
@ -158,6 +157,20 @@ class Worker(object):
self._session_index = 0
self._current_task = None
@property
def connected(self):
return self.node is not None
@property
def node_ip_address(self):
self.check_connected()
return self.node.node_ip_address
@property
def load_code_from_local(self):
self.check_connected()
return self.node.load_code_from_local
@property
def task_context(self):
"""A thread-local that contains the following attributes.
@ -1072,19 +1085,6 @@ def get_resource_ids():
return global_worker.raylet_client.resource_ids()
def _webui_url_helper(client):
"""Parsing for getting the url of the web UI.
Args:
client: A redis client to use to query the primary Redis shard.
Returns:
The URL of the web UI as a string.
"""
result = client.hmget("webui", "url")[0]
return ray.utils.decode(result) if result is not None else result
def get_webui_url():
"""Get the URL to access the web UI.
@ -1093,10 +1093,9 @@ def get_webui_url():
Returns:
The URL of the web UI as a string.
"""
if _mode() == LOCAL_MODE:
raise Exception("ray.get_webui_url() currently does not work in "
"PYTHON MODE.")
return _webui_url_helper(global_worker.redis_client)
if _global_node is None:
raise Exception("Ray has not been initialized/connected.")
return _global_node.get_webui_url
global_worker = Worker()
@ -1211,64 +1210,6 @@ def _initialize_serialization(driver_id, worker=global_worker):
class_id="ray.signature.FunctionSignature")
def get_address_info_from_redis_helper(redis_address,
node_ip_address,
redis_password=None):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine as
# Redis) must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(
host=redis_ip_address, port=int(redis_port), password=redis_password)
client_table = ray.experimental.state.parse_client_table(redis_client)
if len(client_table) == 0:
raise Exception(
"Redis has started but no raylets have registered yet.")
relevant_client = None
for client_info in client_table:
client_node_ip_address = client_info["NodeManagerAddress"]
if (client_node_ip_address == node_ip_address or
(client_node_ip_address == "127.0.0.1"
and redis_ip_address == ray.services.get_node_ip_address())):
relevant_client = client_info
break
if relevant_client is None:
raise Exception(
"Redis has started but no raylets have registered yet.")
return {
"node_ip_address": node_ip_address,
"redis_address": redis_address,
"object_store_address": relevant_client["ObjectStoreSocketName"],
"raylet_socket_name": relevant_client["RayletSocketName"],
# Web UI should be running.
"webui_url": _webui_url_helper(redis_client)
}
def get_address_info_from_redis(redis_address,
node_ip_address,
num_retries=5,
redis_password=None):
counter = 0
while True:
try:
return get_address_info_from_redis_helper(
redis_address, node_ip_address, redis_password=redis_password)
except Exception:
if counter == num_retries:
raise
# Some of the information may not be in Redis yet, so wait a little
# bit.
logger.warning(
"Some processes that the driver needs to connect to have "
"not registered with Redis, so retrying. Have you run "
"'ray start' on this node?")
time.sleep(1)
counter += 1
def init(redis_address=None,
num_cpus=None,
num_gpus=None,
@ -1414,22 +1355,11 @@ def init(redis_address=None,
if redis_address is not None:
redis_address = services.address_to_ip(redis_address)
address_info = {
"node_ip_address": node_ip_address,
"redis_address": redis_address
}
global _global_node
if driver_mode == LOCAL_MODE:
# If starting Ray in LOCAL_MODE, don't start any other processes.
pass
_global_node = ray.node.LocalNode()
elif redis_address is None:
# TODO(suquark): We should remove the code below because they
# have been set when initializing the node.
if node_ip_address is None:
node_ip_address = ray.services.get_node_ip_address()
if num_redis_shards is None:
num_redis_shards = 1
# In this case, we need to start a new cluster.
ray_params = ray.parameter.RayParams(
redis_address=redis_address,
@ -1461,11 +1391,6 @@ def init(redis_address=None,
# handler.
_global_node = ray.node.Node(
head=True, shutdown_at_exit=False, ray_params=ray_params)
address_info["redis_address"] = _global_node.redis_address
address_info[
"object_store_address"] = _global_node.plasma_store_socket_name
address_info["webui_url"] = _global_node.webui_url
address_info["raylet_socket_name"] = _global_node.raylet_socket_name
else:
# In this case, we are connecting to an existing cluster.
if num_cpus is not None or num_gpus is not None:
@ -1505,51 +1430,28 @@ def init(redis_address=None,
raise Exception("When connecting to an existing cluster, "
"_internal_config must not be provided.")
# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address(redis_address)
# Get the address info of the processes to connect to from Redis.
address_info = get_address_info_from_redis(
redis_address, node_ip_address, redis_password=redis_password)
# TODO(suquark): Use "node" as the input of "connect()".
# In this case, we only need to connect the node.
ray_params = ray.parameter.RayParams(
node_ip_address=node_ip_address,
redis_address=redis_address,
redis_password=redis_password,
plasma_store_socket_name=address_info["object_store_address"],
raylet_socket_name=address_info["raylet_socket_name"],
object_id_seed=object_id_seed,
temp_dir=temp_dir)
temp_dir=temp_dir,
load_code_from_local=load_code_from_local)
_global_node = ray.node.Node(
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
if driver_mode == LOCAL_MODE:
driver_address_info = {}
else:
driver_address_info = {
"node_ip_address": node_ip_address,
"redis_address": address_info["redis_address"],
"store_socket_name": address_info["object_store_address"],
"webui_url": address_info["webui_url"],
"raylet_socket_name": address_info["raylet_socket_name"],
}
connect(
driver_address_info,
redis_password=redis_password,
object_id_seed=object_id_seed,
_global_node,
mode=driver_mode,
log_to_driver=log_to_driver,
worker=global_worker,
driver_id=driver_id,
load_code_from_local=load_code_from_local)
driver_id=driver_id)
for hook in _post_init_hooks:
hook()
return address_info
return _global_node.address_info
# Functions to run as callback after a successful ray init
@ -1782,9 +1684,7 @@ def is_initialized():
return ray.worker.global_worker.connected
def connect(info,
redis_password=None,
object_id_seed=None,
def connect(node,
mode=WORKER_MODE,
log_to_driver=False,
worker=global_worker,
@ -1793,14 +1693,7 @@ def connect(info,
"""Connect this worker to the raylet, to Plasma, and to Redis.
Args:
info (dict): A dictionary with address of the Redis server and the
sockets of the plasma store and raylet.
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
object_id_seed (int): Used to seed the deterministic generation of
object IDs. The same value can be used across multiple runs of the
same job in order to generate the object IDs in a consistent
manner. However, the same ID should not be used for different jobs.
node (ray.node.Node): The node to connect.
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and
LOCAL_MODE.
log_to_driver (bool): If true, then output from all of the worker
@ -1844,25 +1737,19 @@ def connect(info,
# All workers start out as non-actors. A worker can be turned into an actor
# after it is created.
worker.actor_id = ActorID.nil()
worker.connected = True
worker.node = node
worker.set_mode(mode)
worker.load_code_from_local = load_code_from_local
# If running Ray in LOCAL_MODE, there is no need to create call
# create_worker or to start the worker service.
if mode == LOCAL_MODE:
return
# Set the node IP address.
worker.node_ip_address = info["node_ip_address"]
worker.redis_address = info["redis_address"]
# Create a Redis client.
redis_ip_address, redis_port = info["redis_address"].split(":")
# The Redis client can safely be shared between threads. However, that is
# not true of Redis pubsub clients. See the documentation at
# https://github.com/andymccurdy/redis-py#thread-safety.
worker.redis_client = redis.StrictRedis(
host=redis_ip_address, port=int(redis_port), password=redis_password)
worker.redis_client = node.create_redis_client()
# For driver's check that the version information matches the version
# information that the Ray cluster was started with.
@ -1883,7 +1770,7 @@ def connect(info,
# Create an object for interfacing with the global state.
global_state._initialize_global_state(
redis_ip_address, int(redis_port), redis_password=redis_password)
node.redis_address, redis_password=node.redis_password)
# Register the worker with Redis.
if mode == SCRIPT_MODE:
@ -1891,11 +1778,11 @@ def connect(info,
# Register the driver/job with Redis here.
import __main__ as main
driver_info = {
"node_ip_address": worker.node_ip_address,
"node_ip_address": node.node_ip_address,
"driver_id": worker.worker_id,
"start_time": time.time(),
"plasma_store_socket": info["store_socket_name"],
"raylet_socket": info.get("raylet_socket_name"),
"plasma_store_socket": node.plasma_store_socket_name,
"raylet_socket": node.raylet_socket_name,
"name": (main.__file__
if hasattr(main, "__file__") else "INTERACTIVE MODE")
}
@ -1903,8 +1790,8 @@ def connect(info,
elif mode == WORKER_MODE:
# Register the worker with Redis.
worker_dict = {
"node_ip_address": worker.node_ip_address,
"plasma_store_socket": info["store_socket_name"],
"node_ip_address": node.node_ip_address,
"plasma_store_socket": node.plasma_store_socket_name,
}
# Check the RedirectOutput key in Redis and based on its value redirect
# worker output and error to their own files.
@ -1913,7 +1800,7 @@ def connect(info,
if (redirect_worker_output_val is not None
and int(redirect_worker_output_val) == 1):
log_stdout_file, log_stderr_file = (
_global_node.new_worker_redirected_log_file(worker.worker_id))
node.new_worker_redirected_log_file(worker.worker_id))
# Redirect stdout/stderr at the file descriptor level. If we simply
# set sys.stdout and sys.stderr, then logging from C++ can fail to
# be redirected.
@ -1941,7 +1828,7 @@ def connect(info,
# Create an object store client.
worker.plasma_client = thread_safe_client(
plasma.connect(info["store_socket_name"], None, 0, 300))
plasma.connect(node.plasma_store_socket_name, None, 0, 300))
# If this is a driver, set the current task ID, the task driver ID, and set
# the task index to 0.
@ -1951,8 +1838,8 @@ def connect(info,
# the user's random number generator). Otherwise, set the current task
# ID randomly to avoid object ID collisions.
numpy_state = np.random.get_state()
if object_id_seed is not None:
np.random.seed(object_id_seed)
if node.object_id_seed is not None:
np.random.seed(node.object_id_seed)
else:
# Try to use true randomness.
np.random.seed(None)
@ -1999,7 +1886,7 @@ def connect(info,
worker.task_context.current_task_id = driver_task.task_id()
worker.raylet_client = ray._raylet.RayletClient(
info["raylet_socket_name"],
node.raylet_socket_name,
ClientID(worker.worker_id),
(mode == WORKER_MODE),
DriverID(worker.current_task_id.binary()),
@ -2096,7 +1983,7 @@ def disconnect():
worker.threads_stopped.clear()
worker._session_index += 1
worker.connected = False
worker.node = None # Disconnect the worker from the node.
worker.cached_functions_to_run = []
worker.function_actor_manager.reset_cache()
worker.serialization_context_map.clear()

View file

@ -66,14 +66,6 @@ parser.add_argument(
if __name__ == "__main__":
args = parser.parse_args()
info = {
"node_ip_address": args.node_ip_address,
"redis_address": args.redis_address,
"redis_password": args.redis_password,
"store_socket_name": args.object_store_name,
"raylet_socket_name": args.raylet_name,
}
ray.utils.setup_logger(args.logging_level, args.logging_format)
ray_params = RayParams(
@ -89,12 +81,7 @@ if __name__ == "__main__":
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
ray.worker._global_node = node
# TODO(suquark): Use "node" as the input of "connect".
ray.worker.connect(
info,
redis_password=args.redis_password,
mode=ray.WORKER_MODE,
load_code_from_local=args.load_code_from_local)
ray.worker.connect(node, mode=ray.WORKER_MODE)
error_explanation = """
This error is unexpected and should not have happened. Somehow a worker