Add actor table to global state API (#6629)

This commit is contained in:
Philipp Moritz 2019-12-31 15:11:59 -08:00 committed by GitHub
parent a4d64de39a
commit ecddaafd94
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 133 additions and 14 deletions

View file

@ -102,8 +102,8 @@ from ray._raylet import (
_config = _Config()
from ray.profiling import profile # noqa: E402
from ray.state import (global_state, jobs, nodes, tasks, objects, timeline,
object_transfer_timeline, cluster_resources,
from ray.state import (global_state, jobs, nodes, actors, tasks, objects,
timeline, object_transfer_timeline, cluster_resources,
available_resources, errors) # noqa: E402
from ray.worker import (
LOCAL_MODE,
@ -139,6 +139,7 @@ __all__ = [
"global_state",
"jobs",
"nodes",
"actors",
"tasks",
"objects",
"timeline",

View file

@ -250,7 +250,8 @@ class NodeStats(threading.Thread):
# Mapping from IP address to PID to list of error messages
self._errors = defaultdict(lambda: defaultdict(list))
ray.init(address=redis_address, redis_password=redis_password)
ray.state.state._initialize_global_state(
redis_address=redis_address, redis_password=redis_password)
super().__init__()

View file

@ -61,6 +61,7 @@ TablePrefix_OBJECT_string = "OBJECT"
TablePrefix_ERROR_INFO_string = "ERROR_INFO"
TablePrefix_PROFILE_string = "PROFILE"
TablePrefix_JOB_string = "JOB"
TablePrefix_ACTOR_string = "ACTOR"
def construct_error_message(job_id, error_type, message, timestamp):

View file

@ -306,6 +306,71 @@ class GlobalState(object):
self._object_table(binary_to_object_id(object_id_binary)))
return results
def _actor_table(self, actor_id):
"""Fetch and parse the actor table information for a single actor ID.
Args:
actor_id: A actor ID to get information about.
Returns:
A dictionary with information about the actor ID in question.
"""
assert isinstance(actor_id, ray.ActorID)
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ACTOR"), "",
actor_id.binary())
if message is None:
return {}
gcs_entries = gcs_utils.GcsEntry.FromString(message)
assert len(gcs_entries.entries) == 1
actor_table_data = gcs_utils.ActorTableData.FromString(
gcs_entries.entries[0])
actor_info = {
"JobID": binary_to_hex(actor_table_data.job_id),
"Address": {
"IPAddress": actor_table_data.address.ip_address,
"Port": actor_table_data.address.port
},
"OwnerAddress": {
"IPAddress": actor_table_data.owner_address.ip_address,
"Port": actor_table_data.owner_address.port
},
"IsDirectCall": actor_table_data.is_direct_call
}
return actor_info
def actor_table(self, actor_id=None):
"""Fetch and parse the actor table information for one or more actor IDs.
Args:
actor_id: A hex string of the actor ID to fetch information about.
If this is None, then the actor table is fetched.
Returns:
Information from the actor table.
"""
self._check_connected()
if actor_id is not None:
actor_id = ray.ActorID(hex_to_binary(actor_id))
return self._actor_table(actor_id)
else:
actor_table_keys = list(
self.redis_client.scan_iter(
match=gcs_utils.TablePrefix_ACTOR_string + "*"))
actor_ids_binary = [
key[len(gcs_utils.TablePrefix_ACTOR_string):]
for key in actor_table_keys
]
results = {}
for actor_id_binary in actor_ids_binary:
results[binary_to_hex(actor_id_binary)] = self._actor_table(
ray.ActorID(actor_id_binary))
return results
def _task_table(self, task_id):
"""Fetch and parse the task table information for a single task ID.
@ -1120,6 +1185,19 @@ def node_ids():
return node_ids
def actors(actor_id=None):
"""Fetch and parse the actor info for one or more actor IDs.
Args:
actor_id: A hex string of the actor ID to fetch information about. If
this is None, then all actor information is fetched.
Returns:
Information about the actors.
"""
return state.actor_table(actor_id=actor_id)
def tasks(task_id=None):
"""Fetch and parse the task table information for one or more task IDs.

View file

@ -89,6 +89,15 @@ def test_load_balancing_with_dependencies(ray_start_cluster):
attempt_to_load_balance(f, [x], 100, num_nodes, 25)
def wait_for_num_actors(num_actors, timeout=10):
start_time = time.time()
while time.time() - start_time < timeout:
if len(ray.actors()) >= num_actors:
return
time.sleep(0.1)
raise RayTestTimeoutException("Timed out while waiting for global state.")
def wait_for_num_tasks(num_tasks, timeout=10):
start_time = time.time()
while time.time() - start_time < timeout:
@ -107,11 +116,6 @@ def wait_for_num_objects(num_objects, timeout=10):
raise RayTestTimeoutException("Timed out while waiting for global state.")
@pytest.mark.skipif(
os.environ.get("RAY_USE_NEW_GCS") == "on",
reason="New GCS API doesn't have a Python API yet.")
@pytest.mark.skipif(
ray_constants.direct_call_enabled(), reason="state API not supported")
def test_global_state_api(shutdown_only):
error_message = ("The ray global state API cannot be used "
@ -120,6 +124,9 @@ def test_global_state_api(shutdown_only):
with pytest.raises(Exception, match=error_message):
ray.objects()
with pytest.raises(Exception, match=error_message):
ray.actors()
with pytest.raises(Exception, match=error_message):
ray.tasks()
@ -163,6 +170,43 @@ def test_global_state_api(shutdown_only):
assert len(client_table) == 1
assert client_table[0]["NodeManagerAddress"] == node_ip_address
@ray.remote
class Actor:
def __init__(self):
pass
_ = Actor.remote()
# Wait for actor to be created
wait_for_num_actors(1)
actor_table = ray.actors()
assert len(actor_table) == 1
actor_info, = actor_table.values()
assert actor_info["JobID"] == job_id.hex()
assert "IPAddress" in actor_info["Address"]
assert "IPAddress" in actor_info["OwnerAddress"]
assert actor_info["Address"]["Port"] != actor_info["OwnerAddress"]["Port"]
job_table = ray.jobs()
assert len(job_table) == 1
assert job_table[0]["JobID"] == job_id.hex()
assert job_table[0]["NodeManagerAddress"] == node_ip_address
@pytest.mark.skipif(
ray_constants.direct_call_enabled(),
reason="object and task API not supported")
def test_global_state_task_object_api(shutdown_only):
ray.init()
job_id = ray.utils.compute_job_id_from_driver(
ray.WorkerID(ray.worker.global_worker.worker_id))
driver_task_id = ray.worker.global_worker.current_task_id.hex()
nil_actor_id_hex = ray.ActorID.nil().hex()
@ray.remote
def f(*xs):
return 1
@ -213,12 +257,6 @@ def test_global_state_api(shutdown_only):
object_table_entry = ray.objects(result_id)
assert object_table[result_id] == object_table_entry
job_table = ray.jobs()
assert len(job_table) == 1
assert job_table[0]["JobID"] == job_id.hex()
assert job_table[0]["NodeManagerAddress"] == node_ip_address
# TODO(rkn): Pytest actually has tools for capturing stdout and stderr, so we
# should use those, but they seem to conflict with Ray's use of faulthandler.