[xray] Implement timeline and profiling API. (#2306)

* Add profile table and store profiling information there.

* Code for dumping timeline.

* Improve color scheme.

* Push timeline events on driver only for raylet.

* Improvements to profiling and timeline visualization

* Some linting

* Small fix.

* Linting

* Propagate node IP address through profiling events.

* Fix test.

* object_id.hex() should return byte string in python 2.

* Include gcs.fbs in node_manager.fbs.

* Remove flatbuffer definition duplication.

* Decode to unicode in Python 3 and bytes in Python 2.

* Minor

* Submit profile events in a batch. Revert some CMake changes.

* Fix

* Workaround test failure.

* Fix linting

* Linting

* Don't return anything from chrome_tracing_dump when filename is provided.

* Remove some redundancy from profile table.

* Linting

* Move TODOs out of docstring.

* Minor
This commit is contained in:
Robert Nishihara 2018-07-04 23:23:48 -07:00 committed by Philipp Moritz
parent 8e687cbc98
commit b90e551b41
27 changed files with 777 additions and 147 deletions

View file

@ -55,6 +55,17 @@ enable_testing()
include(ThirdpartyToolchain)
# TODO(rkn): Fix all of this. This include is needed for the following
# reason. The local scheduler depends on tables.cc which depends on
# node_manager_generated.h which depends on gcs_generated.h. However,
# the include statement for gcs_generated.h doesn't include the file
# path, so we include the relevant directory here.
set(GCS_FBS_OUTPUT_DIRECTORY
"${CMAKE_CURRENT_LIST_DIR}/src/ray/gcs/format")
include_directories(${GCS_FBS_OUTPUT_DIRECTORY})
include_directories(SYSTEM ${ARROW_INCLUDE_DIR})
include_directories(SYSTEM ${PLASMA_INCLUDE_DIR})
include_directories("${CMAKE_CURRENT_LIST_DIR}/src/")

View file

@ -48,6 +48,7 @@ MOCK_MODULES = ["gym",
"ray.core.generated.GcsTableEntry",
"ray.core.generated.HeartbeatTableData",
"ray.core.generated.ErrorTableData",
"ray.core.generated.ProfileTableData",
"ray.core.generated.ObjectTableData",
"ray.core.generated.ray.protocol.Task",
"ray.core.generated.TablePrefix",

View file

@ -48,7 +48,7 @@ except ImportError as e:
from ray.local_scheduler import ObjectID, _config # noqa: E402
from ray.worker import (error_info, init, connect, disconnect, get, put, wait,
remote, log_event, log_span, flush_log, get_gpu_ids,
remote, profile, flush_profile_data, get_gpu_ids,
get_resource_ids, get_webui_url,
register_custom_serializer) # noqa: E402
from ray.worker import (SCRIPT_MODE, WORKER_MODE, PYTHON_MODE,
@ -65,7 +65,7 @@ __version__ = "0.4.0"
__all__ = [
"error_info", "init", "connect", "disconnect", "get", "put", "wait",
"remote", "log_event", "log_span", "flush_log", "actor", "method",
"remote", "profile", "flush_profile_data", "actor", "method",
"get_gpu_ids", "get_resource_ids", "get_webui_url",
"register_custom_serializer", "SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE",
"SILENT_MODE", "global_state", "ObjectID", "_config", "__version__"

View file

@ -14,6 +14,7 @@ import ray.ray_constants as ray_constants
import ray.signature as signature
import ray.worker
from ray.utils import (
decode,
_random_string,
check_oversized_pickle,
is_cython,
@ -292,10 +293,10 @@ def fetch_and_register_actor(actor_class_key, worker):
"checkpoint_interval", "actor_method_names"
])
class_name = class_name.decode("ascii")
module = module.decode("ascii")
class_name = decode(class_name)
module = decode(module)
checkpoint_interval = int(checkpoint_interval)
actor_method_names = json.loads(actor_method_names.decode("ascii"))
actor_method_names = json.loads(decode(actor_method_names))
# Create a temporary actor with some temporary methods so that if the actor
# fails to be unpickled, the temporary actor can be used (just to produce

View file

@ -357,8 +357,6 @@ class GlobalState(object):
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
gcs_entries.Entries(i), 0)
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
gcs_entries.Entries(0), 0)
execution_spec = task_table_message.TaskExecutionSpec()
task_spec = task_table_message.TaskSpecification()
task_spec = ray.local_scheduler.task_from_string(task_spec)
@ -487,11 +485,10 @@ class GlobalState(object):
decode(value))
elif client_info[b"client_type"] == b"local_scheduler":
# The remaining fields are resource types.
client_info_parsed[field.decode("ascii")] = float(
client_info_parsed[decode(field)] = float(
decode(value))
else:
client_info_parsed[field.decode("ascii")] = decode(
value)
client_info_parsed[decode(field)] = decode(value)
node_info[node_ip_address].append(client_info_parsed)
@ -513,21 +510,19 @@ class GlobalState(object):
gcs_entry.Entries(i), 0))
resources = {
client.ResourcesTotalLabel(i).decode("ascii"):
decode(client.ResourcesTotalLabel(i)):
client.ResourcesTotalCapacity(i)
for i in range(client.ResourcesTotalLabelLength())
}
node_info.append({
"ClientID": ray.utils.binary_to_hex(client.ClientId()),
"IsInsertion": client.IsInsertion(),
"NodeManagerAddress": client.NodeManagerAddress().decode(
"ascii"),
"NodeManagerAddress": decode(client.NodeManagerAddress()),
"NodeManagerPort": client.NodeManagerPort(),
"ObjectManagerPort": client.ObjectManagerPort(),
"ObjectStoreSocketName": client.ObjectStoreSocketName()
.decode("ascii"),
"RayletSocketName": client.RayletSocketName().decode(
"ascii"),
"ObjectStoreSocketName": decode(
client.ObjectStoreSocketName()),
"RayletSocketName": decode(client.RayletSocketName()),
"Resources": resources
})
return node_info
@ -543,14 +538,14 @@ class GlobalState(object):
ip_filename_file = {}
for filename in relevant_files:
filename = filename.decode("ascii")
filename = decode(filename)
filename_components = filename.split(":")
ip_addr = filename_components[1]
file = self.redis_client.lrange(filename, 0, -1)
file_str = []
for x in file:
y = x.decode("ascii")
y = decode(x)
file_str.append(y)
if ip_addr not in ip_filename_file:
@ -630,7 +625,7 @@ class GlobalState(object):
event_log_set, **params)
for (event, score) in event_list:
event_dict = json.loads(event.decode())
event_dict = json.loads(decode(event))
task_id = ""
for event in event_dict:
if "task_id" in event[3]:
@ -643,31 +638,29 @@ class GlobalState(object):
heap_size += 1
for event in event_dict:
if event[1] == "ray:get_task" and event[2] == 1:
if event[1] == "get_task" and event[2] == 1:
task_info[task_id]["get_task_start"] = event[0]
if event[1] == "ray:get_task" and event[2] == 2:
if event[1] == "get_task" and event[2] == 2:
task_info[task_id]["get_task_end"] = event[0]
if (event[1] == "ray:import_remote_function"
if (event[1] == "register_remote_function"
and event[2] == 1):
task_info[task_id]["import_remote_start"] = event[0]
if (event[1] == "ray:import_remote_function"
if (event[1] == "register_remote_function"
and event[2] == 2):
task_info[task_id]["import_remote_end"] = event[0]
if event[1] == "ray:acquire_lock" and event[2] == 1:
task_info[task_id]["acquire_lock_start"] = event[0]
if event[1] == "ray:acquire_lock" and event[2] == 2:
task_info[task_id]["acquire_lock_end"] = event[0]
if event[1] == "ray:task:get_arguments" and event[2] == 1:
if (event[1] == "task:deserialize_arguments"
and event[2] == 1):
task_info[task_id]["get_arguments_start"] = event[0]
if event[1] == "ray:task:get_arguments" and event[2] == 2:
if (event[1] == "task:deserialize_arguments"
and event[2] == 2):
task_info[task_id]["get_arguments_end"] = event[0]
if event[1] == "ray:task:execute" and event[2] == 1:
if event[1] == "task:execute" and event[2] == 1:
task_info[task_id]["execute_start"] = event[0]
if event[1] == "ray:task:execute" and event[2] == 2:
if event[1] == "task:execute" and event[2] == 2:
task_info[task_id]["execute_end"] = event[0]
if event[1] == "ray:task:store_outputs" and event[2] == 1:
if event[1] == "task:store_outputs" and event[2] == 1:
task_info[task_id]["store_outputs_start"] = event[0]
if event[1] == "ray:task:store_outputs" and event[2] == 2:
if event[1] == "task:store_outputs" and event[2] == 2:
task_info[task_id]["store_outputs_end"] = event[0]
if "worker_id" in event[3]:
task_info[task_id]["worker_id"] = event[3]["worker_id"]
@ -685,6 +678,173 @@ class GlobalState(object):
return task_info
def _profile_table(self, component_id):
"""Get the profile events for a given component.
Args:
component_id: An identifier for a component.
Returns:
A list of the profile events for the specified process.
"""
# TODO(rkn): This method should support limiting the number of log
# events and should also support returning a window of events.
message = self._execute_command(component_id, "RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.PROFILE, "",
component_id.id())
if message is None:
return []
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
profile_events = []
for i in range(gcs_entries.EntriesLength()):
profile_table_message = (
ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData(
gcs_entries.Entries(i), 0))
component_type = decode(profile_table_message.ComponentType())
component_id = binary_to_hex(profile_table_message.ComponentId())
node_ip_address = decode(profile_table_message.NodeIpAddress())
for j in range(profile_table_message.ProfileEventsLength()):
profile_event_message = profile_table_message.ProfileEvents(j)
profile_event = {
"event_type": decode(profile_event_message.EventType()),
"component_id": component_id,
"node_ip_address": node_ip_address,
"component_type": component_type,
"start_time": profile_event_message.StartTime(),
"end_time": profile_event_message.EndTime(),
"extra_data": json.loads(
decode(profile_event_message.ExtraData())),
}
profile_events.append(profile_event)
return profile_events
def profile_table(self):
if not self.use_raylet:
raise Exception("This method is only supported in the raylet "
"code path.")
profile_table_keys = self._keys(
ray.gcs_utils.TablePrefix_PROFILE_string + "*")
component_identifiers_binary = [
key[len(ray.gcs_utils.TablePrefix_PROFILE_string):]
for key in profile_table_keys
]
return {
binary_to_hex(component_id): self._profile_table(
binary_to_object_id(component_id))
for component_id in component_identifiers_binary
}
def chrome_tracing_dump(self,
include_task_data=False,
filename=None,
open_browser=False):
"""Return a list of profiling events that can viewed as a timeline.
To view this information as a timeline, simply dump it as a json file
using json.dumps, and then load go to chrome://tracing in the Chrome
web browser and load the dumped file. Make sure to enable "Flow events"
in the "View Options" menu.
Args:
include_task_data: If true, we will include more task metadata such
as the task specifications in the json.
filename: If a filename is provided, the timeline is dumped to that
file.
open_browser: If true, we will attempt to automatically open the
timeline visualization in Chrome.
Returns:
If filename is not provided, this returns a list of profiling
events. Each profile event is a dictionary.
"""
# TODO(rkn): Support including the task specification data in the
# timeline.
# TODO(rkn): This should support viewing just a window of time or a
# limited number of events.
if include_task_data:
raise NotImplementedError("This flag has not been implented yet.")
if open_browser:
raise NotImplementedError("This flag has not been implented yet.")
profile_table = self.profile_table()
all_events = []
# Colors are specified at
# https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. # noqa: E501
default_color_mapping = defaultdict(
lambda: "generic_work", {
"get_task": "cq_build_abandoned",
"task": "rail_response",
"task:deserialize_arguments": "rail_load",
"task:execute": "rail_animation",
"task:store_outputs": "rail_idle",
"wait_for_function": "detailed_memory_dump",
"ray.get": "good",
"ray.put": "terrible",
"ray.wait": "vsync_highlight_color",
"submit_task": "background_memory_dump",
"fetch_and_run_function": "detailed_memory_dump",
"register_remote_function": "detailed_memory_dump",
})
def seconds_to_microseconds(time_in_seconds):
time_in_microseconds = 10**6 * time_in_seconds
return time_in_microseconds
for component_id_hex, component_events in profile_table.items():
for event in component_events:
new_event = {
# The category of the event.
"cat": event["event_type"],
# The string displayed on the event.
"name": event["event_type"],
# The identifier for the group of rows that the event
# appears in.
"pid": event["node_ip_address"],
# The identifier for the row that the event appears in.
"tid": event["component_type"] + ":" +
event["component_id"],
# The start time in microseconds.
"ts": seconds_to_microseconds(event["start_time"]),
# The duration in microseconds.
"dur": seconds_to_microseconds(event["end_time"] -
event["start_time"]),
# What is this?
"ph": "X",
# This is the name of the color to display the box in.
"cname": default_color_mapping[event["event_type"]],
# The extra user-defined data.
"args": event["extra_data"],
}
# Modify the json with the additional user-defined extra data.
# This can be used to add fields or override existing fields.
if "cname" in event["extra_data"]:
new_event["cname"] = event["extra_data"]["cname"]
if "name" in event["extra_data"]:
new_event["name"] = event["extra_data"]["name"]
all_events.append(new_event)
if filename is not None:
with open(filename, "w") as outfile:
json.dump(all_events, outfile)
else:
return all_events
def dump_catapult_trace(self,
path,
task_info,
@ -1047,21 +1207,20 @@ class GlobalState(object):
worker_id = binary_to_hex(worker_key[len("Workers:"):])
workers_data[worker_id] = {
"local_scheduler_socket": (
worker_info[b"local_scheduler_socket"].decode("ascii")),
"node_ip_address": (worker_info[b"node_ip_address"]
.decode("ascii")),
"plasma_manager_socket": (worker_info[b"plasma_manager_socket"]
.decode("ascii")),
"plasma_store_socket": (worker_info[b"plasma_store_socket"]
.decode("ascii"))
"local_scheduler_socket": (decode(
worker_info[b"local_scheduler_socket"])),
"node_ip_address": decode(worker_info[b"node_ip_address"]),
"plasma_manager_socket": decode(
worker_info[b"plasma_manager_socket"]),
"plasma_store_socket": decode(
worker_info[b"plasma_store_socket"])
}
if b"stderr_file" in worker_info:
workers_data[worker_id]["stderr_file"] = (
worker_info[b"stderr_file"].decode("ascii"))
workers_data[worker_id]["stderr_file"] = decode(
worker_info[b"stderr_file"])
if b"stdout_file" in worker_info:
workers_data[worker_id]["stdout_file"] = (
worker_info[b"stdout_file"].decode("ascii"))
workers_data[worker_id]["stdout_file"] = decode(
worker_info[b"stdout_file"])
return workers_data
def actors(self):
@ -1155,8 +1314,8 @@ class GlobalState(object):
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
gcs_entries.Entries(i), 0)
error_message = {
"type": error_data.Type().decode("ascii"),
"message": error_data.ErrorMessage().decode("ascii"),
"type": decode(error_data.Type()),
"message": decode(error_data.ErrorMessage()),
"timestamp": error_data.Timestamp(),
}
error_messages.append(error_message)

View file

@ -22,6 +22,7 @@ import ray.core.generated.ErrorTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.ProfileTableData import ProfileTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ray.protocol.Task import Task
@ -33,9 +34,9 @@ __all__ = [
"SubscribeToNotificationsReply", "ResultTableReply",
"TaskExecutionDependencies", "TaskReply", "DriverTableMessage",
"LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo",
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
"ObjectTableData", "Task", "TablePrefix", "TablePubsub",
"construct_error_message"
"GcsTableEntry", "ClientTableData", "ErrorTableData", "ProfileTableData",
"HeartbeatTableData", "ObjectTableData", "Task", "TablePrefix",
"TablePubsub", "construct_error_message"
]
# These prefixes must be kept up-to-date with the definitions in
@ -53,6 +54,7 @@ FUNCTION_PREFIX = "RemoteFunction:"
TablePrefix_RAYLET_TASK_string = "RAYLET_TASK"
TablePrefix_OBJECT_string = "OBJECT"
TablePrefix_ERROR_INFO_string = "ERROR_INFO"
TablePrefix_PROFILE_string = "PROFILE"
def construct_error_message(error_type, message, timestamp):

View file

@ -10,6 +10,7 @@ import time
from ray.services import get_ip_address
from ray.services import get_port
from ray.services import logger
import ray.utils
class LogMonitor(object):
@ -70,7 +71,7 @@ class LogMonitor(object):
if len(new_lines) > 0:
self.log_files[log_filename] += new_lines
redis_key = "LOGFILE:{}:{}".format(
self.node_ip_address, log_filename.decode("ascii"))
self.node_ip_address, ray.utils.decode(log_filename))
self.redis_client.rpush(redis_key, *new_lines)
# Pass if we already failed to open the log file.

View file

@ -10,6 +10,7 @@ import subprocess
import ray.services as services
from ray.autoscaler.commands import (create_or_update_cluster,
teardown_cluster, get_head_node_ip)
import ray.utils
def check_no_existing_redis_clients(node_ip_address, redis_client):
@ -31,7 +32,7 @@ def check_no_existing_redis_clients(node_ip_address, redis_client):
if deleted:
continue
if info[b"node_ip_address"].decode("ascii") == node_ip_address:
if ray.utils.decode(info[b"node_ip_address"]) == node_ip_address:
raise Exception("This Redis instance is already connected to "
"clients with this IP address.")

View file

@ -386,7 +386,7 @@ def check_version_info(redis_client):
if redis_reply is None:
return
true_version_info = tuple(json.loads(redis_reply.decode("ascii")))
true_version_info = tuple(json.loads(ray.utils.decode(redis_reply)))
version_info = _compute_version_info()
if version_info != true_version_info:
node_ip_address = ray.services.get_node_ip_address()
@ -776,7 +776,7 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
new_env["REDIS_ADDRESS"] = redis_address
# We generate the token used for authentication ourselves to avoid
# querying the jupyter server.
token = binascii.hexlify(os.urandom(24)).decode("ascii")
token = ray.utils.decode(binascii.hexlify(os.urandom(24)))
command = [
"jupyter", "notebook", "--no-browser", "--port={}".format(port),
"--NotebookApp.iopub_data_rate_limit=10000000000",
@ -1373,7 +1373,7 @@ def start_ray_processes(address_info=None,
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
redis_shards = [shard.decode("ascii") for shard in redis_shards]
redis_shards = [ray.utils.decode(shard) for shard in redis_shards]
address_info["redis_shards"] = redis_shards
# Start the log monitor, if necessary.

View file

@ -170,6 +170,8 @@ def random_string():
def decode(byte_str):
"""Make this unicode in Python 3, otherwise leave it as bytes."""
if not isinstance(byte_str, bytes):
raise ValueError("The argument must be a bytes object.")
if sys.version_info >= (3, 0):
return byte_str.decode("ascii")
else:

View file

@ -562,7 +562,7 @@ class Worker(object):
Returns:
The return object IDs for this task.
"""
with log_span("ray:submit_task", worker=self):
with profile("submit_task", worker=self):
check_main_thread()
if actor_id is None:
assert actor_handle_id is None
@ -867,7 +867,7 @@ class Worker(object):
# Get task arguments from the object store.
try:
with log_span("ray:task:get_arguments", worker=self):
with profile("task:deserialize_arguments", worker=self):
arguments = self._get_arguments_for_execution(
function_name, args)
except (RayGetError, RayGetArgumentError) as e:
@ -882,7 +882,7 @@ class Worker(object):
# Execute the task.
try:
with log_span("ray:task:execute", worker=self):
with profile("task:execute", worker=self):
if task.actor_id().id() == NIL_ACTOR_ID:
outputs = function_executor(*arguments)
else:
@ -901,7 +901,7 @@ class Worker(object):
# Store the outputs in the local object store.
try:
with log_span("ray:task:store_outputs", worker=self):
with profile("task:store_outputs", worker=self):
# If this is an actor task, then the last object ID returned by
# the task is a dummy output, not returned by the function
# itself. Decrement to get the correct number of return values.
@ -976,7 +976,7 @@ class Worker(object):
# Wait until the function to be executed has actually been registered
# on this worker. We will push warnings to the user if we spend too
# long in this loop.
with log_span("ray:wait_for_function", worker=self):
with profile("wait_for_function", worker=self):
self._wait_for_function(function_id, driver_id)
# Execute the task.
@ -984,22 +984,26 @@ class Worker(object):
# warning to the user if we are waiting too long to acquire the lock
# because that may indicate that the system is hanging, and it'd be
# good to know where the system is hanging.
log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=self)
with self.lock:
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=self)
function_name = (self.function_execution_info[driver_id][
function_id.id()]).function_name
contents = {
if not self.use_raylet:
extra_data = {
"function_name": function_name,
"task_id": task.task_id().hex(),
"worker_id": binary_to_hex(self.worker_id)
}
with log_span("ray:task", contents=contents, worker=self):
else:
extra_data = {
"name": function_name,
"task_id": task.task_id().hex()
}
with profile("task", extra_data=extra_data, worker=self):
self._process_task(task)
# Push all of the log events to the global state store.
flush_log()
flush_profile_data()
# Increase the task execution counter.
self.num_task_executions[driver_id][function_id.id()] += 1
@ -1017,7 +1021,7 @@ class Worker(object):
Returns:
A task from the local scheduler.
"""
with log_span("ray:get_task", worker=self):
with profile("get_task", worker=self):
task = self.local_scheduler_client.get_task()
# Automatically restrict the GPUs available to this task.
@ -1103,7 +1107,7 @@ def _webui_url_helper(client):
The URL of the web UI as a string.
"""
result = client.hmget("webui", "url")[0]
return result.decode("ascii") if result is not None else result
return ray.utils.decode(result) if result is not None else result
def get_webui_url():
@ -1194,9 +1198,9 @@ def error_info(worker=global_worker):
if error_applies_to_driver(error_key, worker=worker):
error_contents = worker.redis_client.hgetall(error_key)
error_contents = {
"type": error_contents[b"type"].decode("ascii"),
"message": error_contents[b"message"].decode("ascii"),
"data": error_contents[b"data"].decode("ascii")
"type": ray.utils.decode(error_contents[b"type"]),
"message": ray.utils.decode(error_contents[b"message"]),
"data": ray.utils.decode(error_contents[b"data"])
}
errors.append(error_contents)
@ -1296,13 +1300,14 @@ def get_address_info_from_redis_helper(redis_address,
assert b"ray_client_id" in info
assert b"node_ip_address" in info
assert b"client_type" in info
client_node_ip_address = info[b"node_ip_address"].decode("ascii")
client_node_ip_address = ray.utils.decode(info[b"node_ip_address"])
if (client_node_ip_address == node_ip_address or
(client_node_ip_address == "127.0.0.1"
and redis_ip_address == ray.services.get_node_ip_address())):
if info[b"client_type"].decode("ascii") == "plasma_manager":
if ray.utils.decode(info[b"client_type"]) == "plasma_manager":
plasma_managers.append(info)
elif info[b"client_type"].decode("ascii") == "local_scheduler":
elif (ray.utils.decode(
info[b"client_type"]) == "local_scheduler"):
local_schedulers.append(info)
# Make sure that we got at least one plasma manager and local
# scheduler.
@ -1311,16 +1316,16 @@ def get_address_info_from_redis_helper(redis_address,
# Build the address information.
object_store_addresses = []
for manager in plasma_managers:
address = manager[b"manager_address"].decode("ascii")
address = ray.utils.decode(manager[b"manager_address"])
port = services.get_port(address)
object_store_addresses.append(
services.ObjectStoreAddress(
name=manager[b"store_socket_name"].decode("ascii"),
manager_name=manager[b"manager_socket_name"].decode(
"ascii"),
name=ray.utils.decode(manager[b"store_socket_name"]),
manager_name=ray.utils.decode(
manager[b"manager_socket_name"]),
manager_port=port))
scheduler_names = [
scheduler[b"local_scheduler_socket_name"].decode("ascii")
ray.utils.decode(scheduler[b"local_scheduler_socket_name"])
for scheduler in local_schedulers
]
client_info = {
@ -1343,8 +1348,8 @@ def get_address_info_from_redis_helper(redis_address,
for client_message in clients:
client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
client_message, 0)
client_node_ip_address = client.NodeManagerAddress().decode(
"ascii")
client_node_ip_address = ray.utils.decode(
client.NodeManagerAddress())
if (client_node_ip_address == node_ip_address or
(client_node_ip_address == "127.0.0.1"
and redis_ip_address == ray.services.get_node_ip_address())):
@ -1352,12 +1357,12 @@ def get_address_info_from_redis_helper(redis_address,
object_store_addresses = [
services.ObjectStoreAddress(
name=raylet.ObjectStoreSocketName().decode("ascii"),
name=ray.utils.decode(raylet.ObjectStoreSocketName()),
manager_name=None,
manager_port=None) for raylet in raylets
]
raylet_socket_names = [
raylet.RayletSocketName().decode("ascii") for raylet in raylets
ray.utils.decode(raylet.RayletSocketName()) for raylet in raylets
]
return {
"node_ip_address": node_ip_address,
@ -1807,6 +1812,21 @@ def custom_excepthook(type, value, tb):
sys.excepthook = custom_excepthook
def _flush_profile_events(worker):
"""Drivers run this as a thread to flush profile data in the background."""
# Note(rkn): This is run on a background thread in the driver. It uses the
# local scheduler client. This should be ok because it doesn't read from
# the local scheduler client and we have the GIL here. However, if either
# of those things changes, then we could run into issues.
try:
while True:
time.sleep(1)
flush_profile_data(worker=worker)
except AttributeError:
# This is to suppress errors that occur at shutdown.
pass
def print_error_messages_raylet(worker):
"""Print error messages in the background on the driver.
@ -1858,7 +1878,7 @@ def print_error_messages_raylet(worker):
if job_id not in [worker.task_driver_id.id(), NIL_JOB_ID]:
continue
error_message = error_data.ErrorMessage().decode("ascii")
error_message = ray.utils.decode(error_data.ErrorMessage())
if error_message not in old_error_messages:
logger.error(error_message)
@ -1900,8 +1920,8 @@ def print_error_messages(worker):
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
for error_key in error_keys:
if error_applies_to_driver(error_key, worker=worker):
error_message = worker.redis_client.hget(
error_key, "message").decode("ascii")
error_message = ray.utils.decode(
worker.redis_client.hget(error_key, "message"))
if error_message not in old_error_messages:
logger.error(error_message)
old_error_messages.add(error_message)
@ -1915,8 +1935,8 @@ def print_error_messages(worker):
for error_key in worker.redis_client.lrange(
"ErrorKeys", num_errors_received, -1):
if error_applies_to_driver(error_key, worker=worker):
error_message = worker.redis_client.hget(
error_key, "message").decode("ascii")
error_message = ray.utils.decode(
worker.redis_client.hget(error_key, "message"))
if error_message not in old_error_messages:
logger.error(error_message)
old_error_messages.add(error_message)
@ -1939,9 +1959,9 @@ def fetch_and_register_remote_function(key, worker=global_worker):
"module", "resources", "max_calls"
])
function_id = ray.ObjectID(function_id_str)
function_name = function_name.decode("ascii")
function_name = ray.utils.decode(function_name)
max_calls = int(max_calls)
module = module.decode("ascii")
module = ray.utils.decode(module)
# This is a placeholder in case the function can't be unpickled. This will
# be overwritten if the function is successfully registered.
@ -2031,14 +2051,17 @@ def import_thread(worker, mode):
# Handle the driver case first.
if mode != WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
with profile("fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(key, worker=worker)
# Continue because FunctionsToRun are the only things that the
# driver should import.
continue
if key.startswith(b"RemoteFunction"):
with profile("register_remote_function", worker=worker):
fetch_and_register_remote_function(key, worker=worker)
elif key.startswith(b"FunctionsToRun"):
with profile("fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(key, worker=worker)
elif key.startswith(b"ActorClass"):
# Keep track of the fact that this actor class has been
@ -2063,9 +2086,8 @@ def import_thread(worker, mode):
# Handle the driver case first.
if mode != WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
with log_span(
"ray:import_function_to_run",
worker=worker):
with profile(
"fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(
key, worker=worker)
# Continue because FunctionsToRun are the only things
@ -2073,13 +2095,12 @@ def import_thread(worker, mode):
continue
if key.startswith(b"RemoteFunction"):
with log_span(
"ray:import_remote_function", worker=worker):
with profile(
"register_remote_function", worker=worker):
fetch_and_register_remote_function(
key, worker=worker)
elif key.startswith(b"FunctionsToRun"):
with log_span(
"ray:import_function_to_run", worker=worker):
with profile("fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(
key, worker=worker)
elif key.startswith(b"ActorClass"):
@ -2333,6 +2354,13 @@ def connect(info,
t.daemon = True
t.start()
if mode in [SCRIPT_MODE, SILENT_MODE] and worker.use_raylet:
t = threading.Thread(target=_flush_profile_events, args=(worker, ))
# Making the thread a daemon causes it to exit when the main thread
# exits.
t.daemon = True
t.start()
if mode in [SCRIPT_MODE, SILENT_MODE]:
# Add the directory containing the script that is running to the Python
# paths of the workers. Also add the current directory. Note that this
@ -2526,7 +2554,8 @@ class RayLogSpan(object):
def __enter__(self):
"""Log the beginning of a span event."""
log(event_type=self.event_type,
_log(
event_type=self.event_type,
contents=self.contents,
kind=LOG_SPAN_START,
worker=self.worker)
@ -2534,11 +2563,13 @@ class RayLogSpan(object):
def __exit__(self, type, value, tb):
"""Log the end of a span event. Log any exception that occurred."""
if type is None:
log(event_type=self.event_type,
_log(
event_type=self.event_type,
kind=LOG_SPAN_END,
worker=self.worker)
else:
log(event_type=self.event_type,
_log(
event_type=self.event_type,
contents={
"type": str(type),
"value": value,
@ -2548,19 +2579,109 @@ class RayLogSpan(object):
worker=self.worker)
def log_span(event_type, contents=None, worker=global_worker):
return RayLogSpan(event_type, contents=contents, worker=worker)
class RayLogSpanRaylet(object):
"""An object used to enable logging a span of events with a with statement.
Attributes:
event_type (str): The type of the event being logged.
contents: Additional information to log.
"""
def __init__(self, event_type, extra_data=None, worker=global_worker):
"""Initialize a RayLogSpan object."""
self.event_type = event_type
self.extra_data = extra_data if extra_data is not None else {}
self.worker = worker
def set_attribute(self, key, value):
"""Add a key-value pair to the extra_data dict.
This can be used to add attributes that are not available when
ray.profile was called.
Args:
key: The attribute name.
value: The attribute value.
"""
if not isinstance(key, str) or not isinstance(value, str):
raise ValueError("The extra_data argument must be a "
"dictionary mapping strings to strings.")
self.extra_data[key] = value
def __enter__(self):
"""Log the beginning of a span event.
Returns:
The object itself is returned so that if the block is opened using
"with ray.profile(...) as prof:", we can call
"prof.set_attribute" inside the block.
"""
self.start_time = time.time()
return self
def __exit__(self, type, value, tb):
"""Log the end of a span event. Log any exception that occurred."""
for key, value in self.extra_data.items():
if not isinstance(key, str) or not isinstance(value, str):
raise ValueError("The extra_data argument must be a "
"dictionary mapping strings to strings.")
event = {
"event_type": self.event_type,
"start_time": self.start_time,
"end_time": time.time(),
"extra_data": json.dumps(self.extra_data),
}
if type is not None:
event["extra_data"] = json.dumps({
"type": str(type),
"value": str(value),
"traceback": str(traceback.format_exc()),
})
self.worker.events.append(event)
def log_event(event_type, contents=None, worker=global_worker):
log(event_type, kind=LOG_POINT, contents=contents, worker=worker)
def profile(event_type, extra_data=None, worker=global_worker):
"""Profile a span of time so that it appears in the timeline visualization.
This function can be used as follows (both on the driver or within a task).
with ray.profile("custom event", extra_data={'key': 'value'}):
# Do some computation here.
Optionally, a dictionary can be passed as the "extra_data" argument, and
it can have keys "name" and "cname" if you want to override the default
timeline display text and box color. Other values will appear at the bottom
of the chrome tracing GUI when you click on the box corresponding to this
profile span.
Args:
event_type: A string describing the type of the event.
extra_data: This must be a dictionary mapping strings to strings. This
data will be added to the json objects that are used to populate
the timeline, so if you want to set a particular color, you can
simply set the "cname" attribute to an appropriate color.
Similarly, if you set the "name" attribute, then that will set the
text displayed on the box in the timeline.
Returns:
An object that can profile a span of time via a "with" statement.
"""
if not worker.use_raylet:
return RayLogSpan(event_type, contents=extra_data, worker=worker)
else:
return RayLogSpanRaylet(
event_type, extra_data=extra_data, worker=worker)
def log(event_type, kind, contents=None, worker=global_worker):
def _log(event_type, kind, contents=None, worker=global_worker):
"""Log an event to the global state store.
This adds the event to a buffer of events locally. The buffer can be
flushed and written to the global state store by calling flush_log().
flushed and written to the global state store by calling
flush_profile_data().
Args:
event_type (str): The type of the event.
@ -2571,6 +2692,9 @@ def log(event_type, kind, contents=None, worker=global_worker):
time, and it is LOG_SPAN_END if we are finishing logging a span of
time.
"""
if worker.use_raylet:
raise Exception(
"This method is not supported in the raylet code path.")
# TODO(rkn): This code currently takes around half a microsecond. Since we
# call it tens of times per task, this adds up. We will need to redo the
# logging code, perhaps in C.
@ -2584,13 +2708,32 @@ def log(event_type, kind, contents=None, worker=global_worker):
worker.events.append((time.time(), event_type, kind, contents))
def flush_log(worker=global_worker):
"""Send the logged worker events to the global state store."""
# TODO(rkn): Support calling this function in the middle of a task, and also
# call this periodically in the background from the driver.
def flush_profile_data(worker=global_worker):
"""Push the logged profiling data to the global control store.
By default, profiling information for a given task won't appear in the
timeline until after the task has completed. For very long-running tasks,
we may want profiling information to appear more quickly. In such cases,
this function can be called. Note that as an alternative, we could start
a thread in the background on workers that calls this automatically.
"""
if not worker.use_raylet:
event_log_key = b"event_log:" + worker.worker_id
event_log_value = json.dumps(worker.events)
if not worker.use_raylet:
worker.local_scheduler_client.log_event(event_log_key, event_log_value,
time.time())
else:
if worker.mode == WORKER_MODE:
component_type = "worker"
else:
component_type = "driver"
worker.local_scheduler_client.push_profile_events(
component_type, ray.ObjectID(worker.worker_id),
worker.node_ip_address, worker.events)
worker.events = []
@ -2611,7 +2754,7 @@ def get(object_ids, worker=global_worker):
A Python object or a list of Python objects.
"""
worker.check_connected()
with log_span("ray:get", worker=worker):
with profile("ray.get", worker=worker):
check_main_thread()
if worker.mode == PYTHON_MODE:
@ -2644,7 +2787,7 @@ def put(value, worker=global_worker):
The object ID assigned to this value.
"""
worker.check_connected()
with log_span("ray:put", worker=worker):
with profile("ray.put", worker=worker):
check_main_thread()
if worker.mode == PYTHON_MODE:
@ -2702,7 +2845,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
type(object_id)))
worker.check_connected()
with log_span("ray:wait", worker=worker):
with profile("ray.wait", worker=worker):
check_main_thread()
# When Ray is run in PYTHON_MODE, all functions are run immediately,

View file

@ -165,7 +165,11 @@ static PyObject *PyObjectID_id(PyObject *self) {
static PyObject *PyObjectID_hex(PyObject *self) {
PyObjectID *s = (PyObjectID *) self;
std::string hex_id = s->object_id.hex();
PyObject *result = PyUnicode_FromString(hex_id.c_str());
#if PY_MAJOR_VERSION >= 3
PyObject *result = PyUnicode_FromStringAndSize(hex_id.data(), hex_id.size());
#else
PyObject *result = PyBytes_FromStringAndSize(hex_id.data(), hex_id.size());
#endif
return result;
}

View file

@ -695,6 +695,8 @@ int TableAppend_DoWrite(RedisModuleCtx *ctx,
// Check that we actually add a new entry during the append. This is only
// necessary since we implement the log with a sorted set, so all entries
// must be unique, or else we will have gaps in the log.
// TODO(rkn): We need to get rid of this uniqueness requirement. We can
// easily have multiple log events with the same message.
RAY_CHECK(flags == REDISMODULE_ZADD_ADDED) << "Appended a duplicate entry";
return REDISMODULE_OK;
} else {

View file

@ -309,6 +309,109 @@ static PyObject *PyLocalSchedulerClient_push_error(PyObject *self,
Py_RETURN_NONE;
}
int PyBytes_or_PyUnicode_to_string(PyObject *py_string, std::string &out) {
// Handle the case where the key is a bytes object and the case where it
// is a unicode object.
if (PyUnicode_Check(py_string)) {
PyObject *ascii_string = PyUnicode_AsASCIIString(py_string);
out =
std::string(PyBytes_AsString(ascii_string), PyBytes_Size(ascii_string));
Py_DECREF(ascii_string);
} else if (PyBytes_Check(py_string)) {
out = std::string(PyBytes_AsString(py_string), PyBytes_Size(py_string));
} else {
return -1;
}
return 0;
}
static PyObject *PyLocalSchedulerClient_push_profile_events(PyObject *self,
PyObject *args) {
const char *component_type;
int component_type_length;
UniqueID component_id;
PyObject *profile_data;
const char *node_ip_address;
int node_ip_address_length;
if (!PyArg_ParseTuple(args, "s#O&s#O", &component_type,
&component_type_length, &PyObjectToUniqueID,
&component_id, &node_ip_address,
&node_ip_address_length, &profile_data)) {
return NULL;
}
ProfileTableDataT profile_info;
profile_info.component_type =
std::string(component_type, component_type_length);
profile_info.component_id = component_id.binary();
profile_info.node_ip_address =
std::string(node_ip_address, node_ip_address_length);
if (PyList_Size(profile_data) == 0) {
// Short circuit if there are no profile events.
Py_RETURN_NONE;
}
for (int64_t i = 0; i < PyList_Size(profile_data); ++i) {
ProfileEventT profile_event;
PyObject *py_profile_event = PyList_GetItem(profile_data, i);
if (!PyDict_CheckExact(py_profile_event)) {
return NULL;
}
PyObject *key, *val;
Py_ssize_t pos = 0;
while (PyDict_Next(py_profile_event, &pos, &key, &val)) {
std::string key_string;
if (PyBytes_or_PyUnicode_to_string(key, key_string) == -1) {
return NULL;
}
// TODO(rkn): If the dictionary is formatted incorrectly, that could lead
// to errors. E.g., if any of the strings are empty, that will cause
// segfaults in the node manager.
if (key_string == std::string("event_type")) {
if (PyBytes_or_PyUnicode_to_string(val, profile_event.event_type) ==
-1) {
return NULL;
}
if (profile_event.event_type.size() == 0) {
return NULL;
}
} else if (key_string == std::string("start_time")) {
profile_event.start_time = PyFloat_AsDouble(val);
} else if (key_string == std::string("end_time")) {
profile_event.end_time = PyFloat_AsDouble(val);
} else if (key_string == std::string("extra_data")) {
if (PyBytes_or_PyUnicode_to_string(val, profile_event.extra_data) ==
-1) {
return NULL;
}
if (profile_event.extra_data.size() == 0) {
return NULL;
}
} else {
return NULL;
}
}
// Note that profile_info.profile_events is a vector of unique pointers, so
// profile_event will be deallocated when profile_info goes out of scope.
profile_info.profile_events.emplace_back(new ProfileEventT(profile_event));
}
local_scheduler_push_profile_events(
reinterpret_cast<PyLocalSchedulerClient *>(self)
->local_scheduler_connection,
profile_info);
Py_RETURN_NONE;
}
static PyMethodDef PyLocalSchedulerClient_methods[] = {
{"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS,
"Notify the local scheduler that this client is exiting gracefully."},
@ -338,6 +441,9 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = {
"Wait for a list of objects to be created."},
{"push_error", (PyCFunction) PyLocalSchedulerClient_push_error,
METH_VARARGS, "Push an error message to the relevant driver."},
{"push_profile_events",
(PyCFunction) PyLocalSchedulerClient_push_profile_events, METH_VARARGS,
"Store some profiling events in the GCS."},
{NULL} /* Sentinel */
};

View file

@ -322,3 +322,17 @@ void local_scheduler_push_error(LocalSchedulerConnection *conn,
ray::protocol::MessageType::PushErrorRequest),
fbb.GetSize(), fbb.GetBufferPointer());
}
void local_scheduler_push_profile_events(
LocalSchedulerConnection *conn,
const ProfileTableDataT &profile_events) {
flatbuffers::FlatBufferBuilder fbb;
auto message = CreateProfileTableData(fbb, &profile_events);
fbb.Finish(message);
write_message(conn->conn,
static_cast<int64_t>(
ray::protocol::MessageType::PushProfileEventsRequest),
fbb.GetSize(), fbb.GetBufferPointer());
}

View file

@ -225,4 +225,13 @@ void local_scheduler_push_error(LocalSchedulerConnection *conn,
const std::string &error_message,
double timestamp);
/// Store some profile events in the GCS.
///
/// \param conn The connection information.
/// \param profile_events A batch of profiling event information.
/// \return Void.
void local_scheduler_push_profile_events(
LocalSchedulerConnection *conn,
const ProfileTableDataT &profile_events);
#endif

View file

@ -17,6 +17,7 @@ AsyncGcsClient::AsyncGcsClient(const ClientID &client_id, CommandType command_ty
task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this));
heartbeat_table_.reset(new HeartbeatTable(context_, this));
error_table_.reset(new ErrorTable(primary_context_, this));
profile_table_.reset(new ProfileTable(context_, this));
command_type_ = command_type;
}
@ -84,6 +85,8 @@ HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; }
ErrorTable &AsyncGcsClient::error_table() { return *error_table_; }
ProfileTable &AsyncGcsClient::profile_table() { return *profile_table_; }
} // namespace gcs
} // namespace ray

View file

@ -59,6 +59,7 @@ class RAY_EXPORT AsyncGcsClient {
ClientTable &client_table();
HeartbeatTable &heartbeat_table();
ErrorTable &error_table();
ProfileTable &profile_table();
// We also need something to export generic code to run on workers from the
// driver (to set the PYTHONPATH)
@ -81,6 +82,7 @@ class RAY_EXPORT AsyncGcsClient {
std::unique_ptr<TaskReconstructionLog> task_reconstruction_log_;
std::unique_ptr<HeartbeatTable> heartbeat_table_;
std::unique_ptr<ErrorTable> error_table_;
std::unique_ptr<ProfileTable> profile_table_;
std::unique_ptr<ClientTable> client_table_;
// The following contexts write to the data shard
std::shared_ptr<RedisContext> context_;

View file

@ -15,6 +15,7 @@ enum TablePrefix:int {
TASK_RECONSTRUCTION,
HEARTBEAT,
ERROR_INFO,
PROFILE,
}
// The channel that Add operations to the Table should be published on, if any.
@ -121,6 +122,33 @@ table CustomSerializerData {
table ConfigTableData {
}
table ProfileEvent {
// The type of the event.
event_type: string;
// The start time of the event.
start_time: double;
// The end time of the event. If the event is a point event, then this should
// be the same as the start time.
end_time: double;
// Additional data associated with the event. This data must be serialized
// using JSON.
extra_data: string;
}
table ProfileTableData {
// The type of the component that generated the event, e.g., worker or
// object_manager, or node_manager.
component_type: string;
// An identifier for the component that generated the event.
component_id: string;
// An identifier for the node that generated the event.
node_ip_address: string;
// This is a batch of profiling events. We batch these together for
// performance reasons because a single task may generate many events, and
// we don't want each event to require a GCS command.
profile_events: [ProfileEvent];
}
table RayResource {
// The type of the resource.
resource_name: string;

View file

@ -219,6 +219,45 @@ Status ErrorTable::PushErrorToDriver(const JobID &job_id, const std::string &typ
});
}
Status ProfileTable::AddProfileEvent(const std::string &event_type,
const std::string &component_type,
const UniqueID &component_id,
const std::string &node_ip_address,
double start_time, double end_time,
const std::string &extra_data) {
auto data = std::make_shared<ProfileTableDataT>();
ProfileEventT profile_event;
profile_event.event_type = event_type;
profile_event.start_time = start_time;
profile_event.end_time = end_time;
profile_event.extra_data = extra_data;
data->component_type = component_type;
data->component_id = component_id.binary();
data->node_ip_address = node_ip_address;
data->profile_events.emplace_back(new ProfileEventT(profile_event));
return Append(JobID::nil(), component_id, data,
[](ray::gcs::AsyncGcsClient *client, const JobID &id,
const ProfileTableDataT &data) {
RAY_LOG(DEBUG) << "Profile message pushed callback";
});
}
Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) {
auto data = std::make_shared<ProfileTableDataT>();
// There is some room for optimization here because the Append function will just
// call "Pack" and undo the "UnPack".
profile_events.UnPackTo(data.get());
return Append(JobID::nil(), from_flatbuf(*profile_events.component_id()), data,
[](ray::gcs::AsyncGcsClient *client, const JobID &id,
const ProfileTableDataT &data) {
RAY_LOG(DEBUG) << "Profile message pushed callback";
});
}
void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) {
client_added_callback_ = callback;
// Call the callback for any added clients that are cached.
@ -371,6 +410,7 @@ template class Log<TaskID, TaskReconstructionData>;
template class Table<ClientID, HeartbeatTableData>;
template class Log<JobID, ErrorTableData>;
template class Log<UniqueID, ClientTableData>;
template class Log<UniqueID, ProfileTableData>;
} // namespace gcs

View file

@ -12,6 +12,7 @@
#include "ray/gcs/format/gcs_generated.h"
#include "ray/gcs/redis_context.h"
// TODO(rkn): Remove this include.
#include "ray/raylet/format/node_manager_generated.h"
// TODO(pcm): Remove this
@ -95,7 +96,8 @@ class Log : virtual public PubsubInterface<ID> {
///
/// \param job_id The ID of the job (= driver).
/// \param id The ID of the data that is added to the GCS.
/// \param data Data to append to the log.
/// \param data Data to append to the log. TODO(rkn): This can be made const,
/// right?
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
@ -438,7 +440,8 @@ class ErrorTable : private Log<JobID, ErrorTableData> {
/// Push an error message for a specific job.
///
/// TODO(rkn): We need to make sure that the errors are unique because
/// duplicate messages currently cause failures (the GCS doesn't allow it).
/// duplicate messages currently cause failures (the GCS doesn't allow it). A
/// natural way to do this is to have finer-grained time stamps.
///
/// \param job_id The ID of the job that generated the error. If the error
/// should be pushed to all jobs, then this should be nil.
@ -450,6 +453,37 @@ class ErrorTable : private Log<JobID, ErrorTableData> {
const std::string &error_message, double timestamp);
};
class ProfileTable : private Log<UniqueID, ProfileTableData> {
public:
ProfileTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Log(context, client) {
prefix_ = TablePrefix::PROFILE;
};
/// Add a single profile event to the profile table.
///
/// \param event_type The type of the event.
/// \param component_type The type of the component that the event came from.
/// \param component_id An identifier for the component that generated the event.
/// \param node_ip_address The IP address of the node that generated the event.
/// \param start_time The timestamp of the event start, this should be in seconds since
/// the Unix epoch.
/// \param end_time The timestamp of the event end, this should be in seconds since
/// the Unix epoch. If the event is a point event, this should be equal to start_time.
/// \param extra_data Additional data to associate with the event.
/// \return Status.
Status AddProfileEvent(const std::string &event_type, const std::string &component_type,
const UniqueID &component_id, const std::string &node_ip_address,
double start_time, double end_time,
const std::string &extra_data);
/// Add a batch of profiling events to the profile table.
///
/// \param profile_events The profile events to record.
/// \return Status.
Status AddProfileEventBatch(const ProfileTableData &profile_events);
};
using CustomSerializerTable = Table<ClassID, CustomSerializerData>;
using ConfigTable = Table<ConfigID, ConfigTableData>;

View file

@ -12,7 +12,7 @@ add_custom_command(
# flatbuffers message Message, which can be used to store deserialized
# messages in data structures. This is currently used for ObjectInfo for
# example.
COMMAND ${FLATBUFFERS_COMPILER} -c -o ${OUTPUT_DIR} ${NODE_MANAGER_FBS_SRC} --cpp --gen-object-api --gen-mutable --scoped-enums
COMMAND ${FLATBUFFERS_COMPILER} -c -o ${OUTPUT_DIR} -I ${GCS_FBS_OUTPUT_DIRECTORY} ${NODE_MANAGER_FBS_SRC} --cpp --gen-object-api --gen-mutable --scoped-enums
DEPENDS ${FBS_DEPENDS}
COMMENT "Running flatc compiler on ${NODE_MANAGER_FBS_SRC}"
VERBATIM)
@ -23,7 +23,7 @@ add_custom_target(gen_node_manager_fbs DEPENDS ${NODE_MANAGER_FBS_OUTPUT_FILES})
set(PYTHON_OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/../../../python/ray/core/generated/)
add_custom_command(
TARGET gen_node_manager_fbs
COMMAND ${FLATBUFFERS_COMPILER} -p -o ${PYTHON_OUTPUT_DIR} ${NODE_MANAGER_FBS_SRC}
COMMAND ${FLATBUFFERS_COMPILER} -p -o ${PYTHON_OUTPUT_DIR} -I ${GCS_FBS_OUTPUT_DIRECTORY} ${NODE_MANAGER_FBS_SRC}
DEPENDS ${FBS_DEPENDS}
COMMENT "Running flatc compiler on ${NODE_MANAGER_FBS_SRC}"
VERBATIM)
@ -38,6 +38,7 @@ ADD_RAY_TEST(task_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main p
ADD_RAY_TEST(lineage_cache_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
ADD_RAY_TEST(task_dependency_manager_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
include_directories(${GCS_FBS_OUTPUT_DIRECTORY})
add_library(rayletlib raylet.cc ${NODE_MANAGER_FBS_OUTPUT_FILES})
target_link_libraries(rayletlib ray_static ${Boost_SYSTEM_LIBRARY})

View file

@ -1,5 +1,8 @@
// Local scheduler protocol specification
include "gcs.fbs";
// TODO(swang): We put the flatbuffer types in a separate namespace for now to
// avoid conflicts with legacy Ray types.
namespace ray.protocol;
@ -62,6 +65,9 @@ enum MessageType:int {
// Push an error to the relevant driver. This is sent from a worker to the
// node manager.
PushErrorRequest,
// Push some profiling events to the GCS. When sending this message to the
// node manager, the message itself is serialized as a ProfileTableData object.
PushProfileEventsRequest,
}
table TaskExecutionSpecification {

View file

@ -552,6 +552,11 @@ void NodeManager::ProcessClientMessage(
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message,
timestamp));
} break;
case protocol::MessageType::PushProfileEventsRequest: {
auto message = flatbuffers::GetRoot<ProfileTableData>(message_data);
RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*message));
} break;
default:
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;

View file

@ -11,6 +11,18 @@ import subprocess
import sys
# This is duplicated from ray.utils so that we do not have to introduce a
# dependency on Ray to run this file.
def decode(byte_str):
"""Make this unicode in Python 3, otherwise leave it as bytes."""
if not isinstance(byte_str, bytes):
raise ValueError("The argument must be a bytes object.")
if sys.version_info >= (3, 0):
return byte_str.decode("ascii")
else:
return byte_str
def wait_for_output(proc):
"""This is a convenience method to parse a process's stdout and stderr.
@ -27,7 +39,7 @@ def wait_for_output(proc):
# NOTE(rkn): This try/except block is here because I once saw an
# exception raised here and want to print more information if that
# happens again.
stdout_data = stdout_data.decode("ascii")
stdout_data = decode(stdout_data)
except UnicodeDecodeError:
raise Exception("Failed to decode stdout_data:", stdout_data)
@ -36,7 +48,7 @@ def wait_for_output(proc):
# NOTE(rkn): This try/except block is here because I once saw an
# exception raised here and want to print more information if that
# happens again.
stderr_data = stderr_data.decode("ascii")
stderr_data = decode(stderr_data)
except UnicodeDecodeError:
raise Exception("Failed to decode stderr_data:", stderr_data)

View file

@ -24,7 +24,8 @@ def run_string_as_driver(driver_script):
with tempfile.NamedTemporaryFile() as f:
f.write(driver_script.encode("ascii"))
f.flush()
out = subprocess.check_output([sys.executable, f.name]).decode("ascii")
out = ray.utils.decode(
subprocess.check_output([sys.executable, f.name]))
return out

View file

@ -974,7 +974,7 @@ class APITest(unittest.TestCase):
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") == "1",
"This test does not work with xray yet.")
"This test does not work with xray (nor is it intended to).")
def testLoggingAPI(self):
self.init_ray(driver_mode=ray.SILENT_MODE)
@ -996,38 +996,76 @@ class APITest(unittest.TestCase):
time.sleep(0.1)
print("Timing out of wait.")
@ray.remote
def test_log_event():
ray.log_event("event_type1", contents={"key": "val"})
@ray.remote
def test_log_span():
with ray.log_span("event_type2", contents={"key": "val"}):
with ray.profile("event_type2", extra_data={"key": "val"}):
pass
# Make sure that we can call ray.log_event in a remote function.
ray.get(test_log_event.remote())
# Wait for the event to appear in the event log.
wait_for_num_events(1)
self.assertEqual(len(events()), 1)
# Make sure that we can call ray.log_span in a remote function.
ray.get(test_log_span.remote())
# Wait for the events to appear in the event log.
wait_for_num_events(2)
self.assertEqual(len(events()), 2)
wait_for_num_events(1)
self.assertEqual(len(events()), 1)
@ray.remote
def test_log_span_exception():
with ray.log_span("event_type2", contents={"key": "val"}):
with ray.log_span("event_type2", extra_data={"key": "val"}):
raise Exception("This failed.")
# Make sure that logging a span works if an exception is thrown.
test_log_span_exception.remote()
# Wait for the events to appear in the event log.
wait_for_num_events(3)
self.assertEqual(len(events()), 3)
wait_for_num_events(2)
self.assertEqual(len(events()), 2)
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") != "1",
"This test only works with xray.")
def testProfilingAPI(self):
self.init_ray(num_cpus=2)
@ray.remote
def f():
with ray.profile(
"custom_event",
extra_data={"name": "custom name"}) as ray_prof:
ray_prof.set_attribute("key", "value")
ray.put(1)
object_id = f.remote()
ray.wait([object_id])
ray.get(object_id)
# Wait until all of the profiling information appears in the profile
# table.
timeout_seconds = 20
start_time = time.time()
while True:
if time.time() - start_time > timeout_seconds:
raise Exception("Timed out while waiting for information in "
"profile table.")
profile_data = ray.global_state.chrome_tracing_dump()
event_types = {event["cat"] for event in profile_data}
expected_types = [
"get_task",
"task",
"task:deserialize_arguments",
"task:execute",
"task:store_outputs",
"wait_for_function",
"ray.get",
"ray.put",
"ray.wait",
"submit_task",
"fetch_and_run_function",
"register_remote_function",
"custom_event", # This is the custom one from ray.profile.
]
if all(expected_type in event_types
for expected_type in expected_types):
break
def testIdenticalFunctionNames(self):
# Define a bunch of remote functions and make sure that we don't
@ -1116,6 +1154,10 @@ class APITestSharded(APITest):
if kwargs is None:
kwargs = {}
kwargs["start_ray_local"] = True
if os.environ.get("RAY_USE_XRAY") == "1":
print("XRAY currently supports only a single Redis shard.")
kwargs["num_redis_shards"] = 1
else:
kwargs["num_redis_shards"] = 20
kwargs["redirect_output"] = True
ray.worker._init(**kwargs)
@ -2203,7 +2245,7 @@ class GlobalStateAPI(unittest.TestCase):
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") == "1",
"This test does not work with xray yet.")
"This test does not work with xray (nor is it intended to).")
def testTaskProfileAPI(self):
ray.init(redirect_output=True)