Merge pull request #9 from amplab/tests

Improved OrchPy lib error handling, fixed tests, added tests
This commit is contained in:
Robert Nishihara 2016-03-10 23:51:40 -08:00
commit dc660e945f
11 changed files with 170 additions and 150 deletions

View file

@ -57,7 +57,7 @@ def main_loop(worker=global_worker):
arguments = get_arguments_for_execution(worker.functions[func_name], args, worker) # get args from objstore arguments = get_arguments_for_execution(worker.functions[func_name], args, worker) # get args from objstore
outputs = worker.functions[func_name].executor(arguments) # execute the function outputs = worker.functions[func_name].executor(arguments) # execute the function
store_outputs_in_objstore(return_objrefs, outputs, worker) # store output in local object store store_outputs_in_objstore(return_objrefs, outputs, worker) # store output in local object store
# TODO(rkn): notify the scheduler that the task has completed, orchpy.lib.notify_task_completed(worker.handle) orchpy.lib.notify_task_completed(worker.handle) # notify the scheduler that the task has completed
def distributed(arg_types, return_types, worker=global_worker): def distributed(arg_types, return_types, worker=global_worker):
def distributed_decorator(func): def distributed_decorator(func):

View file

@ -63,7 +63,7 @@ message ChangeCountRequest {
uint64 objref = 1; uint64 objref = 1;
} }
message GetDebugInfoRequest { message SchedulerDebugInfoRequest {
bool do_scheduling = 1; bool do_scheduling = 1;
} }
@ -72,7 +72,7 @@ message FnTableEntry {
uint64 num_return_vals = 2; uint64 num_return_vals = 2;
} }
message GetDebugInfoReply { message SchedulerDebugInfoReply {
repeated Call task = 1; repeated Call task = 1;
repeated uint64 avail_worker = 3; repeated uint64 avail_worker = 3;
map<string, FnTableEntry> function_table = 2; map<string, FnTableEntry> function_table = 2;
@ -100,7 +100,7 @@ service Scheduler {
// used by the worker to report back and ask for more work // used by the worker to report back and ask for more work
rpc WorkerReady(WorkerReadyRequest) returns (AckReply); rpc WorkerReady(WorkerReadyRequest) returns (AckReply);
// get debugging information from the scheduler // get debugging information from the scheduler
rpc GetDebugInfo(GetDebugInfoRequest) returns (GetDebugInfoReply); rpc SchedulerDebugInfo(SchedulerDebugInfoRequest) returns (SchedulerDebugInfoReply);
} }
message DeliverObjRequest { message DeliverObjRequest {
@ -132,9 +132,9 @@ message GetObjReply {
uint64 size = 3; uint64 size = 3;
} }
message DebugInfoRequest {} message ObjStoreDebugInfoRequest {}
message DebugInfoReply { message ObjStoreDebugInfoReply {
repeated uint64 objref = 1; repeated uint64 objref = 1;
} }
@ -144,7 +144,7 @@ service ObjStore {
// Accept incoming data from another object store // Accept incoming data from another object store
rpc StreamObj(stream ObjChunk) returns (AckReply); rpc StreamObj(stream ObjChunk) returns (AckReply);
rpc GetObj(GetObjRequest) returns (GetObjReply); rpc GetObj(GetObjRequest) returns (GetObjReply);
rpc DebugInfo(DebugInfoRequest) returns (DebugInfoReply); rpc ObjStoreDebugInfo(ObjStoreDebugInfoRequest) returns (ObjStoreDebugInfoReply);
} }
message InvokeCallRequest { message InvokeCallRequest {

View file

@ -72,7 +72,7 @@ Status ObjStoreService::DeliverObj(ServerContext* context, const DeliverObjReque
return status; return status;
} }
Status ObjStoreService::DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) { Status ObjStoreService::ObjStoreDebugInfo(ServerContext* context, const ObjStoreDebugInfoRequest* request, ObjStoreDebugInfoReply* reply) {
std::lock_guard<std::mutex> memory_lock(memory_lock_); std::lock_guard<std::mutex> memory_lock(memory_lock_);
for (const auto& entry : memory_) { for (const auto& entry : memory_) {
reply->add_objref(entry.first); reply->add_objref(entry.first);

View file

@ -41,7 +41,7 @@ public:
~ObjStoreService(); ~ObjStoreService();
Status DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) override; Status DeliverObj(ServerContext* context, const DeliverObjRequest* request, AckReply* reply) override;
Status DebugInfo(ServerContext* context, const DebugInfoRequest* request, DebugInfoReply* reply) override; Status ObjStoreDebugInfo(ServerContext* context, const ObjStoreDebugInfoRequest* request, ObjStoreDebugInfoReply* reply) override;
Status GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) override; Status GetObj(ServerContext* context, const GetObjRequest* request, GetObjReply* reply) override;
Status StreamObj(ServerContext* context, ServerReader<ObjChunk>* reader, AckReply* reply) override; Status StreamObj(ServerContext* context, ServerReader<ObjChunk>* reader, AckReply* reply) override;
private: private:

View file

@ -10,23 +10,6 @@
#include "types.pb.h" #include "types.pb.h"
#include "worker.h" #include "worker.h"
extern "C" {
// Error handling
static PyObject *OrchPyError;
}
// extracts a pointer from a python C API capsule
template<typename T>
T* get_pointer_or_fail(PyObject* capsule, const char* name) {
if (PyCapsule_IsValid(capsule, name)) {
return static_cast<T*>(PyCapsule_GetPointer(capsule, name));
} else {
PyErr_SetString(OrchPyError, "not a vaid capsule");
return NULL;
}
}
extern "C" { extern "C" {
// Object references // Object references
@ -110,6 +93,50 @@ PyObject* make_pyobjref(ObjRef objref) {
return result; return result;
} }
// Error handling
static PyObject *OrchPyError;
int PyObjectToCall(PyObject* object, Call **call) {
if (PyCapsule_IsValid(object, "call")) {
*call = static_cast<Call*>(PyCapsule_GetPointer(object, "call"));
return 1;
} else {
PyErr_SetString(PyExc_TypeError, "must be a 'call' capsule");
return 0;
}
}
int PyObjectToObj(PyObject* object, Obj **obj) {
if (PyCapsule_IsValid(object, "obj")) {
*obj = static_cast<Obj*>(PyCapsule_GetPointer(object, "obj"));
return 1;
} else {
PyErr_SetString(PyExc_TypeError, "must be a 'obj' capsule");
return 0;
}
}
int PyObjectToWorker(PyObject* object, Worker **worker) {
if (PyCapsule_IsValid(object, "worker")) {
*worker = static_cast<Worker*>(PyCapsule_GetPointer(object, "worker"));
return 1;
} else {
PyErr_SetString(PyExc_TypeError, "must be a 'worker' capsule");
return 0;
}
}
int PyObjectToObjRef(PyObject* object, ObjRef *objref) {
if (PyObject_IsInstance(object, (PyObject*)&PyObjRefType)) {
*objref = ((PyObjRef*) object)->val;
return 1;
} else {
PyErr_SetString(PyExc_TypeError, "must be a 'worker' capsule");
return 0;
}
}
// Serialization // Serialization
// serialize will serialize the python object val into the protocol buffer // serialize will serialize the python object val into the protocol buffer
@ -206,12 +233,8 @@ PyObject* serialize_object(PyObject* self, PyObject* args) {
} }
PyObject* deserialize_object(PyObject* self, PyObject* args) { PyObject* deserialize_object(PyObject* self, PyObject* args) {
PyObject* capsule; Obj* obj;
if (!PyArg_ParseTuple(args, "O", &capsule)) { if (!PyArg_ParseTuple(args, "O&", &PyObjectToObj, &obj)) {
return NULL;
}
Obj* obj = get_pointer_or_fail<Obj>(capsule, "obj");
if (!obj) {
return NULL; return NULL;
} }
return deserialize(*obj); return deserialize(*obj);
@ -239,9 +262,8 @@ PyObject* serialize_call(PyObject* self, PyObject* args) {
} }
PyObject* deserialize_call(PyObject* self, PyObject* args) { PyObject* deserialize_call(PyObject* self, PyObject* args) {
PyObject* capsule = PyTuple_GetItem(args, 0); Call* call;
Call* call = get_pointer_or_fail<Call>(capsule, "call"); if (!PyArg_ParseTuple(args, "O&", &PyObjectToCall, &call)) {
if (!call) {
return NULL; return NULL;
} }
PyObject* string = PyString_FromStringAndSize(call->name().c_str(), call->name().size()); PyObject* string = PyString_FromStringAndSize(call->name().c_str(), call->name().size());
@ -250,7 +272,7 @@ PyObject* deserialize_call(PyObject* self, PyObject* args) {
for (int i = 0; i < argsize; ++i) { for (int i = 0; i < argsize; ++i) {
const Value& val = call->arg(i); const Value& val = call->arg(i);
if (!val.has_obj()) { if (!val.has_obj()) {
// TODO: Deserialize object reference here PyList_SetItem(arglist, i, make_pyobjref(val.ref()));
} else { } else {
PyList_SetItem(arglist, i, deserialize(val.obj())); PyList_SetItem(arglist, i, deserialize(val.obj()));
} }
@ -280,9 +302,8 @@ PyObject* create_worker(PyObject* self, PyObject* args) {
} }
PyObject* wait_for_next_task(PyObject* self, PyObject* args) { PyObject* wait_for_next_task(PyObject* self, PyObject* args) {
PyObject* capsule = PyTuple_GetItem(args, 0); Worker* worker;
Worker* worker = get_pointer_or_fail<Worker>(capsule, "worker"); if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
if (!worker) {
return NULL; return NULL;
} }
Call* call = worker->receive_next_task(); Call* call = worker->receive_next_task();
@ -290,17 +311,9 @@ PyObject* wait_for_next_task(PyObject* self, PyObject* args) {
} }
PyObject* remote_call(PyObject* self, PyObject* args) { PyObject* remote_call(PyObject* self, PyObject* args) {
PyObject* worker_capsule; Worker* worker;
PyObject* call_capsule; Call* call;
if (!PyArg_ParseTuple(args, "OO", &worker_capsule, &call_capsule)) { if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToCall, &call)) {
return NULL;
}
Worker* worker = get_pointer_or_fail<Worker>(worker_capsule, "worker");
if (!worker) {
return NULL;
}
Call* call = get_pointer_or_fail<Call>(call_capsule, "call");
if (!call) {
return NULL; return NULL;
} }
RemoteCallRequest request; RemoteCallRequest request;
@ -315,111 +328,81 @@ PyObject* remote_call(PyObject* self, PyObject* args) {
return list; return list;
} }
PyObject* register_function(PyObject* self, PyObject* args) { PyObject* notify_task_completed(PyObject* self, PyObject* args) {
PyObject* worker_capsule; Worker* worker;
const char* function_name; if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
int num_return_vals;
if (!PyArg_ParseTuple(args, "Osi", &worker_capsule, &function_name, &num_return_vals)) {
return NULL; return NULL;
} }
Worker* worker = get_pointer_or_fail<Worker>(worker_capsule, "worker"); worker->notify_task_completed();
if (!worker) { Py_RETURN_NONE;
}
PyObject* register_function(PyObject* self, PyObject* args) {
Worker* worker;
const char* function_name;
int num_return_vals;
if (!PyArg_ParseTuple(args, "O&si", &PyObjectToWorker, &worker, &function_name, &num_return_vals)) {
return NULL; return NULL;
} }
worker->register_function(std::string(function_name), num_return_vals); worker->register_function(std::string(function_name), num_return_vals);
Py_RETURN_NONE; Py_RETURN_NONE;
} }
// TODO: test this
PyObject* push_object(PyObject* self, PyObject* args) { PyObject* push_object(PyObject* self, PyObject* args) {
PyObject* worker_capsule; Worker* worker;
PyObject* obj_capsule; Obj* obj;
if (!PyArg_ParseTuple(args, "OO", &worker_capsule, &obj_capsule)) { if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToObj, &obj)) {
return NULL;
}
Worker* worker = get_pointer_or_fail<Worker>(worker_capsule, "worker");
if (!worker) {
return NULL;
}
Obj* obj = get_pointer_or_fail<Obj>(obj_capsule, "obj");
if (!obj) {
return NULL; return NULL;
} }
ObjRef objref = worker->push_object(obj); ObjRef objref = worker->push_object(obj);
return make_pyobjref(objref); return make_pyobjref(objref);
} }
// TODO: test this
PyObject* put_object(PyObject* self, PyObject* args) { PyObject* put_object(PyObject* self, PyObject* args) {
PyObject* worker_capsule; Worker* worker;
PyObject* pyobjref; ObjRef objref;
PyObject* obj_capsule; Obj* obj;
if (!PyArg_ParseTuple(args, "OOO", &worker_capsule, &pyobjref, &obj_capsule)) { if (!PyArg_ParseTuple(args, "O&O&O&", &PyObjectToWorker, &worker, &PyObjectToObjRef, &objref, &PyObjectToObj, &obj)) {
return NULL; return NULL;
} }
Worker* worker = get_pointer_or_fail<Worker>(worker_capsule, "worker");
if (!worker) {
return NULL;
}
Obj* obj = get_pointer_or_fail<Obj>(obj_capsule, "obj");
if (!obj) {
return NULL;
}
ObjRef objref = ((PyObjRef*) pyobjref)->val;
worker->put_object(objref, obj); worker->put_object(objref, obj);
Py_RETURN_NONE; Py_RETURN_NONE;
} }
PyObject* get_object(PyObject* self, PyObject* args) { PyObject* get_object(PyObject* self, PyObject* args) {
PyObject* worker_capsule; Worker* worker;
PyObject* pyobjref; ObjRef objref;
if (!PyArg_ParseTuple(args, "OO", &worker_capsule, &pyobjref)) { if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToObjRef, &objref)) {
return NULL; return NULL;
} }
Worker* worker = get_pointer_or_fail<Worker>(worker_capsule, "worker");
if (!worker) {
return NULL;
}
ObjRef objref = ((PyObjRef*) pyobjref)->val;
slice s = worker->get_object(objref); slice s = worker->get_object(objref);
Obj* obj = new Obj(); // TODO: Make sure this will get deleted Obj* obj = new Obj(); // TODO: Make sure this will get deleted
obj->ParseFromString(std::string(s.data, s.len)); obj->ParseFromString(std::string(s.data, s.len));
return PyCapsule_New(static_cast<void*>(obj), "obj", NULL); return PyCapsule_New(static_cast<void*>(obj), "obj", NULL);
} }
// TODO: implement this
PyObject* pull_object(PyObject* self, PyObject* args) { PyObject* pull_object(PyObject* self, PyObject* args) {
PyObject* worker_capsule; Worker* worker;
PyObject* pyobjref; ObjRef objref;
if (!PyArg_ParseTuple(args, "OO", &worker_capsule, &pyobjref)) { if (!PyArg_ParseTuple(args, "O&O&", &PyObjectToWorker, &worker, &PyObjectToObjRef, &objref)) {
return NULL; return NULL;
} }
Worker* worker = get_pointer_or_fail<Worker>(worker_capsule, "worker");
if (!worker) {
return NULL;
}
ObjRef objref = ((PyObjRef*) pyobjref)->val;
slice s = worker->get_object(objref); slice s = worker->get_object(objref);
Obj* obj = new Obj(); // TODO: Make sure this will get deleted Obj* obj = new Obj(); // TODO: Make sure this will get deleted
obj->ParseFromString(std::string(s.data, s.len)); obj->ParseFromString(std::string(s.data, s.len));
return PyCapsule_New(static_cast<void*>(obj), "obj", NULL); return PyCapsule_New(static_cast<void*>(obj), "obj", NULL);
} }
// TODO: test this
PyObject* start_worker_service(PyObject* self, PyObject* args) { PyObject* start_worker_service(PyObject* self, PyObject* args) {
PyObject* worker_capsule; Worker* worker;
if (!PyArg_ParseTuple(args, "O", &worker_capsule)) { if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
return NULL;
}
Worker* worker = get_pointer_or_fail<Worker>(worker_capsule, "worker");
if (!worker) {
return NULL; return NULL;
} }
worker->start_worker_service(); worker->start_worker_service();
Py_RETURN_NONE; Py_RETURN_NONE;
} }
static PyMethodDef SymphonyMethods[] = { static PyMethodDef OrchPyLibMethods[] = {
{ "serialize_object", serialize_object, METH_VARARGS, "serialize an object to protocol buffers" }, { "serialize_object", serialize_object, METH_VARARGS, "serialize an object to protocol buffers" },
{ "deserialize_object", deserialize_object, METH_VARARGS, "deserialize an object from protocol buffers" }, { "deserialize_object", deserialize_object, METH_VARARGS, "deserialize an object from protocol buffers" },
{ "serialize_call", serialize_call, METH_VARARGS, "serialize a call to protocol buffers" }, { "serialize_call", serialize_call, METH_VARARGS, "serialize a call to protocol buffers" },
@ -432,6 +415,7 @@ static PyMethodDef SymphonyMethods[] = {
{ "pull_object" , pull_object, METH_VARARGS, "pull object with a given object id from the object store" }, { "pull_object" , pull_object, METH_VARARGS, "pull object with a given object id from the object store" },
{ "wait_for_next_task", wait_for_next_task, METH_VARARGS, "get next task from scheduler (blocking)" }, { "wait_for_next_task", wait_for_next_task, METH_VARARGS, "get next task from scheduler (blocking)" },
{ "remote_call", remote_call, METH_VARARGS, "call a remote function" }, { "remote_call", remote_call, METH_VARARGS, "call a remote function" },
{ "notify_task_completed", notify_task_completed, METH_VARARGS, "notify the scheduler that a task has been completed" },
{ "start_worker_service", start_worker_service, METH_VARARGS, "start the worker service" }, { "start_worker_service", start_worker_service, METH_VARARGS, "start the worker service" },
{ NULL, NULL, 0, NULL } { NULL, NULL, 0, NULL }
}; };
@ -442,7 +426,7 @@ PyMODINIT_FUNC initliborchpylib(void) {
if (PyType_Ready(&PyObjRefType) < 0) { if (PyType_Ready(&PyObjRefType) < 0) {
return; return;
} }
m = Py_InitModule3("liborchpylib", SymphonyMethods, "Python C Extension for Orchestra"); m = Py_InitModule3("liborchpylib", OrchPyLibMethods, "Python C Extension for Orchestra");
Py_INCREF(&PyObjRefType); Py_INCREF(&PyObjRefType);
PyModule_AddObject(m, "ObjRef", (PyObject *)&PyObjRefType); PyModule_AddObject(m, "ObjRef", (PyObject *)&PyObjRefType);
OrchPyError = PyErr_NewException("orchpy.error", NULL, NULL); OrchPyError = PyErr_NewException("orchpy.error", NULL, NULL);

View file

@ -87,7 +87,7 @@ Status SchedulerService::WorkerReady(ServerContext* context, const WorkerReadyRe
return Status::OK; return Status::OK;
} }
Status SchedulerService::GetDebugInfo(ServerContext* context, const GetDebugInfoRequest* request, GetDebugInfoReply* reply) { Status SchedulerService::SchedulerDebugInfo(ServerContext* context, const SchedulerDebugInfoRequest* request, SchedulerDebugInfoReply* reply) {
debug_info(*request, reply); debug_info(*request, reply);
return Status::OK; return Status::OK;
} }
@ -218,7 +218,7 @@ void SchedulerService::register_function(const std::string& name, WorkerId worke
info.add_worker(workerid); info.add_worker(workerid);
} }
void SchedulerService::debug_info(const GetDebugInfoRequest& request, GetDebugInfoReply* reply) { void SchedulerService::debug_info(const SchedulerDebugInfoRequest& request, SchedulerDebugInfoReply* reply) {
if (request.do_scheduling()) { if (request.do_scheduling()) {
schedule(); schedule();
} }

View file

@ -45,7 +45,7 @@ public:
Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override; Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override;
Status ObjReady(ServerContext* context, const ObjReadyRequest* request, AckReply* reply) override; Status ObjReady(ServerContext* context, const ObjReadyRequest* request, AckReply* reply) override;
Status WorkerReady(ServerContext* context, const WorkerReadyRequest* request, AckReply* reply) override; Status WorkerReady(ServerContext* context, const WorkerReadyRequest* request, AckReply* reply) override;
Status GetDebugInfo(ServerContext* context, const GetDebugInfoRequest* request, GetDebugInfoReply* reply) override; Status SchedulerDebugInfo(ServerContext* context, const SchedulerDebugInfoRequest* request, SchedulerDebugInfoReply* reply) override;
// ask an object store to send object to another objectstore // ask an object store to send object to another objectstore
void deliver_object(ObjRef objref, ObjStoreId from, ObjStoreId to); void deliver_object(ObjRef objref, ObjStoreId from, ObjStoreId to);
@ -66,7 +66,7 @@ public:
// register a function with the scheduler // register a function with the scheduler
void register_function(const std::string& name, WorkerId workerid, size_t num_return_vals); void register_function(const std::string& name, WorkerId workerid, size_t num_return_vals);
// get debugging information for the scheduler // get debugging information for the scheduler
void debug_info(const GetDebugInfoRequest& request, GetDebugInfoReply* reply); void debug_info(const SchedulerDebugInfoRequest& request, SchedulerDebugInfoReply* reply);
private: private:
// pick an objectstore that holds a given object (needs protection by objtable_lock_) // pick an objectstore that holds a given object (needs protection by objtable_lock_)
ObjStoreId pick_objstore(ObjRef objref); ObjStoreId pick_objstore(ObjRef objref);

View file

@ -105,23 +105,6 @@ void Worker::register_function(const std::string& name, size_t num_return_vals)
scheduler_stub_->RegisterFunction(&context, request, &reply); scheduler_stub_->RegisterFunction(&context, request, &reply);
} }
// Communication between the WorkerServer and the Worker happens via a message
// queue. This is because the Python interpreter needs to be single threaded
// (in our case running in the main thread), whereas the WorkerService will
// run in a separate thread and potentially utilize multiple threads.
void Worker::start_worker_service() {
const char* server_address = worker_address_.c_str();
worker_server_thread_ = std::thread([server_address]() {
WorkerServiceImpl service(server_address);
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<Server> server(builder.BuildAndStart());
ORCH_LOG(ORCH_INFO, "worker server listening on " << server_address);
server->Wait();
});
}
Call* Worker::receive_next_task() { Call* Worker::receive_next_task() {
const char* message_queue_name = worker_address_.c_str(); const char* message_queue_name = worker_address_.c_str();
try { try {
@ -140,3 +123,28 @@ Call* Worker::receive_next_task() {
std::cout << ex.what() << std::endl; std::cout << ex.what() << std::endl;
} }
} }
void Worker::notify_task_completed() {
ClientContext context;
WorkerReadyRequest request;
request.set_workerid(workerid_);
AckReply reply;
scheduler_stub_->WorkerReady(&context, request, &reply);
}
// Communication between the WorkerServer and the Worker happens via a message
// queue. This is because the Python interpreter needs to be single threaded
// (in our case running in the main thread), whereas the WorkerService will
// run in a separate thread and potentially utilize multiple threads.
void Worker::start_worker_service() {
const char* server_address = worker_address_.c_str();
worker_server_thread_ = std::thread([server_address]() {
WorkerServiceImpl service(server_address);
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<Server> server(builder.BuildAndStart());
ORCH_LOG(ORCH_INFO, "worker server listening on " << server_address);
server->Wait();
});
}

View file

@ -62,6 +62,8 @@ class Worker {
void start_worker_service(); void start_worker_service();
// wait for next task from the RPC system // wait for next task from the RPC system
Call* receive_next_task(); Call* receive_next_task();
// tell the scheduler that we are done with the current task and request the next one
void notify_task_completed();
private: private:
const size_t CHUNK_SIZE = 8 * 1024; const size_t CHUNK_SIZE = 8 * 1024;

View file

@ -61,6 +61,30 @@ class SerializationTest(unittest.TestCase):
b = orchpy.lib.deserialize_object(res) b = orchpy.lib.deserialize_object(res)
self.assertTrue((a == b).all()) self.assertTrue((a == b).all())
class OrchPyLibTest(unittest.TestCase):
def testOrchPyLib(self):
scheduler_port = new_scheduler_port()
objstore_port = new_objstore_port()
worker_port = new_worker_port()
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
time.sleep(0.1)
services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port))
time.sleep(0.2)
w = worker.Worker()
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker_port), w)
w.put_object(orchpy.lib.ObjRef(0), 'hello world')
result = w.get_object(orchpy.lib.ObjRef(0))
self.assertEqual(result, 'hello world')
class ObjStoreTest(unittest.TestCase): class ObjStoreTest(unittest.TestCase):
"""Test setting up object stores, transfering data between them and retrieving data to a client""" """Test setting up object stores, transfering data between them and retrieving data to a client"""
@ -140,12 +164,12 @@ class SchedulerTest(unittest.TestCase):
time.sleep(0.2) time.sleep(0.2)
# value_after = worker.pull(objref, worker1) value_after = worker.pull(objref[0], worker1)
# self.assertEqual(value_before, value_after) self.assertEqual(value_before, value_after)
time.sleep(0.1) time.sleep(0.1)
reply = scheduler_stub.GetDebugInfo(orchestra_pb2.GetDebugInfoRequest(), TIMEOUT_SECONDS) reply = scheduler_stub.SchedulerDebugInfo(orchestra_pb2.SchedulerDebugInfoRequest(), TIMEOUT_SECONDS)
services.cleanup() services.cleanup()

View file

@ -7,32 +7,34 @@ from grpc.beta import implementations
import orchestra_pb2 import orchestra_pb2
import types_pb2 import types_pb2
TIMEOUT_SECONDS = 5
parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.') parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.')
parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address") parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address")
parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address") parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address")
parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address") parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address")
@worker.distributed([str], [str]) def connect_to_scheduler(host, port):
def print_string(string): channel = implementations.insecure_channel(host, port)
print "called print_string with", string
f = open("asdfasdf.txt", "w")
f.write("successfully called print_string with argument {}.".format(string))
return string
@worker.distributed([int, int], [int, int])
def handle_int(a, b):
return a + 1, b + 1
def connect_to_scheduler(address):
channel = implementations.insecure_channel(address)
return orchestra_pb2.beta_create_Scheduler_stub(channel) return orchestra_pb2.beta_create_Scheduler_stub(channel)
def address(host, port): def connect_to_objstore(host, port):
return host + ":" + str(port) channel = implementations.insecure_channel(host, port)
return orchestra_pb2.beta_create_ObjStore_stub(channel)
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
scheduler_stub = connect_to_scheduler(args.scheduler_address) scheduler_ip_address, scheduler_port = args.scheduler_address.split(":")
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)) scheduler_stub = connect_to_scheduler(scheduler_ip_address, int(scheduler_port))
objstore_ip_address, objstore_port = args.objstore_address.split(":")
objstore_stub = connect_to_objstore(objstore_ip_address, int(objstore_port))
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
def scheduler_debug_info():
return scheduler_stub.SchedulerDebugInfo(orchestra_pb2.SchedulerDebugInfoRequest(), TIMEOUT_SECONDS)
def objstore_debug_info():
return objstore_stub.ObjStoreDebugInfo(orchestra_pb2.ObjStoreDebugInfoRequest(), TIMEOUT_SECONDS)
import IPython import IPython
IPython.embed() IPython.embed()