mirror of
https://github.com/vale981/ray
synced 2025-03-11 21:56:39 -04:00
883 lines
37 KiB
Python
883 lines
37 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import dis
|
|
import hashlib
|
|
import importlib
|
|
import inspect
|
|
import json
|
|
import logging
|
|
import sys
|
|
import time
|
|
import threading
|
|
import traceback
|
|
from collections import (
|
|
namedtuple,
|
|
defaultdict,
|
|
)
|
|
|
|
import ray
|
|
from ray import profiling
|
|
from ray import ray_constants
|
|
from ray import cloudpickle as pickle
|
|
from ray.utils import (
|
|
binary_to_hex,
|
|
is_function_or_method,
|
|
is_class_method,
|
|
check_oversized_pickle,
|
|
decode,
|
|
ensure_str,
|
|
format_error_message,
|
|
push_error_to_driver,
|
|
)
|
|
|
|
FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
|
|
["function", "function_name", "max_calls"])
|
|
"""FunctionExecutionInfo: A named tuple storing remote function information."""
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FunctionDescriptor(object):
|
|
"""A class used to describe a python function.
|
|
|
|
Attributes:
|
|
module_name: the module name that the function belongs to.
|
|
class_name: the class name that the function belongs to if exists.
|
|
It could be empty is the function is not a class method.
|
|
function_name: the function name of the function.
|
|
function_hash: the hash code of the function source code if the
|
|
function code is available.
|
|
function_id: the function id calculated from this descriptor.
|
|
is_for_driver_task: whether this descriptor is for driver task.
|
|
"""
|
|
|
|
def __init__(self,
|
|
module_name,
|
|
function_name,
|
|
class_name="",
|
|
function_source_hash=b""):
|
|
self._module_name = module_name
|
|
self._class_name = class_name
|
|
self._function_name = function_name
|
|
self._function_source_hash = function_source_hash
|
|
self._function_id = self._get_function_id()
|
|
|
|
def __repr__(self):
|
|
return ("FunctionDescriptor:" + self._module_name + "." +
|
|
self._class_name + "." + self._function_name + "." +
|
|
binary_to_hex(self._function_source_hash))
|
|
|
|
@classmethod
|
|
def from_bytes_list(cls, function_descriptor_list):
|
|
"""Create a FunctionDescriptor instance from list of bytes.
|
|
|
|
This function is used to create the function descriptor from
|
|
backend data.
|
|
|
|
Args:
|
|
cls: Current class which is required argument for classmethod.
|
|
function_descriptor_list: list of bytes to represent the
|
|
function descriptor.
|
|
|
|
Returns:
|
|
The FunctionDescriptor instance created from the bytes list.
|
|
"""
|
|
assert isinstance(function_descriptor_list, list)
|
|
if len(function_descriptor_list) == 0:
|
|
# This is a function descriptor of driver task.
|
|
return FunctionDescriptor.for_driver_task()
|
|
elif (len(function_descriptor_list) == 3
|
|
or len(function_descriptor_list) == 4):
|
|
module_name = ensure_str(function_descriptor_list[0])
|
|
class_name = ensure_str(function_descriptor_list[1])
|
|
function_name = ensure_str(function_descriptor_list[2])
|
|
if len(function_descriptor_list) == 4:
|
|
return cls(module_name, function_name, class_name,
|
|
function_descriptor_list[3])
|
|
else:
|
|
return cls(module_name, function_name, class_name)
|
|
else:
|
|
raise Exception(
|
|
"Invalid input for FunctionDescriptor.from_bytes_list")
|
|
|
|
@classmethod
|
|
def from_function(cls, function, pickled_function):
|
|
"""Create a FunctionDescriptor from a function instance.
|
|
|
|
This function is used to create the function descriptor from
|
|
a python function. If a function is a class function, it should
|
|
not be used by this function.
|
|
|
|
Args:
|
|
cls: Current class which is required argument for classmethod.
|
|
function: the python function used to create the function
|
|
descriptor.
|
|
pickled_function: This is factored in to ensure that any
|
|
modifications to the function result in a different function
|
|
descriptor.
|
|
|
|
Returns:
|
|
The FunctionDescriptor instance created according to the function.
|
|
"""
|
|
module_name = function.__module__
|
|
function_name = function.__name__
|
|
class_name = ""
|
|
|
|
pickled_function_hash = hashlib.sha1(pickled_function).digest()
|
|
|
|
return cls(module_name, function_name, class_name,
|
|
pickled_function_hash)
|
|
|
|
@classmethod
|
|
def from_class(cls, target_class):
|
|
"""Create a FunctionDescriptor from a class.
|
|
|
|
Args:
|
|
cls: Current class which is required argument for classmethod.
|
|
target_class: the python class used to create the function
|
|
descriptor.
|
|
|
|
Returns:
|
|
The FunctionDescriptor instance created according to the class.
|
|
"""
|
|
module_name = target_class.__module__
|
|
class_name = target_class.__name__
|
|
return cls(module_name, "__init__", class_name)
|
|
|
|
@classmethod
|
|
def for_driver_task(cls):
|
|
"""Create a FunctionDescriptor instance for a driver task."""
|
|
return cls("", "", "", b"")
|
|
|
|
@property
|
|
def is_for_driver_task(self):
|
|
"""See whether this function descriptor is for a driver or not.
|
|
|
|
Returns:
|
|
True if this function descriptor is for driver tasks.
|
|
"""
|
|
return all(
|
|
len(x) == 0
|
|
for x in [self.module_name, self.class_name, self.function_name])
|
|
|
|
@property
|
|
def module_name(self):
|
|
"""Get the module name of current function descriptor.
|
|
|
|
Returns:
|
|
The module name of the function descriptor.
|
|
"""
|
|
return self._module_name
|
|
|
|
@property
|
|
def class_name(self):
|
|
"""Get the class name of current function descriptor.
|
|
|
|
Returns:
|
|
The class name of the function descriptor. It could be
|
|
empty if the function is not a class method.
|
|
"""
|
|
return self._class_name
|
|
|
|
@property
|
|
def function_name(self):
|
|
"""Get the function name of current function descriptor.
|
|
|
|
Returns:
|
|
The function name of the function descriptor.
|
|
"""
|
|
return self._function_name
|
|
|
|
@property
|
|
def function_hash(self):
|
|
"""Get the hash code of the function source code.
|
|
|
|
Returns:
|
|
The bytes with length of ray_constants.ID_SIZE if the source
|
|
code is available. Otherwise, the bytes length will be 0.
|
|
"""
|
|
return self._function_source_hash
|
|
|
|
@property
|
|
def function_id(self):
|
|
"""Get the function id calculated from this descriptor.
|
|
|
|
Returns:
|
|
The value of ray.ObjectID that represents the function id.
|
|
"""
|
|
return self._function_id
|
|
|
|
def _get_function_id(self):
|
|
"""Calculate the function id of current function descriptor.
|
|
|
|
This function id is calculated from all the fields of function
|
|
descriptor.
|
|
|
|
Returns:
|
|
ray.ObjectID to represent the function descriptor.
|
|
"""
|
|
if self.is_for_driver_task:
|
|
return ray.FunctionID.nil()
|
|
function_id_hash = hashlib.sha1()
|
|
# Include the function module and name in the hash.
|
|
function_id_hash.update(self.module_name.encode("ascii"))
|
|
function_id_hash.update(self.function_name.encode("ascii"))
|
|
function_id_hash.update(self.class_name.encode("ascii"))
|
|
function_id_hash.update(self._function_source_hash)
|
|
# Compute the function ID.
|
|
function_id = function_id_hash.digest()
|
|
return ray.FunctionID(function_id)
|
|
|
|
def get_function_descriptor_list(self):
|
|
"""Return a list of bytes representing the function descriptor.
|
|
|
|
This function is used to pass this function descriptor to backend.
|
|
|
|
Returns:
|
|
A list of bytes.
|
|
"""
|
|
descriptor_list = []
|
|
if self.is_for_driver_task:
|
|
# Driver task returns an empty list.
|
|
return descriptor_list
|
|
else:
|
|
descriptor_list.append(self.module_name.encode("ascii"))
|
|
descriptor_list.append(self.class_name.encode("ascii"))
|
|
descriptor_list.append(self.function_name.encode("ascii"))
|
|
if len(self._function_source_hash) != 0:
|
|
descriptor_list.append(self._function_source_hash)
|
|
return descriptor_list
|
|
|
|
def is_actor_method(self):
|
|
"""Wether this function descriptor is an actor method.
|
|
|
|
Returns:
|
|
True if it's an actor method, False if it's a normal function.
|
|
"""
|
|
return len(self._class_name) > 0
|
|
|
|
|
|
class FunctionActorManager(object):
|
|
"""A class used to export/load remote functions and actors.
|
|
|
|
Attributes:
|
|
_worker: The associated worker that this manager related.
|
|
_functions_to_export: The remote functions to export when
|
|
the worker gets connected.
|
|
_actors_to_export: The actors to export when the worker gets
|
|
connected.
|
|
_function_execution_info: The map from job_id to function_id
|
|
and execution_info.
|
|
_num_task_executions: The map from job_id to function
|
|
execution times.
|
|
imported_actor_classes: The set of actor classes keys (format:
|
|
ActorClass:function_id) that are already in GCS.
|
|
"""
|
|
|
|
def __init__(self, worker):
|
|
self._worker = worker
|
|
self._functions_to_export = []
|
|
self._actors_to_export = []
|
|
# This field is a dictionary that maps a driver ID to a dictionary of
|
|
# functions (and information about those functions) that have been
|
|
# registered for that driver (this inner dictionary maps function IDs
|
|
# to a FunctionExecutionInfo object. This should only be used on
|
|
# workers that execute remote functions.
|
|
self._function_execution_info = defaultdict(lambda: {})
|
|
self._num_task_executions = defaultdict(lambda: {})
|
|
# A set of all of the actor class keys that have been imported by the
|
|
# import thread. It is safe to convert this worker into an actor of
|
|
# these types.
|
|
self.imported_actor_classes = set()
|
|
self._loaded_actor_classes = {}
|
|
self.lock = threading.Lock()
|
|
self.execution_infos = {}
|
|
|
|
def increase_task_counter(self, job_id, function_descriptor):
|
|
function_id = function_descriptor.function_id
|
|
if self._worker.load_code_from_local:
|
|
job_id = ray.JobID.nil()
|
|
self._num_task_executions[job_id][function_id] += 1
|
|
|
|
def get_task_counter(self, job_id, function_descriptor):
|
|
function_id = function_descriptor.function_id
|
|
if self._worker.load_code_from_local:
|
|
job_id = ray.JobID.nil()
|
|
return self._num_task_executions[job_id][function_id]
|
|
|
|
def compute_collision_identifier(self, function_or_class):
|
|
"""The identifier is used to detect excessive duplicate exports.
|
|
|
|
The identifier is used to determine when the same function or class is
|
|
exported many times. This can yield false positives.
|
|
|
|
Args:
|
|
function_or_class: The function or class to compute an identifier
|
|
for.
|
|
|
|
Returns:
|
|
The identifier. Note that different functions or classes can give
|
|
rise to same identifier. However, the same function should
|
|
hopefully always give rise to the same identifier. TODO(rkn):
|
|
verify if this is actually the case. Note that if the
|
|
identifier is incorrect in any way, then we may give warnings
|
|
unnecessarily or fail to give warnings, but the application's
|
|
behavior won't change.
|
|
"""
|
|
if sys.version_info[0] >= 3:
|
|
import io
|
|
string_file = io.StringIO()
|
|
if sys.version_info[1] >= 7:
|
|
dis.dis(function_or_class, file=string_file, depth=2)
|
|
else:
|
|
dis.dis(function_or_class, file=string_file)
|
|
collision_identifier = (
|
|
function_or_class.__name__ + ":" + string_file.getvalue())
|
|
else:
|
|
collision_identifier = function_or_class.__name__
|
|
|
|
# Return a hash of the identifier in case it is too large.
|
|
return hashlib.sha1(collision_identifier.encode("ascii")).digest()
|
|
|
|
def export(self, remote_function):
|
|
"""Pickle a remote function and export it to redis.
|
|
|
|
Args:
|
|
remote_function: the RemoteFunction object.
|
|
"""
|
|
if self._worker.mode == ray.worker.LOCAL_MODE:
|
|
return
|
|
if self._worker.load_code_from_local:
|
|
return
|
|
|
|
function = remote_function._function
|
|
pickled_function = pickle.dumps(function)
|
|
|
|
check_oversized_pickle(pickled_function,
|
|
remote_function._function_name,
|
|
"remote function", self._worker)
|
|
key = (b"RemoteFunction:" + self._worker.current_job_id.binary() + b":"
|
|
+ remote_function._function_descriptor.function_id.binary())
|
|
self._worker.redis_client.hmset(
|
|
key, {
|
|
"job_id": self._worker.current_job_id.binary(),
|
|
"function_id": remote_function._function_descriptor.
|
|
function_id.binary(),
|
|
"function_name": remote_function._function_name,
|
|
"module": function.__module__,
|
|
"function": pickled_function,
|
|
"collision_identifier": self.compute_collision_identifier(
|
|
function),
|
|
"max_calls": remote_function._max_calls
|
|
})
|
|
self._worker.redis_client.rpush("Exports", key)
|
|
|
|
def fetch_and_register_remote_function(self, key):
|
|
"""Import a remote function."""
|
|
(job_id_str, function_id_str, function_name, serialized_function,
|
|
num_return_vals, module, resources,
|
|
max_calls) = self._worker.redis_client.hmget(key, [
|
|
"job_id", "function_id", "function_name", "function",
|
|
"num_return_vals", "module", "resources", "max_calls"
|
|
])
|
|
function_id = ray.FunctionID(function_id_str)
|
|
job_id = ray.JobID(job_id_str)
|
|
function_name = decode(function_name)
|
|
max_calls = int(max_calls)
|
|
module = decode(module)
|
|
|
|
# This is a placeholder in case the function can't be unpickled. This
|
|
# will be overwritten if the function is successfully registered.
|
|
def f(*args, **kwargs):
|
|
raise Exception("This function was not imported properly.")
|
|
|
|
# This function is called by ImportThread. This operation needs to be
|
|
# atomic. Otherwise, there is race condition. Another thread may use
|
|
# the temporary function above before the real function is ready.
|
|
with self.lock:
|
|
self._function_execution_info[job_id][function_id] = (
|
|
FunctionExecutionInfo(
|
|
function=f,
|
|
function_name=function_name,
|
|
max_calls=max_calls))
|
|
self._num_task_executions[job_id][function_id] = 0
|
|
|
|
try:
|
|
function = pickle.loads(serialized_function)
|
|
except Exception:
|
|
# If an exception was thrown when the remote function was
|
|
# imported, we record the traceback and notify the scheduler
|
|
# of the failure.
|
|
traceback_str = format_error_message(traceback.format_exc())
|
|
# Log the error message.
|
|
push_error_to_driver(
|
|
self._worker,
|
|
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
|
|
"Failed to unpickle the remote function '{}' with "
|
|
"function ID {}. Traceback:\n{}".format(
|
|
function_name, function_id.hex(), traceback_str),
|
|
job_id=job_id)
|
|
else:
|
|
# The below line is necessary. Because in the driver process,
|
|
# if the function is defined in the file where the python
|
|
# script was started from, its module is `__main__`.
|
|
# However in the worker process, the `__main__` module is a
|
|
# different module, which is `default_worker.py`
|
|
function.__module__ = module
|
|
self._function_execution_info[job_id][function_id] = (
|
|
FunctionExecutionInfo(
|
|
function=function,
|
|
function_name=function_name,
|
|
max_calls=max_calls))
|
|
# Add the function to the function table.
|
|
self._worker.redis_client.rpush(
|
|
b"FunctionTable:" + function_id.binary(),
|
|
self._worker.worker_id)
|
|
|
|
def get_execution_info(self, job_id, function_descriptor):
|
|
"""Get the FunctionExecutionInfo of a remote function.
|
|
|
|
Args:
|
|
job_id: ID of the job that the function belongs to.
|
|
function_descriptor: The FunctionDescriptor of the function to get.
|
|
|
|
Returns:
|
|
A FunctionExecutionInfo object.
|
|
"""
|
|
if self._worker.load_code_from_local:
|
|
# Load function from local code.
|
|
# Currently, we don't support isolating code by jobs,
|
|
# thus always set job ID to NIL here.
|
|
job_id = ray.JobID.nil()
|
|
if not function_descriptor.is_actor_method():
|
|
self._load_function_from_local(job_id, function_descriptor)
|
|
else:
|
|
# Load function from GCS.
|
|
# Wait until the function to be executed has actually been
|
|
# registered on this worker. We will push warnings to the user if
|
|
# we spend too long in this loop.
|
|
# The driver function may not be found in sys.path. Try to load
|
|
# the function from GCS.
|
|
with profiling.profile("wait_for_function"):
|
|
self._wait_for_function(function_descriptor, job_id)
|
|
try:
|
|
function_id = function_descriptor.function_id
|
|
info = self._function_execution_info[job_id][function_id]
|
|
except KeyError as e:
|
|
message = ("Error occurs in get_execution_info: "
|
|
"job_id: %s, function_descriptor: %s. Message: %s" %
|
|
(job_id, function_descriptor, e))
|
|
raise KeyError(message)
|
|
return info
|
|
|
|
def _load_function_from_local(self, job_id, function_descriptor):
|
|
assert not function_descriptor.is_actor_method()
|
|
function_id = function_descriptor.function_id
|
|
if (job_id in self._function_execution_info
|
|
and function_id in self._function_execution_info[function_id]):
|
|
return
|
|
module_name, function_name = (
|
|
function_descriptor.module_name,
|
|
function_descriptor.function_name,
|
|
)
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
function = getattr(module, function_name)._function
|
|
self._function_execution_info[job_id][function_id] = (
|
|
FunctionExecutionInfo(
|
|
function=function,
|
|
function_name=function_name,
|
|
max_calls=0,
|
|
))
|
|
self._num_task_executions[job_id][function_id] = 0
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to load function %s.".format(function_name))
|
|
raise Exception(
|
|
"Function {} failed to be loaded from local code.".format(
|
|
function_descriptor))
|
|
|
|
def _wait_for_function(self, function_descriptor, job_id, timeout=10):
|
|
"""Wait until the function to be executed is present on this worker.
|
|
|
|
This method will simply loop until the import thread has imported the
|
|
relevant function. If we spend too long in this loop, that may indicate
|
|
a problem somewhere and we will push an error message to the user.
|
|
|
|
If this worker is an actor, then this will wait until the actor has
|
|
been defined.
|
|
|
|
Args:
|
|
function_descriptor : The FunctionDescriptor of the function that
|
|
we want to execute.
|
|
job_id (str): The ID of the job to push the error message to
|
|
if this times out.
|
|
"""
|
|
start_time = time.time()
|
|
# Only send the warning once.
|
|
warning_sent = False
|
|
while True:
|
|
with self.lock:
|
|
if (self._worker.actor_id.is_nil()
|
|
and (function_descriptor.function_id in
|
|
self._function_execution_info[job_id])):
|
|
break
|
|
elif not self._worker.actor_id.is_nil() and (
|
|
self._worker.actor_id in self._worker.actors):
|
|
break
|
|
if time.time() - start_time > timeout:
|
|
warning_message = ("This worker was asked to execute a "
|
|
"function that it does not have "
|
|
"registered. You may have to restart "
|
|
"Ray.")
|
|
if not warning_sent:
|
|
ray.utils.push_error_to_driver(
|
|
self._worker,
|
|
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
|
|
warning_message,
|
|
job_id=job_id)
|
|
warning_sent = True
|
|
time.sleep(0.001)
|
|
|
|
def _publish_actor_class_to_key(self, key, actor_class_info):
|
|
"""Push an actor class definition to Redis.
|
|
|
|
The is factored out as a separate function because it is also called
|
|
on cached actor class definitions when a worker connects for the first
|
|
time.
|
|
|
|
Args:
|
|
key: The key to store the actor class info at.
|
|
actor_class_info: Information about the actor class.
|
|
"""
|
|
# We set the driver ID here because it may not have been available when
|
|
# the actor class was defined.
|
|
self._worker.redis_client.hmset(key, actor_class_info)
|
|
self._worker.redis_client.rpush("Exports", key)
|
|
|
|
def export_actor_class(self, Class, actor_method_names):
|
|
if self._worker.load_code_from_local:
|
|
return
|
|
function_descriptor = FunctionDescriptor.from_class(Class)
|
|
# `current_job_id` shouldn't be NIL, unless:
|
|
# 1) This worker isn't an actor;
|
|
# 2) And a previous task started a background thread, which didn't
|
|
# finish before the task finished, and still uses Ray API
|
|
# after that.
|
|
assert not self._worker.current_job_id.is_nil(), (
|
|
"You might have started a background thread in a non-actor task, "
|
|
"please make sure the thread finishes before the task finishes.")
|
|
job_id = self._worker.current_job_id
|
|
key = (b"ActorClass:" + job_id.binary() + b":" +
|
|
function_descriptor.function_id.binary())
|
|
actor_class_info = {
|
|
"class_name": Class.__name__,
|
|
"module": Class.__module__,
|
|
"class": pickle.dumps(Class),
|
|
"job_id": job_id.binary(),
|
|
"collision_identifier": self.compute_collision_identifier(Class),
|
|
"actor_method_names": json.dumps(list(actor_method_names))
|
|
}
|
|
|
|
check_oversized_pickle(actor_class_info["class"],
|
|
actor_class_info["class_name"], "actor",
|
|
self._worker)
|
|
|
|
self._publish_actor_class_to_key(key, actor_class_info)
|
|
# TODO(rkn): Currently we allow actor classes to be defined
|
|
# within tasks. I tried to disable this, but it may be necessary
|
|
# because of https://github.com/ray-project/ray/issues/1146.
|
|
|
|
def load_actor_class(self, job_id, function_descriptor):
|
|
"""Load the actor class.
|
|
|
|
Args:
|
|
job_id: job ID of the actor.
|
|
function_descriptor: Function descriptor of the actor constructor.
|
|
|
|
Returns:
|
|
The actor class.
|
|
"""
|
|
function_id = function_descriptor.function_id
|
|
# Check if the actor class already exists in the cache.
|
|
actor_class = self._loaded_actor_classes.get(function_id, None)
|
|
if actor_class is None:
|
|
# Load actor class.
|
|
if self._worker.load_code_from_local:
|
|
job_id = ray.JobID.nil()
|
|
# Load actor class from local code.
|
|
actor_class = self._load_actor_from_local(
|
|
job_id, function_descriptor)
|
|
else:
|
|
# Load actor class from GCS.
|
|
actor_class = self._load_actor_class_from_gcs(
|
|
job_id, function_descriptor)
|
|
# Save the loaded actor class in cache.
|
|
self._loaded_actor_classes[function_id] = actor_class
|
|
|
|
# Generate execution info for the methods of this actor class.
|
|
module_name = function_descriptor.module_name
|
|
actor_class_name = function_descriptor.class_name
|
|
actor_methods = inspect.getmembers(
|
|
actor_class, predicate=is_function_or_method)
|
|
for actor_method_name, actor_method in actor_methods:
|
|
method_descriptor = FunctionDescriptor(
|
|
module_name, actor_method_name, actor_class_name)
|
|
method_id = method_descriptor.function_id
|
|
executor = self._make_actor_method_executor(
|
|
actor_method_name,
|
|
actor_method,
|
|
actor_imported=True,
|
|
)
|
|
self._function_execution_info[job_id][method_id] = (
|
|
FunctionExecutionInfo(
|
|
function=executor,
|
|
function_name=actor_method_name,
|
|
max_calls=0,
|
|
))
|
|
self._num_task_executions[job_id][method_id] = 0
|
|
self._num_task_executions[job_id][function_id] = 0
|
|
return actor_class
|
|
|
|
def _load_actor_from_local(self, job_id, function_descriptor):
|
|
"""Load actor class from local code."""
|
|
assert isinstance(job_id, ray.JobID)
|
|
module_name, class_name = (function_descriptor.module_name,
|
|
function_descriptor.class_name)
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
actor_class = getattr(module, class_name)
|
|
if isinstance(actor_class, ray.actor.ActorClass):
|
|
return actor_class.__ray_metadata__.modified_class
|
|
else:
|
|
return actor_class
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to load actor_class %s.".format(class_name))
|
|
raise Exception(
|
|
"Actor {} failed to be imported from local code.".format(
|
|
class_name))
|
|
|
|
def _create_fake_actor_class(self, actor_class_name, actor_method_names):
|
|
class TemporaryActor(object):
|
|
pass
|
|
|
|
def temporary_actor_method(*args, **kwargs):
|
|
raise Exception(
|
|
"The actor with name {} failed to be imported, "
|
|
"and so cannot execute this method.".format(actor_class_name))
|
|
|
|
for method in actor_method_names:
|
|
setattr(TemporaryActor, method, temporary_actor_method)
|
|
|
|
return TemporaryActor
|
|
|
|
def _load_actor_class_from_gcs(self, job_id, function_descriptor):
|
|
"""Load actor class from GCS."""
|
|
key = (b"ActorClass:" + job_id.binary() + b":" +
|
|
function_descriptor.function_id.binary())
|
|
# Wait for the actor class key to have been imported by the
|
|
# import thread. TODO(rkn): It shouldn't be possible to end
|
|
# up in an infinite loop here, but we should push an error to
|
|
# the driver if too much time is spent here.
|
|
while key not in self.imported_actor_classes:
|
|
time.sleep(0.001)
|
|
|
|
# Fetch raw data from GCS.
|
|
(job_id_str, class_name, module, pickled_class,
|
|
actor_method_names) = self._worker.redis_client.hmget(
|
|
key,
|
|
["job_id", "class_name", "module", "class", "actor_method_names"])
|
|
|
|
class_name = ensure_str(class_name)
|
|
module_name = ensure_str(module)
|
|
job_id = ray.JobID(job_id_str)
|
|
actor_method_names = json.loads(ensure_str(actor_method_names))
|
|
|
|
actor_class = None
|
|
try:
|
|
with self.lock:
|
|
actor_class = pickle.loads(pickled_class)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to load actor class %s.".format(class_name))
|
|
# The actor class failed to be unpickled, create a fake actor
|
|
# class instead (just to produce error messages and to prevent
|
|
# the driver from hanging).
|
|
actor_class = self._create_fake_actor_class(
|
|
class_name, actor_method_names)
|
|
# If an exception was thrown when the actor was imported, we record
|
|
# the traceback and notify the scheduler of the failure.
|
|
traceback_str = ray.utils.format_error_message(
|
|
traceback.format_exc())
|
|
# Log the error message.
|
|
push_error_to_driver(
|
|
self._worker,
|
|
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
|
|
"Failed to unpickle actor class '{}' for actor ID {}. "
|
|
"Traceback:\n{}".format(
|
|
class_name, self._worker.actor_id.hex(), traceback_str),
|
|
job_id=job_id)
|
|
# TODO(rkn): In the future, it might make sense to have the worker
|
|
# exit here. However, currently that would lead to hanging if
|
|
# someone calls ray.get on a method invoked on the actor.
|
|
|
|
# The below line is necessary. Because in the driver process,
|
|
# if the function is defined in the file where the python script
|
|
# was started from, its module is `__main__`.
|
|
# However in the worker process, the `__main__` module is a
|
|
# different module, which is `default_worker.py`
|
|
actor_class.__module__ = module_name
|
|
return actor_class
|
|
|
|
def _make_actor_method_executor(self, method_name, method, actor_imported):
|
|
"""Make an executor that wraps a user-defined actor method.
|
|
|
|
The wrapped method updates the worker's internal state and performs any
|
|
necessary checkpointing operations.
|
|
|
|
Args:
|
|
method_name (str): The name of the actor method.
|
|
method (instancemethod): The actor method to wrap. This should be a
|
|
method defined on the actor class and should therefore take an
|
|
instance of the actor as the first argument.
|
|
actor_imported (bool): Whether the actor has been imported.
|
|
Checkpointing operations will not be run if this is set to
|
|
False.
|
|
|
|
Returns:
|
|
A function that executes the given actor method on the worker's
|
|
stored instance of the actor. The function also updates the
|
|
worker's internal state to record the executed method.
|
|
"""
|
|
|
|
def actor_method_executor(actor, *args, **kwargs):
|
|
# Update the actor's task counter to reflect the task we're about
|
|
# to execute.
|
|
self._worker.actor_task_counter += 1
|
|
|
|
# Execute the assigned method and save a checkpoint if necessary.
|
|
try:
|
|
if is_class_method(method):
|
|
method_returns = method(*args, **kwargs)
|
|
else:
|
|
method_returns = method(actor, *args, **kwargs)
|
|
except Exception as e:
|
|
# Save the checkpoint before allowing the method exception
|
|
# to be thrown, but don't save the checkpoint for actor
|
|
# creation task.
|
|
if (isinstance(actor, ray.actor.Checkpointable)
|
|
and self._worker.actor_task_counter != 1):
|
|
self._save_and_log_checkpoint(actor)
|
|
raise e
|
|
else:
|
|
# Handle any checkpointing operations before storing the
|
|
# method's return values.
|
|
# NOTE(swang): If method_returns is a pointer to the actor's
|
|
# state and the checkpointing operations can modify the return
|
|
# values if they mutate the actor's state. Is this okay?
|
|
if isinstance(actor, ray.actor.Checkpointable):
|
|
# If this is the first task to execute on the actor, try to
|
|
# resume from a checkpoint.
|
|
if self._worker.actor_task_counter == 1:
|
|
if actor_imported:
|
|
self._restore_and_log_checkpoint(actor)
|
|
else:
|
|
# Save the checkpoint before returning the method's
|
|
# return values.
|
|
self._save_and_log_checkpoint(actor)
|
|
return method_returns
|
|
|
|
# Set method_name and method as attributes to the executor clusore
|
|
# so we can make decision based on these attributes in task executor.
|
|
# Precisely, asyncio support requires to know whether:
|
|
# - the method is a ray internal method: starts with __ray
|
|
# - the method is a coroutine function: defined by async def
|
|
actor_method_executor.name = method_name
|
|
actor_method_executor.method = method
|
|
|
|
return actor_method_executor
|
|
|
|
def _save_and_log_checkpoint(self, actor):
|
|
"""Save an actor checkpoint if necessary and log any errors.
|
|
|
|
Args:
|
|
actor: The actor to checkpoint.
|
|
|
|
Returns:
|
|
The result of the actor's user-defined `save_checkpoint` method.
|
|
"""
|
|
actor_id = self._worker.actor_id
|
|
checkpoint_info = self._worker.actor_checkpoint_info[actor_id]
|
|
checkpoint_info.num_tasks_since_last_checkpoint += 1
|
|
now = int(1000 * time.time())
|
|
checkpoint_context = ray.actor.CheckpointContext(
|
|
actor_id, checkpoint_info.num_tasks_since_last_checkpoint,
|
|
now - checkpoint_info.last_checkpoint_timestamp)
|
|
# If we should take a checkpoint, notify raylet to prepare a checkpoint
|
|
# and then call `save_checkpoint`.
|
|
if actor.should_checkpoint(checkpoint_context):
|
|
try:
|
|
now = int(1000 * time.time())
|
|
checkpoint_id = (self._worker.raylet_client.
|
|
prepare_actor_checkpoint(actor_id))
|
|
checkpoint_info.checkpoint_ids.append(checkpoint_id)
|
|
actor.save_checkpoint(actor_id, checkpoint_id)
|
|
if (len(checkpoint_info.checkpoint_ids) >
|
|
ray._config.num_actor_checkpoints_to_keep()):
|
|
actor.checkpoint_expired(
|
|
actor_id,
|
|
checkpoint_info.checkpoint_ids.pop(0),
|
|
)
|
|
checkpoint_info.num_tasks_since_last_checkpoint = 0
|
|
checkpoint_info.last_checkpoint_timestamp = now
|
|
except Exception:
|
|
# Checkpoint save or reload failed. Notify the driver.
|
|
traceback_str = ray.utils.format_error_message(
|
|
traceback.format_exc())
|
|
ray.utils.push_error_to_driver(
|
|
self._worker,
|
|
ray_constants.CHECKPOINT_PUSH_ERROR,
|
|
traceback_str,
|
|
job_id=self._worker.current_job_id)
|
|
|
|
def _restore_and_log_checkpoint(self, actor):
|
|
"""Restore an actor from a checkpoint if available and log any errors.
|
|
|
|
This should only be called on workers that have just executed an actor
|
|
creation task.
|
|
|
|
Args:
|
|
actor: The actor to restore from a checkpoint.
|
|
"""
|
|
actor_id = self._worker.actor_id
|
|
try:
|
|
checkpoints = ray.actor.get_checkpoints_for_actor(actor_id)
|
|
if len(checkpoints) > 0:
|
|
# If we found previously saved checkpoints for this actor,
|
|
# call the `load_checkpoint` callback.
|
|
checkpoint_id = actor.load_checkpoint(actor_id, checkpoints)
|
|
if checkpoint_id is not None:
|
|
# Check that the returned checkpoint id is in the
|
|
# `available_checkpoints` list.
|
|
msg = (
|
|
"`load_checkpoint` must return a checkpoint id that " +
|
|
"exists in the `available_checkpoints` list, or None.")
|
|
assert any(checkpoint_id == checkpoint.checkpoint_id
|
|
for checkpoint in checkpoints), msg
|
|
# Notify raylet that this actor has been resumed from
|
|
# a checkpoint.
|
|
(self._worker.raylet_client.
|
|
notify_actor_resumed_from_checkpoint(
|
|
actor_id, checkpoint_id))
|
|
except Exception:
|
|
# Checkpoint save or reload failed. Notify the driver.
|
|
traceback_str = ray.utils.format_error_message(
|
|
traceback.format_exc())
|
|
ray.utils.push_error_to_driver(
|
|
self._worker,
|
|
ray_constants.CHECKPOINT_PUSH_ERROR,
|
|
traceback_str,
|
|
job_id=self._worker.current_job_id)
|