Add ray.internal.free (#2542)

This commit is contained in:
Yuhong Guo 2018-08-15 13:01:23 +08:00 committed by Robert Nishihara
parent f13e3e22f2
commit eeb15771ba
19 changed files with 346 additions and 2 deletions

View file

@ -54,6 +54,7 @@ from ray.worker import (error_info, init, connect, disconnect, get, put, wait,
from ray.worker import (SCRIPT_MODE, WORKER_MODE, LOCAL_MODE, SILENT_MODE,
PYTHON_MODE) # noqa: E402
from ray.worker import global_state # noqa: E402
import ray.internal # noqa: E402
# We import ray.actor because some code is run in actor.py which initializes
# some functions in the worker.
import ray.actor # noqa: F401
@ -68,7 +69,7 @@ __all__ = [
"remote", "profile", "actor", "method", "get_gpu_ids", "get_resource_ids",
"get_webui_url", "register_custom_serializer", "shutdown", "SCRIPT_MODE",
"WORKER_MODE", "LOCAL_MODE", "SILENT_MODE", "PYTHON_MODE", "global_state",
"ObjectID", "_config", "__version__"
"ObjectID", "_config", "__version__", "internal"
]
import ctypes # noqa: E402

View file

@ -0,0 +1,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.internal.internal_api import free
__all__ = ["free"]

View file

@ -0,0 +1,48 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray.local_scheduler
import ray.worker
from ray import profiling
__all__ = ["free"]
def free(object_ids, local_only=False, worker=None):
"""Free a list of IDs from object stores.
This function is a low-level API which should be used in restricted
scenarios.
If local_only is false, the request will be send to all object stores.
This method will not return any value to indicate whether the deletion is
successful or not. This function is an instruction to object store. If
the some of the objects are in use, object stores will delete them later
when the ref count is down to 0.
Args:
object_ids (List[ObjectID]): List of object IDs to delete.
local_only (bool): Whether only deleting the list of objects in local
object store or all object stores.
"""
if worker is None:
worker = ray.worker.get_global_worker()
if isinstance(object_ids, ray.ObjectID):
object_ids = [object_ids]
if not isinstance(object_ids, list):
raise TypeError("free() expects a list of ObjectID, got {}".format(
type(object_ids)))
worker.check_connected()
with profiling.profile("ray.free", worker=worker):
if len(object_ids) == 0:
return
if worker.use_raylet:
worker.local_scheduler_client.free(object_ids, local_only)
else:
raise Exception("Free is not supported in legacy backend.")

View file

@ -414,6 +414,43 @@ static PyObject *PyLocalSchedulerClient_push_profile_events(PyObject *self,
Py_RETURN_NONE;
}
static PyObject *PyLocalSchedulerClient_free(PyObject *self, PyObject *args) {
PyObject *py_object_ids;
PyObject *py_local_only;
if (!PyArg_ParseTuple(args, "OO", &py_object_ids, &py_local_only)) {
return NULL;
}
bool local_only = static_cast<bool>(PyObject_IsTrue(py_local_only));
// Convert object ids.
PyObject *iter = PyObject_GetIter(py_object_ids);
if (!iter) {
return NULL;
}
std::vector<ObjectID> object_ids;
while (true) {
PyObject *next = PyIter_Next(iter);
ObjectID object_id;
if (!next) {
break;
}
if (!PyObjectToUniqueID(next, &object_id)) {
// Error parsing object ID.
return NULL;
}
object_ids.push_back(object_id);
}
// Invoke local_scheduler_free_objects_in_object_store.
local_scheduler_free_objects_in_object_store(
reinterpret_cast<PyLocalSchedulerClient *>(self)
->local_scheduler_connection,
object_ids, local_only);
Py_RETURN_NONE;
}
static PyMethodDef PyLocalSchedulerClient_methods[] = {
{"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS,
"Notify the local scheduler that this client is exiting gracefully."},
@ -446,6 +483,8 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = {
{"push_profile_events",
(PyCFunction) PyLocalSchedulerClient_push_profile_events, METH_VARARGS,
"Store some profiling events in the GCS."},
{"free", (PyCFunction) PyLocalSchedulerClient_free, METH_VARARGS,
"Free a list of objects from object stores."},
{NULL} /* Sentinel */
};

View file

@ -351,3 +351,20 @@ void local_scheduler_push_profile_events(
ray::protocol::MessageType::PushProfileEventsRequest),
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
}
void local_scheduler_free_objects_in_object_store(
LocalSchedulerConnection *conn,
const std::vector<ray::ObjectID> &object_ids,
bool local_only) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateFreeObjectsRequest(
fbb, local_only, to_flatbuf(fbb, object_ids));
fbb.Finish(message);
int success = write_message(
conn->conn,
static_cast<int64_t>(
ray::protocol::MessageType::FreeObjectsInObjectStoreRequest),
fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex);
RAY_CHECK(success == 0) << "Failed to write message to raylet.";
}

View file

@ -244,4 +244,16 @@ void local_scheduler_push_profile_events(
LocalSchedulerConnection *conn,
const ProfileTableDataT &profile_events);
/// Free a list of objects from object stores.
///
/// \param conn The connection information.
/// \param object_ids A list of ObjectsIDs to be deleted.
/// \param local_only Whether keep this request with local object store
/// or send it to all the object stores.
/// \return Void.
void local_scheduler_free_objects_in_object_store(
LocalSchedulerConnection *conn,
const std::vector<ray::ObjectID> &object_ids,
bool local_only);
#endif

View file

@ -426,6 +426,10 @@ const ClientTableDataT &ClientTable::GetClient(const ClientID &client_id) const
}
}
const std::unordered_map<ClientID, ClientTableDataT> &ClientTable::GetAllClients() const {
return client_cache_;
}
template class Log<ObjectID, ObjectTableData>;
template class Log<TaskID, ray::protocol::Task>;
template class Table<TaskID, ray::protocol::Task>;

View file

@ -651,6 +651,13 @@ class ClientTable : private Log<UniqueID, ClientTableData> {
/// \return Whether the client with ID client_id is removed.
bool IsRemoved(const ClientID &client_id) const;
/// Get the information of all clients.
///
/// Note: The return value contains ClientID::nil() which should be filtered.
///
/// \return The client ID to client information map.
const std::unordered_map<ClientID, ClientTableDataT> &GetAllClients() const;
private:
/// Handle a client table notification.
void HandleNotification(AsyncGcsClient *client, const ClientTableDataT &notifications);

View file

@ -4,7 +4,8 @@ namespace ray.object_manager.protocol;
enum MessageType:int {
ConnectClient = 1,
PushRequest,
PullRequest
PullRequest,
FreeRequest
}
table PushRequestMessage {
@ -31,3 +32,8 @@ table ConnectClientMessage {
// Whether this is a transfer connection.
is_transfer: bool;
}
table FreeRequestMessage {
// List of IDs to be deleted.
object_ids: [string];
}

View file

@ -185,4 +185,13 @@ std::vector<ObjectBufferPool::ChunkInfo> ObjectBufferPool::BuildChunks(
return chunks;
}
void ObjectBufferPool::FreeObjects(const std::vector<ObjectID> &object_ids) {
std::vector<plasma::ObjectID> plasma_ids;
plasma_ids.reserve(object_ids.size());
for (const auto &id : object_ids) {
plasma_ids.push_back(id.to_plasma_id());
}
ARROW_CHECK_OK(store_client_.Delete(plasma_ids));
}
} // namespace ray

View file

@ -123,6 +123,12 @@ class ObjectBufferPool {
/// \param chunk_index The index of the chunk.
void SealChunk(const ObjectID &object_id, uint64_t chunk_index);
/// Free a list of objects from object store.
///
/// \param object_ids the The list of ObjectIDs to be deleted.
/// \return Void.
void FreeObjects(const std::vector<ObjectID> &object_ids);
private:
/// Abort the create operation associated with an object. This destroys the buffer
/// state, including create operations in progress for all chunks of the object.

View file

@ -117,6 +117,24 @@ ray::Status ObjectDirectory::GetInformation(const ClientID &client_id,
return ray::Status::OK();
}
void ObjectDirectory::RunFunctionForEachClient(
const InfoSuccessCallback &client_function) {
const auto &clients = gcs_client_->client_table().GetAllClients();
for (const auto &client_pair : clients) {
const ClientTableDataT &data = client_pair.second;
if (client_pair.first == ClientID::nil() ||
client_pair.first == gcs_client_->client_table().GetLocalClientId() ||
!data.is_insertion) {
continue;
} else {
const auto &info =
RemoteConnectionInfo(client_pair.first, data.node_manager_address,
static_cast<uint16_t>(data.object_manager_port));
client_function(info);
}
}
}
ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_id,
const ObjectID &object_id,
const OnLocationsFound &callback) {

View file

@ -101,6 +101,13 @@ class ObjectDirectoryInterface {
/// \return Status of whether this method succeeded.
virtual ray::Status ReportObjectRemoved(const ObjectID &object_id,
const ClientID &client_id) = 0;
/// Go through all the client information.
///
/// \param success_cb A callback which handles the success of this method.
/// This function will be called multiple times.
/// \return Void.
virtual void RunFunctionForEachClient(const InfoSuccessCallback &client_function) = 0;
};
/// Ray ObjectDirectory declaration.
@ -115,6 +122,8 @@ class ObjectDirectory : public ObjectDirectoryInterface {
const InfoSuccessCallback &success_callback,
const InfoFailureCallback &fail_callback) override;
void RunFunctionForEachClient(const InfoSuccessCallback &client_function) override;
ray::Status LookupLocations(const ObjectID &object_id,
const OnLocationsFound &callback) override;

View file

@ -1,4 +1,5 @@
#include "ray/object_manager/object_manager.h"
#include "common/common_protocol.h"
#include "ray/util/util.h"
namespace asio = boost::asio;
@ -655,6 +656,10 @@ void ObjectManager::ProcessClientMessage(std::shared_ptr<TcpClientConnection> &c
ConnectClient(conn, message);
break;
}
case static_cast<int64_t>(object_manager_protocol::MessageType::FreeRequest): {
ReceiveFreeRequest(conn, message);
break;
}
case static_cast<int64_t>(protocol::MessageType::DisconnectClient): {
// TODO(hme): Disconnect without depending on the node manager protocol.
DisconnectClient(conn, message);
@ -755,4 +760,51 @@ void ObjectManager::ExecuteReceiveObject(const ClientID &client_id,
<< "/" << config_.max_receives;
}
void ObjectManager::ReceiveFreeRequest(std::shared_ptr<TcpClientConnection> &conn,
const uint8_t *message) {
auto free_request =
flatbuffers::GetRoot<object_manager_protocol::FreeRequestMessage>(message);
std::vector<ObjectID> object_ids = from_flatbuf(*free_request->object_ids());
// This RPC should come from another Object Manager.
// Keep this request local.
bool local_only = true;
FreeObjects(object_ids, local_only);
conn->ProcessMessages();
}
void ObjectManager::FreeObjects(const std::vector<ObjectID> &object_ids,
bool local_only) {
buffer_pool_.FreeObjects(object_ids);
if (!local_only) {
SpreadFreeObjectRequest(object_ids);
}
}
void ObjectManager::SpreadFreeObjectRequest(const std::vector<ObjectID> &object_ids) {
// This code path should be called from node manager.
flatbuffers::FlatBufferBuilder fbb;
flatbuffers::Offset<object_manager_protocol::FreeRequestMessage> request =
object_manager_protocol::CreateFreeRequestMessage(fbb, to_flatbuf(fbb, object_ids));
fbb.Finish(request);
auto function_on_client = [this, &fbb](const RemoteConnectionInfo &connection_info) {
std::shared_ptr<SenderConnection> conn;
connection_pool_.GetSender(ConnectionPool::ConnectionType::MESSAGE,
connection_info.client_id, &conn);
if (conn == nullptr) {
conn = CreateSenderConnection(ConnectionPool::ConnectionType::MESSAGE,
connection_info);
connection_pool_.RegisterSender(ConnectionPool::ConnectionType::MESSAGE,
connection_info.client_id, conn);
}
ray::Status status = conn->WriteMessage(
static_cast<int64_t>(object_manager_protocol::MessageType::FreeRequest),
fbb.GetSize(), fbb.GetBufferPointer());
if (status.ok()) {
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::MESSAGE, conn);
}
// TODO(Yuhong): Implement ConnectionPool::RemoveSender and call it in "else".
};
object_directory_->RunFunctionForEachClient(function_on_client);
}
} // namespace ray

View file

@ -163,6 +163,13 @@ class ObjectManager : public ObjectManagerInterface {
uint64_t num_required_objects, bool wait_local,
const WaitCallback &callback);
/// Free a list of objects from object store.
///
/// \param object_ids the The list of ObjectIDs to be deleted.
/// \param local_only Whether keep this request with local object store
/// or send it to all the object stores.
void FreeObjects(const std::vector<ObjectID> &object_ids, bool local_only);
private:
friend class TestObjectManager;
@ -214,6 +221,11 @@ class ObjectManager : public ObjectManagerInterface {
/// Completion handler for Wait.
void WaitComplete(const UniqueID &wait_id);
/// Spread the Free request to all objects managers.
///
/// \param object_ids the The list of ObjectIDs to be deleted.
void SpreadFreeObjectRequest(const std::vector<ObjectID> &object_ids);
/// Handle starting, running, and stopping asio io_service.
void StartIOService();
void RunSendService();
@ -271,6 +283,9 @@ class ObjectManager : public ObjectManagerInterface {
/// Handles receiving a pull request message.
void ReceivePullRequest(std::shared_ptr<TcpClientConnection> &conn,
const uint8_t *message);
/// Handles freeing objects request.
void ReceiveFreeRequest(std::shared_ptr<TcpClientConnection> &conn,
const uint8_t *message);
/// Handles connect message of a new client connection.
void ConnectClient(std::shared_ptr<TcpClientConnection> &conn, const uint8_t *message);

View file

@ -68,6 +68,8 @@ enum MessageType:int {
// Push some profiling events to the GCS. When sending this message to the
// node manager, the message itself is serialized as a ProfileTableData object.
PushProfileEventsRequest,
// Free the objects in objects store.
FreeObjectsInObjectStoreRequest,
}
table TaskExecutionSpecification {
@ -177,3 +179,11 @@ table PushErrorRequest {
// The timestamp of the error message.
timestamp: double;
}
table FreeObjectsRequest {
// Whether keep this request with local object store
// or send it to all the object stores.
local_only: bool;
// List of object ids we'll delete from object store.
object_ids: [string];
}

View file

@ -712,6 +712,11 @@ void NodeManager::ProcessClientMessage(
RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*message));
} break;
case protocol::MessageType::FreeObjectsInObjectStoreRequest: {
auto message = flatbuffers::GetRoot<protocol::FreeObjectsRequest>(message_data);
std::vector<ObjectID> object_ids = from_flatbuf(*message->object_ids());
object_manager_.FreeObjects(object_ids, message->local_only());
} break;
default:
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;

View file

@ -47,6 +47,7 @@ class MockObjectDirectory : public ObjectDirectoryInterface {
MOCK_METHOD3(ReportObjectAdded,
ray::Status(const ObjectID &, const ClientID &, const ObjectInfoT &));
MOCK_METHOD2(ReportObjectRemoved, ray::Status(const ObjectID &, const ClientID &));
MOCK_METHOD1(RunFunctionForEachClient, void(const InfoSuccessCallback &success_cb));
private:
std::vector<std::pair<ObjectID, OnLocationsFound>> callbacks_;

View file

@ -1205,6 +1205,84 @@ class APITest(unittest.TestCase):
# test multi-threading in the worker
ray.get(test_multi_threading_in_worker.remote())
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") != "1",
"This test only works with xray.")
def testFreeObjectsMultiNode(self):
ray.worker._init(
start_ray_local=True,
num_local_schedulers=3,
num_workers=1,
num_cpus=[1, 1, 1],
resources=[{
"Custom0": 1
}, {
"Custom1": 1
}, {
"Custom2": 1
}],
use_raylet=True)
@ray.remote(resources={"Custom0": 1})
def run_on_0():
return ray.worker.global_worker.plasma_client.store_socket_name
@ray.remote(resources={"Custom1": 1})
def run_on_1():
return ray.worker.global_worker.plasma_client.store_socket_name
@ray.remote(resources={"Custom2": 1})
def run_on_2():
return ray.worker.global_worker.plasma_client.store_socket_name
def create():
a = run_on_0.remote()
b = run_on_1.remote()
c = run_on_2.remote()
(l1, l2) = ray.wait([a, b, c], num_returns=3)
assert len(l1) == 3
assert len(l2) == 0
return (a, b, c)
def flush():
# Flush the Release History.
# Current Plasma Client Cache will maintain 64-item list.
# If the number changed, this will fail.
print("Start Flush!")
for i in range(64):
ray.get(
[run_on_0.remote(),
run_on_1.remote(),
run_on_2.remote()])
print("Flush finished!")
def run_one_test(local_only):
(a, b, c) = create()
# The three objects should be generated on different object stores.
assert ray.get(a) != ray.get(b)
assert ray.get(a) != ray.get(c)
assert ray.get(c) != ray.get(b)
ray.internal.free([a, b, c], local_only=local_only)
flush()
return (a, b, c)
# Case 1: run this local_only=False. All 3 objects will be deleted.
(a, b, c) = run_one_test(False)
(l1, l2) = ray.wait([a, b, c], timeout=10, num_returns=1)
# All the objects are deleted.
assert len(l1) == 0
assert len(l2) == 3
# Case 2: run this local_only=True. Only 1 object will be deleted.
(a, b, c) = run_one_test(True)
(l1, l2) = ray.wait([a, b, c], timeout=10, num_returns=3)
# One object is deleted and 2 objects are not.
assert len(l1) == 2
assert len(l2) == 1
# The deleted object will have the same store with the driver.
local_return = ray.worker.global_worker.plasma_client.store_socket_name
for object_id in l1:
assert ray.get(object_id) != local_return
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),