ray/python/ray/local_mode_manager.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

152 lines
5.1 KiB
Python

import copy
import traceback
import ray
from ray import ObjectID
from ray.utils import format_error_message
from ray.exceptions import RayTaskError
class LocalModeObjectID(ObjectID):
"""Wrapper class around ray.ObjectID used for local mode.
Object values are stored directly as a field of the LocalModeObjectID.
Attributes:
value: Field that stores object values. If this field does not exist,
it equates to the object not existing in the object store. This is
necessary because None is a valid object value.
"""
def __copy__(self):
new = LocalModeObjectID(self.binary())
if hasattr(self, "value"):
new.value = self.value
return new
def __deepcopy__(self, memo=None):
new = LocalModeObjectID(self.binary())
if hasattr(self, "value"):
new.value = self.value
return new
class LocalModeManager:
"""Used to emulate remote operations when running in local mode."""
def __init__(self):
"""Initialize a LocalModeManager."""
def execute(self, function, function_name, args, kwargs, num_return_vals):
"""Synchronously executes a "remote" function or actor method.
Stores results directly in the generated and returned
LocalModeObjectIDs. Any exceptions raised during function execution
will be stored under all returned object IDs and later raised by the
worker.
Args:
function: The function to execute.
function_name: Name of the function to execute.
args: Arguments to the function. These will not be modified by
the function execution.
kwargs: Keyword arguments to the function.
num_return_vals: Number of expected return values specified in the
function's decorator.
Returns:
LocalModeObjectIDs corresponding to the function return values.
"""
return_ids = [
LocalModeObjectID.from_random() for _ in range(num_return_vals)
]
new_args = []
for i, arg in enumerate(args):
if isinstance(arg, ObjectID):
new_args.append(ray.get(arg))
else:
new_args.append(copy.deepcopy(arg))
new_kwargs = {}
for k, v in kwargs.items():
if isinstance(v, ObjectID):
new_kwargs[k] = ray.get(v)
else:
new_kwargs[k] = copy.deepcopy(v)
try:
results = function(*new_args, **new_kwargs)
if num_return_vals == 1:
return_ids[0].value = results
else:
for object_id, result in zip(return_ids, results):
object_id.value = result
except Exception as e:
backtrace = format_error_message(traceback.format_exc())
task_error = RayTaskError(function_name, backtrace, e.__class__)
for object_id in return_ids:
object_id.value = task_error
return return_ids
def put_object(self, value):
"""Store an object in the emulated object store.
Implemented by generating a LocalModeObjectID and storing the value
directly within it.
Args:
value: The value to store.
Returns:
LocalModeObjectID corresponding to the value.
"""
object_id = LocalModeObjectID.from_random()
object_id.value = value
return object_id
def get_objects(self, object_ids):
"""Fetch objects from the emulated object store.
Accepts only LocalModeObjectIDs and reads values directly from them.
Args:
object_ids: A list of object IDs to fetch values for.
Raises:
TypeError if any of the object IDs are not LocalModeObjectIDs.
KeyError if any of the object IDs do not contain values.
"""
results = []
for object_id in object_ids:
if not isinstance(object_id, LocalModeObjectID):
raise TypeError("Only LocalModeObjectIDs are supported "
"when running in LOCAL_MODE. Using "
"user-generated ObjectIDs will fail.")
if not hasattr(object_id, "value"):
raise KeyError("Value for {} not found".format(object_id))
results.append(object_id.value)
return results
def free(self, object_ids):
"""Delete objects from the emulated object store.
Accepts only LocalModeObjectIDs and deletes their values directly.
Args:
object_ids: A list of ObjectIDs to delete.
Raises:
TypeError if any of the object IDs are not LocalModeObjectIDs.
"""
for object_id in object_ids:
if not isinstance(object_id, LocalModeObjectID):
raise TypeError("Only LocalModeObjectIDs are supported "
"when running in LOCAL_MODE. Using "
"user-generated ObjectIDs will fail.")
try:
del object_id.value
except AttributeError:
pass