mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Function serialization (#261)
This commit is contained in:
parent
8e0ecfa1f4
commit
f2c43bec87
15 changed files with 180 additions and 43 deletions
|
@ -19,7 +19,7 @@ brew update
|
|||
brew install git cmake automake autoconf libtool boost graphviz
|
||||
sudo easy_install pip
|
||||
sudo pip install ipython --user
|
||||
sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz --ignore-installed six
|
||||
sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz cloudpickle --ignore-installed six
|
||||
```
|
||||
|
||||
### Build
|
||||
|
|
|
@ -15,7 +15,7 @@ First install the dependencies. We currently do not support Python 3.
|
|||
```
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
|
||||
sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz
|
||||
sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz cloudpickle
|
||||
```
|
||||
|
||||
### Build
|
||||
|
|
|
@ -31,11 +31,11 @@ if [[ $platform == "linux" ]]; then
|
|||
# These commands must be kept in sync with the installation instructions.
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
|
||||
sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz
|
||||
sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz cloudpickle
|
||||
elif [[ $platform == "macosx" ]]; then
|
||||
# These commands must be kept in sync with the installation instructions.
|
||||
brew install git cmake automake autoconf libtool boost graphviz
|
||||
sudo easy_install pip
|
||||
sudo pip install ipython --user
|
||||
sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0-alpha-2 colorama graphviz --ignore-installed six
|
||||
sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0-alpha-2 colorama graphviz cloudpickle --ignore-installed six
|
||||
fi
|
||||
|
|
7
lib/python/ray/pickling.py
Normal file
7
lib/python/ray/pickling.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
import cloudpickle
|
||||
|
||||
def dumps(func, arg_types, return_types):
|
||||
return cloudpickle.dumps((func, arg_types, return_types))
|
||||
|
||||
def loads(function):
|
||||
return cloudpickle.loads(function)
|
|
@ -67,6 +67,6 @@ def serialize_task(worker_capsule, func_name, args):
|
|||
return ray.lib.serialize_task(worker_capsule, func_name, primitive_args)
|
||||
|
||||
def deserialize_task(worker_capsule, task):
|
||||
func_name, primitive_args, return_objrefs = ray.lib.deserialize_task(worker_capsule, task)
|
||||
func_name, primitive_args, return_objrefs = task
|
||||
args = [(arg if isinstance(arg, ray.lib.ObjRef) else from_primitive(arg)) for arg in primitive_args]
|
||||
return func_name, args, return_objrefs
|
||||
|
|
|
@ -10,6 +10,7 @@ import numpy as np
|
|||
import colorama
|
||||
|
||||
import ray
|
||||
import pickling
|
||||
import serialization
|
||||
import ray.internal.graph_pb2
|
||||
import ray.graph
|
||||
|
@ -120,6 +121,7 @@ class Worker(object):
|
|||
"""Initialize a Worker object."""
|
||||
self.functions = {}
|
||||
self.handle = None
|
||||
self.mode = None
|
||||
|
||||
def set_mode(self, mode):
|
||||
"""Set the mode of the worker.
|
||||
|
@ -510,13 +512,26 @@ def main_loop(worker=global_worker):
|
|||
store_outputs_in_objstore(return_objrefs, outputs, worker) # store output in local object store
|
||||
ray.lib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully
|
||||
while True:
|
||||
task = ray.lib.wait_for_next_task(worker.handle)
|
||||
if task is None:
|
||||
# We use this as a mechanism to allow the scheduler to kill workers. When
|
||||
# the scheduler wants to kill a worker, it gives the worker a null task,
|
||||
# causing the worker program to exit the main loop here.
|
||||
break
|
||||
process_task(task)
|
||||
(task, function) = ray.lib.wait_for_next_message(worker.handle)
|
||||
try:
|
||||
# Currently the schedule does not ask the worker to execute a task and
|
||||
# import a function at the same time.
|
||||
assert task is None or function is None
|
||||
if task is None and function is None:
|
||||
# We use this as a mechanism to allow the scheduler to kill workers. When
|
||||
# the scheduler wants to kill a worker, it gives the worker a null task,
|
||||
# causing the worker program to exit the main loop here.
|
||||
break
|
||||
if function is not None:
|
||||
(function, arg_types, return_types) = pickling.loads(function)
|
||||
if function.__module__ is None: function.__module__ = "__main__"
|
||||
worker.register_function(remote(arg_types, return_types, worker)(function))
|
||||
if task is not None:
|
||||
process_task(task)
|
||||
finally:
|
||||
# Allow releasing the variables BEFORE we wait for the next message or exit the block
|
||||
del task
|
||||
del function
|
||||
|
||||
def remote(arg_types, return_types, worker=global_worker):
|
||||
"""This decorator is used to create remote functions.
|
||||
|
@ -526,6 +541,7 @@ def remote(arg_types, return_types, worker=global_worker):
|
|||
return_types (List[type]): List of Python types of the return values.
|
||||
"""
|
||||
def remote_decorator(func):
|
||||
to_export = pickling.dumps(func, arg_types, return_types) if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE] else None
|
||||
def func_executor(arguments):
|
||||
"""This gets run when the remote function is executed."""
|
||||
logging.info("Calling function {}".format(func.__name__))
|
||||
|
@ -560,6 +576,8 @@ def remote(arg_types, return_types, worker=global_worker):
|
|||
func_call.has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in func_call.sig_params])
|
||||
func_call.has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in func_call.sig_params])
|
||||
check_signature_supported(func_call)
|
||||
if to_export is not None:
|
||||
ray.lib.export_function(worker.handle, to_export)
|
||||
return func_call
|
||||
return remote_decorator
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ syntax = "proto3";
|
|||
import "types.proto";
|
||||
|
||||
message Task {
|
||||
string name = 1; // Name of the function call
|
||||
string name = 1; // Name of the function call. Must not be empty.
|
||||
repeated Value arg = 2; // List of arguments, can be either object references or protobuf descriptions of object passed by value
|
||||
repeated uint64 result = 3; // Object references for result
|
||||
}
|
||||
|
|
|
@ -52,6 +52,8 @@ service Scheduler {
|
|||
rpc TaskInfo(TaskInfoRequest) returns (TaskInfoReply);
|
||||
// Kills the workers
|
||||
rpc KillWorkers(KillWorkersRequest) returns (KillWorkersReply);
|
||||
// Exports function to the workers
|
||||
rpc ExportFunction(ExportFunctionRequest) returns (ExportFunctionReply);
|
||||
}
|
||||
|
||||
message AckReply {
|
||||
|
@ -228,6 +230,13 @@ message KillWorkersReply {
|
|||
bool success = 1; // Currently, the only reason to fail is if there are workers still executing tasks
|
||||
}
|
||||
|
||||
message ExportFunctionRequest {
|
||||
Function function = 1;
|
||||
}
|
||||
|
||||
message ExportFunctionReply {
|
||||
}
|
||||
|
||||
// These messages are for getting information about the object store state
|
||||
|
||||
message ObjStoreInfoRequest {
|
||||
|
@ -243,6 +252,7 @@ message ObjStoreInfoReply {
|
|||
|
||||
service WorkerService {
|
||||
rpc ExecuteTask(ExecuteTaskRequest) returns (ExecuteTaskReply); // Scheduler calls a function from the worker
|
||||
rpc ImportFunction(ImportFunctionRequest) returns (ImportFunctionReply); // Scheduler imports a function into the worker
|
||||
rpc Die(DieRequest) returns (DieReply); // Kills this worker
|
||||
}
|
||||
|
||||
|
@ -253,6 +263,13 @@ message ExecuteTaskRequest {
|
|||
message ExecuteTaskReply {
|
||||
}
|
||||
|
||||
message ImportFunctionRequest {
|
||||
Function function = 1;
|
||||
}
|
||||
|
||||
message ImportFunctionReply {
|
||||
}
|
||||
|
||||
message DieRequest {
|
||||
}
|
||||
|
||||
|
|
|
@ -32,6 +32,11 @@ message PyObj {
|
|||
bytes data = 1;
|
||||
}
|
||||
|
||||
// Used for shipping remote functions to workers
|
||||
message Function {
|
||||
bytes implementation = 1;
|
||||
}
|
||||
|
||||
// Union of possible object types
|
||||
message Obj {
|
||||
String string_data = 1;
|
||||
|
|
|
@ -592,12 +592,7 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) {
|
|||
return PyCapsule_New(static_cast<void*>(task), "task", &TaskCapsule_Destructor);
|
||||
}
|
||||
|
||||
static PyObject* deserialize_task(PyObject* self, PyObject* args) {
|
||||
PyObject* worker_capsule;
|
||||
Task* task;
|
||||
if (!PyArg_ParseTuple(args, "OO&", &worker_capsule, &PyObjectToTask, &task)) {
|
||||
return NULL;
|
||||
}
|
||||
static PyObject* deserialize_task(PyObject* worker_capsule, Task* task) {
|
||||
std::vector<ObjRef> objrefs; // This is a vector of all the objrefs that were serialized in this task, including objrefs that are contained in Python objects that are passed by value.
|
||||
PyObject* string = PyString_FromStringAndSize(task->name().c_str(), task->name().size());
|
||||
int argsize = task->arg_size();
|
||||
|
@ -667,19 +662,36 @@ static PyObject* connected(PyObject* self, PyObject* args) {
|
|||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
static PyObject* wait_for_next_task(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
|
||||
static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
|
||||
PyObject* worker_capsule;
|
||||
if (!PyArg_ParseTuple(args, "O", &worker_capsule)) {
|
||||
return NULL;
|
||||
}
|
||||
if (std::unique_ptr<Task> task = worker->receive_next_task()) {
|
||||
PyObject* pyobj = PyCapsule_New(task.get(), "task", TaskCapsule_Destructor);
|
||||
task.release(); // Now that the wrapper object was constructed successfully, release ownership
|
||||
return pyobj;
|
||||
Worker* worker;
|
||||
PyObjectToWorker(worker_capsule, &worker);
|
||||
if (std::unique_ptr<WorkerMessage> message = worker->receive_next_message()) {
|
||||
PyObject* t = PyTuple_New(2); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple.
|
||||
PyTuple_SetItem(t, 0, message->task.name().empty() ? Py_None : deserialize_task(worker_capsule, &message->task));
|
||||
PyTuple_SetItem(t, 1, message->function.empty() ? Py_None : PyString_FromStringAndSize(message->function.data(), static_cast<ssize_t>(message->function.size())));
|
||||
return t;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* export_function(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
const char* function;
|
||||
int function_size;
|
||||
if (!PyArg_ParseTuple(args, "O&s#", &PyObjectToWorker, &worker, &function, &function_size)) {
|
||||
return NULL;
|
||||
}
|
||||
if (worker->export_function(std::string(function, static_cast<size_t>(function_size)))) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject* submit_task(PyObject* self, PyObject* args) {
|
||||
PyObject* worker_capsule;
|
||||
Task* task;
|
||||
|
@ -918,7 +930,6 @@ static PyMethodDef RayLibMethods[] = {
|
|||
{ "is_arrow", is_arrow, METH_VARARGS, "is the object in the local object store an arrow object?"},
|
||||
{ "unmap_object", unmap_object, METH_VARARGS, "unmap the object from the client's shared memory pool"},
|
||||
{ "serialize_task", serialize_task, METH_VARARGS, "serialize a task to protocol buffers" },
|
||||
{ "deserialize_task", deserialize_task, METH_VARARGS, "deserialize a task from protocol buffers" },
|
||||
{ "create_worker", create_worker, METH_VARARGS, "connect to the scheduler and the object store" },
|
||||
{ "disconnect", disconnect, METH_VARARGS, "disconnect the worker from the scheduler and the object store" },
|
||||
{ "connected", connected, METH_VARARGS, "check if the worker is connected to the scheduler and the object store" },
|
||||
|
@ -928,12 +939,13 @@ static PyMethodDef RayLibMethods[] = {
|
|||
{ "get_objref", get_objref, METH_VARARGS, "register a new object reference with the scheduler" },
|
||||
{ "request_object" , request_object, METH_VARARGS, "request an object to be delivered to the local object store" },
|
||||
{ "alias_objrefs", alias_objrefs, METH_VARARGS, "make two objrefs refer to the same object" },
|
||||
{ "wait_for_next_task", wait_for_next_task, METH_VARARGS, "get next task from scheduler (blocking)" },
|
||||
{ "wait_for_next_message", wait_for_next_message, METH_VARARGS, "get next message from scheduler (blocking)" },
|
||||
{ "submit_task", submit_task, METH_VARARGS, "call a remote function" },
|
||||
{ "notify_task_completed", notify_task_completed, METH_VARARGS, "notify the scheduler that a task has been completed" },
|
||||
{ "start_worker_service", start_worker_service, METH_VARARGS, "start the worker service" },
|
||||
{ "scheduler_info", scheduler_info, METH_VARARGS, "get info about scheduler state" },
|
||||
{ "task_info", task_info, METH_VARARGS, "get task statuses" },
|
||||
{ "export_function", export_function, METH_VARARGS, "export function to workers" },
|
||||
{ "dump_computation_graph", dump_computation_graph, METH_VARARGS, "dump the current computation graph to a file" },
|
||||
{ "set_log_config", set_log_config, METH_VARARGS, "set filename for raylib logging" },
|
||||
{ "kill_workers", kill_workers, METH_VARARGS, "kills all of the workers" },
|
||||
|
|
|
@ -291,6 +291,20 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe
|
|||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) {
|
||||
auto workers = workers_.get();
|
||||
for (size_t i = 0; i < workers->size(); ++i) {
|
||||
ClientContext import_context;
|
||||
ImportFunctionRequest import_request;
|
||||
import_request.mutable_function()->set_implementation(request->function().implementation());
|
||||
if ((*workers)[i].current_task != ROOT_OPERATION) {
|
||||
ImportFunctionReply import_reply;
|
||||
(*workers)[i].worker_stub->ImportFunction(&import_context, import_request, &import_reply);
|
||||
}
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
void SchedulerService::deliver_object_async_if_necessary(ObjRef canonical_objref, ObjStoreId from, ObjStoreId to) {
|
||||
bool object_present_or_in_transit;
|
||||
{
|
||||
|
|
|
@ -70,6 +70,7 @@ public:
|
|||
Status SchedulerInfo(ServerContext* context, const SchedulerInfoRequest* request, SchedulerInfoReply* reply) override;
|
||||
Status TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) override;
|
||||
Status KillWorkers(ServerContext* context, const KillWorkersRequest* request, KillWorkersReply* reply) override;
|
||||
Status ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) override;
|
||||
|
||||
// This will ask an object store to send an object to another object store if
|
||||
// the object is not already present in that object store and is not already
|
||||
|
|
|
@ -17,16 +17,32 @@ inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address)
|
|||
}
|
||||
|
||||
Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) {
|
||||
task_ = std::unique_ptr<Task>(new Task(request->task())); // Copy task
|
||||
RAY_LOG(RAY_INFO, "invoked task " << request->task().name());
|
||||
WorkerMessage message = { &task_ };
|
||||
RAY_CHECK(send_queue_.send(&message), "error sending over IPC");
|
||||
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
|
||||
message->task = request->task();
|
||||
{
|
||||
WorkerMessage* message_ptr = message.get();
|
||||
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
|
||||
}
|
||||
message.release();
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) {
|
||||
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
|
||||
message->function = request->function().implementation();
|
||||
RAY_LOG(RAY_INFO, "importing function");
|
||||
{
|
||||
WorkerMessage* message_ptr = message.get();
|
||||
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
|
||||
}
|
||||
message.release();
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, DieReply* reply) {
|
||||
WorkerMessage message = { NULL };
|
||||
RAY_CHECK(send_queue_.send(&message), "error sending over IPC");
|
||||
WorkerMessage* message_ptr = NULL;
|
||||
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
|
@ -285,10 +301,10 @@ void Worker::register_function(const std::string& name, size_t num_return_vals)
|
|||
scheduler_stub_->RegisterFunction(&context, request, &reply);
|
||||
}
|
||||
|
||||
std::unique_ptr<Task> Worker::receive_next_task() {
|
||||
WorkerMessage message;
|
||||
RAY_CHECK(receive_queue_.receive(&message), "error receiving over IPC");
|
||||
return message.task ? std::move(*message.task) : std::unique_ptr<Task>();
|
||||
std::unique_ptr<WorkerMessage> Worker::receive_next_message() {
|
||||
WorkerMessage* message_ptr;
|
||||
RAY_CHECK(receive_queue_.receive(&message_ptr), "error receiving over IPC");
|
||||
return std::unique_ptr<WorkerMessage>(message_ptr);
|
||||
}
|
||||
|
||||
void Worker::notify_task_completed(bool task_succeeded, std::string error_message) {
|
||||
|
@ -322,6 +338,16 @@ void Worker::task_info(ClientContext &context, TaskInfoRequest &request, TaskInf
|
|||
scheduler_stub_->TaskInfo(&context, request, &reply);
|
||||
}
|
||||
|
||||
bool Worker::export_function(const std::string& function) {
|
||||
RAY_CHECK(connected_, "Attempted to export function but failed.");
|
||||
ClientContext context;
|
||||
ExportFunctionRequest request;
|
||||
request.mutable_function()->set_implementation(function);
|
||||
ExportFunctionReply reply;
|
||||
Status status = scheduler_stub_->ExportFunction(&context, request, &reply);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Communication between the WorkerServer and the Worker happens via a message
|
||||
// queue. This is because the Python interpreter needs to be single threaded
|
||||
// (in our case running in the main thread), whereas the WorkerService will
|
||||
|
|
14
src/worker.h
14
src/worker.h
|
@ -24,19 +24,19 @@ using grpc::ClientContext;
|
|||
using grpc::ClientWriter;
|
||||
|
||||
struct WorkerMessage {
|
||||
std::unique_ptr<Task>* task;
|
||||
Task task;
|
||||
std::string function;
|
||||
};
|
||||
static_assert(std::is_pod<WorkerMessage>::value, "WorkerMessage must be memcpy-able");
|
||||
|
||||
class WorkerServiceImpl final : public WorkerService::Service {
|
||||
public:
|
||||
WorkerServiceImpl(const std::string& worker_address);
|
||||
Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) override;
|
||||
Status ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) override;
|
||||
Status Die(ServerContext* context, const DieRequest* request, DieReply* reply) override;
|
||||
private:
|
||||
std::string worker_address_;
|
||||
std::unique_ptr<Task> task_; // copy of the current task
|
||||
MessageQueue<WorkerMessage> send_queue_;
|
||||
MessageQueue<WorkerMessage*> send_queue_;
|
||||
};
|
||||
|
||||
class Worker {
|
||||
|
@ -79,7 +79,7 @@ class Worker {
|
|||
// it in the message queue, which is read by the Python interpreter
|
||||
void start_worker_service();
|
||||
// wait for next task from the RPC system. If null, it means there are no more tasks and the worker should shut down.
|
||||
std::unique_ptr<Task> receive_next_task();
|
||||
std::unique_ptr<WorkerMessage> receive_next_message();
|
||||
// tell the scheduler that we are done with the current task and request the
|
||||
// next one, if task_succeeded is false, this tells the scheduler that the
|
||||
// task threw an exception
|
||||
|
@ -92,13 +92,15 @@ class Worker {
|
|||
void scheduler_info(ClientContext &context, SchedulerInfoRequest &request, SchedulerInfoReply &reply);
|
||||
// get task statuses from scheduler
|
||||
void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply);
|
||||
// export function to workers
|
||||
bool export_function(const std::string& function);
|
||||
|
||||
private:
|
||||
bool connected_;
|
||||
const size_t CHUNK_SIZE = 8 * 1024;
|
||||
std::unique_ptr<Scheduler::Stub> scheduler_stub_;
|
||||
std::thread worker_server_thread_;
|
||||
MessageQueue<WorkerMessage> receive_queue_;
|
||||
MessageQueue<WorkerMessage*> receive_queue_;
|
||||
bip::managed_shared_memory segment_;
|
||||
WorkerId workerid_;
|
||||
ObjStoreId objstoreid_;
|
||||
|
|
|
@ -272,6 +272,41 @@ class APITest(unittest.TestCase):
|
|||
|
||||
ray.services.cleanup()
|
||||
|
||||
def testDefiningRemoteFunctions(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
ray.services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.SCRIPT_MODE)
|
||||
|
||||
# Test that we can define a remote function in the shell.
|
||||
@ray.remote([int], [int])
|
||||
def f(x):
|
||||
return x + 1
|
||||
self.assertEqual(ray.get(f(0)), 1)
|
||||
|
||||
# Test that we can redefine the remote function.
|
||||
@ray.remote([int], [int])
|
||||
def f(x):
|
||||
return x + 10
|
||||
self.assertEqual(ray.get(f(0)), 10)
|
||||
|
||||
# Test that we can close over plain old data.
|
||||
data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 2L], 2L, {"a": np.zeros(3)}]
|
||||
@ray.remote([], [list])
|
||||
def g():
|
||||
return data
|
||||
ray.get(g())
|
||||
|
||||
# Test that we can close over modules.
|
||||
@ray.remote([], [np.ndarray])
|
||||
def h():
|
||||
return np.zeros([3, 5])
|
||||
self.assertTrue(np.alltrue(ray.get(h()) == np.zeros([3, 5])))
|
||||
@ray.remote([], [float])
|
||||
def j():
|
||||
return time.time()
|
||||
ray.get(j())
|
||||
|
||||
ray.services.cleanup()
|
||||
|
||||
class TaskStatusTest(unittest.TestCase):
|
||||
def testFailedTask(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
|
|
Loading…
Add table
Reference in a new issue